Merge pull request #18 from nullbio/random-improvements

Random improvements
This commit is contained in:
Patrick O'Brien 2016-06-12 16:35:38 +10:00 committed by GitHub
commit b2ae6cb688
42 changed files with 900 additions and 946 deletions

5
.gitignore vendored
View file

@ -1,2 +1,3 @@
sqlboiler
config.toml
/sqlboiler
/cmd/sqlboiler/sqlboiler
sqlboiler.toml

View file

@ -5,12 +5,19 @@ import (
"os"
)
var (
// currentDB is a global database handle for the package
currentDB Executor
)
// Executor can perform SQL queries.
type Executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// Transactor can commit and rollback, on top of being able to execute queries.
type Transactor interface {
Commit() error
Rollback() error
@ -18,12 +25,11 @@ type Transactor interface {
Executor
}
// Creator starts transactions.
type Creator interface {
Begin() (*sql.Tx, error)
}
var currentDB Executor
// DebugMode is a flag controlling whether generated sql statements and
// debug information is outputted to the DebugWriter handle
//
@ -33,6 +39,7 @@ var DebugMode = false
// DebugWriter is where the debug output will be sent if DebugMode is true
var DebugWriter = os.Stdout
// Begin a transaction
func Begin() (Transactor, error) {
creator, ok := currentDB.(Creator)
if !ok {

91
cmd/sqlboiler/main.go Normal file
View file

@ -0,0 +1,91 @@
// Package main defines a command line interface for the sqlboiler package
package main
import (
"errors"
"fmt"
"os"
"github.com/nullbio/sqlboiler"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
func main() {
var err error
viper.SetConfigName("sqlboiler")
viper.AddConfigPath("$HOME/.sqlboiler")
viper.AddConfigPath(".")
err = viper.ReadInConfig()
if err != nil {
fmt.Printf("Failed to load config file: %s\n", err)
os.Exit(-1)
}
// Set up the cobra root command
var rootCmd = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
PreRunE: preRun,
RunE: run,
PostRunE: postRun,
}
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
viper.BindPFlags(rootCmd.PersistentFlags())
if err := rootCmd.Execute(); err != nil {
fmt.Println("Failed to execute sqlboiler command:", err)
os.Exit(-1)
}
}
var state *sqlboiler.State
var config *sqlboiler.Config
func preRun(cmd *cobra.Command, args []string) error {
config = new(sqlboiler.Config)
config.DriverName = viper.GetString("driver")
config.TableName = viper.GetString("table")
config.OutFolder = viper.GetString("folder")
config.PkgName = viper.GetString("pkgname")
if len(config.DriverName) == 0 {
return errors.New("Must supply a driver flag.")
}
if len(config.OutFolder) == 0 {
return fmt.Errorf("No output folder specified.")
}
if viper.IsSet("postgres.dbname") {
config.Postgres = sqlboiler.PostgresConfig{
User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"),
Host: viper.GetString("postgres.host"),
Port: viper.GetInt("postgres.port"),
DBName: viper.GetString("postgres.dbname"),
}
}
var err error
state, err = sqlboiler.New(config)
return err
}
func run(cmd *cobra.Command, args []string) error {
return state.Run(true)
}
func postRun(cmd *cobra.Command, args []string) error {
return state.Cleanup()
}

View file

@ -1,186 +0,0 @@
package cmds
import (
"fmt"
"os"
"text/template"
"github.com/BurntSushi/toml"
"github.com/nullbio/sqlboiler/strmangle"
)
// sqlBoilerTypeImports imports are only included in the template output if the database
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
var sqlBoilerTypeImports = map[string]imports{
"null.Float32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Float64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int8": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int16": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint8": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint16": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint32": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint64": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.String": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Bool": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Time": imports{
thirdparty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}
// sqlBoilerImports defines the list of default template imports.
var sqlBoilerImports = imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"strings"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
}
var sqlBoilerSinglesImports = map[string]imports{
"helpers": imports{
standard: importList{},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
},
}
// sqlBoilerTestImports defines the list of default test template imports.
var sqlBoilerTestImports = imports{
standard: importList{
`"testing"`,
`"reflect"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
},
}
var sqlBoilerSinglesTestImports = map[string]imports{
"helper_funcs": imports{
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdparty: importList{},
},
}
var sqlBoilerTestMainImports = map[string]imports{
"postgres": imports{
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"time"`,
`"math/rand"`,
},
thirdparty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/BurntSushi/toml"`,
`_ "github.com/lib/pq"`,
},
},
}
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{
"singular": strmangle.Singular,
"plural": strmangle.Plural,
"titleCase": strmangle.TitleCase,
"titleCaseSingular": strmangle.TitleCaseSingular,
"titleCasePlural": strmangle.TitleCasePlural,
"titleCaseCommaList": strmangle.TitleCaseCommaList,
"camelCase": strmangle.CamelCase,
"camelCaseSingular": strmangle.CamelCaseSingular,
"camelCasePlural": strmangle.CamelCasePlural,
"camelCaseCommaList": strmangle.CamelCaseCommaList,
"columnsToStrings": strmangle.ColumnsToStrings,
"commaList": strmangle.CommaList,
"makeDBName": strmangle.MakeDBName,
"selectParamNames": strmangle.SelectParamNames,
"insertParamNames": strmangle.InsertParamNames,
"insertParamFlags": strmangle.InsertParamFlags,
"insertParamVariables": strmangle.InsertParamVariables,
"scanParamNames": strmangle.ScanParamNames,
"hasPrimaryKey": strmangle.HasPrimaryKey,
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
"wherePrimaryKey": strmangle.WherePrimaryKey,
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
"updateParamNames": strmangle.UpdateParamNames,
"updateParamVariables": strmangle.UpdateParamVariables,
"supportsResultObject": strmangle.SupportsResultObject,
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
"randDBStruct": strmangle.RandDBStruct,
"randDBStructSlice": strmangle.RandDBStructSlice,
}
// LoadConfigFile loads the toml config file into the cfg object
func (c *CmdData) LoadConfigFile(filename string) error {
cfg := &Config{}
_, err := toml.DecodeFile(filename, &cfg)
if os.IsNotExist(err) {
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
}
if err != nil {
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
}
c.Config = cfg
return nil
}

View file

@ -1,35 +0,0 @@
package cmds
import (
"io/ioutil"
"os"
"testing"
)
func TestLoadConfig(t *testing.T) {
t.Parallel()
cmdData := &CmdData{}
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
defer os.Remove(file.Name())
fContents := `[postgres]
host="localhost"
port=5432
user="user"
pass="pass"
dbname="mydb"`
file.WriteString(fContents)
err := cmdData.LoadConfigFile(file.Name())
if err != nil {
t.Errorf("Unable to load config file: %s", err)
}
if cmdData.Config.Postgres.Host != "localhost" ||
cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
}
}

