Finish update tests
* Add SetMerge helper * Add quotes around all column names in statements * Fix blacklist bug in RandomizeStruct * Split up update helper
This commit is contained in:
parent
08feff45d7
commit
81883d5f75
6 changed files with 155 additions and 39 deletions
|
@ -51,6 +51,26 @@ func SetIntersect(a []string, b []string) []string {
|
|||
return c
|
||||
}
|
||||
|
||||
// SetMerge will return a merged slice without duplicates
|
||||
func SetMerge(a []string, b []string) []string {
|
||||
var x, merged []string
|
||||
|
||||
x = append(x, a...)
|
||||
x = append(x, b...)
|
||||
|
||||
check := map[string]bool{}
|
||||
for _, v := range x {
|
||||
if check[v] == true {
|
||||
continue
|
||||
}
|
||||
|
||||
merged = append(merged, v)
|
||||
check[v] = true
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// NonZeroDefaultSet returns the fields included in the
|
||||
// defaults slice that are non zero values
|
||||
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
||||
|
@ -104,7 +124,7 @@ func GenerateParamFlags(colCount int, startAt int) string {
|
|||
}
|
||||
|
||||
// WherePrimaryKeyIn generates a "in" string for where queries
|
||||
// For example: (col1, col2) IN (($1, $2), ($3, $4))
|
||||
// For example: ("col1","col2") IN (($1,$2), ($3,$4))
|
||||
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
||||
in := &bytes.Buffer{}
|
||||
|
||||
|
@ -147,7 +167,7 @@ func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
|||
}
|
||||
|
||||
// SelectNames returns the column names for a select statement
|
||||
// Eg: col1, col2, col3
|
||||
// Eg: "col1", "col2", "col3"
|
||||
func SelectNames(results interface{}) string {
|
||||
var names []string
|
||||
|
||||
|
@ -164,14 +184,14 @@ func SelectNames(results interface{}) string {
|
|||
name = goVarToSQLName(field.Name)
|
||||
}
|
||||
|
||||
names = append(names, name)
|
||||
names = append(names, fmt.Sprintf(`"%s"`, name))
|
||||
}
|
||||
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// WhereClause returns the where clause for an sql statement
|
||||
// eg: col1=$1 AND col2=$2 AND col3=$3
|
||||
// eg: "col1"=$1 AND "col2"=$2 AND "col3"=$3
|
||||
func WhereClause(columns []string) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
|
@ -180,14 +200,14 @@ func WhereClause(columns []string) string {
|
|||
}
|
||||
|
||||
for i, c := range names {
|
||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
||||
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
|
||||
}
|
||||
|
||||
return strings.Join(names, " AND ")
|
||||
}
|
||||
|
||||
// Update returns the column list for an update statement SET clause
|
||||
// eg: col1=$1,col2=$2,col3=$3
|
||||
// eg: "col1"=$1, "col2"=$2, "col3"=$3
|
||||
func Update(columns map[string]interface{}) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
|
@ -198,23 +218,23 @@ func Update(columns map[string]interface{}) string {
|
|||
sort.Strings(names)
|
||||
|
||||
for i, c := range names {
|
||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
||||
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
|
||||
}
|
||||
|
||||
return strings.Join(names, ",")
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// SetParamNames takes a slice of columns and returns a comma seperated
|
||||
// list of parameter names for a template statement SET clause.
|
||||
// eg: col1=$1,col2=$2,col3=$3
|
||||
// eg: "col1"=$1, "col2"=$2, "col3"=$3
|
||||
func SetParamNames(columns []string) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
counter := 0
|
||||
for _, c := range columns {
|
||||
counter++
|
||||
names = append(names, fmt.Sprintf("%s=$%d", c, counter))
|
||||
names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
|
||||
}
|
||||
return strings.Join(names, ",")
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// WherePrimaryKey returns the where clause using start as the $ flag index
|
||||
|
@ -222,7 +242,7 @@ func SetParamNames(columns []string) string {
|
|||
func WherePrimaryKey(start int, pkeys ...string) string {
|
||||
var output string
|
||||
for i, c := range pkeys {
|
||||
output = fmt.Sprintf("%s%s=$%d", output, c, start)
|
||||
output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
|
||||
start++
|
||||
|
||||
if i < len(pkeys)-1 {
|
||||
|
|
|
@ -90,6 +90,49 @@ func TestSetIntersect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSetMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
A []string
|
||||
B []string
|
||||
C []string
|
||||
}{
|
||||
{
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
[]string{"thing1", "thing3", "thing4"},
|
||||
[]string{"thing1", "thing2", "thing3", "thing4"},
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{"thing1", "thing2"},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
[]string{"thing1", "thing2"},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{"thing3", "thing4"},
|
||||
[]string{"thing1", "thing2", "thing3", "thing4"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
m := SetMerge(test.A, test.B)
|
||||
if !reflect.DeepEqual(test.C, m) {
|
||||
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonZeroDefaultSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -180,10 +180,15 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid
|
|||
fieldVal := value.Field(i)
|
||||
fieldTyp := typ.Field(i)
|
||||
|
||||
found := sort.Search(len(blacklist), func(i int) bool {
|
||||
return blacklist[i] == fieldTyp.Name
|
||||
})
|
||||
if found != len(blacklist) {
|
||||
var found bool
|
||||
for _, v := range blacklist {
|
||||
if strmangle.TitleCase(v) == fieldTyp.Name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -3,12 +3,10 @@
|
|||
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
|
||||
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
|
||||
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}}
|
||||
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
|
||||
// column_name's that should be updated. The primary key will be used to find
|
||||
// the record to update.
|
||||
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
|
||||
// to control the set of columns that will be saved.
|
||||
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
|
||||
// Update a single {{$tableNameSingular}} record.
|
||||
// Update takes a whitelist of column names that should be updated.
|
||||
// The primary key will be used to find the record to update.
|
||||
func (o *{{$tableNameSingular}}) Update(whitelist ...string) error {
|
||||
return o.UpdateX(boil.GetDB(), whitelist...)
|
||||
}
|
||||
|
||||
|
@ -28,28 +26,24 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{$pkArgs}}, whit
|
|||
return err
|
||||
}
|
||||
|
||||
if len(whitelist) == 0 {
|
||||
cols := {{$varNameSingular}}ColumnsWithoutDefault
|
||||
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
|
||||
// Subtract primary keys and autoincrement columns
|
||||
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
|
||||
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
|
||||
|
||||
whitelist = make([]string, len(cols))
|
||||
copy(whitelist, cols)
|
||||
}
|
||||
|
||||
var err error
|
||||
var query string
|
||||
if len(whitelist) != 0 {
|
||||
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
|
||||
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
|
||||
var values []interface{}
|
||||
|
||||
wl := o.generateUpdateColumns(whitelist...)
|
||||
|
||||
if len(wl) != 0 {
|
||||
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(wl), boil.WherePrimaryKey(len(wl)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
|
||||
values = boil.GetStructValues(o, wl...)
|
||||
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
|
||||
_, err = exec.Exec(query, values...)
|
||||
} else {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err)
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
|
||||
}
|
||||
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, query)
|
||||
fmt.Fprintln(boil.DebugWriter, values)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -73,3 +67,22 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateUpdateColumns generates the whitelist columns for an update statement
|
||||
func (o *{{$tableNameSingular}}) generateUpdateColumns(whitelist ...string) []string {
|
||||
if len(whitelist) != 0 {
|
||||
return whitelist
|
||||
}
|
||||
|
||||
var wl []string
|
||||
cols := {{$varNameSingular}}ColumnsWithoutDefault
|
||||
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
|
||||
// Subtract primary keys and autoincrement columns
|
||||
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
|
||||
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
|
||||
|
||||
wl = make([]string, len(cols))
|
||||
copy(wl, cols)
|
||||
|
||||
return wl
|
||||
}
|
||||
|
|
|
@ -4,5 +4,40 @@
|
|||
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
|
||||
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
|
||||
func Test{{$tableNamePlural}}Update(t *testing.T) {
|
||||
t.Skip("test update not implemented")
|
||||
var err error
|
||||
|
||||
item := {{$tableNameSingular}}{}
|
||||
if err = item.Insert(); err != nil {
|
||||
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
|
||||
}
|
||||
|
||||
blacklistCols := boil.SetMerge({{$varNameSingular}}AutoIncrementColumns, {{$varNameSingular}}PrimaryKeyColumns)
|
||||
if err = boil.RandomizeStruct(&item, {{$varNameSingular}}DBTypes, false, blacklistCols...); err != nil {
|
||||
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
|
||||
}
|
||||
|
||||
whitelist := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}AutoIncrementColumns)
|
||||
if err = item.Update(whitelist...); err != nil {
|
||||
t.Errorf("Unable to update {{$tableNameSingular}}: %s", err)
|
||||
}
|
||||
|
||||
var j *{{$tableNameSingular}}
|
||||
j, err = {{$tableNameSingular}}Find({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})
|
||||
if err != nil {
|
||||
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
|
||||
}
|
||||
|
||||
{{$varNameSingular}}CompareVals(&item, j, t)
|
||||
|
||||
wl := item.generateUpdateColumns("test")
|
||||
if len(wl) != 1 && wl[0] != "test" {
|
||||
t.Errorf("Expected generateUpdateColumns whitelist to match expected whitelist")
|
||||
}
|
||||
|
||||
wl = item.generateUpdateColumns()
|
||||
if len(wl) == 0 && len({{$varNameSingular}}ColumnsWithoutDefault) > 0 {
|
||||
t.Errorf("Expected generateUpdateColumns to build a whitelist for {{$tableNameSingular}}, but got 0 results")
|
||||
}
|
||||
|
||||
{{$varNamePlural}}DeleteAllRows(t)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue