From c84d35d394af5371aba80093787ece9d011a5fe0 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Mon, 4 Apr 2016 20:28:58 +1000 Subject: [PATCH] Fixed compiler errors and broken tests --- cmds/config.go | 6 +- cmds/sqlboiler.go | 2 +- cmds/template_funcs.go | 64 ++++++++++++++----- cmds/template_funcs_test.go | 116 ++++++++++++++++++++++++++++++++--- cmds/templates/delete.tpl | 5 +- cmds/templates/update.tpl | 8 +-- dbdrivers/postgres_driver.go | 5 ++ 7 files changed, 172 insertions(+), 34 deletions(-) diff --git a/cmds/config.go b/cmds/config.go index b17a843..1116da8 100644 --- a/cmds/config.go +++ b/cmds/config.go @@ -85,7 +85,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{ "insertParamVariables": insertParamVariables, "scanParamNames": scanParamNames, "hasPrimaryKey": hasPrimaryKey, - "getPrimaryKey": getPrimaryKey, + "wherePrimaryKey": wherePrimaryKey, + "paramsPrimaryKey": paramsPrimaryKey, + "primaryKeyFlagIndex": primaryKeyFlagIndex, "updateParamNames": updateParamNames, "updateParamVariables": updateParamVariables, } @@ -104,6 +106,6 @@ func (c *CmdData) LoadConfigFile(filename string) error { return fmt.Errorf("Failed to decode toml configuration file: %s", err) } - cmdData.Config = cfg + c.Config = cfg return nil } diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index e42cf7a..e880910 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -170,7 +170,7 @@ func initTables(tableName string, cmdData *CmdData) error { } var err error - cmdData.Tables, err = cmdData.Interface.Tables(tableNames...) + cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...) if err != nil { return fmt.Errorf("Unable to get all table names: %s", err) } diff --git a/cmds/template_funcs.go b/cmds/template_funcs.go index 3511c33..839aa17 100644 --- a/cmds/template_funcs.go +++ b/cmds/template_funcs.go @@ -107,11 +107,11 @@ func makeDBName(tableName, colName string) string { // list of parameter names for the update statement template SET clause. // eg: col1=$1,col2=$2,col3=$3 // Note: updateParamNames will exclude the PRIMARY KEY column. -func updateParamNames(columns []dbdrivers.Column) string { +func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string { names := make([]string, 0, len(columns)) counter := 0 for _, c := range columns { - if c.IsPrimaryKey { + if isPrimaryKey(c.Name, pkeyColumns) { continue } counter++ @@ -124,11 +124,11 @@ func updateParamNames(columns []dbdrivers.Column) string { // comma seperated list of parameter variable names for the update statement. // eg: prefix("o."), column("name_id") -> "o.NameID, ..." // Note: updateParamVariables will exclude the PRIMARY KEY column. -func updateParamVariables(prefix string, columns []dbdrivers.Column) string { +func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string { names := make([]string, 0, len(columns)) for _, c := range columns { - if c.IsPrimaryKey { + if isPrimaryKey(c.Name, pkeyColumns) { continue } n := prefix + titleCase(c.Name) @@ -138,6 +138,17 @@ func updateParamVariables(prefix string, columns []dbdrivers.Column) string { return strings.Join(names, ", ") } +// isPrimaryKey checks if the column is found in the primary key columns +func isPrimaryKey(col string, pkeyCols []string) bool { + for _, pkey := range pkeyCols { + if pkey == col { + return true + } + } + + return false +} + // insertParamNames takes a []Column and returns a comma seperated // list of parameter names for the insert statement template. func insertParamNames(columns []dbdrivers.Column) string { @@ -199,23 +210,48 @@ func scanParamNames(object string, columns []dbdrivers.Column) string { } // hasPrimaryKey returns true if one of the columns passed in is a primary key -func hasPrimaryKey(columns []dbdrivers.Column) bool { - for _, c := range columns { - if c.IsPrimaryKey { - return true +func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool { + if pKey == nil || len(pKey.Columns) == 0 { + return false + } + + return true +} + +// wherePrimaryKey returns the where clause using start as the $ flag index +// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" +func wherePrimaryKey(pkeyCols []string, start int) string { + var output string + for i, c := range pkeyCols { + output = fmt.Sprintf("%s%s=$%d", output, c, start) + start++ + + if i < len(pkeyCols)-1 { + output = fmt.Sprintf("%s AND ", output) } } - return false + return output } -// getPrimaryKey returns the primary key column name if one is present -func getPrimaryKey(columns []dbdrivers.Column) string { +// paramsPrimaryKey returns the parameters for the sql statement $ flags +// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2" +func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string { + names := make([]string, 0, len(columns)) + for _, c := range columns { - if c.IsPrimaryKey { - return c.Name + var n string + if shouldTitleCase { + n = prefix + titleCase(c) + } else { + n = prefix + c } + names = append(names, n) } - return "" + return strings.Join(names, ", ") +} + +func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int { + return len(regularCols) - len(pkeyCols) + 1 } diff --git a/cmds/template_funcs_test.go b/cmds/template_funcs_test.go index e17825a..e0e1e22 100644 --- a/cmds/template_funcs_test.go +++ b/cmds/template_funcs_test.go @@ -7,8 +7,8 @@ import ( ) var testColumns = []dbdrivers.Column{ - {Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, - {Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, + {Name: "friend_column", Type: "int", IsNullable: false}, + {Name: "enemy_column_thing", Type: "string", IsNullable: true}, } func TestSingular(t *testing.T) { @@ -99,12 +99,12 @@ func TestUpdateParamNames(t *testing.T) { t.Parallel() var testCols = []dbdrivers.Column{ - {Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true}, - {Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, - {Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, + {Name: "id", Type: "int", IsNullable: false}, + {Name: "friend_column", Type: "int", IsNullable: false}, + {Name: "enemy_column_thing", Type: "string", IsNullable: true}, } - out := updateParamNames(testCols) + out := updateParamNames(testCols, []string{"id"}) if out != "friend_column=$1,enemy_column_thing=$2" { t.Error("Wrong output:", out) } @@ -114,12 +114,12 @@ func TestUpdateParamVariables(t *testing.T) { t.Parallel() var testCols = []dbdrivers.Column{ - {Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true}, - {Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false}, - {Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false}, + {Name: "id", Type: "int", IsNullable: false}, + {Name: "friend_column", Type: "int", IsNullable: false}, + {Name: "enemy_column_thing", Type: "string", IsNullable: true}, } - out := updateParamVariables("o.", testCols) + out := updateParamVariables("o.", testCols, []string{"id"}) if out != "o.FriendColumn, o.EnemyColumnThing" { t.Error("Wrong output:", out) } @@ -167,3 +167,99 @@ func TestScanParams(t *testing.T) { t.Error("Wrong output:", out) } } + +func TestHasPrimaryKey(t *testing.T) { + t.Parallel() + + var pkey *dbdrivers.PrimaryKey + if hasPrimaryKey(pkey) { + t.Errorf("1) Expected false, got true") + } + + pkey = &dbdrivers.PrimaryKey{} + if hasPrimaryKey(pkey) { + t.Errorf("2) Expected false, got true") + } + + pkey.Columns = append(pkey.Columns, "test") + if !hasPrimaryKey(pkey) { + t.Errorf("3) Expected true, got false") + } +} + +func TestParamsPrimaryKey(t *testing.T) { + t.Parallel() + + tests := []struct { + Pkey dbdrivers.PrimaryKey + Prefix string + Should string + }{ + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}}, + Prefix: "o.", Should: "o.ColOne", + }, + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}}, + Prefix: "o.", Should: "o.ColOne, o.ColTwo", + }, + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}}, + Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree", + }, + } + + for i, test := range tests { + r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } + + tests2 := []struct { + Pkey dbdrivers.PrimaryKey + Prefix string + Should string + }{ + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}}, + Prefix: "o.", Should: "o.col_one", + }, + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}}, + Prefix: "o.", Should: "o.col_one, o.col_two", + }, + { + Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}}, + Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three", + }, + } + + for i, test := range tests2 { + r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } +} + +func TestWherePrimaryKey(t *testing.T) { + t.Parallel() + + tests := []struct { + Pkey dbdrivers.PrimaryKey + Start int + Should string + }{ + {Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"}, + {Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"}, + {Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"}, + } + + for i, test := range tests { + r := wherePrimaryKey(&test.Pkey, test.Start) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } +} diff --git a/cmds/templates/delete.tpl b/cmds/templates/delete.tpl index cf989f7..ec0ae34 100644 --- a/cmds/templates/delete.tpl +++ b/cmds/templates/delete.tpl @@ -13,12 +13,11 @@ func {{$tableNameSingular}}Delete(db boil.DB, id int) error { return nil } -{{if hasPrimaryKey .Table.Columns -}} +{{if hasPrimaryKey .Table.PKey -}} // Delete deletes a single {{$tableNameSingular}} record. // Delete will match against the primary key column to find the record to delete. func (o *{{$tableNameSingular}}) Delete(db boil.DB) error { - {{- $pkeyName := getPrimaryKey .Table.Columns -}} - _, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{$pkeyName}}=$1", o.{{titleCase $pkeyName}}) + _, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{wherePrimaryKey .Table.PKey.Columns 1}}", {{paramsPrimaryKey "o." .Table.PKey.Columns true}}) if err != nil { return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err) } diff --git a/cmds/templates/update.tpl b/cmds/templates/update.tpl index 870155a..e271fc3 100644 --- a/cmds/templates/update.tpl +++ b/cmds/templates/update.tpl @@ -15,16 +15,16 @@ func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interfa return nil } -{{if hasPrimaryKey .Table.Columns -}} +{{if hasPrimaryKey .Table.PKey -}} // Update updates a single {{$tableNameSingular}} record. // Update will match against the primary key column to find the record to update. // WARNING: This Update method will NOT ignore nil members. // If you pass in nil members, those columnns will be set to null. func (o *{{$tableNameSingular}}) Update(db boil.DB) error { - {{- $pkeyName := getPrimaryKey .Table.Columns -}} - _, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns}} WHERE {{$pkeyName}}=${{len .Table.Columns}}", {{updateParamVariables "o." .Table.Columns}}, o.{{titleCase $pkeyName}}) + {{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}} + _, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}}) if err != nil { - return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row using primary key {{$pkeyName}}: %s", err) + return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err) } return nil diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go index aaafda0..931a8e0 100644 --- a/dbdrivers/postgres_driver.go +++ b/dbdrivers/postgres_driver.go @@ -134,6 +134,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { return nil, err } + var columns []string for rows.Next() { var column string @@ -141,8 +142,12 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { if err != nil { return nil, err } + + columns = append(columns, column) } + pkey.Columns = columns + if err = rows.Err(); err != nil { return nil, err }