More swaths of renaming awkward structs.

This commit is contained in:
Aaron 2016-03-01 20:05:25 -08:00
parent 4f619d4efe
commit 8a591ecdd6
19 changed files with 106 additions and 68 deletions

View file

@ -44,10 +44,10 @@ func boilRun(cmd *cobra.Command, args []string) {
// Prepend "struct" command to templateNames slice so it sits at top of sort // Prepend "struct" command to templateNames slice so it sits at top of sort
templateNames = append([]string{"struct"}, templateNames...) templateNames = append([]string{"struct"}, templateNames...)
for i := 0; i < len(cmdData.TablesInfo); i++ { for i := 0; i < len(cmdData.Columns); i++ {
data := tplData{ data := tplData{
TableName: cmdData.TableNames[i], Table: cmdData.Tables[i],
TableData: cmdData.TablesInfo[i], Columns: cmdData.Columns[i],
} }
var out [][]byte var out [][]byte

View file

@ -69,6 +69,7 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
"selectParamNames": selectParamNames, "selectParamNames": selectParamNames,
"insertParamNames": insertParamNames, "insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags, "insertParamFlags": insertParamFlags,
"scanParamNames": scanParamNames,
} }
var allCmd = &cobra.Command{ var allCmd = &cobra.Command{

View file

@ -16,8 +16,8 @@ type CobraRunFunc func(cmd *cobra.Command, args []string)
// the database driver chosen by the driver flag at runtime, and a pointer to the // the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag. // output file, if one is specified with a flag.
type CmdData struct { type CmdData struct {
TablesInfo [][]dbdrivers.DBColumn Tables []string
TableNames []string Columns [][]dbdrivers.DBColumn
PkgName string PkgName string
OutFolder string OutFolder string
DBDriver dbdrivers.DBDriver DBDriver dbdrivers.DBDriver
@ -25,8 +25,8 @@ type CmdData struct {
// tplData is used to pass data to the template // tplData is used to pass data to the template
type tplData struct { type tplData struct {
TableName string Table string
TableData []dbdrivers.DBColumn Columns []dbdrivers.DBColumn
} }
// errorQuit displays an error message and then exits the application. // errorQuit displays an error message and then exits the application.
@ -39,10 +39,10 @@ func errorQuit(err error) {
// It will generate the specific commands template and send it to outHandler for output. // It will generate the specific commands template and send it to outHandler for output.
func defaultRun(cmd *cobra.Command, args []string) { func defaultRun(cmd *cobra.Command, args []string) {
// Generate the template for every table // Generate the template for every table
for i := 0; i < len(cmdData.TablesInfo); i++ { for i := 0; i < len(cmdData.Columns); i++ {
data := tplData{ data := tplData{
TableName: cmdData.TableNames[i], Table: cmdData.Tables[i],
TableData: cmdData.TablesInfo[i], Columns: cmdData.Columns[i],
} }
// outHandler takes a slice of byte slices, so append the Template // outHandler takes a slice of byte slices, so append the Template
@ -72,7 +72,7 @@ func outHandler(output [][]byte, data *tplData) error {
} }
} }
} else { // If not using stdout, attempt to create the model file. } else { // If not using stdout, attempt to create the model file.
path := cmdData.OutFolder + "/" + data.TableName + ".go" path := cmdData.OutFolder + "/" + data.Table + ".go"
out, err := os.Create(path) out, err := os.Create(path)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err)) errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))

View file

@ -13,6 +13,10 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
const (
templatesDirectory = "/cmds/templates"
)
// cmdData is used globally by all commands to access the table schema data, // cmdData is used globally by all commands to access the table schema data,
// the database driver and the output file. cmdData is initialized by // the database driver and the output file. cmdData is initialized by
// the root SQLBoiler cobra command at run time, before other commands execute. // the root SQLBoiler cobra command at run time, before other commands execute.
@ -68,11 +72,11 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err)) errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
} }
// Initialize the cmdData.TableNames // Initialize the cmdData.Tables
initTableNames() initTables()
// Initialize the cmdData.TablesInfo // Initialize the cmdData.Columns
initTablesInfo() initColumns()
// Initialize the package name // Initialize the package name
initPkgName() initPkgName()
@ -81,7 +85,7 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
initOutFolder() initOutFolder()
// Initialize the templates // Initialize the templates
templates, err = initTemplates() templates, err = initTemplates(templatesDirectory)
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to initialize templates: %s", err)) errorQuit(fmt.Errorf("Unable to initialize templates: %s", err))
} }
@ -113,46 +117,46 @@ func initDBDriver() {
} }
} }
// initTableNames will create a string slice out of the passed in table flag value // initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the // if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result. // database to retrieve all "public" table names, and build a slice out of that result.
func initTableNames() { func initTables() {
// Retrieve the list of tables // Retrieve the list of tables
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String() tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 { if len(tn) != 0 {
cmdData.TableNames = strings.Split(tn, ",") cmdData.Tables = strings.Split(tn, ",")
for i, name := range cmdData.TableNames { for i, name := range cmdData.Tables {
cmdData.TableNames[i] = strings.TrimSpace(name) cmdData.Tables[i] = strings.TrimSpace(name)
} }
} }
// If no table names are provided attempt to process all tables in database // If no table names are provided attempt to process all tables in database
if len(cmdData.TableNames) == 0 { if len(cmdData.Tables) == 0 {
// get all table names // get all table names
var err error var err error
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames() cmdData.Tables, err = cmdData.DBDriver.GetAllTables()
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to get all table names: %s", err)) errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
} }
if len(cmdData.TableNames) == 0 { if len(cmdData.Tables) == 0 {
errorQuit(errors.New("No tables found in database, migrate some tables first")) errorQuit(errors.New("No tables found in database, migrate some tables first"))
} }
} }
} }
// initTablesInfo builds a description of each table (column name, column type) // initColumns builds a description of each table (column name, column type)
// and assigns it to cmdData.TablesInfo, the slice of dbdrivers.DBColumn slices. // and assigns it to cmdData.Columns, the slice of dbdrivers.DBColumn slices.
func initTablesInfo() { func initColumns() {
// loop over table Names and build TablesInfo // loop over table Names and build Columns
for i := 0; i < len(cmdData.TableNames); i++ { for i := 0; i < len(cmdData.Tables); i++ {
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i]) tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.Tables[i])
if err != nil { if err != nil {
errorQuit(fmt.Errorf("Unable to get the table info: %s", err)) errorQuit(fmt.Errorf("Unable to get the table info: %s", err))
} }
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo) cmdData.Columns = append(cmdData.Columns, tInfo)
} }
} }
@ -178,13 +182,13 @@ func initOutFolder() {
// initTemplates loads all of the template files in the /cmds/templates directory // initTemplates loads all of the template files in the /cmds/templates directory
// and returns a slice of pointers to these templates. // and returns a slice of pointers to these templates.
func initTemplates() ([]*template.Template, error) { func initTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
} }
pattern := filepath.Join(wd, "/cmds/templates", "*.tpl") pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern) tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil { if err != nil {

View file

@ -138,3 +138,15 @@ func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
return strings.Join(selects, ", ") return strings.Join(selects, ", ")
} }
// scanParamNames takes a []DBColumn and returns a comma seperated
// list of parameter names for use in a db.Scan() call.
func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
scans := make([]string, 0, len(columns))
for _, c := range columns {
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
scans = append(scans, statement)
}
return strings.Join(scans, ", ")
}

View file

@ -71,3 +71,10 @@ func TestSelectParamFlags(t *testing.T) {
t.Error("Wrong output:", out) t.Error("Wrong output:", out)
} }
} }
func TestScanParams(t *testing.T) {
out := scanParamNames("object", testColumns)
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
t.Error("Wrong output:", out)
}
}

View file

@ -1,12 +1,26 @@
{{- $tableName := .TableName -}} {{- $tableName := titleCase .Table -}}
// {{titleCase $tableName}}All retrieves all records. {{- $varName := camelCase $tableName -}}
func {{titleCase $tableName}}All(db boil.DB) ([]*{{titleCase $tableName}}, error) { // {{$tableName}}All retrieves all records.
{{$varName := camelCase $tableName -}} func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) {
var {{$varName}} []*{{titleCase $tableName}} var {{$varName}} []*{{$tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: failed to query: %v", err)
}
for rows.Next() {
{{$varName}}Tmp := {{$tableName}}{}
if err := rows.Scan({{scanParamNames $varName .Columns}}); err != nil {
return nil, fmt.Errorf("models: failed to scan row: %v", err)
}
{{$varName}} = append({{$varName}}, {{$varName}}Tmp)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("models: failed to read rows: %v", err)
} }
return {{$varName}}, nil return {{$varName}}, nil

View file

@ -1,9 +1,9 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}AllBy retrieves all records with the specified column values. // {{titleCase $tableName}}AllBy retrieves all records with the specified column values.
func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) { func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) {
{{$varName := camelCase $tableName -}} {{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}} var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}Delete deletes a single record. // {{titleCase $tableName}}Delete deletes a single record.
func {{titleCase $tableName}}Delete(db boil.DB, id int) error { func {{titleCase $tableName}}Delete(db boil.DB, id int) error {
if id == nil { if id == nil {

View file

@ -1,11 +1,11 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records. // {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"` // For example: friendName string `db:"friend_name"`
func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error { func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error {
{{$varName := camelCase $tableName -}} {{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}} var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsAllBy retrieves the specified columns // {{titleCase $tableName}}FieldsAllBy retrieves the specified columns
// for all records with the specified column values. // for all records with the specified column values.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
@ -6,7 +6,7 @@
func {{titleCase $tableName}}FieldsAllBy(db boil.DB, columns map[string]interface{}, results interface{}) error { func {{titleCase $tableName}}FieldsAllBy(db boil.DB, columns map[string]interface{}, results interface{}) error {
{{$varName := camelCase $tableName -}} {{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}} var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsFind retrieves the specified columns for a single record by ID. // {{titleCase $tableName}}FieldsFind retrieves the specified columns for a single record by ID.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"` // For example: friendName string `db:"friend_name"`
@ -8,7 +8,7 @@ func {{titleCase $tableName}}FieldsFind(db boil.DB, id int, results interface{})
} }
{{$varName := camelCase $tableName}} {{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}} var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsFindBy retrieves the specified columns // {{titleCase $tableName}}FieldsFindBy retrieves the specified columns
// for a single record with the specified column values. // for a single record with the specified column values.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
@ -9,7 +9,7 @@ func {{titleCase $tableName}}FieldsFindBy(db boil.DB, columns map[string]interfa
} }
{{$varName := camelCase $tableName}} {{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}} var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}Find retrieves a single record by ID. // {{titleCase $tableName}}Find retrieves a single record by ID.
func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) { func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) {
if id == 0 { if id == 0 {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}
} }
{{$varName := camelCase $tableName}} {{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}} var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}FindBy retrieves a single record with the specified column values. // {{titleCase $tableName}}FindBy retrieves a single record with the specified column values.
func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) (*{{titleCase $tableName}}, error) { func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) (*{{titleCase $tableName}}, error) {
if id == 0 { if id == 0 {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{})
} }
{{$varName := camelCase $tableName}} {{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}} var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .TableData}} WHERE %s=$1`, column), value) err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .Columns}} WHERE %s=$1`, column), value)
if err != nil { if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}}Insert inserts a single record. // {{titleCase $tableName}}Insert inserts a single record.
func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) { func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) {
if o == nil { if o == nil {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (in
} }
var rowID int var rowID int
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .TableData}}) VALUES({{insertParamFlags .TableData}}) RETURNING id`) err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`)
if err != nil { if err != nil {
return 0, fmt.Errorf("model: unable to insert {{$tableName}}: %s", err) return 0, fmt.Errorf("model: unable to insert {{$tableName}}: %s", err)

View file

@ -1,7 +1,7 @@
{{- $tableName := .TableName -}} {{- $tableName := .Table -}}
// {{titleCase $tableName}} is an object representing the database table. // {{titleCase $tableName}} is an object representing the database table.
type {{titleCase $tableName}} struct { type {{titleCase $tableName}} struct {
{{range $key, $value := .TableData -}} {{range $key, $value := .Columns -}}
{{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"` {{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"`
{{end -}} {{end -}}
} }

View file

@ -4,12 +4,12 @@ package dbdrivers
// type of database connection. For example, queries to obtain schema data // type of database connection. For example, queries to obtain schema data
// will vary depending on what type of database software is in use. // will vary depending on what type of database software is in use.
// The goal of the DBDriver is to retrieve all table names in a database // The goal of the DBDriver is to retrieve all table names in a database
// using GetAllTableNames() if no table names are provided via flags, // using GetAllTables() if no table names are provided via flags,
// to handle the database connection using Open() and Close(), and to // to handle the database connection using Open() and Close(), and to
// build the table information using GetTableInfo() and ParseTableInfo() // build the table information using GetTableInfo() and ParseTableInfo()
type DBDriver interface { type DBDriver interface {
// GetAllTableNames connects to the database and retrieves all "public" table names // GetAllTables connects to the database and retrieves all "public" table names
GetAllTableNames() ([]string, error) GetAllTables() ([]string, error)
// GetTableInfo retrieves column information about the table. // GetTableInfo retrieves column information about the table.
GetTableInfo(tableName string) ([]DBColumn, error) GetTableInfo(tableName string) ([]DBColumn, error)

View file

@ -44,11 +44,11 @@ func (d *PostgresDriver) Close() {
d.dbConn.Close() d.dbConn.Close()
} }
// GetAllTableNames connects to the postgres database and // GetAllTables connects to the postgres database and
// retrieves all table names from the information_schema where the // retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables // table schema is public. It excludes common migration tool tables
// such as gorp_migrations // such as gorp_migrations
func (d *PostgresDriver) GetAllTableNames() ([]string, error) { func (d *PostgresDriver) GetAllTables() ([]string, error) {
var tableNames []string var tableNames []string
rows, err := d.dbConn.Query(`select table_name from rows, err := d.dbConn.Query(`select table_name from