Add ability to replace a template

- This feature remains undocumented because it's not a good idea in most
  cases but it enables us to replace a template. This is especially
  useful when using sqlboiler as a library since testmain
  problematically loads config in it's own magical way, divorced from
  even the way sqlboiler in "normal" mode loads it. This enables
  replacement of that mechanism by replacing it's template.
This commit is contained in:
Aaron L 2017-01-15 21:21:04 -08:00
parent bfba60eaad
commit 761efee9f0
4 changed files with 64 additions and 8 deletions

View file

@ -198,7 +198,43 @@ func (s *State) initTemplates() error {
return err
}
s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl")
testMain := s.Config.DriverName + "_main.tpl"
s.TestMainTemplate, err = loadTemplate(nil, testMain, filepath.Join(basePath, templatesTestMainDirectory, testMain))
if err != nil {
return err
}
}
return s.processReplacements()
}
// processReplacements loads any replacement templates
func (s *State) processReplacements() error {
for _, replace := range s.Config.Replacements {
fmt.Println(replace)
splits := strings.Split(replace, ":")
if len(splits) != 2 {
return errors.Errorf("replace parameters must have 2 arguments, given: %s", replace)
}
toReplace, replaceWith := splits[0], splits[1]
var err error
switch filepath.Dir(toReplace) {
case templatesDirectory:
_, err = loadTemplate(s.Templates.Template, toReplace, replaceWith)
case templatesSingletonDirectory:
_, err = loadTemplate(s.SingletonTemplates.Template, toReplace, replaceWith)
case templatesTestDirectory:
_, err = loadTemplate(s.TestTemplates.Template, toReplace, replaceWith)
case templatesSingletonTestDirectory:
_, err = loadTemplate(s.SingletonTestTemplates.Template, toReplace, replaceWith)
case templatesTestMainDirectory:
s.TestMainTemplate, err = loadTemplate(nil, toReplace, replaceWith)
default:
return errors.Errorf("replace file's directory not part of any known folder: %s", toReplace)
}
if err != nil {
return err
}

View file

@ -10,6 +10,7 @@ type Config struct {
WhitelistTables []string
BlacklistTables []string
Tags []string
Replacements []string
Debug bool
NoTests bool
NoHooks bool

View file

@ -2,11 +2,13 @@ package boilingcore
import (
"fmt"
"io/ioutil"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/queries"
"github.com/vattle/sqlboiler/strmangle"
@ -109,16 +111,24 @@ func loadTemplates(dir string) (*templateList, error) {
return &templateList{Template: tpl}, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
pattern := filepath.Join(dir, filename)
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
// loadTemplate loads a single template, uses tpl as a base template if provided
// and creates a new base template if not.
func loadTemplate(tpl *template.Template, name, filename string) (*template.Template, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
return nil, errors.Wrapf(err, "failed reading template file: %s", filename)
}
return tpl.Lookup(filename), err
if tpl == nil {
tpl = template.New(name)
} else {
tpl = tpl.New(name)
}
if tpl, err = tpl.Funcs(templateFunctions).Parse(string(b)); err != nil {
return nil, errors.Wrapf(err, "failed to parse template file: %s", filename)
}
return tpl, nil
}
// set is to stop duplication from named enums, allowing a template loop

View file

@ -79,6 +79,7 @@ func main() {
rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
rootCmd.PersistentFlags().StringSliceP("replace", "", nil, "Replace templates by directory: relpath/to_file.tpl:relpath/to_replacement.tpl")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models")
@ -163,6 +164,14 @@ func preRun(cmd *cobra.Command, args []string) error {
}
}
cmdConfig.Replacements = viper.GetStringSlice("replace")
if len(cmdConfig.Replacements) == 1 && strings.ContainsRune(cmdConfig.Replacements[0], ',') {
cmdConfig.Replacements, err = cmd.PersistentFlags().GetStringSlice("replace")
if err != nil {
return err
}
}
if driverName == "postgres" {
cmdConfig.Postgres = boilingcore.PostgresConfig{
User: viper.GetString("postgres.user"),