Refactor entire project :D

- Move most files to root
- Remove cmds directory in favor of cmd directory with binary
- Remove all cobra from main
This commit is contained in:
Aaron L 2016-06-11 18:25:00 -07:00
parent 612b670048
commit 8757c8a184
39 changed files with 790 additions and 940 deletions

View file

@ -1,186 +0,0 @@
package cmds
import (
"fmt"
"os"
"text/template"
"github.com/BurntSushi/toml"
"github.com/nullbio/sqlboiler/strmangle"
)
// sqlBoilerTypeImports imports are only included in the template output if the database
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
var sqlBoilerTypeImports = map[string]imports{
"null.Float32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Float64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int8": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int16": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint8": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint16": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.String": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Bool": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Time": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}
// sqlBoilerImports defines the list of default template imports.
var sqlBoilerImports = imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"strings"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
}
var sqlBoilerSinglesImports = map[string]imports{
"helpers": imports{
standard: importList{},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
},
}
// sqlBoilerTestImports defines the list of default test template imports.
var sqlBoilerTestImports = imports{
standard: importList{
`"testing"`,
`"reflect"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
},
}
var sqlBoilerSinglesTestImports = map[string]imports{
"helper_funcs": imports{
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdparty: importList{},
},
}
var sqlBoilerTestMainImports = map[string]imports{
"postgres": imports{
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"time"`,
`"math/rand"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/BurntSushi/toml"`,
`_ "github.com/lib/pq"`,
},
},
}
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{
"singular": strmangle.Singular,
"plural": strmangle.Plural,
"titleCase": strmangle.TitleCase,
"titleCaseSingular": strmangle.TitleCaseSingular,
"titleCasePlural": strmangle.TitleCasePlural,
"titleCaseCommaList": strmangle.TitleCaseCommaList,
"camelCase": strmangle.CamelCase,
"camelCaseSingular": strmangle.CamelCaseSingular,
"camelCasePlural": strmangle.CamelCasePlural,
"camelCaseCommaList": strmangle.CamelCaseCommaList,
"columnsToStrings": strmangle.ColumnsToStrings,
"commaList": strmangle.CommaList,
"makeDBName": strmangle.MakeDBName,
"selectParamNames": strmangle.SelectParamNames,
"insertParamNames": strmangle.InsertParamNames,
"insertParamFlags": strmangle.InsertParamFlags,
"insertParamVariables": strmangle.InsertParamVariables,
"scanParamNames": strmangle.ScanParamNames,
"hasPrimaryKey": strmangle.HasPrimaryKey,
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
"wherePrimaryKey": strmangle.WherePrimaryKey,
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
"updateParamNames": strmangle.UpdateParamNames,
"updateParamVariables": strmangle.UpdateParamVariables,
"supportsResultObject": strmangle.SupportsResultObject,
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
"randDBStruct": strmangle.RandDBStruct,
"randDBStructSlice": strmangle.RandDBStructSlice,
}
// LoadConfigFile loads the toml config file into the cfg object
func (c *CmdData) LoadConfigFile(filename string) error {
cfg := &Config{}
_, err := toml.DecodeFile(filename, &cfg)
if os.IsNotExist(err) {
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
}
if err != nil {
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
}
c.Config = cfg
return nil
}

View file

@ -1,35 +0,0 @@
package cmds
import (
"io/ioutil"
"os"
"testing"
)
func TestLoadConfig(t *testing.T) {
t.Parallel()
cmdData := &CmdData{}
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
defer os.Remove(file.Name())
fContents := `[postgres]
host="localhost"
port=5432
user="user"
pass="pass"
dbname="mydb"`
file.WriteString(fContents)
err := cmdData.LoadConfigFile(file.Name())
if err != nil {
t.Errorf("Unable to load config file: %s", err)
}
if cmdData.Config.Postgres.Host != "localhost" ||
cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
}
}

View file

