Add base_dir to make it less ornerous to use
This commit is contained in:
parent
278d9ab80a
commit
63fae21c51
4 changed files with 31 additions and 18 deletions
|
@ -5,6 +5,7 @@ type Config struct {
|
||||||
DriverName string `toml:"driver_name"`
|
DriverName string `toml:"driver_name"`
|
||||||
PkgName string `toml:"pkg_name"`
|
PkgName string `toml:"pkg_name"`
|
||||||
OutFolder string `toml:"out_folder"`
|
OutFolder string `toml:"out_folder"`
|
||||||
|
BaseDir string `toml:"base_dir"`
|
||||||
ExcludeTables []string `toml:"exclude"`
|
ExcludeTables []string `toml:"exclude"`
|
||||||
|
|
||||||
Postgres PostgresConfig `toml:"postgres"`
|
Postgres PostgresConfig `toml:"postgres"`
|
||||||
|
|
1
main.go
1
main.go
|
@ -62,6 +62,7 @@ func main() {
|
||||||
// Set up the cobra root command flags
|
// Set up the cobra root command flags
|
||||||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||||
|
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory templates and templates_test folders are")
|
||||||
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
|
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
|
||||||
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
|
||||||
|
|
||||||
|
|
32
sqlboiler.go
32
sqlboiler.go
|
@ -3,7 +3,9 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"go/build"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
@ -138,27 +140,32 @@ func (s *State) Cleanup() error {
|
||||||
func (s *State) initTemplates() error {
|
func (s *State) initTemplates() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
s.Templates, err = loadTemplates(templatesDirectory)
|
basePath, err := getBasePath(s.Config.BaseDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
|
s.Templates, err = loadTemplates(filepath.Join(basePath, templatesDirectory))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
s.SingletonTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonDirectory))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
|
s.TestTemplates, err = loadTemplates(filepath.Join(basePath, templatesTestDirectory))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, s.Config.DriverName+"_main.tpl")
|
s.SingletonTestTemplates, err = loadTemplates(filepath.Join(basePath, templatesSingletonTestDirectory))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.TestMainTemplate, err = loadTemplate(filepath.Join(basePath, templatesTestMainDirectory), s.Config.DriverName+"_main.tpl")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -166,6 +173,21 @@ func (s *State) initTemplates() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var basePackage = "github.com/vattle/sqlboiler"
|
||||||
|
|
||||||
|
func getBasePath(baseDirConfig string) (string, error) {
|
||||||
|
if len(baseDirConfig) > 0 {
|
||||||
|
return baseDirConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p, _ := build.Default.Import(basePackage, "", build.FindOnly)
|
||||||
|
if p != nil && len(p.Dir) > 0 {
|
||||||
|
return p.Dir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Getwd()
|
||||||
|
}
|
||||||
|
|
||||||
// initDriver attempts to set the state Interface based off the passed in
|
// initDriver attempts to set the state Interface based off the passed in
|
||||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||||
func (s *State) initDriver(driverName string) error {
|
func (s *State) initDriver(driverName string) error {
|
||||||
|
|
15
templates.go
15
templates.go
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -73,12 +72,7 @@ func (t templateList) Templates() []string {
|
||||||
|
|
||||||
// loadTemplates loads all of the template files in the specified directory.
|
// loadTemplates loads all of the template files in the specified directory.
|
||||||
func loadTemplates(dir string) (*templateList, error) {
|
func loadTemplates(dir string) (*templateList, error) {
|
||||||
wd, err := os.Getwd()
|
pattern := filepath.Join(dir, "*.tpl")
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
|
||||||
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
|
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -90,12 +84,7 @@ func loadTemplates(dir string) (*templateList, error) {
|
||||||
|
|
||||||
// loadTemplate loads a single template file.
|
// loadTemplate loads a single template file.
|
||||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||||
wd, err := os.Getwd()
|
pattern := filepath.Join(dir, filename)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pattern := filepath.Join(wd, dir, filename)
|
|
||||||
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Reference in a new issue