Added imports to output
This commit is contained in:
parent
d320bb6944
commit
4cc62fdf5a
4 changed files with 123 additions and 111 deletions
|
@ -17,12 +17,12 @@ type imports struct {
|
||||||
// Imports that are defined
|
// Imports that are defined
|
||||||
var sqlBoilerDefaultImports = imports{
|
var sqlBoilerDefaultImports = imports{
|
||||||
standard: []string{
|
standard: []string{
|
||||||
"errors",
|
`"errors"`,
|
||||||
"fmt",
|
`"fmt"`,
|
||||||
},
|
},
|
||||||
thirdparty: []string{
|
thirdparty: []string{
|
||||||
"github.com/pobri19/sqlboiler/boil",
|
`"github.com/pobri19/sqlboiler/boil"`,
|
||||||
"gopkg.in/guregu/null.v3",
|
`"gopkg.in/guregu/null.v3"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
71
cmds/imports.go
Normal file
71
cmds/imports.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImportSorter []string
|
||||||
|
|
||||||
|
func (i ImportSorter) Len() int {
|
||||||
|
return len(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i ImportSorter) Swap(k, j int) {
|
||||||
|
i[k], i[j] = i[j], i[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i ImportSorter) Less(k, j int) bool {
|
||||||
|
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||||
|
if res <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineImports(a, b imports) imports {
|
||||||
|
var c imports
|
||||||
|
|
||||||
|
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||||
|
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
||||||
|
|
||||||
|
sort.Sort(ImportSorter(c.standard))
|
||||||
|
sort.Sort(ImportSorter(c.thirdparty))
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildImportString(imps *imports) []byte {
|
||||||
|
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
||||||
|
if stdlen+thirdlen < 1 {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stdlen+thirdlen == 1 {
|
||||||
|
var imp string
|
||||||
|
if stdlen == 1 {
|
||||||
|
imp = imps.standard[0]
|
||||||
|
} else {
|
||||||
|
imp = imps.thirdparty[0]
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf(`import %s`, imp))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
buf.WriteString("import (")
|
||||||
|
for _, std := range imps.standard {
|
||||||
|
fmt.Fprintf(buf, "\n\t%s", std)
|
||||||
|
}
|
||||||
|
if stdlen != 0 && thirdlen != 0 {
|
||||||
|
buf.WriteString("\n")
|
||||||
|
}
|
||||||
|
for _, third := range imps.thirdparty {
|
||||||
|
fmt.Fprintf(buf, "\n\t%s", third)
|
||||||
|
}
|
||||||
|
buf.WriteString("\n)\n")
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
|
@ -1,12 +1,9 @@
|
||||||
package cmds
|
package cmds
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -98,49 +95,6 @@ func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func combineImports(a, b imports) imports {
|
|
||||||
var c imports
|
|
||||||
|
|
||||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
|
||||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
|
||||||
|
|
||||||
c.standard = sortImports(c.standard)
|
|
||||||
c.thirdparty = sortImports(c.thirdparty)
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// sortImports sorts the import strings alphabetically.
|
|
||||||
// If the import begins with an underscore, it temporarily
|
|
||||||
// strips it so that it does not impact the sort.
|
|
||||||
func sortImports(data []string) []string {
|
|
||||||
sorted := make([]string, len(data))
|
|
||||||
copy(sorted, data)
|
|
||||||
|
|
||||||
var underscoreImports []string
|
|
||||||
for i, v := range sorted {
|
|
||||||
if string(v[0]) == "_" && len(v) > 1 {
|
|
||||||
s := strings.Split(v, "_")
|
|
||||||
underscoreImports = append(underscoreImports, s[1])
|
|
||||||
sorted[i] = s[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(sorted)
|
|
||||||
|
|
||||||
AddUnderscores:
|
|
||||||
for i, v := range sorted {
|
|
||||||
for _, underImp := range underscoreImports {
|
|
||||||
if v == underImp {
|
|
||||||
sorted[i] = "_" + sorted[i]
|
|
||||||
continue AddUnderscores
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sorted
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineStringSlices(a, b []string) []string {
|
func combineStringSlices(a, b []string) []string {
|
||||||
c := make([]string, len(a)+len(b))
|
c := make([]string, len(a)+len(b))
|
||||||
if len(a) > 0 {
|
if len(a) > 0 {
|
||||||
|
@ -174,35 +128,3 @@ func removeDuplicates(dedup []string) []string {
|
||||||
|
|
||||||
return dedup
|
return dedup
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildImportString(imps *imports) []byte {
|
|
||||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
|
||||||
if stdlen+thirdlen < 1 {
|
|
||||||
return []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if stdlen+thirdlen == 1 {
|
|
||||||
var imp string
|
|
||||||
if stdlen == 1 {
|
|
||||||
imp = imps.standard[0]
|
|
||||||
} else {
|
|
||||||
imp = imps.thirdparty[0]
|
|
||||||
}
|
|
||||||
return []byte(fmt.Sprintf(`import "%s"`, imp))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
buf.WriteString("import (")
|
|
||||||
for _, std := range imps.standard {
|
|
||||||
fmt.Fprintf(buf, "\n\t\"%s\"", std)
|
|
||||||
}
|
|
||||||
if stdlen != 0 && thirdlen != 0 {
|
|
||||||
buf.WriteString("\n")
|
|
||||||
}
|
|
||||||
for _, third := range imps.thirdparty {
|
|
||||||
fmt.Fprintf(buf, "\n\t\"%s\"", third)
|
|
||||||
}
|
|
||||||
buf.WriteString("\n)\n")
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
|
|
||||||
a1 := imports{
|
a1 := imports{
|
||||||
standard: []string{
|
standard: []string{
|
||||||
"fmt",
|
`"fmt"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
@ -85,7 +85,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
|
|
||||||
a2 := imports{
|
a2 := imports{
|
||||||
thirdparty: []string{
|
thirdparty: []string{
|
||||||
"github.com/spf13/cobra",
|
`"github.com/spf13/cobra"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
@ -98,44 +98,68 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a3 := imports{
|
a3 := imports{
|
||||||
standard: []string{"fmt", "errors"},
|
standard: []string{
|
||||||
|
`"fmt"`,
|
||||||
|
`"errors"`,
|
||||||
|
},
|
||||||
thirdparty: []string{
|
thirdparty: []string{
|
||||||
"_github.com/lib/pq",
|
`_ "github.com/lib/pq"`,
|
||||||
"_github.com/gorilla/n",
|
`_ "github.com/gorilla/n"`,
|
||||||
"github.com/gorilla/mux",
|
`"github.com/gorilla/mux"`,
|
||||||
"github.com/gorilla/websocket",
|
`"github.com/gorilla/websocket"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
|
sort.Sort(ImportSorter(a3.standard))
|
||||||
|
sort.Sort(ImportSorter(a3.thirdparty))
|
||||||
|
|
||||||
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
|
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
|
||||||
t.Errorf("Wrong output: %s", out)
|
expectedOut := `import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
_ "github.com/gorilla/n"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
hello world
|
||||||
|
patrick's dreams
|
||||||
|
`
|
||||||
|
|
||||||
|
if out := file.String(); out != expectedOut {
|
||||||
|
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortImports(t *testing.T) {
|
func TestSortImports(t *testing.T) {
|
||||||
a1 := []string{"fmt", "errors"}
|
a1 := []string{
|
||||||
|
`"fmt"`,
|
||||||
|
`"errors"`,
|
||||||
|
}
|
||||||
a2 := []string{
|
a2 := []string{
|
||||||
"_github.com/lib/pq",
|
`_ "github.com/lib/pq"`,
|
||||||
"_github.com/gorilla/n",
|
`_ "github.com/gorilla/n"`,
|
||||||
"github.com/gorilla/mux",
|
`"github.com/gorilla/mux"`,
|
||||||
"github.com/gorilla/websocket",
|
`"github.com/gorilla/websocket"`,
|
||||||
}
|
}
|
||||||
|
|
||||||
a1Expected := []string{"errors", "fmt"}
|
a1Expected := []string{"errors", "fmt"}
|
||||||
a2Expected := []string{
|
a2Expected := []string{
|
||||||
"github.com/gorilla/mux",
|
`"github.com/gorilla/mux"`,
|
||||||
"_github.com/gorilla/n",
|
`_ "github.com/gorilla/n"`,
|
||||||
"github.com/gorilla/websocket",
|
`"github.com/gorilla/websocket"`,
|
||||||
"_github.com/lib/pq",
|
`_ "github.com/lib/pq"`,
|
||||||
}
|
}
|
||||||
|
|
||||||
result := sortImports(a1)
|
sort.Sort(ImportSorter(a1))
|
||||||
if !reflect.DeepEqual(result, a1Expected) {
|
if !reflect.DeepEqual(a1, a1Expected) {
|
||||||
fmt.Errorf("Expected res to match a1expected, got: %v", result)
|
fmt.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, v := range a1 {
|
for i, v := range a1 {
|
||||||
|
@ -144,9 +168,9 @@ func TestSortImports(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = sortImports(a2)
|
sort.Sort(ImportSorter(a2))
|
||||||
if !reflect.DeepEqual(result, a2Expected) {
|
if !reflect.DeepEqual(a2, a2Expected) {
|
||||||
fmt.Errorf("Expected res to match a2expected, got: %v", result)
|
fmt.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, v := range a2 {
|
for i, v := range a2 {
|
||||||
|
@ -154,11 +178,6 @@ func TestSortImports(t *testing.T) {
|
||||||
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(result)
|
|
||||||
if reflect.DeepEqual(result, a2Expected) {
|
|
||||||
fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildImportString(t *testing.T) {
|
func TestBuildImportString(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue