Finish update tests

* Add SetMerge helper
* Add quotes around all column names in statements
* Fix blacklist bug in RandomizeStruct
* Split up update helper
This commit is contained in:
Patrick O'brien 2016-07-15 20:14:47 +10:00
parent 08feff45d7
commit 81883d5f75
6 changed files with 155 additions and 39 deletions

View file

@ -51,6 +51,26 @@ func SetIntersect(a []string, b []string) []string {
return c return c
} }
// SetMerge will return a merged slice without duplicates
func SetMerge(a []string, b []string) []string {
var x, merged []string
x = append(x, a...)
x = append(x, b...)
check := map[string]bool{}
for _, v := range x {
if check[v] == true {
continue
}
merged = append(merged, v)
check[v] = true
}
return merged
}
// NonZeroDefaultSet returns the fields included in the // NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values // defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string { func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
@ -104,7 +124,7 @@ func GenerateParamFlags(colCount int, startAt int) string {
} }
// WherePrimaryKeyIn generates a "in" string for where queries // WherePrimaryKeyIn generates a "in" string for where queries
// For example: (col1, col2) IN (($1, $2), ($3, $4)) // For example: ("col1","col2") IN (($1,$2), ($3,$4))
func WherePrimaryKeyIn(numRows int, keyNames ...string) string { func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
in := &bytes.Buffer{} in := &bytes.Buffer{}
@ -147,7 +167,7 @@ func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
} }
// SelectNames returns the column names for a select statement // SelectNames returns the column names for a select statement
// Eg: col1, col2, col3 // Eg: "col1", "col2", "col3"
func SelectNames(results interface{}) string { func SelectNames(results interface{}) string {
var names []string var names []string
@ -164,14 +184,14 @@ func SelectNames(results interface{}) string {
name = goVarToSQLName(field.Name) name = goVarToSQLName(field.Name)
} }
names = append(names, name) names = append(names, fmt.Sprintf(`"%s"`, name))
} }
return strings.Join(names, ", ") return strings.Join(names, ", ")
} }
// WhereClause returns the where clause for an sql statement // WhereClause returns the where clause for an sql statement
// eg: col1=$1 AND col2=$2 AND col3=$3 // eg: "col1"=$1 AND "col2"=$2 AND "col3"=$3
func WhereClause(columns []string) string { func WhereClause(columns []string) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
@ -180,14 +200,14 @@ func WhereClause(columns []string) string {
} }
for i, c := range names { for i, c := range names {
names[i] = fmt.Sprintf("%s=$%d", c, i+1) names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
} }
return strings.Join(names, " AND ") return strings.Join(names, " AND ")
} }
// Update returns the column list for an update statement SET clause // Update returns the column list for an update statement SET clause
// eg: col1=$1,col2=$2,col3=$3 // eg: "col1"=$1, "col2"=$2, "col3"=$3
func Update(columns map[string]interface{}) string { func Update(columns map[string]interface{}) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
@ -198,23 +218,23 @@ func Update(columns map[string]interface{}) string {
sort.Strings(names) sort.Strings(names)
for i, c := range names { for i, c := range names {
names[i] = fmt.Sprintf("%s=$%d", c, i+1) names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
} }
return strings.Join(names, ",") return strings.Join(names, ", ")
} }
// SetParamNames takes a slice of columns and returns a comma seperated // SetParamNames takes a slice of columns and returns a comma seperated
// list of parameter names for a template statement SET clause. // list of parameter names for a template statement SET clause.
// eg: col1=$1,col2=$2,col3=$3 // eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string { func SetParamNames(columns []string) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
counter := 0 counter := 0
for _, c := range columns { for _, c := range columns {
counter++ counter++
names = append(names, fmt.Sprintf("%s=$%d", c, counter)) names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
} }
return strings.Join(names, ",") return strings.Join(names, ", ")
} }
// WherePrimaryKey returns the where clause using start as the $ flag index // WherePrimaryKey returns the where clause using start as the $ flag index
@ -222,7 +242,7 @@ func SetParamNames(columns []string) string {
func WherePrimaryKey(start int, pkeys ...string) string { func WherePrimaryKey(start int, pkeys ...string) string {
var output string var output string
for i, c := range pkeys { for i, c := range pkeys {
output = fmt.Sprintf("%s%s=$%d", output, c, start) output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
start++ start++
if i < len(pkeys)-1 { if i < len(pkeys)-1 {

View file

@ -90,6 +90,49 @@ func TestSetIntersect(t *testing.T) {
} }
} }
func TestSetMerge(t *testing.T) {
t.Parallel()
tests := []struct {
A []string
B []string
C []string
}{
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing3", "thing4"},
[]string{"thing1", "thing2", "thing3", "thing4"},
},
{
[]string{},
[]string{"thing1", "thing2"},
[]string{"thing1", "thing2"},
},
{
[]string{"thing1", "thing2"},
[]string{},
[]string{"thing1", "thing2"},
},
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing2", "thing3"},
},
{
[]string{"thing1", "thing2"},
[]string{"thing3", "thing4"},
[]string{"thing1", "thing2", "thing3", "thing4"},
},
}
for i, test := range tests {
m := SetMerge(test.A, test.B)
if !reflect.DeepEqual(test.C, m) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, m)
}
}
}
func TestNonZeroDefaultSet(t *testing.T) { func TestNonZeroDefaultSet(t *testing.T) {
t.Parallel() t.Parallel()

View file

@ -180,10 +180,15 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid
fieldVal := value.Field(i) fieldVal := value.Field(i)
fieldTyp := typ.Field(i) fieldTyp := typ.Field(i)
found := sort.Search(len(blacklist), func(i int) bool { var found bool
return blacklist[i] == fieldTyp.Name for _, v := range blacklist {
}) if strmangle.TitleCase(v) == fieldTyp.Name {
if found != len(blacklist) { found = true
break
}
}
if found {
continue continue
} }

View file

@ -3,12 +3,10 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}}
// Update a single {{$tableNameSingular}} record. It takes a whitelist of // Update a single {{$tableNameSingular}} record.
// column_name's that should be updated. The primary key will be used to find // Update takes a whitelist of column names that should be updated.
// the record to update. // The primary key will be used to find the record to update.
// WARNING: Update does NOT ignore nil members - only the whitelist can be used func (o *{{$tableNameSingular}}) Update(whitelist ...string) error {
// to control the set of columns that will be saved.
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
return o.UpdateX(boil.GetDB(), whitelist...) return o.UpdateX(boil.GetDB(), whitelist...)
} }
@ -28,28 +26,24 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{$pkArgs}}, whit
return err return err
} }
if len(whitelist) == 0 {
cols := {{$varNameSingular}}ColumnsWithoutDefault
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
// Subtract primary keys and autoincrement columns
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
whitelist = make([]string, len(cols))
copy(whitelist, cols)
}
var err error var err error
var query string var query string
if len(whitelist) != 0 { var values []interface{}
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) wl := o.generateUpdateColumns(whitelist...)
if len(wl) != 0 {
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(wl), boil.WherePrimaryKey(len(wl)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
values = boil.GetStructValues(o, wl...)
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
_, err = exec.Exec(query, values...)
} else { } else {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err) return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
} }
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query) fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
} }
if err != nil { if err != nil {
@ -73,3 +67,22 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
return nil return nil
} }
// generateUpdateColumns generates the whitelist columns for an update statement
func (o *{{$tableNameSingular}}) generateUpdateColumns(whitelist ...string) []string {
if len(whitelist) != 0 {
return whitelist
}
var wl []string
cols := {{$varNameSingular}}ColumnsWithoutDefault
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
// Subtract primary keys and autoincrement columns
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
wl = make([]string, len(cols))
copy(wl, cols)
return wl
}

