Refactored init and fixed TestTemplates
This commit is contained in:
parent
f7a4ed0c54
commit
dd5d643d4f
3 changed files with 76 additions and 48 deletions
|
@ -2,6 +2,7 @@ package cmds
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
|
@ -17,8 +18,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
|||
|
||||
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
||||
fmt.Println("No template files located for generation")
|
||||
return nil
|
||||
return errors.New("No template files located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
|
|
|
@ -58,36 +58,32 @@ func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) erro
|
|||
|
||||
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
||||
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
||||
var err error
|
||||
|
||||
// Initialize package name
|
||||
cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
||||
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
||||
|
||||
err = initInterface(cmd, cmdData.Config, cmdData)
|
||||
if err != nil {
|
||||
return err
|
||||
// Retrieve driver flag
|
||||
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
||||
if driverName == "" {
|
||||
return errors.New("Must supply a driver flag.")
|
||||
}
|
||||
|
||||
// Connect to the driver database
|
||||
if err = cmdData.Interface.Open(); err != nil {
|
||||
return fmt.Errorf("Unable to connect to the database: %s", err)
|
||||
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
|
||||
|
||||
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
|
||||
if outFolder == "" {
|
||||
return fmt.Errorf("No output folder specified.")
|
||||
}
|
||||
|
||||
err = initTables(cmd, cmdData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize tables: %s", err)
|
||||
}
|
||||
|
||||
err = initOutFolder(cmd, cmdData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return cmdData.initCmdData(pkgName, driverName, tableName, outFolder)
|
||||
}
|
||||
|
||||
// SQLBoilerRun executes every sqlboiler template and outputs them to files.
|
||||
// SQLBoilerRun is a proxy method for the run function
|
||||
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||
return cmdData.run(true)
|
||||
}
|
||||
|
||||
// run executes the sqlboiler templates and outputs them to files.
|
||||
func (cmdData *CmdData) run(includeTests bool) error {
|
||||
for _, table := range cmdData.Tables {
|
||||
data := &tplData{
|
||||
Table: table,
|
||||
|
@ -100,32 +96,55 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
|
||||
// Generate the test templates
|
||||
if includeTests {
|
||||
if err := generateOutput(cmdData, data, true); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmdData *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
|
||||
cmdData.OutFolder = outFolder
|
||||
cmdData.PkgName = pkgName
|
||||
|
||||
err := initInterface(driverName, cmdData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect to the driver database
|
||||
if err = cmdData.Interface.Open(); err != nil {
|
||||
return fmt.Errorf("Unable to connect to the database: %s", err)
|
||||
}
|
||||
|
||||
err = initTables(tableName, cmdData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize tables: %s", err)
|
||||
}
|
||||
|
||||
err = initOutFolder(cmdData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initInterface attempts to set the cmdData Interface based off the passed in
|
||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
|
||||
// Retrieve driver flag
|
||||
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
||||
if driverName == "" {
|
||||
return errors.New("Must supply a driver flag.")
|
||||
}
|
||||
|
||||
func initInterface(driverName string, cmdData *CmdData) error {
|
||||
// Create a driver based off driver flag
|
||||
switch driverName {
|
||||
case "postgres":
|
||||
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
||||
cfg.Postgres.User,
|
||||
cfg.Postgres.Pass,
|
||||
cfg.Postgres.DBName,
|
||||
cfg.Postgres.Host,
|
||||
cfg.Postgres.Port,
|
||||
cmdData.Config.Postgres.User,
|
||||
cmdData.Config.Postgres.Pass,
|
||||
cmdData.Config.Postgres.DBName,
|
||||
cmdData.Config.Postgres.Host,
|
||||
cmdData.Config.Postgres.Port,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -140,12 +159,11 @@ func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
|
|||
// initTables will create a string slice out of the passed in table flag value
|
||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
||||
func initTables(cmd *cobra.Command, cmdData *CmdData) error {
|
||||
func initTables(tableName string, cmdData *CmdData) error {
|
||||
var tableNames []string
|
||||
tn := cmd.PersistentFlags().Lookup("table").Value.String()
|
||||
|
||||
if len(tn) != 0 {
|
||||
tableNames = strings.Split(tn, ",")
|
||||
if len(tableName) != 0 {
|
||||
tableNames = strings.Split(tableName, ",")
|
||||
for i, name := range tableNames {
|
||||
tableNames[i] = strings.TrimSpace(name)
|
||||
}
|
||||
|
@ -165,12 +183,7 @@ func initTables(cmd *cobra.Command, cmdData *CmdData) error {
|
|||
}
|
||||
|
||||
// initOutFolder creates the folder that will hold the generated output.
|
||||
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error {
|
||||
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
|
||||
if cmdData.OutFolder == "" {
|
||||
return fmt.Errorf("No output folder specified.")
|
||||
}
|
||||
|
||||
func initOutFolder(cmdData *CmdData) error {
|
||||
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||
}
|
||||
|
|
|
@ -53,13 +53,28 @@ func TestTemplates(t *testing.T) {
|
|||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.Templates) == 0 {
|
||||
t.Errorf("Templates is empty.")
|
||||
}
|
||||
|
||||
cmdData.TestTemplates, err = loadTemplates("templates_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.Templates) == 0 {
|
||||
t.Errorf("Templates is empty.")
|
||||
}
|
||||
|
||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create tempdir: %s", err)
|
||||
}
|
||||
defer os.RemoveAll(cmdData.OutFolder)
|
||||
|
||||
cmdData.SQLBoilerRun(nil, []string{})
|
||||
if err = cmdData.run(true); err != nil {
|
||||
t.Errorf("Unable to run SQLBoilerRun: %s", err)
|
||||
}
|
||||
|
||||
tplFile := cmdData.OutFolder + "/templates_test.go"
|
||||
tplTestHandle, err := os.Create(tplFile)
|
||||
|
|
Loading…
Reference in a new issue