Added template sorter, MainTest separation
* Moved MainTests into a dedicated folder * Added sorter for templates (sorts structs to top) * Fixed all the broken tests * Tidied up output.go functions * Only creating one MainTest (main_test.go) per app run now * Split test imports up into normal test imports & main test imports
This commit is contained in:
parent
c84d35d394
commit
d3aeb7375d
13 changed files with 399 additions and 198 deletions
|
@ -8,33 +8,6 @@ import (
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
)
|
)
|
||||||
|
|
||||||
// sqlBoilerImports defines the list of default template imports.
|
|
||||||
var sqlBoilerImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"errors"`,
|
|
||||||
`"fmt"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/pobri19/sqlboiler/boil"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerTestImports defines the list of default test template imports.
|
|
||||||
var sqlBoilerTestImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"testing"`,
|
|
||||||
`"os"`,
|
|
||||||
`"os/exec"`,
|
|
||||||
`"fmt"`,
|
|
||||||
`"io/ioutil"`,
|
|
||||||
`"bytes"`,
|
|
||||||
`"errors"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/BurntSushi/toml"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||||
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
||||||
var sqlBoilerTypeImports = map[string]imports{
|
var sqlBoilerTypeImports = map[string]imports{
|
||||||
|
@ -58,12 +31,40 @@ var sqlBoilerTypeImports = map[string]imports{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlBoilerDriverTestImports defines the test template imports
|
// sqlBoilerImports defines the list of default template imports.
|
||||||
// for the particular database interfaces
|
var sqlBoilerImports = imports{
|
||||||
var sqlBoilerDriverTestImports = map[string]imports{
|
standard: importList{
|
||||||
|
`"errors"`,
|
||||||
|
`"fmt"`,
|
||||||
|
},
|
||||||
|
thirdparty: importList{
|
||||||
|
`"github.com/pobri19/sqlboiler/boil"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlBoilerTestImports defines the list of default test template imports.
|
||||||
|
var sqlBoilerTestImports = imports{
|
||||||
|
standard: importList{
|
||||||
|
`"testing"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqlBoilerTestMainImports = map[string]imports{
|
||||||
"postgres": imports{
|
"postgres": imports{
|
||||||
standard: importList{`"database/sql"`},
|
standard: importList{
|
||||||
thirdparty: importList{`_ "github.com/lib/pq"`},
|
`"testing"`,
|
||||||
|
`"os"`,
|
||||||
|
`"os/exec"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"io/ioutil"`,
|
||||||
|
`"bytes"`,
|
||||||
|
`"errors"`,
|
||||||
|
`"database/sql"`,
|
||||||
|
},
|
||||||
|
thirdparty: importList{
|
||||||
|
`"github.com/BurntSushi/toml"`,
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,28 +4,10 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (i importList) Len() int {
|
|
||||||
return len(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i importList) Swap(k, j int) {
|
|
||||||
i[k], i[j] = i[j], i[k]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i importList) Less(k, j int) bool {
|
|
||||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
|
||||||
if res <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineImports(a, b imports) imports {
|
func combineImports(a, b imports) imports {
|
||||||
var c imports
|
var c imports
|
||||||
|
|
||||||
|
@ -78,7 +60,7 @@ func buildImportString(imps imports) []byte {
|
||||||
} else {
|
} else {
|
||||||
imp = imps.thirdparty[0]
|
imp = imps.thirdparty[0]
|
||||||
}
|
}
|
||||||
return []byte(fmt.Sprintf(`import %s`, imp))
|
return []byte(fmt.Sprintf("import %s", imp))
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package cmds
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
|
@ -72,51 +71,6 @@ func TestCombineTypeImports(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortImports(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
a1 := importList{
|
|
||||||
`"fmt"`,
|
|
||||||
`"errors"`,
|
|
||||||
}
|
|
||||||
a2 := importList{
|
|
||||||
`_ "github.com/lib/pq"`,
|
|
||||||
`_ "github.com/gorilla/n"`,
|
|
||||||
`"github.com/gorilla/mux"`,
|
|
||||||
`"github.com/gorilla/websocket"`,
|
|
||||||
}
|
|
||||||
|
|
||||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
|
||||||
a2Expected := importList{
|
|
||||||
`"github.com/gorilla/mux"`,
|
|
||||||
`_ "github.com/gorilla/n"`,
|
|
||||||
`"github.com/gorilla/websocket"`,
|
|
||||||
`_ "github.com/lib/pq"`,
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Sort(a1)
|
|
||||||
if !reflect.DeepEqual(a1, a1Expected) {
|
|
||||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range a1 {
|
|
||||||
if v != a1Expected[i] {
|
|
||||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Sort(a2)
|
|
||||||
if !reflect.DeepEqual(a2, a2Expected) {
|
|
||||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range a2 {
|
|
||||||
if v != a2Expected[i] {
|
|
||||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCombineImports(t *testing.T) {
|
func TestCombineImports(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
163
cmds/output.go
163
cmds/output.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"go/format"
|
"go/format"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"text/template"
|
"text/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,31 +17,20 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||||
return file, err
|
return file, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
// generateOutput builds the file output and sends it to outHandler for saving
|
||||||
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
func generateOutput(cmdData *CmdData, data *tplData) error {
|
||||||
|
if len(cmdData.Templates) == 0 {
|
||||||
return errors.New("No template files located for generation")
|
return errors.New("No template files located for generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
var imps imports
|
var imps imports
|
||||||
var tpls []*template.Template
|
|
||||||
|
|
||||||
if testOutput {
|
imps.standard = sqlBoilerImports.standard
|
||||||
imps.standard = sqlBoilerTestImports.standard
|
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
|
||||||
imps = combineImports(imps, sqlBoilerDriverTestImports[cmdData.DriverName])
|
|
||||||
tpls = cmdData.TestTemplates
|
|
||||||
} else {
|
|
||||||
imps.standard = sqlBoilerImports.standard
|
|
||||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
|
||||||
tpls = cmdData.Templates
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop through and generate every individual template
|
for _, template := range cmdData.Templates {
|
||||||
for _, template := range tpls {
|
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||||
if !testOutput {
|
|
||||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
|
||||||
}
|
|
||||||
resp, err := generateTemplate(template, data)
|
resp, err := generateTemplate(template, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -48,7 +38,63 @@ func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||||
out = append(out, resp)
|
out = append(out, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := outHandler(cmdData, out, data, imps, testOutput); err != nil {
|
fName := data.Table.Name + ".go"
|
||||||
|
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestOutput builds the test file output and sends it to outHandler for saving
|
||||||
|
func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
||||||
|
if len(cmdData.TestTemplates) == 0 {
|
||||||
|
return errors.New("No template files located for generation")
|
||||||
|
}
|
||||||
|
|
||||||
|
var out [][]byte
|
||||||
|
var imps imports
|
||||||
|
|
||||||
|
imps.standard = sqlBoilerTestImports.standard
|
||||||
|
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||||
|
|
||||||
|
for _, template := range cmdData.TestTemplates {
|
||||||
|
resp, err := generateTemplate(template, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
out = append(out, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
fName := data.Table.Name + "_test.go"
|
||||||
|
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestMainOutput(cmdData *CmdData) error {
|
||||||
|
if cmdData.TestMainTemplate == nil {
|
||||||
|
return errors.New("No TestMain template located for generation")
|
||||||
|
}
|
||||||
|
|
||||||
|
var out [][]byte
|
||||||
|
var imps imports
|
||||||
|
|
||||||
|
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
|
||||||
|
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
|
||||||
|
|
||||||
|
resp, err := generateTemplate(cmdData.TestMainTemplate, &tplData{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
out = append(out, resp)
|
||||||
|
|
||||||
|
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,28 +102,71 @@ func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// outHandler loops over each template in the slice of byte slices and builds an output file.
|
// outHandler loops over each template in the slice of byte slices and builds an output file.
|
||||||
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
// func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||||
|
// out := testHarnessStdout
|
||||||
|
//
|
||||||
|
// var path string
|
||||||
|
//
|
||||||
|
// if len(cmdData.OutFolder) != 0 {
|
||||||
|
// if testTemplate {
|
||||||
|
// path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||||
|
// } else {
|
||||||
|
// path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// outFile, err := testHarnessFileOpen(path)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||||
|
// }
|
||||||
|
// defer outFile.Close()
|
||||||
|
// out = outFile
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||||
|
// return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// impStr := buildImportString(imps)
|
||||||
|
// if len(impStr) > 0 {
|
||||||
|
// if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||||
|
// return fmt.Errorf("Unable to write imports to file handle: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// for _, templateOutput := range output {
|
||||||
|
// if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||||
|
// return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func outHandler(cmdData *CmdData, data *tplData, imps imports, output [][]byte, testTemplate bool) error {
|
||||||
|
// var fileName string
|
||||||
|
// if testTemplate == true {
|
||||||
|
// fileName = data.Table.Name + "_test.go"
|
||||||
|
// } else {
|
||||||
|
// fileName = data.Table.Name + ".go"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// outGenerator()
|
||||||
|
// }
|
||||||
|
|
||||||
|
func outHandler(outFolder string, fileName string, pkgName string, imps imports, contents [][]byte) error {
|
||||||
out := testHarnessStdout
|
out := testHarnessStdout
|
||||||
|
|
||||||
var path string
|
path := filepath.Join(outFolder, fileName)
|
||||||
|
|
||||||
if len(cmdData.OutFolder) != 0 {
|
outFile, err := testHarnessFileOpen(path)
|
||||||
if testTemplate {
|
if err != nil {
|
||||||
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||||
} else {
|
|
||||||
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
|
||||||
}
|
|
||||||
|
|
||||||
outFile, err := testHarnessFileOpen(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
|
||||||
}
|
|
||||||
defer outFile.Close()
|
|
||||||
out = outFile
|
|
||||||
}
|
}
|
||||||
|
defer outFile.Close()
|
||||||
|
out = outFile
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
if _, err := fmt.Fprintf(out, "package %s\n\n", pkgName); err != nil {
|
||||||
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
return fmt.Errorf("Unable to write package name %s to file: %s", pkgName, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
impStr := buildImportString(imps)
|
impStr := buildImportString(imps)
|
||||||
|
@ -87,7 +176,7 @@ func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, templateOutput := range output {
|
for _, templateOutput := range contents {
|
||||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||||
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,36 +5,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOutHandler(t *testing.T) {
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
|
|
||||||
saveTestHarnessStdout := testHarnessStdout
|
|
||||||
testHarnessStdout = buf
|
|
||||||
defer func() {
|
|
||||||
testHarnessStdout = saveTestHarnessStdout
|
|
||||||
}()
|
|
||||||
|
|
||||||
data := tplData{
|
|
||||||
Table: dbdrivers.Table{
|
|
||||||
Name: "patrick",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
|
||||||
|
|
||||||
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
|
||||||
t.Errorf("Wrong output: %q", out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type NopWriteCloser struct {
|
type NopWriteCloser struct {
|
||||||
io.Writer
|
io.Writer
|
||||||
}
|
}
|
||||||
|
@ -47,6 +19,30 @@ func nopCloser(w io.Writer) io.WriteCloser {
|
||||||
return NopWriteCloser{w}
|
return NopWriteCloser{w}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOutHandler(t *testing.T) {
|
||||||
|
// set the function pointer back to its original value
|
||||||
|
// after we modify it for the test
|
||||||
|
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||||
|
defer func() {
|
||||||
|
testHarnessFileOpen = saveTestHarnessFileOpen
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
|
||||||
|
return nopCloser(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||||
|
|
||||||
|
if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||||
|
t.Errorf("Wrong output: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOutHandlerFiles(t *testing.T) {
|
func TestOutHandlerFiles(t *testing.T) {
|
||||||
saveTestHarnessFileOpen := testHarnessFileOpen
|
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -58,13 +54,9 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
return nopCloser(file), nil
|
return nopCloser(file), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data := tplData{
|
|
||||||
Table: dbdrivers.Table{Name: "patrick"},
|
|
||||||
}
|
|
||||||
|
|
||||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||||
|
|
||||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -78,7 +70,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
|
if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -92,7 +84,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
|
if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -116,7 +108,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
sort.Sort(a3.standard)
|
sort.Sort(a3.standard)
|
||||||
sort.Sort(a3.thirdparty)
|
sort.Sort(a3.thirdparty)
|
||||||
|
|
||||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
|
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
42
cmds/sorters.go
Normal file
42
cmds/sorters.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func (i importList) Len() int {
|
||||||
|
return len(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i importList) Swap(k, j int) {
|
||||||
|
i[k], i[j] = i[j], i[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i importList) Less(k, j int) bool {
|
||||||
|
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||||
|
if res <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t templater) Len() int {
|
||||||
|
return len(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t templater) Swap(k, j int) {
|
||||||
|
t[k], t[j] = t[j], t[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t templater) Less(k, j int) bool {
|
||||||
|
// Make sure "struct" goes to the front
|
||||||
|
if t[k].Name() == "struct.tpl" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
res := strings.Compare(t[k].Name(), t[j].Name())
|
||||||
|
if res <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
82
cmds/sorters_test.go
Normal file
82
cmds/sorters_test.go
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSortImports(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
a1 := importList{
|
||||||
|
`"fmt"`,
|
||||||
|
`"errors"`,
|
||||||
|
}
|
||||||
|
a2 := importList{
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
`_ "github.com/gorilla/n"`,
|
||||||
|
`"github.com/gorilla/mux"`,
|
||||||
|
`"github.com/gorilla/websocket"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||||
|
a2Expected := importList{
|
||||||
|
`"github.com/gorilla/mux"`,
|
||||||
|
`_ "github.com/gorilla/n"`,
|
||||||
|
`"github.com/gorilla/websocket"`,
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(a1)
|
||||||
|
if !reflect.DeepEqual(a1, a1Expected) {
|
||||||
|
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range a1 {
|
||||||
|
if v != a1Expected[i] {
|
||||||
|
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(a2)
|
||||||
|
if !reflect.DeepEqual(a2, a2Expected) {
|
||||||
|
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range a2 {
|
||||||
|
if v != a2Expected[i] {
|
||||||
|
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortTemplates(t *testing.T) {
|
||||||
|
templs := templater{
|
||||||
|
template.New("bob.tpl"),
|
||||||
|
template.New("all.tpl"),
|
||||||
|
template.New("struct.tpl"),
|
||||||
|
template.New("ttt.tpl"),
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
||||||
|
|
||||||
|
for i, v := range templs {
|
||||||
|
if v.Name() != expected[i] {
|
||||||
|
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
||||||
|
|
||||||
|
sort.Sort(templs)
|
||||||
|
|
||||||
|
for i, v := range templs {
|
||||||
|
if v.Name() != expected[i] {
|
||||||
|
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Errorf("cant sort templates")
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
@ -13,19 +14,26 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
templatesDirectory = "/cmds/templates"
|
templatesDirectory = "/cmds/templates"
|
||||||
templatesTestDirectory = "/cmds/templates_test"
|
templatesTestDirectory = "/cmds/templates_test"
|
||||||
|
templatesTestMainDirectory = "/cmds/templates_test/main_test"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadTemplates loads all template folders into the cmdData object.
|
// LoadTemplates loads all template folders into the cmdData object.
|
||||||
func (c *CmdData) LoadTemplates() error {
|
func initTemplates(cmdData *CmdData) error {
|
||||||
var err error
|
var err error
|
||||||
c.Templates, err = loadTemplates(templatesDirectory)
|
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := cmdData.DriverName + "_main.tpl"
|
||||||
|
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -47,7 +55,27 @@ func loadTemplates(dir string) ([]*template.Template, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tpl.Templates(), err
|
templates := templater(tpl.Templates())
|
||||||
|
sort.Sort(templates)
|
||||||
|
|
||||||
|
return templates, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTemplate loads a single template file.
|
||||||
|
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := filepath.Join(wd, dir, filename)
|
||||||
|
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tpl.Lookup(filename), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
||||||
|
@ -84,6 +112,12 @@ func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
// run executes the sqlboiler templates and outputs them to files.
|
// run executes the sqlboiler templates and outputs them to files.
|
||||||
func (c *CmdData) run(includeTests bool) error {
|
func (c *CmdData) run(includeTests bool) error {
|
||||||
|
if includeTests {
|
||||||
|
if err := generateTestMainOutput(c); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, table := range c.Tables {
|
for _, table := range c.Tables {
|
||||||
data := &tplData{
|
data := &tplData{
|
||||||
Table: table,
|
Table: table,
|
||||||
|
@ -91,13 +125,13 @@ func (c *CmdData) run(includeTests bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the regular templates
|
// Generate the regular templates
|
||||||
if err := generateOutput(c, data, false); err != nil {
|
if err := generateOutput(c, data); err != nil {
|
||||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the test templates
|
// Generate the test templates
|
||||||
if includeTests {
|
if includeTests {
|
||||||
if err := generateOutput(c, data, true); err != nil {
|
if err := generateTestOutput(c, data); err != nil {
|
||||||
return fmt.Errorf("Unable to generate output: %s", err)
|
return fmt.Errorf("Unable to generate output: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -130,6 +164,11 @@ func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string)
|
||||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = initTemplates(c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to initialize templates: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,23 @@ func init() {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
PkgName: "patrick",
|
PkgName: "patrick",
|
||||||
OutFolder: "",
|
OutFolder: "",
|
||||||
Interface: nil,
|
DriverName: "postgres",
|
||||||
|
Interface: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadTemplate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
template, err := loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to loadTemplate: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if template == nil {
|
||||||
|
t.Fatal("Unable to load template.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,6 +79,11 @@ func TestTemplates(t *testing.T) {
|
||||||
t.Errorf("Templates is empty.")
|
t.Errorf("Templates is empty.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to create tempdir: %s", err)
|
t.Fatalf("Unable to create tempdir: %s", err)
|
||||||
|
@ -78,7 +97,7 @@ func TestTemplates(t *testing.T) {
|
||||||
buf := bytes.Buffer{}
|
buf := bytes.Buffer{}
|
||||||
buf2 := bytes.Buffer{}
|
buf2 := bytes.Buffer{}
|
||||||
|
|
||||||
cmd := exec.Command("go", "test")
|
cmd := exec.Command("go", "test", "-c")
|
||||||
cmd.Dir = cmdData.OutFolder
|
cmd.Dir = cmdData.OutFolder
|
||||||
cmd.Stderr = &buf
|
cmd.Stderr = &buf
|
||||||
cmd.Stdout = &buf2
|
cmd.Stdout = &buf2
|
||||||
|
|
|
@ -257,7 +257,7 @@ func TestWherePrimaryKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
r := wherePrimaryKey(&test.Pkey, test.Start)
|
r := wherePrimaryKey(test.Pkey.Columns, test.Start)
|
||||||
if r != test.Should {
|
if r != test.Should {
|
||||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,6 +57,7 @@ func LoadConfigFile(filename string) error {
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
err := setup()
|
err := setup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
os.Exit(-1)
|
os.Exit(-1)
|
||||||
}
|
}
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
@ -111,5 +112,9 @@ func setup() error {
|
||||||
|
|
||||||
err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
||||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName))
|
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -10,19 +10,22 @@ import (
|
||||||
// CobraRunFunc declares the cobra.Command.Run function definition
|
// CobraRunFunc declares the cobra.Command.Run function definition
|
||||||
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
||||||
|
|
||||||
|
type templater []*template.Template
|
||||||
|
|
||||||
// CmdData holds the table schema a slice of (column name, column type) slices.
|
// CmdData holds the table schema a slice of (column name, column type) slices.
|
||||||
// It also holds a slice of all of the table names sqlboiler is generating against,
|
// It also holds a slice of all of the table names sqlboiler is generating against,
|
||||||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||||
// output file, if one is specified with a flag.
|
// output file, if one is specified with a flag.
|
||||||
type CmdData struct {
|
type CmdData struct {
|
||||||
Tables []dbdrivers.Table
|
Tables []dbdrivers.Table
|
||||||
PkgName string
|
PkgName string
|
||||||
OutFolder string
|
OutFolder string
|
||||||
Interface dbdrivers.Interface
|
Interface dbdrivers.Interface
|
||||||
DriverName string
|
DriverName string
|
||||||
Config *Config
|
Config *Config
|
||||||
Templates []*template.Template
|
Templates templater
|
||||||
TestTemplates []*template.Template
|
TestTemplates templater
|
||||||
|
TestMainTemplate *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// tplData is used to pass data to the template
|
// tplData is used to pass data to the template
|
||||||
|
|
7
main.go
7
main.go
|
@ -24,13 +24,6 @@ func main() {
|
||||||
os.Exit(-1)
|
os.Exit(-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load all templates
|
|
||||||
err = cmdData.LoadTemplates()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to load templates: %s\n", err)
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the cobra root command
|
// Set up the cobra root command
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
Use: "sqlboiler",
|
Use: "sqlboiler",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue