Finished most of the stripping
This commit is contained in:
parent
c74ed4e75f
commit
27cafdd2fb
15 changed files with 401 additions and 723 deletions
101
cmds/boil.go
101
cmds/boil.go
|
@ -1,101 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
var boilCmd = &cobra.Command{
|
|
||||||
Use: "boil",
|
|
||||||
Short: "Generates ALL templates by running every command alphabetically",
|
|
||||||
}
|
|
||||||
|
|
||||||
// boilRun executes every sqlboiler command, starting with structs.
|
|
||||||
func boilRun(cmd *cobra.Command, args []string) {
|
|
||||||
commandNames := buildCommandList()
|
|
||||||
|
|
||||||
// Prepend "struct" command to templateNames slice so it sits at top of sort
|
|
||||||
commandNames = append([]string{"struct"}, commandNames...)
|
|
||||||
|
|
||||||
// Create a testCommandNames with "driverName_main" prepended to the front for the test templates
|
|
||||||
// the main template initializes all of the testing assets
|
|
||||||
testCommandNames := append([]string{cmdData.DriverName + "_main"}, commandNames...)
|
|
||||||
|
|
||||||
for _, table := range cmdData.Tables {
|
|
||||||
data := &tplData{
|
|
||||||
Table: table,
|
|
||||||
PkgName: cmdData.PkgName,
|
|
||||||
}
|
|
||||||
|
|
||||||
var out [][]byte
|
|
||||||
var imps imports
|
|
||||||
|
|
||||||
imps.standard = sqlBoilerDefaultImports.standard
|
|
||||||
imps.thirdparty = sqlBoilerDefaultImports.thirdparty
|
|
||||||
|
|
||||||
// Loop through and generate every command template (excluding skipTemplates)
|
|
||||||
for _, command := range commandNames {
|
|
||||||
imps = combineImports(imps, sqlBoilerCustomImports[command])
|
|
||||||
imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns)
|
|
||||||
out = append(out, generateTemplate(command, data))
|
|
||||||
}
|
|
||||||
|
|
||||||
err := outHandler(cmdData.OutFolder, out, data, imps, false)
|
|
||||||
if err != nil {
|
|
||||||
errorQuit(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the test templates for all commands
|
|
||||||
if len(testTemplates) != 0 {
|
|
||||||
var testOut [][]byte
|
|
||||||
var testImps imports
|
|
||||||
|
|
||||||
testImps.standard = sqlBoilerDefaultTestImports.standard
|
|
||||||
testImps.thirdparty = sqlBoilerDefaultTestImports.thirdparty
|
|
||||||
|
|
||||||
testImps = combineImports(testImps, sqlBoilerConditionalDriverTestImports[cmdData.DriverName])
|
|
||||||
|
|
||||||
// Loop through and generate every command test template (excluding skipTemplates)
|
|
||||||
for _, command := range testCommandNames {
|
|
||||||
testImps = combineImports(testImps, sqlBoilerCustomTestImports[command])
|
|
||||||
testOut = append(testOut, generateTestTemplate(command, data))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = outHandler(cmdData.OutFolder, testOut, data, testImps, true)
|
|
||||||
if err != nil {
|
|
||||||
errorQuit(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCommandList() []string {
|
|
||||||
// Exclude these commands from the output
|
|
||||||
skipCommands := []string{
|
|
||||||
"boil",
|
|
||||||
"struct",
|
|
||||||
}
|
|
||||||
|
|
||||||
var commandNames []string
|
|
||||||
|
|
||||||
// Build a list of template names
|
|
||||||
for _, c := range sqlBoilerCommands {
|
|
||||||
skip := false
|
|
||||||
for _, s := range skipCommands {
|
|
||||||
// Skip name if it's in the exclude list.
|
|
||||||
if s == c.Name() {
|
|
||||||
skip = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !skip {
|
|
||||||
commandNames = append(commandNames, c.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort all names alphabetically
|
|
||||||
sort.Strings(commandNames)
|
|
||||||
return commandNames
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestBuildCommandList(t *testing.T) {
|
|
||||||
list := buildCommandList()
|
|
||||||
|
|
||||||
skips := []string{"struct", "boil"}
|
|
||||||
|
|
||||||
for _, item := range list {
|
|
||||||
for _, skipItem := range skips {
|
|
||||||
if item == skipItem {
|
|
||||||
t.Errorf("Did not expect to find: %s %#v", item, list)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandNameLoop:
|
|
||||||
for cmdName := range sqlBoilerCommands {
|
|
||||||
for _, skipItem := range skips {
|
|
||||||
if cmdName == skipItem {
|
|
||||||
continue CommandNameLoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, item := range list {
|
|
||||||
if item == cmdName {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
t.Error("Expected to find command name:", cmdName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
194
cmds/commands.go
194
cmds/commands.go
|
@ -1,194 +0,0 @@
|
||||||
package cmds
|
|
||||||
|
|
||||||
import (
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
type importList []string
|
|
||||||
|
|
||||||
// imports defines the optional standard imports and
|
|
||||||
// thirdparty imports (from github for example)
|
|
||||||
type imports struct {
|
|
||||||
standard importList
|
|
||||||
thirdparty importList
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerDefaultImports defines the list of default template imports.
|
|
||||||
var sqlBoilerDefaultImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"errors"`,
|
|
||||||
`"fmt"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/pobri19/sqlboiler/boil"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerDefaultTestImports defines the list of default test template imports.
|
|
||||||
var sqlBoilerDefaultTestImports = imports{
|
|
||||||
standard: importList{
|
|
||||||
`"testing"`,
|
|
||||||
`"os"`,
|
|
||||||
`"os/exec"`,
|
|
||||||
`"fmt"`,
|
|
||||||
`"io/ioutil"`,
|
|
||||||
`"bytes"`,
|
|
||||||
`"errors"`,
|
|
||||||
},
|
|
||||||
thirdparty: importList{
|
|
||||||
`"github.com/BurntSushi/toml"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerConditionalTypeImports imports are only included in the template output
|
|
||||||
// if the database requires one of the following special types. Check TranslateColumnType
|
|
||||||
// to see the type assignments.
|
|
||||||
var sqlBoilerConditionalTypeImports = map[string]imports{
|
|
||||||
"null.Int": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
|
||||||
},
|
|
||||||
"null.String": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
|
||||||
},
|
|
||||||
"null.Bool": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
|
||||||
},
|
|
||||||
"null.Float": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
|
||||||
},
|
|
||||||
"null.Time": imports{
|
|
||||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
|
||||||
},
|
|
||||||
"time.Time": imports{
|
|
||||||
standard: importList{`"time"`},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerConditionalDriverTestImports defines the test template imports
|
|
||||||
// for the particular database interfaces
|
|
||||||
var sqlBoilerConditionalDriverTestImports = map[string]imports{
|
|
||||||
"postgres": imports{
|
|
||||||
standard: importList{`"database/sql"`},
|
|
||||||
thirdparty: importList{`_ "github.com/lib/pq"`},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlBoilerCustomImports map[string]imports
|
|
||||||
var sqlBoilerCustomTestImports map[string]imports
|
|
||||||
|
|
||||||
// sqlBoilerCommands points each command to its cobra.Command variable.
|
|
||||||
//
|
|
||||||
// If you would like to add your own custom command, add it to this
|
|
||||||
// map, and point it to your <commandName>Cmd variable.
|
|
||||||
//
|
|
||||||
// Command names should match the template file name (without the file extension).
|
|
||||||
var sqlBoilerCommands = map[string]*cobra.Command{
|
|
||||||
// Command to generate all commands
|
|
||||||
"boil": boilCmd,
|
|
||||||
// Struct commands
|
|
||||||
"struct": structCmd,
|
|
||||||
// Insert commands
|
|
||||||
"insert": insertCmd,
|
|
||||||
// Select commands
|
|
||||||
"all": allCmd,
|
|
||||||
"where": whereCmd,
|
|
||||||
"select": selectCmd,
|
|
||||||
"selectwhere": selectWhereCmd,
|
|
||||||
"find": findCmd,
|
|
||||||
"findselect": findSelectCmd,
|
|
||||||
// Delete commands
|
|
||||||
"delete": deleteCmd,
|
|
||||||
// Update commands
|
|
||||||
"update": updateCmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerCommandRuns points each command to its custom run function.
|
|
||||||
// If a run function is not defined here, it will use the defaultRun() default run function.
|
|
||||||
var sqlBoilerCommandRuns = map[string]CobraRunFunc{
|
|
||||||
"boil": boilRun,
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
|
||||||
// If you wish to pass a new function into your own template, add a pointer to it here.
|
|
||||||
var sqlBoilerTemplateFuncs = template.FuncMap{
|
|
||||||
"singular": singular,
|
|
||||||
"plural": plural,
|
|
||||||
"titleCase": titleCase,
|
|
||||||
"titleCaseSingular": titleCaseSingular,
|
|
||||||
"titleCasePlural": titleCasePlural,
|
|
||||||
"camelCase": camelCase,
|
|
||||||
"camelCaseSingular": camelCaseSingular,
|
|
||||||
"camelCasePlural": camelCasePlural,
|
|
||||||
"makeDBName": makeDBName,
|
|
||||||
"selectParamNames": selectParamNames,
|
|
||||||
"insertParamNames": insertParamNames,
|
|
||||||
"insertParamFlags": insertParamFlags,
|
|
||||||
"insertParamVariables": insertParamVariables,
|
|
||||||
"scanParamNames": scanParamNames,
|
|
||||||
"hasPrimaryKey": hasPrimaryKey,
|
|
||||||
"getPrimaryKey": getPrimaryKey,
|
|
||||||
"updateParamNames": updateParamNames,
|
|
||||||
"updateParamVariables": updateParamVariables,
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Struct commands */
|
|
||||||
|
|
||||||
var structCmd = &cobra.Command{
|
|
||||||
Use: "struct",
|
|
||||||
Short: "Generate structs from table definitions",
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Insert commands */
|
|
||||||
|
|
||||||
var insertCmd = &cobra.Command{
|
|
||||||
Use: "insert",
|
|
||||||
Short: "Generate insert statement helpers from table definitions",
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Select commands */
|
|
||||||
|
|
||||||
var allCmd = &cobra.Command{
|
|
||||||
Use: "all",
|
|
||||||
Short: "Generate a helper to select all records",
|
|
||||||
}
|
|
||||||
|
|
||||||
var whereCmd = &cobra.Command{
|
|
||||||
Use: "where",
|
|
||||||
Short: "Generate a helper to select all records with specific column values",
|
|
||||||
}
|
|
||||||
|
|
||||||
var selectCmd = &cobra.Command{
|
|
||||||
Use: "select",
|
|
||||||
Short: "Generate a helper to select specific fields of all records",
|
|
||||||
}
|
|
||||||
|
|
||||||
var selectWhereCmd = &cobra.Command{
|
|
||||||
Use: "selectwhere",
|
|
||||||
Short: "Generate a helper to select specific fields of records with specific column values",
|
|
||||||
}
|
|
||||||
|
|
||||||
var findCmd = &cobra.Command{
|
|
||||||
Use: "find",
|
|
||||||
Short: "Generate a helper to select a single record by ID",
|
|
||||||
}
|
|
||||||
|
|
||||||
var findSelectCmd = &cobra.Command{
|
|
||||||
Use: "findselect",
|
|
||||||
Short: "Generate a helper to select specific fields of a record by ID",
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Delete commands */
|
|
||||||
|
|
||||||
var deleteCmd = &cobra.Command{
|
|
||||||
Use: "delete",
|
|
||||||
Short: "Generate delete statement helpers from table definitions",
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Update commands */
|
|
||||||
|
|
||||||
var updateCmd = &cobra.Command{
|
|
||||||
Use: "update",
|
|
||||||
Short: "Generate update statement helpers from table definitions",
|
|
||||||
}
|
|
112
cmds/config.go
112
cmds/config.go
|
@ -3,49 +3,107 @@ package cmds
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PostgresCfg struct {
|
// sqlBoilerDefaultImports defines the list of default template imports.
|
||||||
User string `toml:"user"`
|
var sqlBoilerImports = imports{
|
||||||
Pass string `toml:"pass"`
|
standard: importList{
|
||||||
Host string `toml:"host"`
|
`"errors"`,
|
||||||
Port int `toml:"port"`
|
`"fmt"`,
|
||||||
DBName string `toml:"dbname"`
|
},
|
||||||
|
thirdparty: importList{
|
||||||
|
`"github.com/pobri19/sqlboiler/boil"`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
// sqlBoilerDefaultTestImports defines the list of default test template imports.
|
||||||
Postgres PostgresCfg `toml:"postgres"`
|
var sqlBoilerTestImports = imports{
|
||||||
TestPostgres *PostgresCfg `toml:"postgres_test"`
|
standard: importList{
|
||||||
|
`"testing"`,
|
||||||
|
`"os"`,
|
||||||
|
`"os/exec"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"io/ioutil"`,
|
||||||
|
`"bytes"`,
|
||||||
|
`"errors"`,
|
||||||
|
},
|
||||||
|
thirdparty: importList{
|
||||||
|
`"github.com/BurntSushi/toml"`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg *Config
|
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||||
|
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
||||||
|
var sqlBoilerTypeImports = map[string]imports{
|
||||||
|
"null.Int": imports{
|
||||||
|
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||||
|
},
|
||||||
|
"null.String": imports{
|
||||||
|
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||||
|
},
|
||||||
|
"null.Bool": imports{
|
||||||
|
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||||
|
},
|
||||||
|
"null.Float": imports{
|
||||||
|
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||||
|
},
|
||||||
|
"null.Time": imports{
|
||||||
|
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||||
|
},
|
||||||
|
"time.Time": imports{
|
||||||
|
standard: importList{`"time"`},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlBoilerConditionalDriverTestImports defines the test template imports
|
||||||
|
// for the particular database interfaces
|
||||||
|
var sqlBoilerDriverTestImports = map[string]imports{
|
||||||
|
"postgres": imports{
|
||||||
|
standard: importList{`"database/sql"`},
|
||||||
|
thirdparty: importList{`_ "github.com/lib/pq"`},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
||||||
|
// If you wish to pass a new function into your own template, add a pointer to it here.
|
||||||
|
var sqlBoilerTemplateFuncs = template.FuncMap{
|
||||||
|
"singular": singular,
|
||||||
|
"plural": plural,
|
||||||
|
"titleCase": titleCase,
|
||||||
|
"titleCaseSingular": titleCaseSingular,
|
||||||
|
"titleCasePlural": titleCasePlural,
|
||||||
|
"camelCase": camelCase,
|
||||||
|
"camelCaseSingular": camelCaseSingular,
|
||||||
|
"camelCasePlural": camelCasePlural,
|
||||||
|
"makeDBName": makeDBName,
|
||||||
|
"selectParamNames": selectParamNames,
|
||||||
|
"insertParamNames": insertParamNames,
|
||||||
|
"insertParamFlags": insertParamFlags,
|
||||||
|
"insertParamVariables": insertParamVariables,
|
||||||
|
"scanParamNames": scanParamNames,
|
||||||
|
"hasPrimaryKey": hasPrimaryKey,
|
||||||
|
"getPrimaryKey": getPrimaryKey,
|
||||||
|
"updateParamNames": updateParamNames,
|
||||||
|
"updateParamVariables": updateParamVariables,
|
||||||
|
}
|
||||||
|
|
||||||
// LoadConfigFile loads the toml config file into the cfg object
|
// LoadConfigFile loads the toml config file into the cfg object
|
||||||
func LoadConfigFile(filename string) {
|
func (cmdData *CmdData) LoadConfigFile(filename string) error {
|
||||||
|
cfg := &Config{}
|
||||||
|
|
||||||
_, err := toml.DecodeFile(filename, &cfg)
|
_, err := toml.DecodeFile(filename, &cfg)
|
||||||
|
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
fmt.Printf("Failed to find the toml configuration file %s: %s", filename, err)
|
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to decode toml configuration file:", err)
|
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If any of the test cfg variables are not present then test TestPostgres to nil
|
cmdData.Config = cfg
|
||||||
//
|
return nil
|
||||||
// As a safety precaution, set TestPostgres to nil if
|
|
||||||
// the dbname is the same as the cfg dbname. This will prevent the test
|
|
||||||
// from erasing the production database tables if someone accidently
|
|
||||||
// configures the config.toml incorrectly.
|
|
||||||
if cfg.TestPostgres != nil {
|
|
||||||
if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" ||
|
|
||||||
cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 ||
|
|
||||||
cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName {
|
|
||||||
cfg.TestPostgres = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,13 +9,11 @@ import (
|
||||||
func TestLoadConfig(t *testing.T) {
|
func TestLoadConfig(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
cmdData := &CmdData{}
|
||||||
|
|
||||||
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
|
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
|
||||||
defer os.Remove(file.Name())
|
defer os.Remove(file.Name())
|
||||||
|
|
||||||
if cfg != nil {
|
|
||||||
t.Errorf("Expected cfgs to be empty for the time being.")
|
|
||||||
}
|
|
||||||
|
|
||||||
fContents := `[postgres]
|
fContents := `[postgres]
|
||||||
host="localhost"
|
host="localhost"
|
||||||
port=5432
|
port=5432
|
||||||
|
@ -24,37 +22,14 @@ func TestLoadConfig(t *testing.T) {
|
||||||
dbname="mydb"`
|
dbname="mydb"`
|
||||||
|
|
||||||
file.WriteString(fContents)
|
file.WriteString(fContents)
|
||||||
LoadConfigFile(file.Name())
|
err := cmdData.LoadConfigFile(file.Name())
|
||||||
|
if err != nil {
|
||||||
if cfg.TestPostgres != nil || cfg.Postgres.Host != "localhost" ||
|
t.Errorf("Unable to load config file: %s", err)
|
||||||
cfg.Postgres.User != "user" || cfg.Postgres.Pass != "pass" ||
|
|
||||||
cfg.Postgres.DBName != "mydb" || cfg.Postgres.Port != 5432 {
|
|
||||||
t.Errorf("Config failed to load properly, got: %#v", cfg.Postgres)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fContents = `
|
if cmdData.Config.Postgres.Host != "localhost" ||
|
||||||
[postgres_test]
|
cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
|
||||||
host="localhost"
|
cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
|
||||||
port=5432
|
t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
|
||||||
user="testuser"
|
|
||||||
pass="testpass"`
|
|
||||||
|
|
||||||
file.WriteString(fContents)
|
|
||||||
LoadConfigFile(file.Name())
|
|
||||||
|
|
||||||
if cfg.TestPostgres != nil {
|
|
||||||
t.Errorf("Test config failed to load properly, got: %#v", cfg.Postgres)
|
|
||||||
}
|
|
||||||
|
|
||||||
fContents = `
|
|
||||||
dbname="testmydb"`
|
|
||||||
|
|
||||||
file.WriteString(fContents)
|
|
||||||
LoadConfigFile(file.Name())
|
|
||||||
|
|
||||||
if cfg.TestPostgres.DBName != "testmydb" || cfg.TestPostgres.Host != "localhost" ||
|
|
||||||
cfg.TestPostgres.User != "testuser" || cfg.TestPostgres.Pass != "testpass" ||
|
|
||||||
cfg.TestPostgres.Port != 5432 {
|
|
||||||
t.Errorf("Test config failed to load properly, got: %#v", cfg.Postgres)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func combineImports(a, b imports) imports {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func combineConditionalTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
||||||
tmpImp := imports{
|
tmpImp := imports{
|
||||||
standard: make(importList, len(a.standard)),
|
standard: make(importList, len(a.standard)),
|
||||||
thirdparty: make(importList, len(a.thirdparty)),
|
thirdparty: make(importList, len(a.thirdparty)),
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCombineConditionalTypeImports(t *testing.T) {
|
func TestCombineTypeImports(t *testing.T) {
|
||||||
imports1 := imports{
|
imports1 := imports{
|
||||||
standard: importList{
|
standard: importList{
|
||||||
`"errors"`,
|
`"errors"`,
|
||||||
|
@ -45,7 +45,7 @@ func TestCombineConditionalTypeImports(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res1 := combineConditionalTypeImports(imports1, sqlBoilerConditionalTypeImports, cols)
|
res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
|
||||||
|
|
||||||
if !reflect.DeepEqual(res1, importsExpected) {
|
if !reflect.DeepEqual(res1, importsExpected) {
|
||||||
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
|
||||||
|
@ -63,7 +63,7 @@ func TestCombineConditionalTypeImports(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res2 := combineConditionalTypeImports(imports2, sqlBoilerConditionalTypeImports, cols)
|
res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
|
||||||
|
|
||||||
if !reflect.DeepEqual(res2, importsExpected) {
|
if !reflect.DeepEqual(res2, importsExpected) {
|
||||||
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)
|
||||||
|
|
|
@ -4,69 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CobraRunFunc declares the cobra.Command.Run function definition
|
|
||||||
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
|
||||||
|
|
||||||
// CmdData holds the table schema a slice of (column name, column type) slices.
|
|
||||||
// It also holds a slice of all of the table names sqlboiler is generating against,
|
|
||||||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
|
||||||
// output file, if one is specified with a flag.
|
|
||||||
type CmdData struct {
|
|
||||||
Tables []dbdrivers.Table
|
|
||||||
PkgName string
|
|
||||||
OutFolder string
|
|
||||||
Interface dbdrivers.Interface
|
|
||||||
DriverName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// tplData is used to pass data to the template
|
|
||||||
type tplData struct {
|
|
||||||
Table dbdrivers.Table
|
|
||||||
PkgName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// errorQuit displays an error message and then exits the application.
|
|
||||||
func errorQuit(err error) {
|
|
||||||
fmt.Println(fmt.Sprintf("Error: %s\n---\n\nRun 'sqlboiler --help' for usage.", err))
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultRun is the default function passed to the commands cobra.Command.Run.
|
|
||||||
// It will generate the specific commands template and send it to outHandler for output.
|
|
||||||
func defaultRun(cmd *cobra.Command, args []string) {
|
|
||||||
// Generate the template for every table
|
|
||||||
for _, t := range cmdData.Tables {
|
|
||||||
data := &tplData{
|
|
||||||
Table: t,
|
|
||||||
PkgName: cmdData.PkgName,
|
|
||||||
}
|
|
||||||
|
|
||||||
templater(cmd, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// templater generates the template by passing it the tplData object.
|
|
||||||
// Once the template is generated, it will add the imports to the output stream
|
|
||||||
// and output the contents of the template with the added bits (imports and package declaration).
|
|
||||||
func templater(cmd *cobra.Command, data *tplData) {
|
|
||||||
// outHandler takes a slice of byte slices, so append the Template
|
|
||||||
// execution output to a [][]byte before sending it to outHandler.
|
|
||||||
out := [][]byte{generateTemplate(cmd.Name(), data)}
|
|
||||||
|
|
||||||
imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()])
|
|
||||||
imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns)
|
|
||||||
|
|
||||||
err := outHandler(cmdData.OutFolder, out, data, imps, false)
|
|
||||||
if err != nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var testHarnessStdout io.Writer = os.Stdout
|
var testHarnessStdout io.Writer = os.Stdout
|
||||||
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||||
file, err := os.Create(filename)
|
file, err := os.Create(filename)
|
||||||
|
@ -75,39 +14,39 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
||||||
|
|
||||||
// outHandler loops over the slice of byte slices, outputting them to either
|
// outHandler loops over the slice of byte slices, outputting them to either
|
||||||
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
|
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
|
||||||
func outHandler(outFolder string, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
|
||||||
out := testHarnessStdout
|
out := testHarnessStdout
|
||||||
|
|
||||||
var path string
|
var path string
|
||||||
if len(outFolder) != 0 {
|
if len(cmdData.OutFolder) != 0 {
|
||||||
if testTemplate {
|
if testTemplate {
|
||||||
path = outFolder + "/" + data.Table.Name + "_test.go"
|
path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
|
||||||
} else {
|
} else {
|
||||||
path = outFolder + "/" + data.Table.Name + ".go"
|
path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
|
||||||
}
|
}
|
||||||
|
|
||||||
outFile, err := testHarnessFileOpen(path)
|
outFile, err := testHarnessFileOpen(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))
|
return fmt.Errorf("Unable to create output file %s: %s", path, err)
|
||||||
}
|
}
|
||||||
defer outFile.Close()
|
defer outFile.Close()
|
||||||
out = outFile
|
out = outFile
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path))
|
return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
impStr := buildImportString(imps)
|
impStr := buildImportString(imps)
|
||||||
if len(impStr) > 0 {
|
if len(impStr) > 0 {
|
||||||
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to write imports to file handle: %v", err))
|
return fmt.Errorf("Unable to write imports to file handle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, templateOutput := range output {
|
for _, templateOutput := range output {
|
||||||
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err))
|
return fmt.Errorf("Unable to write template output to file handle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ func TestOutHandler(t *testing.T) {
|
||||||
|
|
||||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||||
|
|
||||||
if err := outHandler("", templateOutputs, &data, imports{}, false); err != nil {
|
if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
|
|
||||||
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
|
||||||
|
|
||||||
if err := outHandler("folder", templateOutputs, &data, imports{}, false); err != nil {
|
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -80,7 +80,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
if err := outHandler("folder", templateOutputs, &data, a1, false); err != nil {
|
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -94,7 +94,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
file = &bytes.Buffer{}
|
file = &bytes.Buffer{}
|
||||||
|
|
||||||
if err := outHandler("folder", templateOutputs, &data, a2, false); err != nil {
|
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
|
||||||
|
@ -118,7 +118,7 @@ func TestOutHandlerFiles(t *testing.T) {
|
||||||
sort.Sort(a3.standard)
|
sort.Sort(a3.standard)
|
||||||
sort.Sort(a3.thirdparty)
|
sort.Sort(a3.thirdparty)
|
||||||
|
|
||||||
if err := outHandler("folder", templateOutputs, &data, a3, false); err != nil {
|
if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
@ -18,166 +17,24 @@ const (
|
||||||
templatesTestDirectory = "/cmds/templates_test"
|
templatesTestDirectory = "/cmds/templates_test"
|
||||||
)
|
)
|
||||||
|
|
||||||
// cmdData is used globally by all commands to access the table schema data,
|
// LoadTemplates loads all template folders into the cmdData object.
|
||||||
// the database driver and the output file. cmdData is initialized by
|
func (cmdData *CmdData) LoadTemplates() error {
|
||||||
// the root SQLBoiler cobra command at run time, before other commands execute.
|
|
||||||
var cmdData *CmdData
|
|
||||||
|
|
||||||
// templates holds a slice of pointers to all templates in the templates directory.
|
|
||||||
var templates []*template.Template
|
|
||||||
|
|
||||||
// testTemplates holds a slice of pointers to all test templates in the templates directory.
|
|
||||||
var testTemplates []*template.Template
|
|
||||||
|
|
||||||
// SQLBoiler is the root app command
|
|
||||||
var SQLBoiler = &cobra.Command{
|
|
||||||
Use: "sqlboiler",
|
|
||||||
Short: "SQL Boiler generates boilerplate structs and statements",
|
|
||||||
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
|
|
||||||
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
|
|
||||||
}
|
|
||||||
|
|
||||||
// init initializes the sqlboiler flags, such as driver, table, and output file.
|
|
||||||
// It also sets the global preRun hook and postRun hook. Every command will execute
|
|
||||||
// these hooks before and after running to initialize the shared state.
|
|
||||||
func init() {
|
|
||||||
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
|
||||||
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
|
||||||
SQLBoiler.PersistentFlags().StringP("folder", "f", "", "The name of the output folder. If not specified will output to stdout")
|
|
||||||
SQLBoiler.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
|
||||||
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
|
|
||||||
SQLBoiler.PersistentPostRun = sqlBoilerPostRun
|
|
||||||
|
|
||||||
// Initialize the SQLBoiler commands and hook the custom Run functions
|
|
||||||
initCommands(SQLBoiler, sqlBoilerCommands, sqlBoilerCommandRuns)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerPostRun cleans up the output file and database connection once
|
|
||||||
// all commands are finished running.
|
|
||||||
func sqlBoilerPostRun(cmd *cobra.Command, args []string) {
|
|
||||||
cmdData.Interface.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlBoilerPreRun executes before all commands start running. Its job is to
|
|
||||||
// initialize the shared state object (cmdData). Initialization tasks include
|
|
||||||
// assigning the database driver based off the driver flag and opening the database connection,
|
|
||||||
// retrieving a list of all the tables in the database (if specific tables are not provided
|
|
||||||
// via a flag), building the table schema for use in the templates, and opening the output file
|
|
||||||
// handle if one is specified with a flag.
|
|
||||||
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
|
||||||
var err error
|
var err error
|
||||||
cmdData = &CmdData{}
|
cmdData.Templates, err = loadTemplates(templatesDirectory)
|
||||||
|
|
||||||
// Initialize the cmdData.Interface
|
|
||||||
initInterface()
|
|
||||||
|
|
||||||
// Connect to the driver database
|
|
||||||
if err = cmdData.Interface.Open(); err != nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the cmdData.Tables
|
|
||||||
initTables()
|
|
||||||
|
|
||||||
// Initialize the package name
|
|
||||||
initPkgName()
|
|
||||||
|
|
||||||
// Initialize the cmdData.OutFile
|
|
||||||
initOutFolder()
|
|
||||||
|
|
||||||
// Initialize the templates
|
|
||||||
templates, err = initTemplates(templatesDirectory)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to initialize templates: %s", err))
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the test templates if the OutFolder
|
cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
|
||||||
if cmdData.OutFolder != "" && cfg.TestPostgres != nil {
|
|
||||||
testTemplates, err = initTemplates(templatesTestDirectory)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to initialize test templates: %s", err))
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initInterface attempts to set the cmdData Interface based off the passed in
|
// loadTemplates loads all of the template files in the specified directory.
|
||||||
// driver flag value. If an invalid flag string is provided the program will exit.
|
func loadTemplates(dir string) ([]*template.Template, error) {
|
||||||
func initInterface() {
|
|
||||||
// Retrieve driver flag
|
|
||||||
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
|
||||||
if driverName == "" {
|
|
||||||
errorQuit(errors.New("Must supply a driver flag."))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a driver based off driver flag
|
|
||||||
switch driverName {
|
|
||||||
case "postgres":
|
|
||||||
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
|
||||||
cfg.Postgres.User,
|
|
||||||
cfg.Postgres.Pass,
|
|
||||||
cfg.Postgres.DBName,
|
|
||||||
cfg.Postgres.Host,
|
|
||||||
cfg.Postgres.Port,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmdData.Interface == nil {
|
|
||||||
errorQuit(errors.New("An invalid driver name was provided"))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdData.DriverName = driverName
|
|
||||||
}
|
|
||||||
|
|
||||||
// initTables will create a string slice out of the passed in table flag value
|
|
||||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
|
||||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
|
||||||
func initTables() {
|
|
||||||
// Retrieve the list of tables
|
|
||||||
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
|
|
||||||
|
|
||||||
var tableNames []string
|
|
||||||
|
|
||||||
if len(tn) != 0 {
|
|
||||||
tableNames = strings.Split(tn, ",")
|
|
||||||
for i, name := range tableNames {
|
|
||||||
tableNames[i] = strings.TrimSpace(name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
|
|
||||||
if err != nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cmdData.Tables) == 0 {
|
|
||||||
errorQuit(errors.New("No tables found in database, migrate some tables first"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the package name provided by the flag
|
|
||||||
func initPkgName() {
|
|
||||||
cmdData.PkgName = SQLBoiler.PersistentFlags().Lookup("pkgname").Value.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// initOutFile opens a file handle to the file name specified by the out flag.
|
|
||||||
// If no file name is provided, out will remain nil and future output will be
|
|
||||||
// piped to Stdout instead of to a file.
|
|
||||||
func initOutFolder() {
|
|
||||||
// open the out file filehandle
|
|
||||||
cmdData.OutFolder = SQLBoiler.PersistentFlags().Lookup("folder").Value.String()
|
|
||||||
if cmdData.OutFolder == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to make output folder: %s", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initTemplates loads all of the template files in the /cmds/templates directory
|
|
||||||
// and returns a slice of pointers to these templates.
|
|
||||||
func initTemplates(dir string) ([]*template.Template, error) {
|
|
||||||
wd, err := os.Getwd()
|
wd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -193,39 +50,166 @@ func initTemplates(dir string) ([]*template.Template, error) {
|
||||||
return tpl.Templates(), err
|
return tpl.Templates(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// initCommands loads all of the commands in the sqlBoilerCommands and hooks their run functions.
|
// SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
|
||||||
func initCommands(rootCmd *cobra.Command, commands map[string]*cobra.Command, commandRuns map[string]CobraRunFunc) {
|
func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
|
||||||
var commandNames []string
|
cmdData.Interface.Close()
|
||||||
|
return nil
|
||||||
// Build a list of command names to alphabetically sort them for ordered loading.
|
}
|
||||||
for _, c := range commands {
|
|
||||||
// Skip the boil command load, we do it manually below.
|
// SQLBoilerPreRun performs the initialization tasks before the root command is run
|
||||||
if c.Name() == "boil" {
|
func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
|
||||||
continue
|
var err error
|
||||||
}
|
|
||||||
|
// Initialize package name
|
||||||
commandNames = append(commandNames, c.Name())
|
cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String()
|
||||||
}
|
|
||||||
|
err = initInterface(cmd, cmdData.Config, cmdData)
|
||||||
// Initialize the "boil" command first, and manually. It should be at the top of the help file.
|
if err != nil {
|
||||||
commands["boil"].Run = commandRuns["boil"]
|
return err
|
||||||
rootCmd.AddCommand(commands["boil"])
|
}
|
||||||
|
|
||||||
// Load commands alphabetically. This ensures proper order of help file.
|
// Connect to the driver database
|
||||||
sort.Strings(commandNames)
|
if err = cmdData.Interface.Open(); err != nil {
|
||||||
|
return fmt.Errorf("Unable to connect to the database: %s", err)
|
||||||
// Loop every command name, load it and hook it to its Run handler
|
}
|
||||||
for _, c := range commandNames {
|
|
||||||
// If there is a commandRun for the command (matched by name)
|
err = initTables(cmd, cmdData)
|
||||||
// then set the Run hook
|
if err != nil {
|
||||||
r, ok := commandRuns[c]
|
return fmt.Errorf("Unable to initialize tables: %s", err)
|
||||||
if ok {
|
}
|
||||||
commands[c].Run = r
|
|
||||||
} else {
|
err = initOutFolder(cmd, cmdData)
|
||||||
commands[c].Run = defaultRun // Load default run if no custom run is found
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("Unable to initialize the output folder: %s", err)
|
||||||
|
}
|
||||||
// Add the command
|
|
||||||
rootCmd.AddCommand(commands[c])
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SQLBoilerRun executes every sqlboiler template and outputs them to files.
|
||||||
|
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
||||||
|
for _, table := range cmdData.Tables {
|
||||||
|
data := &tplData{
|
||||||
|
Table: table,
|
||||||
|
PkgName: cmdData.PkgName,
|
||||||
|
}
|
||||||
|
|
||||||
|
var out [][]byte
|
||||||
|
var imps imports
|
||||||
|
|
||||||
|
imps.standard = sqlBoilerImports.standard
|
||||||
|
imps.thirdparty = sqlBoilerImports.thirdparty
|
||||||
|
|
||||||
|
// Loop through and generate every command template (excluding skipTemplates)
|
||||||
|
for _, template := range cmdData.Templates {
|
||||||
|
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
|
||||||
|
resp, err := generateTemplate(template, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
out = append(out, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := outHandler(cmdData, out, data, imps, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the test templates for all commands
|
||||||
|
if len(cmdData.TestTemplates) != 0 {
|
||||||
|
var testOut [][]byte
|
||||||
|
var testImps imports
|
||||||
|
|
||||||
|
testImps.standard = sqlBoilerTestImports.standard
|
||||||
|
testImps.thirdparty = sqlBoilerTestImports.thirdparty
|
||||||
|
|
||||||
|
testImps = combineImports(testImps, sqlBoilerDriverTestImports[cmdData.DriverName])
|
||||||
|
|
||||||
|
// Loop through and generate every command test template (excluding skipTemplates)
|
||||||
|
for _, template := range cmdData.TestTemplates {
|
||||||
|
resp, err := generateTemplate(template, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
testOut = append(testOut, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = outHandler(cmdData, testOut, data, testImps, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initInterface attempts to set the cmdData Interface based off the passed in
|
||||||
|
// driver flag value. If an invalid flag string is provided an error is returned.
|
||||||
|
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
|
||||||
|
// Retrieve driver flag
|
||||||
|
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
|
||||||
|
if driverName == "" {
|
||||||
|
return errors.New("Must supply a driver flag.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a driver based off driver flag
|
||||||
|
switch driverName {
|
||||||
|
case "postgres":
|
||||||
|
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
||||||
|
cfg.Postgres.User,
|
||||||
|
cfg.Postgres.Pass,
|
||||||
|
cfg.Postgres.DBName,
|
||||||
|
cfg.Postgres.Host,
|
||||||
|
cfg.Postgres.Port,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmdData.Interface == nil {
|
||||||
|
return errors.New("An invalid driver name was provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdData.DriverName = driverName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initTables will create a string slice out of the passed in table flag value
|
||||||
|
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||||
|
// database to retrieve all "public" table names, and build a slice out of that result.
|
||||||
|
func initTables(cmd *cobra.Command, cmdData *CmdData) error {
|
||||||
|
var tableNames []string
|
||||||
|
tn := cmd.PersistentFlags().Lookup("table").Value.String()
|
||||||
|
|
||||||
|
if len(tn) != 0 {
|
||||||
|
tableNames = strings.Split(tn, ",")
|
||||||
|
for i, name := range tableNames {
|
||||||
|
tableNames[i] = strings.TrimSpace(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmdData.Tables) == 0 {
|
||||||
|
return errors.New("No tables found in database, migrate some tables first")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initOutFolder creates the folder that will hold the generated output.
|
||||||
|
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error {
|
||||||
|
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
|
||||||
|
if cmdData.OutFolder == "" {
|
||||||
|
return fmt.Errorf("No output folder specified.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
|
||||||
|
return fmt.Errorf("Unable to make output folder: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var cmdData *CmdData
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
cmdData = &CmdData{
|
cmdData = &CmdData{
|
||||||
Tables: []dbdrivers.Table{
|
Tables: []dbdrivers.Table{
|
||||||
|
@ -46,7 +48,7 @@ func TestTemplates(t *testing.T) {
|
||||||
|
|
||||||
// Initialize the templates
|
// Initialize the templates
|
||||||
var err error
|
var err error
|
||||||
templates, err = initTemplates("templates")
|
cmdData.Templates, err = loadTemplates("templates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to initialize templates: %s", err)
|
t.Fatalf("Unable to initialize templates: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -57,7 +59,7 @@ func TestTemplates(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(cmdData.OutFolder)
|
defer os.RemoveAll(cmdData.OutFolder)
|
||||||
|
|
||||||
boilRun(sqlBoilerCommands["boil"], []string{})
|
cmdData.SQLBoilerRun(nil, []string{})
|
||||||
|
|
||||||
tplFile := cmdData.OutFolder + "/templates_test.go"
|
tplFile := cmdData.OutFolder + "/templates_test.go"
|
||||||
tplTestHandle, err := os.Create(tplFile)
|
tplTestHandle, err := os.Create(tplFile)
|
||||||
|
|
|
@ -12,65 +12,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// generateTemplate generates the template associated to the passed in command name.
|
// generateTemplate generates the template associated to the passed in command name.
|
||||||
func generateTemplate(commandName string, data *tplData) []byte {
|
func generateTemplate(template *template.Template, data *tplData) ([]byte, error) {
|
||||||
template := getTemplate(commandName)
|
|
||||||
|
|
||||||
if template == nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to find the template: %s", commandName+".tpl"))
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := processTemplate(template, data)
|
output, err := processTemplate(template, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err))
|
return nil, fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return output
|
return output, nil
|
||||||
}
|
|
||||||
|
|
||||||
// generateTestTemplate generates the test template associated to the passed in command name.
|
|
||||||
func generateTestTemplate(commandName string, data *tplData) []byte {
|
|
||||||
template := getTestTemplate(commandName)
|
|
||||||
|
|
||||||
if template == nil {
|
|
||||||
return []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := processTemplate(template, data)
|
|
||||||
if err != nil {
|
|
||||||
errorQuit(fmt.Errorf("Unable to process the test template %s for table %s: %s", template.Name(), data.Table.Name, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTemplate returns a pointer to the template matching the passed in name
|
|
||||||
func getTemplate(name string) *template.Template {
|
|
||||||
var tpl *template.Template
|
|
||||||
|
|
||||||
// Find the template that matches the passed in template name
|
|
||||||
for _, t := range templates {
|
|
||||||
if t.Name() == name+".tpl" {
|
|
||||||
tpl = t
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tpl
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTemplate returns a pointer to the template matching the passed in name
|
|
||||||
func getTestTemplate(name string) *template.Template {
|
|
||||||
var tpl *template.Template
|
|
||||||
|
|
||||||
// Find the template that matches the passed in template name
|
|
||||||
for _, t := range testTemplates {
|
|
||||||
if t.Name() == name+".tpl" {
|
|
||||||
tpl = t
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tpl
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processTemplate takes a template and returns the output of the template execution.
|
// processTemplate takes a template and returns the output of the template execution.
|
||||||
|
|
53
cmds/types.go
Normal file
53
cmds/types.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CobraRunFunc declares the cobra.Command.Run function definition
|
||||||
|
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
||||||
|
|
||||||
|
// CmdData holds the table schema a slice of (column name, column type) slices.
|
||||||
|
// It also holds a slice of all of the table names sqlboiler is generating against,
|
||||||
|
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||||
|
// output file, if one is specified with a flag.
|
||||||
|
type CmdData struct {
|
||||||
|
Tables []dbdrivers.Table
|
||||||
|
PkgName string
|
||||||
|
OutFolder string
|
||||||
|
Interface dbdrivers.Interface
|
||||||
|
DriverName string
|
||||||
|
Config *Config
|
||||||
|
Templates []*template.Template
|
||||||
|
TestTemplates []*template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
// tplData is used to pass data to the template
|
||||||
|
type tplData struct {
|
||||||
|
Table dbdrivers.Table
|
||||||
|
PkgName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type importList []string
|
||||||
|
|
||||||
|
// imports defines the optional standard imports and
|
||||||
|
// thirdparty imports (from github for example)
|
||||||
|
type imports struct {
|
||||||
|
standard importList
|
||||||
|
thirdparty importList
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresCfg struct {
|
||||||
|
User string `toml:"user"`
|
||||||
|
Pass string `toml:"pass"`
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
DBName string `toml:"dbname"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Postgres PostgresCfg `toml:"postgres"`
|
||||||
|
}
|
|
@ -50,6 +50,7 @@ func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
|
||||||
var err error
|
var err error
|
||||||
if len(names) == 0 {
|
if len(names) == 0 {
|
||||||
if names, err = p.tableNames(); err != nil {
|
if names, err = p.tableNames(); err != nil {
|
||||||
|
fmt.Println("Unable to get table names.")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -59,14 +60,17 @@ func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
|
||||||
t := Table{Name: name}
|
t := Table{Name: name}
|
||||||
|
|
||||||
if t.Columns, err = p.columns(name); err != nil {
|
if t.Columns, err = p.columns(name); err != nil {
|
||||||
|
fmt.Println("Unable to get columnss.")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.PKey, err = p.primaryKeyInfo(name); err != nil {
|
if t.PKey, err = p.primaryKeyInfo(name); err != nil {
|
||||||
|
fmt.Println("Unable to get primary key info.")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
|
if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
|
||||||
|
fmt.Println("Unable to get foreign key info.")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,14 +157,25 @@ func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||||
pkey := &PrimaryKey{}
|
pkey := &PrimaryKey{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
query := ``
|
query := `SELECT conname
|
||||||
|
FROM pg_constraint
|
||||||
|
WHERE conrelid =
|
||||||
|
(SELECT oid
|
||||||
|
FROM pg_class
|
||||||
|
WHERE relname LIKE $1)
|
||||||
|
AND contype='p';`
|
||||||
|
|
||||||
row := p.dbConn.QueryRow(query, tableName)
|
row := p.dbConn.QueryRow(query, tableName)
|
||||||
if err = row.Scan(&pkey.Name); err != nil {
|
if err = row.Scan(&pkey.Name); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
queryColumns := ``
|
queryColumns := `SELECT a.attname AS column
|
||||||
|
FROM pg_index i
|
||||||
|
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
||||||
|
AND a.attnum = ANY(i.indkey)
|
||||||
|
WHERE i.indrelid = $1::regclass
|
||||||
|
AND i.indisprimary;`
|
||||||
|
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
if rows, err = p.dbConn.Query(queryColumns, tableName); err != nil {
|
if rows, err = p.dbConn.Query(queryColumns, tableName); err != nil {
|
||||||
|
@ -197,7 +212,7 @@ func (p *PostgresDriver) foreignKeyInfo(tableName string) ([]ForeignKey, error)
|
||||||
FROM information_schema.table_constraints as tc
|
FROM information_schema.table_constraints as tc
|
||||||
JOIN information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name
|
JOIN information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name
|
||||||
JOIN information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name
|
JOIN information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name
|
||||||
WHERE source_table = $1, tc.constraint_type = 'FOREIGN KEY';`
|
WHERE tc.table_name = $1 AND tc.constraint_type = 'FOREIGN KEY';`
|
||||||
|
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
|
45
main.go
45
main.go
|
@ -10,15 +10,52 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/pobri19/sqlboiler/cmds"
|
"github.com/pobri19/sqlboiler/cmds"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Load the config.toml file
|
var err error
|
||||||
cmds.LoadConfigFile("config.toml")
|
cmdData := &cmds.CmdData{}
|
||||||
|
|
||||||
|
// Load the "config.toml" global config
|
||||||
|
err = cmdData.LoadConfigFile("config.toml")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load config file: %s\n", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load all templates
|
||||||
|
err = cmdData.LoadTemplates()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load templates: %s\n", err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the cobra root command
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "sqlboiler",
|
||||||
|
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||||
|
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
|
||||||
|
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return cmdData.SQLBoilerPreRun(cmd, args)
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return cmdData.SQLBoilerRun(cmd, args)
|
||||||
|
},
|
||||||
|
PostRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return cmdData.SQLBoilerPostRun(cmd, args)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the cobra root command flags
|
||||||
|
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
|
||||||
|
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||||
|
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
|
||||||
|
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
|
||||||
|
|
||||||
// Execute SQLBoiler
|
// Execute SQLBoiler
|
||||||
if err := cmds.SQLBoiler.Execute(); err != nil {
|
if err := rootCmd.Execute(); err != nil {
|
||||||
fmt.Printf("Failed to execute SQLBoiler: %s", err)
|
|
||||||
os.Exit(-1)
|
os.Exit(-1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue