Add sql query mod, finish or querymod

* Add comments to kill lint errors
This commit is contained in:
Patrick O'brien 2016-07-19 14:07:12 +10:00
parent b447018220
commit f713e041ad
5 changed files with 130 additions and 34 deletions

View file

@ -3,6 +3,7 @@ package boil
// HookPoint is the point in time at which we hook // HookPoint is the point in time at which we hook
type HookPoint int type HookPoint int
// the hook point constants
const ( const (
HookAfterCreate HookPoint = iota + 1 HookAfterCreate HookPoint = iota + 1
HookAfterUpdate HookAfterUpdate

View file

@ -2,85 +2,110 @@ package qm
import "github.com/nullbio/sqlboiler/boil" import "github.com/nullbio/sqlboiler/boil"
// QueryMod to modify the query object
type QueryMod func(q *boil.Query) type QueryMod func(q *boil.Query)
// Apply the query mods to the Query object
func Apply(q *boil.Query, mods ...QueryMod) { func Apply(q *boil.Query, mods ...QueryMod) {
for _, mod := range mods { for _, mod := range mods {
mod(q) mod(q)
} }
} }
func Or(whereMods ...QueryMod) QueryMod { // SQL allows you to execute a plain SQL statement
func SQL(sql string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
if len(whereMods) < 2 { boil.SetSQL(q, sql, args...)
// error, needs to be at least 2 for an or
}
// add the where mods to query with or seperators
} }
} }
// Or surrounds where clauses to join them with OR as opposed to AND
func Or(whereMods ...QueryMod) QueryMod {
return func(q *boil.Query) {
if len(whereMods) < 2 {
panic("Or requires at least two arguments")
}
for _, w := range whereMods {
w(q)
boil.SetLastWhereAsOr(q)
}
}
}
// Limit the number of returned rows
func Limit(limit int) QueryMod { func Limit(limit int) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetLimit(q, limit) boil.SetLimit(q, limit)
} }
} }
// InnerJoin on another table
func InnerJoin(on string, args ...interface{}) QueryMod { func InnerJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetInnerJoin(q, on, args...) boil.SetInnerJoin(q, on, args...)
} }
} }
// OuterJoin on another table
func OuterJoin(on string, args ...interface{}) QueryMod { func OuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetOuterJoin(q, on, args...) boil.SetOuterJoin(q, on, args...)
} }
} }
// LeftOuterJoin on another table
func LeftOuterJoin(on string, args ...interface{}) QueryMod { func LeftOuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetLeftOuterJoin(q, on, args...) boil.SetLeftOuterJoin(q, on, args...)
} }
} }
// RightOuterJoin on another table
func RightOuterJoin(on string, args ...interface{}) QueryMod { func RightOuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetRightOuterJoin(q, on, args...) boil.SetRightOuterJoin(q, on, args...)
} }
} }
// Select specific columns opposed to all columns
func Select(columns ...string) QueryMod { func Select(columns ...string) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetSelect(q, columns...) boil.SetSelect(q, columns...)
} }
} }
// Where allows you to specify a where clause for your statement
func Where(clause string, args ...interface{}) QueryMod { func Where(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetWhere(q, clause, args...) boil.SetWhere(q, clause, args...)
} }
} }
// GroupBy allows you to specify a group by clause for your statement
func GroupBy(clause string) QueryMod { func GroupBy(clause string) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetGroupBy(q, clause) boil.SetGroupBy(q, clause)
} }
} }
// OrderBy allows you to specify a order by clause for your statement
func OrderBy(clause string) QueryMod { func OrderBy(clause string) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetOrderBy(q, clause) boil.SetOrderBy(q, clause)
} }
} }
// Having allows you to specify a having clause for your statement
func Having(clause string) QueryMod { func Having(clause string) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetHaving(q, clause) boil.SetHaving(q, clause)
} }
} }
func Table(table string) QueryMod { // From allows to specify the table for your statement
func From(from string) QueryMod {
return func(q *boil.Query) { return func(q *boil.Query) {
boil.SetTable(q, table) boil.SetFrom(q, from)
} }
} }