View file

@ -4,5 +4,40 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func Test{{$tableNamePlural}}Update(t *testing.T) { func Test{{$tableNamePlural}}Update(t *testing.T) {
t.Skip("test update not implemented") var err error
item := {{$tableNameSingular}}{}
if err = item.Insert(); err != nil {
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
}
blacklistCols := boil.SetMerge({{$varNameSingular}}AutoIncrementColumns, {{$varNameSingular}}PrimaryKeyColumns)
if err = boil.RandomizeStruct(&item, {{$varNameSingular}}DBTypes, false, blacklistCols...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
whitelist := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}AutoIncrementColumns)
if err = item.Update(whitelist...); err != nil {
t.Errorf("Unable to update {{$tableNameSingular}}: %s", err)
}
var j *{{$tableNameSingular}}
j, err = {{$tableNameSingular}}Find({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})
if err != nil {
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
}
{{$varNameSingular}}CompareVals(&item, j, t)
wl := item.generateUpdateColumns("test")
if len(wl) != 1 && wl[0] != "test" {
t.Errorf("Expected generateUpdateColumns whitelist to match expected whitelist")
}
wl = item.generateUpdateColumns()
if len(wl) == 0 && len({{$varNameSingular}}ColumnsWithoutDefault) > 0 {
t.Errorf("Expected generateUpdateColumns to build a whitelist for {{$tableNameSingular}}, but got 0 results")
}
{{$varNamePlural}}DeleteAllRows(t)
} }