diff --git a/config.go b/config.go index fc5ec7c..ac17593 100644 --- a/config.go +++ b/config.go @@ -2,10 +2,10 @@ package main // Config for the running of the commands type Config struct { - DriverName string `toml:"driver_name"` - PkgName string `toml:"pkg_name"` - OutFolder string `toml:"out_folder"` - TableName string `toml:"table_name"` + DriverName string `toml:"driver_name"` + PkgName string `toml:"pkg_name"` + OutFolder string `toml:"out_folder"` + TableNames []string `toml:"table_names"` Postgres PostgresConfig `toml:"postgres"` } diff --git a/dbdrivers/interface.go b/dbdrivers/interface.go index 50f9653..d406b60 100644 --- a/dbdrivers/interface.go +++ b/dbdrivers/interface.go @@ -1,6 +1,6 @@ package dbdrivers -import "fmt" +import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) @@ -61,8 +61,7 @@ func Tables(db Interface, names ...string) ([]Table, error) { var err error if len(names) == 0 { if names, err = db.TableNames(); err != nil { - fmt.Println("Unable to get table names.") - return nil, err + return nil, errors.Wrap(err, "unable to get table names") } } @@ -71,8 +70,7 @@ func Tables(db Interface, names ...string) ([]Table, error) { t := Table{Name: name} if t.Columns, err = db.Columns(name); err != nil { - fmt.Println("Unable to get columns.") - return nil, err + return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name) } for i, c := range t.Columns { @@ -80,13 +78,11 @@ func Tables(db Interface, names ...string) ([]Table, error) { } if t.PKey, err = db.PrimaryKeyInfo(name); err != nil { - fmt.Println("Unable to get primary key info.") - return nil, err + return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) } if t.FKeys, err = db.ForeignKeyInfo(name); err != nil { - fmt.Println("Unable to get foreign key info.") - return nil, err + return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) } setIsJoinTable(&t) diff --git a/main.go b/main.go index 622d028..ebd4da5 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,13 @@ package main import ( - "errors" "fmt" "os" "path/filepath" + "strings" "github.com/davecgh/go-spew/spew" + "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -56,31 +57,42 @@ func main() { } // Set up the cobra root command flags - rootCmd.PersistentFlags().StringSliceP("table", "t", nil, "Tables to generate models for, all tables if empty") + rootCmd.PersistentFlags().StringSliceP("tables", "t", nil, "Tables to generate models for, all tables if empty") rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package") viper.BindPFlags(rootCmd.PersistentFlags()) if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(-1) + fmt.Printf("\n%+v\n", err) + os.Exit(1) } } func preRun(cmd *cobra.Command, args []string) error { + var err error + if len(args) == 0 { - _ = cmd.Help() - fmt.Println("\nmust provide a driver") - os.Exit(1) + return errors.New("must provide a driver name") } - cmdConfig = new(Config) + cmdConfig = &Config{ + DriverName: args[0], + OutFolder: viper.GetString("output"), + PkgName: viper.GetString("pkgname"), + } - cmdConfig.DriverName = args[0] - cmdConfig.TableName = viper.GetString("table") - cmdConfig.OutFolder = viper.GetString("output") - cmdConfig.PkgName = viper.GetString("pkgname") + // BUG: https://github.com/spf13/viper/issues/200 + // Look up the value of TableNames directly from PFlags in Cobra if we + // detect a malformed value coming out of viper. + // Once the bug is fixed we'll be able to move this into the init above + cmdConfig.TableNames = viper.GetStringSlice("tables") + if len(cmdConfig.TableNames) == 1 && strings.HasPrefix(cmdConfig.TableNames[0], "[") { + cmdConfig.TableNames, err = cmd.PersistentFlags().GetStringSlice("tables") + if err != nil { + return err + } + } if len(cmdConfig.DriverName) == 0 { return errors.New("Must supply a driver flag.") @@ -101,7 +113,6 @@ func preRun(cmd *cobra.Command, args []string) error { spew.Dump(cmdConfig) - var err error cmdState, err = New(cmdConfig) return err } diff --git a/sqlboiler.go b/sqlboiler.go index 2faf34f..1347962 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -3,13 +3,13 @@ package main import ( - "errors" "fmt" "os" "strings" "text/template" "github.com/nullbio/sqlboiler/dbdrivers" + "github.com/pkg/errors" ) const ( @@ -50,22 +50,22 @@ func New(config *Config) (*State, error) { // Connect to the driver database if err = s.Driver.Open(); err != nil { - return nil, fmt.Errorf("Unable to connect to the database: %s", err) + return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables(config.TableName) + err = s.initTables(config.TableNames) if err != nil { - return nil, fmt.Errorf("Unable to initialize tables: %s", err) + return nil, errors.Wrap(err, "unable to initialize tables") } err = s.initOutFolder() if err != nil { - return nil, fmt.Errorf("Unable to initialize the output folder: %s", err) + return nil, errors.Wrap(err, "unable to initialize the output folder") } err = s.initTemplates() if err != nil { - return nil, fmt.Errorf("Unable to initialize templates: %s", err) + return nil, errors.Wrap(err, "unable to initialize templates") } return s, nil @@ -81,16 +81,16 @@ func (s *State) Run(includeTests bool) error { } if err := generateSingletonOutput(s, singletonData); err != nil { - return fmt.Errorf("Unable to generate singleton template output: %s", err) + return errors.Wrap(err, "singleton template output") } if includeTests { if err := generateTestMainOutput(s, singletonData); err != nil { - return fmt.Errorf("Unable to generate TestMain output: %s", err) + return errors.Wrap(err, "unable to generate TestMain output") } if err := generateSingletonTestOutput(s, singletonData); err != nil { - return fmt.Errorf("Unable to generate singleton test template output: %s", err) + return errors.Wrap(err, "unable to generate singleton test template output") } } @@ -107,13 +107,13 @@ func (s *State) Run(includeTests bool) error { // Generate the regular templates if err := generateOutput(s, data); err != nil { - return fmt.Errorf("Unable to generate output: %s", err) + return errors.Wrap(err, "unable to generate output") } // Generate the test templates if includeTests { if err := generateTestOutput(s, data); err != nil { - return fmt.Errorf("Unable to generate test output: %s", err) + return errors.Wrap(err, "unable to generate test output") } } } @@ -185,24 +185,15 @@ func (s *State) initDriver(driverName string) error { // if one is provided. If no flag is provided, it will attempt to connect to the // database to retrieve all "public" table names, and build a slice out of that // result. -func (s *State) initTables(tableName string) error { - var tableNames []string - - if len(tableName) != 0 { - tableNames = strings.Split(tableName, ",") - for i, name := range tableNames { - tableNames[i] = strings.TrimSpace(name) - } - } - +func (s *State) initTables(tableNames []string) error { var err error s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...) if err != nil { - return fmt.Errorf("Unable to get all table names: %s", err) + return errors.Wrap(err, "unable to fetch table data") } if len(s.Tables) == 0 { - return errors.New("No tables found in database, migrate some tables first") + return errors.New("no tables found in database") } if err := checkPKeys(s.Tables); err != nil { @@ -214,11 +205,7 @@ func (s *State) initTables(tableName string) error { // initOutFolder creates the folder that will hold the generated output. func (s *State) initOutFolder() error { - if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil { - return fmt.Errorf("Unable to make output folder: %s", err) - } - - return nil + return os.MkdirAll(s.Config.OutFolder, os.ModePerm) } // checkPKeys ensures every table has a primary key column @@ -231,7 +218,7 @@ func checkPKeys(tables []dbdrivers.Table) error { } if len(missingPkey) != 0 { - return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) + return fmt.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", ")) } return nil