Fix some problems with viper setup.

- Fix error reporting throughout the executable side of the project.
This commit is contained in:
Aaron L 2016-06-12 15:34:57 -07:00
parent e0f461014b
commit 57e20dfd72
4 changed files with 49 additions and 55 deletions

View file

@ -2,10 +2,10 @@ package main
// Config for the running of the commands // Config for the running of the commands
type Config struct { type Config struct {
DriverName string `toml:"driver_name"` DriverName string `toml:"driver_name"`
PkgName string `toml:"pkg_name"` PkgName string `toml:"pkg_name"`
OutFolder string `toml:"out_folder"` OutFolder string `toml:"out_folder"`
TableName string `toml:"table_name"` TableNames []string `toml:"table_names"`
Postgres PostgresConfig `toml:"postgres"` Postgres PostgresConfig `toml:"postgres"`
} }

View file

@ -1,6 +1,6 @@
package dbdrivers package dbdrivers
import "fmt" import "github.com/pkg/errors"
// Interface for a database driver. Functionality required to support a specific // Interface for a database driver. Functionality required to support a specific
// database type (eg, MySQL, Postgres etc.) // database type (eg, MySQL, Postgres etc.)
@ -61,8 +61,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
var err error var err error
if len(names) == 0 { if len(names) == 0 {
if names, err = db.TableNames(); err != nil { if names, err = db.TableNames(); err != nil {
fmt.Println("Unable to get table names.") return nil, errors.Wrap(err, "unable to get table names")
return nil, err
} }
} }
@ -71,8 +70,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
t := Table{Name: name} t := Table{Name: name}
if t.Columns, err = db.Columns(name); err != nil { if t.Columns, err = db.Columns(name); err != nil {
fmt.Println("Unable to get columns.") return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
return nil, err
} }
for i, c := range t.Columns { for i, c := range t.Columns {
@ -80,13 +78,11 @@ func Tables(db Interface, names ...string) ([]Table, error) {
} }
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil { if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
fmt.Println("Unable to get primary key info.") return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
return nil, err
} }
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil { if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
fmt.Println("Unable to get foreign key info.") return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
return nil, err
} }
setIsJoinTable(&t) setIsJoinTable(&t)

37
main.go
View file

