Add exclude/blacklist tables flag
This commit is contained in:
parent
79dfcf3ebf
commit
1e67965482
7 changed files with 46 additions and 24 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vattle/sqlboiler/bdb"
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// PostgresDriver holds the database connection string and a handle
|
||||
|
@ -83,13 +84,19 @@ func (p *PostgresDriver) UseLastInsertID() bool {
|
|||
// retrieves all table names from the information_schema where the
|
||||
// table schema is public. It excludes common migration tool tables
|
||||
// such as gorp_migrations
|
||||
func (p *PostgresDriver) TableNames() ([]string, error) {
|
||||
func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
|
||||
var names []string
|
||||
|
||||
rows, err := p.dbConn.Query(`
|
||||
select table_name from information_schema.tables
|
||||
where table_schema = 'public' and table_name not like '%migrations%'
|
||||
`)
|
||||
query := `select table_name from information_schema.tables where table_schema = 'public'`
|
||||
if len(exclude) > 0 {
|
||||
quoteStr := func(x string) string {
|
||||
return `'` + x + `'`
|
||||
}
|
||||
exclude = strmangle.StringMap(quoteStr, exclude)
|
||||
query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ","))
|
||||
}
|
||||
|
||||
rows, err := p.dbConn.Query(query)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -6,7 +6,7 @@ import "github.com/pkg/errors"
|
|||
// Interface for a database driver. Functionality required to support a specific
|
||||
// database type (eg, MySQL, Postgres etc.)
|
||||
type Interface interface {
|
||||
TableNames() ([]string, error)
|
||||
TableNames(exclude []string) ([]string, error)
|
||||
Columns(tableName string) ([]Column, error)
|
||||
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
|
||||
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
|
||||
|
@ -24,12 +24,12 @@ type Interface interface {
|
|||
Close()
|
||||
}
|
||||
|
||||
// Tables returns the table metadata for the given tables, or all tables if
|
||||
// no tables are provided.
|
||||
func Tables(db Interface) ([]Table, error) {
|
||||
// Tables returns the metadata for all tables, minus the tables
|
||||
// specified in the exclude slice.
|
||||
func Tables(db Interface, exclude ...string) ([]Table, error) {
|
||||
var err error
|
||||
|
||||
names, err := db.TableNames()
|
||||
names, err := db.TableNames(exclude)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to get table names")
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
type testInterface struct{}
|
||||
|
||||
func (t testInterface) TableNames() ([]string, error) {
|
||||
func (t testInterface) TableNames(exclude []string) ([]string, error) {
|
||||
return []string{"table1", "table2"}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,10 +2,11 @@ package main
|
|||
|
||||
// Config for the running of the commands
|
||||
type Config struct {
|
||||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
TableNames []string `toml:"table_names"`
|
||||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
ExcludeTables []string `toml:"exclude"`
|
||||
TableNames []string
|
||||
|
||||
Postgres PostgresConfig `toml:"postgres"`
|
||||
}
|
||||
|
|
|
@ -7,8 +7,9 @@ import (
|
|||
|
||||
type fakeDB int
|
||||
|
||||
func (fakeDB) TableNames() ([]string, error) {
|
||||
return []string{"users", "videos", "contests", "notifications", "users_videos_tags"}, nil
|
||||
func (fakeDB) TableNames(exclude []string) ([]string, error) {
|
||||
tables := []string{"users", "videos", "contests", "notifications", "users_videos_tags"}
|
||||
return strmangle.SetComplement(tables, exclude), nil
|
||||
}
|
||||
func (fakeDB) Columns(tableName string) ([]bdb.Column, error) {
|
||||
return map[string][]bdb.Column{
|
||||
|
|
21
main.go
21
main.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/kat-co/vala"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -40,10 +41,9 @@ func main() {
|
|||
viper.AddConfigPath(p)
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
fmt.Printf("Cannot read or locate config file: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Ignore errors here, fallback to other validation methods.
|
||||
// Users can use environment variables if a config is not found.
|
||||
_ = viper.ReadInConfig()
|
||||
|
||||
// Set up the cobra root command
|
||||
var rootCmd = &cobra.Command{
|
||||
|
@ -62,6 +62,7 @@ func main() {
|
|||
// Set up the cobra root command flags
|
||||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||
rootCmd.PersistentFlags().StringP("exclude", "x", "", "Tables to be excluded from the generated package")
|
||||
|
||||
viper.SetDefault("postgres.sslmode", "require")
|
||||
viper.SetDefault("postgres.port", "5432")
|
||||
|
@ -100,6 +101,18 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
PkgName: viper.GetString("pkgname"),
|
||||
}
|
||||
|
||||
// BUG: https://github.com/spf13/viper/issues/200
|
||||
// Look up the value of ExcludeTables directly from PFlags in Cobra if we
|
||||
// detect a malformed value coming out of viper.
|
||||
// Once the bug is fixed we'll be able to move this into the init above
|
||||
cmdConfig.ExcludeTables = viper.GetStringSlice("exclude")
|
||||
if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") {
|
||||
cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
cmdConfig.Postgres = PostgresConfig{
|
||||
User: viper.GetString("postgres.user"),
|
||||
|
|
|
@ -53,7 +53,7 @@ func New(config *Config) (*State, error) {
|
|||
return nil, errors.Wrap(err, "unable to connect to the database")
|
||||
}
|
||||
|
||||
err = s.initTables()
|
||||
err = s.initTables(config.ExcludeTables)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to initialize tables")
|
||||
}
|
||||
|
@ -190,9 +190,9 @@ func (s *State) initDriver(driverName string) error {
|
|||
}
|
||||
|
||||
// initTables retrieves all "public" schema table names from the database.
|
||||
func (s *State) initTables() error {
|
||||
func (s *State) initTables(exclude []string) error {
|
||||
var err error
|
||||
s.Tables, err = bdb.Tables(s.Driver)
|
||||
s.Tables, err = bdb.Tables(s.Driver, exclude...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unable to fetch table data")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue