// Package sqlboiler has types and methods useful for generating code that // acts as a fully dynamic ORM might. package main import ( "encoding/json" "fmt" "go/build" "os" "path/filepath" "regexp" "strings" "text/template" "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb/drivers" "github.com/vattle/sqlboiler/queries" "github.com/vattle/sqlboiler/strmangle" ) const ( templatesDirectory = "templates" templatesSingletonDirectory = "templates/singleton" templatesTestDirectory = "templates_test" templatesSingletonTestDirectory = "templates_test/singleton" templatesTestMainDirectory = "templates_test/main_test" ) // State holds the global data needed by most pieces to run type State struct { Config *Config Driver bdb.Interface Tables []bdb.Table Dialect queries.Dialect Templates *templateList TestTemplates *templateList SingletonTemplates *templateList SingletonTestTemplates *templateList TestMainTemplate *template.Template } // New creates a new state based off of the config func New(config *Config) (*State, error) { s := &State{ Config: config, } err := s.initDriver(config.DriverName) if err != nil { return nil, err } // Connect to the driver database if err = s.Driver.Open(); err != nil { return nil, errors.Wrap(err, "unable to connect to the database") } err = s.initTables(config.Schema, config.WhitelistTables, config.BlacklistTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } if s.Config.Debug { b, err := json.Marshal(s.Tables) if err != nil { return nil, errors.Wrap(err, "unable to json marshal tables") } fmt.Printf("%s\n", b) } err = s.initOutFolder() if err != nil { return nil, errors.Wrap(err, "unable to initialize the output folder") } err = s.initTemplates() if err != nil { return nil, errors.Wrap(err, "unable to initialize templates") } err = s.initTags(config.Tags) if err != nil { return nil, errors.Wrap(err, "unable to initialize struct tags") } return s, nil } // Run executes the sqlboiler templates and outputs them to files based on the // state given. func (s *State) Run(includeTests bool) error { singletonData := &templateData{ Tables: s.Tables, Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, Dialect: s.Dialect, LQ: strmangle.QuoteCharacter(s.Dialect.LQ), RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } if err := generateSingletonOutput(s, singletonData); err != nil { return errors.Wrap(err, "singleton template output") } if !s.Config.NoTests && includeTests { if err := generateTestMainOutput(s, singletonData); err != nil { return errors.Wrap(err, "unable to generate TestMain output") } if err := generateSingletonTestOutput(s, singletonData); err != nil { return errors.Wrap(err, "unable to generate singleton test template output") } } for _, table := range s.Tables { if table.IsJoinTable { continue } data := &templateData{ Tables: s.Tables, Table: table, Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, Tags: s.Config.Tags, Dialect: s.Dialect, LQ: strmangle.QuoteCharacter(s.Dialect.LQ), RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } // Generate the regular templates if err := generateOutput(s, data); err != nil { return errors.Wrap(err, "unable to generate output") } // Generate the test templates if !s.Config.NoTests && includeTests { if err := generateTestOutput(s, data); err != nil { return errors.Wrap(err, "unable to generate test output") } } } return nil } // Cleanup closes any resources that must be closed func (s *State) Cleanup() error { s.Driver.Close() return nil } // initTemplates loads all template folders into the state object. func (s *State) initTemplates() error { var err error basePath, err := getBasePath(s.Config.BaseDir) if err != nil { return err } s.Templates, err = loadTemplates(filepath.Join(basePath, templatesDirectory)) if err != nil { return err } s.SingletonTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonDirectory)) if err != nil { return err } if !s.Config.NoTests { s.TestTemplates, err = loadTemplates(filepath.Join(basePath, templatesTestDirectory)) if err != nil { return err } s.SingletonTestTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonTestDirectory)) if err != nil { return err } s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl") if err != nil { return err } } return nil } var basePackage = "github.com/vattle/sqlboiler" func getBasePath(baseDirConfig string) (string, error) { if len(baseDirConfig) > 0 { return baseDirConfig, nil } p, _ := build.Default.Import(basePackage, "", build.FindOnly) if p != nil && len(p.Dir) > 0 { return p.Dir, nil } return os.Getwd() } // initDriver attempts to set the state Interface based off the passed in // driver flag value. If an invalid flag string is provided an error is returned. func (s *State) initDriver(driverName string) error { // Create a driver based off driver flag switch driverName { case "postgres": s.Driver = drivers.NewPostgresDriver( s.Config.Postgres.User, s.Config.Postgres.Pass, s.Config.Postgres.DBName, s.Config.Postgres.Host, s.Config.Postgres.Port, s.Config.Postgres.SSLMode, ) case "mysql": s.Driver = drivers.NewMySQLDriver( s.Config.MySQL.User, s.Config.MySQL.Pass, s.Config.MySQL.DBName, s.Config.MySQL.Host, s.Config.MySQL.Port, s.Config.MySQL.SSLMode, ) case "mock": s.Driver = &drivers.MockDriver{} } if s.Driver == nil { return errors.New("An invalid driver name was provided") } s.Dialect.LQ = s.Driver.LeftQuote() s.Dialect.RQ = s.Driver.RightQuote() s.Dialect.IndexPlaceholders = s.Driver.IndexPlaceholders() return nil } // initTables retrieves all "public" schema table names from the database. func (s *State) initTables(schema string, whitelist, blacklist []string) error { var err error s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, blacklist) if err != nil { return errors.Wrap(err, "unable to fetch table data") } if len(s.Tables) == 0 { return errors.New("no tables found in database") } if err := checkPKeys(s.Tables); err != nil { return err } return nil } // Tags must be in a format like: json, xml, etc. var rgxValidTag = regexp.MustCompile(`[a-zA-Z_\.]+`) // initTags removes duplicate tags and validates the format // of all user tags are simple strings without quotes: [a-zA-Z_\.]+ func (s *State) initTags(tags []string) error { s.Config.Tags = removeDuplicates(s.Config.Tags) for _, v := range s.Config.Tags { if !rgxValidTag.MatchString(v) { return errors.New("Invalid tag format %q supplied, only specify name, eg: xml") } } return nil } // initOutFolder creates the folder that will hold the generated output. func (s *State) initOutFolder() error { return os.MkdirAll(s.Config.OutFolder, os.ModePerm) } // checkPKeys ensures every table has a primary key column func checkPKeys(tables []bdb.Table) error { var missingPkey []string for _, t := range tables { if t.PKey == nil { missingPkey = append(missingPkey, t.Name) } } if len(missingPkey) != 0 { return errors.Errorf("primary key missing in tables (%s)", strings.Join(missingPkey, ", ")) } return nil }