Added template sorter, MainTest separation
* Moved MainTests into a dedicated folder * Added sorter for templates (sorts structs to top) * Fixed all the broken tests * Tidied up output.go functions * Only creating one MainTest (main_test.go) per app run now * Split test imports up into normal test imports & main test imports
This commit is contained in:
parent
c84d35d394
commit
d3aeb7375d
13 changed files with 399 additions and 198 deletions
|
@ -8,33 +8,6 @@ import (
|
|||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// sqlBoilerImports defines the list of default template imports.
|
||||
var sqlBoilerImports = imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerTestImports defines the list of default test template imports.
|
||||
var sqlBoilerTestImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"errors"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/BurntSushi/toml"`,
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
||||
var sqlBoilerTypeImports = map[string]imports{
|
||||
|
@ -58,12 +31,40 @@ var sqlBoilerTypeImports = map[string]imports{
|
|||
},
|
||||
}
|
||||
|
||||
// sqlBoilerDriverTestImports defines the test template imports
|
||||
// for the particular database interfaces
|
||||
var sqlBoilerDriverTestImports = map[string]imports{
|
||||
// sqlBoilerImports defines the list of default template imports.
|
||||
var sqlBoilerImports = imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerTestImports defines the list of default test template imports.
|
||||
var sqlBoilerTestImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerTestMainImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
standard: importList{`"database/sql"`},
|
||||
thirdparty: importList{`_ "github.com/lib/pq"`},
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"errors"`,
|
||||
`"database/sql"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/BurntSushi/toml"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -4,28 +4,10 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func (i importList) Len() int {
|
||||
return len(i)
|
||||
}
|
||||
|
||||
func (i importList) Swap(k, j int) {
|
||||
i[k], i[j] = i[j], i[k]
|
||||
}
|
||||
|
||||
func (i importList) Less(k, j int) bool {
|
||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func combineImports(a, b imports) imports {
|
||||
var c imports
|
||||
|
||||
|
@ -78,7 +60,7 @@ func buildImportString(imps imports) []byte {
|
|||
} else {
|
||||
imp = imps.thirdparty[0]
|
||||
}
|
||||
return []byte(fmt.Sprintf(`import %s`, imp))
|
||||
return []byte(fmt.Sprintf("import %s", imp))
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
|
|
@ -3,7 +3,6 @@ package cmds
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
|
@ -72,51 +71,6 @@ func TestCombineTypeImports(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSortImports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a1 := importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
}
|
||||
a2 := importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
}
|
||||
|
||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||
a2Expected := importList{
|
||||
`"github.com/gorilla/mux"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
}
|
||||
|
||||
sort.Sort(a1)
|
||||
if !reflect.DeepEqual(a1, a1Expected) {
|
||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
if v != a1Expected[i] {
|
||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(a2)
|
||||
if !reflect.DeepEqual(a2, a2Expected) {
|
||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
if v != a2Expected[i] {
|
||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineImports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
163
cmds/output.go
163
cmds/output.go
|
@ -7,6 +7,7 @@ import (
|
|||
"go/format"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
|
@ -16,31 +17,20 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
|||
return file, err
|
||||
}
|
||||
|
||||
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
||||
// generateOutput builds the file output and sends it to outHandler for saving
|
||||
func generateOutput(cmdData *CmdData, data *tplData) error {
|
||||
if len(cmdData.Templates) == 0 {
|
||||
return errors.New("No template files located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
var tpls []*template.Template
|
||||
|
||||
if testOutput {
|
||||
imps.standard = sqlBoilerTestImports.standard
|
||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||
imps = combineImports(imps, sqlBoilerDriverTestImports[cmdData.DriverName])
|
||||
tpls = cmdData.TestTemplates
|
||||
} else {
|
||||
imps.standard = sqlBoilerImports.standard
|
||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||
tpls = cmdData.Templates
|
||||
}
|
||||
imps.standard = sqlBoilerImports.standard
|
||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||
|
||||
// Loop through and generate every individual template
|
||||
for _, template := range tpls {
|
||||
if !testOutput {
|
||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||
}
|
||||
for _, template := range cmdData.Templates {
|
||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -48,7 +38,63 @@ func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
|||
out = append(out, resp)
|
||||
}
|
||||
|
||||
if err := outHandler(cmdData, out, data, imps, testOutput); err != nil {
|
||||
fName := data.Table.Name + ".go"
|
||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateTestOutput builds the test file output and sends it to outHandler for saving
|
||||
func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
||||
if len(cmdData.TestTemplates) == 0 {
|
||||
return errors.New("No template files located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerTestImports.standard
|
||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||
|
||||
for _, template := range cmdData.TestTemplates {
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, resp)
|
||||
}
|
||||
|
||||
fName := data.Table.Name + "_test.go"
|
||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateTestMainOutput(cmdData *CmdData) error {
|
||||
if cmdData.TestMainTemplate == nil {
|
||||
return errors.New("No TestMain template located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
|
||||
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
|
||||
|
||||
resp, err := generateTemplate(cmdData.TestMainTemplate, &tplData{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, resp)
|
||||
|
||||
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -56,28 +102,71 @@ func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
|||
}
|
||||
|
||||
// outHandler loops over each template in the slice of byte slices and builds an output file.
|
||||
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||
// func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||
// out := testHarnessStdout
|
||||
//
|
||||
// var path string
|
||||
//
|
||||
// if len(cmdData.OutFolder) != 0 {
|
||||
// if testTemplate {
|
||||
// path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||
// } else {
|
||||
// path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||
// }
|
||||
//
|
||||
// outFile, err := testHarnessFileOpen(path)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||
// }
|
||||
// defer outFile.Close()
|
||||
// out = outFile
|
||||
// }
|
||||
//
|
||||
// if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||
// return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||
// }
|
||||
//
|
||||
// impStr := buildImportString(imps)
|
||||
// if len(impStr) > 0 {
|
||||
// if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||
// return fmt.Errorf("Unable to write imports to file handle: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// for _, templateOutput := range output {
|
||||
// if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||
// return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// func outHandler(cmdData *CmdData, data *tplData, imps imports, output [][]byte, testTemplate bool) error {
|
||||
// var fileName string
|
||||
// if testTemplate == true {
|
||||
// fileName = data.Table.Name + "_test.go"
|
||||
// } else {
|
||||
// fileName = data.Table.Name + ".go"
|
||||
// }
|
||||
//
|
||||
// outGenerator()
|
||||
// }
|
||||
|
||||
func outHandler(outFolder string, fileName string, pkgName string, imps imports, contents [][]byte) error {
|
||||
out := testHarnessStdout
|
||||
|
||||
var path string
|
||||
path := filepath.Join(outFolder, fileName)
|
||||
|
||||
if len(cmdData.OutFolder) != 0 {
|
||||
if testTemplate {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||
} else {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||
}
|
||||
|
||||
outFile, err := testHarnessFileOpen(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
out = outFile
|
||||
outFile, err := testHarnessFileOpen(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
out = outFile
|
||||
|
||||
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||
if _, err := fmt.Fprintf(out, "package %s\n\n", pkgName); err != nil {
|
||||
return fmt.Errorf("Unable to write package name %s to file: %s", pkgName, path)
|
||||
}
|
||||
|
||||
impStr := buildImportString(imps)
|
||||
|
@ -87,7 +176,7 @@ func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports,
|
|||
}
|
||||
}
|
||||
|
||||
for _, templateOutput := range output {
|
||||
for _, templateOutput := range contents {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||
}
|
||||
|
|
|
@ -5,36 +5,8 @@ import (
|
|||
"io"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func TestOutHandler(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
saveTestHarnessStdout := testHarnessStdout
|
||||
testHarnessStdout = buf
|
||||
defer func() {
|
||||
testHarnessStdout = saveTestHarnessStdout
|
||||
}()
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{
|
||||
Name: "patrick",
|
||||
},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type NopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
@ -47,6 +19,30 @@ func nopCloser(w io.Writer) io.WriteCloser {
|
|||
return NopWriteCloser{w}
|
||||
}
|
||||
|
||||
func TestOutHandler(t *testing.T) {
|
||||
// set the function pointer back to its original value
|
||||
// after we modify it for the test
|
||||
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||
defer func() {
|
||||
testHarnessFileOpen = saveTestHarnessFileOpen
|
||||
}()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
|
||||
return nopCloser(buf), nil
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutHandlerFiles(t *testing.T) {
|
||||
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||
defer func() {
|
||||
|
@ -58,13 +54,9 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
return nopCloser(file), nil
|
||||
}
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{Name: "patrick"},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
|
@ -78,7 +70,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
|
||||
if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||
|
@ -92,7 +84,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
|
||||
if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
|
@ -116,7 +108,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
sort.Sort(a3.standard)
|
||||
sort.Sort(a3.thirdparty)
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
|
||||
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
|
|
42
cmds/sorters.go
Normal file
42
cmds/sorters.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package cmds
|
||||
|
||||
import "strings"
|
||||
|
||||
func (i importList) Len() int {
|
||||
return len(i)
|
||||
}
|
||||
|
||||
func (i importList) Swap(k, j int) {
|
||||
i[k], i[j] = i[j], i[k]
|
||||
}
|
||||
|
||||
func (i importList) Less(k, j int) bool {
|
||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t templater) Len() int {
|
||||
return len(t)
|
||||
}
|
||||
|
||||
func (t templater) Swap(k, j int) {
|
||||
t[k], t[j] = t[j], t[k]
|
||||
}
|
||||
|
||||
func (t templater) Less(k, j int) bool {
|
||||
// Make sure "struct" goes to the front
|
||||
if t[k].Name() == "struct.tpl" {
|
||||
return true
|
||||
}
|
||||
|
||||
res := strings.Compare(t[k].Name(), t[j].Name())
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
82
cmds/sorters_test.go
Normal file
82
cmds/sorters_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestSortImports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a1 := importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
}
|
||||
a2 := importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
}
|
||||
|
||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||
a2Expected := importList{
|
||||
`"github.com/gorilla/mux"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
}
|
||||
|
||||
sort.Sort(a1)
|
||||
if !reflect.DeepEqual(a1, a1Expected) {
|
||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
if v != a1Expected[i] {
|
||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(a2)
|
||||
if !reflect.DeepEqual(a2, a2Expected) {
|
||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
if v != a2Expected[i] {
|
||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortTemplates(t *testing.T) {
|
||||
templs := templater{
|
||||
template.New("bob.tpl"),
|
||||
template.New("all.tpl"),
|
||||
template.New("struct.tpl"),
|
||||
template.New("ttt.tpl"),
|
||||
}
|
||||
|
||||
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
|
||||
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
||||
|
||||
sort.Sort(templs)
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
|
||||
t.Errorf("cant sort templates")
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
|
@ -13,19 +14,26 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
templatesDirectory = "/cmds/templates"
|
||||
templatesTestDirectory = "/cmds/templates_test"
|
||||
templatesDirectory = "/cmds/templates"
|
||||
templatesTestDirectory = "/cmds/templates_test"
|
||||
templatesTestMainDirectory = "/cmds/templates_test/main_test"
|
||||
)
|
||||
|
||||
// LoadTemplates loads all template folders into the cmdData object.
|
||||
func (c *CmdData) LoadTemplates() error {
|
||||
func initTemplates(cmdData *CmdData) error {
|
||||
var err error
|
||||
c.Templates, err = loadTemplates(templatesDirectory)
|
||||
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := cmdData.DriverName + "_main.tpl"
|
||||
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -47,7 +55,27 @@ func loadTemplates(dir string) ([]*template.Template, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return tpl.Templates(), err
|
||||
templates := templater(tpl.Templates())
|
||||
sort.Sort(templates)
|
||||
|
||||
return templates, err
|
||||
}
|
||||
|
||||
// loadTemplate loads a single template file.
|
||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, filename)
|
||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tpl.Lookup(filename), err
|
||||
}
|
||||
|
||||
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
||||
|
@ -84,6 +112,12 @@ func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
|||
|
||||
// run executes the sqlboiler templates and outputs them to files.
|
||||
func (c *CmdData) run(includeTests bool) error {
|
||||
if includeTests {
|
||||
if err := generateTestMainOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range c.Tables {
|
||||
data := &tplData{
|
||||
Table: table,
|
||||
|
@ -91,13 +125,13 @@ func (c *CmdData) run(includeTests bool) error {
|
|||
}
|
||||
|
||||
// Generate the regular templates
|
||||
if err := generateOutput(c, data, false); err != nil {
|
||||
if err := generateOutput(c, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||
}
|
||||
|
||||
// Generate the test templates
|
||||
if includeTests {
|
||||
if err := generateOutput(c, data, true); err != nil {
|
||||
if err := generateTestOutput(c, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -130,6 +164,11 @@ func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string)
|
|||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
}
|
||||
|
||||
err = initTemplates(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -34,9 +34,23 @@ func init() {
|
|||
},
|
||||
},
|
||||
},
|
||||
PkgName: "patrick",
|
||||
OutFolder: "",
|
||||
Interface: nil,
|
||||
PkgName: "patrick",
|
||||
OutFolder: "",
|
||||
DriverName: "postgres",
|
||||
Interface: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadTemplate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
template, err := loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to loadTemplate: %s", err)
|
||||
}
|
||||
|
||||
if template == nil {
|
||||
t.Fatal("Unable to load template.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,6 +79,11 @@ func TestTemplates(t *testing.T) {
|
|||
t.Errorf("Templates is empty.")
|
||||
}
|
||||
|
||||
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create tempdir: %s", err)
|
||||
|
@ -78,7 +97,7 @@ func TestTemplates(t *testing.T) {
|
|||
buf := bytes.Buffer{}
|
||||
buf2 := bytes.Buffer{}
|
||||
|
||||
cmd := exec.Command("go", "test")
|
||||
cmd := exec.Command("go", "test", "-c")
|
||||
cmd.Dir = cmdData.OutFolder
|
||||
cmd.Stderr = &buf
|
||||
cmd.Stdout = &buf2
|
||||
|
|
|
@ -257,7 +257,7 @@ func TestWherePrimaryKey(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, test := range tests {
|
||||
r := wherePrimaryKey(&test.Pkey, test.Start)
|
||||
r := wherePrimaryKey(test.Pkey.Columns, test.Start)
|
||||
if r != test.Should {
|
||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ func LoadConfigFile(filename string) error {
|
|||
func TestMain(m *testing.M) {
|
||||
err := setup()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
code := m.Run()
|
||||
|
@ -111,5 +112,9 @@ func setup() error {
|
|||
|
||||
err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -10,19 +10,22 @@ import (
|
|||
// CobraRunFunc declares the cobra.Command.Run function definition
|
||||
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
||||
|
||||
type templater []*template.Template
|
||||
|
||||
// CmdData holds the table schema a slice of (column name, column type) slices.
|
||||
// It also holds a slice of all of the table names sqlboiler is generating against,
|
||||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||
// output file, if one is specified with a flag.
|
||||
type CmdData struct {
|
||||
Tables []dbdrivers.Table
|
||||
PkgName string
|
||||
OutFolder string
|
||||
Interface dbdrivers.Interface
|
||||
DriverName string
|
||||
Config *Config
|
||||
Templates []*template.Template
|
||||
TestTemplates []*template.Template
|
||||
Tables []dbdrivers.Table
|
||||
PkgName string
|
||||
OutFolder string
|
||||
Interface dbdrivers.Interface
|
||||
DriverName string
|
||||
Config *Config
|
||||
Templates templater
|
||||
TestTemplates templater
|
||||
TestMainTemplate *template.Template
|
||||
}
|
||||
|
||||
// tplData is used to pass data to the template
|
||||
|
|
7
main.go
7
main.go
|
@ -24,13 +24,6 @@ func main() {
|
|||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Load all templates
|
||||
err = cmdData.LoadTemplates()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load templates: %s\n", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Set up the cobra root command
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "sqlboiler",
|
||||
|
|
Loading…
Add table
Reference in a new issue