View file

@ -1,114 +0,0 @@
package cmds
import (
"bytes"
"fmt"
"sort"
"github.com/nullbio/sqlboiler/dbdrivers"
)
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty))
sort.Sort(c.standard)
sort.Sort(c.thirdparty)
return c
}
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
tmpImp := imports{
standard: make(importList, len(a.standard)),
thirdparty: make(importList, len(a.thirdparty)),
}
copy(tmpImp.standard, a.standard)
copy(tmpImp.thirdparty, a.thirdparty)
for _, col := range columns {
for key, imp := range b {
if col.Type == key {
tmpImp.standard = append(tmpImp.standard, imp.standard...)
tmpImp.thirdparty = append(tmpImp.thirdparty, imp.thirdparty...)
}
}
}
tmpImp.standard = removeDuplicates(tmpImp.standard)
tmpImp.thirdparty = removeDuplicates(tmpImp.thirdparty)
sort.Sort(tmpImp.standard)
sort.Sort(tmpImp.thirdparty)
return tmpImp
}
func buildImportString(imps imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdparty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdparty[0]
}
return []byte(fmt.Sprintf("import %s", imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdparty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}

View file

@ -1,42 +0,0 @@
package cmds
import "strings"
func (i importList) Len() int {
return len(i)
}
func (i importList) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i importList) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func (t templater) Len() int {
return len(t)
}
func (t templater) Swap(k, j int) {
t[k], t[j] = t[j], t[k]
}
func (t templater) Less(k, j int) bool {
// Make sure "struct" goes to the front
if t[k].Name() == "struct.tpl" {
return true
}
res := strings.Compare(t[k].Name(), t[j].Name())
if res <= 0 {
return true
}
return false
}

View file

