Add ability to replace a template
- This feature remains undocumented because it's not a good idea in most cases but it enables us to replace a template. This is especially useful when using sqlboiler as a library since testmain problematically loads config in it's own magical way, divorced from even the way sqlboiler in "normal" mode loads it. This enables replacement of that mechanism by replacing it's template.
This commit is contained in:
parent
bfba60eaad
commit
761efee9f0
4 changed files with 64 additions and 8 deletions
|
@ -198,7 +198,43 @@ func (s *State) initTemplates() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl")
|
testMain := s.Config.DriverName + "_main.tpl"
|
||||||
|
s.TestMainTemplate, err = loadTemplate(nil, testMain, filepath.Join(basePath, templatesTestMainDirectory, testMain))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.processReplacements()
|
||||||
|
}
|
||||||
|
|
||||||
|
// processReplacements loads any replacement templates
|
||||||
|
func (s *State) processReplacements() error {
|
||||||
|
for _, replace := range s.Config.Replacements {
|
||||||
|
fmt.Println(replace)
|
||||||
|
splits := strings.Split(replace, ":")
|
||||||
|
if len(splits) != 2 {
|
||||||
|
return errors.Errorf("replace parameters must have 2 arguments, given: %s", replace)
|
||||||
|
}
|
||||||
|
|
||||||
|
toReplace, replaceWith := splits[0], splits[1]
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch filepath.Dir(toReplace) {
|
||||||
|
case templatesDirectory:
|
||||||
|
_, err = loadTemplate(s.Templates.Template, toReplace, replaceWith)
|
||||||
|
case templatesSingletonDirectory:
|
||||||
|
_, err = loadTemplate(s.SingletonTemplates.Template, toReplace, replaceWith)
|
||||||
|
case templatesTestDirectory:
|
||||||
|
_, err = loadTemplate(s.TestTemplates.Template, toReplace, replaceWith)
|
||||||
|
case templatesSingletonTestDirectory:
|
||||||
|
_, err = loadTemplate(s.SingletonTestTemplates.Template, toReplace, replaceWith)
|
||||||
|
case templatesTestMainDirectory:
|
||||||
|
s.TestMainTemplate, err = loadTemplate(nil, toReplace, replaceWith)
|
||||||
|
default:
|
||||||
|
return errors.Errorf("replace file's directory not part of any known folder: %s", toReplace)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ type Config struct {
|
||||||
WhitelistTables []string
|
WhitelistTables []string
|
||||||
BlacklistTables []string
|
BlacklistTables []string
|
||||||
Tags []string
|
Tags []string
|
||||||
|
Replacements []string
|
||||||
Debug bool
|
Debug bool
|
||||||
NoTests bool
|
NoTests bool
|
||||||
NoHooks bool
|
NoHooks bool
|
||||||
|
|
|
@ -2,11 +2,13 @@ package boilingcore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/vattle/sqlboiler/bdb"
|
"github.com/vattle/sqlboiler/bdb"
|
||||||
"github.com/vattle/sqlboiler/queries"
|
"github.com/vattle/sqlboiler/queries"
|
||||||
"github.com/vattle/sqlboiler/strmangle"
|
"github.com/vattle/sqlboiler/strmangle"
|
||||||
|
@ -109,16 +111,24 @@ func loadTemplates(dir string) (*templateList, error) {
|
||||||
return &templateList{Template: tpl}, err
|
return &templateList{Template: tpl}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadTemplate loads a single template file.
|
// loadTemplate loads a single template, uses tpl as a base template if provided
|
||||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
// and creates a new base template if not.
|
||||||
pattern := filepath.Join(dir, filename)
|
func loadTemplate(tpl *template.Template, name, filename string) (*template.Template, error) {
|
||||||
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
b, err := ioutil.ReadFile(filename)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrapf(err, "failed reading template file: %s", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tpl.Lookup(filename), err
|
if tpl == nil {
|
||||||
|
tpl = template.New(name)
|
||||||
|
} else {
|
||||||
|
tpl = tpl.New(name)
|
||||||
|
}
|
||||||
|
if tpl, err = tpl.Funcs(templateFunctions).Parse(string(b)); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "failed to parse template file: %s", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tpl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// set is to stop duplication from named enums, allowing a template loop
|
// set is to stop duplication from named enums, allowing a template loop
|
||||||
|
|
9
main.go
9
main.go
|
@ -79,6 +79,7 @@ func main() {
|
||||||
rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package")
|
rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package")
|
||||||
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
|
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
|
||||||
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
|
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
|
||||||
|
rootCmd.PersistentFlags().StringSliceP("replace", "", nil, "Replace templates by directory: relpath/to_file.tpl:relpath/to_replacement.tpl")
|
||||||
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
||||||
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
|
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
|
||||||
rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models")
|
rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models")
|
||||||
|
@ -163,6 +164,14 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmdConfig.Replacements = viper.GetStringSlice("replace")
|
||||||
|
if len(cmdConfig.Replacements) == 1 && strings.ContainsRune(cmdConfig.Replacements[0], ',') {
|
||||||
|
cmdConfig.Replacements, err = cmd.PersistentFlags().GetStringSlice("replace")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if driverName == "postgres" {
|
if driverName == "postgres" {
|
||||||
cmdConfig.Postgres = boilingcore.PostgresConfig{
|
cmdConfig.Postgres = boilingcore.PostgresConfig{
|
||||||
User: viper.GetString("postgres.user"),
|
User: viper.GetString("postgres.user"),
|
||||||
|
|
Loading…
Reference in a new issue