package cmds import ( "errors" "fmt" "os" "path/filepath" "sort" "strings" "text/template" "github.com/pobri19/sqlboiler/dbdrivers" "github.com/spf13/cobra" ) const ( templatesDirectory = "/cmds/templates" templatesSinglesDirectory = "/cmds/templates/singles" templatesTestDirectory = "/cmds/templates_test" templatesTestMainDirectory = "/cmds/templates_test/main_test" ) // LoadTemplates loads all template folders into the cmdData object. func initTemplates(cmdData *CmdData) error { var err error cmdData.Templates, err = loadTemplates(templatesDirectory) if err != nil { return err } cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory) if err != nil { return err } cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory) if err != nil { return err } filename := cmdData.DriverName + "_main.tpl" cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename) if err != nil { return err } return nil } // loadTemplates loads all of the template files in the specified directory. func loadTemplates(dir string) ([]*template.Template, error) { wd, err := os.Getwd() if err != nil { return nil, err } pattern := filepath.Join(wd, dir, "*.tpl") tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern) if err != nil { return nil, err } templates := templater(tpl.Templates()) sort.Sort(templates) return templates, err } // loadTemplate loads a single template file. func loadTemplate(dir string, filename string) (*template.Template, error) { wd, err := os.Getwd() if err != nil { return nil, err } pattern := filepath.Join(wd, dir, filename) tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern) if err != nil { return nil, err } return tpl.Lookup(filename), err } // SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished. func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error { c.Interface.Close() return nil } // SQLBoilerPreRun performs the initialization tasks before the root command is run func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error { // Initialize package name pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String() // Retrieve driver flag driverName := cmd.PersistentFlags().Lookup("driver").Value.String() if driverName == "" { return errors.New("Must supply a driver flag.") } tableName := cmd.PersistentFlags().Lookup("table").Value.String() outFolder := cmd.PersistentFlags().Lookup("folder").Value.String() if outFolder == "" { return fmt.Errorf("No output folder specified.") } return c.initCmdData(pkgName, driverName, tableName, outFolder) } // SQLBoilerRun is a proxy method for the run function func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { return c.run(true) } // run executes the sqlboiler templates and outputs them to files. func (c *CmdData) run(includeTests bool) error { if includeTests { if err := generateTestMainOutput(c); err != nil { return fmt.Errorf("Unable to generate TestMain output: %s", err) } } if err := generateSinglesOutput(c); err != nil { return fmt.Errorf("Unable to generate single templates output: %s", err) } for _, table := range c.Tables { data := &tplData{ Table: table, DriverName: c.DriverName, PkgName: c.PkgName, } // Generate the regular templates if err := generateOutput(c, data); err != nil { return fmt.Errorf("Unable to generate output: %s", err) } // Generate the test templates if includeTests { if err := generateTestOutput(c, data); err != nil { return fmt.Errorf("Unable to generate test output: %s", err) } } } return nil } func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error { c.OutFolder = outFolder c.PkgName = pkgName err := initInterface(driverName, c) if err != nil { return err } // Connect to the driver database if err = c.Interface.Open(); err != nil { return fmt.Errorf("Unable to connect to the database: %s", err) } err = initTables(tableName, c) if err != nil { return fmt.Errorf("Unable to initialize tables: %s", err) } err = initOutFolder(c) if err != nil { return fmt.Errorf("Unable to initialize the output folder: %s", err) } err = initTemplates(c) if err != nil { return fmt.Errorf("Unable to initialize templates: %s", err) } return nil } // initInterface attempts to set the cmdData Interface based off the passed in // driver flag value. If an invalid flag string is provided an error is returned. func initInterface(driverName string, cmdData *CmdData) error { // Create a driver based off driver flag switch driverName { case "postgres": cmdData.Interface = dbdrivers.NewPostgresDriver( cmdData.Config.Postgres.User, cmdData.Config.Postgres.Pass, cmdData.Config.Postgres.DBName, cmdData.Config.Postgres.Host, cmdData.Config.Postgres.Port, ) } if cmdData.Interface == nil { return errors.New("An invalid driver name was provided") } cmdData.DriverName = driverName return nil } // initTables will create a string slice out of the passed in table flag value // if one is provided. If no flag is provided, it will attempt to connect to the // database to retrieve all "public" table names, and build a slice out of that result. func initTables(tableName string, cmdData *CmdData) error { var tableNames []string if len(tableName) != 0 { tableNames = strings.Split(tableName, ",") for i, name := range tableNames { tableNames[i] = strings.TrimSpace(name) } } var err error cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...) if err != nil { return fmt.Errorf("Unable to get all table names: %s", err) } if len(cmdData.Tables) == 0 { return errors.New("No tables found in database, migrate some tables first") } if err := checkPKeys(cmdData.Tables); err != nil { return err } return nil } // checkPKeys ensures every table has a primary key column func checkPKeys(tables []dbdrivers.Table) error { var missingPkey []string for _, t := range tables { if t.PKey == nil { missingPkey = append(missingPkey, t.Name) } } if len(missingPkey) != 0 { return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) } return nil } // initOutFolder creates the folder that will hold the generated output. func initOutFolder(cmdData *CmdData) error { if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil { return fmt.Errorf("Unable to make output folder: %s", err) } return nil }