@ -1,114 +0,0 @@
package cmds
import (
"bytes"
"fmt"
"sort"
"github.com/nullbio/sqlboiler/dbdrivers"
)
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
sort.Sort(c.standard)
sort.Sort(c.thirdparty)
return c
}
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
tmpImp := imports{
standard: make(importList, len(a.standard)),
thirdparty: make(importList, len(a.thirdparty)),
}
copy(tmpImp.standard, a.standard)
copy(tmpImp.thirdparty, a.thirdparty)
for _, col := range columns {
for key, imp := range b {
if col.Type == key {
tmpImp.standard = append(tmpImp.standard, imp.standard...)
tmpImp.thirdparty = append(tmpImp.thirdparty, imp.thirdparty...)
}
}
}
tmpImp.standard = removeDuplicates(tmpImp.standard)
tmpImp.thirdparty = removeDuplicates(tmpImp.thirdparty)
sort.Sort(tmpImp.standard)
sort.Sort(tmpImp.thirdparty)
return tmpImp
}
func buildImportString(imps imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdparty[0]
}
return []byte(fmt.Sprintf("import %s", imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdparty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}

View file

@ -1,42 +0,0 @@
package cmds
import "strings"
func (i importList) Len() int {
return len(i)
}
func (i importList) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i importList) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func (t templater) Len() int {
return len(t)
}
func (t templater) Swap(k, j int) {
t[k], t[j] = t[j], t[k]
}
func (t templater) Less(k, j int) bool {
// Make sure "struct" goes to the front
if t[k].Name() == "struct.tpl" {
return true
}
res := strings.Compare(t[k].Name(), t[j].Name())
if res <= 0 {
return true
}
return false
}

View file

@ -1,80 +0,0 @@
package cmds
import (
"reflect"
"sort"
"testing"
"text/template"
)
func TestSortImports(t *testing.T) {
t.Parallel()
a1 := importList{
`"fmt"`,
`"errors"`,
}
a2 := importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
}
a1Expected := importList{`"errors"`, `"fmt"`}
a2Expected := importList{
`"github.com/gorilla/mux"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/websocket"`,
`_ "github.com/lib/pq"`,
}
sort.Sort(a1)
if !reflect.DeepEqual(a1, a1Expected) {
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
}
for i, v := range a1 {
if v != a1Expected[i] {
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
sort.Sort(a2)
if !reflect.DeepEqual(a2, a2Expected) {
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
}
for i, v := range a2 {
if v != a2Expected[i] {
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
}
func TestSortTemplates(t *testing.T) {
templs := templater{
template.New("bob.tpl"),
template.New("all.tpl"),
template.New("struct.tpl"),
template.New("ttt.tpl"),
}
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
sort.Sort(templs)
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
}

View file

@ -1,277 +0,0 @@
package cmds
import (
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
const (
templatesDirectory = "cmds/templates"
templatesSinglesDirectory = "cmds/templates/singleton"
templatesTestDirectory = "cmds/templates_test"
templatesSinglesTestDirectory = "cmds/templates_test/singleton"
templatesTestMainDirectory = "cmds/templates_test/main_test"
)
// LoadTemplates loads all template folders into the cmdData object.
func initTemplates(cmdData *CmdData) error {
var err error
cmdData.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory)
if err != nil {
return err
}
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
if err != nil {
return err
}
filename := cmdData.DriverName + "_main.tpl"
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
if err != nil {
return err
}
return nil
}
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil {
return nil, err
}
templates := templater(tpl.Templates())
sort.Sort(templates)
return templates, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
if err != nil {
return nil, err
}
return tpl.Lookup(filename), err
}
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
c.Interface.Close()
return nil
}
// SQLBoilerPreRun performs the initialization tasks before the root command is run
func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
// Initialize package name
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
// Retrieve driver flag
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
return errors.New("Must supply a driver flag.")
}
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
if outFolder == "" {
return fmt.Errorf("No output folder specified.")
}
return c.initCmdData(pkgName, driverName, tableName, outFolder)
}
// SQLBoilerRun is a proxy method for the run function
func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
return c.run(true)
}
// run executes the sqlboiler templates and outputs them to files.
func (c *CmdData) run(includeTests bool) error {
if err := generateSingletonOutput(c); err != nil {
return fmt.Errorf("Unable to generate singleton template output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(c); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
if err := generateSingletonTestOutput(c); err != nil {
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
}
}
for _, table := range c.Tables {
if table.IsJoinTable {
continue
}
data := &tplData{
Table: table,
DriverName: c.DriverName,
PkgName: c.PkgName,
}
// Generate the regular templates
if err := generateOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}
return nil
}
func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
c.OutFolder = outFolder
c.PkgName = pkgName
err := initInterface(driverName, c)
if err != nil {
return err
}
// Connect to the driver database
if err = c.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err)
}
err = initTables(tableName, c)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
err = initOutFolder(c)
if err != nil {
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
err = initTemplates(c)
if err != nil {
return fmt.Errorf("Unable to initialize templates: %s", err)
}
return nil
}
// initInterface attempts to set the cmdData Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(driverName string, cmdData *CmdData) error {
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver(
cmdData.Config.Postgres.User,
cmdData.Config.Postgres.Pass,
cmdData.Config.Postgres.DBName,
cmdData.Config.Postgres.Host,
cmdData.Config.Postgres.Port,
)
}
if cmdData.Interface == nil {
return errors.New("An invalid driver name was provided")
}
cmdData.DriverName = driverName
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTables(tableName string, cmdData *CmdData) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(cmdData.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
if err := checkPKeys(cmdData.Tables); err != nil {
return err
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmdData *CmdData) error {
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
}

View file

@ -1,67 +0,0 @@
package cmds
import (
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
// CobraRunFunc declares the cobra.Command.Run function definition
type CobraRunFunc func(cmd *cobra.Command, args []string)
type templater []*template.Template
// CmdData holds the table schema a slice of (column name, column type) slices.
// It also holds a slice of all of the table names sqlboiler is generating against,
// the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag.
type CmdData struct {
Tables []dbdrivers.Table
PkgName string
OutFolder string
Interface dbdrivers.Interface
DriverName string
Config *Config
Templates templater
// SingleTemplates are only created once, not per table
SingleTemplates templater
TestTemplates templater
// SingleTestTemplates are only created once, not per table
SingleTestTemplates templater
//TestMainTemplate is only created once, not per table
TestMainTemplate *template.Template
}
// tplData is used to pass data to the template
type tplData struct {
Table dbdrivers.Table
DriverName string
PkgName string
Tables []string
}
type importList []string
// imports defines the optional standard imports and
// thirdparty imports (from github for example)
type imports struct {
standard importList
thirdparty importList
}
// PostgresCfg configures a postgres database
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
}
// Config is loaded from a file
type Config struct {
Postgres PostgresCfg `toml:"postgres"`
}

41
config.go Normal file
View file

@ -0,0 +1,41 @@
package sqlboiler
import (
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
)
// State holds the global data needed by most pieces to run
type State struct {
Config *Config
Driver dbdrivers.Interface
Tables []dbdrivers.Table
Templates templateList
TestTemplates templateList
SingletonTemplates templateList
SingletonTestTemplates templateList
TestMainTemplate *template.Template
}
// Config for the running of the commands
type Config struct {
DriverName string `toml:"driver_name"`
PkgName string `toml:"pkg_name"`
OutFolder string `toml:"out_folder"`
TableName string `toml:"table_name"`
Postgres PostgresConfig `toml:"postgres"`
}
// PostgresConfig configures a postgres database
type PostgresConfig struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
}

