From 4cc62fdf5ad1fc53ce2f097870012f99e45e9884 Mon Sep 17 00:00:00 2001
From: Patrick O'brien <pobri19@gmail.com>
Date: Thu, 3 Mar 2016 02:59:34 +1000
Subject: [PATCH] Added imports to output

---
 cmds/commands.go    |  8 ++---
 cmds/imports.go     | 71 +++++++++++++++++++++++++++++++++++++++++
 cmds/shared.go      | 78 ---------------------------------------------
 cmds/shared_test.go | 77 +++++++++++++++++++++++++++-----------------
 4 files changed, 123 insertions(+), 111 deletions(-)
 create mode 100644 cmds/imports.go

diff --git a/cmds/commands.go b/cmds/commands.go
index 4916612..6e92039 100644
--- a/cmds/commands.go
+++ b/cmds/commands.go
@@ -17,12 +17,12 @@ type imports struct {
 // Imports that are defined
 var sqlBoilerDefaultImports = imports{
 	standard: []string{
-		"errors",
-		"fmt",
+		`"errors"`,
+		`"fmt"`,
 	},
 	thirdparty: []string{
-		"github.com/pobri19/sqlboiler/boil",
-		"gopkg.in/guregu/null.v3",
+		`"github.com/pobri19/sqlboiler/boil"`,
+		`"gopkg.in/guregu/null.v3"`,
 	},
 }
 
diff --git a/cmds/imports.go b/cmds/imports.go
new file mode 100644
index 0000000..648033a
--- /dev/null
+++ b/cmds/imports.go
@@ -0,0 +1,71 @@
+package cmds
+
+import (
+	"bytes"
+	"fmt"
+	"sort"
+	"strings"
+)
+
+type ImportSorter []string
+
+func (i ImportSorter) Len() int {
+	return len(i)
+}
+
+func (i ImportSorter) Swap(k, j int) {
+	i[k], i[j] = i[j], i[k]
+}
+
+func (i ImportSorter) Less(k, j int) bool {
+	res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
+	if res <= 0 {
+		return true
+	}
+
+	return false
+}
+
+func combineImports(a, b imports) imports {
+	var c imports
+
+	c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
+	c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
+
+	sort.Sort(ImportSorter(c.standard))
+	sort.Sort(ImportSorter(c.thirdparty))
+
+	return c
+}
+
+func buildImportString(imps *imports) []byte {
+	stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
+	if stdlen+thirdlen < 1 {
+		return []byte{}
+	}
+
+	if stdlen+thirdlen == 1 {
+		var imp string
+		if stdlen == 1 {
+			imp = imps.standard[0]
+		} else {
+			imp = imps.thirdparty[0]
+		}
+		return []byte(fmt.Sprintf(`import %s`, imp))
+	}
+
+	buf := &bytes.Buffer{}
+	buf.WriteString("import (")
+	for _, std := range imps.standard {
+		fmt.Fprintf(buf, "\n\t%s", std)
+	}
+	if stdlen != 0 && thirdlen != 0 {
+		buf.WriteString("\n")
+	}
+	for _, third := range imps.thirdparty {
+		fmt.Fprintf(buf, "\n\t%s", third)
+	}
+	buf.WriteString("\n)\n")
+
+	return buf.Bytes()
+}
diff --git a/cmds/shared.go b/cmds/shared.go
index b8319f5..54b00b7 100644
--- a/cmds/shared.go
+++ b/cmds/shared.go
@@ -1,12 +1,9 @@
 package cmds
 
 import (
-	"bytes"
 	"fmt"
 	"io"
 	"os"
-	"sort"
-	"strings"
 
 	"github.com/pobri19/sqlboiler/dbdrivers"
 	"github.com/spf13/cobra"
@@ -98,49 +95,6 @@ func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports)
 	return nil
 }
 
-func combineImports(a, b imports) imports {
-	var c imports
-
-	c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
-	c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
-
-	c.standard = sortImports(c.standard)
-	c.thirdparty = sortImports(c.thirdparty)
-
-	return c
-}
-
-// sortImports sorts the import strings alphabetically.
-// If the import begins with an underscore, it temporarily
-// strips it so that it does not impact the sort.
-func sortImports(data []string) []string {
-	sorted := make([]string, len(data))
-	copy(sorted, data)
-
-	var underscoreImports []string
-	for i, v := range sorted {
-		if string(v[0]) == "_" && len(v) > 1 {
-			s := strings.Split(v, "_")
-			underscoreImports = append(underscoreImports, s[1])
-			sorted[i] = s[1]
-		}
-	}
-
-	sort.Strings(sorted)
-
-AddUnderscores:
-	for i, v := range sorted {
-		for _, underImp := range underscoreImports {
-			if v == underImp {
-				sorted[i] = "_" + sorted[i]
-				continue AddUnderscores
-			}
-		}
-	}
-
-	return sorted
-}
-
 func combineStringSlices(a, b []string) []string {
 	c := make([]string, len(a)+len(b))
 	if len(a) > 0 {
@@ -174,35 +128,3 @@ func removeDuplicates(dedup []string) []string {
 
 	return dedup
 }
-
-func buildImportString(imps *imports) []byte {
-	stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
-	if stdlen+thirdlen < 1 {
-		return []byte{}
-	}
-
-	if stdlen+thirdlen == 1 {
-		var imp string
-		if stdlen == 1 {
-			imp = imps.standard[0]
-		} else {
-			imp = imps.thirdparty[0]
-		}
-		return []byte(fmt.Sprintf(`import "%s"`, imp))
-	}
-
-	buf := &bytes.Buffer{}
-	buf.WriteString("import (")
-	for _, std := range imps.standard {
-		fmt.Fprintf(buf, "\n\t\"%s\"", std)
-	}
-	if stdlen != 0 && thirdlen != 0 {
-		buf.WriteString("\n")
-	}
-	for _, third := range imps.thirdparty {
-		fmt.Fprintf(buf, "\n\t\"%s\"", third)
-	}
-	buf.WriteString("\n)\n")
-
-	return buf.Bytes()
-}
diff --git a/cmds/shared_test.go b/cmds/shared_test.go
index ffca52e..a28bdc0 100644
--- a/cmds/shared_test.go
+++ b/cmds/shared_test.go
@@ -71,7 +71,7 @@ func TestOutHandlerFiles(t *testing.T) {
 
 	a1 := imports{
 		standard: []string{
-			"fmt",
+			`"fmt"`,
 		},
 	}
 	file = &bytes.Buffer{}
@@ -85,7 +85,7 @@ func TestOutHandlerFiles(t *testing.T) {
 
 	a2 := imports{
 		thirdparty: []string{
-			"github.com/spf13/cobra",
+			`"github.com/spf13/cobra"`,
 		},
 	}
 	file = &bytes.Buffer{}
@@ -98,44 +98,68 @@ func TestOutHandlerFiles(t *testing.T) {
 	}
 
 	a3 := imports{
-		standard: []string{"fmt", "errors"},
+		standard: []string{
+			`"fmt"`,
+			`"errors"`,
+		},
 		thirdparty: []string{
-			"_github.com/lib/pq",
-			"_github.com/gorilla/n",
-			"github.com/gorilla/mux",
-			"github.com/gorilla/websocket",
+			`_ "github.com/lib/pq"`,
+			`_ "github.com/gorilla/n"`,
+			`"github.com/gorilla/mux"`,
+			`"github.com/gorilla/websocket"`,
 		},
 	}
 	file = &bytes.Buffer{}
 
+	sort.Sort(ImportSorter(a3.standard))
+	sort.Sort(ImportSorter(a3.thirdparty))
+
 	if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
 		t.Error(err)
 	}
-	if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
-		t.Errorf("Wrong output: %s", out)
+
+	expectedOut := `import (
+	"errors"
+	"fmt"
+
+	"github.com/gorilla/mux"
+	_ "github.com/gorilla/n"
+	"github.com/gorilla/websocket"
+	_ "github.com/lib/pq"
+)
+
+hello world
+patrick's dreams
+`
+
+	if out := file.String(); out != expectedOut {
+		t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
 	}
 }
 
 func TestSortImports(t *testing.T) {
-	a1 := []string{"fmt", "errors"}
+	a1 := []string{
+		`"fmt"`,
+		`"errors"`,
+	}
 	a2 := []string{
-		"_github.com/lib/pq",
-		"_github.com/gorilla/n",
-		"github.com/gorilla/mux",
-		"github.com/gorilla/websocket",
+		`_ "github.com/lib/pq"`,
+		`_ "github.com/gorilla/n"`,
+		`"github.com/gorilla/mux"`,
+		`"github.com/gorilla/websocket"`,
 	}
 
 	a1Expected := []string{"errors", "fmt"}
 	a2Expected := []string{
-		"github.com/gorilla/mux",
-		"_github.com/gorilla/n",
-		"github.com/gorilla/websocket",
-		"_github.com/lib/pq",
+		`"github.com/gorilla/mux"`,
+		`_ "github.com/gorilla/n"`,
+		`"github.com/gorilla/websocket"`,
+		`_ "github.com/lib/pq"`,
 	}
 
-	result := sortImports(a1)
-	if !reflect.DeepEqual(result, a1Expected) {
-		fmt.Errorf("Expected res to match a1expected, got: %v", result)
+	sort.Sort(ImportSorter(a1))
+	if !reflect.DeepEqual(a1, a1Expected) {
+		fmt.Errorf("Expected a1 to match a1Expected, got: %v", a1)
 	}
 
 	for i, v := range a1 {
@@ -144,9 +168,9 @@ func TestSortImports(t *testing.T) {
 		}
 	}
 
-	result = sortImports(a2)
-	if !reflect.DeepEqual(result, a2Expected) {
-		fmt.Errorf("Expected res to match a2expected, got: %v", result)
+	sort.Sort(ImportSorter(a2))
+	if !reflect.DeepEqual(a2, a2Expected) {
+		fmt.Errorf("Expected a2 to match a2expected, got: %v", a2)
 	}
 
 	for i, v := range a2 {
@@ -154,11 +178,6 @@ func TestSortImports(t *testing.T) {
 			fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
 		}
 	}
-
-	sort.Strings(result)
-	if reflect.DeepEqual(result, a2Expected) {
-		fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
-	}
 }
 
 func TestBuildImportString(t *testing.T) {