sqlboiler/cmds/sqlboiler.go

180 lines
4.6 KiB
Go
Raw Normal View History

package cmds
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"text/template"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
const (
templatesDirectory = "/cmds/templates"
templatesTestDirectory = "/cmds/templates_test"
)
2016-03-27 17:03:14 +02:00
// LoadTemplates loads all template folders into the cmdData object.
func (cmdData *CmdData) LoadTemplates() error {
var err error
cmdData.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
2016-03-27 17:03:14 +02:00
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
2016-03-27 17:03:14 +02:00
return nil
}
2016-03-27 17:03:14 +02:00
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil {
return nil, err
}
return tpl.Templates(), err
}
2016-03-27 17:03:14 +02:00
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
2016-03-23 05:25:57 +01:00
cmdData.Interface.Close()
2016-03-27 17:03:14 +02:00
return nil
}
2016-03-27 17:03:14 +02:00
// SQLBoilerPreRun performs the initialization tasks before the root command is run
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
var err error
2016-03-27 17:03:14 +02:00
// Initialize package name
cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String()
err = initInterface(cmd, cmdData.Config, cmdData)
if err != nil {
return err
}
// Connect to the driver database
2016-03-23 05:25:57 +01:00
if err = cmdData.Interface.Open(); err != nil {
2016-03-27 17:03:14 +02:00
return fmt.Errorf("Unable to connect to the database: %s", err)
}
2016-03-27 17:03:14 +02:00
err = initTables(cmd, cmdData)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
2016-03-27 17:03:14 +02:00
err = initOutFolder(cmd, cmdData)
if err != nil {
2016-03-27 17:03:14 +02:00
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
2016-03-27 17:03:14 +02:00
return nil
}
// SQLBoilerRun executes every sqlboiler template and outputs them to files.
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
for _, table := range cmdData.Tables {
data := &tplData{
Table: table,
PkgName: cmdData.PkgName,
}
// Generate the regular templates
if err := generateOutput(cmdData, data, false); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
2016-03-27 17:03:14 +02:00
}
// Generate the test templates
if err := generateOutput(cmdData, data, true); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
}
2016-03-27 17:03:14 +02:00
return nil
}
2016-03-23 05:25:57 +01:00
// initInterface attempts to set the cmdData Interface based off the passed in
2016-03-27 17:03:14 +02:00
// driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
// Retrieve driver flag
2016-03-27 17:03:14 +02:00
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
2016-03-27 17:03:14 +02:00
return errors.New("Must supply a driver flag.")
}
// Create a driver based off driver flag
switch driverName {
case "postgres":
2016-03-23 05:25:57 +01:00
cmdData.Interface = dbdrivers.NewPostgresDriver(
cfg.Postgres.User,
cfg.Postgres.Pass,
cfg.Postgres.DBName,
cfg.Postgres.Host,
cfg.Postgres.Port,
)
}
2016-03-23 05:25:57 +01:00
if cmdData.Interface == nil {
2016-03-27 17:03:14 +02:00
return errors.New("An invalid driver name was provided")
}
cmdData.DriverName = driverName
2016-03-27 17:03:14 +02:00
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
2016-03-27 17:03:14 +02:00
func initTables(cmd *cobra.Command, cmdData *CmdData) error {
var tableNames []string
2016-03-27 17:03:14 +02:00
tn := cmd.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 {
tableNames = strings.Split(tn, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
if err != nil {
2016-03-27 17:03:14 +02:00
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(cmdData.Tables) == 0 {
2016-03-27 17:03:14 +02:00
return errors.New("No tables found in database, migrate some tables first")
}
2016-03-27 17:03:14 +02:00
return nil
}
2016-03-27 17:03:14 +02:00
// initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error {
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
if cmdData.OutFolder == "" {
2016-03-27 17:03:14 +02:00
return fmt.Errorf("No output folder specified.")
}
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
2016-03-27 17:03:14 +02:00
return fmt.Errorf("Unable to make output folder: %s", err)
}
2016-03-27 17:03:14 +02:00
return nil
}