262
imports.go Normal file
View file

@ -0,0 +1,262 @@
package sqlboiler
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/nullbio/sqlboiler/dbdrivers"
)
// imports defines the optional standard imports and
// thirdParty imports (from github for example)
type imports struct {
standard importList
thirdParty importList
}
// importList is a list of import names
type importList []string
func (i importList) Len() int {
return len(i)
}
func (i importList) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i importList) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty))
sort.Sort(c.standard)
sort.Sort(c.thirdParty)
return c
}
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
tmpImp := imports{
standard: make(importList, len(a.standard)),
thirdParty: make(importList, len(a.thirdParty)),
}
copy(tmpImp.standard, a.standard)
copy(tmpImp.thirdParty, a.thirdParty)
for _, col := range columns {
for key, imp := range b {
if col.Type == key {
tmpImp.standard = append(tmpImp.standard, imp.standard...)
tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...)
}
}
}
tmpImp.standard = removeDuplicates(tmpImp.standard)
tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty)
sort.Sort(tmpImp.standard)
sort.Sort(tmpImp.thirdParty)
return tmpImp
}
func buildImportString(imps imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdParty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdParty[0]
}
return []byte(fmt.Sprintf("import %s", imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdParty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}
var defaultTemplateImports = imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
}
var defaultSingletonTemplateImports = map[string]imports{
"helpers": imports{
standard: importList{},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
},
}
var defaultTestTemplateImports = imports{
standard: importList{
`"testing"`,
`"reflect"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
},
}
var defaultSingletonTestTemplateImports = map[string]imports{
"helper_funcs": imports{
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdParty: importList{},
},
}
var defaultTestMainImports = map[string]imports{
"postgres": imports{
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/BurntSushi/toml"`,
`_ "github.com/lib/pq"`,
},
},
}
// importsBasedOnType imports are only included in the template output if the
// database requires one of the following special types. Check
// TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{
"null.Float32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Float64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int8": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int16": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint8": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint16": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.String": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Bool": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Time": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}

