Add whitelist feature
This commit is contained in:
parent
d5eb79ae28
commit
7144d272bd
9 changed files with 36 additions and 22 deletions
|
@ -225,6 +225,7 @@ not to pass them through the command line or environment variables:
|
|||
| basedir | none |
|
||||
| pkgname | "models" |
|
||||
| output | "models" |
|
||||
| whitelist | [ ] |
|
||||
| exclude | [ ] |
|
||||
| tag | [ ] |
|
||||
| debug | false |
|
||||
|
@ -261,6 +262,7 @@ sqlboiler postgres
|
|||
Flags:
|
||||
-b, --basedir string The base directory has the templates and templates_test folders
|
||||
-d, --debug Debug mode prints stack traces on error
|
||||
-w, --whitelist stringSlice Only include these tables in your generated package
|
||||
-x, --exclude stringSlice Tables to be excluded from the generated package
|
||||
--no-auto-timestamps Disable automatic timestamps for created_at/updated_at
|
||||
--no-hooks Disable hooks feature for your models
|
||||
|
|
|
@ -9,7 +9,10 @@ import (
|
|||
type MockDriver struct{}
|
||||
|
||||
// TableNames returns a list of mock table names
|
||||
func (m *MockDriver) TableNames(exclude []string) ([]string, error) {
|
||||
func (m *MockDriver) TableNames(whitelist, exclude []string) ([]string, error) {
|
||||
if len(whitelist) > 0 {
|
||||
return whitelist, nil
|
||||
}
|
||||
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
|
||||
return strmangle.SetComplement(tables, exclude), nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vattle/sqlboiler/bdb"
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// PostgresDriver holds the database connection string and a handle
|
||||
|
@ -82,18 +81,15 @@ func (p *PostgresDriver) UseLastInsertID() bool {
|
|||
|
||||
// TableNames connects to the postgres database and
|
||||
// retrieves all table names from the information_schema where the
|
||||
// table schema is public. It excludes common migration tool tables
|
||||
// such as gorp_migrations
|
||||
func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
|
||||
// table schema is public. It uses a whitelist and exclude list.
|
||||
func (p *PostgresDriver) TableNames(whitelist, exclude []string) ([]string, error) {
|
||||
var names []string
|
||||
|
||||
query := `select table_name from information_schema.tables where table_schema = 'public'`
|
||||
if len(exclude) > 0 {
|
||||
quoteStr := func(x string) string {
|
||||
return `'` + x + `'`
|
||||
}
|
||||
exclude = strmangle.StringMap(quoteStr, exclude)
|
||||
query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ","))
|
||||
if len(whitelist) > 0 {
|
||||
query = query + fmt.Sprintf("and table_name in ('%s');", strings.Join(whitelist, "','"))
|
||||
} else if len(exclude) > 0 {
|
||||
query = query + fmt.Sprintf("and table_name not in ('%s');", strings.Join(exclude, "','"))
|
||||
}
|
||||
|
||||
rows, err := p.dbConn.Query(query)
|
||||
|
|
|
@ -6,7 +6,7 @@ import "github.com/pkg/errors"
|
|||
// Interface for a database driver. Functionality required to support a specific
|
||||
// database type (eg, MySQL, Postgres etc.)
|
||||
type Interface interface {
|
||||
TableNames(exclude []string) ([]string, error)
|
||||
TableNames(whitelist, exclude []string) ([]string, error)
|
||||
Columns(tableName string) ([]Column, error)
|
||||
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
|
||||
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
|
||||
|
@ -26,10 +26,10 @@ type Interface interface {
|
|||
|
||||
// Tables returns the metadata for all tables, minus the tables
|
||||
// specified in the exclude slice.
|
||||
func Tables(db Interface, exclude ...string) ([]Table, error) {
|
||||
func Tables(db Interface, whitelist, exclude []string) ([]Table, error) {
|
||||
var err error
|
||||
|
||||
names, err := db.TableNames(exclude)
|
||||
names, err := db.TableNames(whitelist, exclude)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to get table names")
|
||||
}
|
||||
|
|
|
@ -13,7 +13,10 @@ func (m mockDriver) UseLastInsertID() bool { return false }
|
|||
func (m mockDriver) Open() error { return nil }
|
||||
func (m mockDriver) Close() {}
|
||||
|
||||
func (m mockDriver) TableNames(exclude []string) ([]string, error) {
|
||||
func (m mockDriver) TableNames(whitelist, exclude []string) ([]string, error) {
|
||||
if len(whitelist) > 0 {
|
||||
return whitelist, nil
|
||||
}
|
||||
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
|
||||
return strmangle.SetComplement(tables, exclude), nil
|
||||
}
|
||||
|
@ -96,7 +99,7 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
|||
func TestTables(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tables, err := Tables(mockDriver{})
|
||||
tables, err := Tables(mockDriver{}, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ type Config struct {
|
|||
PkgName string
|
||||
OutFolder string
|
||||
BaseDir string
|
||||
WhitelistTables []string
|
||||
ExcludeTables []string
|
||||
Tags []string
|
||||
Debug bool
|
||||
|
|
9
main.go
9
main.go
|
@ -64,6 +64,7 @@ func main() {
|
|||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders")
|
||||
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
|
||||
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
|
||||
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
|
||||
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
||||
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
|
||||
|
@ -126,6 +127,14 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist")
|
||||
if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") {
|
||||
cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cmdConfig.Tags = viper.GetStringSlice("tag")
|
||||
if len(cmdConfig.Tags) == 1 && strings.HasPrefix(cmdConfig.Tags[0], "[") {
|
||||
cmdConfig.Tags, err = cmd.PersistentFlags().GetStringSlice("tag")
|
||||
|
|
|
@ -59,7 +59,7 @@ func New(config *Config) (*State, error) {
|
|||
return nil, errors.Wrap(err, "unable to connect to the database")
|
||||
}
|
||||
|
||||
err = s.initTables(config.ExcludeTables)
|
||||
err = s.initTables(config.WhitelistTables, config.ExcludeTables)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to initialize tables")
|
||||
}
|
||||
|
@ -239,9 +239,9 @@ func (s *State) initDriver(driverName string) error {
|
|||
}
|
||||
|
||||
// initTables retrieves all "public" schema table names from the database.
|
||||
func (s *State) initTables(exclude []string) error {
|
||||
func (s *State) initTables(whitelist, exclude []string) error {
|
||||
var err error
|
||||
s.Tables, err = bdb.Tables(s.Driver, exclude...)
|
||||
s.Tables, err = bdb.Tables(s.Driver, whitelist, exclude)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unable to fetch table data")
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
func TestTextsFromForeignKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{})
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) {
|
|||
func TestTextsFromOneToOneRelationship(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{})
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) {
|
|||
func TestTextsFromRelationship(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{})
|
||||
tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue