More swaths of renaming awkward structs.

This commit is contained in:
Aaron 2016-03-01 20:05:25 -08:00
parent 4f619d4efe
commit 8a591ecdd6
19 changed files with 106 additions and 68 deletions

View file

@ -44,10 +44,10 @@ func boilRun(cmd *cobra.Command, args []string) {
// Prepend "struct" command to templateNames slice so it sits at top of sort
templateNames = append([]string{"struct"}, templateNames...)
for i := 0; i < len(cmdData.TablesInfo); i++ {
for i := 0; i < len(cmdData.Columns); i++ {
data := tplData{
TableName: cmdData.TableNames[i],
TableData: cmdData.TablesInfo[i],
Table: cmdData.Tables[i],
Columns: cmdData.Columns[i],
}
var out [][]byte

View file

@ -69,6 +69,7 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags,
"scanParamNames": scanParamNames,
}
var allCmd = &cobra.Command{

View file

@ -16,17 +16,17 @@ type CobraRunFunc func(cmd *cobra.Command, args []string)
// the database driver chosen by the driver flag at runtime, and a pointer to the
// output file, if one is specified with a flag.
type CmdData struct {
TablesInfo [][]dbdrivers.DBColumn
TableNames []string
PkgName string
OutFolder string
DBDriver dbdrivers.DBDriver
Tables []string
Columns [][]dbdrivers.DBColumn
PkgName string
OutFolder string
DBDriver dbdrivers.DBDriver
}
// tplData is used to pass data to the template
type tplData struct {
TableName string
TableData []dbdrivers.DBColumn
Table string
Columns []dbdrivers.DBColumn
}
// errorQuit displays an error message and then exits the application.
@ -39,10 +39,10 @@ func errorQuit(err error) {
// It will generate the specific commands template and send it to outHandler for output.
func defaultRun(cmd *cobra.Command, args []string) {
// Generate the template for every table
for i := 0; i < len(cmdData.TablesInfo); i++ {
for i := 0; i < len(cmdData.Columns); i++ {
data := tplData{
TableName: cmdData.TableNames[i],
TableData: cmdData.TablesInfo[i],
Table: cmdData.Tables[i],
Columns: cmdData.Columns[i],
}
// outHandler takes a slice of byte slices, so append the Template
@ -72,7 +72,7 @@ func outHandler(output [][]byte, data *tplData) error {
}
}
} else { // If not using stdout, attempt to create the model file.
path := cmdData.OutFolder + "/" + data.TableName + ".go"
path := cmdData.OutFolder + "/" + data.Table + ".go"
out, err := os.Create(path)
if err != nil {
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))

View file

@ -13,6 +13,10 @@ import (
"github.com/spf13/cobra"
)
const (
templatesDirectory = "/cmds/templates"
)
// cmdData is used globally by all commands to access the table schema data,
// the database driver and the output file. cmdData is initialized by
// the root SQLBoiler cobra command at run time, before other commands execute.
@ -68,11 +72,11 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
}
// Initialize the cmdData.TableNames
initTableNames()
// Initialize the cmdData.Tables
initTables()
// Initialize the cmdData.TablesInfo
initTablesInfo()
// Initialize the cmdData.Columns
initColumns()
// Initialize the package name
initPkgName()
@ -81,7 +85,7 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
initOutFolder()
// Initialize the templates
templates, err = initTemplates()
templates, err = initTemplates(templatesDirectory)
if err != nil {
errorQuit(fmt.Errorf("Unable to initialize templates: %s", err))
}
@ -113,46 +117,46 @@ func initDBDriver() {
}
}
// initTableNames will create a string slice out of the passed in table flag value
// initTables will create a string slice out of the passed in table flag value
// if one is provided. If no flag is provided, it will attempt to connect to the
// database to retrieve all "public" table names, and build a slice out of that result.
func initTableNames() {
func initTables() {
// Retrieve the list of tables
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 {
cmdData.TableNames = strings.Split(tn, ",")
for i, name := range cmdData.TableNames {
cmdData.TableNames[i] = strings.TrimSpace(name)
cmdData.Tables = strings.Split(tn, ",")
for i, name := range cmdData.Tables {
cmdData.Tables[i] = strings.TrimSpace(name)
}
}
// If no table names are provided attempt to process all tables in database
if len(cmdData.TableNames) == 0 {
if len(cmdData.Tables) == 0 {
// get all table names
var err error
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
cmdData.Tables, err = cmdData.DBDriver.GetAllTables()
if err != nil {
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
}
if len(cmdData.TableNames) == 0 {
if len(cmdData.Tables) == 0 {
errorQuit(errors.New("No tables found in database, migrate some tables first"))
}
}
}
// initTablesInfo builds a description of each table (column name, column type)
// and assigns it to cmdData.TablesInfo, the slice of dbdrivers.DBColumn slices.
func initTablesInfo() {
// loop over table Names and build TablesInfo
for i := 0; i < len(cmdData.TableNames); i++ {
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
// initColumns builds a description of each table (column name, column type)
// and assigns it to cmdData.Columns, the slice of dbdrivers.DBColumn slices.
func initColumns() {
// loop over table Names and build Columns
for i := 0; i < len(cmdData.Tables); i++ {
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.Tables[i])
if err != nil {
errorQuit(fmt.Errorf("Unable to get the table info: %s", err))
}
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
cmdData.Columns = append(cmdData.Columns, tInfo)
}
}
@ -178,13 +182,13 @@ func initOutFolder() {
// initTemplates loads all of the template files in the /cmds/templates directory
// and returns a slice of pointers to these templates.
func initTemplates() ([]*template.Template, error) {
func initTemplates(dir string) ([]*template.Template, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
pattern := filepath.Join(wd, "/cmds/templates", "*.tpl")
pattern := filepath.Join(wd, dir, "*.tpl")
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
if err != nil {

View file

@ -138,3 +138,15 @@ func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
return strings.Join(selects, ", ")
}
// scanParamNames takes a []DBColumn and returns a comma seperated
// list of parameter names for use in a db.Scan() call.
func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
scans := make([]string, 0, len(columns))
for _, c := range columns {
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
scans = append(scans, statement)
}
return strings.Join(scans, ", ")
}

View file

@ -71,3 +71,10 @@ func TestSelectParamFlags(t *testing.T) {
t.Error("Wrong output:", out)
}
}
func TestScanParams(t *testing.T) {
out := scanParamNames("object", testColumns)
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
t.Error("Wrong output:", out)
}
}

View file

@ -1,13 +1,27 @@
{{- $tableName := .TableName -}}
// {{titleCase $tableName}}All retrieves all records.
func {{titleCase $tableName}}All(db boil.DB) ([]*{{titleCase $tableName}}, error) {
{{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
{{- $tableName := titleCase .Table -}}
{{- $varName := camelCase $tableName -}}
// {{$tableName}}All retrieves all records.
func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) {
var {{$varName}} []*{{$tableName}}
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
return nil, fmt.Errorf("models: failed to query: %v", err)
}
return {{$varName}}, nil
for rows.Next() {
{{$varName}}Tmp := {{$tableName}}{}
if err := rows.Scan({{scanParamNames $varName .Columns}}); err != nil {
return nil, fmt.Errorf("models: failed to scan row: %v", err)
}
{{$varName}} = append({{$varName}}, {{$varName}}Tmp)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("models: failed to read rows: %v", err)
}
return {{$varName}}, nil
}

View file

@ -1,9 +1,9 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}AllBy retrieves all records with the specified column values.
func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) {
{{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Delete deletes a single record.
func {{titleCase $tableName}}Delete(db boil.DB, id int) error {
if id == nil {

View file

@ -1,11 +1,11 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error {
{{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsAllBy retrieves the specified columns
// for all records with the specified column values.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
@ -6,7 +6,7 @@
func {{titleCase $tableName}}FieldsAllBy(db boil.DB, columns map[string]interface{}, results interface{}) error {
{{$varName := camelCase $tableName -}}
var {{$varName}} []*{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsFind retrieves the specified columns for a single record by ID.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
@ -8,7 +8,7 @@ func {{titleCase $tableName}}FieldsFind(db boil.DB, id int, results interface{})
}
{{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FieldsFindBy retrieves the specified columns
// for a single record with the specified column values.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
@ -9,7 +9,7 @@ func {{titleCase $tableName}}FieldsFindBy(db boil.DB, columns map[string]interfa
}
{{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Find retrieves a single record by ID.
func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) {
if id == 0 {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}
}
{{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FindBy retrieves a single record with the specified column values.
func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) (*{{titleCase $tableName}}, error) {
if id == 0 {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{})
}
{{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .TableData}} WHERE %s=$1`, column), value)
err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .Columns}} WHERE %s=$1`, column), value)
if err != nil {
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)

View file

@ -1,4 +1,4 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Insert inserts a single record.
func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) {
if o == nil {
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (in
}
var rowID int
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .TableData}}) VALUES({{insertParamFlags .TableData}}) RETURNING id`)
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`)
if err != nil {
return 0, fmt.Errorf("model: unable to insert {{$tableName}}: %s", err)

View file

@ -1,7 +1,7 @@
{{- $tableName := .TableName -}}
{{- $tableName := .Table -}}
// {{titleCase $tableName}} is an object representing the database table.
type {{titleCase $tableName}} struct {
{{range $key, $value := .TableData -}}
{{range $key, $value := .Columns -}}
{{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"`
{{end -}}
}

View file

@ -4,12 +4,12 @@ package dbdrivers
// type of database connection. For example, queries to obtain schema data
// will vary depending on what type of database software is in use.
// The goal of the DBDriver is to retrieve all table names in a database
// using GetAllTableNames() if no table names are provided via flags,
// using GetAllTables() if no table names are provided via flags,
// to handle the database connection using Open() and Close(), and to
// build the table information using GetTableInfo() and ParseTableInfo()
type DBDriver interface {
// GetAllTableNames connects to the database and retrieves all "public" table names
GetAllTableNames() ([]string, error)
// GetAllTables connects to the database and retrieves all "public" table names
GetAllTables() ([]string, error)
// GetTableInfo retrieves column information about the table.
GetTableInfo(tableName string) ([]DBColumn, error)

View file

@ -44,11 +44,11 @@ func (d *PostgresDriver) Close() {
d.dbConn.Close()
}
// GetAllTableNames connects to the postgres database and
// GetAllTables connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables
// such as gorp_migrations
func (d *PostgresDriver) GetAllTableNames() ([]string, error) {
func (d *PostgresDriver) GetAllTables() ([]string, error) {
var tableNames []string
rows, err := d.dbConn.Query(`select table_name from