From f179ea54c8512a77bdb6328f90e96b52ca31819e Mon Sep 17 00:00:00 2001
From: Aaron <aaron@bettercoder.net>
Date: Tue, 1 Mar 2016 21:37:14 -0800
Subject: [PATCH] Add tests to a lot of things.

- Clean up the outputHandler function.
---
 cmds/boil.go           | 53 +++++++++++++++++---------------
 cmds/boil_test.go      | 38 +++++++++++++++++++++++
 cmds/shared.go         | 44 ++++++++++++---------------
 cmds/shared_test.go    | 69 ++++++++++++++++++++++++++++++++++++++++++
 cmds/sqlboiler_test.go | 54 +++++++++++++++++++++++++++++----
 5 files changed, 203 insertions(+), 55 deletions(-)
 create mode 100644 cmds/boil_test.go
 create mode 100644 cmds/shared_test.go

diff --git a/cmds/boil.go b/cmds/boil.go
index b4a33e8..f4bfca9 100644
--- a/cmds/boil.go
+++ b/cmds/boil.go
@@ -13,12 +13,37 @@ var boilCmd = &cobra.Command{
 
 // boilRun executes every sqlboiler command, starting with structs.
 func boilRun(cmd *cobra.Command, args []string) {
+	commandNames := buildCommandList()
+
+	// Prepend "struct" command to templateNames slice so it sits at top of sort
+	commandNames = append([]string{"struct"}, commandNames...)
+
+	for i := 0; i < len(cmdData.Columns); i++ {
+		data := tplData{
+			Table:   cmdData.Tables[i],
+			Columns: cmdData.Columns[i],
+		}
+
+		var out [][]byte
+		// Loop through and generate every command template (excluding skipTemplates)
+		for _, command := range commandNames {
+			out = append(out, generateTemplate(command, &data))
+		}
+
+		err := outHandler(cmdData.OutFolder, out, &data)
+		if err != nil {
+			errorQuit(err)
+		}
+	}
+}
+
+func buildCommandList() []string {
 	// Exclude these commands from the output
 	skipTemplates := []string{
 		"boil",
 	}
 
-	var templateNames []string
+	var commandNames []string
 
 	// Build a list of template names
 	for _, c := range sqlBoilerCommands {
@@ -34,31 +59,11 @@ func boilRun(cmd *cobra.Command, args []string) {
 		}
 
 		if !skip {
-			templateNames = append(templateNames, c.Name())
+			commandNames = append(commandNames, c.Name())
 		}
 	}
 
 	// Sort all names alphabetically
-	sort.Strings(templateNames)
-
-	// Prepend "struct" command to templateNames slice so it sits at top of sort
-	templateNames = append([]string{"struct"}, templateNames...)
-
-	for i := 0; i < len(cmdData.Columns); i++ {
-		data := tplData{
-			Table:   cmdData.Tables[i],
-			Columns: cmdData.Columns[i],
-		}
-
-		var out [][]byte
-		// Loop through and generate every command template (excluding skipTemplates)
-		for _, n := range templateNames {
-			out = append(out, generateTemplate(n, &data))
-		}
-
-		err := outHandler(out, &data)
-		if err != nil {
-			errorQuit(err)
-		}
-	}
+	sort.Strings(commandNames)
+	return commandNames
 }
diff --git a/cmds/boil_test.go b/cmds/boil_test.go
new file mode 100644
index 0000000..0e68b0e
--- /dev/null
+++ b/cmds/boil_test.go
@@ -0,0 +1,38 @@
+package cmds
+
+import "testing"
+
+func TestBuildCommandList(t *testing.T) {
+	list := buildCommandList()
+
+	skips := []string{"struct", "boil"}
+
+	for _, item := range list {
+		for _, skipItem := range skips {
+			if item == skipItem {
+				t.Errorf("Did not expect to find: %s %#v", item, list)
+			}
+		}
+	}
+
+CommandNameLoop:
+	for cmdName := range sqlBoilerCommands {
+		for _, skipItem := range skips {
+			if cmdName == skipItem {
+				continue CommandNameLoop
+			}
+		}
+
+		found := false
+		for _, item := range list {
+			if item == cmdName {
+				found = true
+				break
+			}
+		}
+
+		if !found {
+			t.Error("Expected to find command name:", cmdName)
+		}
+	}
+}
diff --git a/cmds/shared.go b/cmds/shared.go
index 54ccfb8..3870175 100644
--- a/cmds/shared.go
+++ b/cmds/shared.go
@@ -2,6 +2,7 @@ package cmds
 
 import (
 	"fmt"
+	"io"
 	"os"
 
 	"github.com/pobri19/sqlboiler/dbdrivers"
@@ -49,44 +50,37 @@ func defaultRun(cmd *cobra.Command, args []string) {
 		// execution output to a [][]byte before sending it to outHandler.
 		out := [][]byte{generateTemplate(cmd.Name(), &data)}
 
-		err := outHandler(out, &data)
+		err := outHandler(cmdData.OutFolder, out, &data)
 		if err != nil {
 			errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
 		}
 	}
 }
 
+var testHarnessStdout io.Writer = os.Stdout
+var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
+	file, err := os.Create(filename)
+	return file, err
+}
+
 // outHandler loops over the slice of byte slices, outputting them to either
 // the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
