Added underscores import stuff
This commit is contained in:
parent
8a9003770b
commit
d320bb6944
3 changed files with 157 additions and 14 deletions
10
cmds/boil.go
10
cmds/boil.go
|
@ -25,12 +25,18 @@ func boilRun(cmd *cobra.Command, args []string) {
|
|||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
// Loop through and generate every command template (excluding skipTemplates)
|
||||
for _, command := range commandNames {
|
||||
for i, command := range commandNames {
|
||||
if i == 0 {
|
||||
imps = combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[command])
|
||||
} else {
|
||||
imps = combineImports(imps, sqlBoilerCustomImports[command])
|
||||
}
|
||||
out = append(out, generateTemplate(command, &data))
|
||||
}
|
||||
|
||||
err := outHandler(cmdData.OutFolder, out, &data)
|
||||
err := outHandler(cmdData.OutFolder, out, &data, &imps)
|
||||
if err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -52,7 +53,8 @@ func defaultRun(cmd *cobra.Command, args []string) {
|
|||
// execution output to a [][]byte before sending it to outHandler.
|
||||
out := [][]byte{generateTemplate(cmd.Name(), &data)}
|
||||
|
||||
err := outHandler(cmdData.OutFolder, out, &data)
|
||||
imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()])
|
||||
err := outHandler(cmdData.OutFolder, out, &data, &imps)
|
||||
if err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
|
||||
}
|
||||
|
@ -67,7 +69,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
|||
|
||||
// outHandler loops over the slice of byte slices, outputting them to either
|
||||
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
|
||||
func outHandler(outFolder string, output [][]byte, data *tplData) error {
|
||||
func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports) error {
|
||||
out := testHarnessStdout
|
||||
|
||||
if len(outFolder) != 0 {
|
||||
|
@ -80,6 +82,13 @@ func outHandler(outFolder string, output [][]byte, data *tplData) error {
|
|||
out = outFile
|
||||
}
|
||||
|
||||
impStr := buildImportString(imps)
|
||||
if len(impStr) > 0 {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to write imports to file handle: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
for _, templateOutput := range output {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err))
|
||||
|
@ -94,12 +103,44 @@ func combineImports(a, b imports) imports {
|
|||
|
||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
||||
sort.Strings(c.standard)
|
||||
sort.Strings(c.thirdparty)
|
||||
|
||||
c.standard = sortImports(c.standard)
|
||||
c.thirdparty = sortImports(c.thirdparty)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// sortImports sorts the import strings alphabetically.
|
||||
// If the import begins with an underscore, it temporarily
|
||||
// strips it so that it does not impact the sort.
|
||||
func sortImports(data []string) []string {
|
||||
sorted := make([]string, len(data))
|
||||
copy(sorted, data)
|
||||
|
||||
var underscoreImports []string
|
||||
for i, v := range sorted {
|
||||
if string(v[0]) == "_" && len(v) > 1 {
|
||||
s := strings.Split(v, "_")
|
||||
underscoreImports = append(underscoreImports, s[1])
|
||||
sorted[i] = s[1]
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(sorted)
|
||||
|
||||
AddUnderscores:
|
||||
for i, v := range sorted {
|
||||
for _, underImp := range underscoreImports {
|
||||
if v == underImp {
|
||||
sorted[i] = "_" + sorted[i]
|
||||
continue AddUnderscores
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
|
@ -134,8 +175,12 @@ func removeDuplicates(dedup []string) []string {
|
|||
return dedup
|
||||
}
|
||||
|
||||
func buildImportString(imps imports) []byte {
|
||||
func buildImportString(imps *imports) []byte {
|
||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
||||
if stdlen+thirdlen < 1 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if stdlen+thirdlen == 1 {
|
||||
var imp string
|
||||
if stdlen == 1 {
|
||||
|
@ -149,15 +194,15 @@ func buildImportString(imps imports) []byte {
|
|||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("import (")
|
||||
for _, std := range imps.standard {
|
||||
fmt.Fprintf(buf, "\t%s", std)
|
||||
fmt.Fprintf(buf, "\n\t\"%s\"", std)
|
||||
}
|
||||
if stdlen != 0 && thirdlen != 0 {
|
||||
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
for _, third := range imps.thirdparty {
|
||||
fmt.Fprintf(buf, "\t%s", third)
|
||||
fmt.Fprintf(buf, "\n\t\"%s\"", third)
|
||||
}
|
||||
buf.WriteString(")")
|
||||
buf.WriteString("\n)\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -22,7 +24,7 @@ func TestOutHandler(t *testing.T) {
|
|||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler("", templateOutputs, &data); err != nil {
|
||||
if err := outHandler("", templateOutputs, &data, &imports{}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
|
@ -60,13 +62,103 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler("folder", templateOutputs, &data); err != nil {
|
||||
if err := outHandler("folder", templateOutputs, &data, &imports{}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if out := file.String(); out != "hello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a1 := imports{
|
||||
standard: []string{
|
||||
"fmt",
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler("folder", templateOutputs, &data, &a1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "import \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a2 := imports{
|
||||
thirdparty: []string{
|
||||
"github.com/spf13/cobra",
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler("folder", templateOutputs, &data, &a2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a3 := imports{
|
||||
standard: []string{"fmt", "errors"},
|
||||
thirdparty: []string{
|
||||
"_github.com/lib/pq",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/mux",
|
||||
"github.com/gorilla/websocket",
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortImports(t *testing.T) {
|
||||
a1 := []string{"fmt", "errors"}
|
||||
a2 := []string{
|
||||
"_github.com/lib/pq",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/mux",
|
||||
"github.com/gorilla/websocket",
|
||||
}
|
||||
|
||||
a1Expected := []string{"errors", "fmt"}
|
||||
a2Expected := []string{
|
||||
"github.com/gorilla/mux",
|
||||
"_github.com/gorilla/n",
|
||||
"github.com/gorilla/websocket",
|
||||
"_github.com/lib/pq",
|
||||
}
|
||||
|
||||
result := sortImports(a1)
|
||||
if !reflect.DeepEqual(result, a1Expected) {
|
||||
fmt.Errorf("Expected res to match a1expected, got: %v", result)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
if v != a1Expected[i] {
|
||||
fmt.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
result = sortImports(a2)
|
||||
if !reflect.DeepEqual(result, a2Expected) {
|
||||
fmt.Errorf("Expected res to match a2expected, got: %v", result)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
if v != a2Expected[i] {
|
||||
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(result)
|
||||
if reflect.DeepEqual(result, a2Expected) {
|
||||
fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildImportString(t *testing.T) {
|
||||
|
|
Loading…
Add table
Reference in a new issue