Fix some problems with viper setup.
- Fix error reporting throughout the executable side of the project.
This commit is contained in:
parent
e0f461014b
commit
57e20dfd72
4 changed files with 49 additions and 55 deletions
|
@ -2,10 +2,10 @@ package main
|
|||
|
||||
// Config for the running of the commands
|
||||
type Config struct {
|
||||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
TableName string `toml:"table_name"`
|
||||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
TableNames []string `toml:"table_names"`
|
||||
|
||||
Postgres PostgresConfig `toml:"postgres"`
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package dbdrivers
|
||||
|
||||
import "fmt"
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
// Interface for a database driver. Functionality required to support a specific
|
||||
// database type (eg, MySQL, Postgres etc.)
|
||||
|
@ -61,8 +61,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
|||
var err error
|
||||
if len(names) == 0 {
|
||||
if names, err = db.TableNames(); err != nil {
|
||||
fmt.Println("Unable to get table names.")
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "unable to get table names")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,8 +70,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
|||
t := Table{Name: name}
|
||||
|
||||
if t.Columns, err = db.Columns(name); err != nil {
|
||||
fmt.Println("Unable to get columns.")
|
||||
return nil, err
|
||||
return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
|
||||
}
|
||||
|
||||
for i, c := range t.Columns {
|
||||
|
@ -80,13 +78,11 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
|||
}
|
||||
|
||||
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get primary key info.")
|
||||
return nil, err
|
||||
return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
|
||||
}
|
||||
|
||||
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get foreign key info.")
|
||||
return nil, err
|
||||
return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
|
||||
}
|
||||
|
||||
setIsJoinTable(&t)
|
||||
|
|
37
main.go
37
main.go
|
@ -2,12 +2,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
@ -56,31 +57,42 @@ func main() {
|
|||
}
|
||||
|
||||
// Set up the cobra root command flags
|
||||
rootCmd.PersistentFlags().StringSliceP("table", "t", nil, "Tables to generate models for, all tables if empty")
|
||||
rootCmd.PersistentFlags().StringSliceP("tables", "t", nil, "Tables to generate models for, all tables if empty")
|
||||
rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||
|
||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(-1)
|
||||
fmt.Printf("\n%+v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func preRun(cmd *cobra.Command, args []string) error {
|
||||
var err error
|
||||
|
||||
if len(args) == 0 {
|
||||
_ = cmd.Help()
|
||||
fmt.Println("\nmust provide a driver")
|
||||
os.Exit(1)
|
||||
return errors.New("must provide a driver name")
|
||||
}
|
||||
|
||||
cmdConfig = new(Config)
|
||||
cmdConfig = &Config{
|
||||
DriverName: args[0],
|
||||
OutFolder: viper.GetString("output"),
|
||||
PkgName: viper.GetString("pkgname"),
|
||||
}
|
||||
|
||||
cmdConfig.DriverName = args[0]
|
||||
cmdConfig.TableName = viper.GetString("table")
|
||||
cmdConfig.OutFolder = viper.GetString("output")
|
||||
cmdConfig.PkgName = viper.GetString("pkgname")
|
||||
// BUG: https://github.com/spf13/viper/issues/200
|
||||
// Look up the value of TableNames directly from PFlags in Cobra if we
|
||||
// detect a malformed value coming out of viper.
|
||||
// Once the bug is fixed we'll be able to move this into the init above
|
||||
cmdConfig.TableNames = viper.GetStringSlice("tables")
|
||||
if len(cmdConfig.TableNames) == 1 && strings.HasPrefix(cmdConfig.TableNames[0], "[") {
|
||||
cmdConfig.TableNames, err = cmd.PersistentFlags().GetStringSlice("tables")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(cmdConfig.DriverName) == 0 {
|
||||
return errors.New("Must supply a driver flag.")
|
||||
|
@ -101,7 +113,6 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
|
||||
spew.Dump(cmdConfig)
|
||||
|
||||
var err error
|
||||
cmdState, err = New(cmdConfig)
|
||||
return err
|
||||
}
|
||||
|
|
45
sqlboiler.go
45
sqlboiler.go
|
@ -3,13 +3,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -50,22 +50,22 @@ func New(config *Config) (*State, error) {
|
|||
|
||||
// Connect to the driver database
|
||||
if err = s.Driver.Open(); err != nil {
|
||||
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
|
||||
return nil, errors.Wrap(err, "unable to connect to the database")
|
||||
}
|
||||
|
||||
err = s.initTables(config.TableName)
|
||||
err = s.initTables(config.TableNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
|
||||
return nil, errors.Wrap(err, "unable to initialize tables")
|
||||
}
|
||||
|
||||
err = s.initOutFolder()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
return nil, errors.Wrap(err, "unable to initialize the output folder")
|
||||
}
|
||||
|
||||
err = s.initTemplates()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
|
||||
return nil, errors.Wrap(err, "unable to initialize templates")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
|
@ -81,16 +81,16 @@ func (s *State) Run(includeTests bool) error {
|
|||
}
|
||||
|
||||
if err := generateSingletonOutput(s, singletonData); err != nil {
|
||||
return fmt.Errorf("Unable to generate singleton template output: %s", err)
|
||||
return errors.Wrap(err, "singleton template output")
|
||||
}
|
||||
|
||||
if includeTests {
|
||||
if err := generateTestMainOutput(s, singletonData); err != nil {
|
||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||
return errors.Wrap(err, "unable to generate TestMain output")
|
||||
}
|
||||
|
||||
if err := generateSingletonTestOutput(s, singletonData); err != nil {
|
||||
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
|
||||
return errors.Wrap(err, "unable to generate singleton test template output")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,13 +107,13 @@ func (s *State) Run(includeTests bool) error {
|
|||
|
||||
// Generate the regular templates
|
||||
if err := generateOutput(s, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
return errors.Wrap(err, "unable to generate output")
|
||||
}
|
||||
|
||||
// Generate the test templates
|
||||
if includeTests {
|
||||
if err := generateTestOutput(s, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||
return errors.Wrap(err, "unable to generate test output")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -185,24 +185,15 @@ func (s *State) initDriver(driverName string) error {
|
|||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||
// database to retrieve all "public" table names, and build a slice out of that
|
||||
// result.
|
||||
func (s *State) initTables(tableName string) error {
|
||||
var tableNames []string
|
||||
|
||||
if len(tableName) != 0 {
|
||||
tableNames = strings.Split(tableName, ",")
|
||||
for i, name := range tableNames {
|
||||
tableNames[i] = strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) initTables(tableNames []string) error {
|
||||
var err error
|
||||
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||
return errors.Wrap(err, "unable to fetch table data")
|
||||
}
|
||||
|
||||
if len(s.Tables) == 0 {
|
||||
return errors.New("No tables found in database, migrate some tables first")
|
||||
return errors.New("no tables found in database")
|
||||
}
|
||||
|
||||
if err := checkPKeys(s.Tables); err != nil {
|
||||
|
@ -214,11 +205,7 @@ func (s *State) initTables(tableName string) error {
|
|||
|
||||
// initOutFolder creates the folder that will hold the generated output.
|
||||
func (s *State) initOutFolder() error {
|
||||
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return os.MkdirAll(s.Config.OutFolder, os.ModePerm)
|
||||
}
|
||||
|
||||
// checkPKeys ensures every table has a primary key column
|
||||
|
@ -231,7 +218,7 @@ func checkPKeys(tables []dbdrivers.Table) error {
|
|||
}
|
||||
|
||||
if len(missingPkey) != 0 {
|
||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||
return fmt.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in a new issue