Add args to Having query mod

* Add more thorough golden tests
* Fix bug in update column ordering, now uses sort
This commit is contained in:
Patrick O'brien 2016-08-13 03:08:09 +10:00
parent d1265764dc
commit 4571184b7d
7 changed files with 75 additions and 38 deletions

View file

@ -1 +1 @@
SELECT * FROM "a" GROUP BY id, name HAVING id <> 1, length(name, 'utf8') > 5;
SELECT * FROM "a" WHERE (a=$1 or b=$2) AND (c=$3) GROUP BY id, name HAVING id <> ?, length(name, ?) > ?;

View file

@ -1 +1 @@
UPDATE thing happy, "fun", "stuff" SET ("col1", "col2", "fun"."col3") = ($1, $2, $3) WHERE (aa=$4 or bb=$5) OR (cc=$6) LIMIT 5;
UPDATE thing happy, "fun", "stuff" SET ("col2", "fun"."col3", "col1") = ($1, $2, $3) WHERE (aa=$4 or bb=$5 or cc=$6) OR (dd=$7 or ee=$8 or ff=$9 and gg=$10) LIMIT 5;

View file

@ -86,9 +86,9 @@ func OrderBy(clause string) QueryMod {
}
// Having allows you to specify a having clause for your statement
func Having(clause string) QueryMod {
func Having(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.ApplyHaving(q, clause)
boil.ApplyHaving(q, clause, args...)
}
}

View file

@ -29,7 +29,7 @@ type Query struct {
where []where
groupBy []string
orderBy []string
having []string
having []having
limit int
offset int
}
@ -40,6 +40,11 @@ type where struct {
args []interface{}
}
type having struct {
clause string
args []interface{}
}
type plainSQL struct {
sql string
args []interface{}
@ -213,13 +218,13 @@ func SetOrderBy(q *Query, clause string) {
}
// ApplyHaving on the query.
func ApplyHaving(q *Query, clause string) {
q.having = append(q.having, clause)
func ApplyHaving(q *Query, clause string, args ...interface{}) {
q.having = append(q.having, having{clause: clause, args: args})
}
// SetHaving on the query.
func SetHaving(q *Query, clause string) {
q.having = append([]string(nil), clause)
func SetHaving(q *Query, clause string, args ...interface{}) {
q.having = append([]having(nil), having{clause: clause, args: args})
}
// SetLimit on the query.

View file

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"regexp"
"sort"
"strings"
"github.com/vattle/sqlboiler/strmangle"
@ -63,17 +64,20 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
var args []interface{}
for _, j := range q.joins {
if j.kind != JoinInner {
panic("only inner joins are supported")
}
fmt.Fprintf(buf, " INNER JOIN %s", j.clause)
args = append(args, j.args...)
}
where, args := whereClause(q, 1)
where, whereArgs := whereClause(q, len(args)+1)
buf.WriteString(where)
writeModifiers(q, buf)
args = append(args, whereArgs...)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
@ -88,7 +92,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
where, args := whereClause(q, 1)
buf.WriteString(where)
writeModifiers(q, buf)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
@ -101,16 +105,22 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
cols := make([]string, len(q.update))
args := make([]interface{}, len(q.update))
cols := make(sort.StringSlice, len(q.update))
var args []interface{}
count := 0
for name, value := range q.update {
cols[count] = strmangle.IdentQuote(name)
args[count] = value
for name := range q.update {
cols[count] = name
count++
}
cols.Sort()
for i := 0; i < len(cols); i++ {
args = append(args, q.update[cols[i]])
cols[i] = strmangle.IdentQuote(cols[i])
}
buf.WriteString(fmt.Sprintf(
" SET (%s) = (%s)",
strings.Join(cols, ", "),
@ -119,22 +129,29 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
where, whereArgs := whereClause(q, len(args)+1)
buf.WriteString(where)
args = append(args, whereArgs...)
writeModifiers(q, buf)
args = append(args, whereArgs...)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func writeModifiers(q *Query, buf *bytes.Buffer) {
func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
if len(q.groupBy) != 0 {
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
}
if len(q.having) != 0 {
fmt.Fprintf(buf, " HAVING %s", strings.Join(q.having, ", "))
fmt.Fprintf(buf, " HAVING ")
for i, j := range q.having {
if i > 0 {
fmt.Fprintf(buf, ", ")
}
fmt.Fprintf(buf, j.clause)
*args = append(*args, j.args...)
}
}
if len(q.orderBy) != 0 {

View file

@ -42,22 +42,29 @@ func TestBuildQuery(t *testing.T) {
{&Query{
from: []string{"a"},
groupBy: []string{"id", "name"},
having: []string{"id <> 1", "length(name, 'utf8') > 5"},
}, nil},
where: []where{
{clause: "a=? or b=?", args: []interface{}{1, 2}},
{clause: "c=?", args: []interface{}{3}},
},
having: []having{
{clause: "id <> ?", args: []interface{}{1}},
{clause: "length(name, ?) > ?", args: []interface{}{"utf8", 5}},
},
}, []interface{}{1, 2, 3, 1, "utf8", 5}},
{&Query{
delete: true,
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
where: []where{
where{clause: "a=?", args: []interface{}{}},
where{clause: "b=?", args: []interface{}{}},
where{clause: "c=?", args: []interface{}{}},
{clause: "a=?", args: []interface{}{}},
{clause: "b=?", args: []interface{}{}},
{clause: "c=?", args: []interface{}{}},
},
}, nil},
{&Query{
delete: true,
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
where: []where{
where{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}},
{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}},
},
limit: 5,
}, nil},
@ -69,11 +76,11 @@ func TestBuildQuery(t *testing.T) {
`"fun".col3`: 3,
},
where: []where{
where{clause: "aa=? or bb=?", orSeparator: true, args: []interface{}{4, 5}},
where{clause: "cc=?", args: []interface{}{6}},
{clause: "aa=? or bb=? or cc=?", orSeparator: true, args: []interface{}{4, 5, 6}},
{clause: "dd=? or ee=? or ff=? and gg=?", args: []interface{}{7, 8, 9, 10}},
},
limit: 5,
}, []interface{}{1, 2, 3, 4, 5, 6}},
}, []interface{}{2, 3, 1, 4, 5, 6, 7, 8, 9, 10}},
}
for i, test := range tests {

View file

@ -160,17 +160,25 @@ func TestHaving(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "count(orders.order_id) > 10"
ApplyHaving(q, expect)
ApplyHaving(q, expect)
expect := "count(orders.order_id) > ?"
ApplyHaving(q, expect, 10)
ApplyHaving(q, expect, 10)
if len(q.having) != 2 && (q.having[0] != expect || q.having[1] != expect) {
t.Errorf("Expected %s, got %s %s", expect, q.having[0], q.having[1])
if len(q.having) != 2 {
t.Errorf("Expected 2, got %d", len(q.having))
}
SetHaving(q, expect)
if len(q.having) != 1 && q.having[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.having[0])
if q.having[0].clause != expect || q.having[1].clause != expect {
t.Errorf("Expected %s, got %s %s", expect, q.having[0].clause, q.having[1].clause)
}
if q.having[0].args[0] != 10 || q.having[1].args[0] != 10 {
t.Errorf("Expected %v, got %v %v", 10, q.having[0].args[0], q.having[1].args[0])
}
SetHaving(q, expect, 10)
if len(q.having) != 1 && (q.having[0].clause != expect || q.having[0].args[0] != 10) {
t.Errorf("Expected %s, got %s %v", expect, q.having[0], q.having[0].args[0])
}
}