diff --git a/bdb/column.go b/bdb/column.go index 56d6ff4..6ce1424 100644 --- a/bdb/column.go +++ b/bdb/column.go @@ -97,7 +97,7 @@ func FilterColumnsByValidated(columns []Column) []Column { var cols []Column for _, c := range columns { - if c.Validated == true { + if c.Validated { cols = append(cols, c) } } @@ -110,7 +110,7 @@ func FilterColumnsByUnique(columns []Column) []Column { var cols []Column for _, c := range columns { - if c.Unique == true { + if c.Unique { cols = append(cols, c) } } diff --git a/boil/query_builders.go b/boil/query_builders.go index 494ec29..187df78 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -320,7 +320,7 @@ func identifierMapping(q *Query) map[string]string { return ids } -// parseBits takes a set of tokens and looks for something of the form: +// parseIdentifierClause takes a set of tokens and looks for something of the form: // a b // a as b // where 'a' and 'b' are valid SQL identifiers diff --git a/imports.go b/imports.go index 37c8fb0..80463b1 100644 --- a/imports.go +++ b/imports.go @@ -157,9 +157,14 @@ var defaultTemplateImports = imports{ var defaultSingletonTemplateImports = map[string]imports{ "boil_helpers": { + standard: importList{ + `"fmt"`, + `"strings"`, + }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil/qm"`, + `"github.com/vattle/sqlboiler/strmangle"`, }, }, "boil_types": { diff --git a/strmangle/sets.go b/strmangle/sets.go index 2901adb..b60dc39 100644 --- a/strmangle/sets.go +++ b/strmangle/sets.go @@ -1,5 +1,43 @@ package strmangle +// UpdateColumnSet generates the set of columns to update for an update statement. +// if a whitelist is supplied, it's returned +// if a whitelist is missing then we begin with all columns +// then we remove the primary key columns +func UpdateColumnSet(allColumns, pkeyCols, whitelist []string) []string { + if len(whitelist) != 0 { + return whitelist + } + + return SetComplement(allColumns, pkeyCols) +} + +// InsertColumnSet generates the set of columns to insert and return for an insert statement +// the return columns are used to get values that are assigned within the database during the +// insert to keep the struct in sync with what's in the db. +// with a whitelist: +// - the whitelist is used for the insert columns +// - the return columns are the result of (columns with default values - the whitelist) +// without a whitelist: +// - start with columns without a default as these always need to be inserted +// - add all columns that have a default in the database but that are non-zero in the struct +// - the return columns are the result of (columns with default values - the previous set) +func InsertColumnSet(cols, defaults, noDefaults, nonZeroDefaults, whitelist []string) ([]string, []string) { + if len(whitelist) > 0 { + return whitelist, SetComplement(defaults, whitelist) + } + + var wl []string + wl = append(wl, noDefaults...) + wl = SetMerge(nonZeroDefaults, wl) + wl = SortByKeys(cols, wl) + + // Only return the columns with default values that are not in the insert whitelist + rc := SetComplement(defaults, wl) + + return wl, rc +} + // SetInclude checks to see if the string is found in the string slice func SetInclude(str string, slice []string) bool { for _, s := range slice { diff --git a/strmangle/sets_test.go b/strmangle/sets_test.go index 7bbc389..34e5a22 100644 --- a/strmangle/sets_test.go +++ b/strmangle/sets_test.go @@ -5,6 +5,86 @@ import ( "testing" ) +func TestUpdateColumnSet(t *testing.T) { + t.Parallel() + + tests := []struct { + Cols []string + PKeys []string + Whitelist []string + Out []string + }{ + {Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{"b"}}, + {Cols: []string{"a", "b"}, PKeys: []string{"a"}, Whitelist: []string{"a"}, Out: []string{"a"}}, + {Cols: []string{"a", "b"}, PKeys: []string{"a"}, Whitelist: []string{"a", "b"}, Out: []string{"a", "b"}}, + } + + for i, test := range tests { + set := UpdateColumnSet(test.Cols, test.PKeys, test.Whitelist) + + if !reflect.DeepEqual(set, test.Out) { + t.Errorf("%d) set was wrong\nwant: %v\ngot: %v", i, test.Out, set) + } + } +} + +func TestInsertColumnSet(t *testing.T) { + t.Parallel() + + columns := []string{"a", "b", "c"} + defaults := []string{"a", "c"} + nodefaults := []string{"b"} + + tests := []struct { + Cols []string + Defaults []string + NoDefaults []string + NonZeroDefaults []string + Whitelist []string + Set []string + Ret []string + }{ + // No whitelist + {Set: []string{"b"}, Ret: []string{"a", "c"}}, + {Defaults: []string{}, NoDefaults: []string{"a", "b", "c"}, Set: []string{"a", "b", "c"}, Ret: []string{}}, + + // No whitelist + Nonzero defaults + {NonZeroDefaults: []string{"a"}, Set: []string{"a", "b"}, Ret: []string{"c"}}, + {NonZeroDefaults: []string{"c"}, Set: []string{"b", "c"}, Ret: []string{"a"}}, + + // Whitelist + {Whitelist: []string{"a"}, Set: []string{"a"}, Ret: []string{"c"}}, + {Whitelist: []string{"c"}, Set: []string{"c"}, Ret: []string{"a"}}, + {Whitelist: []string{"a", "c"}, Set: []string{"a", "c"}, Ret: []string{}}, + {Whitelist: []string{"a", "b", "c"}, Set: []string{"a", "b", "c"}, Ret: []string{}}, + + // Whitelist + Nonzero defaults (shouldn't care, same results as above) + {Whitelist: []string{"a"}, NonZeroDefaults: []string{"c"}, Set: []string{"a"}, Ret: []string{"c"}}, + {Whitelist: []string{"c"}, NonZeroDefaults: []string{"b"}, Set: []string{"c"}, Ret: []string{"a"}}, + } + + for i, test := range tests { + if test.Cols == nil { + test.Cols = columns + } + if test.Defaults == nil { + test.Defaults = defaults + } + if test.NoDefaults == nil { + test.NoDefaults = nodefaults + } + + set, ret := InsertColumnSet(test.Cols, test.Defaults, test.NoDefaults, test.NonZeroDefaults, test.Whitelist) + + if !reflect.DeepEqual(set, test.Set) { + t.Errorf("%d) set was wrong\nwant: %v\ngot: %v", i, test.Set, set) + } + if !reflect.DeepEqual(ret, test.Ret) { + t.Errorf("%d) ret was wrong\nwant: %v\ngot: %v", i, test.Ret, ret) + } + } +} + func TestSetInclude(t *testing.T) { t.Parallel() diff --git a/templates/08_insert.tpl b/templates/08_insert.tpl index 073354d..f1fc32d 100644 --- a/templates/08_insert.tpl +++ b/templates/08_insert.tpl @@ -31,7 +31,13 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") } - wl, returnColumns := o.generateInsertColumns(whitelist...) + wl, returnColumns := strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, + ) var err error if err := o.doBeforeCreateHooks(); err != nil { @@ -85,31 +91,3 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string return o.doAfterCreateHooks() } - -// generateInsertColumns generates the whitelist columns and return columns for an insert statement -// the return columns are used to get values that are assigned within the database during the -// insert to keep the struct in sync with what's in the db. -// with a whitelist: -// - the whitelist is used for the insert columns -// - the return columns are the result of (columns with default values - the whitelist) -// without a whitelist: -// - start with columns without a default as these always need to be inserted -// - add all columns that have a default in the database but that are non-zero in the struct -// - the return columns are the result of (columns with default values - the previous set) -func (o *{{$tableNameSingular}}) generateInsertColumns(whitelist ...string) ([]string, []string) { - if len(whitelist) > 0 { - return whitelist, strmangle.SetComplement({{$varNameSingular}}ColumnsWithDefault, whitelist) - } - - var wl []string - - wl = append(wl, {{$varNameSingular}}ColumnsWithoutDefault...) - - wl = strmangle.SetMerge(boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), wl) - wl = strmangle.SortByKeys({{$varNameSingular}}Columns, wl) - - // Only return the columns with default values that are not in the insert whitelist - rc := strmangle.SetComplement({{$varNameSingular}}ColumnsWithDefault, wl) - - return wl, rc -} diff --git a/templates/09_update.tpl b/templates/09_update.tpl index f5ceaa9..2a0a252 100644 --- a/templates/09_update.tpl +++ b/templates/09_update.tpl @@ -41,7 +41,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string var query string var values []interface{} - wl := o.generateUpdateColumns(whitelist...) + wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) if len(wl) == 0 { return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") @@ -145,15 +145,3 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error return nil } - -// generateUpdateColumns generates the whitelist columns for an update statement -// if a whitelist is supplied, it's returned -// if a whitelist is missing then we begin with all columns -// then we remove the primary key columns -func (o *{{$tableNameSingular}}) generateUpdateColumns(whitelist ...string) []string { - if len(whitelist) != 0 { - return whitelist - } - - return strmangle.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) -} diff --git a/templates/10_upsert.tpl b/templates/10_upsert.tpl index 542c8af..09d26e9 100644 --- a/templates/10_upsert.tpl +++ b/templates/10_upsert.tpl @@ -21,13 +21,31 @@ func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, update bool, confli } // Upsert attempts an insert using an executor, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, update bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { +func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { if o == nil { return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") } - columns := o.generateUpsertColumns(conflictColumns, updateColumns, whitelist) - query := o.generateUpsertQuery(update, columns) + var ret []string + whitelist, ret = strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, + ) + update := strmangle.UpdateColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + updateColumns, + ) + conflict := conflictColumns + if len(conflict) == 0 { + conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) + copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + } + + query := generateUpsertQuery("{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) var err error if err := o.doBeforeUpsertHooks(); err != nil { @@ -36,14 +54,18 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, update bool, conflic if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, columns.whitelist...)) + fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) } - if len(columns.returning) != 0 { - err = exec.QueryRow(query, boil.GetStructValues(o, columns.whitelist...)...).Scan(boil.GetStructPointers(o, columns.returning...)...) + {{- if .UseLastInsertID}} + return errors.New("don't know how to do this yet") + {{- else}} + if len(ret) != 0 { + err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) } else { - _, err = exec.Exec(query, {{.Table.Columns | columnNames | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) + _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) } + {{- end}} if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") @@ -55,69 +77,3 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, update bool, conflic return nil } - -// generateUpsertColumns builds an upsertData object, using generated values when necessary. -func (o *{{$tableNameSingular}}) generateUpsertColumns(conflict []string, update []string, whitelist []string) upsertData { - var upsertCols upsertData - - upsertCols.whitelist, upsertCols.returning = o.generateInsertColumns(whitelist...) - - upsertCols.conflict = make([]string, len(conflict)) - upsertCols.update = make([]string, len(update)) - - // generates the ON CONFLICT() columns if none are provided - upsertCols.conflict = o.generateConflictColumns(conflict...) - - // generate the UPDATE SET columns if none are provided - upsertCols.update = o.generateUpdateColumns(update...) - - return upsertCols -} - -// generateConflictColumns returns the user provided columns. -// If no columns are provided, it returns the primary key columns. -func (o *{{$tableNameSingular}}) generateConflictColumns(columns ...string) []string { - if len(columns) != 0 { - return columns - } - - c := make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) - copy(c, {{$varNameSingular}}PrimaryKeyColumns) - - return c -} - -// generateUpsertQuery builds a SQL statement string using the upsertData provided. -func (o *{{$tableNameSingular}}) generateUpsertQuery(update bool, columns upsertData) string { - var set, query string - - conflict := strmangle.IdentQuoteSlice(columns.conflict) - whitelist := strmangle.IdentQuoteSlice(columns.whitelist) - returning := strmangle.IdentQuoteSlice(columns.returning) - - var sets []string - // Generate the UPDATE SET clause - for _, v := range columns.update { - quoted := strmangle.IdentQuote(v) - sets = append(sets, fmt.Sprintf("%s = EXCLUDED.%s", quoted, quoted)) - } - set = strings.Join(sets, ", ") - - query = fmt.Sprintf( - "INSERT INTO {{.Table.Name}} (%s) VALUES (%s) ON CONFLICT", - strings.Join(whitelist, ", "), - strmangle.Placeholders(len(whitelist), 1, 1), - ) - - if !update { - query = query + " DO NOTHING" - } else { - query = fmt.Sprintf("%s (%s) DO UPDATE SET %s", query, strings.Join(conflict, ", "), set) - } - - if len(columns.returning) != 0 { - query = fmt.Sprintf("%s RETURNING %s", query, strings.Join(returning, ", ")) - } - - return query -} diff --git a/templates/singleton/boil_helpers.tpl b/templates/singleton/boil_helpers.tpl index 6cb607d..7df71a2 100644 --- a/templates/singleton/boil_helpers.tpl +++ b/templates/singleton/boil_helpers.tpl @@ -11,3 +11,46 @@ func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { return q } + +// generateUpsertQuery builds a SQL statement string using the upsertData provided. +func generateUpsertQuery(tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { + conflict = strmangle.IdentQuoteSlice(conflict) + whitelist = strmangle.IdentQuoteSlice(whitelist) + ret = strmangle.IdentQuoteSlice(ret) + + buf := strmangle.GetBuffer() + defer strmangle.PutBuffer(buf) + + fmt.Fprintf( + buf, + "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ", + tableName, + strings.Join(whitelist, ", "), + strmangle.Placeholders(len(whitelist), 1, 1), + ) + + if !updateOnConflict { + buf.WriteString("DO NOTHING") + } else { + buf.WriteByte('(') + buf.WriteString(strings.Join(conflict, ", ")) + buf.WriteString(") DO UPDATE SET") + + for i, v := range update { + if i != 0 { + buf.WriteByte(',') + } + quoted := strmangle.IdentQuote(v) + buf.WriteString(quoted) + buf.WriteString(" = EXCLUDED.") + buf.WriteString(quoted) + } + } + + if len(ret) != 0 { + buf.WriteString(" RETURNING ") + buf.WriteString(strings.Join(ret, ", ")) + } + + return buf.String() +} diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index 52bdb78..b05ab34 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -1,13 +1,6 @@ // M type is for providing columns and column values to UpdateAll. type M map[string]interface{} -type upsertData struct { - conflict []string - update []string - whitelist []string - returning []string -} - // ErrSyncFail occurs during insert when the record could not be retrieved in // order to populate default value information. This usually happens when LastInsertId // fails or there was a primary key configuration that was not resolvable. diff --git a/templates_test/insert.tpl b/templates_test/insert.tpl index d26740e..97ba512 100644 --- a/templates_test/insert.tpl +++ b/templates_test/insert.tpl @@ -20,8 +20,13 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) { t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) } + tx, err := boil.Begin() + if err != nil { + t.Fatal(err) + } + for i := 0; i < len(o); i++ { - if err = o[i].InsertG(); err != nil { + if err = o[i].Insert(tx); err != nil { t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o[i], err) } } @@ -29,20 +34,25 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) { j := make({{$tableNameSingular}}Slice, 3) // Perform all Find queries and assign result objects to slice for comparison for i := 0; i < len(o); i++ { - j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) + j[i], err = {{$tableNameSingular}}Find(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) if err != nil { t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) } - err = {{$varNameSingular}}CompareVals(o[i], j[i], true); if err != nil { + err = {{$varNameSingular}}CompareVals(o[i], j[i], true); if err != nil { t.Error(err) } } - {{$varNamePlural}}DeleteAllRows(t) + _ = tx.Rollback() + tx, err = boil.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() item := &{{$tableNameSingular}}{} boil.RandomizeValidatedStruct(item, {{$varNameSingular}}ValidatedColumns, {{$varNameSingular}}DBTypes) - if err = item.InsertG(); err != nil { + if err = item.Insert(tx); err != nil { t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err) } @@ -87,30 +97,4 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) { t.Errorf("Expected column %s to be zero value, got: %v, wanted: %v", c, fv, zv) } } - - item = &{{$tableNameSingular}}{} - - wl, rc := item.generateInsertColumns() - if !reflect.DeepEqual(rc, {{$varNameSingular}}ColumnsWithDefault) { - t.Errorf("Expected return columns to contain all columns with default values:\n\nGot: %v\nWanted: %v", rc, {{$varNameSingular}}ColumnsWithDefault) - } - - if !reflect.DeepEqual(wl, {{$varNameSingular}}ColumnsWithoutDefault) { - t.Errorf("Expected whitelist to contain all columns without default values:\n\nGot: %v\nWanted: %v", wl, {{$varNameSingular}}ColumnsWithoutDefault) - } - - if err = boil.RandomizeStruct(item, {{$varNameSingular}}DBTypes, false); err != nil { - t.Errorf("Unable to randomize item: %s", err) - } - - wl, rc = item.generateInsertColumns() - if len(rc) > 0 { - t.Errorf("Expected return columns to contain no columns:\n\nGot: %v", rc) - } - - if !reflect.DeepEqual(wl, {{$varNameSingular}}Columns) { - t.Errorf("Expected whitelist to contain all columns values:\n\nGot: %v\nWanted: %v", wl, {{$varNameSingular}}Columns) - } - - {{$varNamePlural}}DeleteAllRows(t) } diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 98ae0c4..30cb505 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -31,16 +31,6 @@ func Test{{$tableNamePlural}}Update(t *testing.T) { t.Error(err) } - wl := item.generateUpdateColumns("test") - if len(wl) != 1 && wl[0] != "test" { - t.Errorf("Expected generateUpdateColumns whitelist to match expected whitelist") - } - - wl = item.generateUpdateColumns() - if len(wl) == 0 && len({{$varNameSingular}}ColumnsWithoutDefault) > 0 { - t.Errorf("Expected generateUpdateColumns to build a whitelist for {{$tableNameSingular}}, but got 0 results") - } - {{$varNamePlural}}DeleteAllRows(t) } diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index 83005f3..7d9c115 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -7,71 +7,22 @@ func Test{{$tableNamePlural}}Upsert(t *testing.T) { o := {{$tableNameSingular}}{} - columns := o.generateUpsertColumns([]string{"one", "two"}, []string{"three", "four"}, []string{"five", "six"}) - if columns.conflict[0] != "one" || columns.conflict[1] != "two" { - t.Errorf("Expected conflict to be %v, got %v", []string{"one", "two"}, columns.conflict) - } - - if columns.update[0] != "three" || columns.update[1] != "four" { - t.Errorf("Expected update to be %v, got %v", []string{"three", "four"}, columns.update) - } - - if columns.whitelist[0] != "five" || columns.whitelist[1] != "six" { - t.Errorf("Expected whitelist to be %v, got %v", []string{"five", "six"}, columns.whitelist) - } - - columns = o.generateUpsertColumns(nil, nil, nil) - if len(columns.whitelist) == 0 { - t.Errorf("Expected whitelist to contain columns, but got len 0") - } - - if len(columns.conflict) == 0 { - t.Errorf("Expected conflict to contain columns, but got len 0") - } - - if len(columns.update) == 0 { - t.Errorf("expected update to contain columns, but got len 0") - } - - upsertCols := upsertData{ - conflict: []string{"key1", `"key2"`}, - update: []string{"aaa", `"bbb"`}, - whitelist: []string{"thing", `"stuff"`}, - returning: []string{}, - } - - query := o.generateUpsertQuery(false, upsertCols) - expectedQuery := `INSERT INTO {{.Table.Name}} ("thing", "stuff") VALUES ($1,$2) ON CONFLICT DO NOTHING` - - if query != expectedQuery { - t.Errorf("Expected query mismatch:\n\n%s\n%s\n", query, expectedQuery) - } - - query = o.generateUpsertQuery(true, upsertCols) - expectedQuery = `INSERT INTO {{.Table.Name}} ("thing", "stuff") VALUES ($1,$2) ON CONFLICT ("key1", "key2") DO UPDATE SET "aaa" = EXCLUDED."aaa", "bbb" = EXCLUDED."bbb"` - - if query != expectedQuery { - t.Errorf("Expected query mismatch:\n\n%s\n%s\n", query, expectedQuery) - } - - upsertCols.returning = []string{"stuff"} - query = o.generateUpsertQuery(true, upsertCols) - expectedQuery = expectedQuery + ` RETURNING "stuff"` - - if query != expectedQuery { - t.Errorf("Expected query mismatch:\n\n%s\n%s\n", query, expectedQuery) - } - // Attempt the INSERT side of an UPSERT if err = boil.RandomizeStruct(&o, {{$varNameSingular}}DBTypes, true); err != nil { t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) } - if err = o.UpsertG(false, nil, nil); err != nil { + tx, err := boil.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + if err = o.Upsert(tx, false, nil, nil); err != nil { t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) } - compare, err := {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) + compare, err := {{$tableNameSingular}}Find(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) if err != nil { t.Errorf("Unable to find {{$tableNameSingular}}: %s", err) } @@ -84,17 +35,15 @@ func Test{{$tableNamePlural}}Upsert(t *testing.T) { t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) } - if err = o.UpsertG(true, nil, nil); err != nil { + if err = o.Upsert(tx, true, nil, nil); err != nil { t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) } - compare, err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) + compare, err = {{$tableNameSingular}}Find(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) if err != nil { t.Errorf("Unable to find {{$tableNameSingular}}: %s", err) } err = {{$varNameSingular}}CompareVals(&o, compare, true); if err != nil { t.Error(err) } - - {{$varNamePlural}}DeleteAllRows(t) }