Added underscores import stuff

This commit is contained in:
Patrick O'brien 2016-03-03 01:18:26 +10:00
parent 8a9003770b
commit d320bb6944
3 changed files with 157 additions and 14 deletions

View file

@ -25,12 +25,18 @@ func boilRun(cmd *cobra.Command, args []string) {
}
var out [][]byte
var imps imports
// Loop through and generate every command template (excluding skipTemplates)
for _, command := range commandNames {
for i, command := range commandNames {
if i == 0 {
imps = combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[command])
} else {
imps = combineImports(imps, sqlBoilerCustomImports[command])
}
out = append(out, generateTemplate(command, &data))
}
err := outHandler(cmdData.OutFolder, out, &data)
err := outHandler(cmdData.OutFolder, out, &data, &imps)
if err != nil {
errorQuit(err)
}

View file

@ -6,6 +6,7 @@ import (
"io"
"os"
"sort"
"strings"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
@ -52,7 +53,8 @@ func defaultRun(cmd *cobra.Command, args []string) {
// execution output to a [][]byte before sending it to outHandler.
out := [][]byte{generateTemplate(cmd.Name(), &data)}
err := outHandler(cmdData.OutFolder, out, &data)
imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()])
err := outHandler(cmdData.OutFolder, out, &data, &imps)
if err != nil {
errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
}
@ -67,7 +69,7 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
// outHandler loops over the slice of byte slices, outputting them to either
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
func outHandler(outFolder string, output [][]byte, data *tplData) error {
func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports) error {
out := testHarnessStdout
if len(outFolder) != 0 {
@ -80,6 +82,13 @@ func outHandler(outFolder string, output [][]byte, data *tplData) error {
out = outFile
}
impStr := buildImportString(imps)
if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
errorQuit(fmt.Errorf("Unable to write imports to file handle: %v", err))
}
}
for _, templateOutput := range output {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err))
@ -94,12 +103,44 @@ func combineImports(a, b imports) imports {
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
sort.Strings(c.standard)
sort.Strings(c.thirdparty)
c.standard = sortImports(c.standard)
c.thirdparty = sortImports(c.thirdparty)
return c
}
// sortImports sorts the import strings alphabetically.
// If the import begins with an underscore, it temporarily
// strips it so that it does not impact the sort.
func sortImports(data []string) []string {
sorted := make([]string, len(data))
copy(sorted, data)
var underscoreImports []string
for i, v := range sorted {
if string(v[0]) == "_" && len(v) > 1 {
s := strings.Split(v, "_")
underscoreImports = append(underscoreImports, s[1])
sorted[i] = s[1]
}
}
sort.Strings(sorted)
AddUnderscores:
for i, v := range sorted {
for _, underImp := range underscoreImports {
if v == underImp {
sorted[i] = "_" + sorted[i]
continue AddUnderscores
}
}
}
return sorted
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
@ -134,8 +175,12 @@ func removeDuplicates(dedup []string) []string {
return dedup
}
func buildImportString(imps imports) []byte {
func buildImportString(imps *imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
@ -149,15 +194,15 @@ func buildImportString(imps imports) []byte {
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\t%s", std)
fmt.Fprintf(buf, "\n\t\"%s\"", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdparty {
fmt.Fprintf(buf, "\t%s", third)
fmt.Fprintf(buf, "\n\t\"%s\"", third)
}
buf.WriteString(")")
buf.WriteString("\n)\n")
return buf.Bytes()
}

View file

@ -4,6 +4,8 @@ import (
"bytes"
"fmt"
"io"
"reflect"
"sort"
"testing"
)
@ -22,7 +24,7 @@ func TestOutHandler(t *testing.T) {
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("", templateOutputs, &data); err != nil {
if err := outHandler("", templateOutputs, &data, &imports{}); err != nil {
t.Error(err)
}
@ -60,13 +62,103 @@ func TestOutHandlerFiles(t *testing.T) {
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("folder", templateOutputs, &data); err != nil {
if err := outHandler("folder", templateOutputs, &data, &imports{}); err != nil {
t.Error(err)
}
if out := file.String(); out != "hello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a1 := imports{
standard: []string{
"fmt",
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", templateOutputs, &data, &a1); err != nil {
t.Error(err)
}
if out := file.String(); out != "import \"fmt\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a2 := imports{
thirdparty: []string{
"github.com/spf13/cobra",
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", templateOutputs, &data, &a2); err != nil {
t.Error(err)
}
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a3 := imports{
standard: []string{"fmt", "errors"},
thirdparty: []string{
"_github.com/lib/pq",
"_github.com/gorilla/n",
"github.com/gorilla/mux",
"github.com/gorilla/websocket",
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", templateOutputs, &data, &a3); err != nil {
t.Error(err)
}
if out := file.String(); out != "import \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %s", out)
}
}
func TestSortImports(t *testing.T) {
a1 := []string{"fmt", "errors"}
a2 := []string{
"_github.com/lib/pq",
"_github.com/gorilla/n",
"github.com/gorilla/mux",
"github.com/gorilla/websocket",
}
a1Expected := []string{"errors", "fmt"}
a2Expected := []string{
"github.com/gorilla/mux",
"_github.com/gorilla/n",
"github.com/gorilla/websocket",
"_github.com/lib/pq",
}
result := sortImports(a1)
if !reflect.DeepEqual(result, a1Expected) {
fmt.Errorf("Expected res to match a1expected, got: %v", result)
}
for i, v := range a1 {
if v != a1Expected[i] {
fmt.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
result = sortImports(a2)
if !reflect.DeepEqual(result, a2Expected) {
fmt.Errorf("Expected res to match a2expected, got: %v", result)
}
for i, v := range a2 {
if v != a2Expected[i] {
fmt.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
sort.Strings(result)
if reflect.DeepEqual(result, a2Expected) {
fmt.Errorf("Expected res not to match a2expected when using sort.Strings.")
}
}
func TestBuildImportString(t *testing.T) {