sqlboiler/cmds/sqlboiler.go
Patrick O'brien 7aba7104a5 Fix null package imports, finish Bind
* Fix randomizeStruct time randomization
* Defer close sql.Rows
* Begin Delete test template
2016-06-08 15:45:34 +10:00

274 lines
6.9 KiB
Go

package cmds
import (
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
const (
templatesDirectory = "/cmds/templates"
templatesSinglesDirectory = "/cmds/templates/singles"
templatesTestDirectory = "/cmds/templates_test"
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
templatesTestMainDirectory = "/cmds/templates_test/main_test"
)
// LoadTemplates loads all template folders into the cmdData object.
func initTemplates(cmdData *CmdData) error {
var err error
cmdData.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory)
if err != nil {
return err
}
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
if err != nil {
return err
}
filename := cmdData.DriverName + "_main.tpl"
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
if err != nil {
return err
}
return nil
}
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil {
return nil, err
}
templates := templater(tpl.Templates())
sort.Sort(templates)
return templates, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
if err != nil {
return nil, err
}
return tpl.Lookup(filename), err
}
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
c.Interface.Close()
return nil
}
// SQLBoilerPreRun performs the initialization tasks before the root command is run
func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
// Initialize package name
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
// Retrieve driver flag
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
return errors.New("Must supply a driver flag.")
}
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
if outFolder == "" {
return fmt.Errorf("No output folder specified.")
}
return c.initCmdData(pkgName, driverName, tableName, outFolder)
}
// SQLBoilerRun is a proxy method for the run function
func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
return c.run(true)
}
// run executes the sqlboiler templates and outputs them to files.
func (c *CmdData) run(includeTests bool) error {
if err := generateSinglesOutput(c); err != nil {
return fmt.Errorf("Unable to generate single templates output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(c); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
if err := generateSinglesTestOutput(c); err != nil {
return fmt.Errorf("Unable to generate single test templates output: %s", err)
}
}
for _, table := range c.Tables {
data := &tplData{
Table: table,
DriverName: c.DriverName,
PkgName: c.PkgName,
}
// Generate the regular templates
if err := generateOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}
return nil
}
func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
c.OutFolder = outFolder
c.PkgName = pkgName
err := initInterface(driverName, c)
if err != nil {
return err
}
// Connect to the driver database
if err = c.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err)
}
err = initTables(tableName, c)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
err = initOutFolder(c)
if err != nil {
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
err = initTemplates(c)
if err != nil {
return fmt.Errorf("Unable to initialize templates: %s", err)
}
return nil
}
// initInterface attempts to set the cmdData Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(driverName string, cmdData *CmdData) error {
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver(
cmdData.Config.Postgres.User,
cmdData.Config.Postgres.Pass,
cmdData.Config.Postgres.DBName,
cmdData.Config.Postgres.Host,
cmdData.Config.Postgres.Port,
)
}
if cmdData.Interface == nil {
return errors.New("An invalid driver name was provided")
}
cmdData.DriverName = driverName
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTables(tableName string, cmdData *CmdData) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(cmdData.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
if err := checkPKeys(cmdData.Tables); err != nil {
return err
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmdData *CmdData) error {
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
}