Merge branch 'master' of github.com:nullbio/sqlboiler
This commit is contained in:
commit
3352220c53
42 changed files with 903 additions and 949 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
||||||
sqlboiler
|
/sqlboiler
|
||||||
config.toml
|
/cmd/sqlboiler/sqlboiler
|
||||||
|
sqlboiler.toml
|
||||||
|
|
11
boil/db.go
11
boil/db.go
|
@ -5,12 +5,19 @@ import (
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// currentDB is a global database handle for the package
|
||||||
|
currentDB Executor
|
||||||
|
)
|
||||||
|
|
||||||
|
// Executor can perform SQL queries.
|
||||||
type Executor interface {
|
type Executor interface {
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transactor can commit and rollback, on top of being able to execute queries.
|
||||||
type Transactor interface {
|
type Transactor interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
|
@ -18,12 +25,11 @@ type Transactor interface {
|
||||||
Executor
|
Executor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creator starts transactions.
|
||||||
type Creator interface {
|
type Creator interface {
|
||||||
Begin() (*sql.Tx, error)
|
Begin() (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var currentDB Executor
|
|
||||||
|
|
||||||
// DebugMode is a flag controlling whether generated sql statements and
|
// DebugMode is a flag controlling whether generated sql statements and
|
||||||
// debug information is outputted to the DebugWriter handle
|
// debug information is outputted to the DebugWriter handle
|
||||||
//
|
//
|
||||||
|
@ -33,6 +39,7 @@ var DebugMode = false
|
||||||
// DebugWriter is where the debug output will be sent if DebugMode is true
|
// DebugWriter is where the debug output will be sent if DebugMode is true
|
||||||
var DebugWriter = os.Stdout
|
var DebugWriter = os.Stdout
|
||||||
|
|
||||||
|
// Begin a transaction
|
||||||
func Begin() (Transactor, error) {
|
func Begin() (Transactor, error) {
|
||||||
creator, ok := currentDB.(Creator)
|
creator, ok := currentDB.(Creator)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
91
cmd/sqlboiler/main.go
Normal file
91
cmd/sqlboiler/main.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
// Package main defines a command line interface for the sqlboiler package
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/nullbio/sqlboiler"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
viper.SetConfigName("sqlboiler")
|
||||||
|
viper.AddConfigPath("$HOME/.sqlboiler")
|
||||||
|
viper.AddConfigPath(".")
|
||||||
|
|
||||||
|
err = viper.ReadInConfig()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load config file: %s\n", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the cobra root command
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "sqlboiler",
|
||||||
|
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||||
|
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
|
||||||
|
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
|
||||||
|
PreRunE: preRun,
|
||||||
|
RunE: run,
|
||||||
|
PostRunE: postRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the cobra root command flags
|
||||||
|
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
||||||
|
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||||
|
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
|
||||||
|
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||||
|
|
||||||
|
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||||
|
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
fmt.Println("Failed to execute sqlboiler command:", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var state *sqlboiler.State
|
||||||
|
var config *sqlboiler.Config
|
||||||
|
|
||||||
|
func preRun(cmd *cobra.Command, args []string) error {
|
||||||
|
config = new(sqlboiler.Config)
|
||||||
|
|
||||||
|
config.DriverName = viper.GetString("driver")
|
||||||
|
config.TableName = viper.GetString("table")
|
||||||
|
config.OutFolder = viper.GetString("folder")
|
||||||
|
config.PkgName = viper.GetString("pkgname")
|
||||||
|
|
||||||
|
if len(config.DriverName) == 0 {
|
||||||
|
return errors.New("Must supply a driver flag.")
|
||||||
|
}
|
||||||
|
if len(config.OutFolder) == 0 {
|
||||||
|
return fmt.Errorf("No output folder specified.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if viper.IsSet("postgres.dbname") {
|
||||||
|
config.Postgres = sqlboiler.PostgresConfig{
|
||||||
|
User: viper.GetString("postgres.user"),
|
||||||
|
Pass: viper.GetString("postgres.pass"),
|
||||||
|
Host: viper.GetString("postgres.host"),
|
||||||
|
Port: viper.GetInt("postgres.port"),
|
||||||
|
DBName: viper.GetString("postgres.dbname"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
state, err = sqlboiler.New(config)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(cmd *cobra.Command, args []string) error {
|
||||||
|
return state.Run(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func postRun(cmd *cobra.Command, args []string) error {
|
||||||
|
return state.Cleanup()
|
||||||
|
}
|
187
cmds/config.go
187
cmds/config.go
|
@ -1,187 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
"github.com/nullbio/sqlboiler/strmangle"
|
|
||||||
)
|
|
||||||
|
|
||||||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
|
||||||
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
|
||||||
var sqlBoilerTypeImports = map[string]imports{
|
|
||||||
"null.Float32": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Float64": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Int": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Int8": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Int16": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Int32": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Int64": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Uint": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Uint8": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Uint16": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Uint32": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Uint64": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.String": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Bool": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"null.Time": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
|
||||||
},
|
|
||||||
"time.Time": imports{
|
|
||||||
standard: importList{`"time"`},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerImports defines the list of default template imports.
|
|
||||||
var sqlBoilerImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"errors"`,
|
|
||||||
`"fmt"`,
|
|
||||||
`"strings"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
|
||||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlBoilerSinglesImports = map[string]imports{
|
|
||||||
"helpers": imports{
|
|
||||||
standard: importList{},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
|
||||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerTestImports defines the list of default test template imports.
|
|
||||||
var sqlBoilerTestImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"testing"`,
|
|
||||||
`"reflect"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlBoilerSinglesTestImports = map[string]imports{
|
|
||||||
"helper_funcs": imports{
|
|
||||||
standard: importList{
|
|
||||||
`"crypto/md5"`,
|
|
||||||
`"fmt"`,
|
|
||||||
`"os"`,
|
|
||||||
`"strconv"`,
|
|
||||||
`"math/rand"`,
|
|
||||||
`"bytes"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlBoilerTestMainImports = map[string]imports{
|
|
||||||
"postgres": imports{
|
|
||||||
standard: importList{
|
|
||||||
`"testing"`,
|
|
||||||
`"os"`,
|
|
||||||
`"os/exec"`,
|
|
||||||
`"fmt"`,
|
|
||||||
`"io/ioutil"`,
|
|
||||||
`"bytes"`,
|
|
||||||
`"database/sql"`,
|
|
||||||
`"time"`,
|
|
||||||
`"math/rand"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
|
||||||
`"github.com/BurntSushi/toml"`,
|
|
||||||
`_ "github.com/lib/pq"`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
|
||||||
// If you wish to pass a new function into your own template, add a pointer to it here.
|
|
||||||
var sqlBoilerTemplateFuncs = template.FuncMap{
|
|
||||||
"singular": strmangle.Singular,
|
|
||||||
"plural": strmangle.Plural,
|
|
||||||
"titleCase": strmangle.TitleCase,
|
|
||||||
"titleCaseSingular": strmangle.TitleCaseSingular,
|
|
||||||
"titleCasePlural": strmangle.TitleCasePlural,
|
|
||||||
"titleCaseCommaList": strmangle.TitleCaseCommaList,
|
|
||||||
"camelCase": strmangle.CamelCase,
|
|
||||||
"camelCaseSingular": strmangle.CamelCaseSingular,
|
|
||||||
"camelCasePlural": strmangle.CamelCasePlural,
|
|
||||||
"camelCaseCommaList": strmangle.CamelCaseCommaList,
|
|
||||||
"columnsToStrings": strmangle.ColumnsToStrings,
|
|
||||||
"commaList": strmangle.CommaList,
|
|
||||||
"makeDBName": strmangle.MakeDBName,
|
|
||||||
"selectParamNames": strmangle.SelectParamNames,
|
|
||||||
"insertParamNames": strmangle.InsertParamNames,
|
|
||||||
"insertParamFlags": strmangle.InsertParamFlags,
|
|
||||||
"insertParamVariables": strmangle.InsertParamVariables,
|
|
||||||
"scanParamNames": strmangle.ScanParamNames,
|
|
||||||
"hasPrimaryKey": strmangle.HasPrimaryKey,
|
|
||||||
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
|
|
||||||
"wherePrimaryKey": strmangle.WherePrimaryKey,
|
|
||||||
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
|
|
||||||
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
|
|
||||||
"updateParamNames": strmangle.UpdateParamNames,
|
|
||||||
"updateParamVariables": strmangle.UpdateParamVariables,
|
|
||||||
"supportsResultObject": strmangle.SupportsResultObject,
|
|
||||||
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
|
||||||
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
|
|
||||||
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
|
||||||
|
|
||||||
"randDBStruct": strmangle.RandDBStruct,
|
|
||||||
"randDBStructSlice": strmangle.RandDBStructSlice,
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadConfigFile loads the toml config file into the cfg object
|
|
||||||
func (c *CmdData) LoadConfigFile(filename string) error {
|
|
||||||
cfg := &Config{}
|
|
||||||
|
|
||||||
_, err := toml.DecodeFile(filename, &cfg)
|
|
||||||
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Config = cfg
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadConfig(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmdData := &CmdData{}
|
|
||||||
|
|
||||||
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
|
|
||||||
defer os.Remove(file.Name())
|
|
||||||
|
|
||||||
fContents := `[postgres]
|
|
||||||
host="localhost"
|
|
||||||
port=5432
|
|
||||||
user="user"
|
|
||||||
pass="pass"
|
|
||||||
dbname="mydb"`
|
|
||||||
|
|
||||||
file.WriteString(fContents)
|
|
||||||
err := cmdData.LoadConfigFile(file.Name())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unable to load config file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmdData.Config.Postgres.Host != "localhost" ||
|
|
||||||
cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
|
|
||||||
cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
|
|
||||||
t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
|
|
||||||
}
|
|
||||||
}
|
|
114
cmds/helpers.go
114
cmds/helpers.go
|
@ -1,114 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
|
||||||
)
|
|
||||||
|
|
||||||
func combineImports(a, b imports) imports {
|
|
||||||
var c imports
|
|
||||||
|
|
||||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
|
||||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
|
||||||
|
|
||||||
sort.Sort(c.standard)
|
|
||||||
sort.Sort(c.thirdparty)
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
|
||||||
tmpImp := imports{
|
|
||||||
standard: make(importList, len(a.standard)),
|
|
||||||
thirdparty: make(importList, len(a.thirdparty)),
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(tmpImp.standard, a.standard)
|
|
||||||
copy(tmpImp.thirdparty, a.thirdparty)
|
|
||||||
|
|
||||||
for _, col := range columns {
|
|
||||||
for key, imp := range b {
|
|
||||||
if col.Type == key {
|
|
||||||
tmpImp.standard = append(tmpImp.standard, imp.standard...)
|
|
||||||
tmpImp.thirdparty = append(tmpImp.thirdparty, imp.thirdparty...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpImp.standard = removeDuplicates(tmpImp.standard)
|
|
||||||
tmpImp.thirdparty = removeDuplicates(tmpImp.thirdparty)
|
|
||||||
|
|
||||||
sort.Sort(tmpImp.standard)
|
|
||||||
sort.Sort(tmpImp.thirdparty)
|
|
||||||
|
|
||||||
return tmpImp
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildImportString(imps imports) []byte {
|
|
||||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
|
||||||
if stdlen+thirdlen < 1 {
|
|
||||||
return []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if stdlen+thirdlen == 1 {
|
|
||||||
var imp string
|
|
||||||
if stdlen == 1 {
|
|
||||||
imp = imps.standard[0]
|
|
||||||
} else {
|
|
||||||
imp = imps.thirdparty[0]
|
|
||||||
}
|
|
||||||
return []byte(fmt.Sprintf("import %s", imp))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
buf.WriteString("import (")
|
|
||||||
for _, std := range imps.standard {
|
|
||||||
fmt.Fprintf(buf, "\n\t%s", std)
|
|
||||||
}
|
|
||||||
if stdlen != 0 && thirdlen != 0 {
|
|
||||||
buf.WriteString("\n")
|
|
||||||
}
|
|
||||||
for _, third := range imps.thirdparty {
|
|
||||||
fmt.Fprintf(buf, "\n\t%s", third)
|
|
||||||
}
|
|
||||||
buf.WriteString("\n)\n")
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineStringSlices(a, b []string) []string {
|
|
||||||
c := make([]string, len(a)+len(b))
|
|
||||||
if len(a) > 0 {
|
|
||||||
copy(c, a)
|
|
||||||
}
|
|
||||||
if len(b) > 0 {
|
|
||||||
copy(c[len(a):], b)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeDuplicates(dedup []string) []string {
|
|
||||||
if len(dedup) <= 1 {
|
|
||||||
return dedup
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(dedup)-1; i++ {
|
|
||||||
for j := i + 1; j < len(dedup); j++ {
|
|
||||||
if dedup[i] != dedup[j] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if j != len(dedup)-1 {
|
|
||||||
dedup[j] = dedup[len(dedup)-1]
|
|
||||||
j--
|
|
||||||
}
|
|
||||||
dedup = dedup[:len(dedup)-1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return dedup
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import "strings"
|
|
||||||
|
|
||||||
func (i importList) Len() int {
|
|
||||||
return len(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i importList) Swap(k, j int) {
|
|
||||||
i[k], i[j] = i[j], i[k]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i importList) Less(k, j int) bool {
|
|
||||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
|
||||||
if res <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t templater) Len() int {
|
|
||||||
return len(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t templater) Swap(k, j int) {
|
|
||||||
t[k], t[j] = t[j], t[k]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t templater) Less(k, j int) bool {
|
|
||||||
// Make sure "struct" goes to the front
|
|
||||||
if t[k].Name() == "struct.tpl" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
res := strings.Compare(t[k].Name(), t[j].Name())
|
|
||||||
if res <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
"text/template"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSortImports(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
a1 := importList{
|
|
||||||
`"fmt"`,
|
|
||||||
`"errors"`,
|
|
||||||
}
|
|
||||||
a2 := importList{
|
|
||||||
`_ "github.com/lib/pq"`,
|
|
||||||
`_ "github.com/gorilla/n"`,
|
|
||||||
`"github.com/gorilla/mux"`,
|
|
||||||
`"github.com/gorilla/websocket"`,
|
|
||||||
}
|
|
||||||
|
|
||||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
|
||||||
a2Expected := importList{
|
|
||||||
`"github.com/gorilla/mux"`,
|
|
||||||
`_ "github.com/gorilla/n"`,
|
|
||||||
`"github.com/gorilla/websocket"`,
|
|
||||||
`_ "github.com/lib/pq"`,
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Sort(a1)
|
|
||||||
if !reflect.DeepEqual(a1, a1Expected) {
|
|
||||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range a1 {
|
|
||||||
if v != a1Expected[i] {
|
|
||||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Sort(a2)
|
|
||||||
if !reflect.DeepEqual(a2, a2Expected) {
|
|
||||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range a2 {
|
|
||||||
if v != a2Expected[i] {
|
|
||||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSortTemplates(t *testing.T) {
|
|
||||||
templs := templater{
|
|
||||||
template.New("bob.tpl"),
|
|
||||||
template.New("all.tpl"),
|
|
||||||
template.New("struct.tpl"),
|
|
||||||
template.New("ttt.tpl"),
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
|
||||||
|
|
||||||
for i, v := range templs {
|
|
||||||
if v.Name() != expected[i] {
|
|
||||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
|
||||||
|
|
||||||
sort.Sort(templs)
|
|
||||||
|
|
||||||
for i, v := range templs {
|
|
||||||
if v.Name() != expected[i] {
|
|
||||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,273 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
templatesDirectory = "/cmds/templates"
|
|
||||||
templatesSinglesDirectory = "/cmds/templates/singles"
|
|
||||||
|
|
||||||
templatesTestDirectory = "/cmds/templates_test"
|
|
||||||
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
|
|
||||||
templatesTestMainDirectory = "/cmds/templates_test/main_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LoadTemplates loads all template folders into the cmdData object.
|
|
||||||
func initTemplates(cmdData *CmdData) error {
|
|
||||||
var err error
|
|
||||||
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
filename := cmdData.DriverName + "_main.tpl"
|
|
||||||
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadTemplates loads all of the template files in the specified directory.
|
|
||||||
func loadTemplates(dir string) ([]*template.Template, error) {
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
|
||||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
templates := templater(tpl.Templates())
|
|
||||||
sort.Sort(templates)
|
|
||||||
|
|
||||||
return templates, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadTemplate loads a single template file.
|
|
||||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pattern := filepath.Join(wd, dir, filename)
|
|
||||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tpl.Lookup(filename), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
|
||||||
func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
|
|
||||||
c.Interface.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
|
||||||
func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
|
||||||
// Initialize package name
|
|
||||||
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
|
||||||
|
|
||||||
// Retrieve driver flag
|
|
||||||
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
|
||||||
if driverName == "" {
|
|
||||||
return errors.New("Must supply a driver flag.")
|
|
||||||
}
|
|
||||||
|
|
||||||
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
|
|
||||||
|
|
||||||
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
|
|
||||||
if outFolder == "" {
|
|
||||||
return fmt.Errorf("No output folder specified.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.initCmdData(pkgName, driverName, tableName, outFolder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLBoilerRun is a proxy method for the run function
|
|
||||||
func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
|
||||||
return c.run(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// run executes the sqlboiler templates and outputs them to files.
|
|
||||||
func (c *CmdData) run(includeTests bool) error {
|
|
||||||
if err := generateSinglesOutput(c); err != nil {
|
|
||||||
return fmt.Errorf("Unable to generate single templates output: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if includeTests {
|
|
||||||
if err := generateTestMainOutput(c); err != nil {
|
|
||||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := generateSinglesTestOutput(c); err != nil {
|
|
||||||
return fmt.Errorf("Unable to generate single test templates output: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range c.Tables {
|
|
||||||
data := &tplData{
|
|
||||||
Table: table,
|
|
||||||
DriverName: c.DriverName,
|
|
||||||
PkgName: c.PkgName,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the regular templates
|
|
||||||
if err := generateOutput(c, data); err != nil {
|
|
||||||
return fmt.Errorf("Unable to generate output: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the test templates
|
|
||||||
if includeTests {
|
|
||||||
if err := generateTestOutput(c, data); err != nil {
|
|
||||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
|
|
||||||
c.OutFolder = outFolder
|
|
||||||
c.PkgName = pkgName
|
|
||||||
|
|
||||||
err := initInterface(driverName, c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to the driver database
|
|
||||||
if err = c.Interface.Open(); err != nil {
|
|
||||||
return fmt.Errorf("Unable to connect to the database: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = initTables(tableName, c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to initialize tables: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = initOutFolder(c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = initTemplates(c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to initialize templates: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// initInterface attempts to set the cmdData Interface based off the passed in
|
|
||||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
|
||||||
func initInterface(driverName string, cmdData *CmdData) error {
|
|
||||||
// Create a driver based off driver flag
|
|
||||||
switch driverName {
|
|
||||||
case "postgres":
|
|
||||||
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
|
||||||
cmdData.Config.Postgres.User,
|
|
||||||
cmdData.Config.Postgres.Pass,
|
|
||||||
cmdData.Config.Postgres.DBName,
|
|
||||||
cmdData.Config.Postgres.Host,
|
|
||||||
cmdData.Config.Postgres.Port,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmdData.Interface == nil {
|
|
||||||
return errors.New("An invalid driver name was provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdData.DriverName = driverName
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// initTables will create a string slice out of the passed in table flag value
|
|
||||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
|
||||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
|
||||||
func initTables(tableName string, cmdData *CmdData) error {
|
|
||||||
var tableNames []string
|
|
||||||
|
|
||||||
if len(tableName) != 0 {
|
|
||||||
tableNames = strings.Split(tableName, ",")
|
|
||||||
for i, name := range tableNames {
|
|
||||||
tableNames[i] = strings.TrimSpace(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cmdData.Tables) == 0 {
|
|
||||||
return errors.New("No tables found in database, migrate some tables first")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPKeys ensures every table has a primary key column
|
|
||||||
func checkPKeys(tables []dbdrivers.Table) error {
|
|
||||||
var missingPkey []string
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.PKey == nil {
|
|
||||||
missingPkey = append(missingPkey, t.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(missingPkey) != 0 {
|
|
||||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// initOutFolder creates the folder that will hold the generated output.
|
|
||||||
func initOutFolder(cmdData *CmdData) error {
|
|
||||||
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
|
||||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CobraRunFunc declares the cobra.Command.Run function definition
|
|
||||||
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
|
||||||
|
|
||||||
type templater []*template.Template
|
|
||||||
|
|
||||||
// CmdData holds the table schema a slice of (column name, column type) slices.
|
|
||||||
// It also holds a slice of all of the table names sqlboiler is generating against,
|
|
||||||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
|
||||||
// output file, if one is specified with a flag.
|
|
||||||
type CmdData struct {
|
|
||||||
Tables []dbdrivers.Table
|
|
||||||
PkgName string
|
|
||||||
OutFolder string
|
|
||||||
Interface dbdrivers.Interface
|
|
||||||
DriverName string
|
|
||||||
Config *Config
|
|
||||||
|
|
||||||
Templates templater
|
|
||||||
// SingleTemplates are only created once, not per table
|
|
||||||
SingleTemplates templater
|
|
||||||
|
|
||||||
TestTemplates templater
|
|
||||||
// SingleTestTemplates are only created once, not per table
|
|
||||||
SingleTestTemplates templater
|
|
||||||
//TestMainTemplate is only created once, not per table
|
|
||||||
TestMainTemplate *template.Template
|
|
||||||
}
|
|
||||||
|
|
||||||
// tplData is used to pass data to the template
|
|
||||||
type tplData struct {
|
|
||||||
Table dbdrivers.Table
|
|
||||||
DriverName string
|
|
||||||
PkgName string
|
|
||||||
Tables []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type importList []string
|
|
||||||
|
|
||||||
// imports defines the optional standard imports and
|
|
||||||
// thirdparty imports (from github for example)
|
|
||||||
type imports struct {
|
|
||||||
standard importList
|
|
||||||
thirdparty importList
|
|
||||||
}
|
|
||||||
|
|
||||||
type PostgresCfg struct {
|
|
||||||
User string `toml:"user"`
|
|
||||||
Pass string `toml:"pass"`
|
|
||||||
Host string `toml:"host"`
|
|
||||||
Port int `toml:"port"`
|
|
||||||
DBName string `toml:"dbname"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Postgres PostgresCfg `toml:"postgres"`
|
|
||||||
}
|
|
20
config.go
Normal file
20
config.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package sqlboiler
|
||||||
|
|
||||||
|
// Config for the running of the commands
|
||||||
|
type Config struct {
|
||||||
|
DriverName string `toml:"driver_name"`
|
||||||
|
PkgName string `toml:"pkg_name"`
|
||||||
|
OutFolder string `toml:"out_folder"`
|
||||||
|
TableName string `toml:"table_name"`
|
||||||
|
|
||||||
|
Postgres PostgresConfig `toml:"postgres"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostgresConfig configures a postgres database
|
||||||
|
type PostgresConfig struct {
|
||||||
|
User string `toml:"user"`
|
||||||
|
Pass string `toml:"pass"`
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
DBName string `toml:"dbname"`
|
||||||
|
}
|
263
imports.go
Normal file
263
imports.go
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
package sqlboiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// imports defines the optional standard imports and
|
||||||
|
// thirdParty imports (from github for example)
|
||||||
|
type imports struct {
|
||||||
|
standard importList
|
||||||
|
thirdParty importList
|
||||||
|
}
|
||||||
|
|
||||||
|
// importList is a list of import names
|
||||||
|
type importList []string
|
||||||
|
|
||||||
|
func (i importList) Len() int {
|
||||||
|
return len(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i importList) Swap(k, j int) {
|
||||||
|
i[k], i[j] = i[j], i[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i importList) Less(k, j int) bool {
|
||||||
|
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||||
|
if res <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineImports(a, b imports) imports {
|
||||||
|
var c imports
|
||||||
|
|
||||||
|
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||||
|
c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty))
|
||||||
|
|
||||||
|
sort.Sort(c.standard)
|
||||||
|
sort.Sort(c.thirdParty)
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
||||||
|
tmpImp := imports{
|
||||||
|
standard: make(importList, len(a.standard)),
|
||||||
|
thirdParty: make(importList, len(a.thirdParty)),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(tmpImp.standard, a.standard)
|
||||||
|
copy(tmpImp.thirdParty, a.thirdParty)
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
for key, imp := range b {
|
||||||
|
if col.Type == key {
|
||||||
|
tmpImp.standard = append(tmpImp.standard, imp.standard...)
|
||||||
|
tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpImp.standard = removeDuplicates(tmpImp.standard)
|
||||||
|
tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty)
|
||||||
|
|
||||||
|
sort.Sort(tmpImp.standard)
|
||||||
|
sort.Sort(tmpImp.thirdParty)
|
||||||
|
|
||||||
|
return tmpImp
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildImportString(imps imports) []byte {
|
||||||
|
stdlen, thirdlen := len(imps.standard), len(imps.thirdParty)
|
||||||
|
if stdlen+thirdlen < 1 {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stdlen+thirdlen == 1 {
|
||||||
|
var imp string
|
||||||
|
if stdlen == 1 {
|
||||||
|
imp = imps.standard[0]
|
||||||
|
} else {
|
||||||
|
imp = imps.thirdParty[0]
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("import %s", imp))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
buf.WriteString("import (")
|
||||||
|
for _, std := range imps.standard {
|
||||||
|
fmt.Fprintf(buf, "\n\t%s", std)
|
||||||
|
}
|
||||||
|
if stdlen != 0 && thirdlen != 0 {
|
||||||
|
buf.WriteString("\n")
|
||||||
|
}
|
||||||
|
for _, third := range imps.thirdParty {
|
||||||
|
fmt.Fprintf(buf, "\n\t%s", third)
|
||||||
|
}
|
||||||
|
buf.WriteString("\n)\n")
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineStringSlices(a, b []string) []string {
|
||||||
|
c := make([]string, len(a)+len(b))
|
||||||
|
if len(a) > 0 {
|
||||||
|
copy(c, a)
|
||||||
|
}
|
||||||
|
if len(b) > 0 {
|
||||||
|
copy(c[len(a):], b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeDuplicates(dedup []string) []string {
|
||||||
|
if len(dedup) <= 1 {
|
||||||
|
return dedup
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(dedup)-1; i++ {
|
||||||
|
for j := i + 1; j < len(dedup); j++ {
|
||||||
|
if dedup[i] != dedup[j] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if j != len(dedup)-1 {
|
||||||
|
dedup[j] = dedup[len(dedup)-1]
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
dedup = dedup[:len(dedup)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dedup
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultTemplateImports = imports{
|
||||||
|
standard: importList{
|
||||||
|
`"errors"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"strings"`,
|
||||||
|
},
|
||||||
|
thirdParty: importList{
|
||||||
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
|
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultSingletonTemplateImports = map[string]imports{
|
||||||
|
"helpers": imports{
|
||||||
|
standard: importList{},
|
||||||
|
thirdParty: importList{
|
||||||
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
|
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultTestTemplateImports = imports{
|
||||||
|
standard: importList{
|
||||||
|
`"testing"`,
|
||||||
|
`"reflect"`,
|
||||||
|
},
|
||||||
|
thirdParty: importList{
|
||||||
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultSingletonTestTemplateImports = map[string]imports{
|
||||||
|
"helper_funcs": imports{
|
||||||
|
standard: importList{
|
||||||
|
`"crypto/md5"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"os"`,
|
||||||
|
`"strconv"`,
|
||||||
|
`"math/rand"`,
|
||||||
|
`"bytes"`,
|
||||||
|
},
|
||||||
|
thirdParty: importList{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultTestMainImports = map[string]imports{
|
||||||
|
"postgres": imports{
|
||||||
|
standard: importList{
|
||||||
|
`"testing"`,
|
||||||
|
`"os"`,
|
||||||
|
`"os/exec"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"io/ioutil"`,
|
||||||
|
`"bytes"`,
|
||||||
|
`"database/sql"`,
|
||||||
|
`"time"`,
|
||||||
|
`"math/rand"`,
|
||||||
|
},
|
||||||
|
thirdParty: importList{
|
||||||
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
|
`"github.com/BurntSushi/toml"`,
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// importsBasedOnType imports are only included in the template output if the
|
||||||
|
// database requires one of the following special types. Check
|
||||||
|
// TranslateColumnType to see the type assignments.
|
||||||
|
var importsBasedOnType = map[string]imports{
|
||||||
|
"null.Float32": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Float64": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Int": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Int8": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Int16": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Int32": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Int64": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Uint": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Uint8": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Uint16": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Uint32": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Uint64": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.String": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Bool": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"null.Time": imports{
|
||||||
|
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||||
|
},
|
||||||
|
"time.Time": imports{
|
||||||
|
standard: importList{`"time"`},
|
||||||
|
},
|
||||||
|
}
|
|
@ -1,20 +1,66 @@
|
||||||
package cmds
|
package sqlboiler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestImportsSort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
a1 := importList{
|
||||||
|
`"fmt"`,
|
||||||
|
`"errors"`,
|
||||||
|
}
|
||||||
|
a2 := importList{
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
`_ "github.com/gorilla/n"`,
|
||||||
|
`"github.com/gorilla/mux"`,
|
||||||
|
`"github.com/gorilla/websocket"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||||
|
a2Expected := importList{
|
||||||
|
`"github.com/gorilla/mux"`,
|
||||||
|
`_ "github.com/gorilla/n"`,
|
||||||
|
`"github.com/gorilla/websocket"`,
|
||||||
|
`_ "github.com/lib/pq"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(a1)
|
||||||
|
if !reflect.DeepEqual(a1, a1Expected) {
|
||||||
|
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range a1 {
|
||||||
|
if v != a1Expected[i] {
|
||||||
|
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(a2)
|
||||||
|
if !reflect.DeepEqual(a2, a2Expected) {
|
||||||
|
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range a2 {
|
||||||
|
if v != a2Expected[i] {
|
||||||
|
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCombineTypeImports(t *testing.T) {
|
func TestCombineTypeImports(t *testing.T) {
|
||||||
imports1 := imports{
|
imports1 := imports{
|
||||||
standard: importList{
|
standard: importList{
|
||||||
`"errors"`,
|
`"errors"`,
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdParty: importList{
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -25,7 +71,7 @@ func TestCombineTypeImports(t *testing.T) {
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
`"time"`,
|
`"time"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdParty: importList{
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
`"gopkg.in/nullbio/null.v4"`,
|
`"gopkg.in/nullbio/null.v4"`,
|
||||||
},
|
},
|
||||||
|
@ -46,7 +92,7 @@ func TestCombineTypeImports(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
|
res1 := combineTypeImports(imports1, importsBasedOnType, cols)
|
||||||
|
|
||||||
if !reflect.DeepEqual(res1, importsExpected) {
|
if !reflect.DeepEqual(res1, importsExpected) {
|
||||||
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
||||||
|
@ -58,13 +104,13 @@ func TestCombineTypeImports(t *testing.T) {
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
`"time"`,
|
`"time"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdParty: importList{
|
||||||
`"github.com/nullbio/sqlboiler/boil"`,
|
`"github.com/nullbio/sqlboiler/boil"`,
|
||||||
`"gopkg.in/nullbio/null.v4"`,
|
`"gopkg.in/nullbio/null.v4"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
|
res2 := combineTypeImports(imports2, importsBasedOnType, cols)
|
||||||
|
|
||||||
if !reflect.DeepEqual(res2, importsExpected) {
|
if !reflect.DeepEqual(res2, importsExpected) {
|
||||||
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
||||||
|
@ -76,11 +122,11 @@ func TestCombineImports(t *testing.T) {
|
||||||
|
|
||||||
a := imports{
|
a := imports{
|
||||||
standard: importList{"fmt"},
|
standard: importList{"fmt"},
|
||||||
thirdparty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
|
thirdParty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
|
||||||
}
|
}
|
||||||
b := imports{
|
b := imports{
|
||||||
standard: importList{"os"},
|
standard: importList{"os"},
|
||||||
thirdparty: importList{"github.com/nullbio/sqlboiler"},
|
thirdParty: importList{"github.com/nullbio/sqlboiler"},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := combineImports(a, b)
|
c := combineImports(a, b)
|
||||||
|
@ -88,8 +134,8 @@ func TestCombineImports(t *testing.T) {
|
||||||
if c.standard[0] != "fmt" && c.standard[1] != "os" {
|
if c.standard[0] != "fmt" && c.standard[1] != "os" {
|
||||||
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
|
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
|
||||||
}
|
}
|
||||||
if c.thirdparty[0] != "github.com/nullbio/sqlboiler" && c.thirdparty[1] != "gopkg.in/nullbio/null.v4" {
|
if c.thirdParty[0] != "github.com/nullbio/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" {
|
||||||
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdparty)
|
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
54
main.go
54
main.go
|
@ -1,54 +0,0 @@
|
||||||
/*
|
|
||||||
SQLBoiler is a tool to generate Go boilerplate code for database interactions.
|
|
||||||
So far this includes struct definitions and database statement helper functions.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/nullbio/sqlboiler/cmds"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
var err error
|
|
||||||
cmdData := &cmds.CmdData{}
|
|
||||||
|
|
||||||
// Load the "config.toml" global config
|
|
||||||
err = cmdData.LoadConfigFile("config.toml")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to load config file: %s\n", err)
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the cobra root command
|
|
||||||
var rootCmd = &cobra.Command{
|
|
||||||
Use: "sqlboiler",
|
|
||||||
Short: "SQL Boiler generates boilerplate structs and statements",
|
|
||||||
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
|
|
||||||
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
|
|
||||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
return cmdData.SQLBoilerPreRun(cmd, args)
|
|
||||||
},
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
return cmdData.SQLBoilerRun(cmd, args)
|
|
||||||
},
|
|
||||||
PostRunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
return cmdData.SQLBoilerPostRun(cmd, args)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the cobra root command flags
|
|
||||||
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
|
||||||
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
|
||||||
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
|
|
||||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
|
||||||
|
|
||||||
// Execute SQLBoiler
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package cmds
|
package sqlboiler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -18,18 +18,18 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateOutput builds the file output and sends it to outHandler for saving
|
// generateOutput builds the file output and sends it to outHandler for saving
|
||||||
func generateOutput(cmdData *CmdData, data *tplData) error {
|
func generateOutput(state *State, data *templateData) error {
|
||||||
if len(cmdData.Templates) == 0 {
|
if len(state.Templates) == 0 {
|
||||||
return errors.New("No template files located for generation")
|
return errors.New("No template files located for generation")
|
||||||
}
|
}
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
var imps imports
|
var imps imports
|
||||||
|
|
||||||
imps.standard = sqlBoilerImports.standard
|
imps.standard = defaultTemplateImports.standard
|
||||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
imps.thirdParty = defaultTemplateImports.thirdParty
|
||||||
|
|
||||||
for _, template := range cmdData.Templates {
|
for _, template := range state.Templates {
|
||||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns)
|
||||||
resp, err := generateTemplate(template, data)
|
resp, err := generateTemplate(template, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
||||||
|
@ -38,7 +38,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
fName := data.Table.Name + ".go"
|
fName := data.Table.Name + ".go"
|
||||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -47,18 +47,18 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateTestOutput builds the test file output and sends it to outHandler for saving
|
// generateTestOutput builds the test file output and sends it to outHandler for saving
|
||||||
func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
func generateTestOutput(state *State, data *templateData) error {
|
||||||
if len(cmdData.TestTemplates) == 0 {
|
if len(state.TestTemplates) == 0 {
|
||||||
return errors.New("No template files located for generation")
|
return errors.New("No template files located for generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
var imps imports
|
var imps imports
|
||||||
|
|
||||||
imps.standard = sqlBoilerTestImports.standard
|
imps.standard = defaultTestTemplateImports.standard
|
||||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
imps.thirdParty = defaultTestTemplateImports.thirdParty
|
||||||
|
|
||||||
for _, template := range cmdData.TestTemplates {
|
for _, template := range state.TestTemplates {
|
||||||
resp, err := generateTemplate(template, data)
|
resp, err := generateTemplate(template, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
||||||
|
@ -67,7 +67,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
fName := data.Table.Name + "_test.go"
|
fName := data.Table.Name + "_test.go"
|
||||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -75,20 +75,22 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSinglesOutput(cmdData *CmdData) error {
|
// generateSingletonOutput processes the templates that should only be run
|
||||||
if cmdData.SingleTemplates == nil {
|
// one time.
|
||||||
return errors.New("No single templates located for generation")
|
func generateSingletonOutput(state *State) error {
|
||||||
|
if state.SingletonTemplates == nil {
|
||||||
|
return errors.New("No singleton templates located for generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
tplData := &tplData{
|
templateData := &templateData{
|
||||||
PkgName: cmdData.PkgName,
|
PkgName: state.Config.PkgName,
|
||||||
DriverName: cmdData.DriverName,
|
DriverName: state.Config.DriverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, template := range cmdData.SingleTemplates {
|
for _, template := range state.SingletonTemplates {
|
||||||
var imps imports
|
var imps imports
|
||||||
|
|
||||||
resp, err := generateTemplate(template, tplData)
|
resp, err := generateTemplate(template, templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
||||||
}
|
}
|
||||||
|
@ -97,12 +99,12 @@ func generateSinglesOutput(cmdData *CmdData) error {
|
||||||
ext := filepath.Ext(fName)
|
ext := filepath.Ext(fName)
|
||||||
fName = fName[0 : len(fName)-len(ext)]
|
fName = fName[0 : len(fName)-len(ext)]
|
||||||
|
|
||||||
imps.standard = sqlBoilerSinglesImports[fName].standard
|
imps.standard = defaultSingletonTemplateImports[fName].standard
|
||||||
imps.thirdparty = sqlBoilerSinglesImports[fName].thirdparty
|
imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty
|
||||||
|
|
||||||
fName = fName + ".go"
|
fName = fName + ".go"
|
||||||
|
|
||||||
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
|
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -111,20 +113,22 @@ func generateSinglesOutput(cmdData *CmdData) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSinglesTestOutput(cmdData *CmdData) error {
|
// generateSingletonTestOutput processes the templates that should only be run
|
||||||
if cmdData.SingleTestTemplates == nil {
|
// one time.
|
||||||
return errors.New("No single test templates located for generation")
|
func generateSingletonTestOutput(state *State) error {
|
||||||
|
if state.SingletonTestTemplates == nil {
|
||||||
|
return errors.New("No singleton test templates located for generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
tplData := &tplData{
|
templateData := &templateData{
|
||||||
PkgName: cmdData.PkgName,
|
PkgName: state.Config.PkgName,
|
||||||
DriverName: cmdData.DriverName,
|
DriverName: state.Config.DriverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, template := range cmdData.SingleTestTemplates {
|
for _, template := range state.SingletonTestTemplates {
|
||||||
var imps imports
|
var imps imports
|
||||||
|
|
||||||
resp, err := generateTemplate(template, tplData)
|
resp, err := generateTemplate(template, templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
||||||
}
|
}
|
||||||
|
@ -133,12 +137,12 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
|
||||||
ext := filepath.Ext(fName)
|
ext := filepath.Ext(fName)
|
||||||
fName = fName[0 : len(fName)-len(ext)]
|
fName = fName[0 : len(fName)-len(ext)]
|
||||||
|
|
||||||
imps.standard = sqlBoilerSinglesTestImports[fName].standard
|
imps.standard = defaultSingletonTestTemplateImports[fName].standard
|
||||||
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty
|
imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty
|
||||||
|
|
||||||
fName = fName + "_test.go"
|
fName = fName + "_test.go"
|
||||||
|
|
||||||
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
|
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -147,35 +151,30 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateTestMainOutput(cmdData *CmdData) error {
|
func generateTestMainOutput(state *State) error {
|
||||||
if cmdData.TestMainTemplate == nil {
|
if state.TestMainTemplate == nil {
|
||||||
return errors.New("No TestMain template located for generation")
|
return errors.New("No TestMain template located for generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
var imps imports
|
var imps imports
|
||||||
|
|
||||||
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
|
imps.standard = defaultTestMainImports[state.Config.DriverName].standard
|
||||||
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
|
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
|
||||||
|
|
||||||
var tables []string
|
templateData := &templateData{
|
||||||
for _, v := range cmdData.Tables {
|
Tables: state.Tables,
|
||||||
tables = append(tables, v.Name)
|
PkgName: state.Config.PkgName,
|
||||||
|
DriverName: state.Config.DriverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
tplData := &tplData{
|
resp, err := generateTemplate(state.TestMainTemplate, templateData)
|
||||||
PkgName: cmdData.PkgName,
|
|
||||||
DriverName: cmdData.DriverName,
|
|
||||||
Tables: tables,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
out = append(out, resp)
|
out = append(out, resp)
|
||||||
|
|
||||||
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out)
|
err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -216,7 +215,7 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports,
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateTemplate takes a template and returns the output of the template execution.
|
// generateTemplate takes a template and returns the output of the template execution.
|
||||||
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) {
|
func generateTemplate(t *template.Template, data *templateData) ([]byte, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
if err := t.Execute(&buf, data); err != nil {
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
return nil, err
|
return nil, err
|
|
@ -1,4 +1,4 @@
|
||||||
package cmds
|
package sqlboiler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -78,7 +78,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a2 := imports{
|
a2 := imports{
|
||||||
thirdparty: []string{
|
thirdParty: []string{
|
||||||
`"github.com/spf13/cobra"`,
|
`"github.com/spf13/cobra"`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
`"errors"`,
|
`"errors"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdParty: importList{
|
||||||
`_ "github.com/lib/pq"`,
|
`_ "github.com/lib/pq"`,
|
||||||
`_ "github.com/gorilla/n"`,
|
`_ "github.com/gorilla/n"`,
|
||||||
`"github.com/gorilla/mux"`,
|
`"github.com/gorilla/mux"`,
|
||||||
|
@ -106,7 +106,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
sort.Sort(a3.standard)
|
sort.Sort(a3.standard)
|
||||||
sort.Sort(a3.thirdparty)
|
sort.Sort(a3.thirdParty)
|
||||||
|
|
||||||
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
|
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
223
sqlboiler.go
Normal file
223
sqlboiler.go
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
// Package sqlboiler has types and methods useful for generating code that
|
||||||
|
// acts as a fully dynamic ORM might.
|
||||||
|
package sqlboiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
templatesDirectory = "cmds/templates"
|
||||||
|
templatesSingletonDirectory = "cmds/templates/singleton"
|
||||||
|
|
||||||
|
templatesTestDirectory = "cmds/templates_test"
|
||||||
|
templatesSingletonTestDirectory = "cmds/templates_test/singleton"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State holds the global data needed by most pieces to run
|
||||||
|
type State struct {
|
||||||
|
Config *Config
|
||||||
|
|
||||||
|
Driver dbdrivers.Interface
|
||||||
|
Tables []dbdrivers.Table
|
||||||
|
|
||||||
|
Templates templateList
|
||||||
|
TestTemplates templateList
|
||||||
|
SingletonTemplates templateList
|
||||||
|
SingletonTestTemplates templateList
|
||||||
|
|
||||||
|
TestMainTemplate *template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new state based off of the config
|
||||||
|
func New(config *Config) (*State, error) {
|
||||||
|
s := &State{}
|
||||||
|
|
||||||
|
err := s.initDriver(config.DriverName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the driver database
|
||||||
|
if err = s.Driver.Open(); err != nil {
|
||||||
|
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.initTables(config.TableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.initOutFolder()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.initTemplates()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run executes the sqlboiler templates and outputs them to files based on the
|
||||||
|
// state given.
|
||||||
|
func (s *State) Run(includeTests bool) error {
|
||||||
|
if err := generateSingletonOutput(s); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate singleton template output: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeTests {
|
||||||
|
if err := generateTestMainOutput(s); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := generateSingletonTestOutput(s); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range s.Tables {
|
||||||
|
if table.IsJoinTable {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := &templateData{
|
||||||
|
Table: table,
|
||||||
|
DriverName: s.Config.DriverName,
|
||||||
|
PkgName: s.Config.PkgName,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the regular templates
|
||||||
|
if err := generateOutput(s, data); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate output: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the test templates
|
||||||
|
if includeTests {
|
||||||
|
if err := generateTestOutput(s, data); err != nil {
|
||||||
|
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup closes any resources that must be closed
|
||||||
|
func (s *State) Cleanup() error {
|
||||||
|
s.Driver.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initTemplates loads all template folders into the state object.
|
||||||
|
func (s *State) initTemplates() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
s.Templates, err = loadTemplates(templatesDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initDriver attempts to set the state Interface based off the passed in
|
||||||
|
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||||
|
func (s *State) initDriver(driverName string) error {
|
||||||
|
// Create a driver based off driver flag
|
||||||
|
switch driverName {
|
||||||
|
case "postgres":
|
||||||
|
s.Driver = dbdrivers.NewPostgresDriver(
|
||||||
|
s.Config.Postgres.User,
|
||||||
|
s.Config.Postgres.Pass,
|
||||||
|
s.Config.Postgres.DBName,
|
||||||
|
s.Config.Postgres.Host,
|
||||||
|
s.Config.Postgres.Port,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Driver == nil {
|
||||||
|
return errors.New("An invalid driver name was provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initTables will create a string slice out of the passed in table flag value
|
||||||
|
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||||
|
// database to retrieve all "public" table names, and build a slice out of that
|
||||||
|
// result.
|
||||||
|
func (s *State) initTables(tableName string) error {
|
||||||
|
var tableNames []string
|
||||||
|
|
||||||
|
if len(tableName) != 0 {
|
||||||
|
tableNames = strings.Split(tableName, ",")
|
||||||
|
for i, name := range tableNames {
|
||||||
|
tableNames[i] = strings.TrimSpace(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Tables) == 0 {
|
||||||
|
return errors.New("No tables found in database, migrate some tables first")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkPKeys(s.Tables); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initOutFolder creates the folder that will hold the generated output.
|
||||||
|
func (s *State) initOutFolder() error {
|
||||||
|
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
|
||||||
|
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPKeys ensures every table has a primary key column
|
||||||
|
func checkPKeys(tables []dbdrivers.Table) error {
|
||||||
|
var missingPkey []string
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.PKey == nil {
|
||||||
|
missingPkey = append(missingPkey, t.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missingPkey) != 0 {
|
||||||
|
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package cmds
|
package sqlboiler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -15,11 +15,11 @@ import (
|
||||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
var cmdData *CmdData
|
var state *State
|
||||||
var rgxHasSpaces = regexp.MustCompile(`^\s+`)
|
var rgxHasSpaces = regexp.MustCompile(`^\s+`)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
cmdData = &CmdData{
|
state = &State{
|
||||||
Tables: []dbdrivers.Table{
|
Tables: []dbdrivers.Table{
|
||||||
{
|
{
|
||||||
Name: "patrick_table",
|
Name: "patrick_table",
|
||||||
|
@ -59,10 +59,11 @@ func init() {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Config: &Config{
|
||||||
PkgName: "patrick",
|
PkgName: "patrick",
|
||||||
OutFolder: "",
|
OutFolder: "",
|
||||||
DriverName: "postgres",
|
DriverName: "postgres",
|
||||||
Interface: nil,
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,72 +85,72 @@ func TestTemplates(t *testing.T) {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
if err := checkPKeys(state.Tables); err != nil {
|
||||||
t.Fatalf("%s", err)
|
t.Fatalf("%s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the templates
|
// Initialize the templates
|
||||||
var err error
|
var err error
|
||||||
cmdData.Templates, err = loadTemplates("templates")
|
state.Templates, err = loadTemplates("templates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize templates: %s", err)
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cmdData.Templates) == 0 {
|
if len(state.Templates) == 0 {
|
||||||
t.Errorf("Templates is empty.")
|
t.Errorf("Templates is empty.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.SingleTemplates, err = loadTemplates("templates/singles")
|
state.SingletonTemplates, err = loadTemplates("templates/singleton")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize single templates: %s", err)
|
t.Fatalf("Unable to initialize singleton templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cmdData.SingleTemplates) == 0 {
|
if len(state.SingletonTemplates) == 0 {
|
||||||
t.Errorf("SingleTemplates is empty.")
|
t.Errorf("SingletonTemplates is empty.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.TestTemplates, err = loadTemplates("templates_test")
|
state.TestTemplates, err = loadTemplates("templates_test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize templates: %s", err)
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cmdData.Templates) == 0 {
|
if len(state.Templates) == 0 {
|
||||||
t.Errorf("Templates is empty.")
|
t.Errorf("Templates is empty.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
state.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize templates: %s", err)
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.SingleTestTemplates, err = loadTemplates("templates_test/singles")
|
state.SingletonTestTemplates, err = loadTemplates("templates_test/singleton")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize single test templates: %s", err)
|
t.Fatalf("Unable to initialize single test templates: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cmdData.SingleTestTemplates) == 0 {
|
if len(state.SingletonTestTemplates) == 0 {
|
||||||
t.Errorf("SingleTestTemplates is empty.")
|
t.Errorf("SingleTestTemplates is empty.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
state.Config.OutFolder, err = ioutil.TempDir("", "templates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to create tempdir: %s", err)
|
t.Fatalf("Unable to create tempdir: %s", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(cmdData.OutFolder)
|
defer os.RemoveAll(state.Config.OutFolder)
|
||||||
|
|
||||||
if err = cmdData.run(true); err != nil {
|
if err = state.Run(true); err != nil {
|
||||||
t.Errorf("Unable to run SQLBoilerRun: %s", err)
|
t.Errorf("Unable to run SQLBoilerRun: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
cmd := exec.Command("go", "test", "-c")
|
cmd := exec.Command("go", "test", "-c")
|
||||||
cmd.Dir = cmdData.OutFolder
|
cmd.Dir = state.Config.OutFolder
|
||||||
cmd.Stderr = buf
|
cmd.Stderr = buf
|
||||||
|
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
t.Errorf("go test cmd execution failed: %s", err)
|
t.Errorf("go test cmd execution failed: %s", err)
|
||||||
outputCompileErrors(buf, cmdData.OutFolder)
|
outputCompileErrors(buf, state.Config.OutFolder)
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
}
|
}
|
119
templates.go
Normal file
119
templates.go
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
package sqlboiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||||
|
"github.com/nullbio/sqlboiler/strmangle"
|
||||||
|
)
|
||||||
|
|
||||||
|
// templateData for sqlboiler templates
|
||||||
|
type templateData struct {
|
||||||
|
Tables []dbdrivers.Table
|
||||||
|
Table dbdrivers.Table
|
||||||
|
DriverName string
|
||||||
|
PkgName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type templateList []*template.Template
|
||||||
|
|
||||||
|
func (t templateList) Len() int {
|
||||||
|
return len(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t templateList) Swap(k, j int) {
|
||||||
|
t[k], t[j] = t[j], t[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t templateList) Less(k, j int) bool {
|
||||||
|
// Make sure "struct" goes to the front
|
||||||
|
if t[k].Name() == "struct.tpl" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
res := strings.Compare(t[k].Name(), t[j].Name())
|
||||||
|
if res <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTemplates loads all of the template files in the specified directory.
|
||||||
|
func loadTemplates(dir string) (templateList, error) {
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := filepath.Join(wd, dir, "*.tpl")
|
||||||
|
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
templates := templateList(tpl.Templates())
|
||||||
|
sort.Sort(templates)
|
||||||
|
|
||||||
|
return templates, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTemplate loads a single template file.
|
||||||
|
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := filepath.Join(wd, dir, filename)
|
||||||
|
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tpl.Lookup(filename), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateFunctions is a map of all the functions that get passed into the
|
||||||
|
// templates. If you wish to pass a new function into your own template,
|
||||||
|
// add a function pointer here.
|
||||||
|
var templateFunctions = template.FuncMap{
|
||||||
|
"singular": strmangle.Singular,
|
||||||
|
"plural": strmangle.Plural,
|
||||||
|
"titleCase": strmangle.TitleCase,
|
||||||
|
"titleCaseSingular": strmangle.TitleCaseSingular,
|
||||||
|
"titleCasePlural": strmangle.TitleCasePlural,
|
||||||
|
"titleCaseCommaList": strmangle.TitleCaseCommaList,
|
||||||
|
"camelCase": strmangle.CamelCase,
|
||||||
|
"camelCaseSingular": strmangle.CamelCaseSingular,
|
||||||
|
"camelCasePlural": strmangle.CamelCasePlural,
|
||||||
|
"camelCaseCommaList": strmangle.CamelCaseCommaList,
|
||||||
|
"columnsToStrings": strmangle.ColumnsToStrings,
|
||||||
|
"commaList": strmangle.CommaList,
|
||||||
|
"makeDBName": strmangle.MakeDBName,
|
||||||
|
"selectParamNames": strmangle.SelectParamNames,
|
||||||
|
"insertParamNames": strmangle.InsertParamNames,
|
||||||
|
"insertParamFlags": strmangle.InsertParamFlags,
|
||||||
|
"insertParamVariables": strmangle.InsertParamVariables,
|
||||||
|
"scanParamNames": strmangle.ScanParamNames,
|
||||||
|
"hasPrimaryKey": strmangle.HasPrimaryKey,
|
||||||
|
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
|
||||||
|
"wherePrimaryKey": strmangle.WherePrimaryKey,
|
||||||
|
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
|
||||||
|
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
|
||||||
|
"updateParamNames": strmangle.UpdateParamNames,
|
||||||
|
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||||
|
"supportsResultObject": strmangle.SupportsResultObject,
|
||||||
|
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
||||||
|
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
|
||||||
|
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
||||||
|
|
||||||
|
"randDBStruct": strmangle.RandDBStruct,
|
||||||
|
"randDBStructSlice": strmangle.RandDBStructSlice,
|
||||||
|
}
|
|
@ -1,11 +1,11 @@
|
||||||
{{if hasPrimaryKey .Table.PKey -}}
|
{{if hasPrimaryKey .Table.PKey -}}
|
||||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||||
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
||||||
// Update updates a single {{$tableNameSingular}} record.
|
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
|
||||||
// whitelist is a list of column_name's that should be updated.
|
// column_name's that should be updated. The primary key will be used to find
|
||||||
// Update will match against the primary key column to find the record to update.
|
// the record to update.
|
||||||
// WARNING: This Update method will NOT ignore nil members.
|
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
|
||||||
// If you pass in nil members, those columnns will be set to null.
|
// to control the set of columns that will be saved.
|
||||||
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
|
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
|
||||||
return o.UpdateX(boil.GetDB(), whitelist...)
|
return o.UpdateX(boil.GetDB(), whitelist...)
|
||||||
}
|
}
|
34
templates_test.go
Normal file
34
templates_test.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package sqlboiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTemplateListSort(t *testing.T) {
|
||||||
|
templs := templateList{
|
||||||
|
template.New("bob.tpl"),
|
||||||
|
template.New("all.tpl"),
|
||||||
|
template.New("struct.tpl"),
|
||||||
|
template.New("ttt.tpl"),
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
||||||
|
|
||||||
|
for i, v := range templs {
|
||||||
|
if v.Name() != expected[i] {
|
||||||
|
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
||||||
|
|
||||||
|
sort.Sort(templs)
|
||||||
|
|
||||||
|
for i, v := range templs {
|
||||||
|
if v.Name() != expected[i] {
|
||||||
|
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue