diff --git a/cmds/output.go b/cmds/output.go index b8d3658..2a26571 100644 --- a/cmds/output.go +++ b/cmds/output.go @@ -2,6 +2,7 @@ package cmds import ( "bytes" + "errors" "fmt" "go/format" "io" @@ -17,8 +18,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error { if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) { - fmt.Println("No template files located for generation") - return nil + return errors.New("No template files located for generation") } var out [][]byte diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index 23a17ed..7d49fee 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -58,36 +58,32 @@ func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) erro // SQLBoilerPreRun performs the initialization tasks before the root command is run func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error { - var err error - // Initialize package name - cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String() + pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String() - err = initInterface(cmd, cmdData.Config, cmdData) - if err != nil { - return err + // Retrieve driver flag + driverName := cmd.PersistentFlags().Lookup("driver").Value.String() + if driverName == "" { + return errors.New("Must supply a driver flag.") } - // Connect to the driver database - if err = cmdData.Interface.Open(); err != nil { - return fmt.Errorf("Unable to connect to the database: %s", err) + tableName := cmd.PersistentFlags().Lookup("table").Value.String() + + outFolder := cmd.PersistentFlags().Lookup("folder").Value.String() + if outFolder == "" { + return fmt.Errorf("No output folder specified.") } - err = initTables(cmd, cmdData) - if err != nil { - return fmt.Errorf("Unable to initialize tables: %s", err) - } - - err = initOutFolder(cmd, cmdData) - if err != nil { - return fmt.Errorf("Unable to initialize the output folder: %s", err) - } - - return nil + return cmdData.initCmdData(pkgName, driverName, tableName, outFolder) } -// SQLBoilerRun executes every sqlboiler template and outputs them to files. +// SQLBoilerRun is a proxy method for the run function func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { + return cmdData.run(true) +} + +// run executes the sqlboiler templates and outputs them to files. +func (cmdData *CmdData) run(includeTests bool) error { for _, table := range cmdData.Tables { data := &tplData{ Table: table, @@ -100,32 +96,55 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { } // Generate the test templates - if err := generateOutput(cmdData, data, true); err != nil { - return fmt.Errorf("Unable to generate output: %s", err) + if includeTests { + if err := generateOutput(cmdData, data, true); err != nil { + return fmt.Errorf("Unable to generate output: %s", err) + } } } return nil } +func (cmdData *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error { + cmdData.OutFolder = outFolder + cmdData.PkgName = pkgName + + err := initInterface(driverName, cmdData) + if err != nil { + return err + } + + // Connect to the driver database + if err = cmdData.Interface.Open(); err != nil { + return fmt.Errorf("Unable to connect to the database: %s", err) + } + + err = initTables(tableName, cmdData) + if err != nil { + return fmt.Errorf("Unable to initialize tables: %s", err) + } + + err = initOutFolder(cmdData) + if err != nil { + return fmt.Errorf("Unable to initialize the output folder: %s", err) + } + + return nil +} + // initInterface attempts to set the cmdData Interface based off the passed in // driver flag value. If an invalid flag string is provided an error is returned. -func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error { - // Retrieve driver flag - driverName := cmd.PersistentFlags().Lookup("driver").Value.String() - if driverName == "" { - return errors.New("Must supply a driver flag.") - } - +func initInterface(driverName string, cmdData *CmdData) error { // Create a driver based off driver flag switch driverName { case "postgres": cmdData.Interface = dbdrivers.NewPostgresDriver( - cfg.Postgres.User, - cfg.Postgres.Pass, - cfg.Postgres.DBName, - cfg.Postgres.Host, - cfg.Postgres.Port, + cmdData.Config.Postgres.User, + cmdData.Config.Postgres.Pass, + cmdData.Config.Postgres.DBName, + cmdData.Config.Postgres.Host, + cmdData.Config.Postgres.Port, ) } @@ -140,12 +159,11 @@ func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error { // initTables will create a string slice out of the passed in table flag value // if one is provided. If no flag is provided, it will attempt to connect to the // database to retrieve all "public" table names, and build a slice out of that result. -func initTables(cmd *cobra.Command, cmdData *CmdData) error { +func initTables(tableName string, cmdData *CmdData) error { var tableNames []string - tn := cmd.PersistentFlags().Lookup("table").Value.String() - if len(tn) != 0 { - tableNames = strings.Split(tn, ",") + if len(tableName) != 0 { + tableNames = strings.Split(tableName, ",") for i, name := range tableNames { tableNames[i] = strings.TrimSpace(name) } @@ -165,12 +183,7 @@ func initTables(cmd *cobra.Command, cmdData *CmdData) error { } // initOutFolder creates the folder that will hold the generated output. -func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error { - cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String() - if cmdData.OutFolder == "" { - return fmt.Errorf("No output folder specified.") - } - +func initOutFolder(cmdData *CmdData) error { if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil { return fmt.Errorf("Unable to make output folder: %s", err) } diff --git a/cmds/sqlboiler_test.go b/cmds/sqlboiler_test.go index e26bbe8..3929a52 100644 --- a/cmds/sqlboiler_test.go +++ b/cmds/sqlboiler_test.go @@ -53,13 +53,28 @@ func TestTemplates(t *testing.T) { t.Fatalf("Unable to initialize templates: %s", err) } + if len(cmdData.Templates) == 0 { + t.Errorf("Templates is empty.") + } + + cmdData.TestTemplates, err = loadTemplates("templates_test") + if err != nil { + t.Fatalf("Unable to initialize templates: %s", err) + } + + if len(cmdData.Templates) == 0 { + t.Errorf("Templates is empty.") + } + cmdData.OutFolder, err = ioutil.TempDir("", "templates") if err != nil { t.Fatalf("Unable to create tempdir: %s", err) } defer os.RemoveAll(cmdData.OutFolder) - cmdData.SQLBoilerRun(nil, []string{}) + if err = cmdData.run(true); err != nil { + t.Errorf("Unable to run SQLBoilerRun: %s", err) + } tplFile := cmdData.OutFolder + "/templates_test.go" tplTestHandle, err := os.Create(tplFile)