From 94d36d7bf7fc1d711491a31415d9f7bad51b252e Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 17 Aug 2016 15:19:23 +1000 Subject: [PATCH] Begin IN implementation --- boil/qm/query_mods.go | 33 ++++++- boil/query.go | 100 +++++++++------------ boil/query_builders.go | 65 +++++++++++++- boil/query_builders_test.go | 171 ++++++++++++++++++++++++++++++++++++ boil/query_test.go | 125 ++++++++++++++++++++------ strmangle/strmangle.go | 9 ++ 6 files changed, 410 insertions(+), 93 deletions(-) diff --git a/boil/qm/query_mods.go b/boil/qm/query_mods.go index 0fe1c12..959172f 100644 --- a/boil/qm/query_mods.go +++ b/boil/qm/query_mods.go @@ -71,24 +71,51 @@ func Or(clause string, args ...interface{}) QueryMod { } } +// WhereIn allows you to specify a "x IN (set)" clause for your where statement +// Example clauses: "column in ?", "(column1,column2) in ?" +func WhereIn(clause string, args ...interface{}) QueryMod { + return func(q *boil.Query) { + boil.AppendIn(q, clause, args...) + } +} + +// AndIn allows you to specify a "x IN (set)" clause separated by an AndIn +// for your where statement. AndIn is a duplicate of the WhereIn function, but +// allows for more natural looking query mod chains, for example: +// (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?")) +func AndIn(clause string, args ...interface{}) QueryMod { + return func(q *boil.Query) { + boil.AppendIn(q, clause, args...) + } +} + +// OrIn allows you to specify an IN clause separated by +// an OR for your where statement +func OrIn(clause string, args ...interface{}) QueryMod { + return func(q *boil.Query) { + boil.SetLastInAsOr(q) + boil.AppendIn(q, clause, args...) + } +} + // GroupBy allows you to specify a group by clause for your statement func GroupBy(clause string) QueryMod { return func(q *boil.Query) { - boil.ApplyGroupBy(q, clause) + boil.AppendGroupBy(q, clause) } } // OrderBy allows you to specify a order by clause for your statement func OrderBy(clause string) QueryMod { return func(q *boil.Query) { - boil.ApplyOrderBy(q, clause) + boil.AppendOrderBy(q, clause) } } // Having allows you to specify a having clause for your statement func Having(clause string, args ...interface{}) QueryMod { return func(q *boil.Query) { - boil.ApplyHaving(q, clause, args...) + boil.AppendHaving(q, clause, args...) } } diff --git a/boil/query.go b/boil/query.go index 8070c54..146791a 100644 --- a/boil/query.go +++ b/boil/query.go @@ -27,6 +27,7 @@ type Query struct { from []string joins []join where []where + in []in groupBy []string orderBy []string having []having @@ -40,6 +41,12 @@ type where struct { args []interface{} } +type in struct { + clause string + orSeparator bool + args []interface{} +} + type having struct { clause string args []interface{} @@ -96,6 +103,11 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) { return q.executor.Query(qs, args...) } +// SetExecutor on the query. +func SetExecutor(q *Query, exec Executor) { + q.executor = exec +} + // SetSQL on the query. func SetSQL(q *Query, sql string, args ...interface{}) { q.plainSQL = plainSQL{sql: sql, args: args} @@ -131,33 +143,26 @@ func SetDelete(q *Query) { q.delete = true } +// SetLimit on the query. +func SetLimit(q *Query, limit int) { + q.limit = limit +} + +// SetOffset on the query. +func SetOffset(q *Query, offset int) { + q.offset = offset +} + // SetUpdate on the query. func SetUpdate(q *Query, cols map[string]interface{}) { q.update = cols } -// SetExecutor on the query. -func SetExecutor(q *Query, exec Executor) { - q.executor = exec -} - // AppendSelect on the query. func AppendSelect(q *Query, columns ...string) { q.selectCols = append(q.selectCols, columns...) } -// SetSelect replaces the current select clause. -func SetSelect(q *Query, columns ...string) { - q.selectCols = append([]string(nil), columns...) -} - -// Select returns the select columns in the query. -func Select(q *Query) []string { - cols := make([]string, len(q.selectCols)) - copy(cols, q.selectCols) - return cols -} - // AppendFrom on the query. func AppendFrom(q *Query, from ...string) { q.from = append(q.from, from...) @@ -173,9 +178,9 @@ func AppendInnerJoin(q *Query, clause string, args ...interface{}) { q.joins = append(q.joins, join{clause: clause, kind: JoinInner, args: args}) } -// SetInnerJoin on the query. -func SetInnerJoin(q *Query, clause string, args ...interface{}) { - q.joins = append([]join(nil), join{clause: clause, kind: JoinInner, args: args}) +// AppendHaving on the query. +func AppendHaving(q *Query, clause string, args ...interface{}) { + q.having = append(q.having, having{clause: clause, args: args}) } // AppendWhere on the query. @@ -183,12 +188,12 @@ func AppendWhere(q *Query, clause string, args ...interface{}) { q.where = append(q.where, where{clause: clause, args: args}) } -// SetWhere on the query. -func SetWhere(q *Query, clause string, args ...interface{}) { - q.where = append([]where(nil), where{clause: clause, args: args}) +// AppendIn on the query. +func AppendIn(q *Query, clause string, args ...interface{}) { + q.in = append(q.in, in{clause: clause, args: args}) } -// SetLastWhereAsOr sets the or separator for the tail where in the slice +// SetLastWhereAsOr sets the or separator for the tail "WHERE" in the slice func SetLastWhereAsOr(q *Query) { if len(q.where) == 0 { return @@ -197,42 +202,21 @@ func SetLastWhereAsOr(q *Query) { q.where[len(q.where)-1].orSeparator = true } -// ApplyGroupBy on the query. -func ApplyGroupBy(q *Query, clause string) { +// SetLastInAsOr sets the or separator for the tail "IN" in the slice +func SetLastInAsOr(q *Query) { + if len(q.in) == 0 { + return + } + + q.in[len(q.in)-1].orSeparator = true +} + +// AppendGroupBy on the query. +func AppendGroupBy(q *Query, clause string) { q.groupBy = append(q.groupBy, clause) } -// SetGroupBy on the query. -func SetGroupBy(q *Query, clause string) { - q.groupBy = append([]string(nil), clause) -} - -// ApplyOrderBy on the query. -func ApplyOrderBy(q *Query, clause string) { +// AppendOrderBy on the query. +func AppendOrderBy(q *Query, clause string) { q.orderBy = append(q.orderBy, clause) } - -// SetOrderBy on the query. -func SetOrderBy(q *Query, clause string) { - q.orderBy = append([]string(nil), clause) -} - -// ApplyHaving on the query. -func ApplyHaving(q *Query, clause string, args ...interface{}) { - q.having = append(q.having, having{clause: clause, args: args}) -} - -// SetHaving on the query. -func SetHaving(q *Query, clause string, args ...interface{}) { - q.having = append([]having(nil), having{clause: clause, args: args}) -} - -// SetLimit on the query. -func SetLimit(q *Query, limit int) { - q.limit = limit -} - -// SetOffset on the query. -func SetOffset(q *Query, offset int) { - q.offset = offset -} diff --git a/boil/query_builders.go b/boil/query_builders.go index 187df78..7173e2a 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -83,8 +83,16 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { where, whereArgs := whereClause(q, len(args)+1) buf.WriteString(where) + if len(whereArgs) != 0 { + args = append(args, whereArgs...) + } + + in, inArgs := inClause(q, len(args)+1) + buf.WriteString(in) + if len(inArgs) != 0 { + args = append(args, inArgs...) + } - args = append(args, whereArgs...) writeModifiers(q, buf, &args) buf.WriteByte(';') @@ -92,14 +100,24 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { } func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { + var args []interface{} buf := strmangle.GetBuffer() buf.WriteString("DELETE FROM ") buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) - where, args := whereClause(q, 1) + where, whereArgs := whereClause(q, 1) + if len(whereArgs) != 0 { + args = append(args, whereArgs) + } buf.WriteString(where) + in, inArgs := inClause(q, len(args)+1) + if len(inArgs) != 0 { + args = append(args, inArgs...) + } + buf.WriteString(in) + writeModifiers(q, buf, &args) buf.WriteByte(';') @@ -136,9 +154,17 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { ) where, whereArgs := whereClause(q, len(args)+1) + if len(whereArgs) != 0 { + args = append(args, whereArgs...) + } buf.WriteString(where) - args = append(args, whereArgs...) + in, inArgs := inClause(q, len(args)+1) + if len(inArgs) != 0 { + args = append(args, inArgs...) + } + buf.WriteString(in) + writeModifiers(q, buf, &args) buf.WriteByte(';') @@ -260,6 +286,39 @@ func whereClause(q *Query, startAt int) (string, []interface{}) { return convertQuestionMarks(buf.String(), startAt), args } +func inClause(q *Query, startAt int) (string, []interface{}) { + if len(q.in) == 0 { + return "", nil + } + + buf := strmangle.GetBuffer() + defer strmangle.PutBuffer(buf) + var args []interface{} + + if len(q.where) == 0 { + buf.WriteString(" WHERE ") + } + for i := 0; i < len(q.in); i++ { + + } + + // regexp split thing so we have left side and right side + // split on )IN( / \sIN\s, combine them + + // buf.WriteString(convertQuestionMarks(leftSide, startAt)) + // buf.WriteString(" IN ") + // buf.WriteString(convertInQuestionMarks(rightSide, total, group, startAt+offset)) + + return "", args +} + +func convertInQuestionMarks(clause string, total, groupAt, startAt int) string { + return "" +} + +// convertQuestionMarks converts each occurence of ? with $ +// where is an incrementing digit starting at startAt. +// If question-mark (?) is escaped using back-slash (\), it will be ignored. func convertQuestionMarks(clause string, startAt int) string { if startAt == 0 { panic("Not a valid start number.") diff --git a/boil/query_builders_test.go b/boil/query_builders_test.go index 82dd90c..3b31095 100644 --- a/boil/query_builders_test.go +++ b/boil/query_builders_test.go @@ -334,6 +334,138 @@ func TestWhereClause(t *testing.T) { } } +func TestInClause(t *testing.T) { + t.Parallel() + + tests := []struct { + q Query + expect string + }{ + // Or("a=?") + { + q: Query{ + in: []in{{clause: "a in ?", args: []interface{}{1}, orSeparator: true}}, + }, + expect: " WHERE a IN ($1)", + }, + { + q: Query{ + in: []in{{clause: "a in ?", args: []interface{}{1, 2, 3}}}, + }, + expect: " WHERE a IN ($1,$2,$3)", + }, + // // Where("a=?") + // { + // q: Query{ + // where: []where{where{clause: "a=?"}}, + // }, + // expect: " WHERE (a=$1)", + // }, + // // Where("(a=?)") + // { + // q: Query{ + // where: []where{where{clause: "(a=?)"}}, + // }, + // expect: " WHERE ((a=$1))", + // }, + // // Where("((a=? OR b=?))") + // { + // q: Query{ + // where: []where{where{clause: "((a=? OR b=?))"}}, + // }, + // expect: " WHERE (((a=$1 OR b=$2)))", + // }, + // // Where("(a=?)", Or("(b=?)") + // { + // q: Query{ + // where: []where{ + // where{clause: "(a=?)", orSeparator: true}, + // where{clause: "(b=?)"}, + // }, + // }, + // expect: " WHERE ((a=$1)) OR ((b=$2))", + // }, + // // Where("a=? OR b=?") + // { + // q: Query{ + // where: []where{where{clause: "a=? OR b=?"}}, + // }, + // expect: " WHERE (a=$1 OR b=$2)", + // }, + // // Where("a=?"), Where("b=?") + // { + // q: Query{ + // where: []where{where{clause: "a=?"}, where{clause: "b=?"}}, + // }, + // expect: " WHERE (a=$1) AND (b=$2)", + // }, + // // Where("(a=? AND b=?) OR c=?") + // { + // q: Query{ + // where: []where{where{clause: "(a=? AND b=?) OR c=?"}}, + // }, + // expect: " WHERE ((a=$1 AND b=$2) OR c=$3)", + // }, + // // Where("a=? OR b=?"), Where("c=? OR d=? OR e=?") + // { + // q: Query{ + // where: []where{ + // where{clause: "(a=? OR b=?)"}, + // where{clause: "(c=? OR d=? OR e=?)"}, + // }, + // }, + // expect: " WHERE ((a=$1 OR b=$2)) AND ((c=$3 OR d=$4 OR e=$5))", + // }, + // // Where("(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR f=?") + // { + // q: Query{ + // where: []where{ + // where{clause: "(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR g=?"}, + // }, + // }, + // expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 AND e=$5) OR f=$6 OR g=$7)", + // }, + // // Where("(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?") + // { + // q: Query{ + // where: []where{ + // where{clause: "(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?"}, + // }, + // }, + // expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 OR e=$5) OR f=$6 OR g=$7)", + // }, + // // Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?") + // { + // q: Query{ + // where: []where{ + // where{clause: "a=? or b=?", orSeparator: true}, + // where{clause: "c=? and d=?", orSeparator: true}, + // where{clause: "e=? or f=?", orSeparator: true}, + // }, + // }, + // expect: " WHERE (a=$1 or b=$2) OR (c=$3 and d=$4) OR (e=$5 or f=$6)", + // }, + // // Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?") + // { + // q: Query{ + // where: []where{ + // where{clause: "a=? or b=?"}, + // where{clause: "c=? and d=?"}, + // where{clause: "e=? or f=?"}, + // }, + // }, + // expect: " WHERE (a=$1 or b=$2) AND (c=$3 and d=$4) AND (e=$5 or f=$6)", + // }, + } + + for i, test := range tests { + result, _ := inClause(&test.q, 1) + if result != test.expect { + t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) + } + } +} + func TestConvertQuestionMarks(t *testing.T) { t.Parallel() @@ -373,6 +505,45 @@ func TestConvertQuestionMarks(t *testing.T) { } } +func TestConvertInQuestionMarks(t *testing.T) { + t.Parallel() + + tests := []struct { + clause string + start int + expect string + }{ + {clause: "hello friend", start: 1, expect: "hello friend"}, + {clause: "thing=?", start: 2, expect: "thing=$2"}, + {clause: "thing=? and stuff=? and happy=?", start: 2, expect: "thing=$2 and stuff=$3 and happy=$4"}, + {clause: `thing \? stuff`, start: 2, expect: `thing ? stuff`}, + {clause: `thing \? stuff and happy \? fun`, start: 2, expect: `thing ? stuff and happy ? fun`}, + { + clause: `thing \? stuff ? happy \? and mad ? fun \? \? \?`, + start: 2, + expect: `thing ? stuff $2 happy ? and mad $3 fun ? ? ?`, + }, + { + clause: `thing ? stuff ? happy \? fun \? ? ?`, + start: 1, + expect: `thing $1 stuff $2 happy ? fun ? $3 $4`, + }, + {clause: `?`, start: 1, expect: `$1`}, + {clause: `???`, start: 1, expect: `$1$2$3`}, + {clause: `\?`, start: 1, expect: `?`}, + {clause: `\?\?\?`, start: 1, expect: `???`}, + {clause: `\??\??\??`, start: 1, expect: `?$1?$2?$3`}, + {clause: `?\??\??\?`, start: 1, expect: `$1?$2?$3?`}, + } + + for i, test := range tests { + res := convertQuestionMarks(test.clause, test.start) + if res != test.expect { + t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res) + } + } +} + func TestWriteAsStatements(t *testing.T) { t.Parallel() diff --git a/boil/query_test.go b/boil/query_test.go index 18ad3f1..8952628 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -45,7 +45,7 @@ func TestSetSQL(t *testing.T) { } } -func TestWhere(t *testing.T) { +func TestAppendWhere(t *testing.T) { t.Parallel() q := &Query{} @@ -69,7 +69,7 @@ func TestWhere(t *testing.T) { t.Errorf("args wrong: %#v", q.where) } - SetWhere(q, expect, 5, 3) + q.where = []where{{clause: expect, args: []interface{}{5, 3}}} if q.where[0].clause != expect { t.Errorf("Expected %s, got %v", expect, q.where) } @@ -120,49 +120,124 @@ func TestSetLastWhereAsOr(t *testing.T) { } } -func TestGroupBy(t *testing.T) { +func TestAppendIn(t *testing.T) { + t.Parallel() + + q := &Query{} + expect := "col IN ?" + AppendIn(q, expect, 5, 3) + AppendIn(q, expect, 5, 3) + + if len(q.in) != 2 { + t.Errorf("%#v", q.in) + } + + if q.in[0].clause != expect || q.in[1].clause != expect { + t.Errorf("Expected %s, got %#v", expect, q.in) + } + + if len(q.in[0].args) != 2 || len(q.in[0].args) != 2 { + t.Errorf("arg length wrong: %#v", q.in) + } + + if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 { + t.Errorf("args wrong: %#v", q.in) + } + + q.in = []in{{clause: expect, args: []interface{}{5, 3}}} + if q.in[0].clause != expect { + t.Errorf("Expected %s, got %v", expect, q.in) + } + + if len(q.in[0].args) != 2 { + t.Errorf("Expected %d args, got %d", 2, len(q.in[0].args)) + } + + if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 { + t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.in[0].args) + } + + if len(q.in) != 1 { + t.Errorf("%#v", q.in) + } +} + +func TestSetLastInAsOr(t *testing.T) { + t.Parallel() + q := &Query{} + + AppendIn(q, "") + + if q.in[0].orSeparator { + t.Errorf("Do not want or separator") + } + + SetLastInAsOr(q) + + if len(q.in) != 1 { + t.Errorf("Want len 1") + } + if !q.in[0].orSeparator { + t.Errorf("Want or separator") + } + + AppendIn(q, "") + SetLastInAsOr(q) + + if len(q.in) != 2 { + t.Errorf("Want len 2") + } + if q.in[0].orSeparator != true { + t.Errorf("Expected true") + } + if q.in[1].orSeparator != true { + t.Errorf("Expected true") + } +} + +func TestAppendGroupBy(t *testing.T) { t.Parallel() q := &Query{} expect := "col1, col2" - ApplyGroupBy(q, expect) - ApplyGroupBy(q, expect) + AppendGroupBy(q, expect) + AppendGroupBy(q, expect) if len(q.groupBy) != 2 && (q.groupBy[0] != expect || q.groupBy[1] != expect) { t.Errorf("Expected %s, got %s %s", expect, q.groupBy[0], q.groupBy[1]) } - SetGroupBy(q, expect) + q.groupBy = []string{expect} if len(q.groupBy) != 1 && q.groupBy[0] != expect { t.Errorf("Expected %s, got %s", expect, q.groupBy[0]) } } -func TestOrderBy(t *testing.T) { +func TestAppendOrderBy(t *testing.T) { t.Parallel() q := &Query{} expect := "col1 desc, col2 asc" - ApplyOrderBy(q, expect) - ApplyOrderBy(q, expect) + AppendOrderBy(q, expect) + AppendOrderBy(q, expect) if len(q.orderBy) != 2 && (q.orderBy[0] != expect || q.orderBy[1] != expect) { t.Errorf("Expected %s, got %s %s", expect, q.orderBy[0], q.orderBy[1]) } - SetOrderBy(q, "col1 desc, col2 asc") + q.orderBy = []string{"col1 desc, col2 asc"} if len(q.orderBy) != 1 && q.orderBy[0] != expect { t.Errorf("Expected %s, got %s", expect, q.orderBy[0]) } } -func TestHaving(t *testing.T) { +func TestAppendHaving(t *testing.T) { t.Parallel() q := &Query{} expect := "count(orders.order_id) > ?" - ApplyHaving(q, expect, 10) - ApplyHaving(q, expect, 10) + AppendHaving(q, expect, 10) + AppendHaving(q, expect, 10) if len(q.having) != 2 { t.Errorf("Expected 2, got %d", len(q.having)) @@ -176,7 +251,7 @@ func TestHaving(t *testing.T) { t.Errorf("Expected %v, got %v %v", 10, q.having[0].args[0], q.having[1].args[0]) } - SetHaving(q, expect, 10) + q.having = []having{{clause: expect, args: []interface{}{10}}} if len(q.having) != 1 && (q.having[0].clause != expect || q.having[0].args[0] != 10) { t.Errorf("Expected %s, got %s %v", expect, q.having[0], q.having[0].args[0]) } @@ -307,24 +382,12 @@ func TestAppendSelect(t *testing.T) { t.Errorf("select cols value mismatch: %#v", q.selectCols) } - SetSelect(q, "col1", "col2") + q.selectCols = []string{"col1", "col2"} if q.selectCols[0] != `col1` && q.selectCols[1] != `col2` { t.Errorf("select cols value mismatch: %#v", q.selectCols) } } -func TestSelect(t *testing.T) { - t.Parallel() - - q := &Query{} - q.selectCols = []string{"one"} - - ret := Select(q) - if ret[0] != "one" { - t.Errorf("Expected %q, got %s", "one", ret[0]) - } -} - func TestSQL(t *testing.T) { t.Parallel() @@ -337,7 +400,7 @@ func TestSQL(t *testing.T) { } } -func TestInnerJoin(t *testing.T) { +func TestAppendInnerJoin(t *testing.T) { t.Parallel() q := &Query{} @@ -366,7 +429,11 @@ func TestInnerJoin(t *testing.T) { t.Errorf("Invalid args values, got %#v", q.joins[0].args) } - SetInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5) + q.joins = []join{{kind: JoinInner, + clause: "thing=$1 AND stuff=$2", + args: []interface{}{2, 5}, + }} + if len(q.joins) != 1 { t.Errorf("Expected len 1, got %d", len(q.joins)) } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index ea85f81..56206ef 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -19,6 +19,15 @@ var ( smartQuoteRgx = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(\."?[_a-z][_a-z0-9]*"?)*(\.\*)?$`) ) +func init() { + // Override the uncountable inflections with an empty set. + // This way, people using words like Sheep will not have + // collisions with their model name (Sheep) and their + // function name (Sheep()). Instead, it will + // use the regular inflection rules: Sheep, Sheeps(). + inflection.SetUncountable([]string{}) +} + // IdentQuote attempts to quote simple identifiers in SQL tatements func IdentQuote(s string) string { if strings.ToLower(s) == "null" {