Did more stripping, fixed broken things

* Fixed broken template
* Did some reorganizing
* Need to fix TestTemplates test
This commit is contained in:
Patrick O'brien 2016-03-28 18:17:41 +10:00
parent 27cafdd2fb
commit f7a4ed0c54
10 changed files with 348 additions and 354 deletions

View file

@ -8,7 +8,7 @@ import (
"github.com/BurntSushi/toml"
)
// sqlBoilerDefaultImports defines the list of default template imports.
// sqlBoilerImports defines the list of default template imports.
var sqlBoilerImports = imports{
standard: importList{
`"errors"`,
@ -19,7 +19,7 @@ var sqlBoilerImports = imports{
},
}
// sqlBoilerDefaultTestImports defines the list of default test template imports.
// sqlBoilerTestImports defines the list of default test template imports.
var sqlBoilerTestImports = imports{
standard: importList{
`"testing"`,
@ -58,7 +58,7 @@ var sqlBoilerTypeImports = map[string]imports{
},
}
// sqlBoilerConditionalDriverTestImports defines the test template imports
// sqlBoilerDriverTestImports defines the test template imports
// for the particular database interfaces
var sqlBoilerDriverTestImports = map[string]imports{
"postgres": imports{

View file

@ -96,3 +96,37 @@ func buildImportString(imps imports) []byte {
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}

View file

@ -1,9 +1,7 @@
package cmds
import (
"bytes"
"fmt"
"io"
"reflect"
"sort"
"testing"
@ -11,135 +9,66 @@ import (
"github.com/pobri19/sqlboiler/dbdrivers"
)
func TestOutHandler(t *testing.T) {
buf := &bytes.Buffer{}
saveTestHarnessStdout := testHarnessStdout
testHarnessStdout = buf
defer func() {
testHarnessStdout = saveTestHarnessStdout
}()
data := tplData{
Table: dbdrivers.Table{
Name: "patrick",
},
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err)
}
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
}
type NopWriteCloser struct {
io.Writer
}
func (NopWriteCloser) Close() error {
return nil
}
func nopCloser(w io.Writer) io.WriteCloser {
return NopWriteCloser{w}
}
func TestOutHandlerFiles(t *testing.T) {
saveTestHarnessFileOpen := testHarnessFileOpen
defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen
}()
file := &bytes.Buffer{}
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
return nopCloser(file), nil
}
data := tplData{
Table: dbdrivers.Table{Name: "patrick"},
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a1 := imports{
func TestCombineTypeImports(t *testing.T) {
imports1 := imports{
standard: importList{
`"fmt"`,
},
}
file = &bytes.Buffer{}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a2 := imports{
thirdparty: []string{
`"github.com/spf13/cobra"`,
},
}
file = &bytes.Buffer{}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a3 := imports{
standard: importList{
`"fmt"`,
`"errors"`,
`"fmt"`,
},
thirdparty: importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
`"github.com/pobri19/sqlboiler/boil"`,
},
}
file = &bytes.Buffer{}
sort.Sort(a3.standard)
sort.Sort(a3.thirdparty)
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
t.Error(err)
importsExpected := imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"time"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
`"gopkg.in/guregu/null.v3"`,
},
}
expectedOut := `package patrick
cols := []dbdrivers.Column{
dbdrivers.Column{
Type: "null.Time",
},
dbdrivers.Column{
Type: "null.Time",
},
dbdrivers.Column{
Type: "time.Time",
},
dbdrivers.Column{
Type: "null.Float",
},
}
import (
"errors"
"fmt"
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
"github.com/gorilla/mux"
_ "github.com/gorilla/n"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
}
hello world
patrick's dreams
`
imports2 := imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"time"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
`"gopkg.in/guregu/null.v3"`,
},
}
if out := file.String(); out != expectedOut {
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
}
}

View file