-func outHandler(output [][]byte, data *tplData) error {
-	nl := []byte{'\n'}
+func outHandler(outFolder string, output [][]byte, data *tplData) error {
+	out := testHarnessStdout
 
-	if cmdData.OutFolder == "" {
-		for _, v := range output {
-			if _, err := os.Stdout.Write(v); err != nil {
-				return err
-			}
-
-			if _, err := os.Stdout.Write(nl); err != nil {
-				return err
-			}
-		}
-	} else { // If not using stdout, attempt to create the model file.
-		path := cmdData.OutFolder + "/" + data.Table + ".go"
-		out, err := os.Create(path)
+	if len(outFolder) != 0 {
+		path := outFolder + "/" + data.Table + ".go"
+		outFile, err := testHarnessFileOpen(path)
 		if err != nil {
 			errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))
 		}
+		defer outFile.Close()
+		out = outFile
+	}
 
-		// Combine the slice of slice into a single byte slice.
-		var newOutput []byte
-		for _, v := range output {
-			newOutput = append(newOutput, v...)
-			newOutput = append(newOutput, nl...)
-		}
-
-		if _, err := out.Write(newOutput); err != nil {
-			return err
+	for _, templateOutput := range output {
+		if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
+			errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err))
 		}
 	}
 
diff --git a/cmds/shared_test.go b/cmds/shared_test.go
new file mode 100644
index 0000000..130cf57
--- /dev/null
+++ b/cmds/shared_test.go
@@ -0,0 +1,69 @@
+package cmds
+
+import (
+	"bytes"
+	"io"
+	"testing"
+)
+
+func TestOutHandler(t *testing.T) {
+	buf := &bytes.Buffer{}
+
+	saveTestHarnessStdout := testHarnessStdout
+	testHarnessStdout = buf
+	defer func() {
+		testHarnessStdout = saveTestHarnessStdout
+	}()
+
+	data := tplData{
+		Table: "patrick",
+	}
+
+	templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
+
+	if err := outHandler("", templateOutputs, &data); err != nil {
+		t.Error(err)
+	}
+
+	if out := buf.String(); out != "hello world\npatrick's dreams\n" {
+		t.Errorf("Wrong output: %q", out)
+	}
+}
+
+type NopWriteCloser struct {
+	io.Writer
+}
+
+func (NopWriteCloser) Close() error {
+	return nil
+}
+
+func nopCloser(w io.Writer) io.WriteCloser {
+	return NopWriteCloser{w}
+}
+
+func TestOutHandlerFiles(t *testing.T) {
+	saveTestHarnessFileOpen := testHarnessFileOpen
+	defer func() {
+		testHarnessFileOpen = saveTestHarnessFileOpen
+	}()
+
+	file := &bytes.Buffer{}
+	testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
+		return nopCloser(file), nil
+	}
+
+	data := tplData{
+		Table: "patrick",
+	}
+
+	templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
+
+	if err := outHandler("folder", templateOutputs, &data); err != nil {
+		t.Error(err)
+	}
+
+	if out := file.String(); out != "hello world\npatrick's dreams\n" {
+		t.Errorf("Wrong output: %q", out)
+	}
+}
diff --git a/cmds/sqlboiler_test.go b/cmds/sqlboiler_test.go
index 111cd17..acb35d2 100644
--- a/cmds/sqlboiler_test.go
+++ b/cmds/sqlboiler_test.go
@@ -1,16 +1,58 @@
 package cmds
 
-import "testing"
+import "github.com/pobri19/sqlboiler/dbdrivers"
+
+func init() {
+	cmdData = &CmdData{
+		Tables: []string{"patrick_table"},
+		Columns: [][]dbdrivers.DBColumn{
+			[]dbdrivers.DBColumn{
+				{Name: "patrick_column", IsNullable: false},
+			},
+		},
+		PkgName:   "patrick",
+		OutFolder: "",
+		DBDriver:  nil,
+	}
+}
+
+/*
+var testHeader = `package main
+
+import (
+)
+`
 
 func TestInitTemplates(t *testing.T) {
-	// TODO(pobr19): Fix this
-	t.Skip("There's some problem with this test")
-	templates, err := initTemplates()
+	templates, err := initTemplates("./templates")
 	if err != nil {
 		t.Errorf("Unable to init templates: %s", err)
 	}
 
-	if len(templates) < 2 {
-		t.Errorf("Expected > 2 templates to be loaded from templates folder, only loaded: %d\n\n%#v", len(templates), templates)
+	testData := tplData{
+		Table: "hello_world",
+		Columns: []dbdrivers.DBColumn{
+			{Name: "hello_there", Type: "int64", IsNullable: true},
+			{Name: "enemy_friend_list", Type: "string", IsNullable: false},
+		},
+	}
+
+	for _, tpl := range templates {
+		file, err := ioutil.TempFile(os.TempDir(), "boilertemplatetest")
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		fmt.Fprintln(testHeader)
+
+		if err = tpl.Execute(tpl, testData); err != nil {
+			t.Error(err)
+		}
+
+		if err = file.Close(); err != nil {
+			t.Error(err)
+		}
 	}
 }
+
+*/