Finish update tests
* Add SetMerge helper * Add quotes around all column names in statements * Fix blacklist bug in RandomizeStruct * Split up update helper
This commit is contained in:
parent
08feff45d7
commit
81883d5f75
6 changed files with 155 additions and 39 deletions
|
@ -51,6 +51,26 @@ func SetIntersect(a []string, b []string) []string {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMerge will return a merged slice without duplicates
|
||||||
|
func SetMerge(a []string, b []string) []string {
|
||||||
|
var x, merged []string
|
||||||
|
|
||||||
|
x = append(x, a...)
|
||||||
|
x = append(x, b...)
|
||||||
|
|
||||||
|
check := map[string]bool{}
|
||||||
|
for _, v := range x {
|
||||||
|
if check[v] == true {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merged = append(merged, v)
|
||||||
|
check[v] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
// NonZeroDefaultSet returns the fields included in the
|
// NonZeroDefaultSet returns the fields included in the
|
||||||
// defaults slice that are non zero values
|
// defaults slice that are non zero values
|
||||||
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
||||||
|
@ -104,7 +124,7 @@ func GenerateParamFlags(colCount int, startAt int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WherePrimaryKeyIn generates a "in" string for where queries
|
// WherePrimaryKeyIn generates a "in" string for where queries
|
||||||
// For example: (col1, col2) IN (($1, $2), ($3, $4))
|
// For example: ("col1","col2") IN (($1,$2), ($3,$4))
|
||||||
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
||||||
in := &bytes.Buffer{}
|
in := &bytes.Buffer{}
|
||||||
|
|
||||||
|
@ -147,7 +167,7 @@ func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectNames returns the column names for a select statement
|
// SelectNames returns the column names for a select statement
|
||||||
// Eg: col1, col2, col3
|
// Eg: "col1", "col2", "col3"
|
||||||
func SelectNames(results interface{}) string {
|
func SelectNames(results interface{}) string {
|
||||||
var names []string
|
var names []string
|
||||||
|
|
||||||
|
@ -164,14 +184,14 @@ func SelectNames(results interface{}) string {
|
||||||
name = goVarToSQLName(field.Name)
|
name = goVarToSQLName(field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
names = append(names, name)
|
names = append(names, fmt.Sprintf(`"%s"`, name))
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// WhereClause returns the where clause for an sql statement
|
// WhereClause returns the where clause for an sql statement
|
||||||
// eg: col1=$1 AND col2=$2 AND col3=$3
|
// eg: "col1"=$1 AND "col2"=$2 AND "col3"=$3
|
||||||
func WhereClause(columns []string) string {
|
func WhereClause(columns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
|
@ -180,14 +200,14 @@ func WhereClause(columns []string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, c := range names {
|
for i, c := range names {
|
||||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(names, " AND ")
|
return strings.Join(names, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update returns the column list for an update statement SET clause
|
// Update returns the column list for an update statement SET clause
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: "col1"=$1, "col2"=$2, "col3"=$3
|
||||||
func Update(columns map[string]interface{}) string {
|
func Update(columns map[string]interface{}) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
|
@ -198,23 +218,23 @@ func Update(columns map[string]interface{}) string {
|
||||||
sort.Strings(names)
|
sort.Strings(names)
|
||||||
|
|
||||||
for i, c := range names {
|
for i, c := range names {
|
||||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(names, ",")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetParamNames takes a slice of columns and returns a comma seperated
|
// SetParamNames takes a slice of columns and returns a comma seperated
|
||||||
// list of parameter names for a template statement SET clause.
|
// list of parameter names for a template statement SET clause.
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: "col1"=$1, "col2"=$2, "col3"=$3
|
||||||
func SetParamNames(columns []string) string {
|
func SetParamNames(columns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
counter := 0
|
counter := 0
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
counter++
|
counter++
|
||||||
names = append(names, fmt.Sprintf("%s=$%d", c, counter))
|
names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
|
||||||
}
|
}
|
||||||
return strings.Join(names, ",")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// WherePrimaryKey returns the where clause using start as the $ flag index
|
// WherePrimaryKey returns the where clause using start as the $ flag index
|
||||||
|
@ -222,7 +242,7 @@ func SetParamNames(columns []string) string {
|
||||||
func WherePrimaryKey(start int, pkeys ...string) string {
|
func WherePrimaryKey(start int, pkeys ...string) string {
|
||||||
var output string
|
var output string
|
||||||
for i, c := range pkeys {
|
for i, c := range pkeys {
|
||||||
output = fmt.Sprintf("%s%s=$%d", output, c, start)
|
output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
|
||||||
start++
|
start++
|
||||||
|
|
||||||
if i < len(pkeys)-1 {
|
if i < len(pkeys)-1 {
|
||||||
|
|
|
@ -90,6 +90,49 @@ func TestSetIntersect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetMerge(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
A []string
|
||||||
|
B []string
|
||||||
|
C []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
[]string{"thing1", "thing3", "thing4"},
|
||||||
|
[]string{"thing1", "thing2", "thing3", "thing4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{"thing3", "thing4"},
|
||||||
|
[]string{"thing1", "thing2", "thing3", "thing4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
m := SetMerge(test.A, test.B)
|
||||||
|
if !reflect.DeepEqual(test.C, m) {
|
||||||
|
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNonZeroDefaultSet(t *testing.T) {
|
func TestNonZeroDefaultSet(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -180,10 +180,15 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid
|
||||||
fieldVal := value.Field(i)
|
fieldVal := value.Field(i)
|
||||||
fieldTyp := typ.Field(i)
|
fieldTyp := typ.Field(i)
|
||||||
|
|
||||||
found := sort.Search(len(blacklist), func(i int) bool {
|
var found bool
|
||||||
return blacklist[i] == fieldTyp.Name
|
for _, v := range blacklist {
|
||||||
})
|
if strmangle.TitleCase(v) == fieldTyp.Name {
|
||||||
if found != len(blacklist) {
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,10 @@
|
||||||
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
|
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
|
||||||
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
|
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
|
||||||
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}}
|
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", "}}
|
||||||
// Update a single {{$tableNameSingular}} record. It takes a whitelist of
|
// Update a single {{$tableNameSingular}} record.
|
||||||
// column_name's that should be updated. The primary key will be used to find
|
// Update takes a whitelist of column names that should be updated.
|
||||||
// the record to update.
|
// The primary key will be used to find the record to update.
|
||||||
// WARNING: Update does NOT ignore nil members - only the whitelist can be used
|
func (o *{{$tableNameSingular}}) Update(whitelist ...string) error {
|
||||||
// to control the set of columns that will be saved.
|
|
||||||
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
|
|
||||||
return o.UpdateX(boil.GetDB(), whitelist...)
|
return o.UpdateX(boil.GetDB(), whitelist...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,28 +26,24 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{$pkArgs}}, whit
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(whitelist) == 0 {
|
|
||||||
cols := {{$varNameSingular}}ColumnsWithoutDefault
|
|
||||||
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
|
|
||||||
// Subtract primary keys and autoincrement columns
|
|
||||||
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
|
|
||||||
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
|
|
||||||
|
|
||||||
whitelist = make([]string, len(cols))
|
|
||||||
copy(whitelist, cols)
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var query string
|
var query string
|
||||||
if len(whitelist) != 0 {
|
var values []interface{}
|
||||||
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
|
|
||||||
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
|
wl := o.generateUpdateColumns(whitelist...)
|
||||||
|
|
||||||
|
if len(wl) != 0 {
|
||||||
|
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(wl), boil.WherePrimaryKey(len(wl)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
|
||||||
|
values = boil.GetStructValues(o, wl...)
|
||||||
|
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
|
||||||
|
_, err = exec.Exec(query, values...)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err)
|
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
|
||||||
}
|
}
|
||||||
|
|
||||||
if boil.DebugMode {
|
if boil.DebugMode {
|
||||||
fmt.Fprintln(boil.DebugWriter, query)
|
fmt.Fprintln(boil.DebugWriter, query)
|
||||||
|
fmt.Fprintln(boil.DebugWriter, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -73,3 +67,22 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateUpdateColumns generates the whitelist columns for an update statement
|
||||||
|
func (o *{{$tableNameSingular}}) generateUpdateColumns(whitelist ...string) []string {
|
||||||
|
if len(whitelist) != 0 {
|
||||||
|
return whitelist
|
||||||
|
}
|
||||||
|
|
||||||
|
var wl []string
|
||||||
|
cols := {{$varNameSingular}}ColumnsWithoutDefault
|
||||||
|
cols = append(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), cols...)
|
||||||
|
// Subtract primary keys and autoincrement columns
|
||||||
|
cols = boil.SetComplement(cols, {{$varNameSingular}}PrimaryKeyColumns)
|
||||||
|
cols = boil.SetComplement(cols, {{$varNameSingular}}AutoIncrementColumns)
|
||||||
|
|
||||||
|
wl = make([]string, len(cols))
|
||||||
|
copy(wl, cols)
|
||||||
|
|
||||||
|
return wl
|
||||||
|
}
|
||||||
|
|
|
@ -4,5 +4,40 @@
|
||||||
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
|
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
|
||||||
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
|
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
|
||||||
func Test{{$tableNamePlural}}Update(t *testing.T) {
|
func Test{{$tableNamePlural}}Update(t *testing.T) {
|
||||||
t.Skip("test update not implemented")
|
var err error
|
||||||
|
|
||||||
|
item := {{$tableNameSingular}}{}
|
||||||
|
if err = item.Insert(); err != nil {
|
||||||
|
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blacklistCols := boil.SetMerge({{$varNameSingular}}AutoIncrementColumns, {{$varNameSingular}}PrimaryKeyColumns)
|
||||||
|
if err = boil.RandomizeStruct(&item, {{$varNameSingular}}DBTypes, false, blacklistCols...); err != nil {
|
||||||
|
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
whitelist := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}AutoIncrementColumns)
|
||||||
|
if err = item.Update(whitelist...); err != nil {
|
||||||
|
t.Errorf("Unable to update {{$tableNameSingular}}: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var j *{{$tableNameSingular}}
|
||||||
|
j, err = {{$tableNameSingular}}Find({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
{{$varNameSingular}}CompareVals(&item, j, t)
|
||||||
|
|
||||||
|
wl := item.generateUpdateColumns("test")
|
||||||
|
if len(wl) != 1 && wl[0] != "test" {
|
||||||
|
t.Errorf("Expected generateUpdateColumns whitelist to match expected whitelist")
|
||||||
|
}
|
||||||
|
|
||||||
|
wl = item.generateUpdateColumns()
|
||||||
|
if len(wl) == 0 && len({{$varNameSingular}}ColumnsWithoutDefault) > 0 {
|
||||||
|
t.Errorf("Expected generateUpdateColumns to build a whitelist for {{$tableNameSingular}}, but got 0 results")
|
||||||
|
}
|
||||||
|
|
||||||
|
{{$varNamePlural}}DeleteAllRows(t)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue