Added imports to output

This commit is contained in:
Patrick O'brien 2016-03-03 02:59:34 +10:00
parent d320bb6944
commit 4cc62fdf5a
4 changed files with 123 additions and 111 deletions

View file

@ -17,12 +17,12 @@ type imports struct {
// Imports that are defined // Imports that are defined
var sqlBoilerDefaultImports = imports{ var sqlBoilerDefaultImports = imports{
standard: []string{ standard: []string{
"errors", `"errors"`,
"fmt", `"fmt"`,
}, },
thirdparty: []string{ thirdparty: []string{
"github.com/pobri19/sqlboiler/boil", `"github.com/pobri19/sqlboiler/boil"`,
"gopkg.in/guregu/null.v3", `"gopkg.in/guregu/null.v3"`,
}, },
} }

71
cmds/imports.go Normal file
View file

@ -0,0 +1,71 @@
package cmds
import (
"bytes"
"fmt"
"sort"
"strings"
)
type ImportSorter []string
func (i ImportSorter) Len() int {
return len(i)
}
func (i ImportSorter) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i ImportSorter) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
sort.Sort(ImportSorter(c.standard))
sort.Sort(ImportSorter(c.thirdparty))
return c
}
func buildImportString(imps *imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdparty[0]
}
return []byte(fmt.Sprintf(`import %s`, imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdparty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}

View file

@ -1,12 +1,9 @@
package cmds package cmds
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"strings"
"github.com/pobri19/sqlboiler/dbdrivers" "github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -98,49 +95,6 @@ func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports)
return nil return nil
} }
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
c.standard = sortImports(c.standard)
c.thirdparty = sortImports(c.thirdparty)
return c
}
// sortImports sorts the import strings alphabetically.
// If the import begins with an underscore, it temporarily
// strips it so that it does not impact the sort.
func sortImports(data []string) []string {
sorted := make([]string, len(data))
copy(sorted, data)
var underscoreImports []string
for i, v := range sorted {
if string(v[0]) == "_" && len(v) > 1 {
s := strings.Split(v, "_")
underscoreImports = append(underscoreImports, s[1])
sorted[i] = s[1]
}
}
sort.Strings(sorted)
AddUnderscores:
for i, v := range sorted {
for _, underImp := range underscoreImports {
if v == underImp {
sorted[i] = "_" + sorted[i]
continue AddUnderscores
}
}
}
return sorted
}
func combineStringSlices(a, b []string) []string { func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b)) c := make([]string, len(a)+len(b))
if len(a) > 0 { if len(a) > 0 {
@ -174,35 +128,3 @@ func removeDuplicates(dedup []string) []string {
return dedup return dedup
} }
func buildImportString(imps *imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdparty[0]
}
return []byte(fmt.Sprintf(`import "%s"`, imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t\"%s\"", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdparty {
fmt.Fprintf(buf, "\n\t\"%s\"", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}

View file

@ -71,7 +71,7 @@ func TestOutHandlerFiles(t *testing.T) {
a1 := imports{ a1 := imports{
standard: []string{ standard: []string{
"fmt", `"fmt"`,
}, },
} }
file = &bytes.Buffer{} file = &bytes.Buffer{}
@ -85,7 +85,7 @@ func TestOutHandlerFiles(t *testing.T) {
a2 := imports{ a2 := imports{
thirdparty: []string{ thirdparty: []string{
"github.com/spf13/cobra", `"github.com/spf13/cobra"`,
}, },
} }
file = &bytes.Buffer{} file = &bytes.Buffer{}
@ -98,44 +98,68 @@ func TestOutHandlerFiles(t *testing.T) {
} }
a3 := imports{ a3 := imports{
standard: []string{"fmt", "errors"}, standard: []string{
`"fmt"`,
`"errors"`,
},
thirdparty: []string{ thirdparty: []string{
"_github.com/lib/pq", `_ "github.com/lib/pq"`,
"_github.com/gorilla/n", `_ "github.com/gorilla/n"`,
"github.com/gorilla/mux", `"github.com/gorilla/mux"`,
"github.com/gorilla/websocket", `"github.com/gorilla/websocket"`,
}, },
} }
file = &bytes.Buffer{} file = &bytes.Buffer{}
sort.Sort(ImportSorter(a3.standard))
sort.Sort(ImportSorter(a3.thirdparty))
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil { if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
t.Error(err) t.Error(err)
} }
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %s", out) expectedOut := `import (
"errors"
"fmt"
"github.com/gorilla/mux"
_ "github.com/gorilla/n"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
hello world
patrick's dreams
`
if out := file.String(); out != expectedOut {
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
} }
} }
func TestSortImports(t *testing.T) { func TestSortImports(t *testing.T) {
a1 := []string{"fmt", "errors"} a1 := []string{
`"fmt"`,
`"errors"`,
}
a2 := []string{ a2 := []string{
"_github.com/lib/pq", `_ "github.com/lib/pq"`,
"_github.com/gorilla/n", `_ "github.com/gorilla/n"`,
"github.com/gorilla/mux", `"github.com/gorilla/mux"`,
"github.com/gorilla/websocket", `"github.com/gorilla/websocket"`,
} }
a1Expected := []string{"errors", "fmt"} a1Expected := []string{"errors", "fmt"}
a2Expected := []string{ a2Expected := []string{
"github.com/gorilla/mux", `"github.com/gorilla/mux"`,
"_github.com/gorilla/n", `_ "github.com/gorilla/n"`,
"github.com/gorilla/websocket", `"github.com/gorilla/websocket"`,
"_github.com/lib/pq", `_ "github.com/lib/pq"`,
} }
result := sortImports(a1) sort.Sort(ImportSorter(a1))
if !reflect.DeepEqual(result, a1Expected) { if !reflect.DeepEqual(a1, a1Expected) {
fmt.Errorf("Expected res to match a1expected, got: %v", result) fmt.Errorf("Expected a1 to match a1Expected, got: %v", a1)
} }
for i, v := range a1 { for i, v := range a1 {
@ -144,9 +168,9 @@ func TestSortImports(t *testing.T) {
} }
} }
result = sortImports(a2) sort.Sort(ImportSorter(a2))
if !reflect.DeepEqual(result, a2Expected) { if !reflect.DeepEqual(a2, a2Expected) {
fmt.Errorf("Expected res to match a2expected, got: %v", result) fmt.Errorf("Expected a2 to match a2expected, got: %v", a2)
} }
for i, v := range a2 { for i, v := range a2 {
@ -154,11 +178,6 @@ func TestSortImports(t *testing.T) {
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
} }
} }
sort.Strings(result)
if reflect.DeepEqual(result, a2Expected) {
fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
}
} }
func TestBuildImportString(t *testing.T) { func TestBuildImportString(t *testing.T) {