Added imports to output
This commit is contained in:
parent
d320bb6944
commit
4cc62fdf5a
4 changed files with 123 additions and 111 deletions
|
@ -17,12 +17,12 @@ type imports struct {
|
|||
// Imports that are defined
|
||||
var sqlBoilerDefaultImports = imports{
|
||||
standard: []string{
|
||||
"errors",
|
||||
"fmt",
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: []string{
|
||||
"github.com/pobri19/sqlboiler/boil",
|
||||
"gopkg.in/guregu/null.v3",
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
`"gopkg.in/guregu/null.v3"`,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
71
cmds/imports.go
Normal file
71
cmds/imports.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ImportSorter []string
|
||||
|
||||
func (i ImportSorter) Len() int {
|
||||
return len(i)
|
||||
}
|
||||
|
||||
func (i ImportSorter) Swap(k, j int) {
|
||||
i[k], i[j] = i[j], i[k]
|
||||
}
|
||||
|
||||
func (i ImportSorter) Less(k, j int) bool {
|
||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func combineImports(a, b imports) imports {
|
||||
var c imports
|
||||
|
||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
||||
|
||||
sort.Sort(ImportSorter(c.standard))
|
||||
sort.Sort(ImportSorter(c.thirdparty))
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func buildImportString(imps *imports) []byte {
|
||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
||||
if stdlen+thirdlen < 1 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if stdlen+thirdlen == 1 {
|
||||
var imp string
|
||||
if stdlen == 1 {
|
||||
imp = imps.standard[0]
|
||||
} else {
|
||||
imp = imps.thirdparty[0]
|
||||
}
|
||||
return []byte(fmt.Sprintf(`import %s`, imp))
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("import (")
|
||||
for _, std := range imps.standard {
|
||||
fmt.Fprintf(buf, "\n\t%s", std)
|
||||
}
|
||||
if stdlen != 0 && thirdlen != 0 {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
for _, third := range imps.thirdparty {
|
||||
fmt.Fprintf(buf, "\n\t%s", third)
|
||||
}
|
||||
buf.WriteString("\n)\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
|
@ -1,12 +1,9 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -98,49 +95,6 @@ func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports)
|
|||
return nil
|
||||
}
|
||||
|
||||
func combineImports(a, b imports) imports {
|
||||
var c imports
|
||||
|
||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
||||
|
||||
c.standard = sortImports(c.standard)
|
||||
c.thirdparty = sortImports(c.thirdparty)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// sortImports sorts the import strings alphabetically.
|
||||
// If the import begins with an underscore, it temporarily
|
||||
// strips it so that it does not impact the sort.
|
||||
func sortImports(data []string) []string {
|
||||
sorted := make([]string, len(data))
|
||||
copy(sorted, data)
|
||||
|
||||
var underscoreImports []string
|
||||
for i, v := range sorted {
|
||||
if string(v[0]) == "_" && len(v) > 1 {
|
||||
s := strings.Split(v, "_")
|
||||
underscoreImports = append(underscoreImports, s[1])
|
||||
sorted[i] = s[1]
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(sorted)
|
||||
|
||||
AddUnderscores:
|
||||
for i, v := range sorted {
|
||||
for _, underImp := range underscoreImports {
|
||||
if v == underImp {
|
||||
sorted[i] = "_" + sorted[i]
|
||||
continue AddUnderscores
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
|
@ -174,35 +128,3 @@ func removeDuplicates(dedup []string) []string {
|
|||
|
||||
return dedup
|
||||
}
|
||||
|
||||
func buildImportString(imps *imports) []byte {
|
||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
||||
if stdlen+thirdlen < 1 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if stdlen+thirdlen == 1 {
|
||||
var imp string
|
||||
if stdlen == 1 {
|
||||
imp = imps.standard[0]
|
||||
} else {
|
||||
imp = imps.thirdparty[0]
|
||||
}
|
||||
return []byte(fmt.Sprintf(`import "%s"`, imp))
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("import (")
|
||||
for _, std := range imps.standard {
|
||||
fmt.Fprintf(buf, "\n\t\"%s\"", std)
|
||||
}
|
||||
if stdlen != 0 && thirdlen != 0 {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
for _, third := range imps.thirdparty {
|
||||
fmt.Fprintf(buf, "\n\t\"%s\"", third)
|
||||
}
|
||||
buf.WriteString("\n)\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
|
||||
a1 := imports{
|
||||
standard: []string{
|
||||
"fmt",
|
||||
`"fmt"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
@ -85,7 +85,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
|
||||
a2 := imports{
|
||||
thirdparty: []string{
|
||||
"github.com/spf13/cobra",
|
||||
`"github.com/spf13/cobra"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
@ -98,44 +98,68 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
}
|
||||
|
||||
a3 := imports{
|
||||
standard: []string{"fmt", "errors"},
|
||||
standard: []string{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
},
|
||||
thirdparty: []string{
|
||||
"_github.com/lib/pq",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/mux",
|
||||
"github.com/gorilla/websocket",
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
sort.Sort(ImportSorter(a3.standard))
|
||||
sort.Sort(ImportSorter(a3.thirdparty))
|
||||
|
||||
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %s", out)
|
||||
|
||||
expectedOut := `import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/gorilla/n"
|
||||
"github.com/gorilla/websocket"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
hello world
|
||||
patrick's dreams
|
||||
`
|
||||
|
||||
if out := file.String(); out != expectedOut {
|
||||
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortImports(t *testing.T) {
|
||||
a1 := []string{"fmt", "errors"}
|
||||
a1 := []string{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
}
|
||||
a2 := []string{
|
||||
"_github.com/lib/pq",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/mux",
|
||||
"github.com/gorilla/websocket",
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
}
|
||||
|
||||
a1Expected := []string{"errors", "fmt"}
|
||||
a2Expected := []string{
|
||||
"github.com/gorilla/mux",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/websocket",
|
||||
"_github.com/lib/pq",
|
||||
`"github.com/gorilla/mux"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
}
|
||||
|
||||
result := sortImports(a1)
|
||||
if !reflect.DeepEqual(result, a1Expected) {
|
||||
fmt.Errorf("Expected res to match a1expected, got: %v", result)
|
||||
sort.Sort(ImportSorter(a1))
|
||||
if !reflect.DeepEqual(a1, a1Expected) {
|
||||
fmt.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
|
@ -144,9 +168,9 @@ func TestSortImports(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
result = sortImports(a2)
|
||||
if !reflect.DeepEqual(result, a2Expected) {
|
||||
fmt.Errorf("Expected res to match a2expected, got: %v", result)
|
||||
sort.Sort(ImportSorter(a2))
|
||||
if !reflect.DeepEqual(a2, a2Expected) {
|
||||
fmt.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
|
@ -154,11 +178,6 @@ func TestSortImports(t *testing.T) {
|
|||
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(result)
|
||||
if reflect.DeepEqual(result, a2Expected) {
|
||||
fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildImportString(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue