Finished most of the stripping

This commit is contained in:
Patrick O'brien 2016-03-28 01:03:14 +10:00
parent c74ed4e75f
commit 27cafdd2fb
15 changed files with 401 additions and 723 deletions

View file

@ -1,101 +0,0 @@
package cmds
import (
"sort"
"github.com/spf13/cobra"
)
var boilCmd = &cobra.Command{
Use: "boil",
Short: "Generates ALL templates by running every command alphabetically",
}
// boilRun executes every sqlboiler command, starting with structs.
func boilRun(cmd *cobra.Command, args []string) {
commandNames := buildCommandList()
// Prepend "struct" command to templateNames slice so it sits at top of sort
commandNames = append([]string{"struct"}, commandNames...)
// Create a testCommandNames with "driverName_main" prepended to the front for the test templates
// the main template initializes all of the testing assets
testCommandNames := append([]string{cmdData.DriverName + "_main"}, commandNames...)
for _, table := range cmdData.Tables {
data := &tplData{
Table: table,
PkgName: cmdData.PkgName,
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerDefaultImports.standard
imps.thirdparty = sqlBoilerDefaultImports.thirdparty
// Loop through and generate every command template (excluding skipTemplates)
for _, command := range commandNames {
imps = combineImports(imps, sqlBoilerCustomImports[command])
imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns)
out = append(out, generateTemplate(command, data))
}
err := outHandler(cmdData.OutFolder, out, data, imps, false)
if err != nil {
errorQuit(err)
}
// Generate the test templates for all commands
if len(testTemplates) != 0 {
var testOut [][]byte
var testImps imports
testImps.standard = sqlBoilerDefaultTestImports.standard
testImps.thirdparty = sqlBoilerDefaultTestImports.thirdparty
testImps = combineImports(testImps, sqlBoilerConditionalDriverTestImports[cmdData.DriverName])
// Loop through and generate every command test template (excluding skipTemplates)
for _, command := range testCommandNames {
testImps = combineImports(testImps, sqlBoilerCustomTestImports[command])
testOut = append(testOut, generateTestTemplate(command, data))
}
err = outHandler(cmdData.OutFolder, testOut, data, testImps, true)
if err != nil {
errorQuit(err)
}
}
}
}
func buildCommandList() []string {
// Exclude these commands from the output
skipCommands := []string{
"boil",
"struct",
}
var commandNames []string
// Build a list of template names
for _, c := range sqlBoilerCommands {
skip := false
for _, s := range skipCommands {
// Skip name if it's in the exclude list.
if s == c.Name() {
skip = true
break
}
}
if !skip {
commandNames = append(commandNames, c.Name())
}
}
// Sort all names alphabetically
sort.Strings(commandNames)
return commandNames
}

View file

@ -1,38 +0,0 @@
package cmds
import "testing"
func TestBuildCommandList(t *testing.T) {
list := buildCommandList()
skips := []string{"struct", "boil"}
for _, item := range list {
for _, skipItem := range skips {
if item == skipItem {
t.Errorf("Did not expect to find: %s %#v", item, list)
}
}
}
CommandNameLoop:
for cmdName := range sqlBoilerCommands {
for _, skipItem := range skips {
if cmdName == skipItem {
continue CommandNameLoop
}
}
found := false
for _, item := range list {
if item == cmdName {
found = true
break
}
}
if !found {
t.Error("Expected to find command name:", cmdName)
}
}
}

View file

