Finish UpdateAll query builder

* Add modifiers to delete builder
* Update golden file tests
* Add startAt to whereClause
This commit is contained in:
Patrick O'brien 2016-08-11 18:23:47 +10:00
parent c3f8cff117
commit e3f319346f
10 changed files with 240 additions and 40 deletions

View file

@ -1 +1 @@
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3);
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3) LIMIT 5;

1
boil/_fixtures/10.sql Normal file
View file

@ -0,0 +1 @@
UPDATE thing happy, "fun", "stuff" SET ("col1", "col2", "fun"."col3") VALUES ($1, $2, $3) WHERE (aa=$4 or bb=$5) OR (cc=$6) LIMIT 5;

View file

@ -70,9 +70,65 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
fmt.Fprintf(buf, " INNER JOIN %s", j.clause)
}
where, args := whereClause(q)
where, args := whereClause(q, 1)
buf.WriteString(where)
writeModifiers(q, buf)
buf.WriteByte(';')
return buf, args
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
where, args := whereClause(q, 1)
buf.WriteString(where)
writeModifiers(q, buf)
buf.WriteByte(';')
return buf, args
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
cols := make([]string, len(q.update))
args := make([]interface{}, len(q.update))
count := 0
for name, value := range q.update {
cols[count] = strmangle.IdentQuote(name)
args[count] = value
count++
}
buf.WriteString(fmt.Sprintf(
" SET (%s) VALUES (%s)",
strings.Join(cols, ", "),
strmangle.Placeholders(len(cols), 1, 1)),
)
where, whereArgs := whereClause(q, len(args)+1)
buf.WriteString(where)
args = append(args, whereArgs...)
writeModifiers(q, buf)
buf.WriteByte(';')
return buf, args
}
func writeModifiers(q *Query, buf *bytes.Buffer) {
if len(q.groupBy) != 0 {
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
}
@ -92,9 +148,6 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
if q.offset != 0 {
fmt.Fprintf(buf, " OFFSET %d", q.offset)
}
buf.WriteByte(';')
return buf, args
}
func writeStars(q *Query) []string {
@ -144,28 +197,12 @@ func writeAsStatements(q *Query) []string {
return cols
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
where, args := whereClause(q)
buf.WriteString(where)
buf.WriteByte(';')
return buf, args
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteByte(';')
return buf, nil
}
func whereClause(q *Query) (string, []interface{}) {
// whereClause parses a where slice and converts it into a
// single WHERE clause like:
// WHERE (a=$1) AND (b=$2)
//
// startAt specifies what number placeholders start at
func whereClause(q *Query, startAt int) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
@ -211,7 +248,35 @@ func whereClause(q *Query) (string, []interface{}) {
paramIndex++
}
return paramBuf.String(), args
return convertQuestionMarks(buf.String(), startAt), args
}
func convertQuestionMarks(clause string, startAt int) string {
if startAt == 0 {
panic("Not a valid start number.")
}
paramBuf := &bytes.Buffer{}
paramIndex := 0
for ; ; startAt++ {
if paramIndex >= len(clause) {
break
}
clause = clause[paramIndex:]
paramIndex = strings.IndexByte(clause, '?')
if paramIndex == -1 {
paramBuf.WriteString(clause)
break
}
paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
paramIndex++
}
return paramBuf.String()
}
// identifierMapping creates a map of all identifiers to potential model names

View file

@ -59,7 +59,21 @@ func TestBuildQuery(t *testing.T) {
where: []where{
where{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}},
},
limit: 5,
}, nil},
{&Query{
from: []string{"thing happy", `"fun"`, `stuff`},
update: map[string]interface{}{
"col1": 1,
`"col2"`: 2,
`"fun".col3`: 3,
},
where: []where{
where{clause: "aa=? or bb=?", orSeparator: true, args: []interface{}{4, 5}},
where{clause: "cc=?", args: []interface{}{6}},
},
limit: 5,
}, []interface{}{1, 2, 3, 4, 5, 6}},
}
for i, test := range tests {
@ -297,7 +311,7 @@ func TestWhereClause(t *testing.T) {
}
for i, test := range tests {
result, _ := whereClause(&test.q)
result, _ := whereClause(&test.q, 1)
if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
}

View file

@ -208,20 +208,24 @@ func PrefixStringSlice(str string, strs []string) []string {
}
// Placeholders generates the SQL statement placeholders for in queries.
// For example, ($1,$2,$3),($4,$5,$6) etc.
// For example, ($1, $2, $3), ($4, $5, $6) etc.
// It will start counting placeholders at "start".
func Placeholders(count int, start int, group int) string {
var buf bytes.Buffer
if start == 0 || group == 0 {
panic("Invalid start or group numbers supplied.")
}
if group > 1 {
buf.WriteByte('(')
}
for i := 0; i < count; i++ {
if i != 0 {
if group > 1 && i%group == 0 {
buf.WriteString(`),(`)
buf.WriteString("), (")
} else {
buf.WriteByte(',')
buf.WriteString(", ")
}
}
buf.WriteString(fmt.Sprintf("$%d", start+i))

View file

@ -68,7 +68,14 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string
return nil
}
// UpdateAll updates all rows with matching column names.
// UpdateAllP updates all rows with matching column names, and panics on error.
func (q {{$varNameSingular}}Query) UpdateAllP(cols M) {
if err := q.UpdateAll(cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAll updates all rows with the specified column values.
func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
boil.SetUpdate(q.Query, cols)
@ -80,13 +87,72 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
return nil
}
// UpdateAllP updates all rows with matching column names, and panics on error.
func (q {{$varNameSingular}}Query) UpdateAllP(cols M) {
if err := q.UpdateAll(cols); err != nil {
// UpdateAllG updates all rows with the specified column values.
func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error {
return o.UpdateAll(boil.GetDB(), cols)
}
// UpdateAllGP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) {
if err := o.UpdateAll(boil.GetDB(), cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAllP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) {
if err := o.UpdateAll(exec, cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAll updates all rows with the specified column values, using an executor.
func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for update all")
}
if len(o) == 0 {
return nil
}
colNames := make([]string, len(cols))
var args []interface{}
count := 0
for name, value := range cols {
colNames[count] = name
args = append(args, value)
count++
}
// Append all of the primary key values for each column
args = append(args, o.inPrimaryKeyArgs())
sql := fmt.Sprintf(
`UPDATE {{.Table.Name}} SET (%s) VALUES (%s) WHERE (%s) IN (%s)`,
strings.Join(colNames, ", "),
strmangle.Placeholders(len(args), 1, 1),
strings.Join({{$varNameSingular}}PrimaryKeyColumns, ","),
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(args)+1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
q := boil.SQL(sql, args...)
boil.SetExecutor(q, exec)
_, err := boil.ExecQuery(q)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to update all in {{$varNameSingular}} slice: %s", err)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, )
}
return nil
}
// generateUpdateColumns generates the whitelist columns for an update statement
// if a whitelist is supplied, it's returned
// if a whitelist is missing then we begin with all columns

View file

@ -1,6 +1,3 @@
// M type is for providing where filters to Where helpers.
type M map[string]interface{}
// NewQueryG initializes a new Query using the passed in QueryMods
func NewQueryG(mods ...qm.QueryMod) *boil.Query {
return NewQuery(boil.GetDB(), mods...)

View file

@ -1,3 +1,6 @@
// M type is for providing columns and column values to UpdateAll.
type M map[string]interface{}
type upsertData struct {
conflict []string
update []string

View file

@ -28,8 +28,11 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
j := make({{$tableNameSingular}}Slice, 3)
// Perform all Find queries and assign result objects to slice for comparison
for i := 0; i < len(j); i++ {
for i := 0; i < len(o); i++ {
j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}})
if err != nil {
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
}
{{$varNameSingular}}CompareVals(o[i], j[i], t)
}

View file

@ -41,3 +41,50 @@ func Test{{$tableNamePlural}}Update(t *testing.T) {
{{$varNamePlural}}DeleteAllRows(t)
}
func Test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) {
var err error
// insert random columns to test UpdateAll
o := make({{$tableNameSingular}}Slice, 3)
j := make({{$tableNameSingular}}Slice, 3)
if err = boil.RandomizeSlice(&o, {{$varNameSingular}}DBTypes, false); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err)
}
for i := 0; i < len(o); i++ {
if err = o[i].InsertG(); err != nil {
t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o[i], err)
}
}
vals := M{}
tmp := {{$tableNameSingular}}{}
if err = boil.RandomizeStruct(&tmp, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize struct {{$tableNameSingular}}: %s", err)
}
// Build the columns and column values from the randomized struct
tmpVal := reflect.Indirect(reflect.ValueOf(tmp))
nonPrimKeys := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns)
for _, col := range nonPrimKeys {
vals[col] = tmpVal.FieldByName(strmangle.TitleCase(col)).Interface()
}
err = o.UpdateAllG(vals)
if err != nil {
t.Errorf("Failed to update all for {{$tableNameSingular}}: %s", err)
}
for i := 0; i < len(o); i++ {
j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}})
if err != nil {
t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err)
}
{{$varNameSingular}}CompareVals(o[i], &tmp, t)
}
{{$varNamePlural}}DeleteAllRows(t)
}