@ -2,12 +2,13 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -56,31 +57,42 @@ func main() {
} }
// Set up the cobra root command flags // Set up the cobra root command flags
rootCmd.PersistentFlags().StringSliceP("table", "t", nil, "Tables to generate models for, all tables if empty") rootCmd.PersistentFlags().StringSliceP("tables", "t", nil, "Tables to generate models for, all tables if empty")
rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
viper.BindPFlags(rootCmd.PersistentFlags()) viper.BindPFlags(rootCmd.PersistentFlags())
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
fmt.Println(err) fmt.Printf("\n%+v\n", err)
os.Exit(-1) os.Exit(1)
} }
} }
func preRun(cmd *cobra.Command, args []string) error { func preRun(cmd *cobra.Command, args []string) error {
var err error
if len(args) == 0 { if len(args) == 0 {
_ = cmd.Help() return errors.New("must provide a driver name")
fmt.Println("\nmust provide a driver")
os.Exit(1)
} }
cmdConfig = new(Config) cmdConfig = &Config{
DriverName: args[0],
OutFolder: viper.GetString("output"),
PkgName: viper.GetString("pkgname"),
}
cmdConfig.DriverName = args[0] // BUG: https://github.com/spf13/viper/issues/200
cmdConfig.TableName = viper.GetString("table") // Look up the value of TableNames directly from PFlags in Cobra if we
cmdConfig.OutFolder = viper.GetString("output") // detect a malformed value coming out of viper.
cmdConfig.PkgName = viper.GetString("pkgname") // Once the bug is fixed we'll be able to move this into the init above
cmdConfig.TableNames = viper.GetStringSlice("tables")
if len(cmdConfig.TableNames) == 1 && strings.HasPrefix(cmdConfig.TableNames[0], "[") {
cmdConfig.TableNames, err = cmd.PersistentFlags().GetStringSlice("tables")
if err != nil {
return err
}
}
if len(cmdConfig.DriverName) == 0 { if len(cmdConfig.DriverName) == 0 {
return errors.New("Must supply a driver flag.") return errors.New("Must supply a driver flag.")
@ -101,7 +113,6 @@ func preRun(cmd *cobra.Command, args []string) error {
spew.Dump(cmdConfig) spew.Dump(cmdConfig)
var err error
cmdState, err = New(cmdConfig) cmdState, err = New(cmdConfig)
return err return err
} }

View file

@ -3,13 +3,13 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"strings" "strings"
"text/template" "text/template"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/dbdrivers"
"github.com/pkg/errors"
) )
const ( const (
@ -50,22 +50,22 @@ func New(config *Config) (*State, error) {
// Connect to the driver database // Connect to the driver database
if err = s.Driver.Open(); err != nil { if err = s.Driver.Open(); err != nil {
return nil, fmt.Errorf("Unable to connect to the database: %s", err) return nil, errors.Wrap(err, "unable to connect to the database")
} }
err = s.initTables(config.TableName) err = s.initTables(config.TableNames)
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to initialize tables: %s", err) return nil, errors.Wrap(err, "unable to initialize tables")
} }
err = s.initOutFolder() err = s.initOutFolder()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err) return nil, errors.Wrap(err, "unable to initialize the output folder")
} }
err = s.initTemplates() err = s.initTemplates()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to initialize templates: %s", err) return nil, errors.Wrap(err, "unable to initialize templates")
} }
return s, nil return s, nil
@ -81,16 +81,16 @@ func (s *State) Run(includeTests bool) error {
} }
if err := generateSingletonOutput(s, singletonData); err != nil { if err := generateSingletonOutput(s, singletonData); err != nil {
return fmt.Errorf("Unable to generate singleton template output: %s", err) return errors.Wrap(err, "singleton template output")
} }
if includeTests { if includeTests {
if err := generateTestMainOutput(s, singletonData); err != nil { if err := generateTestMainOutput(s, singletonData); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err) return errors.Wrap(err, "unable to generate TestMain output")
} }
if err := generateSingletonTestOutput(s, singletonData); err != nil { if err := generateSingletonTestOutput(s, singletonData); err != nil {
return fmt.Errorf("Unable to generate singleton test template output: %s", err) return errors.Wrap(err, "unable to generate singleton test template output")
} }
} }
@ -107,13 +107,13 @@ func (s *State) Run(includeTests bool) error {
// Generate the regular templates // Generate the regular templates
if err := generateOutput(s, data); err != nil { if err := generateOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err) return errors.Wrap(err, "unable to generate output")
} }
// Generate the test templates // Generate the test templates
if includeTests { if includeTests {
if err := generateTestOutput(s, data); err != nil { if err := generateTestOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err) return errors.Wrap(err, "unable to generate test output")
} }
} }
} }
@ -185,24 +185,15 @@ func (s *State) initDriver(driverName string) error {
// if one is provided. If no flag is provided, it will attempt to connect to the // if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that // database to retrieve all "public" table names, and build a slice out of that
// result. // result.
func (s *State) initTables(tableName string) error { func (s *State) initTables(tableNames []string) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error var err error
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...) s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
if err != nil { if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err) return errors.Wrap(err, "unable to fetch table data")
} }
if len(s.Tables) == 0 { if len(s.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first") return errors.New("no tables found in database")
} }
if err := checkPKeys(s.Tables); err != nil { if err := checkPKeys(s.Tables); err != nil {
@ -214,11 +205,7 @@ func (s *State) initTables(tableName string) error {
// initOutFolder creates the folder that will hold the generated output. // initOutFolder creates the folder that will hold the generated output.
func (s *State) initOutFolder() error { func (s *State) initOutFolder() error {
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil { return os.MkdirAll(s.Config.OutFolder, os.ModePerm)
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
} }
// checkPKeys ensures every table has a primary key column // checkPKeys ensures every table has a primary key column
@ -231,7 +218,7 @@ func checkPKeys(tables []dbdrivers.Table) error {
} }
if len(missingPkey) != 0 { if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) return fmt.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", "))
} }
return nil return nil