Add base_dir to make it less ornerous to use

This commit is contained in:
Aaron L 2016-08-23 21:50:14 -07:00
parent 278d9ab80a
commit 63fae21c51
4 changed files with 31 additions and 18 deletions

View file

@ -5,6 +5,7 @@ type Config struct {
DriverName string `toml:"driver_name"`
PkgName string `toml:"pkg_name"`
OutFolder string `toml:"out_folder"`
BaseDir string `toml:"base_dir"`
ExcludeTables []string `toml:"exclude"`
Postgres PostgresConfig `toml:"postgres"`

View file

@ -62,6 +62,7 @@ func main() {
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory templates and templates_test folders are")
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")

View file

@ -3,7 +3,9 @@
package main
import (
"go/build"
"os"
"path/filepath"
"strings"
"text/template"
@ -138,27 +140,32 @@ func (s *State) Cleanup() error {
func (s *State) initTemplates() error {
var err error
s.Templates, err = loadTemplates(templatesDirectory)
basePath, err := getBasePath(s.Config.BaseDir)
if err != nil {
return err
}
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
s.Templates, err = loadTemplates(filepath.Join(basePath, templatesDirectory))
if err != nil {
return err
}
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
s.SingletonTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonDirectory))
if err != nil {
return err
}
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
s.TestTemplates, err = loadTemplates(filepath.Join(basePath, templatesTestDirectory))
if err != nil {
return err
}
s.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, s.Config.DriverName+"_main.tpl")
s.SingletonTestTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonTestDirectory))
if err != nil {
return err
}
s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl")
if err != nil {
return err
}
@ -166,6 +173,21 @@ func (s *State) initTemplates() error {
return nil
}
var basePackage = "github.com/vattle/sqlboiler"
func getBasePath(baseDirConfig string) (string, error) {
if len(baseDirConfig) > 0 {
return baseDirConfig, nil
}
p, _ := build.Default.Import(basePackage, "", build.FindOnly)
if p != nil && len(p.Dir) > 0 {
return p.Dir, nil
}
return os.Getwd()
}
// initDriver attempts to set the state Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func (s *State) initDriver(driverName string) error {

View file

@ -2,7 +2,6 @@ package main
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
@ -73,12 +72,7 @@ func (t templateList) Templates() []string {
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) (*templateList, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
pattern := filepath.Join(dir, "*.tpl")
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
if err != nil {
@ -90,12 +84,7 @@ func loadTemplates(dir string) (*templateList, error) {
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
pattern := filepath.Join(dir, filename)
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
if err != nil {