Merge branch 'master' of github.com:nullbio/sqlboiler
This commit is contained in:
commit
3352220c53
42 changed files with 903 additions and 949 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
sqlboiler
|
||||
config.toml
|
||||
/sqlboiler
|
||||
/cmd/sqlboiler/sqlboiler
|
||||
sqlboiler.toml
|
||||
|
|
11
boil/db.go
11
boil/db.go
|
@ -5,12 +5,19 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
var (
|
||||
// currentDB is a global database handle for the package
|
||||
currentDB Executor
|
||||
)
|
||||
|
||||
// Executor can perform SQL queries.
|
||||
type Executor interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// Transactor can commit and rollback, on top of being able to execute queries.
|
||||
type Transactor interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
|
@ -18,12 +25,11 @@ type Transactor interface {
|
|||
Executor
|
||||
}
|
||||
|
||||
// Creator starts transactions.
|
||||
type Creator interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
var currentDB Executor
|
||||
|
||||
// DebugMode is a flag controlling whether generated sql statements and
|
||||
// debug information is outputted to the DebugWriter handle
|
||||
//
|
||||
|
@ -33,6 +39,7 @@ var DebugMode = false
|
|||
// DebugWriter is where the debug output will be sent if DebugMode is true
|
||||
var DebugWriter = os.Stdout
|
||||
|
||||
// Begin a transaction
|
||||
func Begin() (Transactor, error) {
|
||||
creator, ok := currentDB.(Creator)
|
||||
if !ok {
|
||||
|
|
91
cmd/sqlboiler/main.go
Normal file
91
cmd/sqlboiler/main.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
// Package main defines a command line interface for the sqlboiler package
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/nullbio/sqlboiler"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
|
||||
viper.SetConfigName("sqlboiler")
|
||||
viper.AddConfigPath("$HOME/.sqlboiler")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
err = viper.ReadInConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load config file: %s\n", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Set up the cobra root command
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "sqlboiler",
|
||||
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
|
||||
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
|
||||
PreRunE: preRun,
|
||||
RunE: run,
|
||||
PostRunE: postRun,
|
||||
}
|
||||
|
||||
// Set up the cobra root command flags
|
||||
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
||||
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||
|
||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Println("Failed to execute sqlboiler command:", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
}
|
||||
|
||||
var state *sqlboiler.State
|
||||
var config *sqlboiler.Config
|
||||
|
||||
func preRun(cmd *cobra.Command, args []string) error {
|
||||
config = new(sqlboiler.Config)
|
||||
|
||||
config.DriverName = viper.GetString("driver")
|
||||
config.TableName = viper.GetString("table")
|
||||
config.OutFolder = viper.GetString("folder")
|
||||
config.PkgName = viper.GetString("pkgname")
|
||||
|
||||
if len(config.DriverName) == 0 {
|
||||
return errors.New("Must supply a driver flag.")
|
||||
}
|
||||
if len(config.OutFolder) == 0 {
|
||||
return fmt.Errorf("No output folder specified.")
|
||||
}
|
||||
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
config.Postgres = sqlboiler.PostgresConfig{
|
||||
User: viper.GetString("postgres.user"),
|
||||
Pass: viper.GetString("postgres.pass"),
|
||||
Host: viper.GetString("postgres.host"),
|
||||
Port: viper.GetInt("postgres.port"),
|
||||
DBName: viper.GetString("postgres.dbname"),
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
state, err = sqlboiler.New(config)
|
||||
return err
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) error {
|
||||
return state.Run(true)
|
||||
}
|
||||
|
||||
func postRun(cmd *cobra.Command, args []string) error {
|
||||
return state.Cleanup()
|
||||
}
|
187
cmds/config.go
187
cmds/config.go
|
@ -1,187 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"text/template"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/nullbio/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
||||
var sqlBoilerTypeImports = map[string]imports{
|
||||
"null.Float32": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Float64": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int8": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int16": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int32": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int64": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint8": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint16": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint32": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint64": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.String": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Bool": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Time": imports{
|
||||
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"time.Time": imports{
|
||||
standard: importList{`"time"`},
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerImports defines the list of default template imports.
|
||||
var sqlBoilerImports = imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"strings"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerSinglesImports = map[string]imports{
|
||||
"helpers": imports{
|
||||
standard: importList{},
|
||||
thirdparty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerTestImports defines the list of default test template imports.
|
||||
var sqlBoilerTestImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"reflect"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerSinglesTestImports = map[string]imports{
|
||||
"helper_funcs": imports{
|
||||
standard: importList{
|
||||
`"crypto/md5"`,
|
||||
`"fmt"`,
|
||||
`"os"`,
|
||||
`"strconv"`,
|
||||
`"math/rand"`,
|
||||
`"bytes"`,
|
||||
},
|
||||
thirdparty: importList{},
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerTestMainImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"database/sql"`,
|
||||
`"time"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/BurntSushi/toml"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
||||
// If you wish to pass a new function into your own template, add a pointer to it here.
|
||||
var sqlBoilerTemplateFuncs = template.FuncMap{
|
||||
"singular": strmangle.Singular,
|
||||
"plural": strmangle.Plural,
|
||||
"titleCase": strmangle.TitleCase,
|
||||
"titleCaseSingular": strmangle.TitleCaseSingular,
|
||||
"titleCasePlural": strmangle.TitleCasePlural,
|
||||
"titleCaseCommaList": strmangle.TitleCaseCommaList,
|
||||
"camelCase": strmangle.CamelCase,
|
||||
"camelCaseSingular": strmangle.CamelCaseSingular,
|
||||
"camelCasePlural": strmangle.CamelCasePlural,
|
||||
"camelCaseCommaList": strmangle.CamelCaseCommaList,
|
||||
"columnsToStrings": strmangle.ColumnsToStrings,
|
||||
"commaList": strmangle.CommaList,
|
||||
"makeDBName": strmangle.MakeDBName,
|
||||
"selectParamNames": strmangle.SelectParamNames,
|
||||
"insertParamNames": strmangle.InsertParamNames,
|
||||
"insertParamFlags": strmangle.InsertParamFlags,
|
||||
"insertParamVariables": strmangle.InsertParamVariables,
|
||||
"scanParamNames": strmangle.ScanParamNames,
|
||||
"hasPrimaryKey": strmangle.HasPrimaryKey,
|
||||
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
|
||||
"wherePrimaryKey": strmangle.WherePrimaryKey,
|
||||
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
|
||||
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
|
||||
"updateParamNames": strmangle.UpdateParamNames,
|
||||
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||
"supportsResultObject": strmangle.SupportsResultObject,
|
||||
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
||||
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
|
||||
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
||||
|
||||
"randDBStruct": strmangle.RandDBStruct,
|
||||
"randDBStructSlice": strmangle.RandDBStructSlice,
|
||||
}
|
||||
|
||||
// LoadConfigFile loads the toml config file into the cfg object
|
||||
func (c *CmdData) LoadConfigFile(filename string) error {
|
||||
cfg := &Config{}
|
||||
|
||||
_, err := toml.DecodeFile(filename, &cfg)
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
||||
}
|
||||
|
||||
c.Config = cfg
|
||||
return nil
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmdData := &CmdData{}
|
||||
|
||||
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
|
||||
defer os.Remove(file.Name())
|
||||
|
||||
fContents := `[postgres]
|
||||
host="localhost"
|
||||
port=5432
|
||||
user="user"
|
||||
pass="pass"
|
||||
dbname="mydb"`
|
||||
|
||||
file.WriteString(fContents)
|
||||
err := cmdData.LoadConfigFile(file.Name())
|
||||
if err != nil {
|
||||
t.Errorf("Unable to load config file: %s", err)
|
||||
}
|
||||
|
||||
if cmdData.Config.Postgres.Host != "localhost" ||
|
||||
cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
|
||||
cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
|
||||
t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
|
||||
}
|
||||
}
|
114
cmds/helpers.go
114
cmds/helpers.go
|
@ -1,114 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func combineImports(a, b imports) imports {
|
||||
var c imports
|
||||
|
||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
|
||||
|
||||
sort.Sort(c.standard)
|
||||
sort.Sort(c.thirdparty)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
||||
tmpImp := imports{
|
||||
standard: make(importList, len(a.standard)),
|
||||
thirdparty: make(importList, len(a.thirdparty)),
|
||||
}
|
||||
|
||||
copy(tmpImp.standard, a.standard)
|
||||
copy(tmpImp.thirdparty, a.thirdparty)
|
||||
|
||||
for _, col := range columns {
|
||||
for key, imp := range b {
|
||||
if col.Type == key {
|
||||
tmpImp.standard = append(tmpImp.standard, imp.standard...)
|
||||
tmpImp.thirdparty = append(tmpImp.thirdparty, imp.thirdparty...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpImp.standard = removeDuplicates(tmpImp.standard)
|
||||
tmpImp.thirdparty = removeDuplicates(tmpImp.thirdparty)
|
||||
|
||||
sort.Sort(tmpImp.standard)
|
||||
sort.Sort(tmpImp.thirdparty)
|
||||
|
||||
return tmpImp
|
||||
}
|
||||
|
||||
func buildImportString(imps imports) []byte {
|
||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
|
||||
if stdlen+thirdlen < 1 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if stdlen+thirdlen == 1 {
|
||||
var imp string
|
||||
if stdlen == 1 {
|
||||
imp = imps.standard[0]
|
||||
} else {
|
||||
imp = imps.thirdparty[0]
|
||||
}
|
||||
return []byte(fmt.Sprintf("import %s", imp))
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("import (")
|
||||
for _, std := range imps.standard {
|
||||
fmt.Fprintf(buf, "\n\t%s", std)
|
||||
}
|
||||
if stdlen != 0 && thirdlen != 0 {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
for _, third := range imps.thirdparty {
|
||||
fmt.Fprintf(buf, "\n\t%s", third)
|
||||
}
|
||||
buf.WriteString("\n)\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
copy(c, a)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
copy(c[len(a):], b)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func removeDuplicates(dedup []string) []string {
|
||||
if len(dedup) <= 1 {
|
||||
return dedup
|
||||
}
|
||||
|
||||
for i := 0; i < len(dedup)-1; i++ {
|
||||
for j := i + 1; j < len(dedup); j++ {
|
||||
if dedup[i] != dedup[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
if j != len(dedup)-1 {
|
||||
dedup[j] = dedup[len(dedup)-1]
|
||||
j--
|
||||
}
|
||||
dedup = dedup[:len(dedup)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return dedup
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import "strings"
|
||||
|
||||
func (i importList) Len() int {
|
||||
return len(i)
|
||||
}
|
||||
|
||||
func (i importList) Swap(k, j int) {
|
||||
i[k], i[j] = i[j], i[k]
|
||||
}
|
||||
|
||||
func (i importList) Less(k, j int) bool {
|
||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t templater) Len() int {
|
||||
return len(t)
|
||||
}
|
||||
|
||||
func (t templater) Swap(k, j int) {
|
||||
t[k], t[j] = t[j], t[k]
|
||||
}
|
||||
|
||||
func (t templater) Less(k, j int) bool {
|
||||
// Make sure "struct" goes to the front
|
||||
if t[k].Name() == "struct.tpl" {
|
||||
return true
|
||||
}
|
||||
|
||||
res := strings.Compare(t[k].Name(), t[j].Name())
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestSortImports(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a1 := importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
}
|
||||
a2 := importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
}
|
||||
|
||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||
a2Expected := importList{
|
||||
`"github.com/gorilla/mux"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
}
|
||||
|
||||
sort.Sort(a1)
|
||||
if !reflect.DeepEqual(a1, a1Expected) {
|
||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
if v != a1Expected[i] {
|
||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(a2)
|
||||
if !reflect.DeepEqual(a2, a2Expected) {
|
||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
if v != a2Expected[i] {
|
||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortTemplates(t *testing.T) {
|
||||
templs := templater{
|
||||
template.New("bob.tpl"),
|
||||
template.New("all.tpl"),
|
||||
template.New("struct.tpl"),
|
||||
template.New("ttt.tpl"),
|
||||
}
|
||||
|
||||
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
|
||||
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
||||
|
||||
sort.Sort(templs)
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,273 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
templatesDirectory = "/cmds/templates"
|
||||
templatesSinglesDirectory = "/cmds/templates/singles"
|
||||
|
||||
templatesTestDirectory = "/cmds/templates_test"
|
||||
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
|
||||
templatesTestMainDirectory = "/cmds/templates_test/main_test"
|
||||
)
|
||||
|
||||
// LoadTemplates loads all template folders into the cmdData object.
|
||||
func initTemplates(cmdData *CmdData) error {
|
||||
var err error
|
||||
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := cmdData.DriverName + "_main.tpl"
|
||||
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTemplates loads all of the template files in the specified directory.
|
||||
func loadTemplates(dir string) ([]*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
templates := templater(tpl.Templates())
|
||||
sort.Sort(templates)
|
||||
|
||||
return templates, err
|
||||
}
|
||||
|
||||
// loadTemplate loads a single template file.
|
||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, filename)
|
||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tpl.Lookup(filename), err
|
||||
}
|
||||
|
||||
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
||||
func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
|
||||
c.Interface.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
||||
func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
||||
// Initialize package name
|
||||
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
||||
|
||||
// Retrieve driver flag
|
||||
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
||||
if driverName == "" {
|
||||
return errors.New("Must supply a driver flag.")
|
||||
}
|
||||
|
||||
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
|
||||
|
||||
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
|
||||
if outFolder == "" {
|
||||
return fmt.Errorf("No output folder specified.")
|
||||
}
|
||||
|
||||
return c.initCmdData(pkgName, driverName, tableName, outFolder)
|
||||
}
|
||||
|
||||
// SQLBoilerRun is a proxy method for the run function
|
||||
func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||
return c.run(true)
|
||||
}
|
||||
|
||||
// run executes the sqlboiler templates and outputs them to files.
|
||||
func (c *CmdData) run(includeTests bool) error {
|
||||
if err := generateSinglesOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate single templates output: %s", err)
|
||||
}
|
||||
|
||||
if includeTests {
|
||||
if err := generateTestMainOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||
}
|
||||
|
||||
if err := generateSinglesTestOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate single test templates output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range c.Tables {
|
||||
data := &tplData{
|
||||
Table: table,
|
||||
DriverName: c.DriverName,
|
||||
PkgName: c.PkgName,
|
||||
}
|
||||
|
||||
// Generate the regular templates
|
||||
if err := generateOutput(c, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
}
|
||||
|
||||
// Generate the test templates
|
||||
if includeTests {
|
||||
if err := generateTestOutput(c, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
|
||||
c.OutFolder = outFolder
|
||||
c.PkgName = pkgName
|
||||
|
||||
err := initInterface(driverName, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect to the driver database
|
||||
if err = c.Interface.Open(); err != nil {
|
||||
return fmt.Errorf("Unable to connect to the database: %s", err)
|
||||
}
|
||||
|
||||
err = initTables(tableName, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize tables: %s", err)
|
||||
}
|
||||
|
||||
err = initOutFolder(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
}
|
||||
|
||||
err = initTemplates(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initInterface attempts to set the cmdData Interface based off the passed in
|
||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||
func initInterface(driverName string, cmdData *CmdData) error {
|
||||
// Create a driver based off driver flag
|
||||
switch driverName {
|
||||
case "postgres":
|
||||
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
||||
cmdData.Config.Postgres.User,
|
||||
cmdData.Config.Postgres.Pass,
|
||||
cmdData.Config.Postgres.DBName,
|
||||
cmdData.Config.Postgres.Host,
|
||||
cmdData.Config.Postgres.Port,
|
||||
)
|
||||
}
|
||||
|
||||
if cmdData.Interface == nil {
|
||||
return errors.New("An invalid driver name was provided")
|
||||
}
|
||||
|
||||
cmdData.DriverName = driverName
|
||||
return nil
|
||||
}
|
||||
|
||||
// initTables will create a string slice out of the passed in table flag value
|
||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
||||
func initTables(tableName string, cmdData *CmdData) error {
|
||||
var tableNames []string
|
||||
|
||||
if len(tableName) != 0 {
|
||||
tableNames = strings.Split(tableName, ",")
|
||||
for i, name := range tableNames {
|
||||
tableNames[i] = strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.Tables) == 0 {
|
||||
return errors.New("No tables found in database, migrate some tables first")
|
||||
}
|
||||
|
||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPKeys ensures every table has a primary key column
|
||||
func checkPKeys(tables []dbdrivers.Table) error {
|
||||
var missingPkey []string
|
||||
for _, t := range tables {
|
||||
if t.PKey == nil {
|
||||
missingPkey = append(missingPkey, t.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPkey) != 0 {
|
||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initOutFolder creates the folder that will hold the generated output.
|
||||
func initOutFolder(cmdData *CmdData) error {
|
||||
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"text/template"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// CobraRunFunc declares the cobra.Command.Run function definition
|
||||
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
||||
|
||||
type templater []*template.Template
|
||||
|
||||
// CmdData holds the table schema a slice of (column name, column type) slices.
|
||||
// It also holds a slice of all of the table names sqlboiler is generating against,
|
||||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||
// output file, if one is specified with a flag.
|
||||
type CmdData struct {
|
||||
Tables []dbdrivers.Table
|
||||
PkgName string
|
||||
OutFolder string
|
||||
Interface dbdrivers.Interface
|
||||
DriverName string
|
||||
Config *Config
|
||||
|
||||
Templates templater
|
||||
// SingleTemplates are only created once, not per table
|
||||
SingleTemplates templater
|
||||
|
||||
TestTemplates templater
|
||||
// SingleTestTemplates are only created once, not per table
|
||||
SingleTestTemplates templater
|
||||
//TestMainTemplate is only created once, not per table
|
||||
TestMainTemplate *template.Template
|
||||
}
|
||||
|
||||
// tplData is used to pass data to the template
|
||||
type tplData struct {
|
||||
Table dbdrivers.Table
|
||||
DriverName string
|
||||
PkgName string
|
||||
Tables []string
|
||||
}
|
||||
|
||||
type importList []string
|
||||
|
||||
// imports defines the optional standard imports and
|
||||
// thirdparty imports (from github for example)
|
||||
type imports struct {
|
||||
standard importList
|
||||
thirdparty importList
|
||||
}
|
||||
|
||||
type PostgresCfg struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
}
|
20
config.go
Normal file
20
config.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package sqlboiler
|
||||
|
||||
// Config for the running of the commands
|
||||
type Config struct {
|
||||
DriverName string `toml:"driver_name"`
|
||||
PkgName string `toml:"pkg_name"`
|
||||
OutFolder string `toml:"out_folder"`
|
||||
TableName string `toml:"table_name"`
|
||||
|
||||
Postgres PostgresConfig `toml:"postgres"`
|
||||
}
|
||||
|
||||
// PostgresConfig configures a postgres database
|
||||
type PostgresConfig struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
}
|
263
imports.go
Normal file
263
imports.go
Normal file
|
@ -0,0 +1,263 @@
|
|||
package sqlboiler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
// imports defines the optional standard imports and
|
||||
// thirdParty imports (from github for example)
|
||||
type imports struct {
|
||||
standard importList
|
||||
thirdParty importList
|
||||
}
|
||||
|
||||
// importList is a list of import names
|
||||
type importList []string
|
||||
|
||||
func (i importList) Len() int {
|
||||
return len(i)
|
||||
}
|
||||
|
||||
func (i importList) Swap(k, j int) {
|
||||
i[k], i[j] = i[j], i[k]
|
||||
}
|
||||
|
||||
func (i importList) Less(k, j int) bool {
|
||||
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func combineImports(a, b imports) imports {
|
||||
var c imports
|
||||
|
||||
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
|
||||
c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty))
|
||||
|
||||
sort.Sort(c.standard)
|
||||
sort.Sort(c.thirdParty)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
||||
tmpImp := imports{
|
||||
standard: make(importList, len(a.standard)),
|
||||
thirdParty: make(importList, len(a.thirdParty)),
|
||||
}
|
||||
|
||||
copy(tmpImp.standard, a.standard)
|
||||
copy(tmpImp.thirdParty, a.thirdParty)
|
||||
|
||||
for _, col := range columns {
|
||||
for key, imp := range b {
|
||||
if col.Type == key {
|
||||
tmpImp.standard = append(tmpImp.standard, imp.standard...)
|
||||
tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpImp.standard = removeDuplicates(tmpImp.standard)
|
||||
tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty)
|
||||
|
||||
sort.Sort(tmpImp.standard)
|
||||
sort.Sort(tmpImp.thirdParty)
|
||||
|
||||
return tmpImp
|
||||
}
|
||||
|
||||
func buildImportString(imps imports) []byte {
|
||||
stdlen, thirdlen := len(imps.standard), len(imps.thirdParty)
|
||||
if stdlen+thirdlen < 1 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if stdlen+thirdlen == 1 {
|
||||
var imp string
|
||||
if stdlen == 1 {
|
||||
imp = imps.standard[0]
|
||||
} else {
|
||||
imp = imps.thirdParty[0]
|
||||
}
|
||||
return []byte(fmt.Sprintf("import %s", imp))
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("import (")
|
||||
for _, std := range imps.standard {
|
||||
fmt.Fprintf(buf, "\n\t%s", std)
|
||||
}
|
||||
if stdlen != 0 && thirdlen != 0 {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
for _, third := range imps.thirdParty {
|
||||
fmt.Fprintf(buf, "\n\t%s", third)
|
||||
}
|
||||
buf.WriteString("\n)\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func combineStringSlices(a, b []string) []string {
|
||||
c := make([]string, len(a)+len(b))
|
||||
if len(a) > 0 {
|
||||
copy(c, a)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
copy(c[len(a):], b)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func removeDuplicates(dedup []string) []string {
|
||||
if len(dedup) <= 1 {
|
||||
return dedup
|
||||
}
|
||||
|
||||
for i := 0; i < len(dedup)-1; i++ {
|
||||
for j := i + 1; j < len(dedup); j++ {
|
||||
if dedup[i] != dedup[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
if j != len(dedup)-1 {
|
||||
dedup[j] = dedup[len(dedup)-1]
|
||||
j--
|
||||
}
|
||||
dedup = dedup[:len(dedup)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return dedup
|
||||
}
|
||||
|
||||
var defaultTemplateImports = imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"strings"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||
},
|
||||
}
|
||||
|
||||
var defaultSingletonTemplateImports = map[string]imports{
|
||||
"helpers": imports{
|
||||
standard: importList{},
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/nullbio/sqlboiler/boil/qs"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var defaultTestTemplateImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"reflect"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
||||
var defaultSingletonTestTemplateImports = map[string]imports{
|
||||
"helper_funcs": imports{
|
||||
standard: importList{
|
||||
`"crypto/md5"`,
|
||||
`"fmt"`,
|
||||
`"os"`,
|
||||
`"strconv"`,
|
||||
`"math/rand"`,
|
||||
`"bytes"`,
|
||||
},
|
||||
thirdParty: importList{},
|
||||
},
|
||||
}
|
||||
|
||||
var defaultTestMainImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"database/sql"`,
|
||||
`"time"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/BurntSushi/toml"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// importsBasedOnType imports are only included in the template output if the
|
||||
// database requires one of the following special types. Check
|
||||
// TranslateColumnType to see the type assignments.
|
||||
var importsBasedOnType = map[string]imports{
|
||||
"null.Float32": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Float64": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int8": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int16": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int32": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Int64": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint8": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint16": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint32": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Uint64": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.String": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Bool": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"null.Time": imports{
|
||||
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
|
||||
},
|
||||
"time.Time": imports{
|
||||
standard: importList{`"time"`},
|
||||
},
|
||||
}
|
|
@ -1,20 +1,66 @@
|
|||
package cmds
|
||||
package sqlboiler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
func TestImportsSort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a1 := importList{
|
||||
`"fmt"`,
|
||||
`"errors"`,
|
||||
}
|
||||
a2 := importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
}
|
||||
|
||||
a1Expected := importList{`"errors"`, `"fmt"`}
|
||||
a2Expected := importList{
|
||||
`"github.com/gorilla/mux"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/websocket"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
}
|
||||
|
||||
sort.Sort(a1)
|
||||
if !reflect.DeepEqual(a1, a1Expected) {
|
||||
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
|
||||
}
|
||||
|
||||
for i, v := range a1 {
|
||||
if v != a1Expected[i] {
|
||||
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(a2)
|
||||
if !reflect.DeepEqual(a2, a2Expected) {
|
||||
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
|
||||
}
|
||||
|
||||
for i, v := range a2 {
|
||||
if v != a2Expected[i] {
|
||||
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineTypeImports(t *testing.T) {
|
||||
imports1 := imports{
|
||||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
},
|
||||
}
|
||||
|
@ -25,7 +71,7 @@ func TestCombineTypeImports(t *testing.T) {
|
|||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"gopkg.in/nullbio/null.v4"`,
|
||||
},
|
||||
|
@ -46,7 +92,7 @@ func TestCombineTypeImports(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
|
||||
res1 := combineTypeImports(imports1, importsBasedOnType, cols)
|
||||
|
||||
if !reflect.DeepEqual(res1, importsExpected) {
|
||||
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
|
@ -58,13 +104,13 @@ func TestCombineTypeImports(t *testing.T) {
|
|||
`"fmt"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"gopkg.in/nullbio/null.v4"`,
|
||||
},
|
||||
}
|
||||
|
||||
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
|
||||
res2 := combineTypeImports(imports2, importsBasedOnType, cols)
|
||||
|
||||
if !reflect.DeepEqual(res2, importsExpected) {
|
||||
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
||||
|
@ -76,11 +122,11 @@ func TestCombineImports(t *testing.T) {
|
|||
|
||||
a := imports{
|
||||
standard: importList{"fmt"},
|
||||
thirdparty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
|
||||
thirdParty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
|
||||
}
|
||||
b := imports{
|
||||
standard: importList{"os"},
|
||||
thirdparty: importList{"github.com/nullbio/sqlboiler"},
|
||||
thirdParty: importList{"github.com/nullbio/sqlboiler"},
|
||||
}
|
||||
|
||||
c := combineImports(a, b)
|
||||
|
@ -88,8 +134,8 @@ func TestCombineImports(t *testing.T) {
|
|||
if c.standard[0] != "fmt" && c.standard[1] != "os" {
|
||||
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
|
||||
}
|
||||
if c.thirdparty[0] != "github.com/nullbio/sqlboiler" && c.thirdparty[1] != "gopkg.in/nullbio/null.v4" {
|
||||
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdparty)
|
||||
if c.thirdParty[0] != "github.com/nullbio/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" {
|
||||
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty)
|
||||
}
|
||||
}
|
||||
|
54
main.go
54
main.go
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
SQLBoiler is a tool to generate Go boilerplate code for database interactions.
|
||||
So far this includes struct definitions and database statement helper functions.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/nullbio/sqlboiler/cmds"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
cmdData := &cmds.CmdData{}
|
||||
|
||||
// Load the "config.toml" global config
|
||||
err = cmdData.LoadConfigFile("config.toml")
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load config file: %s\n", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
// Set up the cobra root command
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "sqlboiler",
|
||||
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
|
||||
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmdData.SQLBoilerPreRun(cmd, args)
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmdData.SQLBoilerRun(cmd, args)
|
||||
},
|
||||
PostRunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmdData.SQLBoilerPostRun(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
// Set up the cobra root command flags
|
||||
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
||||
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||
|
||||
// Execute SQLBoiler
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(-1)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package cmds
|
||||
package sqlboiler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -18,18 +18,18 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
|||
}
|
||||
|
||||
// generateOutput builds the file output and sends it to outHandler for saving
|
||||
func generateOutput(cmdData *CmdData, data *tplData) error {
|
||||
if len(cmdData.Templates) == 0 {
|
||||
func generateOutput(state *State, data *templateData) error {
|
||||
if len(state.Templates) == 0 {
|
||||
return errors.New("No template files located for generation")
|
||||
}
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerImports.standard
|
||||
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||
imps.standard = defaultTemplateImports.standard
|
||||
imps.thirdParty = defaultTemplateImports.thirdParty
|
||||
|
||||
for _, template := range cmdData.Templates {
|
||||
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||
for _, template := range state.Templates {
|
||||
imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns)
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
||||
|
@ -38,7 +38,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
|
|||
}
|
||||
|
||||
fName := data.Table.Name + ".go"
|
||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -47,18 +47,18 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
|
|||
}
|
||||
|
||||
// generateTestOutput builds the test file output and sends it to outHandler for saving
|
||||
func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
||||
if len(cmdData.TestTemplates) == 0 {
|
||||
func generateTestOutput(state *State, data *templateData) error {
|
||||
if len(state.TestTemplates) == 0 {
|
||||
return errors.New("No template files located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerTestImports.standard
|
||||
imps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||
imps.standard = defaultTestTemplateImports.standard
|
||||
imps.thirdParty = defaultTestTemplateImports.thirdParty
|
||||
|
||||
for _, template := range cmdData.TestTemplates {
|
||||
for _, template := range state.TestTemplates {
|
||||
resp, err := generateTemplate(template, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
||||
|
@ -67,7 +67,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
|||
}
|
||||
|
||||
fName := data.Table.Name + "_test.go"
|
||||
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
|
||||
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -75,20 +75,22 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func generateSinglesOutput(cmdData *CmdData) error {
|
||||
if cmdData.SingleTemplates == nil {
|
||||
return errors.New("No single templates located for generation")
|
||||
// generateSingletonOutput processes the templates that should only be run
|
||||
// one time.
|
||||
func generateSingletonOutput(state *State) error {
|
||||
if state.SingletonTemplates == nil {
|
||||
return errors.New("No singleton templates located for generation")
|
||||
}
|
||||
|
||||
tplData := &tplData{
|
||||
PkgName: cmdData.PkgName,
|
||||
DriverName: cmdData.DriverName,
|
||||
templateData := &templateData{
|
||||
PkgName: state.Config.PkgName,
|
||||
DriverName: state.Config.DriverName,
|
||||
}
|
||||
|
||||
for _, template := range cmdData.SingleTemplates {
|
||||
for _, template := range state.SingletonTemplates {
|
||||
var imps imports
|
||||
|
||||
resp, err := generateTemplate(template, tplData)
|
||||
resp, err := generateTemplate(template, templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
|
||||
}
|
||||
|
@ -97,12 +99,12 @@ func generateSinglesOutput(cmdData *CmdData) error {
|
|||
ext := filepath.Ext(fName)
|
||||
fName = fName[0 : len(fName)-len(ext)]
|
||||
|
||||
imps.standard = sqlBoilerSinglesImports[fName].standard
|
||||
imps.thirdparty = sqlBoilerSinglesImports[fName].thirdparty
|
||||
imps.standard = defaultSingletonTemplateImports[fName].standard
|
||||
imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty
|
||||
|
||||
fName = fName + ".go"
|
||||
|
||||
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
|
||||
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -111,20 +113,22 @@ func generateSinglesOutput(cmdData *CmdData) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func generateSinglesTestOutput(cmdData *CmdData) error {
|
||||
if cmdData.SingleTestTemplates == nil {
|
||||
return errors.New("No single test templates located for generation")
|
||||
// generateSingletonTestOutput processes the templates that should only be run
|
||||
// one time.
|
||||
func generateSingletonTestOutput(state *State) error {
|
||||
if state.SingletonTestTemplates == nil {
|
||||
return errors.New("No singleton test templates located for generation")
|
||||
}
|
||||
|
||||
tplData := &tplData{
|
||||
PkgName: cmdData.PkgName,
|
||||
DriverName: cmdData.DriverName,
|
||||
templateData := &templateData{
|
||||
PkgName: state.Config.PkgName,
|
||||
DriverName: state.Config.DriverName,
|
||||
}
|
||||
|
||||
for _, template := range cmdData.SingleTestTemplates {
|
||||
for _, template := range state.SingletonTestTemplates {
|
||||
var imps imports
|
||||
|
||||
resp, err := generateTemplate(template, tplData)
|
||||
resp, err := generateTemplate(template, templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
||||
}
|
||||
|
@ -133,12 +137,12 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
|
|||
ext := filepath.Ext(fName)
|
||||
fName = fName[0 : len(fName)-len(ext)]
|
||||
|
||||
imps.standard = sqlBoilerSinglesTestImports[fName].standard
|
||||
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty
|
||||
imps.standard = defaultSingletonTestTemplateImports[fName].standard
|
||||
imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty
|
||||
|
||||
fName = fName + "_test.go"
|
||||
|
||||
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
|
||||
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -147,35 +151,30 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func generateTestMainOutput(cmdData *CmdData) error {
|
||||
if cmdData.TestMainTemplate == nil {
|
||||
func generateTestMainOutput(state *State) error {
|
||||
if state.TestMainTemplate == nil {
|
||||
return errors.New("No TestMain template located for generation")
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
var imps imports
|
||||
|
||||
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
|
||||
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
|
||||
imps.standard = defaultTestMainImports[state.Config.DriverName].standard
|
||||
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
|
||||
|
||||
var tables []string
|
||||
for _, v := range cmdData.Tables {
|
||||
tables = append(tables, v.Name)
|
||||
templateData := &templateData{
|
||||
Tables: state.Tables,
|
||||
PkgName: state.Config.PkgName,
|
||||
DriverName: state.Config.DriverName,
|
||||
}
|
||||
|
||||
tplData := &tplData{
|
||||
PkgName: cmdData.PkgName,
|
||||
DriverName: cmdData.DriverName,
|
||||
Tables: tables,
|
||||
}
|
||||
|
||||
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)
|
||||
resp, err := generateTemplate(state.TestMainTemplate, templateData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, resp)
|
||||
|
||||
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out)
|
||||
err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -216,7 +215,7 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports,
|
|||
}
|
||||
|
||||
// generateTemplate takes a template and returns the output of the template execution.
|
||||
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) {
|
||||
func generateTemplate(t *template.Template, data *templateData) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return nil, err
|
|
@ -1,4 +1,4 @@
|
|||
package cmds
|
||||
package sqlboiler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -78,7 +78,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
}
|
||||
|
||||
a2 := imports{
|
||||
thirdparty: []string{
|
||||
thirdParty: []string{
|
||||
`"github.com/spf13/cobra"`,
|
||||
},
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
`"fmt"`,
|
||||
`"errors"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
thirdParty: importList{
|
||||
`_ "github.com/lib/pq"`,
|
||||
`_ "github.com/gorilla/n"`,
|
||||
`"github.com/gorilla/mux"`,
|
||||
|
@ -106,7 +106,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
|||
file = &bytes.Buffer{}
|
||||
|
||||
sort.Sort(a3.standard)
|
||||
sort.Sort(a3.thirdparty)
|
||||
sort.Sort(a3.thirdParty)
|
||||
|
||||
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
|
||||
t.Error(err)
|
223
sqlboiler.go
Normal file
223
sqlboiler.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
// Package sqlboiler has types and methods useful for generating code that
|
||||
// acts as a fully dynamic ORM might.
|
||||
package sqlboiler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
const (
|
||||
templatesDirectory = "cmds/templates"
|
||||
templatesSingletonDirectory = "cmds/templates/singleton"
|
||||
|
||||
templatesTestDirectory = "cmds/templates_test"
|
||||
templatesSingletonTestDirectory = "cmds/templates_test/singleton"
|
||||
)
|
||||
|
||||
// State holds the global data needed by most pieces to run
|
||||
type State struct {
|
||||
Config *Config
|
||||
|
||||
Driver dbdrivers.Interface
|
||||
Tables []dbdrivers.Table
|
||||
|
||||
Templates templateList
|
||||
TestTemplates templateList
|
||||
SingletonTemplates templateList
|
||||
SingletonTestTemplates templateList
|
||||
|
||||
TestMainTemplate *template.Template
|
||||
}
|
||||
|
||||
// New creates a new state based off of the config
|
||||
func New(config *Config) (*State, error) {
|
||||
s := &State{}
|
||||
|
||||
err := s.initDriver(config.DriverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect to the driver database
|
||||
if err = s.Driver.Open(); err != nil {
|
||||
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
|
||||
}
|
||||
|
||||
err = s.initTables(config.TableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
|
||||
}
|
||||
|
||||
err = s.initOutFolder()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||
}
|
||||
|
||||
err = s.initTemplates()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Run executes the sqlboiler templates and outputs them to files based on the
|
||||
// state given.
|
||||
func (s *State) Run(includeTests bool) error {
|
||||
if err := generateSingletonOutput(s); err != nil {
|
||||
return fmt.Errorf("Unable to generate singleton template output: %s", err)
|
||||
}
|
||||
|
||||
if includeTests {
|
||||
if err := generateTestMainOutput(s); err != nil {
|
||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||
}
|
||||
|
||||
if err := generateSingletonTestOutput(s); err != nil {
|
||||
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range s.Tables {
|
||||
if table.IsJoinTable {
|
||||
continue
|
||||
}
|
||||
|
||||
data := &templateData{
|
||||
Table: table,
|
||||
DriverName: s.Config.DriverName,
|
||||
PkgName: s.Config.PkgName,
|
||||
}
|
||||
|
||||
// Generate the regular templates
|
||||
if err := generateOutput(s, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate output: %s", err)
|
||||
}
|
||||
|
||||
// Generate the test templates
|
||||
if includeTests {
|
||||
if err := generateTestOutput(s, data); err != nil {
|
||||
return fmt.Errorf("Unable to generate test output: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup closes any resources that must be closed
|
||||
func (s *State) Cleanup() error {
|
||||
s.Driver.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// initTemplates loads all template folders into the state object.
|
||||
func (s *State) initTemplates() error {
|
||||
var err error
|
||||
|
||||
s.Templates, err = loadTemplates(templatesDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initDriver attempts to set the state Interface based off the passed in
|
||||
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||
func (s *State) initDriver(driverName string) error {
|
||||
// Create a driver based off driver flag
|
||||
switch driverName {
|
||||
case "postgres":
|
||||
s.Driver = dbdrivers.NewPostgresDriver(
|
||||
s.Config.Postgres.User,
|
||||
s.Config.Postgres.Pass,
|
||||
s.Config.Postgres.DBName,
|
||||
s.Config.Postgres.Host,
|
||||
s.Config.Postgres.Port,
|
||||
)
|
||||
}
|
||||
|
||||
if s.Driver == nil {
|
||||
return errors.New("An invalid driver name was provided")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initTables will create a string slice out of the passed in table flag value
|
||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||
// database to retrieve all "public" table names, and build a slice out of that
|
||||
// result.
|
||||
func (s *State) initTables(tableName string) error {
|
||||
var tableNames []string
|
||||
|
||||
if len(tableName) != 0 {
|
||||
tableNames = strings.Split(tableName, ",")
|
||||
for i, name := range tableNames {
|
||||
tableNames[i] = strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||
}
|
||||
|
||||
if len(s.Tables) == 0 {
|
||||
return errors.New("No tables found in database, migrate some tables first")
|
||||
}
|
||||
|
||||
if err := checkPKeys(s.Tables); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initOutFolder creates the folder that will hold the generated output.
|
||||
func (s *State) initOutFolder() error {
|
||||
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPKeys ensures every table has a primary key column
|
||||
func checkPKeys(tables []dbdrivers.Table) error {
|
||||
var missingPkey []string
|
||||
for _, t := range tables {
|
||||
if t.PKey == nil {
|
||||
missingPkey = append(missingPkey, t.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPkey) != 0 {
|
||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package cmds
|
||||
package sqlboiler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
@ -15,11 +15,11 @@ import (
|
|||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
var cmdData *CmdData
|
||||
var state *State
|
||||
var rgxHasSpaces = regexp.MustCompile(`^\s+`)
|
||||
|
||||
func init() {
|
||||
cmdData = &CmdData{
|
||||
state = &State{
|
||||
Tables: []dbdrivers.Table{
|
||||
{
|
||||
Name: "patrick_table",
|
||||
|
@ -59,10 +59,11 @@ func init() {
|
|||
},
|
||||
},
|
||||
},
|
||||
PkgName: "patrick",
|
||||
OutFolder: "",
|
||||
DriverName: "postgres",
|
||||
Interface: nil,
|
||||
Config: &Config{
|
||||
PkgName: "patrick",
|
||||
OutFolder: "",
|
||||
DriverName: "postgres",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,72 +85,72 @@ func TestTemplates(t *testing.T) {
|
|||
t.SkipNow()
|
||||
}
|
||||
|
||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||
if err := checkPKeys(state.Tables); err != nil {
|
||||
t.Fatalf("%s", err)
|
||||
}
|
||||
|
||||
// Initialize the templates
|
||||
var err error
|
||||
cmdData.Templates, err = loadTemplates("templates")
|
||||
state.Templates, err = loadTemplates("templates")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.Templates) == 0 {
|
||||
if len(state.Templates) == 0 {
|
||||
t.Errorf("Templates is empty.")
|
||||
}
|
||||
|
||||
cmdData.SingleTemplates, err = loadTemplates("templates/singles")
|
||||
state.SingletonTemplates, err = loadTemplates("templates/singleton")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize single templates: %s", err)
|
||||
t.Fatalf("Unable to initialize singleton templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.SingleTemplates) == 0 {
|
||||
t.Errorf("SingleTemplates is empty.")
|
||||
if len(state.SingletonTemplates) == 0 {
|
||||
t.Errorf("SingletonTemplates is empty.")
|
||||
}
|
||||
|
||||
cmdData.TestTemplates, err = loadTemplates("templates_test")
|
||||
state.TestTemplates, err = loadTemplates("templates_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.Templates) == 0 {
|
||||
if len(state.Templates) == 0 {
|
||||
t.Errorf("Templates is empty.")
|
||||
}
|
||||
|
||||
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||
state.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize templates: %s", err)
|
||||
}
|
||||
|
||||
cmdData.SingleTestTemplates, err = loadTemplates("templates_test/singles")
|
||||
state.SingletonTestTemplates, err = loadTemplates("templates_test/singleton")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to initialize single test templates: %s", err)
|
||||
}
|
||||
|
||||
if len(cmdData.SingleTestTemplates) == 0 {
|
||||
if len(state.SingletonTestTemplates) == 0 {
|
||||
t.Errorf("SingleTestTemplates is empty.")
|
||||
}
|
||||
|
||||
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
|
||||
state.Config.OutFolder, err = ioutil.TempDir("", "templates")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create tempdir: %s", err)
|
||||
}
|
||||
defer os.RemoveAll(cmdData.OutFolder)
|
||||
defer os.RemoveAll(state.Config.OutFolder)
|
||||
|
||||
if err = cmdData.run(true); err != nil {
|
||||
if err = state.Run(true); err != nil {
|
||||
t.Errorf("Unable to run SQLBoilerRun: %s", err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
cmd := exec.Command("go", "test", "-c")
|
||||
cmd.Dir = cmdData.OutFolder
|
||||
cmd.Dir = state.Config.OutFolder
|
||||
cmd.Stderr = buf
|
||||
|
||||
if err = cmd.Run(); err != nil {
|
||||
t.Errorf("go test cmd execution failed: %s", err)
|
||||
outputCompileErrors(buf, cmdData.OutFolder)
|
||||
outputCompileErrors(buf, state.Config.OutFolder)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
119
templates.go
Normal file
119
templates.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package sqlboiler
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/nullbio/sqlboiler/dbdrivers"
|
||||
"github.com/nullbio/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// templateData for sqlboiler templates
|
||||
type templateData struct {
|
||||
Tables []dbdrivers.Table
|
||||
Table dbdrivers.Table
|
||||
DriverName string
|
||||
PkgName string
|
||||
}
|
||||
|
||||
type templateList []*template.Template
|
||||
|
||||
func (t templateList) Len() int {
|
||||
return len(t)
|
||||
}
|
||||
|
||||
func (t templateList) Swap(k, j int) {
|
||||
t[k], t[j] = t[j], t[k]
|
||||
}
|
||||
|
||||
func (t templateList) Less(k, j int) bool {
|
||||
// Make sure "struct" goes to the front
|
||||
if t[k].Name() == "struct.tpl" {
|
||||
return true
|
||||
}
|
||||
|
||||
res := strings.Compare(t[k].Name(), t[j].Name())
|
||||
if res <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// loadTemplates loads all of the template files in the specified directory.
|
||||
func loadTemplates(dir string) (templateList, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
||||
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
templates := templateList(tpl.Templates())
|
||||
sort.Sort(templates)
|
||||
|
||||
return templates, err
|
||||
}
|
||||
|
||||
// loadTemplate loads a single template file.
|
||||
func loadTemplate(dir string, filename string) (*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, dir, filename)
|
||||
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tpl.Lookup(filename), err
|
||||
}
|
||||
|
||||
// templateFunctions is a map of all the functions that get passed into the
|
||||
// templates. If you wish to pass a new function into your own template,
|
||||
// add a function pointer here.
|
||||
var templateFunctions = template.FuncMap{
|
||||
"singular": strmangle.Singular,
|
||||
"plural": strmangle.Plural,
|
||||
"titleCase": strmangle.TitleCase,
|
||||
"titleCaseSingular": strmangle.TitleCaseSingular,
|
||||
"titleCasePlural": strmangle.TitleCasePlural,
|
||||
"titleCaseCommaList": strmangle.TitleCaseCommaList,
|
||||
"camelCase": strmangle.CamelCase,
|
||||
"camelCaseSingular": strmangle.CamelCaseSingular,
|
||||
"camelCasePlural": strmangle.CamelCasePlural,
|
||||
"camelCaseCommaList": strmangle.CamelCaseCommaList,
|
||||
"columnsToStrings": strmangle.ColumnsToStrings,
|
||||
"commaList": strmangle.CommaList,
|
||||
"makeDBName": strmangle.MakeDBName,
|
||||
"selectParamNames": strmangle.SelectParamNames,
|
||||
"insertParamNames": strmangle.InsertParamNames,
|
||||
"insertParamFlags": strmangle.InsertParamFlags,
|
||||
"insertParamVariables": strmangle.InsertParamVariables,
|
||||
"scanParamNames": strmangle.ScanParamNames,
|
||||
"hasPrimaryKey": strmangle.HasPrimaryKey,
|
||||
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
|
||||
"wherePrimaryKey": strmangle.WherePrimaryKey,
|
||||
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
|
||||
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
|
||||
"updateParamNames": strmangle.UpdateParamNames,
|
||||
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||
"supportsResultObject": strmangle.SupportsResultObject,
|
||||
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
||||
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
|
||||
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
||||
|
||||
"randDBStruct": strmangle.RandDBStruct,
|
||||
"randDBStructSlice": strmangle.RandDBStructSlice,
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
{{if hasPrimaryKey .Table.PKey -}}
|
||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
||||
// Update updates a single {{$tableNameSingular}} record.
|
||||
// whitelist is a list of column_name's that should be updated.
|
||||
// Update will match against the primary key column to find the record to update.
|
||||
// WARNING: This Update method will NOT ignore nil members.
|
||||
// If you pass in nil members, those columnns will be set to null.
|
||||
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
|
||||
// column_name's that should be updated. The primary key will be used to find
|
||||
// the record to update.
|
||||
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
|
||||
// to control the set of columns that will be saved.
|
||||
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
|
||||
return o.UpdateX(boil.GetDB(), whitelist...)
|
||||
}
|
34
templates_test.go
Normal file
34
templates_test.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package sqlboiler
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestTemplateListSort(t *testing.T) {
|
||||
templs := templateList{
|
||||
template.New("bob.tpl"),
|
||||
template.New("all.tpl"),
|
||||
template.New("struct.tpl"),
|
||||
template.New("ttt.tpl"),
|
||||
}
|
||||
|
||||
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
|
||||
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
|
||||
|
||||
sort.Sort(templs)
|
||||
|
||||
for i, v := range templs {
|
||||
if v.Name() != expected[i] {
|
||||
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue