From 4cc62fdf5ad1fc53ce2f097870012f99e45e9884 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 3 Mar 2016 02:59:34 +1000 Subject: [PATCH] Added imports to output --- cmds/commands.go | 8 ++--- cmds/imports.go | 71 +++++++++++++++++++++++++++++++++++++++++ cmds/shared.go | 78 --------------------------------------------- cmds/shared_test.go | 77 +++++++++++++++++++++++++++----------------- 4 files changed, 123 insertions(+), 111 deletions(-) create mode 100644 cmds/imports.go diff --git a/cmds/commands.go b/cmds/commands.go index 4916612..6e92039 100644 --- a/cmds/commands.go +++ b/cmds/commands.go @@ -17,12 +17,12 @@ type imports struct { // Imports that are defined var sqlBoilerDefaultImports = imports{ standard: []string{ - "errors", - "fmt", + `"errors"`, + `"fmt"`, }, thirdparty: []string{ - "github.com/pobri19/sqlboiler/boil", - "gopkg.in/guregu/null.v3", + `"github.com/pobri19/sqlboiler/boil"`, + `"gopkg.in/guregu/null.v3"`, }, } diff --git a/cmds/imports.go b/cmds/imports.go new file mode 100644 index 0000000..648033a --- /dev/null +++ b/cmds/imports.go @@ -0,0 +1,71 @@ +package cmds + +import ( + "bytes" + "fmt" + "sort" + "strings" +) + +type ImportSorter []string + +func (i ImportSorter) Len() int { + return len(i) +} + +func (i ImportSorter) Swap(k, j int) { + i[k], i[j] = i[j], i[k] +} + +func (i ImportSorter) Less(k, j int) bool { + res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ ")) + if res <= 0 { + return true + } + + return false +} + +func combineImports(a, b imports) imports { + var c imports + + c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard)) + c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty)) + + sort.Sort(ImportSorter(c.standard)) + sort.Sort(ImportSorter(c.thirdparty)) + + return c +} + +func buildImportString(imps *imports) []byte { + stdlen, thirdlen := len(imps.standard), len(imps.thirdparty) + if stdlen+thirdlen < 1 { + return []byte{} + } + + if stdlen+thirdlen == 1 { + var imp string + if stdlen == 1 { + imp = imps.standard[0] + } else { + imp = imps.thirdparty[0] + } + return []byte(fmt.Sprintf(`import %s`, imp)) + } + + buf := &bytes.Buffer{} + buf.WriteString("import (") + for _, std := range imps.standard { + fmt.Fprintf(buf, "\n\t%s", std) + } + if stdlen != 0 && thirdlen != 0 { + buf.WriteString("\n") + } + for _, third := range imps.thirdparty { + fmt.Fprintf(buf, "\n\t%s", third) + } + buf.WriteString("\n)\n") + + return buf.Bytes() +} diff --git a/cmds/shared.go b/cmds/shared.go index b8319f5..54b00b7 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -1,12 +1,9 @@ package cmds import ( - "bytes" "fmt" "io" "os" - "sort" - "strings" "github.com/pobri19/sqlboiler/dbdrivers" "github.com/spf13/cobra" @@ -98,49 +95,6 @@ func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports) return nil } -func combineImports(a, b imports) imports { - var c imports - - c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard)) - c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty)) - - c.standard = sortImports(c.standard) - c.thirdparty = sortImports(c.thirdparty) - - return c -} - -// sortImports sorts the import strings alphabetically. -// If the import begins with an underscore, it temporarily -// strips it so that it does not impact the sort. -func sortImports(data []string) []string { - sorted := make([]string, len(data)) - copy(sorted, data) - - var underscoreImports []string - for i, v := range sorted { - if string(v[0]) == "_" && len(v) > 1 { - s := strings.Split(v, "_") - underscoreImports = append(underscoreImports, s[1]) - sorted[i] = s[1] - } - } - - sort.Strings(sorted) - -AddUnderscores: - for i, v := range sorted { - for _, underImp := range underscoreImports { - if v == underImp { - sorted[i] = "_" + sorted[i] - continue AddUnderscores - } - } - } - - return sorted -} - func combineStringSlices(a, b []string) []string { c := make([]string, len(a)+len(b)) if len(a) > 0 { @@ -174,35 +128,3 @@ func removeDuplicates(dedup []string) []string { return dedup } - -func buildImportString(imps *imports) []byte { - stdlen, thirdlen := len(imps.standard), len(imps.thirdparty) - if stdlen+thirdlen < 1 { - return []byte{} - } - - if stdlen+thirdlen == 1 { - var imp string - if stdlen == 1 { - imp = imps.standard[0] - } else { - imp = imps.thirdparty[0] - } - return []byte(fmt.Sprintf(`import "%s"`, imp)) - } - - buf := &bytes.Buffer{} - buf.WriteString("import (") - for _, std := range imps.standard { - fmt.Fprintf(buf, "\n\t\"%s\"", std) - } - if stdlen != 0 && thirdlen != 0 { - buf.WriteString("\n") - } - for _, third := range imps.thirdparty { - fmt.Fprintf(buf, "\n\t\"%s\"", third) - } - buf.WriteString("\n)\n") - - return buf.Bytes() -} diff --git a/cmds/shared_test.go b/cmds/shared_test.go index ffca52e..a28bdc0 100644 --- a/cmds/shared_test.go +++ b/cmds/shared_test.go @@ -71,7 +71,7 @@ func TestOutHandlerFiles(t *testing.T) { a1 := imports{ standard: []string{ - "fmt", + `"fmt"`, }, } file = &bytes.Buffer{} @@ -85,7 +85,7 @@ func TestOutHandlerFiles(t *testing.T) { a2 := imports{ thirdparty: []string{ - "github.com/spf13/cobra", + `"github.com/spf13/cobra"`, }, } file = &bytes.Buffer{} @@ -98,44 +98,68 @@ func TestOutHandlerFiles(t *testing.T) { } a3 := imports{ - standard: []string{"fmt", "errors"}, + standard: []string{ + `"fmt"`, + `"errors"`, + }, thirdparty: []string{ - "_github.com/lib/pq", - "_github.com/gorilla/n", - "github.com/gorilla/mux", - "github.com/gorilla/websocket", + `_ "github.com/lib/pq"`, + `_ "github.com/gorilla/n"`, + `"github.com/gorilla/mux"`, + `"github.com/gorilla/websocket"`, }, } file = &bytes.Buffer{} + sort.Sort(ImportSorter(a3.standard)) + sort.Sort(ImportSorter(a3.thirdparty)) + if err := outHandler("folder", templateOutputs, &data, &a3); err != nil { t.Error(err) } - if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" { - t.Errorf("Wrong output: %s", out) + + expectedOut := `import ( + "errors" + "fmt" + + "github.com/gorilla/mux" + _ "github.com/gorilla/n" + "github.com/gorilla/websocket" + _ "github.com/lib/pq" +) + +hello world +patrick's dreams +` + + if out := file.String(); out != expectedOut { + t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut) } } func TestSortImports(t *testing.T) { - a1 := []string{"fmt", "errors"} + a1 := []string{ + `"fmt"`, + `"errors"`, + } a2 := []string{ - "_github.com/lib/pq", - "_github.com/gorilla/n", - "github.com/gorilla/mux", - "github.com/gorilla/websocket", + `_ "github.com/lib/pq"`, + `_ "github.com/gorilla/n"`, + `"github.com/gorilla/mux"`, + `"github.com/gorilla/websocket"`, } a1Expected := []string{"errors", "fmt"} a2Expected := []string{ - "github.com/gorilla/mux", - "_github.com/gorilla/n", - "github.com/gorilla/websocket", - "_github.com/lib/pq", + `"github.com/gorilla/mux"`, + `_ "github.com/gorilla/n"`, + `"github.com/gorilla/websocket"`, + `_ "github.com/lib/pq"`, } - result := sortImports(a1) - if !reflect.DeepEqual(result, a1Expected) { - fmt.Errorf("Expected res to match a1expected, got: %v", result) + sort.Sort(ImportSorter(a1)) + if !reflect.DeepEqual(a1, a1Expected) { + fmt.Errorf("Expected a1 to match a1Expected, got: %v", a1) } for i, v := range a1 { @@ -144,9 +168,9 @@ func TestSortImports(t *testing.T) { } } - result = sortImports(a2) - if !reflect.DeepEqual(result, a2Expected) { - fmt.Errorf("Expected res to match a2expected, got: %v", result) + sort.Sort(ImportSorter(a2)) + if !reflect.DeepEqual(a2, a2Expected) { + fmt.Errorf("Expected a2 to match a2expected, got: %v", a2) } for i, v := range a2 { @@ -154,11 +178,6 @@ func TestSortImports(t *testing.T) { fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) } } - - sort.Strings(result) - if reflect.DeepEqual(result, a2Expected) { - fmt.Errorf("Expected res not to match a2expected when using sort.Strings.") - } } func TestBuildImportString(t *testing.T) {