diff --git a/cmds/boil.go b/cmds/boil.go index b4a33e8..f4bfca9 100644 --- a/cmds/boil.go +++ b/cmds/boil.go @@ -13,12 +13,37 @@ var boilCmd = &cobra.Command{ // boilRun executes every sqlboiler command, starting with structs. func boilRun(cmd *cobra.Command, args []string) { + commandNames := buildCommandList() + + // Prepend "struct" command to templateNames slice so it sits at top of sort + commandNames = append([]string{"struct"}, commandNames...) + + for i := 0; i < len(cmdData.Columns); i++ { + data := tplData{ + Table: cmdData.Tables[i], + Columns: cmdData.Columns[i], + } + + var out [][]byte + // Loop through and generate every command template (excluding skipTemplates) + for _, command := range commandNames { + out = append(out, generateTemplate(command, &data)) + } + + err := outHandler(cmdData.OutFolder, out, &data) + if err != nil { + errorQuit(err) + } + } +} + +func buildCommandList() []string { // Exclude these commands from the output skipTemplates := []string{ "boil", } - var templateNames []string + var commandNames []string // Build a list of template names for _, c := range sqlBoilerCommands { @@ -34,31 +59,11 @@ func boilRun(cmd *cobra.Command, args []string) { } if !skip { - templateNames = append(templateNames, c.Name()) + commandNames = append(commandNames, c.Name()) } } // Sort all names alphabetically - sort.Strings(templateNames) - - // Prepend "struct" command to templateNames slice so it sits at top of sort - templateNames = append([]string{"struct"}, templateNames...) - - for i := 0; i < len(cmdData.Columns); i++ { - data := tplData{ - Table: cmdData.Tables[i], - Columns: cmdData.Columns[i], - } - - var out [][]byte - // Loop through and generate every command template (excluding skipTemplates) - for _, n := range templateNames { - out = append(out, generateTemplate(n, &data)) - } - - err := outHandler(out, &data) - if err != nil { - errorQuit(err) - } - } + sort.Strings(commandNames) + return commandNames } diff --git a/cmds/boil_test.go b/cmds/boil_test.go new file mode 100644 index 0000000..0e68b0e --- /dev/null +++ b/cmds/boil_test.go @@ -0,0 +1,38 @@ +package cmds + +import "testing" + +func TestBuildCommandList(t *testing.T) { + list := buildCommandList() + + skips := []string{"struct", "boil"} + + for _, item := range list { + for _, skipItem := range skips { + if item == skipItem { + t.Errorf("Did not expect to find: %s %#v", item, list) + } + } + } + +CommandNameLoop: + for cmdName := range sqlBoilerCommands { + for _, skipItem := range skips { + if cmdName == skipItem { + continue CommandNameLoop + } + } + + found := false + for _, item := range list { + if item == cmdName { + found = true + break + } + } + + if !found { + t.Error("Expected to find command name:", cmdName) + } + } +} diff --git a/cmds/shared.go b/cmds/shared.go index 54ccfb8..3870175 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -2,6 +2,7 @@ package cmds import ( "fmt" + "io" "os" "github.com/pobri19/sqlboiler/dbdrivers" @@ -49,44 +50,37 @@ func defaultRun(cmd *cobra.Command, args []string) { // execution output to a [][]byte before sending it to outHandler. out := [][]byte{generateTemplate(cmd.Name(), &data)} - err := outHandler(out, &data) + err := outHandler(cmdData.OutFolder, out, &data) if err != nil { errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err)) } } } +var testHarnessStdout io.Writer = os.Stdout +var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { + file, err := os.Create(filename) + return file, err +} + // outHandler loops over the slice of byte slices, outputting them to either // the OutFile if it is specified with a flag, or to Stdout if no flag is specified. -func outHandler(output [][]byte, data *tplData) error { - nl := []byte{'\n'} +func outHandler(outFolder string, output [][]byte, data *tplData) error { + out := testHarnessStdout - if cmdData.OutFolder == "" { - for _, v := range output { - if _, err := os.Stdout.Write(v); err != nil { - return err - } - - if _, err := os.Stdout.Write(nl); err != nil { - return err - } - } - } else { // If not using stdout, attempt to create the model file. - path := cmdData.OutFolder + "/" + data.Table + ".go" - out, err := os.Create(path) + if len(outFolder) != 0 { + path := outFolder + "/" + data.Table + ".go" + outFile, err := testHarnessFileOpen(path) if err != nil { errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err)) } + defer outFile.Close() + out = outFile + } - // Combine the slice of slice into a single byte slice. - var newOutput []byte - for _, v := range output { - newOutput = append(newOutput, v...) - newOutput = append(newOutput, nl...) - } - - if _, err := out.Write(newOutput); err != nil { - return err + for _, templateOutput := range output { + if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil { + errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err)) } } diff --git a/cmds/shared_test.go b/cmds/shared_test.go new file mode 100644 index 0000000..130cf57 --- /dev/null +++ b/cmds/shared_test.go @@ -0,0 +1,69 @@ +package cmds + +import ( + "bytes" + "io" + "testing" +) + +func TestOutHandler(t *testing.T) { + buf := &bytes.Buffer{} + + saveTestHarnessStdout := testHarnessStdout + testHarnessStdout = buf + defer func() { + testHarnessStdout = saveTestHarnessStdout + }() + + data := tplData{ + Table: "patrick", + } + + templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} + + if err := outHandler("", templateOutputs, &data); err != nil { + t.Error(err) + } + + if out := buf.String(); out != "hello world\npatrick's dreams\n" { + t.Errorf("Wrong output: %q", out) + } +} + +type NopWriteCloser struct { + io.Writer +} + +func (NopWriteCloser) Close() error { + return nil +} + +func nopCloser(w io.Writer) io.WriteCloser { + return NopWriteCloser{w} +} + +func TestOutHandlerFiles(t *testing.T) { + saveTestHarnessFileOpen := testHarnessFileOpen + defer func() { + testHarnessFileOpen = saveTestHarnessFileOpen + }() + + file := &bytes.Buffer{} + testHarnessFileOpen = func(path string) (io.WriteCloser, error) { + return nopCloser(file), nil + } + + data := tplData{ + Table: "patrick", + } + + templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} + + if err := outHandler("folder", templateOutputs, &data); err != nil { + t.Error(err) + } + + if out := file.String(); out != "hello world\npatrick's dreams\n" { + t.Errorf("Wrong output: %q", out) + } +} diff --git a/cmds/sqlboiler_test.go b/cmds/sqlboiler_test.go index 111cd17..acb35d2 100644 --- a/cmds/sqlboiler_test.go +++ b/cmds/sqlboiler_test.go @@ -1,16 +1,58 @@ package cmds -import "testing" +import "github.com/pobri19/sqlboiler/dbdrivers" + +func init() { + cmdData = &CmdData{ + Tables: []string{"patrick_table"}, + Columns: [][]dbdrivers.DBColumn{ + []dbdrivers.DBColumn{ + {Name: "patrick_column", IsNullable: false}, + }, + }, + PkgName: "patrick", + OutFolder: "", + DBDriver: nil, + } +} + +/* +var testHeader = `package main + +import ( +) +` func TestInitTemplates(t *testing.T) { - // TODO(pobr19): Fix this - t.Skip("There's some problem with this test") - templates, err := initTemplates() + templates, err := initTemplates("./templates") if err != nil { t.Errorf("Unable to init templates: %s", err) } - if len(templates) < 2 { - t.Errorf("Expected > 2 templates to be loaded from templates folder, only loaded: %d\n\n%#v", len(templates), templates) + testData := tplData{ + Table: "hello_world", + Columns: []dbdrivers.DBColumn{ + {Name: "hello_there", Type: "int64", IsNullable: true}, + {Name: "enemy_friend_list", Type: "string", IsNullable: false}, + }, + } + + for _, tpl := range templates { + file, err := ioutil.TempFile(os.TempDir(), "boilertemplatetest") + if err != nil { + t.Fatal(err) + } + + fmt.Fprintln(testHeader) + + if err = tpl.Execute(tpl, testData); err != nil { + t.Error(err) + } + + if err = file.Close(); err != nil { + t.Error(err) + } } } + +*/