View file

@ -13,18 +13,25 @@ type where struct {
args []interface{} args []interface{}
} }
type plainSQL struct {
sql string
args []interface{}
}
type join struct { type join struct {
on string on string
args []interface{} args []interface{}
} }
// Query holds the state for the built up query
type Query struct { type Query struct {
executor Executor executor Executor
plainSQL plainSQL
delete bool delete bool
update map[string]interface{} update map[string]interface{}
selectCols []string selectCols []string
count bool count bool
table string from string
innerJoins []join innerJoins []join
outerJoins []join outerJoins []join
leftOuterJoins []join leftOuterJoins []join
@ -41,6 +48,8 @@ func buildQuery(q *Query) (string, []interface{}) {
var args []interface{} var args []interface{}
switch { switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete: case q.delete:
buf, args = buildDeleteQuery(q) buf, args = buildDeleteQuery(q)
case len(q.update) > 0: case len(q.update) > 0:
@ -61,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteString("COUNT(") buf.WriteString("COUNT(")
} }
if len(q.selectCols) > 0 { if len(q.selectCols) > 0 {
buf.WriteString(strings.Join(q.selectCols, ",")) buf.WriteString(`"` + strings.Join(q.selectCols, `","`) + `"`)
} else { } else {
buf.WriteByte('*') buf.WriteByte('*')
} }
@ -71,7 +80,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
} }
buf.WriteString(" FROM ") buf.WriteString(" FROM ")
fmt.Fprintf(buf, `"%s"`, q.table) fmt.Fprintf(buf, `"%s"`, q.from)
where, args := whereClause(q) where, args := whereClause(q)
buf.WriteString(where) buf.WriteString(where)
@ -84,7 +93,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ") buf.WriteString("DELETE FROM ")
fmt.Fprintf(buf, `"%s"`, q.table) fmt.Fprintf(buf, `"%s"`, q.from)
where, args := whereClause(q) where, args := whereClause(q)
buf.WriteString(where) buf.WriteString(where)
@ -128,6 +137,11 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) {
return q.executor.Query(qs, args...) return q.executor.Query(qs, args...)
} }
// SetSQL on the query.
func SetSQL(q *Query, sql string, args ...interface{}) {
q.plainSQL = plainSQL{sql: sql, args: args}
}
// SetCount on the query. // SetCount on the query.
func SetCount(q *Query) { func SetCount(q *Query) {
q.count = true q.count = true
@ -160,9 +174,9 @@ func Select(q *Query) []string {
return cols return cols
} }
// SetTable on the query. // SetFrom on the query.
func SetTable(q *Query, table string) { func SetFrom(q *Query, from string) {
q.table = table q.from = from
} }
// SetInnerJoin on the query. // SetInnerJoin on the query.
@ -190,6 +204,11 @@ func SetWhere(q *Query, clause string, args ...interface{}) {
q.where = append(q.where, where{clause: clause, args: args}) q.where = append(q.where, where{clause: clause, args: args})
} }
// SetLastWhereAsOr sets the or seperator for the last element in the where slice
func SetLastWhereAsOr(q *Query) {
q.where[len(q.where)-1].orSeperator = true
}
// SetGroupBy on the query. // SetGroupBy on the query.
func SetGroupBy(q *Query, clause string) { func SetGroupBy(q *Query, clause string) {
q.groupBy = append(q.groupBy, clause) q.groupBy = append(q.groupBy, clause)
@ -211,21 +230,24 @@ func SetLimit(q *Query, limit int) {
} }
func whereClause(q *Query) (string, []interface{}) { func whereClause(q *Query) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
var args []interface{} var args []interface{}
if len(q.where) > 0 { buf.WriteString(" WHERE ")
buf.WriteString(" WHERE ") for i := 0; i < len(q.where); i++ {
for i := 0; i < len(q.where); i++ { buf.WriteString(fmt.Sprintf("%s", q.where[i].clause))
buf.WriteString(fmt.Sprintf("%s", q.where[i].clause)) args = append(args, q.where[i].args...)
args = append(args, q.where[i].args...) if i >= len(q.where)-1 {
if i != len(q.where)-1 { continue
if q.where[i].orSeperator { }
buf.WriteString(" OR ") if q.where[i].orSeperator {
} else { buf.WriteString(" OR ")
buf.WriteString(" AND ") } else {
} buf.WriteString(" AND ")
}
} }
} }

View file

@ -26,7 +26,7 @@ func TestBuildQuery(t *testing.T) {
q *Query q *Query
args []interface{} args []interface{}
}{ }{
{&Query{table: "t"}, []interface{}{}}, {&Query{from: "t"}, []interface{}{}},
} }
for i, test := range tests { for i, test := range tests {
@ -57,6 +57,39 @@ func TestBuildQuery(t *testing.T) {
} }
} }
func TestSetLastWhereAsOr(t *testing.T) {
t.Parallel()
q := &Query{}
SetWhere(q, "")
if q.where[0].orSeperator {
t.Errorf("Do not want or seperator")
}
SetLastWhereAsOr(q)
if len(q.where) != 1 {
t.Errorf("Want len 1")
}
if !q.where[0].orSeperator {
t.Errorf("Want or seperator")
}
SetWhere(q, "")
SetLastWhereAsOr(q)
if len(q.where) != 2 {
t.Errorf("Want len 2")
}
if q.where[0].orSeperator != true {
t.Errorf("Expected true")
}
if q.where[1].orSeperator != true {
t.Errorf("Expected true")
}
}
func TestSetLimit(t *testing.T) { func TestSetLimit(t *testing.T) {
t.Parallel() t.Parallel()
@ -69,6 +102,21 @@ func TestSetLimit(t *testing.T) {
} }
} }
func TestSetSQL(t *testing.T) {
t.Parallel()
q := &Query{}
SetSQL(q, "select * from thing", 5, 3)
if len(q.plainSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args))
}
if q.plainSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.plainSQL.sql)
}
}
func TestSetWhere(t *testing.T) { func TestSetWhere(t *testing.T) {
t.Parallel() t.Parallel()
@ -133,11 +181,11 @@ func TestSetTable(t *testing.T) {
t.Parallel() t.Parallel()
q := &Query{} q := &Query{}
SetTable(q, "videos a, orders b") SetFrom(q, "videos a, orders b")
expect := "videos a, orders b" expect := "videos a, orders b"
if q.table != expect { if q.from != expect {
t.Errorf("Expected %s, got %s", expect, q.table) t.Errorf("Expected %s, got %s", expect, q.from)
} }
} }

View file

@ -25,16 +25,16 @@ func (q *Query) Bind(obj interface{}) error {
row := ExecQueryOne(q) row := ExecQueryOne(q)
err := BindOne(row, q.selectCols, obj) err := BindOne(row, q.selectCols, obj)
if err != nil { if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err) return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
} }
} else if kind == reflect.Slice { } else if kind == reflect.Slice {
rows, err := ExecQueryAll(q) rows, err := ExecQueryAll(q)
if err != nil { if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err) return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
} }
err = BindAll(rows, q.selectCols, obj) err = BindAll(rows, q.selectCols, obj)
if err != nil { if err != nil {
return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.table, err) return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.from, err)
} }
} else { } else {
return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String()) return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String())
@ -43,7 +43,7 @@ func (q *Query) Bind(obj interface{}) error {
return nil return nil
} }
// Bind executes the query and inserts the // BindP executes the query and inserts the
// result into the passed in object pointer. // result into the passed in object pointer.
// It panics on error. // It panics on error.
func (q *Query) BindP(obj interface{}) { func (q *Query) BindP(obj interface{}) {