Finish update tests

* Add SetMerge helper
* Add quotes around all column names in statements
* Fix blacklist bug in RandomizeStruct
* Split up update helper
This commit is contained in:
Patrick O'brien 2016-07-15 20:14:47 +10:00
parent 08feff45d7
commit 81883d5f75
6 changed files with 155 additions and 39 deletions

View file

@ -51,6 +51,26 @@ func SetIntersect(a []string, b []string) []string {
return c
}
// SetMerge will return a merged slice without duplicates
func SetMerge(a []string, b []string) []string {
var x, merged []string
x = append(x, a...)
x = append(x, b...)
check := map[string]bool{}
for _, v := range x {
if check[v] == true {
continue
}
merged = append(merged, v)
check[v] = true
}
return merged
}
// NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
@ -104,7 +124,7 @@ func GenerateParamFlags(colCount int, startAt int) string {
}
// WherePrimaryKeyIn generates a "in" string for where queries
// For example: (col1, col2) IN (($1, $2), ($3, $4))
// For example: ("col1","col2") IN (($1,$2), ($3,$4))
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
in := &bytes.Buffer{}
@ -147,7 +167,7 @@ func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
}
// SelectNames returns the column names for a select statement
// Eg: col1, col2, col3
// Eg: "col1", "col2", "col3"
func SelectNames(results interface{}) string {
var names []string
@ -164,14 +184,14 @@ func SelectNames(results interface{}) string {
name = goVarToSQLName(field.Name)
}
names = append(names, name)
names = append(names, fmt.Sprintf(`"%s"`, name))
}
return strings.Join(names, ", ")
}
// WhereClause returns the where clause for an sql statement
// eg: col1=$1 AND col2=$2 AND col3=$3
// eg: "col1"=$1 AND "col2"=$2 AND "col3"=$3
func WhereClause(columns []string) string {
names := make([]string, 0, len(columns))
@ -180,14 +200,14 @@ func WhereClause(columns []string) string {
}
for i, c := range names {
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
}
return strings.Join(names, " AND ")
}
// Update returns the column list for an update statement SET clause
// eg: col1=$1,col2=$2,col3=$3
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func Update(columns map[string]interface{}) string {
names := make([]string, 0, len(columns))
@ -198,23 +218,23 @@ func Update(columns map[string]interface{}) string {
sort.Strings(names)
for i, c := range names {
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
}
return strings.Join(names, ",")
return strings.Join(names, ", ")
}
// SetParamNames takes a slice of columns and returns a comma seperated
// list of parameter names for a template statement SET clause.
// eg: col1=$1,col2=$2,col3=$3
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string {
names := make([]string, 0, len(columns))
counter := 0
for _, c := range columns {
counter++
names = append(names, fmt.Sprintf("%s=$%d", c, counter))
names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
}
return strings.Join(names, ",")
return strings.Join(names, ", ")
}
// WherePrimaryKey returns the where clause using start as the $ flag index
@ -222,7 +242,7 @@ func SetParamNames(columns []string) string {
func WherePrimaryKey(start int, pkeys ...string) string {
var output string
for i, c := range pkeys {
output = fmt.Sprintf("%s%s=$%d", output, c, start)
output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
start++
if i < len(pkeys)-1 {

View file

@ -90,6 +90,49 @@ func TestSetIntersect(t *testing.T) {
}
}
func TestSetMerge(t *testing.T) {
t.Parallel()
tests := []struct {
A []string
B []string
C []string
}{
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing3", "thing4"},
[]string{"thing1", "thing2", "thing3", "thing4"},
},
{
[]string{},
[]string{"thing1", "thing2"},
[]string{"thing1", "thing2"},
},
{
[]string{"thing1", "thing2"},
[]string{},
[]string{"thing1", "thing2"},
},
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing2", "thing3"},
[]string{"thing1", "thing2", "thing3"},
},
{
[]string{"thing1", "thing2"},
[]string{"thing3", "thing4"},
[]string{"thing1", "thing2", "thing3", "thing4"},
},
}
for i, test := range tests {
m := SetMerge(test.A, test.B)
if !reflect.DeepEqual(test.C, m) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, m)
}
}
}
func TestNonZeroDefaultSet(t *testing.T) {
t.Parallel()

View file

@ -180,10 +180,15 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid
fieldVal := value.Field(i)
fieldTyp := typ.Field(i)
found := sort.Search(len(blacklist), func(i int) bool {
return blacklist[i] == fieldTyp.Name
})
if found != len(blacklist) {
var found bool
for _, v := range blacklist {
if strmangle.TitleCase(v) == fieldTyp.Name {
found = true
break
}
}
if found {
continue
}

View file

@ -3,12 +3,10 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}}
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
// column_name's that should be updated. The primary key will be used to find
// the record to update.
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
// to control the set of columns that will be saved.
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
// Update a single {{$tableNameSingular}} record.
// Update takes a whitelist of column names that should be updated.
// The primary key will be used to find the record to update.
func (o *{{$tableNameSingular}}) Update(whitelist ...string) error {
return o.UpdateX(boil.GetDB(), whitelist...)
}
@ -28,28 +26,24 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{$pkArgs}}, whit
return err
}
if len(whitelist) == 0 {
cols := {{$varNameSingular}}ColumnsWithoutDefault
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
// Subtract primary keys and autoincrement columns
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
whitelist = make([]string, len(cols))
copy(whitelist, cols)
}
var err error
var query string
if len(whitelist) != 0 {
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
var values []interface{}
wl := o.generateUpdateColumns(whitelist...)
if len(wl) != 0 {
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(wl), boil.WherePrimaryKey(len(wl)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
values = boil.GetStructValues(o, wl...)
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
_, err = exec.Exec(query, values...)
} else {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err)
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
if err != nil {
@ -73,3 +67,22 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
return nil
}
// generateUpdateColumns generates the whitelist columns for an update statement
func (o *{{$tableNameSingular}}) generateUpdateColumns(whitelist ...string) []string {
if len(whitelist) != 0 {
return whitelist
}
var wl []string
cols := {{$varNameSingular}}ColumnsWithoutDefault
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
// Subtract primary keys and autoincrement columns
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
wl = make([]string, len(cols))
copy(wl, cols)
return wl
}

View file

@ -4,5 +4,40 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func Test{{$tableNamePlural}}Update(t *testing.T) {
t.Skip("test update not implemented")
var err error
item := {{$tableNameSingular}}{}
if err = item.Insert(); err != nil {
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
}
blacklistCols := boil.SetMerge({{$varNameSingular}}AutoIncrementColumns, {{$varNameSingular}}PrimaryKeyColumns)
if err = boil.RandomizeStruct(&item, {{$varNameSingular}}DBTypes, false, blacklistCols...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
whitelist := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}AutoIncrementColumns)
if err = item.Update(whitelist...); err != nil {
t.Errorf("Unable to update {{$tableNameSingular}}: %s", err)
}
var j *{{$tableNameSingular}}
j, err = {{$tableNameSingular}}Find({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})
if err != nil {
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
}
{{$varNameSingular}}CompareVals(&item, j, t)
wl := item.generateUpdateColumns("test")
if len(wl) != 1 && wl[0] != "test" {
t.Errorf("Expected generateUpdateColumns whitelist to match expected whitelist")
}
wl = item.generateUpdateColumns()
if len(wl) == 0 && len({{$varNameSingular}}ColumnsWithoutDefault) > 0 {
t.Errorf("Expected generateUpdateColumns to build a whitelist for {{$tableNameSingular}}, but got 0 results")
}
{{$varNamePlural}}DeleteAllRows(t)
}