@ -1,80 +0,0 @@
package cmds
import (
"reflect"
"sort"
"testing"
"text/template"
)
func TestSortImports(t *testing.T) {
t.Parallel()
a1 := importList{
`"fmt"`,
`"errors"`,
}
a2 := importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
}
a1Expected := importList{`"errors"`, `"fmt"`}
a2Expected := importList{
`"github.com/gorilla/mux"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/websocket"`,
`_ "github.com/lib/pq"`,
}
sort.Sort(a1)
if !reflect.DeepEqual(a1, a1Expected) {
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
}
for i, v := range a1 {
if v != a1Expected[i] {
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
sort.Sort(a2)
if !reflect.DeepEqual(a2, a2Expected) {
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
}
for i, v := range a2 {
if v != a2Expected[i] {
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
}
func TestSortTemplates(t *testing.T) {
templs := templater{
template.New("bob.tpl"),
template.New("all.tpl"),
template.New("struct.tpl"),
template.New("ttt.tpl"),
}
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
sort.Sort(templs)
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
}

View file

@ -1,273 +0,0 @@
package cmds
import (
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
const (
templatesDirectory = "/cmds/templates"
templatesSinglesDirectory = "/cmds/templates/singles"
templatesTestDirectory = "/cmds/templates_test"
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
templatesTestMainDirectory = "/cmds/templates_test/main_test"
)
// LoadTemplates loads all template folders into the cmdData object.
func initTemplates(cmdData *CmdData) error {
var err error
cmdData.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory)
if err != nil {
return err
}
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
if err != nil {
return err
}
filename := cmdData.DriverName + "_main.tpl"
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
if err != nil {
return err
}
return nil
}
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil {
return nil, err
}
templates := templater(tpl.Templates())
sort.Sort(templates)
return templates, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern)
if err != nil {
return nil, err
}
return tpl.Lookup(filename), err
}
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
c.Interface.Close()
return nil
}
// SQLBoilerPreRun performs the initialization tasks before the root command is run
func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
// Initialize package name
pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String()
// Retrieve driver flag
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
return errors.New("Must supply a driver flag.")
}
tableName := cmd.PersistentFlags().Lookup("table").Value.String()
outFolder := cmd.PersistentFlags().Lookup("folder").Value.String()
if outFolder == "" {
return fmt.Errorf("No output folder specified.")
}
return c.initCmdData(pkgName, driverName, tableName, outFolder)
}
// SQLBoilerRun is a proxy method for the run function
func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
return c.run(true)
}
// run executes the sqlboiler templates and outputs them to files.
func (c *CmdData) run(includeTests bool) error {
if err := generateSinglesOutput(c); err != nil {
return fmt.Errorf("Unable to generate single templates output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(c); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
if err := generateSinglesTestOutput(c); err != nil {
return fmt.Errorf("Unable to generate single test templates output: %s", err)
}
}
for _, table := range c.Tables {
data := &tplData{
Table: table,
DriverName: c.DriverName,
PkgName: c.PkgName,
}
// Generate the regular templates
if err := generateOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}
return nil
}
func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error {
c.OutFolder = outFolder
c.PkgName = pkgName
err := initInterface(driverName, c)
if err != nil {
return err
}
// Connect to the driver database
if err = c.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err)
}
err = initTables(tableName, c)
if err != nil {
return fmt.Errorf("Unable to initialize tables: %s", err)
}
err = initOutFolder(c)
if err != nil {
return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
err = initTemplates(c)
if err != nil {
return fmt.Errorf("Unable to initialize templates: %s", err)
}
return nil
}
// initInterface attempts to set the cmdData Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(driverName string, cmdData *CmdData) error {
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver(
cmdData.Config.Postgres.User,
cmdData.Config.Postgres.Pass,
cmdData.Config.Postgres.DBName,
cmdData.Config.Postgres.Host,
cmdData.Config.Postgres.Port,
)
}
if cmdData.Interface == nil {
return errors.New("An invalid driver name was provided")
}
cmdData.DriverName = driverName
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTables(tableName string, cmdData *CmdData) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(cmdData.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
if err := checkPKeys(cmdData.Tables); err != nil {
return err
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmdData *CmdData) error {
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
}

View file

@ -1,65 +0,0 @@
package cmds
import (
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
// CobraRunFunc declares the cobra.Command.Run function definition
type CobraRunFunc func(cmd *cobra.Command, args []string)
type templater []*template.Template
// CmdData holds the table schema a slice of (column name, column type) slices.
// It also holds a slice of all of the table names sqlboiler is generating against,
// the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag.
type CmdData struct {
Tables []dbdrivers.Table
PkgName string
OutFolder string
Interface dbdrivers.Interface
DriverName string
Config *Config
Templates templater
// SingleTemplates are only created once, not per table
SingleTemplates templater
TestTemplates templater
// SingleTestTemplates are only created once, not per table
SingleTestTemplates templater
//TestMainTemplate is only created once, not per table
TestMainTemplate *template.Template
}
// tplData is used to pass data to the template
type tplData struct {
Table dbdrivers.Table
DriverName string
PkgName string
Tables []string
}
type importList []string
// imports defines the optional standard imports and
// thirdparty imports (from github for example)
type imports struct {
standard importList
thirdparty importList
}
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
}
type Config struct {
Postgres PostgresCfg `toml:"postgres"`
}

20
config.go Normal file
View file

@ -0,0 +1,20 @@
package sqlboiler
// Config for the running of the commands
type Config struct {
DriverName string `toml:"driver_name"`
PkgName string `toml:"pkg_name"`
OutFolder string `toml:"out_folder"`
TableName string `toml:"table_name"`
Postgres PostgresConfig `toml:"postgres"`
}
// PostgresConfig configures a postgres database
type PostgresConfig struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
}

