From e08eacad056d04f70e84181da7df61a823ffe0b7 Mon Sep 17 00:00:00 2001
From: Aaron L <aaron@bettercoder.net>
Date: Sat, 16 Jul 2016 23:55:15 -0700
Subject: [PATCH] Stop dividing templates up, and execute by name

- This allows us to use templates defined from anywhere in other
  templates.
---
 output.go         | 22 +++++++++----------
 sqlboiler.go      |  8 +++----
 sqlboiler_test.go |  8 +++----
 templates.go      | 43 +++++++++++++++++++++++++++----------
 templates_test.go | 54 ++++++++++++++++++++++++++++++++++++++---------
 5 files changed, 95 insertions(+), 40 deletions(-)

diff --git a/output.go b/output.go
index 8bd444a..66bb932 100644
--- a/output.go
+++ b/output.go
@@ -73,7 +73,7 @@ type executeTemplateData struct {
 	state *State
 	data  *templateData
 
-	templates templateList
+	templates *templateList
 
 	importSet      imports
 	importNamedSet map[string]imports
@@ -94,14 +94,14 @@ func executeTemplates(e executeTemplateData) error {
 	imps.standard = e.importSet.standard
 	imps.thirdParty = e.importSet.thirdParty
 
-	for _, template := range e.templates {
+	for _, tplName := range e.templates.Templates() {
 		if e.combineImportsOnType {
 			imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns)
 		}
 
-		resp, err := executeTemplate(template, e.data)
+		resp, err := executeTemplate(e.templates.Template, tplName, e.data)
 		if err != nil {
-			return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
+			return fmt.Errorf("Error generating template %s: %s", tplName, err)
 		}
 		out = append(out, resp)
 	}
@@ -122,13 +122,13 @@ func executeSingletonTemplates(e executeTemplateData) error {
 
 	rgxRemove := regexp.MustCompile(`[0-9]+_`)
 
-	for _, template := range e.templates {
-		resp, err := executeTemplate(template, e.data)
+	for _, tplName := range e.templates.Templates() {
+		resp, err := executeTemplate(e.templates.Template, tplName, e.data)
 		if err != nil {
-			return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
+			return fmt.Errorf("Error generating template %s: %s", tplName, err)
 		}
 
-		fName := template.Name()
+		fName := tplName
 		ext := filepath.Ext(fName)
 		fName = rgxRemove.ReplaceAllString(fName[:len(fName)-len(ext)], "")
 
@@ -163,7 +163,7 @@ func generateTestMainOutput(state *State, data *templateData) error {
 	imps.standard = defaultTestMainImports[state.Config.DriverName].standard
 	imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
 
-	resp, err := executeTemplate(state.TestMainTemplate, data)
+	resp, err := executeTemplate(state.TestMainTemplate, state.TestMainTemplate.Name(), data)
 	if err != nil {
 		return err
 	}
@@ -213,9 +213,9 @@ var rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `)
 
 // executeTemplate takes a template and returns the output of the template
 // execution.
-func executeTemplate(t *template.Template, data *templateData) ([]byte, error) {
+func executeTemplate(t *template.Template, name string, data *templateData) ([]byte, error) {
 	var buf bytes.Buffer
-	if err := t.Execute(&buf, data); err != nil {
+	if err := t.ExecuteTemplate(&buf, name, data); err != nil {
 		return nil, errors.Wrap(err, "failed to execute template")
 	}
 
diff --git a/sqlboiler.go b/sqlboiler.go
index f3eab79..5f1d916 100644
--- a/sqlboiler.go
+++ b/sqlboiler.go
@@ -30,10 +30,10 @@ type State struct {
 	Driver bdb.Interface
 	Tables []bdb.Table
 
-	Templates              templateList
-	TestTemplates          templateList
-	SingletonTemplates     templateList
-	SingletonTestTemplates templateList
+	Templates              *templateList
+	TestTemplates          *templateList
+	SingletonTemplates     *templateList
+	SingletonTestTemplates *templateList
 
 	TestMainTemplate *template.Template
 }
diff --git a/sqlboiler_test.go b/sqlboiler_test.go
index 7b7ba8f..eed674a 100644
--- a/sqlboiler_test.go
+++ b/sqlboiler_test.go
@@ -96,7 +96,7 @@ func TestTemplates(t *testing.T) {
 		t.Fatalf("Unable to initialize templates: %s", err)
 	}
 
-	if len(state.Templates) == 0 {
+	if len(state.Templates.Templates()) == 0 {
 		t.Errorf("Templates is empty.")
 	}
 
@@ -105,7 +105,7 @@ func TestTemplates(t *testing.T) {
 		t.Fatalf("Unable to initialize singleton templates: %s", err)
 	}
 
-	if len(state.SingletonTemplates) == 0 {
+	if len(state.SingletonTemplates.Templates()) == 0 {
 		t.Errorf("SingletonTemplates is empty.")
 	}
 
@@ -114,7 +114,7 @@ func TestTemplates(t *testing.T) {
 		t.Fatalf("Unable to initialize templates: %s", err)
 	}
 
-	if len(state.Templates) == 0 {
+	if len(state.Templates.Templates()) == 0 {
 		t.Errorf("Templates is empty.")
 	}
 
@@ -128,7 +128,7 @@ func TestTemplates(t *testing.T) {
 		t.Fatalf("Unable to initialize single test templates: %s", err)
 	}
 
-	if len(state.SingletonTestTemplates) == 0 {
+	if len(state.SingletonTestTemplates.Templates()) == 0 {
 		t.Errorf("SingleTestTemplates is empty.")
 	}
 
diff --git a/templates.go b/templates.go
index 2ee3fd0..5c2ba5e 100644
--- a/templates.go
+++ b/templates.go
@@ -22,23 +22,27 @@ type templateData struct {
 	StringFuncs map[string]func(string) string
 }
 
-type templateList []*template.Template
+type templateList struct {
+	*template.Template
+}
 
-func (t templateList) Len() int {
+type templateNameList []string
+
+func (t templateNameList) Len() int {
 	return len(t)
 }
 
-func (t templateList) Swap(k, j int) {
+func (t templateNameList) Swap(k, j int) {
 	t[k], t[j] = t[j], t[k]
 }
 
-func (t templateList) Less(k, j int) bool {
+func (t templateNameList) Less(k, j int) bool {
 	// Make sure "struct" goes to the front
-	if t[k].Name() == "struct.tpl" {
+	if t[k] == "struct.tpl" {
 		return true
 	}
 
-	res := strings.Compare(t[k].Name(), t[j].Name())
+	res := strings.Compare(t[k], t[j])
 	if res <= 0 {
 		return true
 	}
@@ -46,8 +50,28 @@ func (t templateList) Less(k, j int) bool {
 	return false
 }
 
+// Templates returns the name of all the templates defined in the template list
+func (t templateList) Templates() []string {
+	tplList := t.Template.Templates()
+
+	if len(tplList) == 0 {
+		return nil
+	}
+
+	ret := make([]string, 0, len(tplList))
+	for _, tpl := range tplList {
+		if name := tpl.Name(); strings.HasSuffix(name, ".tpl") {
+			ret = append(ret, name)
+		}
+	}
+
+	sort.Sort(templateNameList(ret))
+
+	return ret
+}
+
 // loadTemplates loads all of the template files in the specified directory.
-func loadTemplates(dir string) (templateList, error) {
+func loadTemplates(dir string) (*templateList, error) {
 	wd, err := os.Getwd()
 	if err != nil {
 		return nil, err
@@ -60,10 +84,7 @@ func loadTemplates(dir string) (templateList, error) {
 		return nil, err
 	}
 
-	templates := templateList(tpl.Templates())
-	sort.Sort(templates)
-
-	return templates, err
+	return &templateList{Template: tpl}, err
 }
 
 // loadTemplate loads a single template file.
diff --git a/templates_test.go b/templates_test.go
index 760b80a..df45e1e 100644
--- a/templates_test.go
+++ b/templates_test.go
@@ -6,19 +6,21 @@ import (
 	"text/template"
 )
 
-func TestTemplateListSort(t *testing.T) {
-	templs := templateList{
-		template.New("bob.tpl"),
-		template.New("all.tpl"),
-		template.New("struct.tpl"),
-		template.New("ttt.tpl"),
+func TestTemplateNameListSort(t *testing.T) {
+	t.Parallel()
+
+	templs := templateNameList{
+		"bob.tpl",
+		"all.tpl",
+		"struct.tpl",
+		"ttt.tpl",
 	}
 
 	expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
 
 	for i, v := range templs {
-		if v.Name() != expected[i] {
-			t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
+		if v != expected[i] {
+			t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v)
 		}
 	}
 
@@ -27,8 +29,40 @@ func TestTemplateListSort(t *testing.T) {
 	sort.Sort(templs)
 
 	for i, v := range templs {
-		if v.Name() != expected[i] {
-			t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
+		if v != expected[i] {
+			t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v)
 		}
 	}
 }
+
+func TestTemplateList_Templates(t *testing.T) {
+	t.Parallel()
+
+	tpl := template.New("")
+	tpl.New("wat.tpl").Parse("hello")
+	tpl.New("que.tpl").Parse("there")
+	tpl.New("not").Parse("hello")
+
+	tplList := templateList{tpl}
+	foundWat, foundQue, foundNot := false, false, false
+	for _, n := range tplList.Templates() {
+		switch n {
+		case "wat.tpl":
+			foundWat = true
+		case "que.tpl":
+			foundQue = true
+		case "not":
+			foundNot = true
+		}
+	}
+
+	if !foundWat {
+		t.Error("want wat")
+	}
+	if !foundQue {
+		t.Error("want que")
+	}
+	if foundNot {
+		t.Error("don't want not")
+	}
+}