From f713e041ade3177e1aa515d221fc41e10bc0a1f0 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 19 Jul 2016 14:07:12 +1000 Subject: [PATCH] Add sql query mod, finish or querymod * Add comments to kill lint errors --- boil/hooks.go | 1 + boil/qm/query_mods.go | 39 +++++++++++++++++++++++----- boil/query.go | 60 +++++++++++++++++++++++++++++-------------- boil/query_test.go | 56 +++++++++++++++++++++++++++++++++++++--- boil/reflect.go | 8 +++--- 5 files changed, 130 insertions(+), 34 deletions(-) diff --git a/boil/hooks.go b/boil/hooks.go index 7a2ad06..07cfbb4 100644 --- a/boil/hooks.go +++ b/boil/hooks.go @@ -3,6 +3,7 @@ package boil // HookPoint is the point in time at which we hook type HookPoint int +// the hook point constants const ( HookAfterCreate HookPoint = iota + 1 HookAfterUpdate diff --git a/boil/qm/query_mods.go b/boil/qm/query_mods.go index f0d2e0a..5ce3c8d 100644 --- a/boil/qm/query_mods.go +++ b/boil/qm/query_mods.go @@ -2,85 +2,110 @@ package qm import "github.com/nullbio/sqlboiler/boil" +// QueryMod to modify the query object type QueryMod func(q *boil.Query) +// Apply the query mods to the Query object func Apply(q *boil.Query, mods ...QueryMod) { for _, mod := range mods { mod(q) } } -func Or(whereMods ...QueryMod) QueryMod { +// SQL allows you to execute a plain SQL statement +func SQL(sql string, args ...interface{}) QueryMod { return func(q *boil.Query) { - if len(whereMods) < 2 { - // error, needs to be at least 2 for an or - } - // add the where mods to query with or seperators + boil.SetSQL(q, sql, args...) } } +// Or surrounds where clauses to join them with OR as opposed to AND +func Or(whereMods ...QueryMod) QueryMod { + return func(q *boil.Query) { + if len(whereMods) < 2 { + panic("Or requires at least two arguments") + } + + for _, w := range whereMods { + w(q) + boil.SetLastWhereAsOr(q) + } + } +} + +// Limit the number of returned rows func Limit(limit int) QueryMod { return func(q *boil.Query) { boil.SetLimit(q, limit) } } +// InnerJoin on another table func InnerJoin(on string, args ...interface{}) QueryMod { return func(q *boil.Query) { boil.SetInnerJoin(q, on, args...) } } +// OuterJoin on another table func OuterJoin(on string, args ...interface{}) QueryMod { return func(q *boil.Query) { boil.SetOuterJoin(q, on, args...) } } +// LeftOuterJoin on another table func LeftOuterJoin(on string, args ...interface{}) QueryMod { return func(q *boil.Query) { boil.SetLeftOuterJoin(q, on, args...) } } +// RightOuterJoin on another table func RightOuterJoin(on string, args ...interface{}) QueryMod { return func(q *boil.Query) { boil.SetRightOuterJoin(q, on, args...) } } +// Select specific columns opposed to all columns func Select(columns ...string) QueryMod { return func(q *boil.Query) { boil.SetSelect(q, columns...) } } +// Where allows you to specify a where clause for your statement func Where(clause string, args ...interface{}) QueryMod { return func(q *boil.Query) { boil.SetWhere(q, clause, args...) } } +// GroupBy allows you to specify a group by clause for your statement func GroupBy(clause string) QueryMod { return func(q *boil.Query) { boil.SetGroupBy(q, clause) } } +// OrderBy allows you to specify a order by clause for your statement func OrderBy(clause string) QueryMod { return func(q *boil.Query) { boil.SetOrderBy(q, clause) } } +// Having allows you to specify a having clause for your statement func Having(clause string) QueryMod { return func(q *boil.Query) { boil.SetHaving(q, clause) } } -func Table(table string) QueryMod { +// From allows to specify the table for your statement +func From(from string) QueryMod { return func(q *boil.Query) { - boil.SetTable(q, table) + boil.SetFrom(q, from) } } diff --git a/boil/query.go b/boil/query.go index 02de26b..ff44c9d 100644 --- a/boil/query.go +++ b/boil/query.go @@ -13,18 +13,25 @@ type where struct { args []interface{} } +type plainSQL struct { + sql string + args []interface{} +} + type join struct { on string args []interface{} } +// Query holds the state for the built up query type Query struct { executor Executor + plainSQL plainSQL delete bool update map[string]interface{} selectCols []string count bool - table string + from string innerJoins []join outerJoins []join leftOuterJoins []join @@ -41,6 +48,8 @@ func buildQuery(q *Query) (string, []interface{}) { var args []interface{} switch { + case len(q.plainSQL.sql) != 0: + return q.plainSQL.sql, q.plainSQL.args case q.delete: buf, args = buildDeleteQuery(q) case len(q.update) > 0: @@ -61,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteString("COUNT(") } if len(q.selectCols) > 0 { - buf.WriteString(strings.Join(q.selectCols, ",")) + buf.WriteString(`"` + strings.Join(q.selectCols, `","`) + `"`) } else { buf.WriteByte('*') } @@ -71,7 +80,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { } buf.WriteString(" FROM ") - fmt.Fprintf(buf, `"%s"`, q.table) + fmt.Fprintf(buf, `"%s"`, q.from) where, args := whereClause(q) buf.WriteString(where) @@ -84,7 +93,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := &bytes.Buffer{} buf.WriteString("DELETE FROM ") - fmt.Fprintf(buf, `"%s"`, q.table) + fmt.Fprintf(buf, `"%s"`, q.from) where, args := whereClause(q) buf.WriteString(where) @@ -128,6 +137,11 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) { return q.executor.Query(qs, args...) } +// SetSQL on the query. +func SetSQL(q *Query, sql string, args ...interface{}) { + q.plainSQL = plainSQL{sql: sql, args: args} +} + // SetCount on the query. func SetCount(q *Query) { q.count = true @@ -160,9 +174,9 @@ func Select(q *Query) []string { return cols } -// SetTable on the query. -func SetTable(q *Query, table string) { - q.table = table +// SetFrom on the query. +func SetFrom(q *Query, from string) { + q.from = from } // SetInnerJoin on the query. @@ -190,6 +204,11 @@ func SetWhere(q *Query, clause string, args ...interface{}) { q.where = append(q.where, where{clause: clause, args: args}) } +// SetLastWhereAsOr sets the or seperator for the last element in the where slice +func SetLastWhereAsOr(q *Query) { + q.where[len(q.where)-1].orSeperator = true +} + // SetGroupBy on the query. func SetGroupBy(q *Query, clause string) { q.groupBy = append(q.groupBy, clause) @@ -211,21 +230,24 @@ func SetLimit(q *Query, limit int) { } func whereClause(q *Query) (string, []interface{}) { + if len(q.where) == 0 { + return "", nil + } + buf := &bytes.Buffer{} var args []interface{} - if len(q.where) > 0 { - buf.WriteString(" WHERE ") - for i := 0; i < len(q.where); i++ { - buf.WriteString(fmt.Sprintf("%s", q.where[i].clause)) - args = append(args, q.where[i].args...) - if i != len(q.where)-1 { - if q.where[i].orSeperator { - buf.WriteString(" OR ") - } else { - buf.WriteString(" AND ") - } - } + buf.WriteString(" WHERE ") + for i := 0; i < len(q.where); i++ { + buf.WriteString(fmt.Sprintf("%s", q.where[i].clause)) + args = append(args, q.where[i].args...) + if i >= len(q.where)-1 { + continue + } + if q.where[i].orSeperator { + buf.WriteString(" OR ") + } else { + buf.WriteString(" AND ") } } diff --git a/boil/query_test.go b/boil/query_test.go index 62adc35..cd5e297 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -26,7 +26,7 @@ func TestBuildQuery(t *testing.T) { q *Query args []interface{} }{ - {&Query{table: "t"}, []interface{}{}}, + {&Query{from: "t"}, []interface{}{}}, } for i, test := range tests { @@ -57,6 +57,39 @@ func TestBuildQuery(t *testing.T) { } } +func TestSetLastWhereAsOr(t *testing.T) { + t.Parallel() + q := &Query{} + + SetWhere(q, "") + + if q.where[0].orSeperator { + t.Errorf("Do not want or seperator") + } + + SetLastWhereAsOr(q) + + if len(q.where) != 1 { + t.Errorf("Want len 1") + } + if !q.where[0].orSeperator { + t.Errorf("Want or seperator") + } + + SetWhere(q, "") + SetLastWhereAsOr(q) + + if len(q.where) != 2 { + t.Errorf("Want len 2") + } + if q.where[0].orSeperator != true { + t.Errorf("Expected true") + } + if q.where[1].orSeperator != true { + t.Errorf("Expected true") + } +} + func TestSetLimit(t *testing.T) { t.Parallel() @@ -69,6 +102,21 @@ func TestSetLimit(t *testing.T) { } } +func TestSetSQL(t *testing.T) { + t.Parallel() + + q := &Query{} + SetSQL(q, "select * from thing", 5, 3) + + if len(q.plainSQL.args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.plainSQL.args)) + } + + if q.plainSQL.sql != "select * from thing" { + t.Errorf("Was not expected string, got %s", q.plainSQL.sql) + } +} + func TestSetWhere(t *testing.T) { t.Parallel() @@ -133,11 +181,11 @@ func TestSetTable(t *testing.T) { t.Parallel() q := &Query{} - SetTable(q, "videos a, orders b") + SetFrom(q, "videos a, orders b") expect := "videos a, orders b" - if q.table != expect { - t.Errorf("Expected %s, got %s", expect, q.table) + if q.from != expect { + t.Errorf("Expected %s, got %s", expect, q.from) } } diff --git a/boil/reflect.go b/boil/reflect.go index 04188d2..3ec0371 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -25,16 +25,16 @@ func (q *Query) Bind(obj interface{}) error { row := ExecQueryOne(q) err := BindOne(row, q.selectCols, obj) if err != nil { - return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err) + return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err) } } else if kind == reflect.Slice { rows, err := ExecQueryAll(q) if err != nil { - return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err) + return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err) } err = BindAll(rows, q.selectCols, obj) if err != nil { - return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.table, err) + return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.from, err) } } else { return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String()) @@ -43,7 +43,7 @@ func (q *Query) Bind(obj interface{}) error { return nil } -// Bind executes the query and inserts the +// BindP executes the query and inserts the // result into the passed in object pointer. // It panics on error. func (q *Query) BindP(obj interface{}) {