Refactor output.

- Simplify several methods
- Gofmt full output of templates, not individual pieces
- Re-use a global buffer to use less memory during template generation
- Simplify the tests since the main test is responsible for checking
  everything.
This commit is contained in:
Aaron L 2016-09-24 00:51:02 -07:00
parent e9eda8fa1b
commit 9e4b5b750c
2 changed files with 131 additions and 199 deletions

157
output.go
View file

@ -5,8 +5,7 @@ import (
"bytes"
"fmt"
"go/format"
"io"
"os"
"io/ioutil"
"path/filepath"
"regexp"
"strconv"
@ -15,11 +14,17 @@ import (
"github.com/pkg/errors"
)
var testHarnessStdout io.Writer = os.Stdout
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
file, err := os.Create(filename)
return file, err
}
var (
// templateByteBuffer is re-used by all template construction to avoid
// allocating more memory than is needed. This will later be a problem for
// concurrency, address it then.
templateByteBuffer = &bytes.Buffer{}
rgxRemoveNumberedPrefix = regexp.MustCompile(`[0-9]+_`)
rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `)
testHarnessWriteFile = ioutil.WriteFile
)
// generateOutput builds the file output and sends it to outHandler for saving
func generateOutput(state *State, data *templateData) error {
@ -88,26 +93,27 @@ func executeTemplates(e executeTemplateData) error {
return nil
}
var out [][]byte
var imps imports
out := templateByteBuffer
out.Reset()
var imps imports
imps.standard = e.importSet.standard
imps.thirdParty = e.importSet.thirdParty
for _, tplName := range e.templates.Templates() {
if e.combineImportsOnType {
imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns)
}
resp, err := executeTemplate(e.templates.Template, tplName, e.data)
if err != nil {
return errors.Wrapf(err, "Error generating template %s", tplName)
writePackageName(out, e.state.Config.PkgName)
writeImports(out, imps)
for _, tplName := range e.templates.Templates() {
if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil {
return err
}
out = append(out, resp)
}
fName := e.data.Table.Name + e.fileSuffix
err := outHandler(e.state.Config.OutFolder, fName, e.state.Config.PkgName, imps, out)
if err != nil {
if err := writeFile(e.state.Config.OutFolder, fName, out); err != nil {
return err
}
@ -119,31 +125,27 @@ func executeSingletonTemplates(e executeTemplateData) error {
return nil
}
rgxRemove := regexp.MustCompile(`[0-9]+_`)
out := templateByteBuffer
for _, tplName := range e.templates.Templates() {
resp, err := executeTemplate(e.templates.Template, tplName, e.data)
if err != nil {
return errors.Wrapf(err, "Error generating template %s", tplName)
}
out.Reset()
fName := tplName
ext := filepath.Ext(fName)
fName = rgxRemove.ReplaceAllString(fName[:len(fName)-len(ext)], "")
fName = rgxRemoveNumberedPrefix.ReplaceAllString(fName[:len(fName)-len(ext)], "")
imps := imports{
standard: e.importNamedSet[fName].standard,
thirdParty: e.importNamedSet[fName].thirdParty,
}
err = outHandler(
e.state.Config.OutFolder,
fName+e.fileSuffix,
e.state.Config.PkgName,
imps,
[][]byte{resp},
)
if err != nil {
writePackageName(out, e.state.Config.PkgName)
writeImports(out, imps)
if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil {
return err
}
if err := writeFile(e.state.Config.OutFolder, fName+e.fileSuffix, out); err != nil {
return err
}
}
@ -156,95 +158,94 @@ func generateTestMainOutput(state *State, data *templateData) error {
return errors.New("No TestMain template located for generation")
}
var out [][]byte
var imps imports
out := templateByteBuffer
out.Reset()
var imps imports
imps.standard = defaultTestMainImports[state.Config.DriverName].standard
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
resp, err := executeTemplate(state.TestMainTemplate, state.TestMainTemplate.Name(), data)
if err != nil {
writePackageName(out, state.Config.PkgName)
writeImports(out, imps)
if err := executeTemplate(out, state.TestMainTemplate, state.TestMainTemplate.Name(), data); err != nil {
return err
}
out = append(out, resp)
err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out)
if err != nil {
if err := writeFile(state.Config.OutFolder, "main_test.go", out); err != nil {
return err
}
return nil
}
func outHandler(outFolder string, fileName string, pkgName string, imps imports, contents [][]byte) error {
out := testHarnessStdout
// writePackageName writes the package name correctly, ignores errors
// since it's to the concrete buffer type which produces none
func writePackageName(out *bytes.Buffer, pkgName string) {
_, _ = fmt.Fprintf(out, "package %s\n\n", pkgName)
}
// writeImports writes the package imports correctly, ignores errors
// since it's to the concrete buffer type which produces none
func writeImports(out *bytes.Buffer, imps imports) {
if impStr := buildImportString(imps); len(impStr) > 0 {
_, _ = fmt.Fprintf(out, "%s\n", impStr)
}
}
// writeFile writes to the given folder and filename, formatting the buffer
// given.
func writeFile(outFolder string, fileName string, input *bytes.Buffer) error {
byt, err := formatBuffer(input)
if err != nil {
return err
}
path := filepath.Join(outFolder, fileName)
outFile, err := testHarnessFileOpen(path)
if err != nil {
return errors.Wrapf(err, "Unable to create output file %s", path)
}
defer outFile.Close()
out = outFile
if _, err := fmt.Fprintf(out, "package %s\n\n", pkgName); err != nil {
return errors.Errorf("Unable to write package name %s to file: %s", pkgName, path)
}
impStr := buildImportString(imps)
if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
return errors.Wrap(err, "Unable to write imports to file handle")
}
}
for _, templateOutput := range contents {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
return errors.Wrap(err, "Unable to write template output to file handle")
}
if err = testHarnessWriteFile(path, byt, 0666); err != nil {
return errors.Wrapf(err, "failed to write output file %s", path)
}
return nil
}
var rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `)
// executeTemplate takes a template and returns the output of the template
// execution.
func executeTemplate(t *template.Template, name string, data *templateData) ([]byte, error) {
var buf bytes.Buffer
if err := t.ExecuteTemplate(&buf, name, data); err != nil {
return nil, errors.Wrap(err, "failed to execute template")
func executeTemplate(buf *bytes.Buffer, t *template.Template, name string, data *templateData) error {
if err := t.ExecuteTemplate(buf, name, data); err != nil {
return errors.Wrapf(err, "failed to execute template: %s", name)
}
return nil
}
func formatBuffer(buf *bytes.Buffer) ([]byte, error) {
output, err := format.Source(buf.Bytes())
if err == nil {
return output, nil
}
output, err := format.Source(buf.Bytes())
if err != nil {
matches := rgxSyntaxError.FindStringSubmatch(err.Error())
if matches == nil {
return nil, errors.Wrap(err, "failed to format template")
}
lineNum, _ := strconv.Atoi(matches[1])
scanner := bufio.NewScanner(&buf)
scanner := bufio.NewScanner(buf)
errBuf := &bytes.Buffer{}
line := 0
line := 1
for ; scanner.Scan(); line++ {
if delta := line - lineNum; delta < -5 || delta > 5 {
continue
}
if line == lineNum {
errBuf.WriteString(">>> ")
errBuf.WriteString(">>>> ")
} else {
fmt.Fprintf(errBuf, "% 3d ", line)
fmt.Fprintf(errBuf, "% 4d ", line)
}
errBuf.Write(scanner.Bytes())
errBuf.WriteByte('\n')
}
return nil, errors.Wrapf(err, "failed to format template\n\n%s\n", errBuf.Bytes())
}
return output, nil
}

View file

@ -2,8 +2,10 @@ package main
import (
"bytes"
"fmt"
"io"
"sort"
"os"
"strings"
"testing"
)
@ -19,120 +21,49 @@ func nopCloser(w io.Writer) io.WriteCloser {
return NopWriteCloser{w}
}
func TestOutHandler(t *testing.T) {
func TestWriteFile(t *testing.T) {
// t.Parallel() cannot be used
// set the function pointer back to its original value
// after we modify it for the test
saveTestHarnessFileOpen := testHarnessFileOpen
saveTestHarnessWriteFile := testHarnessWriteFile
defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen
testHarnessWriteFile = saveTestHarnessWriteFile
}()
var output []byte
testHarnessWriteFile = func(_ string, in []byte, _ os.FileMode) error {
output = in
return nil
}
buf := &bytes.Buffer{}
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
return nopCloser(buf), nil
}
writePackageName(buf, "pkg")
fmt.Fprintf(buf, "func hello() {}\n\n\nfunc world() {\nreturn\n}\n\n\n\n")
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil {
if err := writeFile("", "", buf); err != nil {
t.Error(err)
}
if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
if string(output) != "package pkg\n\nfunc hello() {}\n\nfunc world() {\n\treturn\n}\n" {
t.Errorf("Wrong output: %q", output)
}
}
func TestOutHandlerFiles(t *testing.T) {
// t.Parallel() cannot be used
func TestFormatBuffer(t *testing.T) {
t.Parallel()
saveTestHarnessFileOpen := testHarnessFileOpen
defer func() {
testHarnessFileOpen = saveTestHarnessFileOpen
}()
buf := &bytes.Buffer{}
file := &bytes.Buffer{}
testHarnessFileOpen = func(path string) (io.WriteCloser, error) {
return nopCloser(file), nil
fmt.Fprintf(buf, "package pkg\n\nfunc() {a}\n")
// Only test error case - happy case is taken care of by template test
_, err := formatBuffer(buf)
if err == nil {
t.Error("want an error")
}
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a1 := imports{
standard: importList{
`"fmt"`,
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a2 := imports{
thirdParty: []string{
`"github.com/spf13/cobra"`,
},
}
file = &bytes.Buffer{}
if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil {
t.Error(err)
}
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
t.Errorf("Wrong output: %q", out)
}
a3 := imports{
standard: importList{
`"fmt"`,
`"errors"`,
},
thirdParty: importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
},
}
file = &bytes.Buffer{}
sort.Sort(a3.standard)
sort.Sort(a3.thirdParty)
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
t.Error(err)
}
expectedOut := `package patrick
import (
"errors"
"fmt"
"github.com/gorilla/mux"
_ "github.com/gorilla/n"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
hello world
patrick's dreams
`
if out := file.String(); out != expectedOut {
t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut)
if txt := err.Error(); !strings.Contains(txt, ">>>> func() {a}") {
t.Error("got:\n", txt)
}
}