From c84d35d394af5371aba80093787ece9d011a5fe0 Mon Sep 17 00:00:00 2001
From: Patrick O'brien <pobri19@gmail.com>
Date: Mon, 4 Apr 2016 20:28:58 +1000
Subject: [PATCH] Fixed compiler errors and broken tests

---
 cmds/config.go               |   6 +-
 cmds/sqlboiler.go            |   2 +-
 cmds/template_funcs.go       |  64 ++++++++++++++-----
 cmds/template_funcs_test.go  | 116 ++++++++++++++++++++++++++++++++---
 cmds/templates/delete.tpl    |   5 +-
 cmds/templates/update.tpl    |   8 +--
 dbdrivers/postgres_driver.go |   5 ++
 7 files changed, 172 insertions(+), 34 deletions(-)

diff --git a/cmds/config.go b/cmds/config.go
index b17a843..1116da8 100644
--- a/cmds/config.go
+++ b/cmds/config.go
@@ -85,7 +85,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
 	"insertParamVariables": insertParamVariables,
 	"scanParamNames":       scanParamNames,
 	"hasPrimaryKey":        hasPrimaryKey,
-	"getPrimaryKey":        getPrimaryKey,
+	"wherePrimaryKey":      wherePrimaryKey,
+	"paramsPrimaryKey":     paramsPrimaryKey,
+	"primaryKeyFlagIndex":  primaryKeyFlagIndex,
 	"updateParamNames":     updateParamNames,
 	"updateParamVariables": updateParamVariables,
 }
@@ -104,6 +106,6 @@ func (c *CmdData) LoadConfigFile(filename string) error {
 		return fmt.Errorf("Failed to decode toml configuration file: %s", err)
 	}
 
-	cmdData.Config = cfg
+	c.Config = cfg
 	return nil
 }
diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go
index e42cf7a..e880910 100644
--- a/cmds/sqlboiler.go
+++ b/cmds/sqlboiler.go
@@ -170,7 +170,7 @@ func initTables(tableName string, cmdData *CmdData) error {
 	}
 
 	var err error
-	cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
+	cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
 	if err != nil {
 		return fmt.Errorf("Unable to get all table names: %s", err)
 	}
diff --git a/cmds/template_funcs.go b/cmds/template_funcs.go
index 3511c33..839aa17 100644
--- a/cmds/template_funcs.go
+++ b/cmds/template_funcs.go
@@ -107,11 +107,11 @@ func makeDBName(tableName, colName string) string {
 // list of parameter names for the update statement template SET clause.
 // eg: col1=$1,col2=$2,col3=$3
 // Note: updateParamNames will exclude the PRIMARY KEY column.
-func updateParamNames(columns []dbdrivers.Column) string {
+func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
 	names := make([]string, 0, len(columns))
 	counter := 0
 	for _, c := range columns {
-		if c.IsPrimaryKey {
+		if isPrimaryKey(c.Name, pkeyColumns) {
 			continue
 		}
 		counter++
@@ -124,11 +124,11 @@ func updateParamNames(columns []dbdrivers.Column) string {
 // comma seperated list of parameter variable names for the update statement.
 // eg: prefix("o."), column("name_id") -> "o.NameID, ..."
 // Note: updateParamVariables will exclude the PRIMARY KEY column.
-func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
+func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
 	names := make([]string, 0, len(columns))
 
 	for _, c := range columns {
-		if c.IsPrimaryKey {
+		if isPrimaryKey(c.Name, pkeyColumns) {
 			continue
 		}
 		n := prefix + titleCase(c.Name)
@@ -138,6 +138,17 @@ func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
 	return strings.Join(names, ", ")
 }
 
+// isPrimaryKey checks if the column is found in the primary key columns
+func isPrimaryKey(col string, pkeyCols []string) bool {
+	for _, pkey := range pkeyCols {
+		if pkey == col {
+			return true
+		}
+	}
+
+	return false
+}
+
 // insertParamNames takes a []Column and returns a comma seperated
 // list of parameter names for the insert statement template.
 func insertParamNames(columns []dbdrivers.Column) string {
@@ -199,23 +210,48 @@ func scanParamNames(object string, columns []dbdrivers.Column) string {
 }
 
 // hasPrimaryKey returns true if one of the columns passed in is a primary key
-func hasPrimaryKey(columns []dbdrivers.Column) bool {
-	for _, c := range columns {
-		if c.IsPrimaryKey {
-			return true
+func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
+	if pKey == nil || len(pKey.Columns) == 0 {
+		return false
+	}
+
+	return true
+}
+
+// wherePrimaryKey returns the where clause using start as the $ flag index
+// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
+func wherePrimaryKey(pkeyCols []string, start int) string {
+	var output string
+	for i, c := range pkeyCols {
+		output = fmt.Sprintf("%s%s=$%d", output, c, start)
+		start++
+
+		if i < len(pkeyCols)-1 {
+			output = fmt.Sprintf("%s AND ", output)
 		}
 	}
 
-	return false
+	return output
 }
 
-// getPrimaryKey returns the primary key column name if one is present
-func getPrimaryKey(columns []dbdrivers.Column) string {
+// paramsPrimaryKey returns the parameters for the sql statement $ flags
+// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
+func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
+	names := make([]string, 0, len(columns))
+
 	for _, c := range columns {
-		if c.IsPrimaryKey {
-			return c.Name
+		var n string
+		if shouldTitleCase {
+			n = prefix + titleCase(c)
+		} else {
+			n = prefix + c
 		}
+		names = append(names, n)
 	}
 
-	return ""
+	return strings.Join(names, ", ")
+}
+
+func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
+	return len(regularCols) - len(pkeyCols) + 1
 }
diff --git a/cmds/template_funcs_test.go b/cmds/template_funcs_test.go
index e17825a..e0e1e22 100644
--- a/cmds/template_funcs_test.go
+++ b/cmds/template_funcs_test.go
@@ -7,8 +7,8 @@ import (
 )
 
 var testColumns = []dbdrivers.Column{
-	{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
-	{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
+	{Name: "friend_column", Type: "int", IsNullable: false},
+	{Name: "enemy_column_thing", Type: "string", IsNullable: true},
 }
 
 func TestSingular(t *testing.T) {
@@ -99,12 +99,12 @@ func TestUpdateParamNames(t *testing.T) {
 	t.Parallel()
 
 	var testCols = []dbdrivers.Column{
-		{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
-		{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
-		{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
+		{Name: "id", Type: "int", IsNullable: false},
+		{Name: "friend_column", Type: "int", IsNullable: false},
+		{Name: "enemy_column_thing", Type: "string", IsNullable: true},
 	}
 
-	out := updateParamNames(testCols)
+	out := updateParamNames(testCols, []string{"id"})
 	if out != "friend_column=$1,enemy_column_thing=$2" {
 		t.Error("Wrong output:", out)
 	}
@@ -114,12 +114,12 @@ func TestUpdateParamVariables(t *testing.T) {
 	t.Parallel()
 
 	var testCols = []dbdrivers.Column{
-		{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
-		{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
-		{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
+		{Name: "id", Type: "int", IsNullable: false},
+		{Name: "friend_column", Type: "int", IsNullable: false},
+		{Name: "enemy_column_thing", Type: "string", IsNullable: true},
 	}
 
-	out := updateParamVariables("o.", testCols)
+	out := updateParamVariables("o.", testCols, []string{"id"})
 	if out != "o.FriendColumn, o.EnemyColumnThing" {
 		t.Error("Wrong output:", out)
 	}
@@ -167,3 +167,99 @@ func TestScanParams(t *testing.T) {
 		t.Error("Wrong output:", out)
 	}
 }
+
+func TestHasPrimaryKey(t *testing.T) {
+	t.Parallel()
+
+	var pkey *dbdrivers.PrimaryKey
+	if hasPrimaryKey(pkey) {
+		t.Errorf("1) Expected false, got true")
+	}
+
+	pkey = &dbdrivers.PrimaryKey{}
+	if hasPrimaryKey(pkey) {
+		t.Errorf("2) Expected false, got true")
+	}
+
+	pkey.Columns = append(pkey.Columns, "test")
+	if !hasPrimaryKey(pkey) {
+		t.Errorf("3) Expected true, got false")
+	}
+}
+
+func TestParamsPrimaryKey(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		Pkey   dbdrivers.PrimaryKey
+		Prefix string
+		Should string
+	}{
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
+			Prefix: "o.", Should: "o.ColOne",
+		},
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
+			Prefix: "o.", Should: "o.ColOne, o.ColTwo",
+		},
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
+			Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree",
+		},
+	}
+
+	for i, test := range tests {
+		r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
+		if r != test.Should {
+			t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
+		}
+	}
+
+	tests2 := []struct {
+		Pkey   dbdrivers.PrimaryKey
+		Prefix string
+		Should string
+	}{
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
+			Prefix: "o.", Should: "o.col_one",
+		},
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
+			Prefix: "o.", Should: "o.col_one, o.col_two",
+		},
+		{
+			Pkey:   dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
+			Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three",
+		},
+	}
+
+	for i, test := range tests2 {
+		r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
+		if r != test.Should {
+			t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
+		}
+	}
+}
+
+func TestWherePrimaryKey(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		Pkey   dbdrivers.PrimaryKey
+		Start  int
+		Should string
+	}{
+		{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"},
+		{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"},
+		{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"},
+	}
+
+	for i, test := range tests {
+		r := wherePrimaryKey(&test.Pkey, test.Start)
+		if r != test.Should {
+			t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
+		}
+	}
+}
diff --git a/cmds/templates/delete.tpl b/cmds/templates/delete.tpl
index cf989f7..ec0ae34 100644
--- a/cmds/templates/delete.tpl
+++ b/cmds/templates/delete.tpl
@@ -13,12 +13,11 @@ func {{$tableNameSingular}}Delete(db boil.DB, id int) error {
   return nil
 }
 
-{{if hasPrimaryKey .Table.Columns -}}
+{{if hasPrimaryKey .Table.PKey -}}
 // Delete deletes a single {{$tableNameSingular}} record.
 // Delete will match against the primary key column to find the record to delete.
 func (o *{{$tableNameSingular}}) Delete(db boil.DB) error {
-  {{- $pkeyName := getPrimaryKey .Table.Columns -}}
-  _, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{$pkeyName}}=$1", o.{{titleCase $pkeyName}})
+  _, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{wherePrimaryKey .Table.PKey.Columns 1}}", {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
   if err != nil {
     return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err)
   }
diff --git a/cmds/templates/update.tpl b/cmds/templates/update.tpl
index 870155a..e271fc3 100644
--- a/cmds/templates/update.tpl
+++ b/cmds/templates/update.tpl
@@ -15,16 +15,16 @@ func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interfa
   return nil
 }
 
-{{if hasPrimaryKey .Table.Columns -}}
+{{if hasPrimaryKey .Table.PKey -}}
 // Update updates a single {{$tableNameSingular}} record.
 // Update will match against the primary key column to find the record to update.
 // WARNING: This Update method will NOT ignore nil members.
 // If you pass in nil members, those columnns will be set to null.
 func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
-  {{- $pkeyName := getPrimaryKey .Table.Columns -}}
-  _, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns}} WHERE {{$pkeyName}}=${{len .Table.Columns}}", {{updateParamVariables "o." .Table.Columns}}, o.{{titleCase $pkeyName}})
+  {{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}}
+  _, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
   if err != nil {
-    return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row using primary key {{$pkeyName}}: %s", err)
+    return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
   }
 
   return nil
diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go
index aaafda0..931a8e0 100644
--- a/dbdrivers/postgres_driver.go
+++ b/dbdrivers/postgres_driver.go
@@ -134,6 +134,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
 		return nil, err
 	}
 
+	var columns []string
 	for rows.Next() {
 		var column string
 
@@ -141,8 +142,12 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
 		if err != nil {
 			return nil, err
 		}
+
+		columns = append(columns, column)
 	}
 
+	pkey.Columns = columns
+
 	if err = rows.Err(); err != nil {
 		return nil, err
 	}