Did more stripping, fixed broken things
* Fixed broken template * Did some reorganizing * Need to fix TestTemplates test
This commit is contained in:
parent
27cafdd2fb
commit
f7a4ed0c54
10 changed files with 348 additions and 354 deletions
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// sqlBoilerDefaultImports defines the list of default template imports.
|
||||
// sqlBoilerImports defines the list of default template imports.
|
||||
var sqlBoilerImports = imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
|
@ -19,7 +19,7 @@ var sqlBoilerImports = imports{
|
|||
},
|
||||
}
|
||||
|
||||
// sqlBoilerDefaultTestImports defines the list of default test template imports.
|
||||
// sqlBoilerTestImports defines the list of default test template imports.
|
||||
var sqlBoilerTestImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
|
@ -58,7 +58,7 @@ var sqlBoilerTypeImports = map[string]imports{
|
|||
},
|
||||
}
|
||||
|
||||
// sqlBoilerConditionalDriverTestImports defines the test template imports
|
||||
// sqlBoilerDriverTestImports defines the test template imports
|
||||
// for the particular database interfaces
|
||||
var sqlBoilerDriverTestImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
|
|
|
@ -96,3 +96,37 @@ func buildImportString(imps imports) []byte {
|
|||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
copy(c, a)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
copy(c[len(a):], b)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func removeDuplicates(dedup []string) []string {
|
||||
if len(dedup) <= 1 {
|
||||
return dedup
|
||||
}
|
||||
|
||||
for i := 0; i < len(dedup)-1; i++ {
|
||||
for j := i + 1; j < len(dedup); j++ {
|
||||
if dedup[i] != dedup[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
if j != len(dedup)-1 {
|
||||
dedup[j] = dedup[len(dedup)-1]
|
||||
j--
|
||||
}
|
||||
dedup = dedup[:len(dedup)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return dedup
|
||||
}
|
|
@ -1,9 +1,7 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
@ -11,135 +9,66 @@ import (
|
|||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func TestOutHandler(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
saveTestHarnessStdout := testHarnessStdout
|
||||
testHarnessStdout = buf
|
||||
defer func() {
|
||||
testHarnessStdout = saveTestHarnessStdout
|
||||
}()
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{
|
||||
Name: "patrick",
|
||||
},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type NopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (NopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func nopCloser(w io.Writer) io.WriteCloser {
|
||||
return NopWriteCloser{w}
|
||||
}
|
||||
|
||||
func TestOutHandlerFiles(t *testing.T) {
|
||||
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||
defer func() {
|
||||
testHarnessFileOpen = saveTestHarnessFileOpen
|
||||
}()
|
||||
|
||||
file := &bytes.Buffer{}
|
||||
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
|
||||
return nopCloser(file), nil
|
||||
}
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{Name: "patrick"},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a1 := imports{
|
||||
func TestCombineTypeImports(t *testing.T) {
|
||||
imports1 := imports{
|
||||
standard: importList{
|
||||
`"fmt"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a2 := imports{
|
||||
thirdparty: []string{
|
||||
`"github.com/spf13/cobra"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a3 := imports{
|
||||
standard: importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
sort.Sort(a3.standard)
|
||||
sort.Sort(a3.thirdparty)
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
|
||||
t.Error(err)
|
||||
importsExpected := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
`"gopkg.in/guregu/null.v3"`,
|
||||
},
|
||||
}
|
||||
|
||||
expectedOut := `package patrick
|
||||
cols := []dbdrivers.Column{
|
||||
dbdrivers.Column{
|
||||
Type: "null.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "null.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "time.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "null.Float",
|
||||
},
|
||||
}
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/gorilla/n"
|
||||
"github.com/gorilla/websocket"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
if !reflect.DeepEqual(res1, importsExpected) {
|
||||
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
}
|
||||
|
||||
hello world
|
||||
patrick's dreams
|
||||
`
|
||||
imports2 := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
`"gopkg.in/guregu/null.v3"`,
|
||||
},
|
||||
}
|
||||
|
||||
if out := file.String(); out != expectedOut {
|
||||
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
|
||||
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
|
||||
|
||||
if !reflect.DeepEqual(res2, importsExpected) {
|
||||
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func TestCombineTypeImports(t *testing.T) {
|
||||
imports1 := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
||||
importsExpected := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
`"gopkg.in/guregu/null.v3"`,
|
||||
},
|
||||
}
|
||||
|
||||
cols := []dbdrivers.Column{
|
||||
dbdrivers.Column{
|
||||
Type: "null.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "null.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "time.Time",
|
||||
},
|
||||
dbdrivers.Column{
|
||||
Type: "null.Float",
|
||||
},
|
||||
}
|
||||
|
||||
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
|
||||
|
||||
if !reflect.DeepEqual(res1, importsExpected) {
|
||||
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
}
|
||||
|
||||
imports2 := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
`"gopkg.in/guregu/null.v3"`,
|
||||
},
|
||||
}
|
||||
|
||||
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
|
||||
|
||||
if !reflect.DeepEqual(res2, importsExpected) {
|
||||
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
}
|
||||
}
|
112
cmds/output.go
Normal file
112
cmds/output.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
"os"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
var testHarnessStdout io.Writer = os.Stdout
|
||||
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||
file, err := os.Create(filename)
|
||||
return file, err
|
||||
}
|
||||
|
||||
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
|
||||
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
|
||||
fmt.Println("No template files located for generation")
|
||||
return nil
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
var tpls []*template.Template
|
||||
|
||||
if testOutput {
|
||||
imps.standard = sqlBoilerTestImports.standard
|
||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||
imps = combineImports(imps, sqlBoilerDriverTestImports[cmdData.DriverName])
|
||||
tpls = cmdData.TestTemplates
|
||||
} else {
|
||||
imps.standard = sqlBoilerImports.standard
|
||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||
tpls = cmdData.Templates
|
||||
}
|
||||
|
||||
// Loop through and generate every individual template
|
||||
for _, template := range tpls {
|
||||
if !testOutput {
|
||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||
}
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, resp)
|
||||
}
|
||||
|
||||
if err := outHandler(cmdData, out, data, imps, testOutput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// outHandler loops over each template in the slice of byte slices and builds an output file.
|
||||
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||
out := testHarnessStdout
|
||||
|
||||
var path string
|
||||
|
||||
if len(cmdData.OutFolder) != 0 {
|
||||
if testTemplate {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||
} else {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||
}
|
||||
|
||||
outFile, err := testHarnessFileOpen(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
out = outFile
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||
}
|
||||
|
||||
impStr := buildImportString(imps)
|
||||
if len(impStr) > 0 {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||
return fmt.Errorf("Unable to write imports to file handle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, templateOutput := range output {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateTemplate takes a template and returns the output of the template execution.
|
||||
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
142
cmds/output_test.go
Normal file
142
cmds/output_test.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func TestOutHandler(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
saveTestHarnessStdout := testHarnessStdout
|
||||
testHarnessStdout = buf
|
||||
defer func() {
|
||||
testHarnessStdout = saveTestHarnessStdout
|
||||
}()
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{
|
||||
Name: "patrick",
|
||||
},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type NopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (NopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func nopCloser(w io.Writer) io.WriteCloser {
|
||||
return NopWriteCloser{w}
|
||||
}
|
||||
|
||||
func TestOutHandlerFiles(t *testing.T) {
|
||||
saveTestHarnessFileOpen := testHarnessFileOpen
|
||||
defer func() {
|
||||
testHarnessFileOpen = saveTestHarnessFileOpen
|
||||
}()
|
||||
|
||||
file := &bytes.Buffer{}
|
||||
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
|
||||
return nopCloser(file), nil
|
||||
}
|
||||
|
||||
data := tplData{
|
||||
Table: dbdrivers.Table{Name: "patrick"},
|
||||
}
|
||||
|
||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a1 := imports{
|
||||
standard: importList{
|
||||
`"fmt"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a2 := imports{
|
||||
thirdparty: []string{
|
||||
`"github.com/spf13/cobra"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||
t.Errorf("Wrong output: %q", out)
|
||||
}
|
||||
|
||||
a3 := imports{
|
||||
standard: importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
},
|
||||
}
|
||||
file = &bytes.Buffer{}
|
||||
|
||||
sort.Sort(a3.standard)
|
||||
sort.Sort(a3.thirdparty)
|
||||
|
||||
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
expectedOut := `package patrick
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/gorilla/n"
|
||||
"github.com/gorilla/websocket"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
hello world
|
||||
patrick's dreams
|
||||
`
|
||||
|
||||
if out := file.String(); out != expectedOut {
|
||||
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
var testHarnessStdout io.Writer = os.Stdout
|
||||
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||
file, err := os.Create(filename)
|
||||
return file, err
|
||||
}
|
||||
|
||||
// outHandler loops over the slice of byte slices, outputting them to either
|
||||
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
|
||||
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||
out := testHarnessStdout
|
||||
|
||||
var path string
|
||||
if len(cmdData.OutFolder) != 0 {
|
||||
if testTemplate {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||
} else {
|
||||
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||
}
|
||||
|
||||
outFile, err := testHarnessFileOpen(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
out = outFile
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||
}
|
||||
|
||||
impStr := buildImportString(imps)
|
||||
if len(impStr) > 0 {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||
return fmt.Errorf("Unable to write imports to file handle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, templateOutput := range output {
|
||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
copy(c, a)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
copy(c[len(a):], b)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func removeDuplicates(dedup []string) []string {
|
||||
if len(dedup) <= 1 {
|
||||
return dedup
|
||||
}
|
||||
|
||||
for i := 0; i < len(dedup)-1; i++ {
|
||||
for j := i + 1; j < len(dedup); j++ {
|
||||
if dedup[i] != dedup[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
if j != len(dedup)-1 {
|
||||
dedup[j] = dedup[len(dedup)-1]
|
||||
j--
|
||||
}
|
||||
dedup = dedup[:len(dedup)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return dedup
|
||||
}
|
|
@ -94,50 +94,14 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
|||
PkgName: cmdData.PkgName,
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerImports.standard
|
||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||
|
||||
// Loop through and generate every command template (excluding skipTemplates)
|
||||
for _, template := range cmdData.Templates {
|
||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, resp)
|
||||
// Generate the regular templates
|
||||
if err := generateOutput(cmdData, data, false); err != nil {
|
||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||
}
|
||||
|
||||
err := outHandler(cmdData, out, data, imps, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate the test templates for all commands
|
||||
if len(cmdData.TestTemplates) != 0 {
|
||||
var testOut [][]byte
|
||||
var testImps imports
|
||||
|
||||
testImps.standard = sqlBoilerTestImports.standard
|
||||
testImps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||
|
||||
testImps = combineImports(testImps, sqlBoilerDriverTestImports[cmdData.DriverName])
|
||||
|
||||
// Loop through and generate every command test template (excluding skipTemplates)
|
||||
for _, template := range cmdData.TestTemplates {
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
testOut = append(testOut, resp)
|
||||
}
|
||||
|
||||
err = outHandler(cmdData, testOut, data, testImps, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Generate the test templates
|
||||
if err := generateOutput(cmdData, data, true); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,41 +1,13 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
// generateTemplate generates the template associated to the passed in command name.
|
||||
func generateTemplate(template *template.Template, data *tplData) ([]byte, error) {
|
||||
output, err := processTemplate(template, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// processTemplate takes a template and returns the output of the template execution.
|
||||
func processTemplate(t *template.Template, data *tplData) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// plural converts singular words to plural words (eg: person to people)
|
||||
func plural(name string) string {
|
||||
splits := strings.Split(name, "_")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{{- $tableNameSingular := titleCaseSingular .Table -}}
|
||||
{{- $dbName := singular .Table -}}
|
||||
{{- $tableNamePlural := titleCasePlural .Table -}}
|
||||
{{- $varNamePlural := camelCasePlural .Table -}}
|
||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||
{{- $dbName := singular .Table.Name -}}
|
||||
{{- $tableNamePlural := titleCasePlural .Table.Name -}}
|
||||
{{- $varNamePlural := camelCasePlural .Table.Name -}}
|
||||
// {{$tableNamePlural}}All retrieves all records.
|
||||
func Test{{$tableNamePlural}}All(t *testing.T) {
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue