Fixed compiler errors and broken tests
This commit is contained in:
parent
8e0773b43e
commit
c84d35d394
7 changed files with 172 additions and 34 deletions
|
@ -85,7 +85,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
|
|||
"insertParamVariables": insertParamVariables,
|
||||
"scanParamNames": scanParamNames,
|
||||
"hasPrimaryKey": hasPrimaryKey,
|
||||
"getPrimaryKey": getPrimaryKey,
|
||||
"wherePrimaryKey": wherePrimaryKey,
|
||||
"paramsPrimaryKey": paramsPrimaryKey,
|
||||
"primaryKeyFlagIndex": primaryKeyFlagIndex,
|
||||
"updateParamNames": updateParamNames,
|
||||
"updateParamVariables": updateParamVariables,
|
||||
}
|
||||
|
@ -104,6 +106,6 @@ func (c *CmdData) LoadConfigFile(filename string) error {
|
|||
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
||||
}
|
||||
|
||||
cmdData.Config = cfg
|
||||
c.Config = cfg
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -170,7 +170,7 @@ func initTables(tableName string, cmdData *CmdData) error {
|
|||
}
|
||||
|
||||
var err error
|
||||
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
|
||||
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||
}
|
||||
|
|
|
@ -107,11 +107,11 @@ func makeDBName(tableName, colName string) string {
|
|||
// list of parameter names for the update statement template SET clause.
|
||||
// eg: col1=$1,col2=$2,col3=$3
|
||||
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
||||
func updateParamNames(columns []dbdrivers.Column) string {
|
||||
func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
counter := 0
|
||||
for _, c := range columns {
|
||||
if c.IsPrimaryKey {
|
||||
if isPrimaryKey(c.Name, pkeyColumns) {
|
||||
continue
|
||||
}
|
||||
counter++
|
||||
|
@ -124,11 +124,11 @@ func updateParamNames(columns []dbdrivers.Column) string {
|
|||
// comma seperated list of parameter variable names for the update statement.
|
||||
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
||||
func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
|
||||
func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
for _, c := range columns {
|
||||
if c.IsPrimaryKey {
|
||||
if isPrimaryKey(c.Name, pkeyColumns) {
|
||||
continue
|
||||
}
|
||||
n := prefix + titleCase(c.Name)
|
||||
|
@ -138,6 +138,17 @@ func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
|
|||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// isPrimaryKey checks if the column is found in the primary key columns
|
||||
func isPrimaryKey(col string, pkeyCols []string) bool {
|
||||
for _, pkey := range pkeyCols {
|
||||
if pkey == col {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// insertParamNames takes a []Column and returns a comma seperated
|
||||
// list of parameter names for the insert statement template.
|
||||
func insertParamNames(columns []dbdrivers.Column) string {
|
||||
|
@ -199,23 +210,48 @@ func scanParamNames(object string, columns []dbdrivers.Column) string {
|
|||
}
|
||||
|
||||
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
||||
func hasPrimaryKey(columns []dbdrivers.Column) bool {
|
||||
for _, c := range columns {
|
||||
if c.IsPrimaryKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
|
||||
if pKey == nil || len(pKey.Columns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// getPrimaryKey returns the primary key column name if one is present
|
||||
func getPrimaryKey(columns []dbdrivers.Column) string {
|
||||
for _, c := range columns {
|
||||
if c.IsPrimaryKey {
|
||||
return c.Name
|
||||
return true
|
||||
}
|
||||
|
||||
// wherePrimaryKey returns the where clause using start as the $ flag index
|
||||
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
||||
func wherePrimaryKey(pkeyCols []string, start int) string {
|
||||
var output string
|
||||
for i, c := range pkeyCols {
|
||||
output = fmt.Sprintf("%s%s=$%d", output, c, start)
|
||||
start++
|
||||
|
||||
if i < len(pkeyCols)-1 {
|
||||
output = fmt.Sprintf("%s AND ", output)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return output
|
||||
}
|
||||
|
||||
// paramsPrimaryKey returns the parameters for the sql statement $ flags
|
||||
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
|
||||
func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
for _, c := range columns {
|
||||
var n string
|
||||
if shouldTitleCase {
|
||||
n = prefix + titleCase(c)
|
||||
} else {
|
||||
n = prefix + c
|
||||
}
|
||||
names = append(names, n)
|
||||
}
|
||||
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
||||
return len(regularCols) - len(pkeyCols) + 1
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
)
|
||||
|
||||
var testColumns = []dbdrivers.Column{
|
||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||
}
|
||||
|
||||
func TestSingular(t *testing.T) {
|
||||
|
@ -99,12 +99,12 @@ func TestUpdateParamNames(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
var testCols = []dbdrivers.Column{
|
||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||
{Name: "id", Type: "int", IsNullable: false},
|
||||
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||
}
|
||||
|
||||
out := updateParamNames(testCols)
|
||||
out := updateParamNames(testCols, []string{"id"})
|
||||
if out != "friend_column=$1,enemy_column_thing=$2" {
|
||||
t.Error("Wrong output:", out)
|
||||
}
|
||||
|
@ -114,12 +114,12 @@ func TestUpdateParamVariables(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
var testCols = []dbdrivers.Column{
|
||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||
{Name: "id", Type: "int", IsNullable: false},
|
||||
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||
}
|
||||
|
||||
out := updateParamVariables("o.", testCols)
|
||||
out := updateParamVariables("o.", testCols, []string{"id"})
|
||||
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
||||
t.Error("Wrong output:", out)
|
||||
}
|
||||
|
@ -167,3 +167,99 @@ func TestScanParams(t *testing.T) {
|
|||
t.Error("Wrong output:", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var pkey *dbdrivers.PrimaryKey
|
||||
if hasPrimaryKey(pkey) {
|
||||
t.Errorf("1) Expected false, got true")
|
||||
}
|
||||
|
||||
pkey = &dbdrivers.PrimaryKey{}
|
||||
if hasPrimaryKey(pkey) {
|
||||
t.Errorf("2) Expected false, got true")
|
||||
}
|
||||
|
||||
pkey.Columns = append(pkey.Columns, "test")
|
||||
if !hasPrimaryKey(pkey) {
|
||||
t.Errorf("3) Expected true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamsPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
Pkey dbdrivers.PrimaryKey
|
||||
Prefix string
|
||||
Should string
|
||||
}{
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
|
||||
Prefix: "o.", Should: "o.ColOne",
|
||||
},
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
|
||||
Prefix: "o.", Should: "o.ColOne, o.ColTwo",
|
||||
},
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
|
||||
Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
|
||||
if r != test.Should {
|
||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||
}
|
||||
}
|
||||
|
||||
tests2 := []struct {
|
||||
Pkey dbdrivers.PrimaryKey
|
||||
Prefix string
|
||||
Should string
|
||||
}{
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
|
||||
Prefix: "o.", Should: "o.col_one",
|
||||
},
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
|
||||
Prefix: "o.", Should: "o.col_one, o.col_two",
|
||||
},
|
||||
{
|
||||
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
|
||||
Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests2 {
|
||||
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
|
||||
if r != test.Should {
|
||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWherePrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
Pkey dbdrivers.PrimaryKey
|
||||
Start int
|
||||
Should string
|
||||
}{
|
||||
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"},
|
||||
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"},
|
||||
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
r := wherePrimaryKey(&test.Pkey, test.Start)
|
||||
if r != test.Should {
|
||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,12 +13,11 @@ func {{$tableNameSingular}}Delete(db boil.DB, id int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
{{if hasPrimaryKey .Table.Columns -}}
|
||||
{{if hasPrimaryKey .Table.PKey -}}
|
||||
// Delete deletes a single {{$tableNameSingular}} record.
|
||||
// Delete will match against the primary key column to find the record to delete.
|
||||
func (o *{{$tableNameSingular}}) Delete(db boil.DB) error {
|
||||
{{- $pkeyName := getPrimaryKey .Table.Columns -}}
|
||||
_, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{$pkeyName}}=$1", o.{{titleCase $pkeyName}})
|
||||
_, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{wherePrimaryKey .Table.PKey.Columns 1}}", {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err)
|
||||
}
|
||||
|
|
|
@ -15,16 +15,16 @@ func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interfa
|
|||
return nil
|
||||
}
|
||||
|
||||
{{if hasPrimaryKey .Table.Columns -}}
|
||||
{{if hasPrimaryKey .Table.PKey -}}
|
||||
// Update updates a single {{$tableNameSingular}} record.
|
||||
// Update will match against the primary key column to find the record to update.
|
||||
// WARNING: This Update method will NOT ignore nil members.
|
||||
// If you pass in nil members, those columnns will be set to null.
|
||||
func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
|
||||
{{- $pkeyName := getPrimaryKey .Table.Columns -}}
|
||||
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns}} WHERE {{$pkeyName}}=${{len .Table.Columns}}", {{updateParamVariables "o." .Table.Columns}}, o.{{titleCase $pkeyName}})
|
||||
{{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}}
|
||||
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row using primary key {{$pkeyName}}: %s", err)
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -134,6 +134,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var columns []string
|
||||
for rows.Next() {
|
||||
var column string
|
||||
|
||||
|
@ -141,8 +142,12 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
pkey.Columns = columns
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue