From 9e4b5b750c6966f2d70bfede73c62e8890abfdcd Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sat, 24 Sep 2016 00:51:02 -0700 Subject: [PATCH] Refactor output. - Simplify several methods - Gofmt full output of templates, not individual pieces - Re-use a global buffer to use less memory during template generation - Simplify the tests since the main test is responsible for checking everything. --- output.go | 205 +++++++++++++++++++++++++------------------------ output_test.go | 125 +++++++----------------------- 2 files changed, 131 insertions(+), 199 deletions(-) diff --git a/output.go b/output.go index d7d159a..d6914c3 100644 --- a/output.go +++ b/output.go @@ -5,8 +5,7 @@ import ( "bytes" "fmt" "go/format" - "io" - "os" + "io/ioutil" "path/filepath" "regexp" "strconv" @@ -15,11 +14,17 @@ import ( "github.com/pkg/errors" ) -var testHarnessStdout io.Writer = os.Stdout -var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { - file, err := os.Create(filename) - return file, err -} +var ( + // templateByteBuffer is re-used by all template construction to avoid + // allocating more memory than is needed. This will later be a problem for + // concurrency, address it then. + templateByteBuffer = &bytes.Buffer{} + + rgxRemoveNumberedPrefix = regexp.MustCompile(`[0-9]+_`) + rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `) + + testHarnessWriteFile = ioutil.WriteFile +) // generateOutput builds the file output and sends it to outHandler for saving func generateOutput(state *State, data *templateData) error { @@ -88,26 +93,27 @@ func executeTemplates(e executeTemplateData) error { return nil } - var out [][]byte - var imps imports + out := templateByteBuffer + out.Reset() + var imps imports imps.standard = e.importSet.standard imps.thirdParty = e.importSet.thirdParty - for _, tplName := range e.templates.Templates() { - if e.combineImportsOnType { - imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns) - } + if e.combineImportsOnType { + imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns) + } - resp, err := executeTemplate(e.templates.Template, tplName, e.data) - if err != nil { - return errors.Wrapf(err, "Error generating template %s", tplName) + writePackageName(out, e.state.Config.PkgName) + writeImports(out, imps) + + for _, tplName := range e.templates.Templates() { + if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil { + return err } - out = append(out, resp) } fName := e.data.Table.Name + e.fileSuffix - err := outHandler(e.state.Config.OutFolder, fName, e.state.Config.PkgName, imps, out) - if err != nil { + if err := writeFile(e.state.Config.OutFolder, fName, out); err != nil { return err } @@ -119,31 +125,27 @@ func executeSingletonTemplates(e executeTemplateData) error { return nil } - rgxRemove := regexp.MustCompile(`[0-9]+_`) - + out := templateByteBuffer for _, tplName := range e.templates.Templates() { - resp, err := executeTemplate(e.templates.Template, tplName, e.data) - if err != nil { - return errors.Wrapf(err, "Error generating template %s", tplName) - } + out.Reset() fName := tplName ext := filepath.Ext(fName) - fName = rgxRemove.ReplaceAllString(fName[:len(fName)-len(ext)], "") + fName = rgxRemoveNumberedPrefix.ReplaceAllString(fName[:len(fName)-len(ext)], "") imps := imports{ standard: e.importNamedSet[fName].standard, thirdParty: e.importNamedSet[fName].thirdParty, } - err = outHandler( - e.state.Config.OutFolder, - fName+e.fileSuffix, - e.state.Config.PkgName, - imps, - [][]byte{resp}, - ) - if err != nil { + writePackageName(out, e.state.Config.PkgName) + writeImports(out, imps) + + if err := executeTemplate(out, e.templates.Template, tplName, e.data); err != nil { + return err + } + + if err := writeFile(e.state.Config.OutFolder, fName+e.fileSuffix, out); err != nil { return err } } @@ -156,95 +158,94 @@ func generateTestMainOutput(state *State, data *templateData) error { return errors.New("No TestMain template located for generation") } - var out [][]byte - var imps imports + out := templateByteBuffer + out.Reset() + var imps imports imps.standard = defaultTestMainImports[state.Config.DriverName].standard imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty - resp, err := executeTemplate(state.TestMainTemplate, state.TestMainTemplate.Name(), data) - if err != nil { + writePackageName(out, state.Config.PkgName) + writeImports(out, imps) + + if err := executeTemplate(out, state.TestMainTemplate, state.TestMainTemplate.Name(), data); err != nil { return err } - out = append(out, resp) - err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out) - if err != nil { + if err := writeFile(state.Config.OutFolder, "main_test.go", out); err != nil { return err } return nil } -func outHandler(outFolder string, fileName string, pkgName string, imps imports, contents [][]byte) error { - out := testHarnessStdout +// writePackageName writes the package name correctly, ignores errors +// since it's to the concrete buffer type which produces none +func writePackageName(out *bytes.Buffer, pkgName string) { + _, _ = fmt.Fprintf(out, "package %s\n\n", pkgName) +} + +// writeImports writes the package imports correctly, ignores errors +// since it's to the concrete buffer type which produces none +func writeImports(out *bytes.Buffer, imps imports) { + if impStr := buildImportString(imps); len(impStr) > 0 { + _, _ = fmt.Fprintf(out, "%s\n", impStr) + } +} + +// writeFile writes to the given folder and filename, formatting the buffer +// given. +func writeFile(outFolder string, fileName string, input *bytes.Buffer) error { + byt, err := formatBuffer(input) + if err != nil { + return err + } path := filepath.Join(outFolder, fileName) - - outFile, err := testHarnessFileOpen(path) - if err != nil { - return errors.Wrapf(err, "Unable to create output file %s", path) - } - defer outFile.Close() - out = outFile - - if _, err := fmt.Fprintf(out, "package %s\n\n", pkgName); err != nil { - return errors.Errorf("Unable to write package name %s to file: %s", pkgName, path) - } - - impStr := buildImportString(imps) - if len(impStr) > 0 { - if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil { - return errors.Wrap(err, "Unable to write imports to file handle") - } - } - - for _, templateOutput := range contents { - if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil { - return errors.Wrap(err, "Unable to write template output to file handle") - } + if err = testHarnessWriteFile(path, byt, 0666); err != nil { + return errors.Wrapf(err, "failed to write output file %s", path) } return nil } -var rgxSyntaxError = regexp.MustCompile(`(\d+):\d+: `) - // executeTemplate takes a template and returns the output of the template // execution. -func executeTemplate(t *template.Template, name string, data *templateData) ([]byte, error) { - var buf bytes.Buffer - if err := t.ExecuteTemplate(&buf, name, data); err != nil { - return nil, errors.Wrap(err, "failed to execute template") +func executeTemplate(buf *bytes.Buffer, t *template.Template, name string, data *templateData) error { + if err := t.ExecuteTemplate(buf, name, data); err != nil { + return errors.Wrapf(err, "failed to execute template: %s", name) } - - output, err := format.Source(buf.Bytes()) - if err != nil { - matches := rgxSyntaxError.FindStringSubmatch(err.Error()) - if matches == nil { - return nil, errors.Wrap(err, "failed to format template") - } - - lineNum, _ := strconv.Atoi(matches[1]) - scanner := bufio.NewScanner(&buf) - errBuf := &bytes.Buffer{} - line := 0 - for ; scanner.Scan(); line++ { - if delta := line - lineNum; delta < -5 || delta > 5 { - continue - } - - if line == lineNum { - errBuf.WriteString(">>> ") - } else { - fmt.Fprintf(errBuf, "% 3d ", line) - } - errBuf.Write(scanner.Bytes()) - errBuf.WriteByte('\n') - } - - return nil, errors.Wrapf(err, "failed to format template\n\n%s\n", errBuf.Bytes()) - } - - return output, nil + return nil +} + +func formatBuffer(buf *bytes.Buffer) ([]byte, error) { + output, err := format.Source(buf.Bytes()) + if err == nil { + return output, nil + } + + matches := rgxSyntaxError.FindStringSubmatch(err.Error()) + if matches == nil { + return nil, errors.Wrap(err, "failed to format template") + } + + lineNum, _ := strconv.Atoi(matches[1]) + scanner := bufio.NewScanner(buf) + errBuf := &bytes.Buffer{} + line := 1 + for ; scanner.Scan(); line++ { + if delta := line - lineNum; delta < -5 || delta > 5 { + continue + } + + if line == lineNum { + errBuf.WriteString(">>>> ") + } else { + fmt.Fprintf(errBuf, "% 4d ", line) + } + errBuf.Write(scanner.Bytes()) + errBuf.WriteByte('\n') + } + + return nil, errors.Wrapf(err, "failed to format template\n\n%s\n", errBuf.Bytes()) } diff --git a/output_test.go b/output_test.go index eb4f772..3a33eca 100644 --- a/output_test.go +++ b/output_test.go @@ -2,8 +2,10 @@ package main import ( "bytes" + "fmt" "io" - "sort" + "os" + "strings" "testing" ) @@ -19,120 +21,49 @@ func nopCloser(w io.Writer) io.WriteCloser { return NopWriteCloser{w} } -func TestOutHandler(t *testing.T) { +func TestWriteFile(t *testing.T) { // t.Parallel() cannot be used // set the function pointer back to its original value // after we modify it for the test - saveTestHarnessFileOpen := testHarnessFileOpen + saveTestHarnessWriteFile := testHarnessWriteFile defer func() { - testHarnessFileOpen = saveTestHarnessFileOpen + testHarnessWriteFile = saveTestHarnessWriteFile }() + var output []byte + testHarnessWriteFile = func(_ string, in []byte, _ os.FileMode) error { + output = in + return nil + } + buf := &bytes.Buffer{} - testHarnessFileOpen = func(path string) (io.WriteCloser, error) { - return nopCloser(buf), nil - } + writePackageName(buf, "pkg") + fmt.Fprintf(buf, "func hello() {}\n\n\nfunc world() {\nreturn\n}\n\n\n\n") - templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} - - if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil { + if err := writeFile("", "", buf); err != nil { t.Error(err) } - if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { - t.Errorf("Wrong output: %q", out) + if string(output) != "package pkg\n\nfunc hello() {}\n\nfunc world() {\n\treturn\n}\n" { + t.Errorf("Wrong output: %q", output) } } -func TestOutHandlerFiles(t *testing.T) { - // t.Parallel() cannot be used +func TestFormatBuffer(t *testing.T) { + t.Parallel() - saveTestHarnessFileOpen := testHarnessFileOpen - defer func() { - testHarnessFileOpen = saveTestHarnessFileOpen - }() + buf := &bytes.Buffer{} - file := &bytes.Buffer{} - testHarnessFileOpen = func(path string) (io.WriteCloser, error) { - return nopCloser(file), nil + fmt.Fprintf(buf, "package pkg\n\nfunc() {a}\n") + + // Only test error case - happy case is taken care of by template test + _, err := formatBuffer(buf) + if err == nil { + t.Error("want an error") } - templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} - - if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil { - t.Error(err) - } - if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { - t.Errorf("Wrong output: %q", out) - } - - a1 := imports{ - standard: importList{ - `"fmt"`, - }, - } - file = &bytes.Buffer{} - - if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil { - t.Error(err) - } - if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" { - t.Errorf("Wrong output: %q", out) - } - - a2 := imports{ - thirdParty: []string{ - `"github.com/spf13/cobra"`, - }, - } - file = &bytes.Buffer{} - - if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil { - t.Error(err) - } - if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" { - t.Errorf("Wrong output: %q", out) - } - - a3 := imports{ - standard: importList{ - `"fmt"`, - `"errors"`, - }, - thirdParty: importList{ - `_ "github.com/lib/pq"`, - `_ "github.com/gorilla/n"`, - `"github.com/gorilla/mux"`, - `"github.com/gorilla/websocket"`, - }, - } - file = &bytes.Buffer{} - - sort.Sort(a3.standard) - sort.Sort(a3.thirdParty) - - if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil { - t.Error(err) - } - - expectedOut := `package patrick - -import ( - "errors" - "fmt" - - "github.com/gorilla/mux" - _ "github.com/gorilla/n" - "github.com/gorilla/websocket" - _ "github.com/lib/pq" -) - -hello world -patrick's dreams -` - - if out := file.String(); out != expectedOut { - t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut) + if txt := err.Error(); !strings.Contains(txt, ">>>> func() {a}") { + t.Error("got:\n", txt) } }