diff --git a/output.go b/output.go index c58056d..e51c690 100644 --- a/output.go +++ b/output.go @@ -19,55 +19,87 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { // generateOutput builds the file output and sends it to outHandler for saving func generateOutput(state *State, data *templateData) error { - if len(state.Templates) == 0 { - return errors.New("No template files located for generation") - } - var out [][]byte - var imps imports - - imps.standard = defaultTemplateImports.standard - imps.thirdParty = defaultTemplateImports.thirdParty - - for _, template := range state.Templates { - imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns) - resp, err := generateTemplate(template, data) - if err != nil { - return fmt.Errorf("Error generating template %s: %s", template.Name(), err) - } - out = append(out, resp) - } - - fName := data.Table.Name + ".go" - err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out) - if err != nil { - return err - } - - return nil + return executeTemplates(executeTemplateData{ + state: state, + data: data, + templates: state.Templates, + importSet: defaultTemplateImports, + combineImportsOnType: true, + fileSuffix: ".go", + }) } // generateTestOutput builds the test file output and sends it to outHandler for saving func generateTestOutput(state *State, data *templateData) error { - if len(state.TestTemplates) == 0 { - return errors.New("No template files located for generation") - } + return executeTemplates(executeTemplateData{ + state: state, + data: data, + templates: state.TestTemplates, + importSet: defaultTestTemplateImports, + combineImportsOnType: false, + fileSuffix: "_test.go", + }) +} +// generateSingletonOutput processes the templates that should only be run +// one time. +func generateSingletonOutput(state *State, data *templateData) error { + return executeSingletonTemplates(executeTemplateData{ + state: state, + data: data, + templates: state.SingletonTemplates, + importNamedSet: defaultSingletonTemplateImports, + fileSuffix: ".go", + }) +} + +// generateSingletonTestOutput processes the templates that should only be run +// one time. +func generateSingletonTestOutput(state *State, data *templateData) error { + return executeSingletonTemplates(executeTemplateData{ + state: state, + data: data, + templates: state.SingletonTestTemplates, + importNamedSet: defaultSingletonTestTemplateImports, + fileSuffix: "_test.go", + }) +} + +type executeTemplateData struct { + state *State + data *templateData + + templates templateList + + importSet imports + importNamedSet map[string]imports + + combineImportsOnType bool + + fileSuffix string +} + +func executeTemplates(e executeTemplateData) error { var out [][]byte var imps imports - imps.standard = defaultTestTemplateImports.standard - imps.thirdParty = defaultTestTemplateImports.thirdParty + imps.standard = e.importSet.standard + imps.thirdParty = e.importSet.thirdParty - for _, template := range state.TestTemplates { - resp, err := generateTemplate(template, data) + for _, template := range e.templates { + if e.combineImportsOnType { + imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns) + } + + resp, err := executeTemplate(template, e.data) if err != nil { - return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) + return fmt.Errorf("Error generating template %s: %s", template.Name(), err) } out = append(out, resp) } - fName := data.Table.Name + "_test.go" - err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out) + fName := e.data.Table.Name + e.fileSuffix + err := outHandler(e.state.Config.OutFolder, fName, e.state.Config.PkgName, imps, out) if err != nil { return err } @@ -75,22 +107,9 @@ func generateTestOutput(state *State, data *templateData) error { return nil } -// generateSingletonOutput processes the templates that should only be run -// one time. -func generateSingletonOutput(state *State) error { - if state.SingletonTemplates == nil { - return errors.New("No singleton templates located for generation") - } - - templateData := &templateData{ - PkgName: state.Config.PkgName, - DriverName: state.Config.DriverName, - } - - for _, template := range state.SingletonTemplates { - var imps imports - - resp, err := generateTemplate(template, templateData) +func executeSingletonTemplates(e executeTemplateData) error { + for _, template := range e.templates { + resp, err := executeTemplate(template, e.data) if err != nil { return fmt.Errorf("Error generating template %s: %s", template.Name(), err) } @@ -99,12 +118,18 @@ func generateSingletonOutput(state *State) error { ext := filepath.Ext(fName) fName = fName[0 : len(fName)-len(ext)] - imps.standard = defaultSingletonTemplateImports[fName].standard - imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty + imps := imports{ + standard: e.importNamedSet[fName].standard, + thirdParty: e.importNamedSet[fName].thirdParty, + } - fName = fName + ".go" - - err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp}) + err = outHandler( + e.state.Config.OutFolder, + fName+e.fileSuffix, + e.state.Config.PkgName, + imps, + [][]byte{resp}, + ) if err != nil { return err } @@ -113,45 +138,7 @@ func generateSingletonOutput(state *State) error { return nil } -// generateSingletonTestOutput processes the templates that should only be run -// one time. -func generateSingletonTestOutput(state *State) error { - if state.SingletonTestTemplates == nil { - return errors.New("No singleton test templates located for generation") - } - - templateData := &templateData{ - PkgName: state.Config.PkgName, - DriverName: state.Config.DriverName, - } - - for _, template := range state.SingletonTestTemplates { - var imps imports - - resp, err := generateTemplate(template, templateData) - if err != nil { - return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) - } - - fName := template.Name() - ext := filepath.Ext(fName) - fName = fName[0 : len(fName)-len(ext)] - - imps.standard = defaultSingletonTestTemplateImports[fName].standard - imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty - - fName = fName + "_test.go" - - err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp}) - if err != nil { - return err - } - } - - return nil -} - -func generateTestMainOutput(state *State) error { +func generateTestMainOutput(state *State, data *templateData) error { if state.TestMainTemplate == nil { return errors.New("No TestMain template located for generation") } @@ -162,13 +149,7 @@ func generateTestMainOutput(state *State) error { imps.standard = defaultTestMainImports[state.Config.DriverName].standard imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty - templateData := &templateData{ - Tables: state.Tables, - PkgName: state.Config.PkgName, - DriverName: state.Config.DriverName, - } - - resp, err := generateTemplate(state.TestMainTemplate, templateData) + resp, err := executeTemplate(state.TestMainTemplate, data) if err != nil { return err } @@ -214,8 +195,9 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports, return nil } -// generateTemplate takes a template and returns the output of the template execution. -func generateTemplate(t *template.Template, data *templateData) ([]byte, error) { +// executeTemplate takes a template and returns the output of the template +// execution. +func executeTemplate(t *template.Template, data *templateData) ([]byte, error) { var buf bytes.Buffer if err := t.Execute(&buf, data); err != nil { return nil, err diff --git a/sqlboiler.go b/sqlboiler.go index effe69c..1ba6202 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -70,16 +70,22 @@ func New(config *Config) (*State, error) { // Run executes the sqlboiler templates and outputs them to files based on the // state given. func (s *State) Run(includeTests bool) error { - if err := generateSingletonOutput(s); err != nil { + singletonData := &templateData{ + Tables: s.Tables, + DriverName: s.Config.DriverName, + PkgName: s.Config.PkgName, + } + + if err := generateSingletonOutput(s, singletonData); err != nil { return fmt.Errorf("Unable to generate singleton template output: %s", err) } if includeTests { - if err := generateTestMainOutput(s); err != nil { + if err := generateTestMainOutput(s, singletonData); err != nil { return fmt.Errorf("Unable to generate TestMain output: %s", err) } - if err := generateSingletonTestOutput(s); err != nil { + if err := generateSingletonTestOutput(s, singletonData); err != nil { return fmt.Errorf("Unable to generate singleton test template output: %s", err) } }