View file

@ -1,20 +1,66 @@
package cmds package sqlboiler
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"testing" "testing"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/dbdrivers"
) )
func TestImportsSort(t *testing.T) {
t.Parallel()
a1 := importList{
`"fmt"`,
`"errors"`,
}
a2 := importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
}
a1Expected := importList{`"errors"`, `"fmt"`}
a2Expected := importList{
`"github.com/gorilla/mux"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/websocket"`,
`_ "github.com/lib/pq"`,
}
sort.Sort(a1)
if !reflect.DeepEqual(a1, a1Expected) {
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
}
for i, v := range a1 {
if v != a1Expected[i] {
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
sort.Sort(a2)
if !reflect.DeepEqual(a2, a2Expected) {
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
}
for i, v := range a2 {
if v != a2Expected[i] {
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
}
func TestCombineTypeImports(t *testing.T) { func TestCombineTypeImports(t *testing.T) {
imports1 := imports{ imports1 := imports{
standard: importList{ standard: importList{
`"errors"`, `"errors"`,
`"fmt"`, `"fmt"`,
}, },
thirdparty: importList{ thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`, `"github.com/nullbio/sqlboiler/boil"`,
}, },
} }
@ -25,7 +71,7 @@ func TestCombineTypeImports(t *testing.T) {
`"fmt"`, `"fmt"`,
`"time"`, `"time"`,
}, },
thirdparty: importList{ thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`, `"github.com/nullbio/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`, `"gopkg.in/nullbio/null.v4"`,
}, },
@ -46,7 +92,7 @@ func TestCombineTypeImports(t *testing.T) {
}, },
} }
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols) res1 := combineTypeImports(imports1, importsBasedOnType, cols)
if !reflect.DeepEqual(res1, importsExpected) { if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
@ -58,13 +104,13 @@ func TestCombineTypeImports(t *testing.T) {
`"fmt"`, `"fmt"`,
`"time"`, `"time"`,
}, },
thirdparty: importList{ thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`, `"github.com/nullbio/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`, `"gopkg.in/nullbio/null.v4"`,
}, },
} }
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols) res2 := combineTypeImports(imports2, importsBasedOnType, cols)
if !reflect.DeepEqual(res2, importsExpected) { if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
@ -76,11 +122,11 @@ func TestCombineImports(t *testing.T) {
a := imports{ a := imports{
standard: importList{"fmt"}, standard: importList{"fmt"},
thirdparty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"}, thirdParty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
} }
b := imports{ b := imports{
standard: importList{"os"}, standard: importList{"os"},
thirdparty: importList{"github.com/nullbio/sqlboiler"}, thirdParty: importList{"github.com/nullbio/sqlboiler"},
} }
c := combineImports(a, b) c := combineImports(a, b)
@ -88,8 +134,8 @@ func TestCombineImports(t *testing.T) {
if c.standard[0] != "fmt" && c.standard[1] != "os" { if c.standard[0] != "fmt" && c.standard[1] != "os" {
t.Errorf("Wanted: fmt, os got: %#v", c.standard) t.Errorf("Wanted: fmt, os got: %#v", c.standard)
} }
if c.thirdparty[0] != "github.com/nullbio/sqlboiler" && c.thirdparty[1] != "gopkg.in/nullbio/null.v4" { if c.thirdParty[0] != "github.com/nullbio/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" {
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdparty) t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty)
} }
} }

54
main.go
View file

@ -1,54 +0,0 @@
/*
SQLBoiler is a tool to generate Go boilerplate code for database interactions.
So far this includes struct definitions and database statement helper functions.
*/
package main
import (
"fmt"
"os"
"github.com/nullbio/sqlboiler/cmds"
"github.com/spf13/cobra"
)
func main() {
var err error
cmdData := &cmds.CmdData{}
// Load the "config.toml" global config
err = cmdData.LoadConfigFile("config.toml")
if err != nil {
fmt.Printf("Failed to load config file: %s\n", err)
os.Exit(-1)
}
// Set up the cobra root command
var rootCmd = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
PreRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPreRun(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerRun(cmd, args)
},
PostRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPostRun(cmd, args)
},
}
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
// Execute SQLBoiler
if err := rootCmd.Execute(); err != nil {
os.Exit(-1)
}
}

