Fixed compiler errors and broken tests

This commit is contained in:
Patrick O'brien 2016-04-04 20:28:58 +10:00
parent 8e0773b43e
commit c84d35d394
7 changed files with 172 additions and 34 deletions

View file

@ -85,7 +85,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
"insertParamVariables": insertParamVariables, "insertParamVariables": insertParamVariables,
"scanParamNames": scanParamNames, "scanParamNames": scanParamNames,
"hasPrimaryKey": hasPrimaryKey, "hasPrimaryKey": hasPrimaryKey,
"getPrimaryKey": getPrimaryKey, "wherePrimaryKey": wherePrimaryKey,
"paramsPrimaryKey": paramsPrimaryKey,
"primaryKeyFlagIndex": primaryKeyFlagIndex,
"updateParamNames": updateParamNames, "updateParamNames": updateParamNames,
"updateParamVariables": updateParamVariables, "updateParamVariables": updateParamVariables,
} }
@ -104,6 +106,6 @@ func (c *CmdData) LoadConfigFile(filename string) error {
return fmt.Errorf("Failed to decode toml configuration file: %s", err) return fmt.Errorf("Failed to decode toml configuration file: %s", err)
} }
cmdData.Config = cfg c.Config = cfg
return nil return nil
} }

View file

@ -170,7 +170,7 @@ func initTables(tableName string, cmdData *CmdData) error {
} }
var err error var err error
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...) cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
if err != nil { if err != nil {
return fmt.Errorf("Unable to get all table names: %s", err) return fmt.Errorf("Unable to get all table names: %s", err)
} }

View file

@ -107,11 +107,11 @@ func makeDBName(tableName, colName string) string {
// list of parameter names for the update statement template SET clause. // list of parameter names for the update statement template SET clause.
// eg: col1=$1,col2=$2,col3=$3 // eg: col1=$1,col2=$2,col3=$3
// Note: updateParamNames will exclude the PRIMARY KEY column. // Note: updateParamNames will exclude the PRIMARY KEY column.
func updateParamNames(columns []dbdrivers.Column) string { func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
counter := 0 counter := 0
for _, c := range columns { for _, c := range columns {
if c.IsPrimaryKey { if isPrimaryKey(c.Name, pkeyColumns) {
continue continue
} }
counter++ counter++
@ -124,11 +124,11 @@ func updateParamNames(columns []dbdrivers.Column) string {
// comma seperated list of parameter variable names for the update statement. // comma seperated list of parameter variable names for the update statement.
// eg: prefix("o."), column("name_id") -> "o.NameID, ..." // eg: prefix("o."), column("name_id") -> "o.NameID, ..."
// Note: updateParamVariables will exclude the PRIMARY KEY column. // Note: updateParamVariables will exclude the PRIMARY KEY column.
func updateParamVariables(prefix string, columns []dbdrivers.Column) string { func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
for _, c := range columns { for _, c := range columns {
if c.IsPrimaryKey { if isPrimaryKey(c.Name, pkeyColumns) {
continue continue
} }
n := prefix + titleCase(c.Name) n := prefix + titleCase(c.Name)
@ -138,6 +138,17 @@ func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
return strings.Join(names, ", ") return strings.Join(names, ", ")
} }
// isPrimaryKey checks if the column is found in the primary key columns
func isPrimaryKey(col string, pkeyCols []string) bool {
for _, pkey := range pkeyCols {
if pkey == col {
return true
}
}
return false
}
// insertParamNames takes a []Column and returns a comma seperated // insertParamNames takes a []Column and returns a comma seperated
// list of parameter names for the insert statement template. // list of parameter names for the insert statement template.
func insertParamNames(columns []dbdrivers.Column) string { func insertParamNames(columns []dbdrivers.Column) string {
@ -199,23 +210,48 @@ func scanParamNames(object string, columns []dbdrivers.Column) string {
} }
// hasPrimaryKey returns true if one of the columns passed in is a primary key // hasPrimaryKey returns true if one of the columns passed in is a primary key
func hasPrimaryKey(columns []dbdrivers.Column) bool { func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
for _, c := range columns { if pKey == nil || len(pKey.Columns) == 0 {
if c.IsPrimaryKey {
return true
}
}
return false return false
}
return true
} }
// getPrimaryKey returns the primary key column name if one is present // wherePrimaryKey returns the where clause using start as the $ flag index
func getPrimaryKey(columns []dbdrivers.Column) string { // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func wherePrimaryKey(pkeyCols []string, start int) string {
var output string
for i, c := range pkeyCols {
output = fmt.Sprintf("%s%s=$%d", output, c, start)
start++
if i < len(pkeyCols)-1 {
output = fmt.Sprintf("%s AND ", output)
}
}
return output
}
// paramsPrimaryKey returns the parameters for the sql statement $ flags
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
names := make([]string, 0, len(columns))
for _, c := range columns { for _, c := range columns {
if c.IsPrimaryKey { var n string
return c.Name if shouldTitleCase {
n = prefix + titleCase(c)
} else {
n = prefix + c
} }
names = append(names, n)
} }
return "" return strings.Join(names, ", ")
}
func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
return len(regularCols) - len(pkeyCols) + 1
} }

View file

@ -7,8 +7,8 @@ import (
) )
var testColumns = []dbdrivers.Column{ var testColumns = []dbdrivers.Column{
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, {Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, {Name: "enemy_column_thing", Type: "string", IsNullable: true},
} }
func TestSingular(t *testing.T) { func TestSingular(t *testing.T) {
@ -99,12 +99,12 @@ func TestUpdateParamNames(t *testing.T) {
t.Parallel() t.Parallel()
var testCols = []dbdrivers.Column{ var testCols = []dbdrivers.Column{
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true}, {Name: "id", Type: "int", IsNullable: false},
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, {Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, {Name: "enemy_column_thing", Type: "string", IsNullable: true},
} }
out := updateParamNames(testCols) out := updateParamNames(testCols, []string{"id"})
if out != "friend_column=$1,enemy_column_thing=$2" { if out != "friend_column=$1,enemy_column_thing=$2" {
t.Error("Wrong output:", out) t.Error("Wrong output:", out)
} }
@ -114,12 +114,12 @@ func TestUpdateParamVariables(t *testing.T) {
t.Parallel() t.Parallel()
var testCols = []dbdrivers.Column{ var testCols = []dbdrivers.Column{
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true}, {Name: "id", Type: "int", IsNullable: false},
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, {Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, {Name: "enemy_column_thing", Type: "string", IsNullable: true},
} }
out := updateParamVariables("o.", testCols) out := updateParamVariables("o.", testCols, []string{"id"})
if out != "o.FriendColumn, o.EnemyColumnThing" { if out != "o.FriendColumn, o.EnemyColumnThing" {
t.Error("Wrong output:", out) t.Error("Wrong output:", out)
} }
@ -167,3 +167,99 @@ func TestScanParams(t *testing.T) {
t.Error("Wrong output:", out) t.Error("Wrong output:", out)
} }
} }
func TestHasPrimaryKey(t *testing.T) {
t.Parallel()
var pkey *dbdrivers.PrimaryKey
if hasPrimaryKey(pkey) {
t.Errorf("1) Expected false, got true")
}
pkey = &dbdrivers.PrimaryKey{}
if hasPrimaryKey(pkey) {
t.Errorf("2) Expected false, got true")
}
pkey.Columns = append(pkey.Columns, "test")
if !hasPrimaryKey(pkey) {
t.Errorf("3) Expected true, got false")
}
}
func TestParamsPrimaryKey(t *testing.T) {
t.Parallel()
tests := []struct {
Pkey dbdrivers.PrimaryKey
Prefix string
Should string
}{
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
Prefix: "o.", Should: "o.ColOne",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
Prefix: "o.", Should: "o.ColOne, o.ColTwo",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree",
},
}
for i, test := range tests {
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
}
tests2 := []struct {
Pkey dbdrivers.PrimaryKey
Prefix string
Should string
}{
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
Prefix: "o.", Should: "o.col_one",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
Prefix: "o.", Should: "o.col_one, o.col_two",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three",
},
}
for i, test := range tests2 {
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
}
}
func TestWherePrimaryKey(t *testing.T) {
t.Parallel()
tests := []struct {
Pkey dbdrivers.PrimaryKey
Start int
Should string
}{
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"},
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"},
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"},
}
for i, test := range tests {
r := wherePrimaryKey(&test.Pkey, test.Start)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
}
}

View file

@ -13,12 +13,11 @@ func {{$tableNameSingular}}Delete(db boil.DB, id int) error {
return nil return nil
} }
{{if hasPrimaryKey .Table.Columns -}} {{if hasPrimaryKey .Table.PKey -}}
// Delete deletes a single {{$tableNameSingular}} record. // Delete deletes a single {{$tableNameSingular}} record.
// Delete will match against the primary key column to find the record to delete. // Delete will match against the primary key column to find the record to delete.
func (o *{{$tableNameSingular}}) Delete(db boil.DB) error { func (o *{{$tableNameSingular}}) Delete(db boil.DB) error {
{{- $pkeyName := getPrimaryKey .Table.Columns -}} _, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{wherePrimaryKey .Table.PKey.Columns 1}}", {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
_, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{$pkeyName}}=$1", o.{{titleCase $pkeyName}})
if err != nil { if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err) return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err)
} }

View file

@ -15,16 +15,16 @@ func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interfa
return nil return nil
} }
{{if hasPrimaryKey .Table.Columns -}} {{if hasPrimaryKey .Table.PKey -}}
// Update updates a single {{$tableNameSingular}} record. // Update updates a single {{$tableNameSingular}} record.
// Update will match against the primary key column to find the record to update. // Update will match against the primary key column to find the record to update.
// WARNING: This Update method will NOT ignore nil members. // WARNING: This Update method will NOT ignore nil members.
// If you pass in nil members, those columnns will be set to null. // If you pass in nil members, those columnns will be set to null.
func (o *{{$tableNameSingular}}) Update(db boil.DB) error { func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
{{- $pkeyName := getPrimaryKey .Table.Columns -}} {{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}}
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns}} WHERE {{$pkeyName}}=${{len .Table.Columns}}", {{updateParamVariables "o." .Table.Columns}}, o.{{titleCase $pkeyName}}) _, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
if err != nil { if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row using primary key {{$pkeyName}}: %s", err) return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
} }
return nil return nil

View file

@ -134,6 +134,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
return nil, err return nil, err
} }
var columns []string
for rows.Next() { for rows.Next() {
var column string var column string
@ -141,8 +142,12 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
columns = append(columns, column)
} }
pkey.Columns = columns
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }