Refactored init and fixed TestTemplates

This commit is contained in:
Patrick O'brien 2016-04-03 16:01:26 +10:00
parent f7a4ed0c54
commit dd5d643d4f
3 changed files with 76 additions and 48 deletions

View file

@ -2,6 +2,7 @@ package cmds
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"go/format" "go/format"
"io" "io"
@ -17,8 +18,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error { func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) { if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
fmt.Println("No template files located for generation") return errors.New("No template files located for generation")
return nil
} }
var out [][]byte var out [][]byte

View file

@ -58,36 +58,32 @@ func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) erro
// SQLBoilerPreRun performs the initialization tasks before the root command is run // SQLBoilerPreRun performs the initialization tasks before the root command is run
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error { func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
var err error
// Initialize package name // Initialize package name
cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String() pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
err = initInterface(cmd, cmdData.Config, cmdData) // Retrieve driver flag
if err != nil { driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
return err if driverName == "" {
return errors.New("Must supply a driver flag.")
} }
// Connect to the driver database tableName := cmd.PersistentFlags().Lookup("table").Value.String()
if err = cmdData.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err) outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
if outFolder == "" {
return fmt.Errorf("No output folder specified.")
} }
err = initTables(cmd, cmdData) return cmdData.initCmdData(pkgName, driverName, tableName, outFolder)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
err = initOutFolder(cmd, cmdData)
if err != nil {
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
return nil
} }
// SQLBoilerRun executes every sqlboiler template and outputs them to files. // SQLBoilerRun is a proxy method for the run function
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
return cmdData.run(true)
}
// run executes the sqlboiler templates and outputs them to files.
func (cmdData *CmdData) run(includeTests bool) error {
for _, table := range cmdData.Tables { for _, table := range cmdData.Tables {
data := &tplData{ data := &tplData{
Table: table, Table: table,
@ -100,32 +96,55 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
} }
// Generate the test templates // Generate the test templates
if err := generateOutput(cmdData, data, true); err != nil { if includeTests {
return fmt.Errorf("Unable to generate output: %s", err) if err := generateOutput(cmdData, data, true); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
} }
} }
return nil return nil
} }
func (cmdData *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
cmdData.OutFolder = outFolder
cmdData.PkgName = pkgName
err := initInterface(driverName, cmdData)
if err != nil {
return err
}
// Connect to the driver database
if err = cmdData.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err)
}
err = initTables(tableName, cmdData)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
err = initOutFolder(cmdData)
if err != nil {
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
return nil
}
// initInterface attempts to set the cmdData Interface based off the passed in // initInterface attempts to set the cmdData Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned. // driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error { func initInterface(driverName string, cmdData *CmdData) error {
// Retrieve driver flag
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
return errors.New("Must supply a driver flag.")
}
// Create a driver based off driver flag // Create a driver based off driver flag
switch driverName { switch driverName {
case "postgres": case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver( cmdData.Interface = dbdrivers.NewPostgresDriver(
cfg.Postgres.User, cmdData.Config.Postgres.User,
cfg.Postgres.Pass, cmdData.Config.Postgres.Pass,
cfg.Postgres.DBName, cmdData.Config.Postgres.DBName,
cfg.Postgres.Host, cmdData.Config.Postgres.Host,
cfg.Postgres.Port, cmdData.Config.Postgres.Port,
) )
} }
@ -140,12 +159,11 @@ func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
// initTables will create a string slice out of the passed in table flag value // initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the // if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result. // database to retrieve all "public" table names, and build a slice out of that result.
func initTables(cmd *cobra.Command, cmdData *CmdData) error { func initTables(tableName string, cmdData *CmdData) error {
var tableNames []string var tableNames []string
tn := cmd.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 { if len(tableName) != 0 {
tableNames = strings.Split(tn, ",") tableNames = strings.Split(tableName, ",")
for i, name := range tableNames { for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name) tableNames[i] = strings.TrimSpace(name)
} }
@ -165,12 +183,7 @@ func initTables(cmd *cobra.Command, cmdData *CmdData) error {
} }
// initOutFolder creates the folder that will hold the generated output. // initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error { func initOutFolder(cmdData *CmdData) error {
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
if cmdData.OutFolder == "" {
return fmt.Errorf("No output folder specified.")
}
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil { if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err) return fmt.Errorf("Unable to make output folder: %s", err)
} }

View file

@ -53,13 +53,28 @@ func TestTemplates(t *testing.T) {
t.Fatalf("Unable to initialize templates: %s", err) t.Fatalf("Unable to initialize templates: %s", err)
} }
if len(cmdData.Templates) == 0 {
t.Errorf("Templates is empty.")
}
cmdData.TestTemplates, err = loadTemplates("templates_test")
if err != nil {
t.Fatalf("Unable to initialize templates: %s", err)
}
if len(cmdData.Templates) == 0 {
t.Errorf("Templates is empty.")
}
cmdData.OutFolder, err = ioutil.TempDir("", "templates") cmdData.OutFolder, err = ioutil.TempDir("", "templates")
if err != nil { if err != nil {
t.Fatalf("Unable to create tempdir: %s", err) t.Fatalf("Unable to create tempdir: %s", err)
} }
defer os.RemoveAll(cmdData.OutFolder) defer os.RemoveAll(cmdData.OutFolder)
cmdData.SQLBoilerRun(nil, []string{}) if err = cmdData.run(true); err != nil {
t.Errorf("Unable to run SQLBoilerRun: %s", err)
}
tplFile := cmdData.OutFolder + "/templates_test.go" tplFile := cmdData.OutFolder + "/templates_test.go"
tplTestHandle, err := os.Create(tplFile) tplTestHandle, err := os.Create(tplFile)