Add whitelist feature

This commit is contained in:
Patrick O'brien 2016-09-06 00:41:12 +10:00
parent d5eb79ae28
commit 7144d272bd
9 changed files with 36 additions and 22 deletions

View file

@ -225,6 +225,7 @@ not to pass them through the command line or environment variables:
| basedir | none |
| pkgname | "models" |
| output | "models" |
| whitelist | [ ] |
| exclude | [ ] |
| tag | [ ] |
| debug | false |
@ -261,6 +262,7 @@ sqlboiler postgres
Flags:
-b, --basedir string The base directory has the templates and templates_test folders
-d, --debug Debug mode prints stack traces on error
-w, --whitelist stringSlice Only include these tables in your generated package
-x, --exclude stringSlice Tables to be excluded from the generated package
--no-auto-timestamps Disable automatic timestamps for created_at/updated_at
--no-hooks Disable hooks feature for your models

View file

@ -9,7 +9,10 @@ import (
type MockDriver struct{}
// TableNames returns a list of mock table names
func (m *MockDriver) TableNames(exclude []string) ([]string, error) {
func (m *MockDriver) TableNames(whitelist, exclude []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil
}

View file

@ -9,7 +9,6 @@ import (
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/strmangle"
)
// PostgresDriver holds the database connection string and a handle
@ -82,18 +81,15 @@ func (p *PostgresDriver) UseLastInsertID() bool {
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables
// such as gorp_migrations
func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
// table schema is public. It uses a whitelist and exclude list.
func (p *PostgresDriver) TableNames(whitelist, exclude []string) ([]string, error) {
var names []string
query := `select table_name from information_schema.tables where table_schema = 'public'`
if len(exclude) > 0 {
quoteStr := func(x string) string {
return `'` + x + `'`
}
exclude = strmangle.StringMap(quoteStr, exclude)
query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ","))
if len(whitelist) > 0 {
query = query + fmt.Sprintf("and table_name in ('%s');", strings.Join(whitelist, "','"))
} else if len(exclude) > 0 {
query = query + fmt.Sprintf("and table_name not in ('%s');", strings.Join(exclude, "','"))
}
rows, err := p.dbConn.Query(query)

View file

@ -6,7 +6,7 @@ import "github.com/pkg/errors"
// Interface for a database driver. Functionality required to support a specific
// database type (eg, MySQL, Postgres etc.)
type Interface interface {
TableNames(exclude []string) ([]string, error)
TableNames(whitelist, exclude []string) ([]string, error)
Columns(tableName string) ([]Column, error)
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
@ -26,10 +26,10 @@ type Interface interface {
// Tables returns the metadata for all tables, minus the tables
// specified in the exclude slice.
func Tables(db Interface, exclude ...string) ([]Table, error) {
func Tables(db Interface, whitelist, exclude []string) ([]Table, error) {
var err error
names, err := db.TableNames(exclude)
names, err := db.TableNames(whitelist, exclude)
if err != nil {
return nil, errors.Wrap(err, "unable to get table names")
}

View file

@ -13,7 +13,10 @@ func (m mockDriver) UseLastInsertID() bool { return false }
func (m mockDriver) Open() error { return nil }
func (m mockDriver) Close() {}
func (m mockDriver) TableNames(exclude []string) ([]string, error) {
func (m mockDriver) TableNames(whitelist, exclude []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil
}
@ -96,7 +99,7 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
func TestTables(t *testing.T) {
t.Parallel()
tables, err := Tables(mockDriver{})
tables, err := Tables(mockDriver{}, nil, nil)
if err != nil {
t.Error(err)
}

View file

@ -6,6 +6,7 @@ type Config struct {
PkgName string
OutFolder string
BaseDir string
WhitelistTables []string
ExcludeTables []string
Tags []string
Debug bool

View file

@ -64,6 +64,7 @@ func main() {
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders")
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
@ -126,6 +127,14 @@ func preRun(cmd *cobra.Command, args []string) error {
}
}
cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist")
if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") {
cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist")
if err != nil {
return err
}
}
cmdConfig.Tags = viper.GetStringSlice("tag")
if len(cmdConfig.Tags) == 1 && strings.HasPrefix(cmdConfig.Tags[0], "[") {
cmdConfig.Tags, err = cmd.PersistentFlags().GetStringSlice("tag")

View file

@ -59,7 +59,7 @@ func New(config *Config) (*State, error) {
return nil, errors.Wrap(err, "unable to connect to the database")
}
err = s.initTables(config.ExcludeTables)
err = s.initTables(config.WhitelistTables, config.ExcludeTables)
if err != nil {
return nil, errors.Wrap(err, "unable to initialize tables")
}
@ -239,9 +239,9 @@ func (s *State) initDriver(driverName string) error {
}
// initTables retrieves all "public" schema table names from the database.
func (s *State) initTables(exclude []string) error {
func (s *State) initTables(whitelist, exclude []string) error {
var err error
s.Tables, err = bdb.Tables(s.Driver, exclude...)
s.Tables, err = bdb.Tables(s.Driver, whitelist, exclude)
if err != nil {
return errors.Wrap(err, "unable to fetch table data")
}

View file

@ -12,7 +12,7 @@ import (
func TestTextsFromForeignKey(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
if err != nil {
t.Fatal(err)
}
@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) {
func TestTextsFromOneToOneRelationship(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
if err != nil {
t.Fatal(err)
}
@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) {
func TestTextsFromRelationship(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
if err != nil {
t.Fatal(err)
}