diff --git a/imports.go b/imports.go index 80463b1..fe23eee 100644 --- a/imports.go +++ b/imports.go @@ -180,6 +180,7 @@ var defaultTestTemplateImports = imports{ `"reflect"`, `"time"`, `"fmt"`, + `"bytes"`, }, thirdParty: importList{ `"github.com/pkg/errors"`, @@ -208,7 +209,6 @@ var defaultSingletonTestTemplateImports = map[string]imports{ `"os"`, `"strconv"`, `"math/rand"`, - `"bytes"`, }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, diff --git a/templates/09_update.tpl b/templates/09_update.tpl index 2a0a252..67278f6 100644 --- a/templates/09_update.tpl +++ b/templates/09_update.tpl @@ -32,6 +32,8 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin // No whitelist behavior: Without a whitelist, columns are inferred by the following rules: // - All columns are inferred to start with // - All primary keys are subtracted from this set +// Update does not automatically update the record in case of default values. Use .Reload() +// to refresh the records. func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error { if err := o.doBeforeUpdateHooks(); err != nil { return err @@ -42,7 +44,6 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string var values []interface{} wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - if len(wl) == 0 { return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") } @@ -56,11 +57,15 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string fmt.Fprintln(boil.DebugWriter, values) } - _, err = exec.Exec(query, values...) + result, err := exec.Exec(query, values...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") } + if r, err := result.RowsAffected(); err == nil && r != 1 { + return errors.Errorf("failed to update single row, updated %d rows", r) + } + return o.doAfterUpdateHooks() } @@ -104,22 +109,23 @@ func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) { // UpdateAll updates all rows with the specified column values, using an executor. func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for update all") - } - - if len(o) == 0 { + ln := int64(len(o)) + if ln == 0 { return nil } - colNames := make([]string, len(cols)) - var args []interface{} + if len(cols) == 0 { + return errors.New("{{.PkgName}}: update all requires at least one column argument") + } - count := 0 + colNames := make([]string, len(cols)) + args := make([]interface{}, len(cols)) + + i := 0 for name, value := range cols { - colNames[count] = strmangle.IdentQuote(name) - args = append(args, value) - count++ + colNames[i] = strmangle.IdentQuote(name) + args[i] = value + i++ } // Append all of the primary key values for each column @@ -138,10 +144,14 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error fmt.Fprintln(boil.DebugWriter, args...) } - _, err := exec.Exec(sql, args...) + result, err := exec.Exec(sql, args...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") } + if r, err := result.RowsAffected(); err == nil && r != ln { + return errors.Errorf("failed to update %d rows, only affected %d", ln, r) + } + return nil } diff --git a/templates_test/helpers.tpl b/templates_test/helpers.tpl index 88ed76a..009e1df 100644 --- a/templates_test/helpers.tpl +++ b/templates_test/helpers.tpl @@ -4,6 +4,12 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}} +var ( + _ = bytes.Equal + _ = time.Second + _ = null.Bool.IsZero +) + func {{$varNameSingular}}CompareVals(o *{{$tableNameSingular}}, j *{{$tableNameSingular}}, equal bool, blacklist ...string) error { {{- range $key, $value := .Table.Columns -}} {{if eq $value.Type "null.Time"}} @@ -19,8 +25,8 @@ func {{$varNameSingular}}CompareVals(o *{{$tableNameSingular}}, j *{{$tableNameS return errors.New(fmt.Sprintf("Time {{$value.Name}} unexpected value, got:\nStruct: %#v\nResponse: %#v\n\n", o.{{titleCase $value.Name}}.Format("02/01/2006"), j.{{titleCase $value.Name}}.Format("02/01/2006"))) } {{else if eq $value.Type "[]byte"}} - if ((equal && !byteSliceEqual(o.{{titleCase $value.Name}}, j.{{titleCase $value.Name}})) || - (!equal && byteSliceEqual(o.{{titleCase $value.Name}}, j.{{titleCase $value.Name}}))) && + if ((equal && !bytes.Equal(o.{{titleCase $value.Name}}, j.{{titleCase $value.Name}})) || + (!equal && bytes.Equal(o.{{titleCase $value.Name}}, j.{{titleCase $value.Name}}))) && !strmangle.SetInclude("{{$value.Name}}", blacklist) { return errors.New(fmt.Sprintf("Expected {{$value.Name}} columns to match, got:\nStruct: %#v\nResponse: %#v\n\n", o.{{titleCase $value.Name}}, j.{{titleCase $value.Name}})) } diff --git a/templates_test/insert.tpl b/templates_test/insert.tpl index 97ba512..16bcfba 100644 --- a/templates_test/insert.tpl +++ b/templates_test/insert.tpl @@ -4,19 +4,8 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $parent := . -}} func Test{{$tableNamePlural}}Insert(t *testing.T) { - var err error - - var errs []error - _ = errs - - emptyTime := time.Time{}.String() - _ = emptyTime - - nullTime := null.NewTime(time.Time{}, true) - _ = nullTime - - o := make({{$tableNameSingular}}Slice, 3) - if err = boil.RandomizeSlice(&o, {{$varNameSingular}}DBTypes, true); err != nil { + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err := boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) } @@ -24,77 +13,44 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) { if err != nil { t.Fatal(err) } + defer tx.Rollback() - for i := 0; i < len(o); i++ { - if err = o[i].Insert(tx); err != nil { - t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o[i], err) - } + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) } - j := make({{$tableNameSingular}}Slice, 3) - // Perform all Find queries and assign result objects to slice for comparison - for i := 0; i < len(o); i++ { - j[i], err = {{$tableNameSingular}}Find(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) - if err != nil { - t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) - } - err = {{$varNameSingular}}CompareVals(o[i], j[i], true); if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) } - _ = tx.Rollback() - tx, err = boil.Begin() + if count != 1 { + t.Error("want one record, got:", count) + } +} + +func Test{{$tableNamePlural}}InsertWhitelist(t *testing.T) { + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err := boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) + } + + tx, err := boil.Begin() if err != nil { t.Fatal(err) } defer tx.Rollback() - item := &{{$tableNameSingular}}{} - boil.RandomizeValidatedStruct(item, {{$varNameSingular}}ValidatedColumns, {{$varNameSingular}}DBTypes) - if err = item.Insert(tx); err != nil { - t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err) + if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { + t.Error(err) } - for _, c := range {{$varNameSingular}}AutoIncrementColumns { - // Ensure the auto increment columns are returned in the object. - if errs = boil.IsZeroValue(item, false, c); errs != nil { - for _, e := range errs { - t.Errorf("Expected auto-increment columns to be greater than 0, err: %s\n", e) - } - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) } - defaultValues := []interface{}{{"{"}}{{.Table.Columns | filterColumnsBySimpleDefault | defaultValues | join ", "}}{{"}"}} - - // Ensure the simple default column values are returned correctly. - if len({{$varNameSingular}}ColumnsWithSimpleDefault) > 0 && len(defaultValues) > 0 { - if len({{$varNameSingular}}ColumnsWithSimpleDefault) != len(defaultValues) { - t.Fatalf("Mismatch between slice lengths: %d, %d", len({{$varNameSingular}}ColumnsWithSimpleDefault), len(defaultValues)) - } - - if errs = boil.IsValueMatch(item, {{$varNameSingular}}ColumnsWithSimpleDefault, defaultValues); errs != nil { - for _, e := range errs { - t.Errorf("Expected default value to match column value, err: %s\n", e); - } - } - } - - regularCols := []string{{"{"}}{{.Table.Columns | filterColumnsByAutoIncrement false | filterColumnsByDefault false | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}{{"}"}} - - // Remove the validated columns, they can never be zero values - regularCols = strmangle.SetComplement(regularCols, {{$varNameSingular}}ValidatedColumns) - - // Ensure the non-defaultvalue columns and non-autoincrement columns are stored correctly as zero or null values. - for _, c := range regularCols { - rv := reflect.Indirect(reflect.ValueOf(item)) - field := rv.FieldByName(strmangle.TitleCase(c)) - - zv := reflect.Zero(field.Type()).Interface() - fv := field.Interface() - - if !reflect.DeepEqual(zv, fv) { - t.Errorf("Expected column %s to be zero value, got: %v, wanted: %v", c, fv, zv) - } + if count != 1 { + t.Error("want one record, got:", count) } } diff --git a/templates_test/singleton/boil_helpers.tpl b/templates_test/singleton/boil_helpers.tpl index e23874a..26f09d3 100644 --- a/templates_test/singleton/boil_helpers.tpl +++ b/templates_test/singleton/boil_helpers.tpl @@ -57,10 +57,3 @@ func getDBNameHash(input string) string { initDBNameRand(input) return randStr(40) } - -// byteSliceEqual calls bytes.Equal to check that two -// byte slices are equal. bytes.Equal is not used directly -// to avoid an unecessary conditional type import. -func byteSliceEqual(a []byte, b []byte) bool { - return bytes.Equal(a, b) -} diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 30cb505..56080c1 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -3,95 +3,82 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func Test{{$tableNamePlural}}Update(t *testing.T) { - var err error - - item := {{$tableNameSingular}}{} - boil.RandomizeValidatedStruct(&item, {{$varNameSingular}}ValidatedColumns, {{$varNameSingular}}DBTypes) - if err = item.InsertG(); err != nil { - t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err) - } - - blacklistCols := strmangle.SetMerge({{$varNameSingular}}AutoIncrementColumns, {{$varNameSingular}}PrimaryKeyColumns) - if err = boil.RandomizeStruct(&item, {{$varNameSingular}}DBTypes, false, blacklistCols...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - whitelist := strmangle.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}AutoIncrementColumns) - if err = item.UpdateG(whitelist...); err != nil { - t.Errorf("Unable to update {{$tableNameSingular}}: %s", err) - } - - var j *{{$tableNameSingular}} - j, err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}}) - if err != nil { - t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) - } - - err = {{$varNameSingular}}CompareVals(&item, j, true); if err != nil { - t.Error(err) - } - - {{$varNamePlural}}DeleteAllRows(t) -} - -func Test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { - var err error - - // insert random columns to test UpdateAll - o := make({{$tableNameSingular}}Slice, 3) - j := make({{$tableNameSingular}}Slice, 3) - - if err = boil.RandomizeSlice(&o, {{$varNameSingular}}DBTypes, false); err != nil { + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err := boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) } - for i := 0; i < len(o); i++ { - if err = o[i].InsertG(); err != nil { - t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o[i], err) - } - } - - vals := M{} - - tmp := {{$tableNameSingular}}{} - blacklist := strmangle.SetMerge({{$varNameSingular}}PrimaryKeyColumns, {{$varNameSingular}}UniqueColumns) - if err = boil.RandomizeStruct(&tmp, {{$varNameSingular}}DBTypes, false, blacklist...); err != nil { - t.Errorf("Unable to randomize struct {{$tableNameSingular}}: %s", err) - } - - // Build the columns and column values from the randomized struct - tmpVal := reflect.Indirect(reflect.ValueOf(tmp)) - nonBlacklist := strmangle.SetComplement({{$varNameSingular}}Columns, blacklist) - for _, col := range nonBlacklist { - vals[col] = tmpVal.FieldByName(strmangle.TitleCase(col)).Interface() - } - - err = o.UpdateAllG(vals) + tx, err := boil.Begin() if err != nil { - t.Errorf("Failed to update all for {{$tableNameSingular}}: %s", err) + t.Fatal(err) + } + defer tx.Rollback() + + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) } - for i := 0; i < len(o); i++ { - j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) - if err != nil { - t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) - } - - err = {{$varNameSingular}}CompareVals(j[i], &tmp, true, blacklist...) - if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) } - for i := 0; i < len(o); i++ { - // Ensure Find found the correct primary key ID's - orig := boil.GetStructValues(o[i], {{$varNameSingular}}PrimaryKeyColumns...) - new := boil.GetStructValues(j[i], {{$varNameSingular}}PrimaryKeyColumns...) + if count != 1 { + t.Error("want one record, got:", count) + } - if !reflect.DeepEqual(orig, new) { - t.Errorf("object %d): primary keys do not match:\n\n%#v\n%#v", i, orig, new) - } - } + if err = boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) + } - {{$varNamePlural}}DeleteAllRows(t) + if err = {{$varNameSingular}}.Update(tx); err != nil { + t.Error(err) + } +} + +func Test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err := boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) + } + + tx, err := boil.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 1 { + t.Error("want one record, got:", count) + } + + if err = boil.RandomizeStruct({{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) + } + + // Remove Primary keys and unique columns from what we plan to update + fields := strmangle.SetComplement( + {{$varNameSingular}}Columns, + strmangle.SetMerge({{$varNameSingular}}PrimaryKeyColumns, {{$varNameSingular}}UniqueColumns), + ) + + value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}})) + updateMap := M{} + for _, col := range fields { + updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() + } + + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + if err = slice.UpdateAll(tx, updateMap); err != nil { + t.Error(err) + } }