package main import ( "bytes" "io" "sort" "testing" ) type NopWriteCloser struct { io.Writer } func (NopWriteCloser) Close() error { return nil } func nopCloser(w io.Writer) io.WriteCloser { return NopWriteCloser{w} } func TestOutHandler(t *testing.T) { // t.Parallel() cannot be used // set the function pointer back to its original value // after we modify it for the test saveTestHarnessFileOpen := testHarnessFileOpen defer func() { testHarnessFileOpen = saveTestHarnessFileOpen }() buf := &bytes.Buffer{} testHarnessFileOpen = func(path string) (io.WriteCloser, error) { return nopCloser(buf), nil } templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} if err := outHandler("", "file.go", "patrick", imports{}, templateOutputs); err != nil { t.Error(err) } if out := buf.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { t.Errorf("Wrong output: %q", out) } } func TestOutHandlerFiles(t *testing.T) { // t.Parallel() cannot be used saveTestHarnessFileOpen := testHarnessFileOpen defer func() { testHarnessFileOpen = saveTestHarnessFileOpen }() file := &bytes.Buffer{} testHarnessFileOpen = func(path string) (io.WriteCloser, error) { return nopCloser(file), nil } templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} if err := outHandler("folder", "file.go", "patrick", imports{}, templateOutputs); err != nil { t.Error(err) } if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { t.Errorf("Wrong output: %q", out) } a1 := imports{ standard: importList{ `"fmt"`, }, } file = &bytes.Buffer{} if err := outHandler("folder", "file.go", "patrick", a1, templateOutputs); err != nil { t.Error(err) } if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" { t.Errorf("Wrong output: %q", out) } a2 := imports{ thirdParty: []string{ `"github.com/spf13/cobra"`, }, } file = &bytes.Buffer{} if err := outHandler("folder", "file.go", "patrick", a2, templateOutputs); err != nil { t.Error(err) } if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" { t.Errorf("Wrong output: %q", out) } a3 := imports{ standard: importList{ `"fmt"`, `"errors"`, }, thirdParty: importList{ `_ "github.com/lib/pq"`, `_ "github.com/gorilla/n"`, `"github.com/gorilla/mux"`, `"github.com/gorilla/websocket"`, }, } file = &bytes.Buffer{} sort.Sort(a3.standard) sort.Sort(a3.thirdParty) if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil { t.Error(err) } expectedOut := `package patrick import ( "errors" "fmt" "github.com/gorilla/mux" _ "github.com/gorilla/n" "github.com/gorilla/websocket" _ "github.com/lib/pq" ) hello world patrick's dreams ` if out := file.String(); out != expectedOut { t.Errorf("Wrong output (len %d, len %d): \n\n%q\n\n%q", len(out), len(expectedOut), out, expectedOut) } }