Refactored init and fixed TestTemplates
This commit is contained in:
parent
f7a4ed0c54
commit
dd5d643d4f
3 changed files with 76 additions and 48 deletions
|
@ -2,6 +2,7 @@ package cmds
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"io"
|
"io"
|
||||||
|
@ -17,8 +18,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||||
|
|
||||||
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||||
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
||||||
fmt.Println("No template files located for generation")
|
return errors.New("No template files located for generation")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
|
|
|
@ -58,36 +58,32 @@ func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) erro
|
||||||
|
|
||||||
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
||||||
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
||||||
var err error
|
|
||||||
|
|
||||||
// Initialize package name
|
// Initialize package name
|
||||||
cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
||||||
|
|
||||||
err = initInterface(cmd, cmdData.Config, cmdData)
|
// Retrieve driver flag
|
||||||
if err != nil {
|
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
||||||
return err
|
if driverName == "" {
|
||||||
|
return errors.New("Must supply a driver flag.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the driver database
|
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
|
||||||
if err = cmdData.Interface.Open(); err != nil {
|
|
||||||
return fmt.Errorf("Unable to connect to the database: %s", err)
|
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
|
||||||
|
if outFolder == "" {
|
||||||
|
return fmt.Errorf("No output folder specified.")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = initTables(cmd, cmdData)
|
return cmdData.initCmdData(pkgName, driverName, tableName, outFolder)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to initialize tables: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = initOutFolder(cmd, cmdData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLBoilerRun executes every sqlboiler template and outputs them to files.
|
// SQLBoilerRun is a proxy method for the run function
|
||||||
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||||
|
return cmdData.run(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// run executes the sqlboiler templates and outputs them to files.
|
||||||
|
func (cmdData *CmdData) run(includeTests bool) error {
|
||||||
for _, table := range cmdData.Tables {
|
for _, table := range cmdData.Tables {
|
||||||
data := &tplData{
|
data := &tplData{
|
||||||
Table: table,
|
Table: table,
|
||||||
|
@ -100,32 +96,55 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the test templates
|
// Generate the test templates
|
||||||
if err := generateOutput(cmdData, data, true); err != nil {
|
if includeTests {
|
||||||
return fmt.Errorf("Unable to generate output: %s", err)
|
if err := generateOutput(cmdData, data, true); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate output: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cmdData *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
|
||||||
|
cmdData.OutFolder = outFolder
|
||||||
|
cmdData.PkgName = pkgName
|
||||||
|
|
||||||
|
err := initInterface(driverName, cmdData)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the driver database
|
||||||
|
if err = cmdData.Interface.Open(); err != nil {
|
||||||
|
return fmt.Errorf("Unable to connect to the database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = initTables(tableName, cmdData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to initialize tables: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = initOutFolder(cmdData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// initInterface attempts to set the cmdData Interface based off the passed in
|
// initInterface attempts to set the cmdData Interface based off the passed in
|
||||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||||
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
|
func initInterface(driverName string, cmdData *CmdData) error {
|
||||||
// Retrieve driver flag
|
|
||||||
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
|
||||||
if driverName == "" {
|
|
||||||
return errors.New("Must supply a driver flag.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a driver based off driver flag
|
// Create a driver based off driver flag
|
||||||
switch driverName {
|
switch driverName {
|
||||||
case "postgres":
|
case "postgres":
|
||||||
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
||||||
cfg.Postgres.User,
|
cmdData.Config.Postgres.User,
|
||||||
cfg.Postgres.Pass,
|
cmdData.Config.Postgres.Pass,
|
||||||
cfg.Postgres.DBName,
|
cmdData.Config.Postgres.DBName,
|
||||||
cfg.Postgres.Host,
|
cmdData.Config.Postgres.Host,
|
||||||
cfg.Postgres.Port,
|
cmdData.Config.Postgres.Port,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,12 +159,11 @@ func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
|
||||||
// initTables will create a string slice out of the passed in table flag value
|
// initTables will create a string slice out of the passed in table flag value
|
||||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
// database to retrieve all "public" table names, and build a slice out of that result.
|
||||||
func initTables(cmd *cobra.Command, cmdData *CmdData) error {
|
func initTables(tableName string, cmdData *CmdData) error {
|
||||||
var tableNames []string
|
var tableNames []string
|
||||||
tn := cmd.PersistentFlags().Lookup("table").Value.String()
|
|
||||||
|
|
||||||
if len(tn) != 0 {
|
if len(tableName) != 0 {
|
||||||
tableNames = strings.Split(tn, ",")
|
tableNames = strings.Split(tableName, ",")
|
||||||
for i, name := range tableNames {
|
for i, name := range tableNames {
|
||||||
tableNames[i] = strings.TrimSpace(name)
|
tableNames[i] = strings.TrimSpace(name)
|
||||||
}
|
}
|
||||||
|
@ -165,12 +183,7 @@ func initTables(cmd *cobra.Command, cmdData *CmdData) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// initOutFolder creates the folder that will hold the generated output.
|
// initOutFolder creates the folder that will hold the generated output.
|
||||||
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error {
|
func initOutFolder(cmdData *CmdData) error {
|
||||||
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
|
|
||||||
if cmdData.OutFolder == "" {
|
|
||||||
return fmt.Errorf("No output folder specified.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
||||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,13 +53,28 @@ func TestTemplates(t *testing.T) {
|
||||||
t.Fatalf("Unable to initialize templates: %s", err)
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(cmdData.Templates) == 0 {
|
||||||
|
t.Errorf("Templates is empty.")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdData.TestTemplates, err = loadTemplates("templates_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmdData.Templates) == 0 {
|
||||||
|
t.Errorf("Templates is empty.")
|
||||||
|
}
|
||||||
|
|
||||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to create tempdir: %s", err)
|
t.Fatalf("Unable to create tempdir: %s", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(cmdData.OutFolder)
|
defer os.RemoveAll(cmdData.OutFolder)
|
||||||
|
|
||||||
cmdData.SQLBoilerRun(nil, []string{})
|
if err = cmdData.run(true); err != nil {
|
||||||
|
t.Errorf("Unable to run SQLBoilerRun: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
tplFile := cmdData.OutFolder + "/templates_test.go"
|
tplFile := cmdData.OutFolder + "/templates_test.go"
|
||||||
tplTestHandle, err := os.Create(tplFile)
|
tplTestHandle, err := os.Create(tplFile)
|
||||||
|
|
Loading…
Reference in a new issue