diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 2c8341b..ecc964e 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -9,6 +9,7 @@ import ( _ "github.com/lib/pq" "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" + "github.com/vattle/sqlboiler/strmangle" ) // PostgresDriver holds the database connection string and a handle @@ -83,13 +84,19 @@ func (p *PostgresDriver) UseLastInsertID() bool { // retrieves all table names from the information_schema where the // table schema is public. It excludes common migration tool tables // such as gorp_migrations -func (p *PostgresDriver) TableNames() ([]string, error) { +func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) { var names []string - rows, err := p.dbConn.Query(` - select table_name from information_schema.tables - where table_schema = 'public' and table_name not like '%migrations%' - `) + query := `select table_name from information_schema.tables where table_schema = 'public'` + if len(exclude) > 0 { + quoteStr := func(x string) string { + return `'` + x + `'` + } + exclude = strmangle.StringMap(quoteStr, exclude) + query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ",")) + } + + rows, err := p.dbConn.Query(query) if err != nil { return nil, err diff --git a/bdb/interface.go b/bdb/interface.go index a27ce3d..9e4bfc7 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -6,7 +6,7 @@ import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - TableNames() ([]string, error) + TableNames(exclude []string) ([]string, error) Columns(tableName string) ([]Column, error) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) ForeignKeyInfo(tableName string) ([]ForeignKey, error) @@ -24,12 +24,12 @@ type Interface interface { Close() } -// Tables returns the table metadata for the given tables, or all tables if -// no tables are provided. -func Tables(db Interface) ([]Table, error) { +// Tables returns the metadata for all tables, minus the tables +// specified in the exclude slice. +func Tables(db Interface, exclude ...string) ([]Table, error) { var err error - names, err := db.TableNames() + names, err := db.TableNames(exclude) if err != nil { return nil, errors.Wrap(err, "unable to get table names") } diff --git a/bdb/interface_test.go b/bdb/interface_test.go index 662fb70..be90f18 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -7,7 +7,7 @@ import ( type testInterface struct{} -func (t testInterface) TableNames() ([]string, error) { +func (t testInterface) TableNames(exclude []string) ([]string, error) { return []string{"table1", "table2"}, nil } diff --git a/config.go b/config.go index 4a05617..6ee94e1 100644 --- a/config.go +++ b/config.go @@ -2,10 +2,11 @@ package main // Config for the running of the commands type Config struct { - DriverName string `toml:"driver_name"` - PkgName string `toml:"pkg_name"` - OutFolder string `toml:"out_folder"` - TableNames []string `toml:"table_names"` + DriverName string `toml:"driver_name"` + PkgName string `toml:"pkg_name"` + OutFolder string `toml:"out_folder"` + ExcludeTables []string `toml:"exclude"` + TableNames []string Postgres PostgresConfig `toml:"postgres"` } diff --git a/fakedb_test.go b/fakedb_test.go index e863d8b..95462d1 100644 --- a/fakedb_test.go +++ b/fakedb_test.go @@ -7,8 +7,9 @@ import ( type fakeDB int -func (fakeDB) TableNames() ([]string, error) { - return []string{"users", "videos", "contests", "notifications", "users_videos_tags"}, nil +func (fakeDB) TableNames(exclude []string) ([]string, error) { + tables := []string{"users", "videos", "contests", "notifications", "users_videos_tags"} + return strmangle.SetComplement(tables, exclude), nil } func (fakeDB) Columns(tableName string) ([]bdb.Column, error) { return map[string][]bdb.Column{ diff --git a/main.go b/main.go index f4b987e..7852021 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/kat-co/vala" "github.com/spf13/cobra" @@ -40,10 +41,9 @@ func main() { viper.AddConfigPath(p) } - if err := viper.ReadInConfig(); err != nil { - fmt.Printf("Cannot read or locate config file: %s\n", err) - os.Exit(1) - } + // Ignore errors here, fallback to other validation methods. + // Users can use environment variables if a config is not found. + _ = viper.ReadInConfig() // Set up the cobra root command var rootCmd = &cobra.Command{ @@ -62,6 +62,7 @@ func main() { // Set up the cobra root command flags rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") + rootCmd.PersistentFlags().StringP("exclude", "x", "", "Tables to be excluded from the generated package") viper.SetDefault("postgres.sslmode", "require") viper.SetDefault("postgres.port", "5432") @@ -100,6 +101,18 @@ func preRun(cmd *cobra.Command, args []string) error { PkgName: viper.GetString("pkgname"), } + // BUG: https://github.com/spf13/viper/issues/200 + // Look up the value of ExcludeTables directly from PFlags in Cobra if we + // detect a malformed value coming out of viper. + // Once the bug is fixed we'll be able to move this into the init above + cmdConfig.ExcludeTables = viper.GetStringSlice("exclude") + if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") { + cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude") + if err != nil { + return err + } + } + if viper.IsSet("postgres.dbname") { cmdConfig.Postgres = PostgresConfig{ User: viper.GetString("postgres.user"), diff --git a/sqlboiler.go b/sqlboiler.go index ec313e3..b04ecee 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -53,7 +53,7 @@ func New(config *Config) (*State, error) { return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables() + err = s.initTables(config.ExcludeTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } @@ -190,9 +190,9 @@ func (s *State) initDriver(driverName string) error { } // initTables retrieves all "public" schema table names from the database. -func (s *State) initTables() error { +func (s *State) initTables(exclude []string) error { var err error - s.Tables, err = bdb.Tables(s.Driver) + s.Tables, err = bdb.Tables(s.Driver, exclude...) if err != nil { return errors.Wrap(err, "unable to fetch table data") }