Add base_dir to make it less ornerous to use
This commit is contained in:
parent
278d9ab80a
commit
63fae21c51
4 changed files with 31 additions and 18 deletions
|
@ -5,6 +5,7 @@ type Config struct {
|
|||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
BaseDir string `toml:"base_dir"`
|
||||
ExcludeTables []string `toml:"exclude"`
|
||||
|
||||
Postgres PostgresConfig `toml:"postgres"`
|
||||
|
|
1
main.go
1
main.go
|
@ -62,6 +62,7 @@ func main() {
|
|||
// Set up the cobra root command flags
|
||||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory templates and templates_test folders are")
|
||||
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
|
||||
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
||||
|
||||
|
|
32
sqlboiler.go
32
sqlboiler.go
|
@ -3,7 +3,9 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"go/build"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
|
@ -138,27 +140,32 @@ func (s *State) Cleanup() error {
|
|||
func (s *State) initTemplates() error {
|
||||
var err error
|
||||
|
||||
s.Templates, err = loadTemplates(templatesDirectory)
|
||||
basePath, err := getBasePath(s.Config.BaseDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
|
||||
s.Templates, err = loadTemplates(filepath.Join(basePath, templatesDirectory))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||
s.SingletonTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonDirectory))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
|
||||
s.TestTemplates, err = loadTemplates(filepath.Join(basePath, templatesTestDirectory))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, s.Config.DriverName+"_main.tpl")
|
||||
s.SingletonTestTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonTestDirectory))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -166,6 +173,21 @@ func (s *State) initTemplates() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
var basePackage = "github.com/vattle/sqlboiler"
|
||||
|
||||
func getBasePath(baseDirConfig string) (string, error) {
|
||||
if len(baseDirConfig) > 0 {
|
||||
return baseDirConfig, nil
|
||||
}
|
||||
|
||||
p, _ := build.Default.Import(basePackage, "", build.FindOnly)
|
||||
if p != nil && len(p.Dir) > 0 {
|
||||
return p.Dir, nil
|
||||
}
|
||||
|
||||
return os.Getwd()
|
||||
}
|
||||
|
||||
// initDriver attempts to set the state Interface based off the passed in
|
||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||
func (s *State) initDriver(driverName string) error {
|
||||
|
|
15
templates.go
15
templates.go
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -73,12 +72,7 @@ func (t templateList) Templates() []string {
|
|||
|
||||
// loadTemplates loads all of the template files in the specified directory.
|
||||
func loadTemplates(dir string) (*templateList, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
||||
pattern := filepath.Join(dir, "*.tpl")
|
||||
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
|
||||
|
||||
if err != nil {
|
||||
|
@ -90,12 +84,7 @@ func loadTemplates(dir string) (*templateList, error) {
|
|||
|
||||
// loadTemplate loads a single template file.
|
||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, filename)
|
||||
pattern := filepath.Join(dir, filename)
|
||||
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
||||
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue