Add ReloadAll for ObjectSlice

* Fix RandomizeSlice bug overwriting blacklisted columns
* Add GroupAt param to param flags generator
This commit is contained in:
Patrick O'brien 2016-08-08 23:30:29 +10:00
parent 6fc2ad8760
commit fa8e431349
8 changed files with 179 additions and 12 deletions

View file

@ -119,8 +119,15 @@ Outer:
// GenerateParamFlags generates the SQL statement parameter flags // GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt. // For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string { //
return strmangle.GenerateParamFlags(colCount, startAt) // If GroupAt is greater than 1, instead of returning $1,$2,$3
// it will return wrapped groups of param flags, for example:
//
// GroupAt(1): $1,$2,$3,$4,$5,$6
// GroupAt(2): ($1,$2),($3,$4),($5,$6)
// GroupAt(3): ($1,$2,$3),($4,$5,$6),($7,$8,$9)
func GenerateParamFlags(colCount int, startAt int, groupAt int) string {
return strmangle.GenerateParamFlags(colCount, startAt, groupAt)
} }
// WherePrimaryKeyIn generates a "in" string for where queries // WherePrimaryKeyIn generates a "in" string for where queries

View file

@ -140,7 +140,10 @@ func RandomizeSlice(obj interface{}, colTypes map[string]string, includeInvalid
for i := 0; i < ptrSlice.Len(); i++ { for i := 0; i < ptrSlice.Len(); i++ {
o := ptrSlice.Index(i) o := ptrSlice.Index(i)
o.Set(reflect.New(structTyp)) if o.IsNil() {
o.Set(reflect.New(structTyp))
}
if err := RandomizeStruct(o.Interface(), colTypes, includeInvalid, blacklist...); err != nil { if err := RandomizeStruct(o.Interface(), colTypes, includeInvalid, blacklist...); err != nil {
return err return err
} }

View file

@ -5,6 +5,7 @@
package strmangle package strmangle
import ( import (
"bytes"
"fmt" "fmt"
"math" "math"
"regexp" "regexp"
@ -208,14 +209,38 @@ func PrefixStringSlice(str string, strs []string) []string {
// GenerateParamFlags generates the SQL statement parameter flags // GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt. // For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string { //
cols := make([]string, 0, colCount) // If GroupAt is greater than 1, instead of returning $1,$2,$3
// it will return wrapped groups of param flags, for example:
//
// GroupAt(1): $1,$2,$3,$4,$5,$6
// GroupAt(2): ($1,$2),($3,$4),($5,$6)
// GroupAt(3): ($1,$2,$3),($4,$5,$6),($7,$8,$9)
func GenerateParamFlags(colCount int, startAt int, groupAt int) string {
var buf bytes.Buffer
for i := startAt; i < colCount+startAt; i++ { if groupAt > 1 {
cols = append(cols, fmt.Sprintf("$%d", i)) buf.WriteByte('(')
} }
return strings.Join(cols, ",") groupCounter := 0
for i := startAt; i < colCount+startAt; i++ {
groupCounter++
buf.WriteString(fmt.Sprintf("$%d", i))
if i+1 != colCount+startAt {
if groupAt > 1 && groupCounter == groupAt {
buf.WriteString("),(")
groupCounter = 0
} else {
buf.WriteByte(',')
}
}
}
if groupAt > 1 {
buf.WriteByte(')')
}
return buf.String()
} }
// WhereClause returns the where clause using start as the $ flag index // WhereClause returns the where clause using start as the $ flag index

View file

@ -83,8 +83,32 @@ func TestDriverUsesLastInsertID(t *testing.T) {
func TestGenerateParamFlags(t *testing.T) { func TestGenerateParamFlags(t *testing.T) {
t.Parallel() t.Parallel()
x := GenerateParamFlags(5, 1) x := GenerateParamFlags(1, 2, 1)
want := "$1,$2,$3,$4,$5" want := "$2"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = GenerateParamFlags(5, 1, 1)
want = "$1,$2,$3,$4,$5"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = GenerateParamFlags(6, 1, 2)
want = "($1,$2),($3,$4),($5,$6)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = GenerateParamFlags(9, 1, 3)
want = "($1,$2,$3),($4,$5,$6),($7,$8,$9)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = GenerateParamFlags(7, 1, 3)
want = "($1,$2,$3),($4,$5,$6),($7)"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
} }

View file

@ -38,7 +38,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
return err return err
} }
ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), boil.GenerateParamFlags(len(wl), 1)) ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), boil.GenerateParamFlags(len(wl), 1, 1))
{{if driverUsesLastInsertID .DriverName}} {{if driverUsesLastInsertID .DriverName}}
if len(returnColumns) != 0 { if len(returnColumns) != 0 {

View file

@ -104,7 +104,7 @@ func (o *{{$tableNameSingular}}) generateUpsertQuery(update bool, columns upsert
query = fmt.Sprintf( query = fmt.Sprintf(
`INSERT INTO {{.Table.Name}} (%s) VALUES (%s) ON CONFLICT`, `INSERT INTO {{.Table.Name}} (%s) VALUES (%s) ON CONFLICT`,
strings.Join(columns.whitelist, `, `), strings.Join(columns.whitelist, `, `),
boil.GenerateParamFlags(len(columns.whitelist), 1), boil.GenerateParamFlags(len(columns.whitelist), 1, 1),
) )
if !update { if !update {

View file

@ -1,5 +1,6 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
// ReloadGP refetches the object from the database and panics on error. // ReloadGP refetches the object from the database and panics on error.
func (o *{{$tableNameSingular}}) ReloadGP() { func (o *{{$tableNameSingular}}) ReloadGP() {
if err := o.ReloadG(); err != nil { if err := o.ReloadG(); err != nil {
@ -34,3 +35,61 @@ func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error {
*o = *ret *o = *ret
return nil return nil
} }
func (o *{{$tableNameSingular}}Slice) ReloadAllGP() {
if err := o.ReloadAllG(); err != nil {
panic(boil.WrapErr(err))
}
}
func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) {
if err := o.ReloadAll(exec); err != nil {
panic(boil.WrapErr(err))
}
}
func (o *{{$tableNameSingular}}Slice) ReloadAllG() error {
if o == nil {
return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all")
}
return o.ReloadAll(boil.GetDB())
}
// ReloadAll refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error {
if len(*o) == 0 {
return nil
}
{{$varNamePlural}} := {{$tableNameSingular}}Slice{}
var args []interface{}
for i := 0; i < len(*o); i++ {
args = append(args, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "(*o)[i]." | join ", "}})
}
sql := fmt.Sprintf(
`select {{.Table.Name}}.* from {{.Table.Name}} where (%s) in (%s)`,
strings.Join({{$varNameSingular}}PrimaryKeyColumns, ","),
strmangle.GenerateParamFlags(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
q := boil.SQL(sql, args...)
boil.SetExecutor(q, exec)
err := q.Bind(&{{$varNamePlural}})
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice: %v", err)
}
*o = {{$varNamePlural}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args)
}
return nil
}

View file

@ -40,3 +40,52 @@ func Test{{$tableNamePlural}}Reload(t *testing.T) {
{{$varNamePlural}}DeleteAllRows(t) {{$varNamePlural}}DeleteAllRows(t)
} }
func Test{{$tableNamePlural}}ReloadAll(t *testing.T) {
var err error
o1 := make({{$tableNameSingular}}Slice, 3)
o2 := make({{$tableNameSingular}}Slice, 3)
if err = boil.RandomizeSlice(&o1, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err)
}
for i := 0; i < len(o1); i++ {
if err = o1[i].InsertG(); err != nil {
t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o1[i], err)
}
}
for i := 0; i < len(o1); i++ {
o2[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o1[i]." | join ", "}})
if err != nil {
t.Errorf("Unable to find {{$tableNameSingular}} row.")
}
{{$varNameSingular}}CompareVals(o1[i], o2[i], t)
}
// Randomize the struct values again, except for the primary key values, so we can call update.
err = boil.RandomizeSlice(&o1, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...)
if err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} slice excluding primary keys: %s", err)
}
colsWithoutPrimKeys := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns)
for i := 0; i < len(o1); i++ {
if err = o1[i].UpdateG(colsWithoutPrimKeys...); err != nil {
t.Errorf("Unable to update the {{$tableNameSingular}} row: %s", err)
}
}
if err = o2.ReloadAllG(); err != nil {
t.Errorf("Unable to reload {{$tableNameSingular}} object: %s", err)
}
for i := 0; i < len(o1); i++ {
{{$varNameSingular}}CompareVals(o2[i], o1[i], t)
}
{{$varNamePlural}}DeleteAllRows(t)
}