From 8757c8a1845cf2281ddec30eae678bdf058857a8 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sat, 11 Jun 2016 18:25:00 -0700 Subject: [PATCH] Refactor entire project :D - Move most files to root - Remove cmds directory in favor of cmd directory with binary - Remove all cobra from main --- cmds/config.go | 186 ------------ cmds/config_test.go | 35 --- cmds/helpers.go | 114 ------- cmds/sorters.go | 42 --- cmds/sorters_test.go | 80 ----- cmds/sqlboiler.go | 277 ------------------ cmds/types.go | 67 ----- config.go | 41 +++ imports.go | 262 +++++++++++++++++ cmds/helpers_test.go => imports_test.go | 66 ++++- main.go | 54 ---- cmds/output.go => output.go | 95 +++--- cmds/output_test.go => output_test.go | 8 +- sqlboiler.go | 207 +++++++++++++ cmds/sqlboiler_test.go => sqlboiler_test.go | 43 +-- templates.go | 119 ++++++++ {cmds/templates => templates}/all.tpl | 0 {cmds/templates => templates}/delete.tpl | 0 {cmds/templates => templates}/find.tpl | 0 {cmds/templates => templates}/finishers.tpl | 0 {cmds/templates => templates}/helpers.tpl | 0 {cmds/templates => templates}/hooks.tpl | 0 {cmds/templates => templates}/insert.tpl | 0 .../singleton/helpers.tpl | 0 {cmds/templates => templates}/struct.tpl | 0 {cmds/templates => templates}/update.tpl | 0 templates_test.go | 34 +++ .../templates_test => templates_test}/all.tpl | 0 .../delete.tpl | 0 .../find.tpl | 0 .../finishers.tpl | 0 .../helpers.tpl | 0 .../hooks.tpl | 0 .../insert.tpl | 0 .../main_test/postgres_main.tpl | 0 .../select.tpl | 0 .../singleton/helper_funcs.tpl | 0 .../struct.tpl | 0 .../update.tpl | 0 39 files changed, 790 insertions(+), 940 deletions(-) delete mode 100644 cmds/config.go delete mode 100644 cmds/config_test.go delete mode 100644 cmds/helpers.go delete mode 100644 cmds/sorters.go delete mode 100644 cmds/sorters_test.go delete mode 100644 cmds/sqlboiler.go delete mode 100644 cmds/types.go create mode 100644 config.go create mode 100644 imports.go rename cmds/helpers_test.go => imports_test.go (70%) delete mode 100644 main.go rename cmds/output.go => output.go (59%) rename cmds/output_test.go => output_test.go (96%) create mode 100644 sqlboiler.go rename cmds/sqlboiler_test.go => sqlboiler_test.go (82%) create mode 100644 templates.go rename {cmds/templates => templates}/all.tpl (100%) rename {cmds/templates => templates}/delete.tpl (100%) rename {cmds/templates => templates}/find.tpl (100%) rename {cmds/templates => templates}/finishers.tpl (100%) rename {cmds/templates => templates}/helpers.tpl (100%) rename {cmds/templates => templates}/hooks.tpl (100%) rename {cmds/templates => templates}/insert.tpl (100%) rename {cmds/templates => templates}/singleton/helpers.tpl (100%) rename {cmds/templates => templates}/struct.tpl (100%) rename {cmds/templates => templates}/update.tpl (100%) create mode 100644 templates_test.go rename {cmds/templates_test => templates_test}/all.tpl (100%) rename {cmds/templates_test => templates_test}/delete.tpl (100%) rename {cmds/templates_test => templates_test}/find.tpl (100%) rename {cmds/templates_test => templates_test}/finishers.tpl (100%) rename {cmds/templates_test => templates_test}/helpers.tpl (100%) rename {cmds/templates_test => templates_test}/hooks.tpl (100%) rename {cmds/templates_test => templates_test}/insert.tpl (100%) rename {cmds/templates_test => templates_test}/main_test/postgres_main.tpl (100%) rename {cmds/templates_test => templates_test}/select.tpl (100%) rename {cmds/templates_test => templates_test}/singleton/helper_funcs.tpl (100%) rename {cmds/templates_test => templates_test}/struct.tpl (100%) rename {cmds/templates_test => templates_test}/update.tpl (100%) diff --git a/cmds/config.go b/cmds/config.go deleted file mode 100644 index 3d7fb91..0000000 --- a/cmds/config.go +++ /dev/null @@ -1,186 +0,0 @@ -package cmds - -import ( - "fmt" - "os" - "text/template" - - "github.com/BurntSushi/toml" - "github.com/nullbio/sqlboiler/strmangle" -) - -// sqlBoilerTypeImports imports are only included in the template output if the database -// requires one of the following special types. Check TranslateColumnType to see the type assignments. -var sqlBoilerTypeImports = map[string]imports{ - "null.Float32": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Float64": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Int": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Int8": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Int16": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Int32": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Int64": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Uint": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Uint8": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Uint16": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Uint32": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Uint64": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.String": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Bool": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "null.Time": imports{ - thirdparty: importList{`"gopkg.in/nullbio/null.v4"`}, - }, - "time.Time": imports{ - standard: importList{`"time"`}, - }, -} - -// sqlBoilerImports defines the list of default template imports. -var sqlBoilerImports = imports{ - standard: importList{ - `"errors"`, - `"fmt"`, - `"strings"`, - }, - thirdparty: importList{ - `"github.com/nullbio/sqlboiler/boil"`, - `"github.com/nullbio/sqlboiler/boil/qs"`, - }, -} - -var sqlBoilerSinglesImports = map[string]imports{ - "helpers": imports{ - standard: importList{}, - thirdparty: importList{ - `"github.com/nullbio/sqlboiler/boil"`, - `"github.com/nullbio/sqlboiler/boil/qs"`, - }, - }, -} - -// sqlBoilerTestImports defines the list of default test template imports. -var sqlBoilerTestImports = imports{ - standard: importList{ - `"testing"`, - `"reflect"`, - }, - thirdparty: importList{ - `"github.com/nullbio/sqlboiler/boil"`, - }, -} - -var sqlBoilerSinglesTestImports = map[string]imports{ - "helper_funcs": imports{ - standard: importList{ - `"crypto/md5"`, - `"fmt"`, - `"os"`, - `"strconv"`, - `"math/rand"`, - }, - thirdparty: importList{}, - }, -} - -var sqlBoilerTestMainImports = map[string]imports{ - "postgres": imports{ - standard: importList{ - `"testing"`, - `"os"`, - `"os/exec"`, - `"fmt"`, - `"io/ioutil"`, - `"bytes"`, - `"database/sql"`, - `"time"`, - `"math/rand"`, - }, - thirdparty: importList{ - `"github.com/nullbio/sqlboiler/boil"`, - `"github.com/BurntSushi/toml"`, - `_ "github.com/lib/pq"`, - }, - }, -} - -// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates. -// If you wish to pass a new function into your own template, add a pointer to it here. -var sqlBoilerTemplateFuncs = template.FuncMap{ - "singular": strmangle.Singular, - "plural": strmangle.Plural, - "titleCase": strmangle.TitleCase, - "titleCaseSingular": strmangle.TitleCaseSingular, - "titleCasePlural": strmangle.TitleCasePlural, - "titleCaseCommaList": strmangle.TitleCaseCommaList, - "camelCase": strmangle.CamelCase, - "camelCaseSingular": strmangle.CamelCaseSingular, - "camelCasePlural": strmangle.CamelCasePlural, - "camelCaseCommaList": strmangle.CamelCaseCommaList, - "columnsToStrings": strmangle.ColumnsToStrings, - "commaList": strmangle.CommaList, - "makeDBName": strmangle.MakeDBName, - "selectParamNames": strmangle.SelectParamNames, - "insertParamNames": strmangle.InsertParamNames, - "insertParamFlags": strmangle.InsertParamFlags, - "insertParamVariables": strmangle.InsertParamVariables, - "scanParamNames": strmangle.ScanParamNames, - "hasPrimaryKey": strmangle.HasPrimaryKey, - "primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig, - "wherePrimaryKey": strmangle.WherePrimaryKey, - "paramsPrimaryKey": strmangle.ParamsPrimaryKey, - "primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex, - "updateParamNames": strmangle.UpdateParamNames, - "updateParamVariables": strmangle.UpdateParamVariables, - "supportsResultObject": strmangle.SupportsResultObject, - "filterColumnsByDefault": strmangle.FilterColumnsByDefault, - "filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement, - "autoIncPrimaryKey": strmangle.AutoIncPrimaryKey, - - "randDBStruct": strmangle.RandDBStruct, - "randDBStructSlice": strmangle.RandDBStructSlice, -} - -// LoadConfigFile loads the toml config file into the cfg object -func (c *CmdData) LoadConfigFile(filename string) error { - cfg := &Config{} - - _, err := toml.DecodeFile(filename, &cfg) - - if os.IsNotExist(err) { - return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err) - } - - if err != nil { - return fmt.Errorf("Failed to decode toml configuration file: %s", err) - } - - c.Config = cfg - return nil -} diff --git a/cmds/config_test.go b/cmds/config_test.go deleted file mode 100644 index 9630c09..0000000 --- a/cmds/config_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package cmds - -import ( - "io/ioutil" - "os" - "testing" -) - -func TestLoadConfig(t *testing.T) { - t.Parallel() - - cmdData := &CmdData{} - - file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest") - defer os.Remove(file.Name()) - - fContents := `[postgres] - host="localhost" - port=5432 - user="user" - pass="pass" - dbname="mydb"` - - file.WriteString(fContents) - err := cmdData.LoadConfigFile(file.Name()) - if err != nil { - t.Errorf("Unable to load config file: %s", err) - } - - if cmdData.Config.Postgres.Host != "localhost" || - cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" || - cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 { - t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres) - } -} diff --git a/cmds/helpers.go b/cmds/helpers.go deleted file mode 100644 index 29f42e9..0000000 --- a/cmds/helpers.go +++ /dev/null @@ -1,114 +0,0 @@ -package cmds - -import ( - "bytes" - "fmt" - "sort" - - "github.com/nullbio/sqlboiler/dbdrivers" -) - -func combineImports(a, b imports) imports { - var c imports - - c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard)) - c.thirdparty = removeDuplicates(combineStringSlices(a.thirdparty, b.thirdparty)) - - sort.Sort(c.standard) - sort.Sort(c.thirdparty) - - return c -} - -func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports { - tmpImp := imports{ - standard: make(importList, len(a.standard)), - thirdparty: make(importList, len(a.thirdparty)), - } - - copy(tmpImp.standard, a.standard) - copy(tmpImp.thirdparty, a.thirdparty) - - for _, col := range columns { - for key, imp := range b { - if col.Type == key { - tmpImp.standard = append(tmpImp.standard, imp.standard...) - tmpImp.thirdparty = append(tmpImp.thirdparty, imp.thirdparty...) - } - } - } - - tmpImp.standard = removeDuplicates(tmpImp.standard) - tmpImp.thirdparty = removeDuplicates(tmpImp.thirdparty) - - sort.Sort(tmpImp.standard) - sort.Sort(tmpImp.thirdparty) - - return tmpImp -} - -func buildImportString(imps imports) []byte { - stdlen, thirdlen := len(imps.standard), len(imps.thirdparty) - if stdlen+thirdlen < 1 { - return []byte{} - } - - if stdlen+thirdlen == 1 { - var imp string - if stdlen == 1 { - imp = imps.standard[0] - } else { - imp = imps.thirdparty[0] - } - return []byte(fmt.Sprintf("import %s", imp)) - } - - buf := &bytes.Buffer{} - buf.WriteString("import (") - for _, std := range imps.standard { - fmt.Fprintf(buf, "\n\t%s", std) - } - if stdlen != 0 && thirdlen != 0 { - buf.WriteString("\n") - } - for _, third := range imps.thirdparty { - fmt.Fprintf(buf, "\n\t%s", third) - } - buf.WriteString("\n)\n") - - return buf.Bytes() -} - -func combineStringSlices(a, b []string) []string { - c := make([]string, len(a)+len(b)) - if len(a) > 0 { - copy(c, a) - } - if len(b) > 0 { - copy(c[len(a):], b) - } - - return c -} - -func removeDuplicates(dedup []string) []string { - if len(dedup) <= 1 { - return dedup - } - - for i := 0; i < len(dedup)-1; i++ { - for j := i + 1; j < len(dedup); j++ { - if dedup[i] != dedup[j] { - continue - } - - if j != len(dedup)-1 { - dedup[j] = dedup[len(dedup)-1] - j-- - } - dedup = dedup[:len(dedup)-1] - } - } - - return dedup -} diff --git a/cmds/sorters.go b/cmds/sorters.go deleted file mode 100644 index 1332eca..0000000 --- a/cmds/sorters.go +++ /dev/null @@ -1,42 +0,0 @@ -package cmds - -import "strings" - -func (i importList) Len() int { - return len(i) -} - -func (i importList) Swap(k, j int) { - i[k], i[j] = i[j], i[k] -} - -func (i importList) Less(k, j int) bool { - res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ ")) - if res <= 0 { - return true - } - - return false -} - -func (t templater) Len() int { - return len(t) -} - -func (t templater) Swap(k, j int) { - t[k], t[j] = t[j], t[k] -} - -func (t templater) Less(k, j int) bool { - // Make sure "struct" goes to the front - if t[k].Name() == "struct.tpl" { - return true - } - - res := strings.Compare(t[k].Name(), t[j].Name()) - if res <= 0 { - return true - } - - return false -} diff --git a/cmds/sorters_test.go b/cmds/sorters_test.go deleted file mode 100644 index ccbac65..0000000 --- a/cmds/sorters_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package cmds - -import ( - "reflect" - "sort" - "testing" - "text/template" -) - -func TestSortImports(t *testing.T) { - t.Parallel() - - a1 := importList{ - `"fmt"`, - `"errors"`, - } - a2 := importList{ - `_ "github.com/lib/pq"`, - `_ "github.com/gorilla/n"`, - `"github.com/gorilla/mux"`, - `"github.com/gorilla/websocket"`, - } - - a1Expected := importList{`"errors"`, `"fmt"`} - a2Expected := importList{ - `"github.com/gorilla/mux"`, - `_ "github.com/gorilla/n"`, - `"github.com/gorilla/websocket"`, - `_ "github.com/lib/pq"`, - } - - sort.Sort(a1) - if !reflect.DeepEqual(a1, a1Expected) { - t.Errorf("Expected a1 to match a1Expected, got: %v", a1) - } - - for i, v := range a1 { - if v != a1Expected[i] { - t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) - } - } - - sort.Sort(a2) - if !reflect.DeepEqual(a2, a2Expected) { - t.Errorf("Expected a2 to match a2expected, got: %v", a2) - } - - for i, v := range a2 { - if v != a2Expected[i] { - t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) - } - } -} - -func TestSortTemplates(t *testing.T) { - templs := templater{ - template.New("bob.tpl"), - template.New("all.tpl"), - template.New("struct.tpl"), - template.New("ttt.tpl"), - } - - expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"} - - for i, v := range templs { - if v.Name() != expected[i] { - t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name()) - } - } - - expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"} - - sort.Sort(templs) - - for i, v := range templs { - if v.Name() != expected[i] { - t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name()) - } - } -} diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go deleted file mode 100644 index 0a2bbf4..0000000 --- a/cmds/sqlboiler.go +++ /dev/null @@ -1,277 +0,0 @@ -package cmds - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "sort" - "strings" - "text/template" - - "github.com/nullbio/sqlboiler/dbdrivers" - "github.com/spf13/cobra" -) - -const ( - templatesDirectory = "cmds/templates" - templatesSinglesDirectory = "cmds/templates/singleton" - - templatesTestDirectory = "cmds/templates_test" - templatesSinglesTestDirectory = "cmds/templates_test/singleton" - templatesTestMainDirectory = "cmds/templates_test/main_test" -) - -// LoadTemplates loads all template folders into the cmdData object. -func initTemplates(cmdData *CmdData) error { - var err error - cmdData.Templates, err = loadTemplates(templatesDirectory) - if err != nil { - return err - } - - cmdData.SingleTemplates, err = loadTemplates(templatesSinglesDirectory) - if err != nil { - return err - } - - cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory) - if err != nil { - return err - } - - cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory) - if err != nil { - return err - } - - filename := cmdData.DriverName + "_main.tpl" - cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename) - if err != nil { - return err - } - - return nil -} - -// loadTemplates loads all of the template files in the specified directory. -func loadTemplates(dir string) ([]*template.Template, error) { - wd, err := os.Getwd() - if err != nil { - return nil, err - } - - pattern := filepath.Join(wd, dir, "*.tpl") - tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern) - - if err != nil { - return nil, err - } - - templates := templater(tpl.Templates()) - sort.Sort(templates) - - return templates, err -} - -// loadTemplate loads a single template file. -func loadTemplate(dir string, filename string) (*template.Template, error) { - wd, err := os.Getwd() - if err != nil { - return nil, err - } - - pattern := filepath.Join(wd, dir, filename) - tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseFiles(pattern) - - if err != nil { - return nil, err - } - - return tpl.Lookup(filename), err -} - -// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished. -func (c *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error { - c.Interface.Close() - return nil -} - -// SQLBoilerPreRun performs the initialization tasks before the root command is run -func (c *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error { - // Initialize package name - pkgName := cmd.PersistentFlags().Lookup("pkgname").Value.String() - - // Retrieve driver flag - driverName := cmd.PersistentFlags().Lookup("driver").Value.String() - if driverName == "" { - return errors.New("Must supply a driver flag.") - } - - tableName := cmd.PersistentFlags().Lookup("table").Value.String() - - outFolder := cmd.PersistentFlags().Lookup("folder").Value.String() - if outFolder == "" { - return fmt.Errorf("No output folder specified.") - } - - return c.initCmdData(pkgName, driverName, tableName, outFolder) -} - -// SQLBoilerRun is a proxy method for the run function -func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { - return c.run(true) -} - -// run executes the sqlboiler templates and outputs them to files. -func (c *CmdData) run(includeTests bool) error { - if err := generateSingletonOutput(c); err != nil { - return fmt.Errorf("Unable to generate singleton template output: %s", err) - } - - if includeTests { - if err := generateTestMainOutput(c); err != nil { - return fmt.Errorf("Unable to generate TestMain output: %s", err) - } - - if err := generateSingletonTestOutput(c); err != nil { - return fmt.Errorf("Unable to generate singleton test template output: %s", err) - } - } - - for _, table := range c.Tables { - if table.IsJoinTable { - continue - } - - data := &tplData{ - Table: table, - DriverName: c.DriverName, - PkgName: c.PkgName, - } - - // Generate the regular templates - if err := generateOutput(c, data); err != nil { - return fmt.Errorf("Unable to generate output: %s", err) - } - - // Generate the test templates - if includeTests { - if err := generateTestOutput(c, data); err != nil { - return fmt.Errorf("Unable to generate test output: %s", err) - } - } - } - - return nil -} - -func (c *CmdData) initCmdData(pkgName, driverName, tableName, outFolder string) error { - c.OutFolder = outFolder - c.PkgName = pkgName - - err := initInterface(driverName, c) - if err != nil { - return err - } - - // Connect to the driver database - if err = c.Interface.Open(); err != nil { - return fmt.Errorf("Unable to connect to the database: %s", err) - } - - err = initTables(tableName, c) - if err != nil { - return fmt.Errorf("Unable to initialize tables: %s", err) - } - - err = initOutFolder(c) - if err != nil { - return fmt.Errorf("Unable to initialize the output folder: %s", err) - } - - err = initTemplates(c) - if err != nil { - return fmt.Errorf("Unable to initialize templates: %s", err) - } - - return nil -} - -// initInterface attempts to set the cmdData Interface based off the passed in -// driver flag value. If an invalid flag string is provided an error is returned. -func initInterface(driverName string, cmdData *CmdData) error { - // Create a driver based off driver flag - switch driverName { - case "postgres": - cmdData.Interface = dbdrivers.NewPostgresDriver( - cmdData.Config.Postgres.User, - cmdData.Config.Postgres.Pass, - cmdData.Config.Postgres.DBName, - cmdData.Config.Postgres.Host, - cmdData.Config.Postgres.Port, - ) - } - - if cmdData.Interface == nil { - return errors.New("An invalid driver name was provided") - } - - cmdData.DriverName = driverName - return nil -} - -// initTables will create a string slice out of the passed in table flag value -// if one is provided. If no flag is provided, it will attempt to connect to the -// database to retrieve all "public" table names, and build a slice out of that result. -func initTables(tableName string, cmdData *CmdData) error { - var tableNames []string - - if len(tableName) != 0 { - tableNames = strings.Split(tableName, ",") - for i, name := range tableNames { - tableNames[i] = strings.TrimSpace(name) - } - } - - var err error - cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...) - if err != nil { - return fmt.Errorf("Unable to get all table names: %s", err) - } - - if len(cmdData.Tables) == 0 { - return errors.New("No tables found in database, migrate some tables first") - } - - if err := checkPKeys(cmdData.Tables); err != nil { - return err - } - - return nil -} - -// checkPKeys ensures every table has a primary key column -func checkPKeys(tables []dbdrivers.Table) error { - var missingPkey []string - for _, t := range tables { - if t.PKey == nil { - missingPkey = append(missingPkey, t.Name) - } - } - - if len(missingPkey) != 0 { - return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) - } - - return nil -} - -// initOutFolder creates the folder that will hold the generated output. -func initOutFolder(cmdData *CmdData) error { - if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil { - return fmt.Errorf("Unable to make output folder: %s", err) - } - - return nil -} diff --git a/cmds/types.go b/cmds/types.go deleted file mode 100644 index 2445f03..0000000 --- a/cmds/types.go +++ /dev/null @@ -1,67 +0,0 @@ -package cmds - -import ( - "text/template" - - "github.com/nullbio/sqlboiler/dbdrivers" - "github.com/spf13/cobra" -) - -// CobraRunFunc declares the cobra.Command.Run function definition -type CobraRunFunc func(cmd *cobra.Command, args []string) - -type templater []*template.Template - -// CmdData holds the table schema a slice of (column name, column type) slices. -// It also holds a slice of all of the table names sqlboiler is generating against, -// the database driver chosen by the driver flag at runtime, and a pointer to the -// output file, if one is specified with a flag. -type CmdData struct { - Tables []dbdrivers.Table - PkgName string - OutFolder string - Interface dbdrivers.Interface - DriverName string - Config *Config - - Templates templater - // SingleTemplates are only created once, not per table - SingleTemplates templater - - TestTemplates templater - // SingleTestTemplates are only created once, not per table - SingleTestTemplates templater - //TestMainTemplate is only created once, not per table - TestMainTemplate *template.Template -} - -// tplData is used to pass data to the template -type tplData struct { - Table dbdrivers.Table - DriverName string - PkgName string - Tables []string -} - -type importList []string - -// imports defines the optional standard imports and -// thirdparty imports (from github for example) -type imports struct { - standard importList - thirdparty importList -} - -// PostgresCfg configures a postgres database -type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` -} - -// Config is loaded from a file -type Config struct { - Postgres PostgresCfg `toml:"postgres"` -} diff --git a/config.go b/config.go new file mode 100644 index 0000000..a2470ab --- /dev/null +++ b/config.go @@ -0,0 +1,41 @@ +package sqlboiler + +import ( + "text/template" + + "github.com/nullbio/sqlboiler/dbdrivers" +) + +// State holds the global data needed by most pieces to run +type State struct { + Config *Config + + Driver dbdrivers.Interface + Tables []dbdrivers.Table + + Templates templateList + TestTemplates templateList + SingletonTemplates templateList + SingletonTestTemplates templateList + + TestMainTemplate *template.Template +} + +// Config for the running of the commands +type Config struct { + DriverName string `toml:"driver_name"` + PkgName string `toml:"pkg_name"` + OutFolder string `toml:"out_folder"` + TableName string `toml:"table_name"` + + Postgres PostgresConfig `toml:"postgres"` +} + +// PostgresConfig configures a postgres database +type PostgresConfig struct { + User string `toml:"user"` + Pass string `toml:"pass"` + Host string `toml:"host"` + Port int `toml:"port"` + DBName string `toml:"dbname"` +} diff --git a/imports.go b/imports.go new file mode 100644 index 0000000..3ec1879 --- /dev/null +++ b/imports.go @@ -0,0 +1,262 @@ +package sqlboiler + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "github.com/nullbio/sqlboiler/dbdrivers" +) + +// imports defines the optional standard imports and +// thirdParty imports (from github for example) +type imports struct { + standard importList + thirdParty importList +} + +// importList is a list of import names +type importList []string + +func (i importList) Len() int { + return len(i) +} + +func (i importList) Swap(k, j int) { + i[k], i[j] = i[j], i[k] +} + +func (i importList) Less(k, j int) bool { + res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ ")) + if res <= 0 { + return true + } + + return false +} + +func combineImports(a, b imports) imports { + var c imports + + c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard)) + c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty)) + + sort.Sort(c.standard) + sort.Sort(c.thirdParty) + + return c +} + +func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports { + tmpImp := imports{ + standard: make(importList, len(a.standard)), + thirdParty: make(importList, len(a.thirdParty)), + } + + copy(tmpImp.standard, a.standard) + copy(tmpImp.thirdParty, a.thirdParty) + + for _, col := range columns { + for key, imp := range b { + if col.Type == key { + tmpImp.standard = append(tmpImp.standard, imp.standard...) + tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...) + } + } + } + + tmpImp.standard = removeDuplicates(tmpImp.standard) + tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty) + + sort.Sort(tmpImp.standard) + sort.Sort(tmpImp.thirdParty) + + return tmpImp +} + +func buildImportString(imps imports) []byte { + stdlen, thirdlen := len(imps.standard), len(imps.thirdParty) + if stdlen+thirdlen < 1 { + return []byte{} + } + + if stdlen+thirdlen == 1 { + var imp string + if stdlen == 1 { + imp = imps.standard[0] + } else { + imp = imps.thirdParty[0] + } + return []byte(fmt.Sprintf("import %s", imp)) + } + + buf := &bytes.Buffer{} + buf.WriteString("import (") + for _, std := range imps.standard { + fmt.Fprintf(buf, "\n\t%s", std) + } + if stdlen != 0 && thirdlen != 0 { + buf.WriteString("\n") + } + for _, third := range imps.thirdParty { + fmt.Fprintf(buf, "\n\t%s", third) + } + buf.WriteString("\n)\n") + + return buf.Bytes() +} + +func combineStringSlices(a, b []string) []string { + c := make([]string, len(a)+len(b)) + if len(a) > 0 { + copy(c, a) + } + if len(b) > 0 { + copy(c[len(a):], b) + } + + return c +} + +func removeDuplicates(dedup []string) []string { + if len(dedup) <= 1 { + return dedup + } + + for i := 0; i < len(dedup)-1; i++ { + for j := i + 1; j < len(dedup); j++ { + if dedup[i] != dedup[j] { + continue + } + + if j != len(dedup)-1 { + dedup[j] = dedup[len(dedup)-1] + j-- + } + dedup = dedup[:len(dedup)-1] + } + } + + return dedup +} + +var defaultTemplateImports = imports{ + standard: importList{ + `"errors"`, + `"fmt"`, + `"strings"`, + }, + thirdParty: importList{ + `"github.com/nullbio/sqlboiler/boil"`, + `"github.com/nullbio/sqlboiler/boil/qs"`, + }, +} + +var defaultSingletonTemplateImports = map[string]imports{ + "helpers": imports{ + standard: importList{}, + thirdParty: importList{ + `"github.com/nullbio/sqlboiler/boil"`, + `"github.com/nullbio/sqlboiler/boil/qs"`, + }, + }, +} + +var defaultTestTemplateImports = imports{ + standard: importList{ + `"testing"`, + `"reflect"`, + }, + thirdParty: importList{ + `"github.com/nullbio/sqlboiler/boil"`, + }, +} + +var defaultSingletonTestTemplateImports = map[string]imports{ + "helper_funcs": imports{ + standard: importList{ + `"crypto/md5"`, + `"fmt"`, + `"os"`, + `"strconv"`, + `"math/rand"`, + }, + thirdParty: importList{}, + }, +} + +var defaultTestMainImports = map[string]imports{ + "postgres": imports{ + standard: importList{ + `"testing"`, + `"os"`, + `"os/exec"`, + `"fmt"`, + `"io/ioutil"`, + `"bytes"`, + `"database/sql"`, + `"time"`, + `"math/rand"`, + }, + thirdParty: importList{ + `"github.com/nullbio/sqlboiler/boil"`, + `"github.com/BurntSushi/toml"`, + `_ "github.com/lib/pq"`, + }, + }, +} + +// importsBasedOnType imports are only included in the template output if the +// database requires one of the following special types. Check +// TranslateColumnType to see the type assignments. +var importsBasedOnType = map[string]imports{ + "null.Float32": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Float64": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Int": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Int8": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Int16": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Int32": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Int64": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Uint": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Uint8": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Uint16": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Uint32": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Uint64": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.String": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Bool": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "null.Time": imports{ + thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + }, + "time.Time": imports{ + standard: importList{`"time"`}, + }, +} diff --git a/cmds/helpers_test.go b/imports_test.go similarity index 70% rename from cmds/helpers_test.go rename to imports_test.go index 7b1f1b5..03dce51 100644 --- a/cmds/helpers_test.go +++ b/imports_test.go @@ -1,20 +1,66 @@ -package cmds +package sqlboiler import ( "fmt" "reflect" + "sort" "testing" "github.com/nullbio/sqlboiler/dbdrivers" ) +func TestImportsSort(t *testing.T) { + t.Parallel() + + a1 := importList{ + `"fmt"`, + `"errors"`, + } + a2 := importList{ + `_ "github.com/lib/pq"`, + `_ "github.com/gorilla/n"`, + `"github.com/gorilla/mux"`, + `"github.com/gorilla/websocket"`, + } + + a1Expected := importList{`"errors"`, `"fmt"`} + a2Expected := importList{ + `"github.com/gorilla/mux"`, + `_ "github.com/gorilla/n"`, + `"github.com/gorilla/websocket"`, + `_ "github.com/lib/pq"`, + } + + sort.Sort(a1) + if !reflect.DeepEqual(a1, a1Expected) { + t.Errorf("Expected a1 to match a1Expected, got: %v", a1) + } + + for i, v := range a1 { + if v != a1Expected[i] { + t.Errorf("Expected a1[%d] to match a1Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) + } + } + + sort.Sort(a2) + if !reflect.DeepEqual(a2, a2Expected) { + t.Errorf("Expected a2 to match a2expected, got: %v", a2) + } + + for i, v := range a2 { + if v != a2Expected[i] { + t.Errorf("Expected a2[%d] to match a2Expected[%d]:\n%s\n%s\n", i, i, v, a1Expected[i]) + } + } +} + func TestCombineTypeImports(t *testing.T) { imports1 := imports{ standard: importList{ `"errors"`, `"fmt"`, }, - thirdparty: importList{ + thirdParty: importList{ `"github.com/nullbio/sqlboiler/boil"`, }, } @@ -25,7 +71,7 @@ func TestCombineTypeImports(t *testing.T) { `"fmt"`, `"time"`, }, - thirdparty: importList{ + thirdParty: importList{ `"github.com/nullbio/sqlboiler/boil"`, `"gopkg.in/nullbio/null.v4"`, }, @@ -46,7 +92,7 @@ func TestCombineTypeImports(t *testing.T) { }, } - res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols) + res1 := combineTypeImports(imports1, importsBasedOnType, cols) if !reflect.DeepEqual(res1, importsExpected) { t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1) @@ -58,13 +104,13 @@ func TestCombineTypeImports(t *testing.T) { `"fmt"`, `"time"`, }, - thirdparty: importList{ + thirdParty: importList{ `"github.com/nullbio/sqlboiler/boil"`, `"gopkg.in/nullbio/null.v4"`, }, } - res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols) + res2 := combineTypeImports(imports2, importsBasedOnType, cols) if !reflect.DeepEqual(res2, importsExpected) { t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1) @@ -76,11 +122,11 @@ func TestCombineImports(t *testing.T) { a := imports{ standard: importList{"fmt"}, - thirdparty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"}, + thirdParty: importList{"github.com/nullbio/sqlboiler", "gopkg.in/nullbio/null.v4"}, } b := imports{ standard: importList{"os"}, - thirdparty: importList{"github.com/nullbio/sqlboiler"}, + thirdParty: importList{"github.com/nullbio/sqlboiler"}, } c := combineImports(a, b) @@ -88,8 +134,8 @@ func TestCombineImports(t *testing.T) { if c.standard[0] != "fmt" && c.standard[1] != "os" { t.Errorf("Wanted: fmt, os got: %#v", c.standard) } - if c.thirdparty[0] != "github.com/nullbio/sqlboiler" && c.thirdparty[1] != "gopkg.in/nullbio/null.v4" { - t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdparty) + if c.thirdParty[0] != "github.com/nullbio/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" { + t.Errorf("Wanted: github.com/nullbio/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty) } } diff --git a/main.go b/main.go deleted file mode 100644 index 3dbcde3..0000000 --- a/main.go +++ /dev/null @@ -1,54 +0,0 @@ -/* -SQLBoiler is a tool to generate Go boilerplate code for database interactions. -So far this includes struct definitions and database statement helper functions. -*/ - -package main - -import ( - "fmt" - "os" - - "github.com/nullbio/sqlboiler/cmds" - "github.com/spf13/cobra" -) - -func main() { - var err error - cmdData := &cmds.CmdData{} - - // Load the "config.toml" global config - err = cmdData.LoadConfigFile("config.toml") - if err != nil { - fmt.Printf("Failed to load config file: %s\n", err) - os.Exit(-1) - } - - // Set up the cobra root command - var rootCmd = &cobra.Command{ - Use: "sqlboiler", - Short: "SQL Boiler generates boilerplate structs and statements", - Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" + - `Complete documentation is available at http://github.com/nullbio/sqlboiler`, - PreRunE: func(cmd *cobra.Command, args []string) error { - return cmdData.SQLBoilerPreRun(cmd, args) - }, - RunE: func(cmd *cobra.Command, args []string) error { - return cmdData.SQLBoilerRun(cmd, args) - }, - PostRunE: func(cmd *cobra.Command, args []string) error { - return cmdData.SQLBoilerPostRun(cmd, args) - }, - } - - // Set up the cobra root command flags - rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)") - rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names") - rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder") - rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package") - - // Execute SQLBoiler - if err := rootCmd.Execute(); err != nil { - os.Exit(-1) - } -} diff --git a/cmds/output.go b/output.go similarity index 59% rename from cmds/output.go rename to output.go index ac530d8..c58056d 100644 --- a/cmds/output.go +++ b/output.go @@ -1,4 +1,4 @@ -package cmds +package sqlboiler import ( "bytes" @@ -18,18 +18,18 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { } // generateOutput builds the file output and sends it to outHandler for saving -func generateOutput(cmdData *CmdData, data *tplData) error { - if len(cmdData.Templates) == 0 { +func generateOutput(state *State, data *templateData) error { + if len(state.Templates) == 0 { return errors.New("No template files located for generation") } var out [][]byte var imps imports - imps.standard = sqlBoilerImports.standard - imps.thirdparty = sqlBoilerImports.thirdparty + imps.standard = defaultTemplateImports.standard + imps.thirdParty = defaultTemplateImports.thirdParty - for _, template := range cmdData.Templates { - imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns) + for _, template := range state.Templates { + imps = combineTypeImports(imps, importsBasedOnType, data.Table.Columns) resp, err := generateTemplate(template, data) if err != nil { return fmt.Errorf("Error generating template %s: %s", template.Name(), err) @@ -38,7 +38,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error { } fName := data.Table.Name + ".go" - err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out) + err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out) if err != nil { return err } @@ -47,18 +47,18 @@ func generateOutput(cmdData *CmdData, data *tplData) error { } // generateTestOutput builds the test file output and sends it to outHandler for saving -func generateTestOutput(cmdData *CmdData, data *tplData) error { - if len(cmdData.TestTemplates) == 0 { +func generateTestOutput(state *State, data *templateData) error { + if len(state.TestTemplates) == 0 { return errors.New("No template files located for generation") } var out [][]byte var imps imports - imps.standard = sqlBoilerTestImports.standard - imps.thirdparty = sqlBoilerTestImports.thirdparty + imps.standard = defaultTestTemplateImports.standard + imps.thirdParty = defaultTestTemplateImports.thirdParty - for _, template := range cmdData.TestTemplates { + for _, template := range state.TestTemplates { resp, err := generateTemplate(template, data) if err != nil { return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) @@ -67,7 +67,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error { } fName := data.Table.Name + "_test.go" - err := outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, out) + err := outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, out) if err != nil { return err } @@ -77,20 +77,20 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error { // generateSingletonOutput processes the templates that should only be run // one time. -func generateSingletonOutput(cmdData *CmdData) error { - if cmdData.SingleTemplates == nil { +func generateSingletonOutput(state *State) error { + if state.SingletonTemplates == nil { return errors.New("No singleton templates located for generation") } - tplData := &tplData{ - PkgName: cmdData.PkgName, - DriverName: cmdData.DriverName, + templateData := &templateData{ + PkgName: state.Config.PkgName, + DriverName: state.Config.DriverName, } - for _, template := range cmdData.SingleTemplates { + for _, template := range state.SingletonTemplates { var imps imports - resp, err := generateTemplate(template, tplData) + resp, err := generateTemplate(template, templateData) if err != nil { return fmt.Errorf("Error generating template %s: %s", template.Name(), err) } @@ -99,12 +99,12 @@ func generateSingletonOutput(cmdData *CmdData) error { ext := filepath.Ext(fName) fName = fName[0 : len(fName)-len(ext)] - imps.standard = sqlBoilerSinglesImports[fName].standard - imps.thirdparty = sqlBoilerSinglesImports[fName].thirdparty + imps.standard = defaultSingletonTemplateImports[fName].standard + imps.thirdParty = defaultSingletonTemplateImports[fName].thirdParty fName = fName + ".go" - err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp}) + err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp}) if err != nil { return err } @@ -115,20 +115,20 @@ func generateSingletonOutput(cmdData *CmdData) error { // generateSingletonTestOutput processes the templates that should only be run // one time. -func generateSingletonTestOutput(cmdData *CmdData) error { - if cmdData.SingleTestTemplates == nil { +func generateSingletonTestOutput(state *State) error { + if state.SingletonTestTemplates == nil { return errors.New("No singleton test templates located for generation") } - tplData := &tplData{ - PkgName: cmdData.PkgName, - DriverName: cmdData.DriverName, + templateData := &templateData{ + PkgName: state.Config.PkgName, + DriverName: state.Config.DriverName, } - for _, template := range cmdData.SingleTestTemplates { + for _, template := range state.SingletonTestTemplates { var imps imports - resp, err := generateTemplate(template, tplData) + resp, err := generateTemplate(template, templateData) if err != nil { return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) } @@ -137,12 +137,12 @@ func generateSingletonTestOutput(cmdData *CmdData) error { ext := filepath.Ext(fName) fName = fName[0 : len(fName)-len(ext)] - imps.standard = sqlBoilerSinglesTestImports[fName].standard - imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty + imps.standard = defaultSingletonTestTemplateImports[fName].standard + imps.thirdParty = defaultSingletonTestTemplateImports[fName].thirdParty fName = fName + "_test.go" - err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp}) + err = outHandler(state.Config.OutFolder, fName, state.Config.PkgName, imps, [][]byte{resp}) if err != nil { return err } @@ -151,35 +151,30 @@ func generateSingletonTestOutput(cmdData *CmdData) error { return nil } -func generateTestMainOutput(cmdData *CmdData) error { - if cmdData.TestMainTemplate == nil { +func generateTestMainOutput(state *State) error { + if state.TestMainTemplate == nil { return errors.New("No TestMain template located for generation") } var out [][]byte var imps imports - imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard - imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty + imps.standard = defaultTestMainImports[state.Config.DriverName].standard + imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty - var tables []string - for _, v := range cmdData.Tables { - tables = append(tables, v.Name) + templateData := &templateData{ + Tables: state.Tables, + PkgName: state.Config.PkgName, + DriverName: state.Config.DriverName, } - tplData := &tplData{ - PkgName: cmdData.PkgName, - DriverName: cmdData.DriverName, - Tables: tables, - } - - resp, err := generateTemplate(cmdData.TestMainTemplate, tplData) + resp, err := generateTemplate(state.TestMainTemplate, templateData) if err != nil { return err } out = append(out, resp) - err = outHandler(cmdData.OutFolder, "main_test.go", cmdData.PkgName, imps, out) + err = outHandler(state.Config.OutFolder, "main_test.go", state.Config.PkgName, imps, out) if err != nil { return err } @@ -220,7 +215,7 @@ func outHandler(outFolder string, fileName string, pkgName string, imps imports, } // generateTemplate takes a template and returns the output of the template execution. -func generateTemplate(t *template.Template, data *tplData) ([]byte, error) { +func generateTemplate(t *template.Template, data *templateData) ([]byte, error) { var buf bytes.Buffer if err := t.Execute(&buf, data); err != nil { return nil, err diff --git a/cmds/output_test.go b/output_test.go similarity index 96% rename from cmds/output_test.go rename to output_test.go index e946d74..969f01f 100644 --- a/cmds/output_test.go +++ b/output_test.go @@ -1,4 +1,4 @@ -package cmds +package sqlboiler import ( "bytes" @@ -78,7 +78,7 @@ func TestOutHandlerFiles(t *testing.T) { } a2 := imports{ - thirdparty: []string{ + thirdParty: []string{ `"github.com/spf13/cobra"`, }, } @@ -96,7 +96,7 @@ func TestOutHandlerFiles(t *testing.T) { `"fmt"`, `"errors"`, }, - thirdparty: importList{ + thirdParty: importList{ `_ "github.com/lib/pq"`, `_ "github.com/gorilla/n"`, `"github.com/gorilla/mux"`, @@ -106,7 +106,7 @@ func TestOutHandlerFiles(t *testing.T) { file = &bytes.Buffer{} sort.Sort(a3.standard) - sort.Sort(a3.thirdparty) + sort.Sort(a3.thirdParty) if err := outHandler("folder", "file.go", "patrick", a3, templateOutputs); err != nil { t.Error(err) diff --git a/sqlboiler.go b/sqlboiler.go new file mode 100644 index 0000000..09d88c4 --- /dev/null +++ b/sqlboiler.go @@ -0,0 +1,207 @@ +// Package sqlboiler has types and methods useful for generating code that +// acts as a fully dynamic ORM might. +package sqlboiler + +import ( + "errors" + "fmt" + "os" + "strings" + + "github.com/nullbio/sqlboiler/dbdrivers" +) + +const ( + templatesDirectory = "cmds/templates" + templatesSingletonDirectory = "cmds/templates/singleton" + + templatesTestDirectory = "cmds/templates_test" + templatesSingletonTestDirectory = "cmds/templates_test/singleton" +) + +// New creates a new state based off of the config +func New(config *Config) (*State, error) { + s := &State{} + + err := s.initDriver(config.DriverName) + if err != nil { + return nil, err + } + + // Connect to the driver database + if err = s.Driver.Open(); err != nil { + return nil, fmt.Errorf("Unable to connect to the database: %s", err) + } + + err = s.initTables(config.TableName) + if err != nil { + return nil, fmt.Errorf("Unable to initialize tables: %s", err) + } + + err = s.initOutFolder() + if err != nil { + return nil, fmt.Errorf("Unable to initialize the output folder: %s", err) + } + + err = s.initTemplates() + if err != nil { + return nil, fmt.Errorf("Unable to initialize templates: %s", err) + } + + return s, nil +} + +// Run executes the sqlboiler templates and outputs them to files based on the +// state given. +func (s *State) Run(includeTests bool) error { + if err := generateSingletonOutput(s); err != nil { + return fmt.Errorf("Unable to generate singleton template output: %s", err) + } + + if includeTests { + if err := generateTestMainOutput(s); err != nil { + return fmt.Errorf("Unable to generate TestMain output: %s", err) + } + + if err := generateSingletonTestOutput(s); err != nil { + return fmt.Errorf("Unable to generate singleton test template output: %s", err) + } + } + + for _, table := range s.Tables { + if table.IsJoinTable { + continue + } + + data := &templateData{ + Table: table, + DriverName: s.Config.DriverName, + PkgName: s.Config.PkgName, + } + + // Generate the regular templates + if err := generateOutput(s, data); err != nil { + return fmt.Errorf("Unable to generate output: %s", err) + } + + // Generate the test templates + if includeTests { + if err := generateTestOutput(s, data); err != nil { + return fmt.Errorf("Unable to generate test output: %s", err) + } + } + } + + return nil +} + +// Cleanup closes any resources that must be closed +func (s *State) Cleanup() error { + s.Driver.Close() + return nil +} + +// initTemplates loads all template folders into the state object. +func (s *State) initTemplates() error { + var err error + + s.Templates, err = loadTemplates(templatesDirectory) + if err != nil { + return err + } + + s.SingletonTemplates, err = loadTemplates(templatesSingletonDirectory) + if err != nil { + return err + } + + s.TestTemplates, err = loadTemplates(templatesTestDirectory) + if err != nil { + return err + } + + s.SingletonTestTemplates, err = loadTemplates(templatesSingletonTestDirectory) + if err != nil { + return err + } + + return nil +} + +// initDriver attempts to set the state Interface based off the passed in +// driver flag value. If an invalid flag string is provided an error is returned. +func (s *State) initDriver(driverName string) error { + // Create a driver based off driver flag + switch driverName { + case "postgres": + s.Driver = dbdrivers.NewPostgresDriver( + s.Config.Postgres.User, + s.Config.Postgres.Pass, + s.Config.Postgres.DBName, + s.Config.Postgres.Host, + s.Config.Postgres.Port, + ) + } + + if s.Driver == nil { + return errors.New("An invalid driver name was provided") + } + + return nil +} + +// initTables will create a string slice out of the passed in table flag value +// if one is provided. If no flag is provided, it will attempt to connect to the +// database to retrieve all "public" table names, and build a slice out of that +// result. +func (s *State) initTables(tableName string) error { + var tableNames []string + + if len(tableName) != 0 { + tableNames = strings.Split(tableName, ",") + for i, name := range tableNames { + tableNames[i] = strings.TrimSpace(name) + } + } + + var err error + s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...) + if err != nil { + return fmt.Errorf("Unable to get all table names: %s", err) + } + + if len(s.Tables) == 0 { + return errors.New("No tables found in database, migrate some tables first") + } + + if err := checkPKeys(s.Tables); err != nil { + return err + } + + return nil +} + +// initOutFolder creates the folder that will hold the generated output. +func (s *State) initOutFolder() error { + if err := os.MkdirAll(s.Config.OutFolder, os.ModePerm); err != nil { + return fmt.Errorf("Unable to make output folder: %s", err) + } + + return nil +} + +// checkPKeys ensures every table has a primary key column +func checkPKeys(tables []dbdrivers.Table) error { + var missingPkey []string + for _, t := range tables { + if t.PKey == nil { + missingPkey = append(missingPkey, t.Name) + } + } + + if len(missingPkey) != 0 { + return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) + } + + return nil +} diff --git a/cmds/sqlboiler_test.go b/sqlboiler_test.go similarity index 82% rename from cmds/sqlboiler_test.go rename to sqlboiler_test.go index 9482498..0d29ae4 100644 --- a/cmds/sqlboiler_test.go +++ b/sqlboiler_test.go @@ -1,4 +1,4 @@ -package cmds +package sqlboiler import ( "bufio" @@ -15,11 +15,11 @@ import ( "github.com/nullbio/sqlboiler/dbdrivers" ) -var cmdData *CmdData +var state *State var rgxHasSpaces = regexp.MustCompile(`^\s+`) func init() { - cmdData = &CmdData{ + state = &State{ Tables: []dbdrivers.Table{ { Name: "patrick_table", @@ -59,10 +59,11 @@ func init() { }, }, }, - PkgName: "patrick", - OutFolder: "", - DriverName: "postgres", - Interface: nil, + Config: &Config{ + PkgName: "patrick", + OutFolder: "", + DriverName: "postgres", + }, } } @@ -84,63 +85,63 @@ func TestTemplates(t *testing.T) { t.SkipNow() } - if err := checkPKeys(cmdData.Tables); err != nil { + if err := checkPKeys(state.Tables); err != nil { t.Fatalf("%s", err) } // Initialize the templates var err error - cmdData.Templates, err = loadTemplates("templates") + state.Templates, err = loadTemplates("templates") if err != nil { t.Fatalf("Unable to initialize templates: %s", err) } - if len(cmdData.Templates) == 0 { + if len(state.Templates) == 0 { t.Errorf("Templates is empty.") } - cmdData.SingleTemplates, err = loadTemplates("templates/singleton") + state.SingletonTemplates, err = loadTemplates("templates/singleton") if err != nil { t.Fatalf("Unable to initialize singleton templates: %s", err) } - if len(cmdData.SingleTemplates) == 0 { - t.Errorf("SingleTemplates is empty.") + if len(state.SingletonTemplates) == 0 { + t.Errorf("SingletonTemplates is empty.") } - cmdData.TestTemplates, err = loadTemplates("templates_test") + state.TestTemplates, err = loadTemplates("templates_test") if err != nil { t.Fatalf("Unable to initialize templates: %s", err) } - if len(cmdData.Templates) == 0 { + if len(state.Templates) == 0 { t.Errorf("Templates is empty.") } - cmdData.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl") + state.TestMainTemplate, err = loadTemplate("templates_test/main_test", "postgres_main.tpl") if err != nil { t.Fatalf("Unable to initialize templates: %s", err) } - cmdData.OutFolder, err = ioutil.TempDir("", "templates") + state.Config.OutFolder, err = ioutil.TempDir("", "templates") if err != nil { t.Fatalf("Unable to create tempdir: %s", err) } - defer os.RemoveAll(cmdData.OutFolder) + defer os.RemoveAll(state.Config.OutFolder) - if err = cmdData.run(true); err != nil { + if err = state.Run(true); err != nil { t.Errorf("Unable to run SQLBoilerRun: %s", err) } buf := &bytes.Buffer{} cmd := exec.Command("go", "test", "-c") - cmd.Dir = cmdData.OutFolder + cmd.Dir = state.Config.OutFolder cmd.Stderr = buf if err = cmd.Run(); err != nil { t.Errorf("go test cmd execution failed: %s", err) - outputCompileErrors(buf, cmdData.OutFolder) + outputCompileErrors(buf, state.Config.OutFolder) fmt.Println() } } diff --git a/templates.go b/templates.go new file mode 100644 index 0000000..2624c6e --- /dev/null +++ b/templates.go @@ -0,0 +1,119 @@ +package sqlboiler + +import ( + "os" + "path/filepath" + "sort" + "strings" + "text/template" + + "github.com/nullbio/sqlboiler/dbdrivers" + "github.com/nullbio/sqlboiler/strmangle" +) + +// templateData for sqlboiler templates +type templateData struct { + Tables []dbdrivers.Table + Table dbdrivers.Table + DriverName string + PkgName string +} + +type templateList []*template.Template + +func (t templateList) Len() int { + return len(t) +} + +func (t templateList) Swap(k, j int) { + t[k], t[j] = t[j], t[k] +} + +func (t templateList) Less(k, j int) bool { + // Make sure "struct" goes to the front + if t[k].Name() == "struct.tpl" { + return true + } + + res := strings.Compare(t[k].Name(), t[j].Name()) + if res <= 0 { + return true + } + + return false +} + +// loadTemplates loads all of the template files in the specified directory. +func loadTemplates(dir string) (templateList, error) { + wd, err := os.Getwd() + if err != nil { + return nil, err + } + + pattern := filepath.Join(wd, dir, "*.tpl") + tpl, err := template.New("").Funcs(templateFunctions).ParseGlob(pattern) + + if err != nil { + return nil, err + } + + templates := templateList(tpl.Templates()) + sort.Sort(templates) + + return templates, err +} + +// loadTemplate loads a single template file. +func loadTemplate(dir string, filename string) (*template.Template, error) { + wd, err := os.Getwd() + if err != nil { + return nil, err + } + + pattern := filepath.Join(wd, dir, filename) + tpl, err := template.New("").Funcs(templateFunctions).ParseFiles(pattern) + + if err != nil { + return nil, err + } + + return tpl.Lookup(filename), err +} + +// templateFunctions is a map of all the functions that get passed into the +// templates. If you wish to pass a new function into your own template, +// add a function pointer here. +var templateFunctions = template.FuncMap{ + "singular": strmangle.Singular, + "plural": strmangle.Plural, + "titleCase": strmangle.TitleCase, + "titleCaseSingular": strmangle.TitleCaseSingular, + "titleCasePlural": strmangle.TitleCasePlural, + "titleCaseCommaList": strmangle.TitleCaseCommaList, + "camelCase": strmangle.CamelCase, + "camelCaseSingular": strmangle.CamelCaseSingular, + "camelCasePlural": strmangle.CamelCasePlural, + "camelCaseCommaList": strmangle.CamelCaseCommaList, + "columnsToStrings": strmangle.ColumnsToStrings, + "commaList": strmangle.CommaList, + "makeDBName": strmangle.MakeDBName, + "selectParamNames": strmangle.SelectParamNames, + "insertParamNames": strmangle.InsertParamNames, + "insertParamFlags": strmangle.InsertParamFlags, + "insertParamVariables": strmangle.InsertParamVariables, + "scanParamNames": strmangle.ScanParamNames, + "hasPrimaryKey": strmangle.HasPrimaryKey, + "primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig, + "wherePrimaryKey": strmangle.WherePrimaryKey, + "paramsPrimaryKey": strmangle.ParamsPrimaryKey, + "primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex, + "updateParamNames": strmangle.UpdateParamNames, + "updateParamVariables": strmangle.UpdateParamVariables, + "supportsResultObject": strmangle.SupportsResultObject, + "filterColumnsByDefault": strmangle.FilterColumnsByDefault, + "filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement, + "autoIncPrimaryKey": strmangle.AutoIncPrimaryKey, + + "randDBStruct": strmangle.RandDBStruct, + "randDBStructSlice": strmangle.RandDBStructSlice, +} diff --git a/cmds/templates/all.tpl b/templates/all.tpl similarity index 100% rename from cmds/templates/all.tpl rename to templates/all.tpl diff --git a/cmds/templates/delete.tpl b/templates/delete.tpl similarity index 100% rename from cmds/templates/delete.tpl rename to templates/delete.tpl diff --git a/cmds/templates/find.tpl b/templates/find.tpl similarity index 100% rename from cmds/templates/find.tpl rename to templates/find.tpl diff --git a/cmds/templates/finishers.tpl b/templates/finishers.tpl similarity index 100% rename from cmds/templates/finishers.tpl rename to templates/finishers.tpl diff --git a/cmds/templates/helpers.tpl b/templates/helpers.tpl similarity index 100% rename from cmds/templates/helpers.tpl rename to templates/helpers.tpl diff --git a/cmds/templates/hooks.tpl b/templates/hooks.tpl similarity index 100% rename from cmds/templates/hooks.tpl rename to templates/hooks.tpl diff --git a/cmds/templates/insert.tpl b/templates/insert.tpl similarity index 100% rename from cmds/templates/insert.tpl rename to templates/insert.tpl diff --git a/cmds/templates/singleton/helpers.tpl b/templates/singleton/helpers.tpl similarity index 100% rename from cmds/templates/singleton/helpers.tpl rename to templates/singleton/helpers.tpl diff --git a/cmds/templates/struct.tpl b/templates/struct.tpl similarity index 100% rename from cmds/templates/struct.tpl rename to templates/struct.tpl diff --git a/cmds/templates/update.tpl b/templates/update.tpl similarity index 100% rename from cmds/templates/update.tpl rename to templates/update.tpl diff --git a/templates_test.go b/templates_test.go new file mode 100644 index 0000000..2806f29 --- /dev/null +++ b/templates_test.go @@ -0,0 +1,34 @@ +package sqlboiler + +import ( + "sort" + "testing" + "text/template" +) + +func TestTemplateListSort(t *testing.T) { + templs := templateList{ + template.New("bob.tpl"), + template.New("all.tpl"), + template.New("struct.tpl"), + template.New("ttt.tpl"), + } + + expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"} + + for i, v := range templs { + if v.Name() != expected[i] { + t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name()) + } + } + + expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"} + + sort.Sort(templs) + + for i, v := range templs { + if v.Name() != expected[i] { + t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v.Name()) + } + } +} diff --git a/cmds/templates_test/all.tpl b/templates_test/all.tpl similarity index 100% rename from cmds/templates_test/all.tpl rename to templates_test/all.tpl diff --git a/cmds/templates_test/delete.tpl b/templates_test/delete.tpl similarity index 100% rename from cmds/templates_test/delete.tpl rename to templates_test/delete.tpl diff --git a/cmds/templates_test/find.tpl b/templates_test/find.tpl similarity index 100% rename from cmds/templates_test/find.tpl rename to templates_test/find.tpl diff --git a/cmds/templates_test/finishers.tpl b/templates_test/finishers.tpl similarity index 100% rename from cmds/templates_test/finishers.tpl rename to templates_test/finishers.tpl diff --git a/cmds/templates_test/helpers.tpl b/templates_test/helpers.tpl similarity index 100% rename from cmds/templates_test/helpers.tpl rename to templates_test/helpers.tpl diff --git a/cmds/templates_test/hooks.tpl b/templates_test/hooks.tpl similarity index 100% rename from cmds/templates_test/hooks.tpl rename to templates_test/hooks.tpl diff --git a/cmds/templates_test/insert.tpl b/templates_test/insert.tpl similarity index 100% rename from cmds/templates_test/insert.tpl rename to templates_test/insert.tpl diff --git a/cmds/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl similarity index 100% rename from cmds/templates_test/main_test/postgres_main.tpl rename to templates_test/main_test/postgres_main.tpl diff --git a/cmds/templates_test/select.tpl b/templates_test/select.tpl similarity index 100% rename from cmds/templates_test/select.tpl rename to templates_test/select.tpl diff --git a/cmds/templates_test/singleton/helper_funcs.tpl b/templates_test/singleton/helper_funcs.tpl similarity index 100% rename from cmds/templates_test/singleton/helper_funcs.tpl rename to templates_test/singleton/helper_funcs.tpl diff --git a/cmds/templates_test/struct.tpl b/templates_test/struct.tpl similarity index 100% rename from cmds/templates_test/struct.tpl rename to templates_test/struct.tpl diff --git a/cmds/templates_test/update.tpl b/templates_test/update.tpl similarity index 100% rename from cmds/templates_test/update.tpl rename to templates_test/update.tpl