@ -1,71 +0,0 @@
package cmds
import (
"reflect"
"testing"
"github.com/pobri19/sqlboiler/dbdrivers"
)
func TestCombineTypeImports(t *testing.T) {
imports1 := imports{
standard: importList{
`"errors"`,
`"fmt"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
},
}
importsExpected := imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"time"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
`"gopkg.in/guregu/null.v3"`,
},
}
cols := []dbdrivers.Column{
dbdrivers.Column{
Type: "null.Time",
},
dbdrivers.Column{
Type: "null.Time",
},
dbdrivers.Column{
Type: "time.Time",
},
dbdrivers.Column{
Type: "null.Float",
},
}
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
}
imports2 := imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"time"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
`"gopkg.in/guregu/null.v3"`,
},
}
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
}
}

112
cmds/output.go Normal file
View file

@ -0,0 +1,112 @@
package cmds
import (
"bytes"
"fmt"
"go/format"
"io"
"os"
"text/template"
)
var testHarnessStdout io.Writer = os.Stdout
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
file, err := os.Create(filename)
return file, err
}
func generateOutput(cmdData *CmdData, data *tplData, testOutput bool) error {
if (testOutput && len(cmdData.TestTemplates) == 0) || (!testOutput && len(cmdData.Templates) == 0) {
fmt.Println("No template files located for generation")
return nil
}
var out [][]byte
var imps imports
var tpls []*template.Template
if testOutput {
imps.standard = sqlBoilerTestImports.standard
imps.thirdparty = sqlBoilerTestImports.thirdparty
imps = combineImports(imps, sqlBoilerDriverTestImports[cmdData.DriverName])
tpls = cmdData.TestTemplates
} else {
imps.standard = sqlBoilerImports.standard
imps.thirdparty = sqlBoilerImports.thirdparty
tpls = cmdData.Templates
}
// Loop through and generate every individual template
for _, template := range tpls {
if !testOutput {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
}
resp, err := generateTemplate(template, data)
if err != nil {
return err
}
out = append(out, resp)
}
if err := outHandler(cmdData, out, data, imps, testOutput); err != nil {
return err
}
return nil
}
// outHandler loops over each template in the slice of byte slices and builds an output file.
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
out := testHarnessStdout
var path string
if len(cmdData.OutFolder) != 0 {
if testTemplate {
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
} else {
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
}
outFile, err := testHarnessFileOpen(path)
if err != nil {
return fmt.Errorf("Unable to create output file %s: %s", path, err)
}
defer outFile.Close()
out = outFile
}
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
}
impStr := buildImportString(imps)
if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
return fmt.Errorf("Unable to write imports to file handle: %v", err)
}
}
for _, templateOutput := range output {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
return fmt.Errorf("Unable to write template output to file handle: %v", err)
}
}
return nil
}
// generateTemplate takes a template and returns the output of the template execution.
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) {
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, err
}
output, err := format.Source(buf.Bytes())
if err != nil {
return nil, err
}
return output, nil
}

142
cmds/output_test.go Normal file
View file

@ -0,0 +1,142 @@
package cmds
import (
"bytes"
"io"
"sort"
"testing"
"github.com/pobri19/sqlboiler/dbdrivers"
)
func TestOutHandler(t *testing.T) {
buf := &bytes.Buffer{}
saveTestHarnessStdout := testHarnessStdout
testHarnessStdout = buf
defer func() {
testHarnessStdout = saveTestHarnessStdout
}()
data := tplData{
Table: dbdrivers.Table{
Name: "patrick",
},
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err)
}
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
}
type NopWriteCloser struct {
io.Writer
}
func (NopWriteCloser) Close() error {
return nil
}
func nopCloser(w io.Writer) io.WriteCloser {
return NopWriteCloser{w}
}
func TestOutHandlerFiles(t *testing.T) {
saveTestHarnessFileOpen := testHarnessFileOpen
defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen
}()
file := &bytes.Buffer{}
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
return nopCloser(file), nil
}
data := tplData{
Table: dbdrivers.Table{Name: "patrick"},
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a1 := imports{
standard: importList{
`"fmt"`,
},
}
file = &bytes.Buffer{}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a2 := imports{
thirdparty: []string{
`"github.com/spf13/cobra"`,
},
}
file = &bytes.Buffer{}
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a3 := imports{
standard: importList{
`"fmt"`,
`"errors"`,
},
thirdparty: importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
},
}
file = &bytes.Buffer{}
sort.Sort(a3.standard)
sort.Sort(a3.thirdparty)
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
t.Error(err)
}
expectedOut := `package patrick
import (
"errors"
"fmt"
"github.com/gorilla/mux"
_ "github.com/gorilla/n"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
hello world
patrick's dreams
`
if out := file.String(); out != expectedOut {
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
}
}

View file

@ -1,88 +0,0 @@
package cmds
import (
"fmt"
"io"
"os"
)
var testHarnessStdout io.Writer = os.Stdout
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
file, err := os.Create(filename)
return file, err
}
// outHandler loops over the slice of byte slices, outputting them to either
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
out := testHarnessStdout
var path string
if len(cmdData.OutFolder) != 0 {
if testTemplate {
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
} else {
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
}
outFile, err := testHarnessFileOpen(path)
if err != nil {
return fmt.Errorf("Unable to create output file %s: %s", path, err)
}
defer outFile.Close()
out = outFile
}
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
}
impStr := buildImportString(imps)
if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
return fmt.Errorf("Unable to write imports to file handle: %v", err)
}
}
for _, templateOutput := range output {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
return fmt.Errorf("Unable to write template output to file handle: %v", err)
}
}
return nil
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}

View file

@ -94,50 +94,14 @@ func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
PkgName: cmdData.PkgName,
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerImports.standard
imps.thirdparty = sqlBoilerImports.thirdparty
// Loop through and generate every command template (excluding skipTemplates)
for _, template := range cmdData.Templates {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
resp, err := generateTemplate(template, data)
if err != nil {
return err
}
out = append(out, resp)
// Generate the regular templates
if err := generateOutput(cmdData, data, false); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
err := outHandler(cmdData, out, data, imps, false)
if err != nil {
return err
}
// Generate the test templates for all commands
if len(cmdData.TestTemplates) != 0 {
var testOut [][]byte
var testImps imports
testImps.standard = sqlBoilerTestImports.standard
testImps.thirdparty = sqlBoilerTestImports.thirdparty
testImps = combineImports(testImps, sqlBoilerDriverTestImports[cmdData.DriverName])
// Loop through and generate every command test template (excluding skipTemplates)
for _, template := range cmdData.TestTemplates {
resp, err := generateTemplate(template, data)
if err != nil {
return err
}
testOut = append(testOut, resp)
}
err = outHandler(cmdData, testOut, data, testImps, true)
if err != nil {
return err
}
// Generate the test templates
if err := generateOutput(cmdData, data, true); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
}

View file

@ -1,41 +1,13 @@
package cmds
import (
"bytes"
"fmt"
"go/format"
"strings"
"text/template"
"github.com/jinzhu/inflection"
"github.com/pobri19/sqlboiler/dbdrivers"
)
// generateTemplate generates the template associated to the passed in command name.
func generateTemplate(template *template.Template, data *tplData) ([]byte, error) {
output, err := processTemplate(template, data)
if err != nil {
return nil, fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err)
}
return output, nil
}
// processTemplate takes a template and returns the output of the template execution.
func processTemplate(t *template.Template, data *tplData) ([]byte, error) {
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, err
}
output, err := format.Source(buf.Bytes())
if err != nil {
return nil, err
}
return output, nil
}
// plural converts singular words to plural words (eg: person to people)
func plural(name string) string {
splits := strings.Split(name, "_")

View file

@ -1,7 +1,7 @@
{{- $tableNameSingular := titleCaseSingular .Table -}}
{{- $dbName := singular .Table -}}
{{- $tableNamePlural := titleCasePlural .Table -}}
{{- $varNamePlural := camelCasePlural .Table -}}
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
{{- $dbName := singular .Table.Name -}}
{{- $tableNamePlural := titleCasePlural .Table.Name -}}
{{- $varNamePlural := camelCasePlural .Table.Name -}}
// {{$tableNamePlural}}All retrieves all records.
func Test{{$tableNamePlural}}All(t *testing.T) {