Fixed compiler errors and broken tests
This commit is contained in:
parent
8e0773b43e
commit
c84d35d394
7 changed files with 172 additions and 34 deletions
|
@ -85,7 +85,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
|
||||||
"insertParamVariables": insertParamVariables,
|
"insertParamVariables": insertParamVariables,
|
||||||
"scanParamNames": scanParamNames,
|
"scanParamNames": scanParamNames,
|
||||||
"hasPrimaryKey": hasPrimaryKey,
|
"hasPrimaryKey": hasPrimaryKey,
|
||||||
"getPrimaryKey": getPrimaryKey,
|
"wherePrimaryKey": wherePrimaryKey,
|
||||||
|
"paramsPrimaryKey": paramsPrimaryKey,
|
||||||
|
"primaryKeyFlagIndex": primaryKeyFlagIndex,
|
||||||
"updateParamNames": updateParamNames,
|
"updateParamNames": updateParamNames,
|
||||||
"updateParamVariables": updateParamVariables,
|
"updateParamVariables": updateParamVariables,
|
||||||
}
|
}
|
||||||
|
@ -104,6 +106,6 @@ func (c *CmdData) LoadConfigFile(filename string) error {
|
||||||
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdData.Config = cfg
|
c.Config = cfg
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,7 +170,7 @@ func initTables(tableName string, cmdData *CmdData) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
cmdData.Tables, err = cmdData.Interface.Tables(tableNames...)
|
cmdData.Tables, err = dbdrivers.Tables(cmdData.Interface, tableNames...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Unable to get all table names: %s", err)
|
return fmt.Errorf("Unable to get all table names: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,11 +107,11 @@ func makeDBName(tableName, colName string) string {
|
||||||
// list of parameter names for the update statement template SET clause.
|
// list of parameter names for the update statement template SET clause.
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: col1=$1,col2=$2,col3=$3
|
||||||
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
||||||
func updateParamNames(columns []dbdrivers.Column) string {
|
func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
counter := 0
|
counter := 0
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if c.IsPrimaryKey {
|
if isPrimaryKey(c.Name, pkeyColumns) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
counter++
|
counter++
|
||||||
|
@ -124,11 +124,11 @@ func updateParamNames(columns []dbdrivers.Column) string {
|
||||||
// comma seperated list of parameter variable names for the update statement.
|
// comma seperated list of parameter variable names for the update statement.
|
||||||
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||||
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
||||||
func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
|
func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if c.IsPrimaryKey {
|
if isPrimaryKey(c.Name, pkeyColumns) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n := prefix + titleCase(c.Name)
|
n := prefix + titleCase(c.Name)
|
||||||
|
@ -138,6 +138,17 @@ func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isPrimaryKey checks if the column is found in the primary key columns
|
||||||
|
func isPrimaryKey(col string, pkeyCols []string) bool {
|
||||||
|
for _, pkey := range pkeyCols {
|
||||||
|
if pkey == col {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// insertParamNames takes a []Column and returns a comma seperated
|
// insertParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for the insert statement template.
|
// list of parameter names for the insert statement template.
|
||||||
func insertParamNames(columns []dbdrivers.Column) string {
|
func insertParamNames(columns []dbdrivers.Column) string {
|
||||||
|
@ -199,23 +210,48 @@ func scanParamNames(object string, columns []dbdrivers.Column) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
||||||
func hasPrimaryKey(columns []dbdrivers.Column) bool {
|
func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
|
||||||
for _, c := range columns {
|
if pKey == nil || len(pKey.Columns) == 0 {
|
||||||
if c.IsPrimaryKey {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPrimaryKey returns the primary key column name if one is present
|
// wherePrimaryKey returns the where clause using start as the $ flag index
|
||||||
func getPrimaryKey(columns []dbdrivers.Column) string {
|
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
||||||
|
func wherePrimaryKey(pkeyCols []string, start int) string {
|
||||||
|
var output string
|
||||||
|
for i, c := range pkeyCols {
|
||||||
|
output = fmt.Sprintf("%s%s=$%d", output, c, start)
|
||||||
|
start++
|
||||||
|
|
||||||
|
if i < len(pkeyCols)-1 {
|
||||||
|
output = fmt.Sprintf("%s AND ", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
// paramsPrimaryKey returns the parameters for the sql statement $ flags
|
||||||
|
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
|
||||||
|
func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
|
||||||
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if c.IsPrimaryKey {
|
var n string
|
||||||
return c.Name
|
if shouldTitleCase {
|
||||||
|
n = prefix + titleCase(c)
|
||||||
|
} else {
|
||||||
|
n = prefix + c
|
||||||
}
|
}
|
||||||
|
names = append(names, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return strings.Join(names, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
||||||
|
return len(regularCols) - len(pkeyCols) + 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var testColumns = []dbdrivers.Column{
|
var testColumns = []dbdrivers.Column{
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSingular(t *testing.T) {
|
func TestSingular(t *testing.T) {
|
||||||
|
@ -99,12 +99,12 @@ func TestUpdateParamNames(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var testCols = []dbdrivers.Column{
|
var testCols = []dbdrivers.Column{
|
||||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
{Name: "id", Type: "int", IsNullable: false},
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
out := updateParamNames(testCols)
|
out := updateParamNames(testCols, []string{"id"})
|
||||||
if out != "friend_column=$1,enemy_column_thing=$2" {
|
if out != "friend_column=$1,enemy_column_thing=$2" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -114,12 +114,12 @@ func TestUpdateParamVariables(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var testCols = []dbdrivers.Column{
|
var testCols = []dbdrivers.Column{
|
||||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
{Name: "id", Type: "int", IsNullable: false},
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
out := updateParamVariables("o.", testCols)
|
out := updateParamVariables("o.", testCols, []string{"id"})
|
||||||
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -167,3 +167,99 @@ func TestScanParams(t *testing.T) {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasPrimaryKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var pkey *dbdrivers.PrimaryKey
|
||||||
|
if hasPrimaryKey(pkey) {
|
||||||
|
t.Errorf("1) Expected false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkey = &dbdrivers.PrimaryKey{}
|
||||||
|
if hasPrimaryKey(pkey) {
|
||||||
|
t.Errorf("2) Expected false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkey.Columns = append(pkey.Columns, "test")
|
||||||
|
if !hasPrimaryKey(pkey) {
|
||||||
|
t.Errorf("3) Expected true, got false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParamsPrimaryKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Pkey dbdrivers.PrimaryKey
|
||||||
|
Prefix string
|
||||||
|
Should string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
|
||||||
|
Prefix: "o.", Should: "o.ColOne",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
|
||||||
|
Prefix: "o.", Should: "o.ColOne, o.ColTwo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
|
||||||
|
Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
|
||||||
|
if r != test.Should {
|
||||||
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests2 := []struct {
|
||||||
|
Pkey dbdrivers.PrimaryKey
|
||||||
|
Prefix string
|
||||||
|
Should string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
|
||||||
|
Prefix: "o.", Should: "o.col_one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
|
||||||
|
Prefix: "o.", Should: "o.col_one, o.col_two",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
|
||||||
|
Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests2 {
|
||||||
|
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
|
||||||
|
if r != test.Should {
|
||||||
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWherePrimaryKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Pkey dbdrivers.PrimaryKey
|
||||||
|
Start int
|
||||||
|
Should string
|
||||||
|
}{
|
||||||
|
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"},
|
||||||
|
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"},
|
||||||
|
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
r := wherePrimaryKey(&test.Pkey, test.Start)
|
||||||
|
if r != test.Should {
|
||||||
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -13,12 +13,11 @@ func {{$tableNameSingular}}Delete(db boil.DB, id int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
{{if hasPrimaryKey .Table.Columns -}}
|
{{if hasPrimaryKey .Table.PKey -}}
|
||||||
// Delete deletes a single {{$tableNameSingular}} record.
|
// Delete deletes a single {{$tableNameSingular}} record.
|
||||||
// Delete will match against the primary key column to find the record to delete.
|
// Delete will match against the primary key column to find the record to delete.
|
||||||
func (o *{{$tableNameSingular}}) Delete(db boil.DB) error {
|
func (o *{{$tableNameSingular}}) Delete(db boil.DB) error {
|
||||||
{{- $pkeyName := getPrimaryKey .Table.Columns -}}
|
_, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{wherePrimaryKey .Table.PKey.Columns 1}}", {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
|
||||||
_, err := db.Exec("DELETE FROM {{.Table.Name}} WHERE {{$pkeyName}}=$1", o.{{titleCase $pkeyName}})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err)
|
return fmt.Errorf("{{.PkgName}}: unable to delete from {{.Table.Name}}: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,16 +15,16 @@ func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interfa
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
{{if hasPrimaryKey .Table.Columns -}}
|
{{if hasPrimaryKey .Table.PKey -}}
|
||||||
// Update updates a single {{$tableNameSingular}} record.
|
// Update updates a single {{$tableNameSingular}} record.
|
||||||
// Update will match against the primary key column to find the record to update.
|
// Update will match against the primary key column to find the record to update.
|
||||||
// WARNING: This Update method will NOT ignore nil members.
|
// WARNING: This Update method will NOT ignore nil members.
|
||||||
// If you pass in nil members, those columnns will be set to null.
|
// If you pass in nil members, those columnns will be set to null.
|
||||||
func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
|
func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
|
||||||
{{- $pkeyName := getPrimaryKey .Table.Columns -}}
|
{{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}}
|
||||||
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns}} WHERE {{$pkeyName}}=${{len .Table.Columns}}", {{updateParamVariables "o." .Table.Columns}}, o.{{titleCase $pkeyName}})
|
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row using primary key {{$pkeyName}}: %s", err)
|
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -134,6 +134,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var columns []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var column string
|
var column string
|
||||||
|
|
||||||
|
@ -141,8 +142,12 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
columns = append(columns, column)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pkey.Columns = columns
|
||||||
|
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue