2016-02-23 09:27:32 +01:00
|
|
|
package cmds
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
2016-02-29 04:30:54 +01:00
|
|
|
"fmt"
|
2016-02-24 06:40:07 +01:00
|
|
|
"os"
|
2016-02-29 04:30:54 +01:00
|
|
|
"path/filepath"
|
2016-02-23 09:27:32 +01:00
|
|
|
"strings"
|
2016-02-29 04:30:54 +01:00
|
|
|
"text/template"
|
2016-02-23 09:27:32 +01:00
|
|
|
|
|
|
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
)
|
|
|
|
|
2016-03-02 05:05:25 +01:00
|
|
|
const (
|
2016-03-21 06:15:14 +01:00
|
|
|
templatesDirectory = "/cmds/templates"
|
|
|
|
templatesTestDirectory = "/cmds/templates_test"
|
2016-03-02 05:05:25 +01:00
|
|
|
)
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
// LoadTemplates loads all template folders into the cmdData object.
|
|
|
|
func (cmdData *CmdData) LoadTemplates() error {
|
|
|
|
var err error
|
|
|
|
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2016-02-29 04:30:54 +01:00
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2016-03-21 06:15:14 +01:00
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
return nil
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
// loadTemplates loads all of the template files in the specified directory.
|
|
|
|
func loadTemplates(dir string) ([]*template.Template, error) {
|
|
|
|
wd, err := os.Getwd()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
pattern := filepath.Join(wd, dir, "*.tpl")
|
|
|
|
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return tpl.Templates(), err
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
|
|
|
func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
|
2016-03-23 05:25:57 +01:00
|
|
|
cmdData.Interface.Close()
|
2016-03-27 17:03:14 +02:00
|
|
|
return nil
|
2016-02-24 06:40:07 +01:00
|
|
|
}
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
|
|
|
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
|
|
|
// Initialize package name
|
2016-04-03 08:01:26 +02:00
|
|
|
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
2016-03-27 17:03:14 +02:00
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
// Retrieve driver flag
|
|
|
|
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
|
|
|
if driverName == "" {
|
|
|
|
return errors.New("Must supply a driver flag.")
|
2016-02-24 09:53:34 +01:00
|
|
|
}
|
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
|
2016-02-29 04:30:54 +01:00
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
|
|
|
|
if outFolder == "" {
|
|
|
|
return fmt.Errorf("No output folder specified.")
|
2016-02-29 04:30:54 +01:00
|
|
|
}
|
2016-03-21 06:15:14 +01:00
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
return cmdData.initCmdData(pkgName, driverName, tableName, outFolder)
|
2016-03-27 17:03:14 +02:00
|
|
|
}
|
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
// SQLBoilerRun is a proxy method for the run function
|
2016-03-27 17:03:14 +02:00
|
|
|
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
2016-04-03 08:01:26 +02:00
|
|
|
return cmdData.run(true)
|
|
|
|
}
|
|
|
|
|
|
|
|
// run executes the sqlboiler templates and outputs them to files.
|
|
|
|
func (cmdData *CmdData) run(includeTests bool) error {
|
2016-03-27 17:03:14 +02:00
|
|
|
for _, table := range cmdData.Tables {
|
|
|
|
data := &tplData{
|
|
|
|
Table: table,
|
|
|
|
PkgName: cmdData.PkgName,
|
|
|
|
}
|
|
|
|
|
2016-03-28 10:17:41 +02:00
|
|
|
// Generate the regular templates
|
|
|
|
if err := generateOutput(cmdData, data, false); err != nil {
|
|
|
|
return fmt.Errorf("Unable to generate test output: %s", err)
|
2016-03-27 17:03:14 +02:00
|
|
|
}
|
|
|
|
|
2016-03-28 10:17:41 +02:00
|
|
|
// Generate the test templates
|
2016-04-03 08:01:26 +02:00
|
|
|
if includeTests {
|
|
|
|
if err := generateOutput(cmdData, data, true); err != nil {
|
|
|
|
return fmt.Errorf("Unable to generate output: %s", err)
|
|
|
|
}
|
2016-03-23 04:03:35 +01:00
|
|
|
}
|
2016-03-21 06:15:14 +01:00
|
|
|
}
|
2016-03-27 17:03:14 +02:00
|
|
|
|
|
|
|
return nil
|
2016-02-24 09:53:34 +01:00
|
|
|
}
|
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
func (cmdData *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
|
|
|
|
cmdData.OutFolder = outFolder
|
|
|
|
cmdData.PkgName = pkgName
|
|
|
|
|
|
|
|
err := initInterface(driverName, cmdData)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Connect to the driver database
|
|
|
|
if err = cmdData.Interface.Open(); err != nil {
|
|
|
|
return fmt.Errorf("Unable to connect to the database: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
err = initTables(tableName, cmdData)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to initialize tables: %s", err)
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
err = initOutFolder(cmdData)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// initInterface attempts to set the cmdData Interface based off the passed in
|
|
|
|
// driver flag value. If an invalid flag string is provided an error is returned.
|
|
|
|
func initInterface(driverName string, cmdData *CmdData) error {
|
2016-02-23 09:27:32 +01:00
|
|
|
// Create a driver based off driver flag
|
|
|
|
switch driverName {
|
|
|
|
case "postgres":
|
2016-03-23 05:25:57 +01:00
|
|
|
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
2016-04-03 08:01:26 +02:00
|
|
|
cmdData.Config.Postgres.User,
|
|
|
|
cmdData.Config.Postgres.Pass,
|
|
|
|
cmdData.Config.Postgres.DBName,
|
|
|
|
cmdData.Config.Postgres.Host,
|
|
|
|
cmdData.Config.Postgres.Port,
|
2016-02-23 09:27:32 +01:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
2016-03-23 05:25:57 +01:00
|
|
|
if cmdData.Interface == nil {
|
2016-03-27 17:03:14 +02:00
|
|
|
return errors.New("An invalid driver name was provided")
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
2016-03-26 07:54:55 +01:00
|
|
|
|
|
|
|
cmdData.DriverName = driverName
|
2016-03-27 17:03:14 +02:00
|
|
|
return nil
|
2016-02-24 09:53:34 +01:00
|
|
|
}
|
2016-02-23 09:27:32 +01:00
|
|
|
|
2016-03-02 05:05:25 +01:00
|
|
|
// initTables will create a string slice out of the passed in table flag value
|
2016-02-24 09:53:34 +01:00
|
|
|
// if one is provided. If no flag is provided, it will attempt to connect to the
|
|
|
|
// database to retrieve all "public" table names, and build a slice out of that result.
|
2016-04-03 08:01:26 +02:00
|
|
|
func initTables(tableName string, cmdData *CmdData) error {
|
2016-03-23 06:05:23 +01:00
|
|
|
var tableNames []string
|
|
|
|
|
2016-04-03 08:01:26 +02:00
|
|
|
if len(tableName) != 0 {
|
|
|
|
tableNames = strings.Split(tableName, ",")
|
2016-03-23 06:05:23 +01:00
|
|
|
for i, name := range tableNames {
|
|
|
|
tableNames[i] = strings.TrimSpace(name)
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-03-23 06:05:23 +01:00
|
|
|
var err error
|
|
|
|
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
|
|
|
|
if err != nil {
|
2016-03-27 17:03:14 +02:00
|
|
|
return fmt.Errorf("Unable to get all table names: %s", err)
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
|
|
|
|
2016-03-23 06:05:23 +01:00
|
|
|
if len(cmdData.Tables) == 0 {
|
2016-03-27 17:03:14 +02:00
|
|
|
return errors.New("No tables found in database, migrate some tables first")
|
2016-02-23 09:27:32 +01:00
|
|
|
}
|
2016-02-24 06:40:07 +01:00
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
return nil
|
2016-03-01 15:20:13 +01:00
|
|
|
}
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
// initOutFolder creates the folder that will hold the generated output.
|
2016-04-03 08:01:26 +02:00
|
|
|
func initOutFolder(cmdData *CmdData) error {
|
2016-03-01 15:20:13 +01:00
|
|
|
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
2016-03-27 17:03:14 +02:00
|
|
|
return fmt.Errorf("Unable to make output folder: %s", err)
|
2016-02-29 12:45:28 +01:00
|
|
|
}
|
|
|
|
|
2016-03-27 17:03:14 +02:00
|
|
|
return nil
|
2016-02-29 06:49:26 +01:00
|
|
|
}
|