Refactor output.

- Simplify several methods
- Gofmt full output of templates, not individual pieces
- Re-use a global buffer to use less memory during template generation
- Simplify the tests since the main test is responsible for checking
  everything.
This commit is contained in:
Aaron L 2016-09-24 00:51:02 -07:00
parent e9eda8fa1b
commit 9e4b5b750c
2 changed files with 131 additions and 199 deletions

205
output.go
View file

@ -5,8 +5,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"go/format" "go/format"
"io" "io/ioutil"
"os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv" "strconv"
@ -15,11 +14,17 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var testHarnessStdout io.Writer = os.Stdout var (
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { // templateByteBuffer is re-used by all template construction to avoid
file, err := os.Create(filename) // allocating more memory than is needed. This will later be a problem for
return file, err // concurrency, address it then.
} templateByteBuffer = &bytes.Buffer{}
rgxRemoveNumberedPrefix = regexp.MustCompile(`[0-9]+_`)
rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `)
testHarnessWriteFile = ioutil.WriteFile
)
// generateOutput builds the file output and sends it to outHandler for saving // generateOutput builds the file output and sends it to outHandler for saving
func generateOutput(state *State, data *templateData) error { func generateOutput(state *State, data *templateData) error {
@ -88,26 +93,27 @@ func executeTemplates(e executeTemplateData) error {
return nil return nil
} }
var out [][]byte out := templateByteBuffer
var imps imports out.Reset()
var imps imports
imps.standard = e.importSet.standard imps.standard = e.importSet.standard
imps.thirdParty = e.importSet.thirdParty imps.thirdParty = e.importSet.thirdParty
for _, tplName := range e.templates.Templates() { if e.combineImportsOnType {
if e.combineImportsOnType { imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns)
imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns) }
}
resp, err := executeTemplate(e.templates.Template, tplName, e.data) writePackageName(out, e.state.Config.PkgName)
if err != nil { writeImports(out, imps)
return errors.Wrapf(err, "Error generating template %s", tplName)
for _, tplName := range e.templates.Templates() {
if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil {
return err
} }
out = append(out, resp)
} }
fName := e.data.Table.Name + e.fileSuffix fName := e.data.Table.Name + e.fileSuffix
err := outHandler(e.state.Config.OutFolder, fName, e.state.Config.PkgName, imps, out) if err := writeFile(e.state.Config.OutFolder, fName, out); err != nil {
if err != nil {
return err return err
} }
@ -119,31 +125,27 @@ func executeSingletonTemplates(e executeTemplateData) error {
return nil return nil
} }
rgxRemove := regexp.MustCompile(`[0-9]+_`) out := templateByteBuffer
for _, tplName := range e.templates.Templates() { for _, tplName := range e.templates.Templates() {
resp, err := executeTemplate(e.templates.Template, tplName, e.data) out.Reset()
if err != nil {
return errors.Wrapf(err, "Error generating template %s", tplName)
}
fName := tplName fName := tplName
ext := filepath.Ext(fName) ext := filepath.Ext(fName)
fName = rgxRemove.ReplaceAllString(fName[:len(fName)-len(ext)], "") fName = rgxRemoveNumberedPrefix.ReplaceAllString(fName[:len(fName)-len(ext)], "")
imps := imports{ imps := imports{
standard: e.importNamedSet[fName].standard, standard: e.importNamedSet[fName].standard,
thirdParty: e.importNamedSet[fName].thirdParty, thirdParty: e.importNamedSet[fName].thirdParty,
} }
err = outHandler( writePackageName(out, e.state.Config.PkgName)
e.state.Config.OutFolder, writeImports(out, imps)
fName+e.fileSuffix,
e.state.Config.PkgName, if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil {
imps, return err
[][]byte{resp}, }
)
if err != nil { if err := writeFile(e.state.Config.OutFolder, fName+e.fileSuffix, out); err != nil {
return err return err
} }
} }
@ -156,95 +158,94 @@ func generateTestMainOutput(state *State, data *templateData) error {
return errors.New("No TestMain template located for generation") return errors.New("No TestMain template located for generation")
} }
var out [][]byte out := templateByteBuffer
var imps imports out.Reset()
var imps imports
imps.standard = defaultTestMainImports[state.Config.DriverName].standard imps.standard = defaultTestMainImports[state.Config.DriverName].standard
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
resp, err := executeTemplate(state.TestMainTemplate, state.TestMainTemplate.Name(), data) writePackageName(out, state.Config.PkgName)
if err != nil { writeImports(out, imps)
if err := executeTemplate(out, state.TestMainTemplate, state.TestMainTemplate.Name(), data); err != nil {
return err return err
} }
out = append(out, resp)
err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out) if err := writeFile(state.Config.OutFolder, "main_test.go", out); err != nil {
if err != nil {
return err return err
} }
return nil return nil
} }
func outHandler(outFolder string, fileName string, pkgName string, imps imports, contents [][]byte) error { // writePackageName writes the package name correctly, ignores errors
out := testHarnessStdout // since it's to the concrete buffer type which produces none
func writePackageName(out *bytes.Buffer, pkgName string) {
_, _ = fmt.Fprintf(out, "package %s\n\n", pkgName)
}
// writeImports writes the package imports correctly, ignores errors
// since it's to the concrete buffer type which produces none
func writeImports(out *bytes.Buffer, imps imports) {
if impStr := buildImportString(imps); len(impStr) > 0 {
_, _ = fmt.Fprintf(out, "%s\n", impStr)
}
}
// writeFile writes to the given folder and filename, formatting the buffer
// given.
func writeFile(outFolder string, fileName string, input *bytes.Buffer) error {
byt, err := formatBuffer(input)
if err != nil {
return err
}
path := filepath.Join(outFolder, fileName) path := filepath.Join(outFolder, fileName)
if err = testHarnessWriteFile(path, byt, 0666); err != nil {
outFile, err := testHarnessFileOpen(path) return errors.Wrapf(err, "failed to write output file %s", path)
if err != nil {
return errors.Wrapf(err, "Unable to create output file %s", path)
}
defer outFile.Close()
out = outFile
if _, err := fmt.Fprintf(out, "package %s\n\n", pkgName); err != nil {
return errors.Errorf("Unable to write package name %s to file: %s", pkgName, path)
}
impStr := buildImportString(imps)
if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
return errors.Wrap(err, "Unable to write imports to file handle")
}
}
for _, templateOutput := range contents {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
return errors.Wrap(err, "Unable to write template output to file handle")
}
} }
return nil return nil
} }
var rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `)
// executeTemplate takes a template and returns the output of the template // executeTemplate takes a template and returns the output of the template
// execution. // execution.
func executeTemplate(t *template.Template, name string, data *templateData) ([]byte, error) { func executeTemplate(buf *bytes.Buffer, t *template.Template, name string, data *templateData) error {
var buf bytes.Buffer if err := t.ExecuteTemplate(buf, name, data); err != nil {
if err := t.ExecuteTemplate(&buf, name, data); err != nil { return errors.Wrapf(err, "failed to execute template: %s", name)
return nil, errors.Wrap(err, "failed to execute template")
} }
return nil
output, err := format.Source(buf.Bytes()) }
if err != nil {
matches := rgxSyntaxError.FindStringSubmatch(err.Error()) func formatBuffer(buf *bytes.Buffer) ([]byte, error) {
if matches == nil { output, err := format.Source(buf.Bytes())
return nil, errors.Wrap(err, "failed to format template") if err == nil {
} return output, nil
}
lineNum, _ := strconv.Atoi(matches[1])
scanner := bufio.NewScanner(&buf) matches := rgxSyntaxError.FindStringSubmatch(err.Error())
errBuf := &bytes.Buffer{} if matches == nil {
line := 0 return nil, errors.Wrap(err, "failed to format template")
for ; scanner.Scan(); line++ { }
if delta := line - lineNum; delta < -5 || delta > 5 {
continue lineNum, _ := strconv.Atoi(matches[1])
} scanner := bufio.NewScanner(buf)
errBuf := &bytes.Buffer{}
if line == lineNum { line := 1
errBuf.WriteString(">>> ") for ; scanner.Scan(); line++ {
} else { if delta := line - lineNum; delta < -5 || delta > 5 {
fmt.Fprintf(errBuf, "% 3d ", line) continue
} }
errBuf.Write(scanner.Bytes())
errBuf.WriteByte('\n') if line == lineNum {
} errBuf.WriteString(">>>> ")
} else {
return nil, errors.Wrapf(err, "failed to format template\n\n%s\n", errBuf.Bytes()) fmt.Fprintf(errBuf, "% 4d ", line)
} }
errBuf.Write(scanner.Bytes())
return output, nil errBuf.WriteByte('\n')
}
return nil, errors.Wrapf(err, "failed to format template\n\n%s\n", errBuf.Bytes())
} }

View file

@ -2,8 +2,10 @@ package main
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"sort" "os"
"strings"
"testing" "testing"
) )
@ -19,120 +21,49 @@ func nopCloser(w io.Writer) io.WriteCloser {
return NopWriteCloser{w} return NopWriteCloser{w}
} }
func TestOutHandler(t *testing.T) { func TestWriteFile(t *testing.T) {
// t.Parallel() cannot be used // t.Parallel() cannot be used
// set the function pointer back to its original value // set the function pointer back to its original value
// after we modify it for the test // after we modify it for the test
saveTestHarnessFileOpen := testHarnessFileOpen saveTestHarnessWriteFile := testHarnessWriteFile
defer func() { defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen testHarnessWriteFile = saveTestHarnessWriteFile
}() }()
var output []byte
testHarnessWriteFile = func(_ string, in []byte, _ os.FileMode) error {
output = in
return nil
}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
testHarnessFileOpen = func(path string) (io.WriteCloser, error) { writePackageName(buf, "pkg")
return nopCloser(buf), nil fmt.Fprintf(buf, "func hello() {}\n\n\nfunc world() {\nreturn\n}\n\n\n\n")
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} if err := writeFile("", "", buf); err != nil {
if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil {
t.Error(err) t.Error(err)
} }
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { if string(output) != "package pkg\n\nfunc hello() {}\n\nfunc world() {\n\treturn\n}\n" {
t.Errorf("Wrong output: %q", out) t.Errorf("Wrong output: %q", output)
} }
} }
func TestOutHandlerFiles(t *testing.T) { func TestFormatBuffer(t *testing.T) {
// t.Parallel() cannot be used t.Parallel()
saveTestHarnessFileOpen := testHarnessFileOpen buf := &bytes.Buffer{}
defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen
}()
file := &bytes.Buffer{} fmt.Fprintf(buf, "package pkg\n\nfunc() {a}\n")
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
return nopCloser(file), nil // Only test error case - happy case is taken care of by template test
_, err := formatBuffer(buf)
if err == nil {
t.Error("want an error")
} }
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} if txt := err.Error(); !strings.Contains(txt, ">>>> func() {a}") {
t.Error("got:\n", txt)
if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a1 := imports{
standard: importList{
`"fmt"`,
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a2 := imports{
thirdParty: []string{
`"github.com/spf13/cobra"`,
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a3 := imports{
standard: importList{
`"fmt"`,
`"errors"`,
},
thirdParty: importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
},
}
file = &bytes.Buffer{}
sort.Sort(a3.standard)
sort.Sort(a3.thirdParty)
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
t.Error(err)
}
expectedOut := `package patrick
import (
"errors"
"fmt"
"github.com/gorilla/mux"
_ "github.com/gorilla/n"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
hello world
patrick's dreams
`
if out := file.String(); out != expectedOut {
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
} }
} }