@ -1,194 +0,0 @@
package cmds
import (
"text/template"
"github.com/spf13/cobra"
)
type importList []string
// imports defines the optional standard imports and
// thirdparty imports (from github for example)
type imports struct {
standard importList
thirdparty importList
}
// sqlBoilerDefaultImports defines the list of default template imports.
var sqlBoilerDefaultImports = imports{
standard: importList{
`"errors"`,
`"fmt"`,
},
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
},
}
// sqlBoilerDefaultTestImports defines the list of default test template imports.
var sqlBoilerDefaultTestImports = imports{
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"errors"`,
},
thirdparty: importList{
`"github.com/BurntSushi/toml"`,
},
}
// sqlBoilerConditionalTypeImports imports are only included in the template output
// if the database requires one of the following special types. Check TranslateColumnType
// to see the type assignments.
var sqlBoilerConditionalTypeImports = map[string]imports{
"null.Int": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.String": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Bool": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Float": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Time": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}
// sqlBoilerConditionalDriverTestImports defines the test template imports
// for the particular database interfaces
var sqlBoilerConditionalDriverTestImports = map[string]imports{
"postgres": imports{
standard: importList{`"database/sql"`},
thirdparty: importList{`_ "github.com/lib/pq"`},
},
}
var sqlBoilerCustomImports map[string]imports
var sqlBoilerCustomTestImports map[string]imports
// sqlBoilerCommands points each command to its cobra.Command variable.
//
// If you would like to add your own custom command, add it to this
// map, and point it to your <commandName>Cmd variable.
//
// Command names should match the template file name (without the file extension).
var sqlBoilerCommands = map[string]*cobra.Command{
// Command to generate all commands
"boil": boilCmd,
// Struct commands
"struct": structCmd,
// Insert commands
"insert": insertCmd,
// Select commands
"all": allCmd,
"where": whereCmd,
"select": selectCmd,
"selectwhere": selectWhereCmd,
"find": findCmd,
"findselect": findSelectCmd,
// Delete commands
"delete": deleteCmd,
// Update commands
"update": updateCmd,
}
// sqlBoilerCommandRuns points each command to its custom run function.
// If a run function is not defined here, it will use the defaultRun() default run function.
var sqlBoilerCommandRuns = map[string]CobraRunFunc{
"boil": boilRun,
}
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{
"singular": singular,
"plural": plural,
"titleCase": titleCase,
"titleCaseSingular": titleCaseSingular,
"titleCasePlural": titleCasePlural,
"camelCase": camelCase,
"camelCaseSingular": camelCaseSingular,
"camelCasePlural": camelCasePlural,
"makeDBName": makeDBName,
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags,
"insertParamVariables": insertParamVariables,
"scanParamNames": scanParamNames,
"hasPrimaryKey": hasPrimaryKey,
"getPrimaryKey": getPrimaryKey,
"updateParamNames": updateParamNames,
"updateParamVariables": updateParamVariables,
}
/* Struct commands */
var structCmd = &cobra.Command{
Use: "struct",
Short: "Generate structs from table definitions",
}
/* Insert commands */
var insertCmd = &cobra.Command{
Use: "insert",
Short: "Generate insert statement helpers from table definitions",
}
/* Select commands */
var allCmd = &cobra.Command{
Use: "all",
Short: "Generate a helper to select all records",
}
var whereCmd = &cobra.Command{
Use: "where",
Short: "Generate a helper to select all records with specific column values",
}
var selectCmd = &cobra.Command{
Use: "select",
Short: "Generate a helper to select specific fields of all records",
}
var selectWhereCmd = &cobra.Command{
Use: "selectwhere",
Short: "Generate a helper to select specific fields of records with specific column values",
}
var findCmd = &cobra.Command{
Use: "find",
Short: "Generate a helper to select a single record by ID",
}
var findSelectCmd = &cobra.Command{
Use: "findselect",
Short: "Generate a helper to select specific fields of a record by ID",
}
/* Delete commands */
var deleteCmd = &cobra.Command{
Use: "delete",
Short: "Generate delete statement helpers from table definitions",
}
/* Update commands */
var updateCmd = &cobra.Command{
Use: "update",
Short: "Generate update statement helpers from table definitions",
}

View file

@ -3,49 +3,107 @@ package cmds
import ( import (
"fmt" "fmt"
"os" "os"
"text/template"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
) )
type PostgresCfg struct { // sqlBoilerDefaultImports defines the list of default template imports.
User string `toml:"user"` var sqlBoilerImports = imports{
Pass string `toml:"pass"` standard: importList{
Host string `toml:"host"` `"errors"`,
Port int `toml:"port"` `"fmt"`,
DBName string `toml:"dbname"` },
thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`,
},
} }
type Config struct { // sqlBoilerDefaultTestImports defines the list of default test template imports.
Postgres PostgresCfg `toml:"postgres"` var sqlBoilerTestImports = imports{
TestPostgres *PostgresCfg `toml:"postgres_test"` standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"errors"`,
},
thirdparty: importList{
`"github.com/BurntSushi/toml"`,
},
} }
var cfg *Config // sqlBoilerTypeImports imports are only included in the template output if the database
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
var sqlBoilerTypeImports = map[string]imports{
"null.Int": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.String": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Bool": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Float": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Time": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"time.Time": imports{
standard: importList{`"time"`},
},
}
// sqlBoilerConditionalDriverTestImports defines the test template imports
// for the particular database interfaces
var sqlBoilerDriverTestImports = map[string]imports{
"postgres": imports{
standard: importList{`"database/sql"`},
thirdparty: importList{`_ "github.com/lib/pq"`},
},
}
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{
"singular": singular,
"plural": plural,
"titleCase": titleCase,
"titleCaseSingular": titleCaseSingular,
"titleCasePlural": titleCasePlural,
"camelCase": camelCase,
"camelCaseSingular": camelCaseSingular,
"camelCasePlural": camelCasePlural,
"makeDBName": makeDBName,
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags,
"insertParamVariables": insertParamVariables,
"scanParamNames": scanParamNames,
"hasPrimaryKey": hasPrimaryKey,
"getPrimaryKey": getPrimaryKey,
"updateParamNames": updateParamNames,
"updateParamVariables": updateParamVariables,
}
// LoadConfigFile loads the toml config file into the cfg object // LoadConfigFile loads the toml config file into the cfg object
func LoadConfigFile(filename string) { func (cmdData *CmdData) LoadConfigFile(filename string) error {
cfg := &Config{}
_, err := toml.DecodeFile(filename, &cfg) _, err := toml.DecodeFile(filename, &cfg)
if os.IsNotExist(err) { if os.IsNotExist(err) {
fmt.Printf("Failed to find the toml configuration file %s: %s", filename, err) return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
return
} }
if err != nil { if err != nil {
fmt.Println("Failed to decode toml configuration file:", err) return fmt.Errorf("Failed to decode toml configuration file: %s", err)
} }
// If any of the test cfg variables are not present then test TestPostgres to nil cmdData.Config = cfg
// return nil
// As a safety precaution, set TestPostgres to nil if
// the dbname is the same as the cfg dbname. This will prevent the test
// from erasing the production database tables if someone accidently
// configures the config.toml incorrectly.
if cfg.TestPostgres != nil {
if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" ||
cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 ||
cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName {
cfg.TestPostgres = nil
}
}
} }

View file

@ -9,13 +9,11 @@ import (
func TestLoadConfig(t *testing.T) { func TestLoadConfig(t *testing.T) {
t.Parallel() t.Parallel()
cmdData := &CmdData{}
file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest") file, _ := ioutil.TempFile(os.TempDir(), "sqlboilercfgtest")
defer os.Remove(file.Name()) defer os.Remove(file.Name())
if cfg != nil {
t.Errorf("Expected cfgs to be empty for the time being.")
}
fContents := `[postgres] fContents := `[postgres]
host="localhost" host="localhost"
port=5432 port=5432
@ -24,37 +22,14 @@ func TestLoadConfig(t *testing.T) {
dbname="mydb"` dbname="mydb"`
file.WriteString(fContents) file.WriteString(fContents)
LoadConfigFile(file.Name()) err := cmdData.LoadConfigFile(file.Name())
if err != nil {
if cfg.TestPostgres != nil || cfg.Postgres.Host != "localhost" || t.Errorf("Unable to load config file: %s", err)
cfg.Postgres.User != "user" || cfg.Postgres.Pass != "pass" ||
cfg.Postgres.DBName != "mydb" || cfg.Postgres.Port != 5432 {
t.Errorf("Config failed to load properly, got: %#v", cfg.Postgres)
} }
fContents = ` if cmdData.Config.Postgres.Host != "localhost" ||
[postgres_test] cmdData.Config.Postgres.User != "user" || cmdData.Config.Postgres.Pass != "pass" ||
host="localhost" cmdData.Config.Postgres.DBName != "mydb" || cmdData.Config.Postgres.Port != 5432 {
port=5432 t.Errorf("Config failed to load properly, got: %#v", cmdData.Config.Postgres)
user="testuser"
pass="testpass"`
file.WriteString(fContents)
LoadConfigFile(file.Name())
if cfg.TestPostgres != nil {
t.Errorf("Test config failed to load properly, got: %#v", cfg.Postgres)
}
fContents = `
dbname="testmydb"`
file.WriteString(fContents)
LoadConfigFile(file.Name())
if cfg.TestPostgres.DBName != "testmydb" || cfg.TestPostgres.Host != "localhost" ||
cfg.TestPostgres.User != "testuser" || cfg.TestPostgres.Pass != "testpass" ||
cfg.TestPostgres.Port != 5432 {
t.Errorf("Test config failed to load properly, got: %#v", cfg.Postgres)
} }
} }

View file

@ -38,7 +38,7 @@ func combineImports(a, b imports) imports {
return c return c
} }
func combineConditionalTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports { func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
tmpImp := imports{ tmpImp := imports{
standard: make(importList, len(a.standard)), standard: make(importList, len(a.standard)),
thirdparty: make(importList, len(a.thirdparty)), thirdparty: make(importList, len(a.thirdparty)),

View file

@ -7,7 +7,7 @@ import (
"github.com/pobri19/sqlboiler/dbdrivers" "github.com/pobri19/sqlboiler/dbdrivers"
) )
func TestCombineConditionalTypeImports(t *testing.T) { func TestCombineTypeImports(t *testing.T) {
imports1 := imports{ imports1 := imports{
standard: importList{ standard: importList{
`"errors"`, `"errors"`,
@ -45,7 +45,7 @@ func TestCombineConditionalTypeImports(t *testing.T) {
}, },
} }
res1 := combineConditionalTypeImports(imports1, sqlBoilerConditionalTypeImports, cols) res1 := combineTypeImports(imports1, sqlBoilerTypeImports, cols)
if !reflect.DeepEqual(res1, importsExpected) { if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
@ -63,7 +63,7 @@ func TestCombineConditionalTypeImports(t *testing.T) {
}, },
} }
res2 := combineConditionalTypeImports(imports2, sqlBoilerConditionalTypeImports, cols) res2 := combineTypeImports(imports2, sqlBoilerTypeImports, cols)
if !reflect.DeepEqual(res2, importsExpected) { if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)

View file

@ -4,69 +4,8 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
) )
// CobraRunFunc declares the cobra.Command.Run function definition
type CobraRunFunc func(cmd *cobra.Command, args []string)
// CmdData holds the table schema a slice of (column name, column type) slices.
// It also holds a slice of all of the table names sqlboiler is generating against,
// the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag.
type CmdData struct {
Tables []dbdrivers.Table
PkgName string
OutFolder string
Interface dbdrivers.Interface
DriverName string
}
// tplData is used to pass data to the template
type tplData struct {
Table dbdrivers.Table
PkgName string
}
// errorQuit displays an error message and then exits the application.
func errorQuit(err error) {
fmt.Println(fmt.Sprintf("Error: %s\n---\n\nRun 'sqlboiler --help' for usage.", err))
os.Exit(-1)
}
// defaultRun is the default function passed to the commands cobra.Command.Run.
// It will generate the specific commands template and send it to outHandler for output.
func defaultRun(cmd *cobra.Command, args []string) {
// Generate the template for every table
for _, t := range cmdData.Tables {
data := &tplData{
Table: t,
PkgName: cmdData.PkgName,
}
templater(cmd, data)
}
}
// templater generates the template by passing it the tplData object.
// Once the template is generated, it will add the imports to the output stream
// and output the contents of the template with the added bits (imports and package declaration).
func templater(cmd *cobra.Command, data *tplData) {
// outHandler takes a slice of byte slices, so append the Template
// execution output to a [][]byte before sending it to outHandler.
out := [][]byte{generateTemplate(cmd.Name(), data)}
imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()])
imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns)
err := outHandler(cmdData.OutFolder, out, data, imps, false)
if err != nil {
errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
}
}
var testHarnessStdout io.Writer = os.Stdout var testHarnessStdout io.Writer = os.Stdout
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
file, err := os.Create(filename) file, err := os.Create(filename)
@ -75,39 +14,39 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
// outHandler loops over the slice of byte slices, outputting them to either // outHandler loops over the slice of byte slices, outputting them to either
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified. // the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
func outHandler(outFolder string, output [][]byte, data *tplData, imps imports, testTemplate bool) error { func outHandler(cmdData *CmdData, output [][]byte, data *tplData, imps imports, testTemplate bool) error {
out := testHarnessStdout out := testHarnessStdout
var path string var path string
if len(outFolder) != 0 { if len(cmdData.OutFolder) != 0 {
if testTemplate { if testTemplate {
path = outFolder + "/" + data.Table.Name + "_test.go" path = cmdData.OutFolder + "/" + data.Table.Name + "_test.go"
} else { } else {
path = outFolder + "/" + data.Table.Name + ".go" path = cmdData.OutFolder + "/" + data.Table.Name + ".go"
} }
outFile, err := testHarnessFileOpen(path) outFile, err := testHarnessFileOpen(path)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err)) return fmt.Errorf("Unable to create output file %s: %s", path, err)
} }
defer outFile.Close() defer outFile.Close()
out = outFile out = outFile
} }
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil { if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
errorQuit(fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)) return fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path)
} }
impStr := buildImportString(imps) impStr := buildImportString(imps)
if len(impStr) > 0 { if len(impStr) > 0 {
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil { if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
errorQuit(fmt.Errorf("Unable to write imports to file handle: %v", err)) return fmt.Errorf("Unable to write imports to file handle: %v", err)
} }
} }
for _, templateOutput := range output { for _, templateOutput := range output {
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil { if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err)) return fmt.Errorf("Unable to write template output to file handle: %v", err)
} }
} }

View file

@ -28,7 +28,7 @@ func TestOutHandler(t *testing.T) {
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("", templateOutputs, &data, imports{}, false); err != nil { if err := outHandler(&CmdData{PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err) t.Error(err)
} }
@ -66,7 +66,7 @@ func TestOutHandlerFiles(t *testing.T) {
templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")} templateOutputs := [][]byte{[]byte("hello world"), []byte("patrick's dreams")}
if err := outHandler("folder", templateOutputs, &data, imports{}, false); err != nil { if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, imports{}, false); err != nil {
t.Error(err) t.Error(err)
} }
if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" { if out := file.String(); out != "package patrick\n\nhello world\npatrick's dreams\n" {
@ -80,7 +80,7 @@ func TestOutHandlerFiles(t *testing.T) {
} }
file = &bytes.Buffer{} file = &bytes.Buffer{}
if err := outHandler("folder", templateOutputs, &data, a1, false); err != nil { if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a1, false); err != nil {
t.Error(err) t.Error(err)
} }
if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" { if out := file.String(); out != "package patrick\n\nimport \"fmt\"\nhello world\npatrick's dreams\n" {
@ -94,7 +94,7 @@ func TestOutHandlerFiles(t *testing.T) {
} }
file = &bytes.Buffer{} file = &bytes.Buffer{}
if err := outHandler("folder", templateOutputs, &data, a2, false); err != nil { if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a2, false); err != nil {
t.Error(err) t.Error(err)
} }
if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" { if out := file.String(); out != "package patrick\n\nimport \"github.com/spf13/cobra\"\nhello world\npatrick's dreams\n" {
@ -118,7 +118,7 @@ func TestOutHandlerFiles(t *testing.T) {
sort.Sort(a3.standard) sort.Sort(a3.standard)
sort.Sort(a3.thirdparty) sort.Sort(a3.thirdparty)
if err := outHandler("folder", templateOutputs, &data, a3, false); err != nil { if err := outHandler(&CmdData{OutFolder: "folder", PkgName: "patrick"}, templateOutputs, &data, a3, false); err != nil {
t.Error(err) t.Error(err)
} }

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strings" "strings"
"text/template" "text/template"
@ -18,166 +17,24 @@ const (
templatesTestDirectory = "/cmds/templates_test" templatesTestDirectory = "/cmds/templates_test"
) )
// cmdData is used globally by all commands to access the table schema data, // LoadTemplates loads all template folders into the cmdData object.
// the database driver and the output file. cmdData is initialized by func (cmdData *CmdData) LoadTemplates() error {
// the root SQLBoiler cobra command at run time, before other commands execute.
var cmdData *CmdData
// templates holds a slice of pointers to all templates in the templates directory.
var templates []*template.Template
// testTemplates holds a slice of pointers to all test templates in the templates directory.
var testTemplates []*template.Template
// SQLBoiler is the root app command
var SQLBoiler = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
}
// init initializes the sqlboiler flags, such as driver, table, and output file.
// It also sets the global preRun hook and postRun hook. Every command will execute
// these hooks before and after running to initialize the shared state.
func init() {
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
SQLBoiler.PersistentFlags().StringP("folder", "f", "", "The name of the output folder. If not specified will output to stdout")
SQLBoiler.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
SQLBoiler.PersistentPostRun = sqlBoilerPostRun
// Initialize the SQLBoiler commands and hook the custom Run functions
initCommands(SQLBoiler, sqlBoilerCommands, sqlBoilerCommandRuns)
}
// sqlBoilerPostRun cleans up the output file and database connection once
// all commands are finished running.
func sqlBoilerPostRun(cmd *cobra.Command, args []string) {
cmdData.Interface.Close()
}
// sqlBoilerPreRun executes before all commands start running. Its job is to
// initialize the shared state object (cmdData). Initialization tasks include
// assigning the database driver based off the driver flag and opening the database connection,
// retrieving a list of all the tables in the database (if specific tables are not provided
// via a flag), building the table schema for use in the templates, and opening the output file
// handle if one is specified with a flag.
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
var err error var err error
cmdData = &CmdData{} cmdData.Templates, err = loadTemplates(templatesDirectory)
// Initialize the cmdData.Interface
initInterface()
// Connect to the driver database
if err = cmdData.Interface.Open(); err != nil {
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
}
// Initialize the cmdData.Tables
initTables()
// Initialize the package name
initPkgName()
// Initialize the cmdData.OutFile
initOutFolder()
// Initialize the templates
templates, err = initTemplates(templatesDirectory)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to initialize templates: %s", err)) return err
} }
// Initialize the test templates if the OutFolder cmdData.TestTemplates, err = loadTemplates(templatesTestDirectory)
if cmdData.OutFolder != "" && cfg.TestPostgres != nil {
testTemplates, err = initTemplates(templatesTestDirectory)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to initialize test templates: %s", err)) return err
}
} }
return nil
} }
// initInterface attempts to set the cmdData Interface based off the passed in // loadTemplates loads all of the template files in the specified directory.
// driver flag value. If an invalid flag string is provided the program will exit. func loadTemplates(dir string) ([]*template.Template, error) {
func initInterface() {
// Retrieve driver flag
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
errorQuit(errors.New("Must supply a driver flag."))
}
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver(
cfg.Postgres.User,
cfg.Postgres.Pass,
cfg.Postgres.DBName,
cfg.Postgres.Host,
cfg.Postgres.Port,
)
}
if cmdData.Interface == nil {
errorQuit(errors.New("An invalid driver name was provided"))
}
cmdData.DriverName = driverName
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTables() {
// Retrieve the list of tables
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
var tableNames []string
if len(tn) != 0 {
tableNames = strings.Split(tn, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
if err != nil {
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
}
if len(cmdData.Tables) == 0 {
errorQuit(errors.New("No tables found in database, migrate some tables first"))
}
}
// Initialize the package name provided by the flag
func initPkgName() {
cmdData.PkgName = SQLBoiler.PersistentFlags().Lookup("pkgname").Value.String()
}
// initOutFile opens a file handle to the file name specified by the out flag.
// If no file name is provided, out will remain nil and future output will be
// piped to Stdout instead of to a file.
func initOutFolder() {
// open the out file filehandle
cmdData.OutFolder = SQLBoiler.PersistentFlags().Lookup("folder").Value.String()
if cmdData.OutFolder == "" {
return
}
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
errorQuit(fmt.Errorf("Unable to make output folder: %s", err))
}
}
// initTemplates loads all of the template files in the /cmds/templates directory
// and returns a slice of pointers to these templates.
func initTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
@ -193,39 +50,166 @@ func initTemplates(dir string) ([]*template.Template, error) {
return tpl.Templates(), err return tpl.Templates(), err
} }
// initCommands loads all of the commands in the sqlBoilerCommands and hooks their run functions. // SQLBoilerPostRun cleans up the output file and db connection once all cmds are finished.
func initCommands(rootCmd *cobra.Command, commands map[string]*cobra.Command, commandRuns map[string]CobraRunFunc) { func (cmdData *CmdData) SQLBoilerPostRun(cmd *cobra.Command, args []string) error {
var commandNames []string cmdData.Interface.Close()
return nil
// Build a list of command names to alphabetically sort them for ordered loading. }
for _, c := range commands {
// Skip the boil command load, we do it manually below. // SQLBoilerPreRun performs the initialization tasks before the root command is run
if c.Name() == "boil" { func (cmdData *CmdData) SQLBoilerPreRun(cmd *cobra.Command, args []string) error {
continue var err error
}
// Initialize package name
commandNames = append(commandNames, c.Name()) cmdData.PkgName = cmd.PersistentFlags().Lookup("pkgname").Value.String()
}
err = initInterface(cmd, cmdData.Config, cmdData)
// Initialize the "boil" command first, and manually. It should be at the top of the help file. if err != nil {
commands["boil"].Run = commandRuns["boil"] return err
rootCmd.AddCommand(commands["boil"]) }
// Load commands alphabetically. This ensures proper order of help file. // Connect to the driver database
sort.Strings(commandNames) if err = cmdData.Interface.Open(); err != nil {
return fmt.Errorf("Unable to connect to the database: %s", err)
// Loop every command name, load it and hook it to its Run handler }
for _, c := range commandNames {
// If there is a commandRun for the command (matched by name) err = initTables(cmd, cmdData)
// then set the Run hook if err != nil {
r, ok := commandRuns[c] return fmt.Errorf("Unable to initialize tables: %s", err)
if ok { }
commands[c].Run = r
} else { err = initOutFolder(cmd, cmdData)
commands[c].Run = defaultRun // Load default run if no custom run is found if err != nil {
} return fmt.Errorf("Unable to initialize the output folder: %s", err)
}
// Add the command
rootCmd.AddCommand(commands[c]) return nil
} }
// SQLBoilerRun executes every sqlboiler template and outputs them to files.
func (cmdData *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
for _, table := range cmdData.Tables {
data := &tplData{
Table: table,
PkgName: cmdData.PkgName,
}
var out [][]byte
var imps imports
imps.standard = sqlBoilerImports.standard
imps.thirdparty = sqlBoilerImports.thirdparty
// Loop through and generate every command template (excluding skipTemplates)
for _, template := range cmdData.Templates {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
resp, err := generateTemplate(template, data)
if err != nil {
return err
}
out = append(out, resp)
}
err := outHandler(cmdData, out, data, imps, false)
if err != nil {
return err
}
// Generate the test templates for all commands
if len(cmdData.TestTemplates) != 0 {
var testOut [][]byte
var testImps imports
testImps.standard = sqlBoilerTestImports.standard
testImps.thirdparty = sqlBoilerTestImports.thirdparty
testImps = combineImports(testImps, sqlBoilerDriverTestImports[cmdData.DriverName])
// Loop through and generate every command test template (excluding skipTemplates)
for _, template := range cmdData.TestTemplates {
resp, err := generateTemplate(template, data)
if err != nil {
return err
}
testOut = append(testOut, resp)
}
err = outHandler(cmdData, testOut, data, testImps, true)
if err != nil {
return err
}
}
}
return nil
}
// initInterface attempts to set the cmdData Interface based off the passed in
// driver flag value. If an invalid flag string is provided an error is returned.
func initInterface(cmd *cobra.Command, cfg *Config, cmdData *CmdData) error {
// Retrieve driver flag
driverName := cmd.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
return errors.New("Must supply a driver flag.")
}
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.Interface = dbdrivers.NewPostgresDriver(
cfg.Postgres.User,
cfg.Postgres.Pass,
cfg.Postgres.DBName,
cfg.Postgres.Host,
cfg.Postgres.Port,
)
}
if cmdData.Interface == nil {
return errors.New("An invalid driver name was provided")
}
cmdData.DriverName = driverName
return nil
}
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTables(cmd *cobra.Command, cmdData *CmdData) error {
var tableNames []string
tn := cmd.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 {
tableNames = strings.Split(tn, ",")
for i, name := range tableNames {
tableNames[i] = strings.TrimSpace(name)
}
}
var err error
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err)
}
if len(cmdData.Tables) == 0 {
return errors.New("No tables found in database, migrate some tables first")
}
return nil
}
// initOutFolder creates the folder that will hold the generated output.
func initOutFolder(cmd *cobra.Command, cmdData *CmdData) error {
cmdData.OutFolder = cmd.PersistentFlags().Lookup("folder").Value.String()
if cmdData.OutFolder == "" {
return fmt.Errorf("No output folder specified.")
}
if err := os.MkdirAll(cmdData.OutFolder, os.ModePerm); err != nil {
return fmt.Errorf("Unable to make output folder: %s", err)
}
return nil
} }

View file

@ -11,6 +11,8 @@ import (
"github.com/pobri19/sqlboiler/dbdrivers" "github.com/pobri19/sqlboiler/dbdrivers"
) )
var cmdData *CmdData
func init() { func init() {
cmdData = &CmdData{ cmdData = &CmdData{
Tables: []dbdrivers.Table{ Tables: []dbdrivers.Table{
@ -46,7 +48,7 @@ func TestTemplates(t *testing.T) {
// Initialize the templates // Initialize the templates
var err error var err error
templates, err = initTemplates("templates") cmdData.Templates, err = loadTemplates("templates")
if err != nil { if err != nil {
t.Fatalf("Unable to initialize templates: %s", err) t.Fatalf("Unable to initialize templates: %s", err)
} }
@ -57,7 +59,7 @@ func TestTemplates(t *testing.T) {
} }
defer os.RemoveAll(cmdData.OutFolder) defer os.RemoveAll(cmdData.OutFolder)
boilRun(sqlBoilerCommands["boil"], []string{}) cmdData.SQLBoilerRun(nil, []string{})
tplFile := cmdData.OutFolder + "/templates_test.go" tplFile := cmdData.OutFolder + "/templates_test.go"
tplTestHandle, err := os.Create(tplFile) tplTestHandle, err := os.Create(tplFile)

View file

@ -12,65 +12,13 @@ import (
) )
// generateTemplate generates the template associated to the passed in command name. // generateTemplate generates the template associated to the passed in command name.
func generateTemplate(commandName string, data *tplData) []byte { func generateTemplate(template *template.Template, data *tplData) ([]byte, error) {
template := getTemplate(commandName)
if template == nil {
errorQuit(fmt.Errorf("Unable to find the template: %s", commandName+".tpl"))
}
output, err := processTemplate(template, data) output, err := processTemplate(template, data)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err)) return nil, fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table.Name, err)
} }
return output return output, nil
}
// generateTestTemplate generates the test template associated to the passed in command name.
func generateTestTemplate(commandName string, data *tplData) []byte {
template := getTestTemplate(commandName)
if template == nil {
return []byte{}
}
output, err := processTemplate(template, data)
if err != nil {
errorQuit(fmt.Errorf("Unable to process the test template %s for table %s: %s", template.Name(), data.Table.Name, err))
}
return output
}
// getTemplate returns a pointer to the template matching the passed in name
func getTemplate(name string) *template.Template {
var tpl *template.Template
// Find the template that matches the passed in template name
for _, t := range templates {
if t.Name() == name+".tpl" {
tpl = t
break
}
}
return tpl
}
// getTemplate returns a pointer to the template matching the passed in name
func getTestTemplate(name string) *template.Template {
var tpl *template.Template
// Find the template that matches the passed in template name
for _, t := range testTemplates {
if t.Name() == name+".tpl" {
tpl = t
break
}
}
return tpl
} }
// processTemplate takes a template and returns the output of the template execution. // processTemplate takes a template and returns the output of the template execution.

53
cmds/types.go Normal file
View file

@ -0,0 +1,53 @@
package cmds
import (
"text/template"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
// CobraRunFunc declares the cobra.Command.Run function definition
type CobraRunFunc func(cmd *cobra.Command, args []string)
// CmdData holds the table schema a slice of (column name, column type) slices.
// It also holds a slice of all of the table names sqlboiler is generating against,
// the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag.
type CmdData struct {
Tables []dbdrivers.Table
PkgName string
OutFolder string
Interface dbdrivers.Interface
DriverName string
Config *Config
Templates []*template.Template
TestTemplates []*template.Template
}
// tplData is used to pass data to the template
type tplData struct {
Table dbdrivers.Table
PkgName string
}
type importList []string
// imports defines the optional standard imports and
// thirdparty imports (from github for example)
type imports struct {
standard importList
thirdparty importList
}
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
}
type Config struct {
Postgres PostgresCfg `toml:"postgres"`
}

View file

@ -50,6 +50,7 @@ func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
var err error var err error
if len(names) == 0 { if len(names) == 0 {
if names, err = p.tableNames(); err != nil { if names, err = p.tableNames(); err != nil {
fmt.Println("Unable to get table names.")
return nil, err return nil, err
} }
} }
@ -59,14 +60,17 @@ func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
t := Table{Name: name} t := Table{Name: name}
if t.Columns, err = p.columns(name); err != nil { if t.Columns, err = p.columns(name); err != nil {
fmt.Println("Unable to get columnss.")
return nil, err return nil, err
} }
if t.PKey, err = p.primaryKeyInfo(name); err != nil { if t.PKey, err = p.primaryKeyInfo(name); err != nil {
fmt.Println("Unable to get primary key info.")
return nil, err return nil, err
} }
if t.FKeys, err = p.foreignKeyInfo(name); err != nil { if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
fmt.Println("Unable to get foreign key info.")
return nil, err return nil, err
} }
@ -153,14 +157,25 @@ func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
pkey := &PrimaryKey{} pkey := &PrimaryKey{}
var err error var err error
query := `` query := `SELECT conname
FROM pg_constraint
WHERE conrelid =
(SELECT oid
FROM pg_class
WHERE relname LIKE $1)
AND contype='p';`
row := p.dbConn.QueryRow(query, tableName) row := p.dbConn.QueryRow(query, tableName)
if err = row.Scan(&pkey.Name); err != nil { if err = row.Scan(&pkey.Name); err != nil {
return nil, err return nil, err
} }
queryColumns := `` queryColumns := `SELECT a.attname AS column
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1::regclass
AND i.indisprimary;`
var rows *sql.Rows var rows *sql.Rows
if rows, err = p.dbConn.Query(queryColumns, tableName); err != nil { if rows, err = p.dbConn.Query(queryColumns, tableName); err != nil {
@ -197,7 +212,7 @@ func (p *PostgresDriver) foreignKeyInfo(tableName string) ([]ForeignKey, error)
FROM information_schema.table_constraints as tc FROM information_schema.table_constraints as tc
JOIN information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name JOIN information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name JOIN information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name
WHERE source_table = $1, tc.constraint_type = 'FOREIGN KEY';` WHERE tc.table_name = $1 AND tc.constraint_type = 'FOREIGN KEY';`
var rows *sql.Rows var rows *sql.Rows
var err error var err error

45
main.go
View file

@ -10,15 +10,52 @@ import (
"os" "os"
"github.com/pobri19/sqlboiler/cmds" "github.com/pobri19/sqlboiler/cmds"
"github.com/spf13/cobra"
) )
func main() { func main() {
// Load the config.toml file var err error
cmds.LoadConfigFile("config.toml") cmdData := &cmds.CmdData{}
// Load the "config.toml" global config
err = cmdData.LoadConfigFile("config.toml")
if err != nil {
fmt.Printf("Failed to load config file: %s\n", err)
os.Exit(-1)
}
// Load all templates
err = cmdData.LoadTemplates()
if err != nil {
fmt.Printf("Failed to load templates: %s\n", err)
os.Exit(-1)
}
// Set up the cobra root command
var rootCmd = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements from the template files.\n" +
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
PreRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPreRun(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerRun(cmd, args)
},
PostRunE: func(cmd *cobra.Command, args []string) error {
return cmdData.SQLBoilerPostRun(cmd, args)
},
}
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml (mandatory)")
rootCmd.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
rootCmd.PersistentFlags().StringP("folder", "f", "output", "The name of the output folder")
rootCmd.PersistentFlags().StringP("pkgname", "p", "model", "The name you wish to assign to your generated package")
// Execute SQLBoiler // Execute SQLBoiler
if err := cmds.SQLBoiler.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
fmt.Printf("Failed to execute SQLBoiler: %s", err)
os.Exit(-1) os.Exit(-1)
} }
} }