262
imports.go Normal file
View file

@ -0,0 +1,262 @@
package sqlboiler
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/nullbio/sqlboiler/dbdrivers"
)
// imports defines the optional standard imports and
// thirdParty imports (from github for example)
type imports struct {
standard importList
thirdParty importList
}
// importList is a list of import names
type importList []string
func (i importList) Len() int {
return len(i)
}
func (i importList) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i importList) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty))
sort.Sort(c.standard)
sort.Sort(c.thirdParty)
return c
}
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
tmpImp := imports{
standard: make(importList, len(a.standard)),
thirdParty: make(importList, len(a.thirdParty)),
}
copy(tmpImp.standard, a.standard)
copy(tmpImp.thirdParty, a.thirdParty)
for _, col := range columns {
for key, imp := range b {
if col.Type == key {
tmpImp.standard = append(tmpImp.standard, imp.standard...)
tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...)
}
}
}
tmpImp.standard = removeDuplicates(tmpImp.standard)
tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty)
sort.Sort(tmpImp.standard)
sort.Sort(tmpImp.thirdParty)
return tmpImp
}
func buildImportString(imps imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdParty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdParty[0]
}
return []byte(fmt.Sprintf("import %s", imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdParty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}
var defaultTemplateImports = imports{
standard: importList{
`"errors"`,
`"fmt"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
}
var defaultSingletonTemplateImports = map[string]imports{
"helpers": imports{
standard: importList{},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/nullbio/sqlboiler/boil/qs"`,
},
},
}
var defaultTestTemplateImports = imports{
standard: importList{
`"testing"`,
`"reflect"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
},
}
var defaultSingletonTestTemplateImports = map[string]imports{
"helper_funcs": imports{
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdParty: importList{},
},
}
var defaultTestMainImports = map[string]imports{
"postgres": imports{
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"github.com/BurntSushi/toml"`,
`_ "github.com/lib/pq"`,
},
},
}
// importsBasedOnType imports are only included in the template output if the
// database requires one of the following special types. Check
// TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{
"null.Float32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Float64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int8": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int16": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint8": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint16": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint32": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint64": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.String": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Bool": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Time": imports{
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}

View file

@ -1,20 +1,66 @@
package cmds
package sqlboiler
import (
"fmt"
"reflect"
"sort"
"testing"
"github.com/nullbio/sqlboiler/dbdrivers"
)
func TestImportsSort(t *testing.T) {
t.Parallel()
a1 := importList{
`"fmt"`,
`"errors"`,
}
a2 := importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
`"github.com/gorilla/websocket"`,
}
a1Expected := importList{`"errors"`, `"fmt"`}
a2Expected := importList{
`"github.com/gorilla/mux"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/websocket"`,
`_ "github.com/lib/pq"`,
}
sort.Sort(a1)
if !reflect.DeepEqual(a1, a1Expected) {
t.Errorf("Expected a1 to match a1Expected, got: %v", a1)
}
for i, v := range a1 {
if v != a1Expected[i] {
t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
sort.Sort(a2)
if !reflect.DeepEqual(a2, a2Expected) {
t.Errorf("Expected a2 to match a2expected, got: %v", a2)
}
for i, v := range a2 {
if v != a2Expected[i] {
t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i])
}
}
}
func TestCombineTypeImports(t *testing.T) {
imports1 := imports{
standard: importList{
`"errors"`,
`"fmt"`,
},
thirdparty: importList{
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
},
}
@ -25,7 +71,7 @@ func TestCombineTypeImports(t *testing.T) {
`"fmt"`,
`"time"`,
},
thirdparty: importList{
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`,
},
@ -46,7 +92,7 @@ func TestCombineTypeImports(t *testing.T) {
},
}
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
res1 := combineTypeImports(imports1, importsBasedOnType, cols)
if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
@ -58,13 +104,13 @@ func TestCombineTypeImports(t *testing.T) {
`"fmt"`,
`"time"`,
},
thirdparty: importList{
thirdParty: importList{
`"github.com/nullbio/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`,
},
}
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
res2 := combineTypeImports(imports2, importsBasedOnType, cols)
if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
@ -76,11 +122,11 @@ func TestCombineImports(t *testing.T) {
a := imports{
standard: importList{"fmt"},
thirdparty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
thirdParty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"},
}
b := imports{
standard: importList{"os"},
thirdparty: importList{"github.com/nullbio/sqlboiler"},
thirdParty: importList{"github.com/nullbio/sqlboiler"},
}
c := combineImports(a, b)
@ -88,8 +134,8 @@ func TestCombineImports(t *testing.T) {
if c.standard[0] != "fmt" && c.standard[1] != "os" {
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
}
if c.thirdparty[0] != "github.com/nullbio/sqlboiler" && c.thirdparty[1] != "gopkg.in/nullbio/null.v4" {
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdparty)
if c.thirdParty[0] != "github.com/nullbio/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" {
t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty)
}
}

54
main.go
View file

@ -1,54 +0,0 @@
/*
SQLBoiler is a tool to generate Go boilerplate code for database interactions.
So far this includes struct definitions and database statement helper functions.
*/
package main
import (
"fmt"
"os"
"github.com/nullbio/sqlboiler/cmds"
"github.com/spf13/cobra"
)
func main() {
var err error
cmdData := &cmds.CmdData{}
// Load the "config.toml" global config
err = cmdData.LoadConfigFile("config.toml")
if err != nil {
fmt.Printf("Failed to load config file: %s\n", err)
os.Exit(-1)
}
// Set up the cobra root command
var rootCmd = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
`Complete documentation is available at http://github.com/nullbio/sqlboiler`,
PreRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPreRun(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerRun(cmd, args)
},
PostRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPostRun(cmd, args)
},
}
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
// Execute SQLBoiler
if err := rootCmd.Execute(); err != nil {
os.Exit(-1)
}
}

View file

@ -1,4 +1,4 @@
package cmds
package sqlboiler
import (
"bytes"
@ -18,18 +18,18 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
}
// generateOutput builds the file output and sends it to outHandler for saving
func generateOutput(cmdData *CmdData, data *tplData) error {
if len(cmdData.Templates) == 0 {
func generateOutput(state *State, data *templateData) error {
if len(state.Templates) == 0 {
return errors.New("No template files located for generation")
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerImports.standard
imps.thirdparty = sqlBoilerImports.thirdparty
imps.standard = defaultTemplateImports.standard
imps.thirdParty = defaultTemplateImports.thirdParty
for _, template := range cmdData.Templates {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
for _, template := range state.Templates {
imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns)
resp, err := generateTemplate(template, data)
if err != nil {
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
@ -38,7 +38,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
}
fName := data.Table.Name + ".go"
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
if err != nil {
return err
}
@ -47,18 +47,18 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
}
// generateTestOutput builds the test file output and sends it to outHandler for saving
func generateTestOutput(cmdData *CmdData, data *tplData) error {
if len(cmdData.TestTemplates) == 0 {
func generateTestOutput(state *State, data *templateData) error {
if len(state.TestTemplates) == 0 {
return errors.New("No template files located for generation")
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerTestImports.standard
imps.thirdparty = sqlBoilerTestImports.thirdparty
imps.standard = defaultTestTemplateImports.standard
imps.thirdParty = defaultTestTemplateImports.thirdParty
for _, template := range cmdData.TestTemplates {
for _, template := range state.TestTemplates {
resp, err := generateTemplate(template, data)
if err != nil {
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
@ -67,7 +67,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
}
fName := data.Table.Name + "_test.go"
err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out)
err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out)
if err != nil {
return err
}
@ -75,20 +75,22 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
return nil
}
func generateSinglesOutput(cmdData *CmdData) error {
if cmdData.SingleTemplates == nil {
return errors.New("No single templates located for generation")
// generateSingletonOutput processes the templates that should only be run
// one time.
func generateSingletonOutput(state *State) error {
if state.SingletonTemplates == nil {
return errors.New("No singleton templates located for generation")
}
tplData := &tplData{
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
templateData := &templateData{
PkgName: state.Config.PkgName,
DriverName: state.Config.DriverName,
}
for _, template := range cmdData.SingleTemplates {
for _, template := range state.SingletonTemplates {
var imps imports
resp, err := generateTemplate(template, tplData)
resp, err := generateTemplate(template, templateData)
if err != nil {
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
}
@ -97,12 +99,12 @@ func generateSinglesOutput(cmdData *CmdData) error {
ext := filepath.Ext(fName)
fName = fName[0 : len(fName)-len(ext)]
imps.standard = sqlBoilerSinglesImports[fName].standard
imps.thirdparty = sqlBoilerSinglesImports[fName].thirdparty
imps.standard = defaultSingletonTemplateImports[fName].standard
imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty
fName = fName + ".go"
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
if err != nil {
return err
}
@ -111,20 +113,22 @@ func generateSinglesOutput(cmdData *CmdData) error {
return nil
}
func generateSinglesTestOutput(cmdData *CmdData) error {
if cmdData.SingleTestTemplates == nil {
return errors.New("No single test templates located for generation")
// generateSingletonTestOutput processes the templates that should only be run
// one time.
func generateSingletonTestOutput(state *State) error {
if state.SingletonTestTemplates == nil {
return errors.New("No singleton test templates located for generation")
}
tplData := &tplData{
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
templateData := &templateData{
PkgName: state.Config.PkgName,
DriverName: state.Config.DriverName,
}
for _, template := range cmdData.SingleTestTemplates {
for _, template := range state.SingletonTestTemplates {
var imps imports
resp, err := generateTemplate(template, tplData)
resp, err := generateTemplate(template, templateData)
if err != nil {
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
}
@ -133,12 +137,12 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
ext := filepath.Ext(fName)
fName = fName[0 : len(fName)-len(ext)]
imps.standard = sqlBoilerSinglesTestImports[fName].standard
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty
imps.standard = defaultSingletonTestTemplateImports[fName].standard
imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty
fName = fName + "_test.go"
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp})
if err != nil {
return err
}
@ -147,35 +151,30 @@ func generateSinglesTestOutput(cmdData *CmdData) error {
return nil
}
func generateTestMainOutput(cmdData *CmdData) error {
if cmdData.TestMainTemplate == nil {
func generateTestMainOutput(state *State) error {
if state.TestMainTemplate == nil {
return errors.New("No TestMain template located for generation")
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
imps.standard = defaultTestMainImports[state.Config.DriverName].standard
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty
var tables []string
for _, v := range cmdData.Tables {
tables = append(tables, v.Name)
templateData := &templateData{
Tables: state.Tables,
PkgName: state.Config.PkgName,
DriverName: state.Config.DriverName,
}
tplData := &tplData{
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
Tables: tables,
}
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)
resp, err := generateTemplate(state.TestMainTemplate, templateData)
if err != nil {
return err
}
out = append(out, resp)
err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out)
err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out)
if err != nil {
return err
}
@ -216,7 +215,7 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports,
}
// generateTemplate takes a template and returns the output of the template execution.
func generateTemplate(t *template.Template, data *tplData) ([]byte, error) {
func generateTemplate(t *template.Template, data *templateData) ([]byte, error) {
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, err

View file

@ -1,4 +1,4 @@
package cmds
package sqlboiler
import (
"bytes"
@ -78,7 +78,7 @@ func TestOutHandlerFiles(t *testing.T) {
}
a2 := imports{
thirdparty: []string{
thirdParty: []string{
`"github.com/spf13/cobra"`,
},
}
@ -96,7 +96,7 @@ func TestOutHandlerFiles(t *testing.T) {
`"fmt"`,
`"errors"`,
},
thirdparty: importList{
thirdParty: importList{
`_ "github.com/lib/pq"`,
`_ "github.com/gorilla/n"`,
`"github.com/gorilla/mux"`,
@ -106,7 +106,7 @@ func TestOutHandlerFiles(t *testing.T) {
file = &bytes.Buffer{}
sort.Sort(a3.standard)
sort.Sort(a3.thirdparty)
sort.Sort(a3.thirdParty)
if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil {
t.Error(err)

223
sqlboiler.go Normal file
View file

@ -0,0 +1,223 @@
// Package sqlboiler has types and methods useful for generating code that
// acts as a fully dynamic ORM might.
package sqlboiler
import (
"errors"
"fmt"
"html/template"
"os"
"strings"
"github.com/nullbio/sqlboiler/dbdrivers"
)
const (
templatesDirectory = "cmds/templates"
templatesSingletonDirectory = "cmds/templates/singleton"
templatesTestDirectory = "cmds/templates_test"
templatesSingletonTestDirectory = "cmds/templates_test/singleton"
)
// State holds the global data needed by most pieces to run
type State struct {
Config *Config
Driver dbdrivers.Interface
Tables []dbdrivers.Table
Templates templateList
TestTemplates templateList
SingletonTemplates templateList
SingletonTestTemplates templateList
TestMainTemplate *template.Template
}
// New creates a new state based off of the config
func New(config *Config) (*State, error) {
s := &State{}
err := s.initDriver(config.DriverName)
if err != nil {
return nil, err
}
// Connect to the driver database
if err = s.Driver.Open(); err != nil {
return nil, fmt.Errorf("Unable to connect to the database: %s", err)
}
err = s.initTables(config.TableName)
if err != nil {
return nil, fmt.Errorf("Unable to initialize tables: %s", err)
}
err = s.initOutFolder()
if err != nil {
return nil, fmt.Errorf("Unable to initialize the output folder: %s", err)
}
err = s.initTemplates()
if err != nil {
return nil, fmt.Errorf("Unable to initialize templates: %s", err)
}
return s, nil
}
// Run executes the sqlboiler templates and outputs them to files based on the
// state given.
func (s *State) Run(includeTests bool) error {
if err := generateSingletonOutput(s); err != nil {
return fmt.Errorf("Unable to generate singleton template output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(s); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
if err := generateSingletonTestOutput(s); err != nil {
return fmt.Errorf("Unable to generate singleton test template output: %s", err)
}
}
for _, table := range s.Tables {
if table.IsJoinTable {
continue
}
data := &templateData{
Table: table,
DriverName: s.Config.DriverName,
PkgName: s.Config.PkgName,
}
// Generate the regular templates
if err := generateOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(s, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}
return nil
}
// Cleanup closes any resources that must be closed
func (s *State) Cleanup() error {
s.Driver.Close()
return nil
}
// initTemplates loads all template folders into the state object.
func (s *State) initTemplates() error {
var err error
s.Templates, err = loadTemplates(templatesDirectory)
if err != nil {
return err
}
s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory)
if err != nil {
return err
}
s.TestTemplates, err = loadTemplates(templatesTestDirectory)
if err != nil {
return err
}
s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory)
if err != nil {
return err
}
return nil
}
// initDriver attempts to set the state Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func (s *State) initDriver(driverName string) error {
// Create a driver based off driver flag
switch driverName {
case "postgres":
s.Driver = dbdrivers.NewPostgresDriver(
s.Config.Postgres.User,
s.Config.Postgres.Pass,
s.Config.Postgres.DBName,
s.Config.Postgres.Host,
s.Config.Postgres.Port,
)
}
if s.Driver == nil {
return errors.New("An invalid driver name was provided")
}
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that
// result.
func (s *State) initTables(tableName string) error {
var tableNames []string
if len(tableName) != 0 {
tableNames = strings.Split(tableName, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(s.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
if err := checkPKeys(s.Tables); err != nil {
return err
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func (s *State) initOutFolder() error {
if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil
}

View file

@ -1,4 +1,4 @@
package cmds
package sqlboiler
import (
"bufio"
@ -15,11 +15,11 @@ import (
"github.com/nullbio/sqlboiler/dbdrivers"
)
var cmdData *CmdData
var state *State
var rgxHasSpaces = regexp.MustCompile(`^\s+`)
func init() {
cmdData = &CmdData{
state = &State{
Tables: []dbdrivers.Table{
{
Name: "patrick_table",
@ -59,10 +59,11 @@ func init() {
},
},
},
PkgName: "patrick",
OutFolder: "",
DriverName: "postgres",
Interface: nil,
Config: &Config{
PkgName: "patrick",
OutFolder: "",
DriverName: "postgres",
},
}
}
@ -84,63 +85,63 @@ func TestTemplates(t *testing.T) {
t.SkipNow()
}
if err := checkPKeys(cmdData.Tables); err != nil {
if err := checkPKeys(state.Tables); err != nil {
t.Fatalf("%s", err)
}
// Initialize the templates
var err error
cmdData.Templates, err = loadTemplates("templates")
state.Templates, err = loadTemplates("templates")
if err != nil {
t.Fatalf("Unable to initialize templates: %s", err)
}
if len(cmdData.Templates) == 0 {
if len(state.Templates) == 0 {
t.Errorf("Templates is empty.")
}
cmdData.SingleTemplates, err = loadTemplates("templates/singles")
state.SingletonTemplates, err = loadTemplates("templates/singleton")
if err != nil {
t.Fatalf("Unable to initialize single templates: %s", err)
t.Fatalf("Unable to initialize singleton templates: %s", err)
}
if len(cmdData.SingleTemplates) == 0 {
t.Errorf("SingleTemplates is empty.")
if len(state.SingletonTemplates) == 0 {
t.Errorf("SingletonTemplates is empty.")
}
cmdData.TestTemplates, err = loadTemplates("templates_test")
state.TestTemplates, err = loadTemplates("templates_test")
if err != nil {
t.Fatalf("Unable to initialize templates: %s", err)
}
if len(cmdData.Templates) == 0 {
if len(state.Templates) == 0 {
t.Errorf("Templates is empty.")
}
cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
state.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl")
if err != nil {
t.Fatalf("Unable to initialize templates: %s", err)
}
cmdData.OutFolder, err = ioutil.TempDir("", "templates")
state.Config.OutFolder, err = ioutil.TempDir("", "templates")
if err != nil {
t.Fatalf("Unable to create tempdir: %s", err)
}
defer os.RemoveAll(cmdData.OutFolder)
defer os.RemoveAll(state.Config.OutFolder)
if err = cmdData.run(true); err != nil {
if err = state.Run(true); err != nil {
t.Errorf("Unable to run SQLBoilerRun: %s", err)
}
buf := &bytes.Buffer{}
cmd := exec.Command("go", "test", "-c")
cmd.Dir = cmdData.OutFolder
cmd.Dir = state.Config.OutFolder
cmd.Stderr = buf
if err = cmd.Run(); err != nil {
t.Errorf("go test cmd execution failed: %s", err)
outputCompileErrors(buf, cmdData.OutFolder)
outputCompileErrors(buf, state.Config.OutFolder)
fmt.Println()
}
}

119
templates.go Normal file
View file

@ -0,0 +1,119 @@
package sqlboiler
import (
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"github.com/nullbio/sqlboiler/dbdrivers"
"github.com/nullbio/sqlboiler/strmangle"
)
// templateData for sqlboiler templates
type templateData struct {
Tables []dbdrivers.Table
Table dbdrivers.Table
DriverName string
PkgName string
}
type templateList []*template.Template
func (t templateList) Len() int {
return len(t)
}
func (t templateList) Swap(k, j int) {
t[k], t[j] = t[j], t[k]
}
func (t templateList) Less(k, j int) bool {
// Make sure "struct" goes to the front
if t[k].Name() == "struct.tpl" {
return true
}
res := strings.Compare(t[k].Name(), t[j].Name())
if res <= 0 {
return true
}
return false
}
// loadTemplates loads all of the template files in the specified directory.
func loadTemplates(dir string) (templateList, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern)
if err != nil {
return nil, err
}
templates := templateList(tpl.Templates())
sort.Sort(templates)
return templates, err
}
// loadTemplate loads a single template file.
func loadTemplate(dir string, filename string) (*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, dir, filename)
tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern)
if err != nil {
return nil, err
}
return tpl.Lookup(filename), err
}
// templateFunctions is a map of all the functions that get passed into the
// templates. If you wish to pass a new function into your own template,
// add a function pointer here.
var templateFunctions = template.FuncMap{
"singular": strmangle.Singular,
"plural": strmangle.Plural,
"titleCase": strmangle.TitleCase,
"titleCaseSingular": strmangle.TitleCaseSingular,
"titleCasePlural": strmangle.TitleCasePlural,
"titleCaseCommaList": strmangle.TitleCaseCommaList,
"camelCase": strmangle.CamelCase,
"camelCaseSingular": strmangle.CamelCaseSingular,
"camelCasePlural": strmangle.CamelCasePlural,
"camelCaseCommaList": strmangle.CamelCaseCommaList,
"columnsToStrings": strmangle.ColumnsToStrings,
"commaList": strmangle.CommaList,
"makeDBName": strmangle.MakeDBName,
"selectParamNames": strmangle.SelectParamNames,
"insertParamNames": strmangle.InsertParamNames,
"insertParamFlags": strmangle.InsertParamFlags,
"insertParamVariables": strmangle.InsertParamVariables,
"scanParamNames": strmangle.ScanParamNames,
"hasPrimaryKey": strmangle.HasPrimaryKey,
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
"wherePrimaryKey": strmangle.WherePrimaryKey,
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
"updateParamNames": strmangle.UpdateParamNames,
"updateParamVariables": strmangle.UpdateParamVariables,
"supportsResultObject": strmangle.SupportsResultObject,
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
"filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement,
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
"randDBStruct": strmangle.RandDBStruct,
"randDBStructSlice": strmangle.RandDBStructSlice,
}

View file

@ -1,11 +1,11 @@
{{if hasPrimaryKey .Table.PKey -}}
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
// Update updates a single {{$tableNameSingular}} record.
// whitelist is a list of column_name's that should be updated.
// Update will match against the primary key column to find the record to update.
// WARNING: This Update method will NOT ignore nil members.
// If you pass in nil members, those columnns will be set to null.
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
// column_name's that should be updated. The primary key will be used to find
// the record to update.
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
// to control the set of columns that will be saved.
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
return o.UpdateX(boil.GetDB(), whitelist...)
}

34
templates_test.go Normal file
View file

@ -0,0 +1,34 @@
package sqlboiler
import (
"sort"
"testing"
"text/template"
)
func TestTemplateListSort(t *testing.T) {
templs := templateList{
template.New("bob.tpl"),
template.New("all.tpl"),
template.New("struct.tpl"),
template.New("ttt.tpl"),
}
expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"}
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"}
sort.Sort(templs)
for i, v := range templs {
if v.Name() != expected[i] {
t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name())
}
}
}