From 8a591ecdd6441bf30474e60a1c486ba879cc589c Mon Sep 17 00:00:00 2001 From: Aaron Date: Tue, 1 Mar 2016 20:05:25 -0800 Subject: [PATCH] More swaths of renaming awkward structs. --- cmds/boil.go | 6 ++--- cmds/commands.go | 1 + cmds/shared.go | 22 +++++++-------- cmds/sqlboiler.go | 48 ++++++++++++++++++--------------- cmds/template_funcs.go | 12 +++++++++ cmds/template_funcs_test.go | 7 +++++ cmds/templates/all.tpl | 30 +++++++++++++++------ cmds/templates/allby.tpl | 4 +-- cmds/templates/delete.tpl | 2 +- cmds/templates/fieldsall.tpl | 4 +-- cmds/templates/fieldsallby.tpl | 4 +-- cmds/templates/fieldsfind.tpl | 4 +-- cmds/templates/fieldsfindby.tpl | 4 +-- cmds/templates/find.tpl | 4 +-- cmds/templates/findby.tpl | 4 +-- cmds/templates/insert.tpl | 4 +-- cmds/templates/struct.tpl | 4 +-- dbdrivers/db_driver.go | 6 ++--- dbdrivers/postgres_driver.go | 4 +-- 19 files changed, 106 insertions(+), 68 deletions(-) diff --git a/cmds/boil.go b/cmds/boil.go index 8e0fc45..b4a33e8 100644 --- a/cmds/boil.go +++ b/cmds/boil.go @@ -44,10 +44,10 @@ func boilRun(cmd *cobra.Command, args []string) { // Prepend "struct" command to templateNames slice so it sits at top of sort templateNames = append([]string{"struct"}, templateNames...) - for i := 0; i < len(cmdData.TablesInfo); i++ { + for i := 0; i < len(cmdData.Columns); i++ { data := tplData{ - TableName: cmdData.TableNames[i], - TableData: cmdData.TablesInfo[i], + Table: cmdData.Tables[i], + Columns: cmdData.Columns[i], } var out [][]byte diff --git a/cmds/commands.go b/cmds/commands.go index b4a7062..4916612 100644 --- a/cmds/commands.go +++ b/cmds/commands.go @@ -69,6 +69,7 @@ var sqlBoilerTemplateFuncs = template.FuncMap{ "selectParamNames": selectParamNames, "insertParamNames": insertParamNames, "insertParamFlags": insertParamFlags, + "scanParamNames": scanParamNames, } var allCmd = &cobra.Command{ diff --git a/cmds/shared.go b/cmds/shared.go index 185de8a..54ccfb8 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -16,17 +16,17 @@ type CobraRunFunc func(cmd *cobra.Command, args []string) // the database driver chosen by the driver flag at runtime, and a pointer to the // output file, if one is specified with a flag. type CmdData struct { - TablesInfo [][]dbdrivers.DBColumn - TableNames []string - PkgName string - OutFolder string - DBDriver dbdrivers.DBDriver + Tables []string + Columns [][]dbdrivers.DBColumn + PkgName string + OutFolder string + DBDriver dbdrivers.DBDriver } // tplData is used to pass data to the template type tplData struct { - TableName string - TableData []dbdrivers.DBColumn + Table string + Columns []dbdrivers.DBColumn } // errorQuit displays an error message and then exits the application. @@ -39,10 +39,10 @@ func errorQuit(err error) { // It will generate the specific commands template and send it to outHandler for output. func defaultRun(cmd *cobra.Command, args []string) { // Generate the template for every table - for i := 0; i < len(cmdData.TablesInfo); i++ { + for i := 0; i < len(cmdData.Columns); i++ { data := tplData{ - TableName: cmdData.TableNames[i], - TableData: cmdData.TablesInfo[i], + Table: cmdData.Tables[i], + Columns: cmdData.Columns[i], } // outHandler takes a slice of byte slices, so append the Template @@ -72,7 +72,7 @@ func outHandler(output [][]byte, data *tplData) error { } } } else { // If not using stdout, attempt to create the model file. - path := cmdData.OutFolder + "/" + data.TableName + ".go" + path := cmdData.OutFolder + "/" + data.Table + ".go" out, err := os.Create(path) if err != nil { errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err)) diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index 1cfb028..a056a8b 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -13,6 +13,10 @@ import ( "github.com/spf13/cobra" ) +const ( + templatesDirectory = "/cmds/templates" +) + // cmdData is used globally by all commands to access the table schema data, // the database driver and the output file. cmdData is initialized by // the root SQLBoiler cobra command at run time, before other commands execute. @@ -68,11 +72,11 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) { errorQuit(fmt.Errorf("Unable to connect to the database: %s", err)) } - // Initialize the cmdData.TableNames - initTableNames() + // Initialize the cmdData.Tables + initTables() - // Initialize the cmdData.TablesInfo - initTablesInfo() + // Initialize the cmdData.Columns + initColumns() // Initialize the package name initPkgName() @@ -81,7 +85,7 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) { initOutFolder() // Initialize the templates - templates, err = initTemplates() + templates, err = initTemplates(templatesDirectory) if err != nil { errorQuit(fmt.Errorf("Unable to initialize templates: %s", err)) } @@ -113,46 +117,46 @@ func initDBDriver() { } } -// initTableNames will create a string slice out of the passed in table flag value +// initTables will create a string slice out of the passed in table flag value // if one is provided. If no flag is provided, it will attempt to connect to the // database to retrieve all "public" table names, and build a slice out of that result. -func initTableNames() { +func initTables() { // Retrieve the list of tables tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String() if len(tn) != 0 { - cmdData.TableNames = strings.Split(tn, ",") - for i, name := range cmdData.TableNames { - cmdData.TableNames[i] = strings.TrimSpace(name) + cmdData.Tables = strings.Split(tn, ",") + for i, name := range cmdData.Tables { + cmdData.Tables[i] = strings.TrimSpace(name) } } // If no table names are provided attempt to process all tables in database - if len(cmdData.TableNames) == 0 { + if len(cmdData.Tables) == 0 { // get all table names var err error - cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames() + cmdData.Tables, err = cmdData.DBDriver.GetAllTables() if err != nil { errorQuit(fmt.Errorf("Unable to get all table names: %s", err)) } - if len(cmdData.TableNames) == 0 { + if len(cmdData.Tables) == 0 { errorQuit(errors.New("No tables found in database, migrate some tables first")) } } } -// initTablesInfo builds a description of each table (column name, column type) -// and assigns it to cmdData.TablesInfo, the slice of dbdrivers.DBColumn slices. -func initTablesInfo() { - // loop over table Names and build TablesInfo - for i := 0; i < len(cmdData.TableNames); i++ { - tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i]) +// initColumns builds a description of each table (column name, column type) +// and assigns it to cmdData.Columns, the slice of dbdrivers.DBColumn slices. +func initColumns() { + // loop over table Names and build Columns + for i := 0; i < len(cmdData.Tables); i++ { + tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.Tables[i]) if err != nil { errorQuit(fmt.Errorf("Unable to get the table info: %s", err)) } - cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo) + cmdData.Columns = append(cmdData.Columns, tInfo) } } @@ -178,13 +182,13 @@ func initOutFolder() { // initTemplates loads all of the template files in the /cmds/templates directory // and returns a slice of pointers to these templates. -func initTemplates() ([]*template.Template, error) { +func initTemplates(dir string) ([]*template.Template, error) { wd, err := os.Getwd() if err != nil { return nil, err } - pattern := filepath.Join(wd, "/cmds/templates", "*.tpl") + pattern := filepath.Join(wd, dir, "*.tpl") tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern) if err != nil { diff --git a/cmds/template_funcs.go b/cmds/template_funcs.go index 29ad21f..0bbc01b 100644 --- a/cmds/template_funcs.go +++ b/cmds/template_funcs.go @@ -138,3 +138,15 @@ func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string { return strings.Join(selects, ", ") } + +// scanParamNames takes a []DBColumn and returns a comma seperated +// list of parameter names for use in a db.Scan() call. +func scanParamNames(object string, columns []dbdrivers.DBColumn) string { + scans := make([]string, 0, len(columns)) + for _, c := range columns { + statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name)) + scans = append(scans, statement) + } + + return strings.Join(scans, ", ") +} diff --git a/cmds/template_funcs_test.go b/cmds/template_funcs_test.go index 77b9952..96b33ce 100644 --- a/cmds/template_funcs_test.go +++ b/cmds/template_funcs_test.go @@ -71,3 +71,10 @@ func TestSelectParamFlags(t *testing.T) { t.Error("Wrong output:", out) } } + +func TestScanParams(t *testing.T) { + out := scanParamNames("object", testColumns) + if out != "&object.FriendColumn, &object.EnemyColumnThing" { + t.Error("Wrong output:", out) + } +} diff --git a/cmds/templates/all.tpl b/cmds/templates/all.tpl index 54f0ffd..d5ec46b 100644 --- a/cmds/templates/all.tpl +++ b/cmds/templates/all.tpl @@ -1,13 +1,27 @@ -{{- $tableName := .TableName -}} -// {{titleCase $tableName}}All retrieves all records. -func {{titleCase $tableName}}All(db boil.DB) ([]*{{titleCase $tableName}}, error) { - {{$varName := camelCase $tableName -}} - var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) +{{- $tableName := titleCase .Table -}} +{{- $varName := camelCase $tableName -}} +// {{$tableName}}All retrieves all records. +func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) { + var {{$varName}} []*{{$tableName}} + rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`) if err != nil { - return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) + return nil, fmt.Errorf("models: failed to query: %v", err) } - return {{$varName}}, nil + for rows.Next() { + {{$varName}}Tmp := {{$tableName}}{} + + if err := rows.Scan({{scanParamNames $varName .Columns}}); err != nil { + return nil, fmt.Errorf("models: failed to scan row: %v", err) + } + + {{$varName}} = append({{$varName}}, {{$varName}}Tmp) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("models: failed to read rows: %v", err) + } + + return {{$varName}}, nil } diff --git a/cmds/templates/allby.tpl b/cmds/templates/allby.tpl index d3d6340..6a22ccb 100644 --- a/cmds/templates/allby.tpl +++ b/cmds/templates/allby.tpl @@ -1,9 +1,9 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}AllBy retrieves all records with the specified column values. func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) { {{$varName := camelCase $tableName -}} var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/delete.tpl b/cmds/templates/delete.tpl index db0bb6a..03384d3 100644 --- a/cmds/templates/delete.tpl +++ b/cmds/templates/delete.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}Delete deletes a single record. func {{titleCase $tableName}}Delete(db boil.DB, id int) error { if id == nil { diff --git a/cmds/templates/fieldsall.tpl b/cmds/templates/fieldsall.tpl index 4cd0ac0..5796b16 100644 --- a/cmds/templates/fieldsall.tpl +++ b/cmds/templates/fieldsall.tpl @@ -1,11 +1,11 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // For example: friendName string `db:"friend_name"` func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error { {{$varName := camelCase $tableName -}} var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/fieldsallby.tpl b/cmds/templates/fieldsallby.tpl index 0eb24c8..16a7fbd 100644 --- a/cmds/templates/fieldsallby.tpl +++ b/cmds/templates/fieldsallby.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}FieldsAllBy retrieves the specified columns // for all records with the specified column values. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. @@ -6,7 +6,7 @@ func {{titleCase $tableName}}FieldsAllBy(db boil.DB, columns map[string]interface{}, results interface{}) error { {{$varName := camelCase $tableName -}} var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/fieldsfind.tpl b/cmds/templates/fieldsfind.tpl index 7b47824..f8b181a 100644 --- a/cmds/templates/fieldsfind.tpl +++ b/cmds/templates/fieldsfind.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}FieldsFind retrieves the specified columns for a single record by ID. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. // For example: friendName string `db:"friend_name"` @@ -8,7 +8,7 @@ func {{titleCase $tableName}}FieldsFind(db boil.DB, id int, results interface{}) } {{$varName := camelCase $tableName}} var {{$varName}} *{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/fieldsfindby.tpl b/cmds/templates/fieldsfindby.tpl index 4e3d114..6c9ce42 100644 --- a/cmds/templates/fieldsfindby.tpl +++ b/cmds/templates/fieldsfindby.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}FieldsFindBy retrieves the specified columns // for a single record with the specified column values. // Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. @@ -9,7 +9,7 @@ func {{titleCase $tableName}}FieldsFindBy(db boil.DB, columns map[string]interfa } {{$varName := camelCase $tableName}} var {{$varName}} *{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/find.tpl b/cmds/templates/find.tpl index fe42fb7..0a5fd32 100644 --- a/cmds/templates/find.tpl +++ b/cmds/templates/find.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}Find retrieves a single record by ID. func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) { if id == 0 { @@ -6,7 +6,7 @@ func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}} } {{$varName := camelCase $tableName}} var {{$varName}} *{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id) + err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/findby.tpl b/cmds/templates/findby.tpl index ad856e3..004d75e 100644 --- a/cmds/templates/findby.tpl +++ b/cmds/templates/findby.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}FindBy retrieves a single record with the specified column values. func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) (*{{titleCase $tableName}}, error) { if id == 0 { @@ -6,7 +6,7 @@ func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) } {{$varName := camelCase $tableName}} var {{$varName}} *{{titleCase $tableName}} - err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .TableData}} WHERE %s=$1`, column), value) + err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .Columns}} WHERE %s=$1`, column), value) if err != nil { return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) diff --git a/cmds/templates/insert.tpl b/cmds/templates/insert.tpl index 16a3b3c..0860fd7 100644 --- a/cmds/templates/insert.tpl +++ b/cmds/templates/insert.tpl @@ -1,4 +1,4 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}}Insert inserts a single record. func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) { if o == nil { @@ -6,7 +6,7 @@ func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (in } var rowID int - err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .TableData}}) VALUES({{insertParamFlags .TableData}}) RETURNING id`) + err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`) if err != nil { return 0, fmt.Errorf("model: unable to insert {{$tableName}}: %s", err) diff --git a/cmds/templates/struct.tpl b/cmds/templates/struct.tpl index 11a6bfc..5ad76b8 100644 --- a/cmds/templates/struct.tpl +++ b/cmds/templates/struct.tpl @@ -1,7 +1,7 @@ -{{- $tableName := .TableName -}} +{{- $tableName := .Table -}} // {{titleCase $tableName}} is an object representing the database table. type {{titleCase $tableName}} struct { - {{range $key, $value := .TableData -}} + {{range $key, $value := .Columns -}} {{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"` {{end -}} } diff --git a/dbdrivers/db_driver.go b/dbdrivers/db_driver.go index 058f81c..82d6126 100644 --- a/dbdrivers/db_driver.go +++ b/dbdrivers/db_driver.go @@ -4,12 +4,12 @@ package dbdrivers // type of database connection. For example, queries to obtain schema data // will vary depending on what type of database software is in use. // The goal of the DBDriver is to retrieve all table names in a database -// using GetAllTableNames() if no table names are provided via flags, +// using GetAllTables() if no table names are provided via flags, // to handle the database connection using Open() and Close(), and to // build the table information using GetTableInfo() and ParseTableInfo() type DBDriver interface { - // GetAllTableNames connects to the database and retrieves all "public" table names - GetAllTableNames() ([]string, error) + // GetAllTables connects to the database and retrieves all "public" table names + GetAllTables() ([]string, error) // GetTableInfo retrieves column information about the table. GetTableInfo(tableName string) ([]DBColumn, error) diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go index aef2576..fc7bbdd 100644 --- a/dbdrivers/postgres_driver.go +++ b/dbdrivers/postgres_driver.go @@ -44,11 +44,11 @@ func (d *PostgresDriver) Close() { d.dbConn.Close() } -// GetAllTableNames connects to the postgres database and +// GetAllTables connects to the postgres database and // retrieves all table names from the information_schema where the // table schema is public. It excludes common migration tool tables // such as gorp_migrations -func (d *PostgresDriver) GetAllTableNames() ([]string, error) { +func (d *PostgresDriver) GetAllTables() ([]string, error) { var tableNames []string rows, err := d.dbConn.Query(`select table_name from