View file

@ -1,4 +1,4 @@
package cmds package sqlboiler
import ( import (
"bytes" "bytes"
@ -18,18 +18,18 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
} }
// generateOutput builds the file output and sends it to outHandler for saving // generateOutput builds the file output and sends it to outHandler for saving
func generateOutput(cmdData *CmdData, data *tplData) error { func generateOutput(state *State, data *templateData) error {
if len(cmdData.Templates) == 0 { if len(state.Templates) == 0 {
return errors.New("No template files located for generation") return errors.New("No template files located for generation")
} }
var out [][]byte var out [][]byte
var imps imports var imps imports
imps.standard = sqlBoilerImports.standard imps.standard = defaultTemplateImports.standard
imps.thirdparty = sqlBoilerImports.thirdparty imps.thirdParty = defaultTemplateImports.thirdParty
for _, template := range cmdData.Templates { for _, template := range state.Templates {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns) imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns)
resp, err := generateTemplate(template, data) resp, err := generateTemplate(template, data)
if err != nil { if err != nil {
return fmt.Errorf("Error generating template %s: %s", template.Name(), err) return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
@ -38,7 +38,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
} }
fName := data.Table.Name + ".go" fName := data.Table.Name + ".go"
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out) err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
if err != nil { if err != nil {
return err return err
} }
@ -47,18 +47,18 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
} }
// generateTestOutput builds the test file output and sends it to outHandler for saving // generateTestOutput builds the test file output and sends it to outHandler for saving
func generateTestOutput(cmdData *CmdData, data *tplData) error { func generateTestOutput(state *State, data *templateData) error {
if len(cmdData.TestTemplates) == 0 { if len(state.TestTemplates) == 0 {
return errors.New("No template files located for generation") return errors.New("No template files located for generation")
} }
var out [][]byte var out [][]byte
var imps imports var imps imports
imps.standard = sqlBoilerTestImports.standard imps.standard = defaultTestTemplateImports.standard
imps.thirdparty = sqlBoilerTestImports.thirdparty imps.thirdParty = defaultTestTemplateImports.thirdParty
for _, template := range cmdData.TestTemplates { for _, template := range state.TestTemplates {
resp, err := generateTemplate(template, data) resp, err := generateTemplate(template, data)
if err != nil { if err != nil {
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
@ -67,7 +67,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
} }
fName := data.Table.Name + "_test.go" fName := data.Table.Name + "_test.go"
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out) err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
if err != nil { if err != nil {
return err return err
} }
@ -77,20 +77,20 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
// generateSingletonOutput processes the templates that should only be run // generateSingletonOutput processes the templates that should only be run
// one time. // one time.
func generateSingletonOutput(cmdData *CmdData) error { func generateSingletonOutput(state *State) error {
if cmdData.SingleTemplates == nil { if state.SingletonTemplates == nil {
return errors.New("No singleton templates located for generation") return errors.New("No singleton templates located for generation")
} }
tplData := &tplData{ templateData := &templateData{
PkgName: cmdData.PkgName, PkgName: state.Config.PkgName,
DriverName: cmdData.DriverName, DriverName: state.Config.DriverName,
} }
for _, template := range cmdData.SingleTemplates { for _, template := range state.SingletonTemplates {
var imps imports var imps imports
resp, err := generateTemplate(template, tplData) resp, err := generateTemplate(template, templateData)
if err != nil { if err != nil {
return fmt.Errorf("Error generating template %s: %s", template.Name(), err) return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
} }
@ -99,12 +99,12 @@ func generateSingletonOutput(cmdData *CmdData) error {
ext := filepath.Ext(fName) ext := filepath.Ext(fName)
fName = fName[0 : len(fName)-len(ext)] fName = fName[0 : len(fName)-len(ext)]
imps.standard = sqlBoilerSinglesImports[fName].standard imps.standard = defaultSingletonTemplateImports[fName].standard
imps.thirdparty = sqlBoilerSinglesImports[fName].thirdparty imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty
fName = fName + ".go" fName = fName + ".go"
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp}) err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
if err != nil { if err != nil {
return err return err
} }
@ -115,20 +115,20 @@ func generateSingletonOutput(cmdData *CmdData) error {
// generateSingletonTestOutput processes the templates that should only be run // generateSingletonTestOutput processes the templates that should only be run
// one time. // one time.
func generateSingletonTestOutput(cmdData *CmdData) error { func generateSingletonTestOutput(state *State) error {
if cmdData.SingleTestTemplates == nil { if state.SingletonTestTemplates == nil {
return errors.New("No singleton test templates located for generation") return errors.New("No singleton test templates located for generation")
} }
tplData := &tplData{ templateData := &templateData{
PkgName: cmdData.PkgName, PkgName: state.Config.PkgName,
DriverName: cmdData.DriverName, DriverName: state.Config.DriverName,
} }
for _, template := range cmdData.SingleTestTemplates { for _, template := range state.SingletonTestTemplates {
var imps imports var imps imports
resp, err := generateTemplate(template, tplData) resp, err := generateTemplate(template, templateData)
if err != nil { if err != nil {
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
} }
@ -137,12 +137,12 @@ func generateSingletonTestOutput(cmdData *CmdData) error {
ext := filepath.Ext(fName) ext := filepath.Ext(fName)
fName = fName[0 : len(fName)-len(ext)] fName = fName[0 : len(fName)-len(ext)]
imps.standard = sqlBoilerSinglesTestImports[fName].standard imps.standard = defaultSingletonTestTemplateImports[fName].standard
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty
fName = fName + "_test.go" fName = fName + "_test.go"
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp}) err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
if err != nil { if err != nil {
return err return err
} }
@ -151,35 +151,30 @@ func generateSingletonTestOutput(cmdData *CmdData) error {
return nil return nil
} }
func generateTestMainOutput(cmdData *CmdData) error { func generateTestMainOutput(state *State) error {
if cmdData.TestMainTemplate == nil { if state.TestMainTemplate == nil {
return errors.New("No TestMain template located for generation") return errors.New("No TestMain template located for generation")
} }
var out [][]byte var out [][]byte
var imps imports var imps imports
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard imps.standard = defaultTestMainImports[state.Config.DriverName].standard
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
var tables []string templateData := &templateData{
for _, v := range cmdData.Tables { Tables: state.Tables,
tables = append(tables, v.Name) PkgName: state.Config.PkgName,
DriverName: state.Config.DriverName,
} }
tplData := &tplData{ resp, err := generateTemplate(state.TestMainTemplate, templateData)
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
Tables: tables,
}
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)
if err != nil { if err != nil {
return err return err
} }
out = append(out, resp) out = append(out, resp)
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out) err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out)
if err != nil { if err != nil {
return err return err
} }
@ -220,7 +215,7 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports,
} }
// generateTemplate takes a template and returns the output of the template execution. // generateTemplate takes a template and returns the output of the template execution.
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) { func generateTemplate(t *template.Template, data *templateData) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil { if err := t.Execute(&buf, data); err != nil {
return nil, err return nil, err

View file

@ -1,4 +1,4 @@
package cmds package sqlboiler
import ( import (
"bytes" "bytes"
@ -78,7 +78,7 @@ func TestOutHandlerFiles(t *testing.T) {
} }
a2 := imports{ a2 := imports{
thirdparty: []string{ thirdParty: []string{
`"github.com/spf13/cobra"`, `"github.com/spf13/cobra"`,
}, },
} }
@ -96,7 +96,7 @@ func TestOutHandlerFiles(t *testing.T) {
`"fmt"`, `"fmt"`,
`"errors"`, `"errors"`,
}, },
thirdparty: importList{ thirdParty: importList{
`_ "github.com/lib/pq"`, `_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`, `_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`, `"github.com/gorilla/mux"`,
@ -106,7 +106,7 @@ func TestOutHandlerFiles(t *testing.T) {
file = &bytes.Buffer{} file = &bytes.Buffer{}
sort.Sort(a3.standard) sort.Sort(a3.standard)
sort.Sort(a3.thirdparty) sort.Sort(a3.thirdParty)
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil { if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
t.Error(err) t.Error(err)

207
sqlboiler.go Normal file
View file

@ -0,0 +1,207 @@
// Package sqlboiler has types and methods useful for generating code that
// acts as a fully dynamic ORM might.
package sqlboiler
import (
"errors"
"fmt"
"os"
"strings"
"github.com/nullbio/sqlboiler/dbdrivers"
)
const (
templatesDirectory = "cmds/templates"
templatesSingletonDirectory = "cmds/templates/singleton"
templatesTestDirectory = "cmds/templates_test"
templatesSingletonTestDirectory = "cmds/templates_test/singleton"
)
// New creates a new state based off of the config
func New(config *Config) (*State, error) {
s := &State{}
err := s.initDriver(config.DriverName)
if err != nil {
return nil, err
}
// Connect to the driver database
if err = s.Driver.Open(); err != nil {
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
}
err = s.initTables(config.TableName)
if err != nil {
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
}
err = s.initOutFolder()
if err != nil {
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
}
err = s.initTemplates()
if err != nil {
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
}
return s, nil
}
// Run executes the sqlboiler templates and outputs them to files based on the
// state given.
func (s *State) Run(includeTests bool) error {
if err := generateSingletonOutput(s); err != nil {
return fmt.Errorf("Unable to generate singleton template output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(s); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
if err := generateSingletonTestOutput(s); err != nil {
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
}
}
for _, table := range s.Tables {
if table.IsJoinTable {
continue
}
data := &templateData{
Table: table,
DriverName: s.Config.DriverName,
PkgName: s.Config.PkgName,
}
// Generate the regular templates
if err := generateOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}
return nil
}
// Cleanup closes any resources that must be closed
func (s *State) Cleanup() error {
s.Driver.Close()
return nil
}
// initTemplates loads all template folders into the state object.
func (s *State) initTemplates() error {
var err error
s.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
if err != nil {
return err
}
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
if err != nil {
return err
}
return nil
}
// initDriver attempts to set the state Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func (s *State) initDriver(driverName string) error {
// Create a driver based off driver flag
switch driverName {
case "postgres":
s.Driver = dbdrivers.NewPostgresDriver(
s.Config.Postgres.User,
s.Config.Postgres.Pass,
s.Config.Postgres.DBName,
s.Config.Postgres.Host,
s.Config.Postgres.Port,
)
}
if s.Driver == nil {
return errors.New("An invalid driver name was provided")
}
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that
// result.
func (s *State) initTables(tableName string) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(s.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
if err := checkPKeys(s.Tables); err != nil {
return err
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func (s *State) initOutFolder() error {
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil
}

View file

@ -1,4 +1,4 @@
package cmds package sqlboiler
import ( import (
"bufio" "bufio"
@ -15,11 +15,11 @@ import (
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/dbdrivers"
) )
var cmdData *CmdData var state *State
var rgxHasSpaces = regexp.MustCompile(`^\s+`) var rgxHasSpaces = regexp.MustCompile(`^\s+`)
func init() { func init() {
cmdData = &CmdData{ state = &State{
Tables: []dbdrivers.Table{ Tables: []dbdrivers.Table{
{ {
Name: "patrick_table", Name: "patrick_table",
@ -59,10 +59,11 @@ func init() {
}, },
}, },
}, },
PkgName: "patrick", Config: &Config{
OutFolder: "", PkgName: "patrick",
DriverName: "postgres", OutFolder: "",
Interface: nil, DriverName: "postgres",
},
} }
} }
@ -84,63 +85,63 @@ func TestTemplates(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
if err := checkPKeys(cmdData.Tables); err != nil { if err := checkPKeys(state.Tables); err != nil {
t.Fatalf("%s", err) t.Fatalf("%s", err)
} }
// Initialize the templates // Initialize the templates
var err error var err error
cmdData.Templates, err = loadTemplates("templates") state.Templates, err = loadTemplates("templates")
if err != nil { if err != nil {
t.Fatalf("Unable to initialize templates: %s", err) t.Fatalf("Unable to initialize templates: %s", err)
} }
if len(cmdData.Templates) == 0 { if len(state.Templates) == 0 {
t.Errorf("Templates is empty.") t.Errorf("Templates is empty.")
} }
cmdData.SingleTemplates, err = loadTemplates("templates/singleton") state.SingletonTemplates, err = loadTemplates("templates/singleton")
if err != nil { if err != nil {
t.Fatalf("Unable to initialize singleton templates: %s", err) t.Fatalf("Unable to initialize singleton templates: %s", err)
} }
if len(cmdData.SingleTemplates) == 0 { if len(state.SingletonTemplates) == 0 {
t.Errorf("SingleTemplates is empty.") t.Errorf("SingletonTemplates is empty.")
} }
cmdData.TestTemplates, err = loadTemplates("templates_test") state.TestTemplates, err = loadTemplates("templates_test")
if err != nil { if err != nil {
t.Fatalf("Unable to initialize templates: %s", err) t.Fatalf("Unable to initialize templates: %s", err)
} }
if len(cmdData.Templates) == 0 { if len(state.Templates) == 0 {
t.Errorf("Templates is empty.") t.Errorf("Templates is empty.")
} }
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl") state.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
if err != nil { if err != nil {
t.Fatalf("Unable to initialize templates: %s", err) t.Fatalf("Unable to initialize templates: %s", err)
} }
cmdData.OutFolder, err = ioutil.TempDir("", "templates") state.Config.OutFolder, err = ioutil.TempDir("", "templates")
if err != nil { if err != nil {
t.Fatalf("Unable to create tempdir: %s", err) t.Fatalf("Unable to create tempdir: %s", err)
} }
defer os.RemoveAll(cmdData.OutFolder) defer os.RemoveAll(state.Config.OutFolder)
if err = cmdData.run(true); err != nil { if err = state.Run(true); err != nil {
t.Errorf("Unable to run SQLBoilerRun: %s", err) t.Errorf("Unable to run SQLBoilerRun: %s", err)
} }
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
cmd := exec.Command("go", "test", "-c") cmd := exec.Command("go", "test", "-c")
cmd.Dir = cmdData.OutFolder cmd.Dir = state.Config.OutFolder
cmd.Stderr = buf cmd.Stderr = buf
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
t.Errorf("go test cmd execution failed: %s", err) t.Errorf("go test cmd execution failed: %s", err)
outputCompileErrors(buf, cmdData.OutFolder) outputCompileErrors(buf, state.Config.OutFolder)
fmt.Println() fmt.Println()
} }
} }

119
templates.go Normal file
View file

@ -0,0 +1,119 @@
package sqlboiler
import (
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/nullbio/sqlboiler/strmangle"
)
// templateData for sqlboiler templates
type templateData struct {
Tables []dbdrivers.Table
Table dbdrivers.Table
DriverName string
PkgName string
}
type templateList []*template.Template
func (t templateList) Len() int {
return len(t)
}
func (t templateList) Swap(k, j int) {
t[k], t[j] = t[j], t[k]
}
func (t templateList) Less(k, j int) bool {
// Make sure "struct" goes to the front
if t[k].Name() == "struct.tpl" {
return true
}
res := strings.Compare(t[k].Name(), t[j].Name())
if res <= 0 {
return true
}
return false
}
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) (templateList, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
if err != nil {
return nil, err
}
templates := templateList(tpl.Templates())
sort.Sort(templates)
return templates, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
if err != nil {
return nil, err
}
return tpl.Lookup(filename), err
}
// templateFunctions is a map of all the functions that get passed into the
// templates. If you wish to pass a new function into your own template,
// add a function pointer here.
var templateFunctions = template.FuncMap{
"singular": strmangle.Singular,
"plural": strmangle.Plural,
"titleCase": strmangle.TitleCase,
"titleCaseSingular": strmangle.TitleCaseSingular,
"titleCasePlural": strmangle.TitleCasePlural,
"titleCaseCommaList": strmangle.TitleCaseCommaList,
"camelCase": strmangle.CamelCase,
"camelCaseSingular": strmangle.CamelCaseSingular,
"camelCasePlural": strmangle.CamelCasePlural,
"camelCaseCommaList": strmangle.CamelCaseCommaList,
"columnsToStrings": strmangle.ColumnsToStrings,
"commaList": strmangle.CommaList,
"makeDBName": strmangle.MakeDBName,
"selectParamNames": strmangle.SelectParamNames,
"insertParamNames": strmangle.InsertParamNames,
"insertParamFlags": strmangle.InsertParamFlags,
"insertParamVariables": strmangle.InsertParamVariables,
"scanParamNames": strmangle.ScanParamNames,
"hasPrimaryKey": strmangle.HasPrimaryKey,
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
"wherePrimaryKey": strmangle.WherePrimaryKey,
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
"updateParamNames": strmangle.UpdateParamNames,
"updateParamVariables": strmangle.UpdateParamVariables,
"supportsResultObject": strmangle.SupportsResultObject,
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
"randDBStruct": strmangle.RandDBStruct,
"randDBStructSlice": strmangle.RandDBStructSlice,
}

34
templates_test.go Normal file
View file

@ -0,0 +1,34 @@
package sqlboiler
import (
"sort"
"testing"
"text/template"
)
func TestTemplateListSort(t *testing.T) {
templs := templateList{
template.New("bob.tpl"),
template.New("all.tpl"),
template.New("struct.tpl"),
template.New("ttt.tpl"),
}
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
sort.Sort(templs)
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
}