diff --git a/cmds/shared.go b/cmds/shared.go index 3870175..3151a84 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -1,9 +1,11 @@ package cmds import ( + "bytes" "fmt" "io" "os" + "sort" "github.com/pobri19/sqlboiler/dbdrivers" "github.com/spf13/cobra" @@ -86,3 +88,76 @@ func outHandler(outFolder string, output [][]byte, data *tplData) error { return nil } + +func combineImports(a, b imports) imports { + var c imports + + c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard)) + c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty)) + sort.Strings(c.standard) + sort.Strings(c.thirdparty) + + return c +} + +func combineStringSlices(a, b []string) []string { + c := make([]string, len(a)+len(b)) + if len(a) > 0 { + copy(c, a) + } + if len(b) > 0 { + copy(c[len(a):], b) + } + + return c +} + +func removeDuplicates(dedup []string) []string { + if len(dedup) <= 1 { + return dedup + } + + for i := 0; i < len(dedup)-1; i++ { + for j := i + 1; j < len(dedup); j++ { + if dedup[i] != dedup[j] { + continue + } + + if j != len(dedup)-1 { + dedup[j] = dedup[len(dedup)-1] + j-- + } + dedup = dedup[:len(dedup)-1] + } + } + + return dedup +} + +func buildImportString(imps imports) []byte { + stdlen, thirdlen := len(imps.standard), len(imps.thirdparty) + if stdlen+thirdlen == 1 { + var imp string + if stdlen == 1 { + imp = imps.standard[0] + } else { + imp = imps.thirdparty[0] + } + return []byte(fmt.Sprintf(`import "%s"`, imp)) + } + + buf := &bytes.Buffer{} + buf.WriteString("import (") + for _, std := range imps.standard { + fmt.Fprintf(buf, "\t%s", std) + } + if stdlen != 0 && thirdlen != 0 { + + } + for _, third := range imps.thirdparty { + fmt.Fprintf(buf, "\t%s", third) + } + buf.WriteString(")") + + return buf.Bytes() +} diff --git a/cmds/shared_test.go b/cmds/shared_test.go index 130cf57..369df38 100644 --- a/cmds/shared_test.go +++ b/cmds/shared_test.go @@ -2,6 +2,7 @@ package cmds import ( "bytes" + "fmt" "io" "testing" ) @@ -67,3 +68,102 @@ func TestOutHandlerFiles(t *testing.T) { t.Errorf("Wrong output: %q", out) } } + +func TestBuildImportString(t *testing.T) { +} + +func TestCombineImports(t *testing.T) { + a := imports{ + standard: []string{"fmt"}, + thirdparty: []string{"github.com/pobri19/sqlboiler", "gopkg.in/guregu/null.v3"}, + } + b := imports{ + standard: []string{"os"}, + thirdparty: []string{"github.com/pobri19/sqlboiler"}, + } + + c := combineImports(a, b) + + if c.standard[0] != "fmt" && c.standard[1] != "os" { + t.Errorf("Wanted: fmt, os got: %#v", c.standard) + } + if c.thirdparty[0] != "github.com/pobri19/sqlboiler" && c.thirdparty[1] != "gopkg.in/guregu/null.v3" { + t.Errorf("Wanted: github.com/pobri19/sqlboiler, gopkg.in/guregu/null.v3 got: %#v", c.thirdparty) + } +} + +func TestRemoveDuplicates(t *testing.T) { + hasDups := func(possible []string) error { + for i := 0; i < len(possible)-1; i++ { + for j := i + 1; j < len(possible); j++ { + if possible[i] == possible[j] { + return fmt.Errorf("found duplicate: %s [%d] [%d]", possible[i], i, j) + } + } + } + + return nil + } + + if len(removeDuplicates([]string{})) != 0 { + t.Error("It should have returned an empty slice") + } + + oneItem := []string{"patrick"} + slice := removeDuplicates(oneItem) + if ln := len(slice); ln != 1 { + t.Error("Length was wrong:", ln) + } else if oneItem[0] != slice[0] { + t.Errorf("Slices differ: %#v %#v", oneItem, slice) + } + + slice = removeDuplicates([]string{"hello", "patrick", "hello"}) + if ln := len(slice); ln != 2 { + t.Error("Length was wrong:", ln) + } + if err := hasDups(slice); err != nil { + t.Error(err) + } + + slice = removeDuplicates([]string{"five", "patrick", "hello", "hello", "patrick", "hello", "hello"}) + if ln := len(slice); ln != 3 { + t.Error("Length was wrong:", ln) + } + if err := hasDups(slice); err != nil { + t.Error(err) + } +} + +func TestCombineStringSlices(t *testing.T) { + var a, b []string + slice := combineStringSlices(a, b) + if ln := len(slice); ln != 0 { + t.Error("Len was wrong:", ln) + } + + a = []string{"1", "2"} + slice = combineStringSlices(a, b) + if ln := len(slice); ln != 2 { + t.Error("Len was wrong:", ln) + } else if slice[0] != a[0] || slice[1] != a[1] { + t.Errorf("Slice mismatch: %#v %#v", a, slice) + } + + b = a + a = nil + slice = combineStringSlices(a, b) + if ln := len(slice); ln != 2 { + t.Error("Len was wrong:", ln) + } else if slice[0] != b[0] || slice[1] != b[1] { + t.Errorf("Slice mismatch: %#v %#v", b, slice) + } + + a = b + b = []string{"3", "4"} + slice = combineStringSlices(a, b) + if ln := len(slice); ln != 4 { + t.Error("Len was wrong:", ln) + } else if slice[0] != a[0] || slice[1] != a[1] || slice[2] != b[0] || slice[3] != b[1] { + t.Errorf("Slice mismatch: %#v + %#v != #%v", a, b, slice) + } +}