Fix some problems with viper setup.
- Fix error reporting throughout the executable side of the project.
This commit is contained in:
parent
e0f461014b
commit
57e20dfd72
4 changed files with 49 additions and 55 deletions
|
@ -5,7 +5,7 @@ type Config struct {
|
||||||
DriverName string `toml:"driver_name"`
|
DriverName string `toml:"driver_name"`
|
||||||
PkgName string `toml:"pkg_name"`
|
PkgName string `toml:"pkg_name"`
|
||||||
OutFolder string `toml:"out_folder"`
|
OutFolder string `toml:"out_folder"`
|
||||||
TableName string `toml:"table_name"`
|
TableNames []string `toml:"table_names"`
|
||||||
|
|
||||||
Postgres PostgresConfig `toml:"postgres"`
|
Postgres PostgresConfig `toml:"postgres"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package dbdrivers
|
package dbdrivers
|
||||||
|
|
||||||
import "fmt"
|
import "github.com/pkg/errors"
|
||||||
|
|
||||||
// Interface for a database driver. Functionality required to support a specific
|
// Interface for a database driver. Functionality required to support a specific
|
||||||
// database type (eg, MySQL, Postgres etc.)
|
// database type (eg, MySQL, Postgres etc.)
|
||||||
|
@ -61,8 +61,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
||||||
var err error
|
var err error
|
||||||
if len(names) == 0 {
|
if len(names) == 0 {
|
||||||
if names, err = db.TableNames(); err != nil {
|
if names, err = db.TableNames(); err != nil {
|
||||||
fmt.Println("Unable to get table names.")
|
return nil, errors.Wrap(err, "unable to get table names")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,8 +70,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
||||||
t := Table{Name: name}
|
t := Table{Name: name}
|
||||||
|
|
||||||
if t.Columns, err = db.Columns(name); err != nil {
|
if t.Columns, err = db.Columns(name); err != nil {
|
||||||
fmt.Println("Unable to get columns.")
|
return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, c := range t.Columns {
|
for i, c := range t.Columns {
|
||||||
|
@ -80,13 +78,11 @@ func Tables(db Interface, names ...string) ([]Table, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
|
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
|
||||||
fmt.Println("Unable to get primary key info.")
|
return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
|
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
|
||||||
fmt.Println("Unable to get foreign key info.")
|
return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsJoinTable(&t)
|
setIsJoinTable(&t)
|
||||||
|
|
37
main.go
37
main.go
|
@ -2,12 +2,13 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
@ -56,31 +57,42 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the cobra root command flags
|
// Set up the cobra root command flags
|
||||||
rootCmd.PersistentFlags().StringSliceP("table", "t", nil, "Tables to generate models for, all tables if empty")
|
rootCmd.PersistentFlags().StringSliceP("tables", "t", nil, "Tables to generate models for, all tables if empty")
|
||||||
rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to")
|
rootCmd.PersistentFlags().StringP("output", "o", "output", "The name of the folder to output to")
|
||||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||||
|
|
||||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||||
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
if err := rootCmd.Execute(); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Printf("\n%+v\n", err)
|
||||||
os.Exit(-1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func preRun(cmd *cobra.Command, args []string) error {
|
func preRun(cmd *cobra.Command, args []string) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
_ = cmd.Help()
|
return errors.New("must provide a driver name")
|
||||||
fmt.Println("\nmust provide a driver")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdConfig = new(Config)
|
cmdConfig = &Config{
|
||||||
|
DriverName: args[0],
|
||||||
|
OutFolder: viper.GetString("output"),
|
||||||
|
PkgName: viper.GetString("pkgname"),
|
||||||
|
}
|
||||||
|
|
||||||
cmdConfig.DriverName = args[0]
|
// BUG: https://github.com/spf13/viper/issues/200
|
||||||
cmdConfig.TableName = viper.GetString("table")
|
// Look up the value of TableNames directly from PFlags in Cobra if we
|
||||||
cmdConfig.OutFolder = viper.GetString("output")
|
// detect a malformed value coming out of viper.
|
||||||
cmdConfig.PkgName = viper.GetString("pkgname")
|
// Once the bug is fixed we'll be able to move this into the init above
|
||||||
|
cmdConfig.TableNames = viper.GetStringSlice("tables")
|
||||||
|
if len(cmdConfig.TableNames) == 1 && strings.HasPrefix(cmdConfig.TableNames[0], "[") {
|
||||||
|
cmdConfig.TableNames, err = cmd.PersistentFlags().GetStringSlice("tables")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(cmdConfig.DriverName) == 0 {
|
if len(cmdConfig.DriverName) == 0 {
|
||||||
return errors.New("Must supply a driver flag.")
|
return errors.New("Must supply a driver flag.")
|
||||||
|
@ -101,7 +113,6 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
spew.Dump(cmdConfig)
|
spew.Dump(cmdConfig)
|
||||||
|
|
||||||
var err error
|
|
||||||
cmdState, err = New(cmdConfig)
|
cmdState, err = New(cmdConfig)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
45
sqlboiler.go
45
sqlboiler.go
|
@ -3,13 +3,13 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -50,22 +50,22 @@ func New(config *Config) (*State, error) {
|
||||||
|
|
||||||
// Connect to the driver database
|
// Connect to the driver database
|
||||||
if err = s.Driver.Open(); err != nil {
|
if err = s.Driver.Open(); err != nil {
|
||||||
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
|
return nil, errors.Wrap(err, "unable to connect to the database")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.initTables(config.TableName)
|
err = s.initTables(config.TableNames)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
|
return nil, errors.Wrap(err, "unable to initialize tables")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.initOutFolder()
|
err = s.initOutFolder()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
|
return nil, errors.Wrap(err, "unable to initialize the output folder")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.initTemplates()
|
err = s.initTemplates()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
|
return nil, errors.Wrap(err, "unable to initialize templates")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
|
@ -81,16 +81,16 @@ func (s *State) Run(includeTests bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generateSingletonOutput(s, singletonData); err != nil {
|
if err := generateSingletonOutput(s, singletonData); err != nil {
|
||||||
return fmt.Errorf("Unable to generate singleton template output: %s", err)
|
return errors.Wrap(err, "singleton template output")
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeTests {
|
if includeTests {
|
||||||
if err := generateTestMainOutput(s, singletonData); err != nil {
|
if err := generateTestMainOutput(s, singletonData); err != nil {
|
||||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
return errors.Wrap(err, "unable to generate TestMain output")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generateSingletonTestOutput(s, singletonData); err != nil {
|
if err := generateSingletonTestOutput(s, singletonData); err != nil {
|
||||||
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
|
return errors.Wrap(err, "unable to generate singleton test template output")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,13 +107,13 @@ func (s *State) Run(includeTests bool) error {
|
||||||
|
|
||||||
// Generate the regular templates
|
// Generate the regular templates
|
||||||
if err := generateOutput(s, data); err != nil {
|
if err := generateOutput(s, data); err != nil {
|
||||||
return fmt.Errorf("Unable to generate output: %s", err)
|
return errors.Wrap(err, "unable to generate output")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the test templates
|
// Generate the test templates
|
||||||
if includeTests {
|
if includeTests {
|
||||||
if err := generateTestOutput(s, data); err != nil {
|
if err := generateTestOutput(s, data); err != nil {
|
||||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
return errors.Wrap(err, "unable to generate test output")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -185,24 +185,15 @@ func (s *State) initDriver(driverName string) error {
|
||||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||||
// database to retrieve all "public" table names, and build a slice out of that
|
// database to retrieve all "public" table names, and build a slice out of that
|
||||||
// result.
|
// result.
|
||||||
func (s *State) initTables(tableName string) error {
|
func (s *State) initTables(tableNames []string) error {
|
||||||
var tableNames []string
|
|
||||||
|
|
||||||
if len(tableName) != 0 {
|
|
||||||
tableNames = strings.Split(tableName, ",")
|
|
||||||
for i, name := range tableNames {
|
|
||||||
tableNames[i] = strings.TrimSpace(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
|
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
return errors.Wrap(err, "unable to fetch table data")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(s.Tables) == 0 {
|
if len(s.Tables) == 0 {
|
||||||
return errors.New("No tables found in database, migrate some tables first")
|
return errors.New("no tables found in database")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkPKeys(s.Tables); err != nil {
|
if err := checkPKeys(s.Tables); err != nil {
|
||||||
|
@ -214,11 +205,7 @@ func (s *State) initTables(tableName string) error {
|
||||||
|
|
||||||
// initOutFolder creates the folder that will hold the generated output.
|
// initOutFolder creates the folder that will hold the generated output.
|
||||||
func (s *State) initOutFolder() error {
|
func (s *State) initOutFolder() error {
|
||||||
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
|
return os.MkdirAll(s.Config.OutFolder, os.ModePerm)
|
||||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkPKeys ensures every table has a primary key column
|
// checkPKeys ensures every table has a primary key column
|
||||||
|
@ -231,7 +218,7 @@ func checkPKeys(tables []dbdrivers.Table) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(missingPkey) != 0 {
|
if len(missingPkey) != 0 {
|
||||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
return fmt.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in a new issue