Add sql query mod, finish or querymod

* Add comments to kill lint errors
This commit is contained in:
Patrick O'brien 2016-07-19 14:07:12 +10:00
parent b447018220
commit f713e041ad
5 changed files with 130 additions and 34 deletions

View file

@ -3,6 +3,7 @@ package boil
// HookPoint is the point in time at which we hook
type HookPoint int
// the hook point constants
const (
HookAfterCreate HookPoint = iota + 1
HookAfterUpdate

View file

@ -2,85 +2,110 @@ package qm
import "github.com/nullbio/sqlboiler/boil"
// QueryMod to modify the query object
type QueryMod func(q *boil.Query)
// Apply the query mods to the Query object
func Apply(q *boil.Query, mods ...QueryMod) {
for _, mod := range mods {
mod(q)
}
}
func Or(whereMods ...QueryMod) QueryMod {
// SQL allows you to execute a plain SQL statement
func SQL(sql string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
if len(whereMods) < 2 {
// error, needs to be at least 2 for an or
}
// add the where mods to query with or seperators
boil.SetSQL(q, sql, args...)
}
}
// Or surrounds where clauses to join them with OR as opposed to AND
func Or(whereMods ...QueryMod) QueryMod {
return func(q *boil.Query) {
if len(whereMods) < 2 {
panic("Or requires at least two arguments")
}
for _, w := range whereMods {
w(q)
boil.SetLastWhereAsOr(q)
}
}
}
// Limit the number of returned rows
func Limit(limit int) QueryMod {
return func(q *boil.Query) {
boil.SetLimit(q, limit)
}
}
// InnerJoin on another table
func InnerJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetInnerJoin(q, on, args...)
}
}
// OuterJoin on another table
func OuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetOuterJoin(q, on, args...)
}
}
// LeftOuterJoin on another table
func LeftOuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetLeftOuterJoin(q, on, args...)
}
}
// RightOuterJoin on another table
func RightOuterJoin(on string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetRightOuterJoin(q, on, args...)
}
}
// Select specific columns opposed to all columns
func Select(columns ...string) QueryMod {
return func(q *boil.Query) {
boil.SetSelect(q, columns...)
}
}
// Where allows you to specify a where clause for your statement
func Where(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetWhere(q, clause, args...)
}
}
// GroupBy allows you to specify a group by clause for your statement
func GroupBy(clause string) QueryMod {
return func(q *boil.Query) {
boil.SetGroupBy(q, clause)
}
}
// OrderBy allows you to specify a order by clause for your statement
func OrderBy(clause string) QueryMod {
return func(q *boil.Query) {
boil.SetOrderBy(q, clause)
}
}
// Having allows you to specify a having clause for your statement
func Having(clause string) QueryMod {
return func(q *boil.Query) {
boil.SetHaving(q, clause)
}
}
func Table(table string) QueryMod {
// From allows to specify the table for your statement
func From(from string) QueryMod {
return func(q *boil.Query) {
boil.SetTable(q, table)
boil.SetFrom(q, from)
}
}

View file

@ -13,18 +13,25 @@ type where struct {
args []interface{}
}
type plainSQL struct {
sql string
args []interface{}
}
type join struct {
on string
args []interface{}
}
// Query holds the state for the built up query
type Query struct {
executor Executor
plainSQL plainSQL
delete bool
update map[string]interface{}
selectCols []string
count bool
table string
from string
innerJoins []join
outerJoins []join
leftOuterJoins []join
@ -41,6 +48,8 @@ func buildQuery(q *Query) (string, []interface{}) {
var args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
@ -61,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteString("COUNT(")
}
if len(q.selectCols) > 0 {
buf.WriteString(strings.Join(q.selectCols, ","))
buf.WriteString(`"` + strings.Join(q.selectCols, `","`) + `"`)
} else {
buf.WriteByte('*')
}
@ -71,7 +80,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
}
buf.WriteString(" FROM ")
fmt.Fprintf(buf, `"%s"`, q.table)
fmt.Fprintf(buf, `"%s"`, q.from)
where, args := whereClause(q)
buf.WriteString(where)
@ -84,7 +93,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
fmt.Fprintf(buf, `"%s"`, q.table)
fmt.Fprintf(buf, `"%s"`, q.from)
where, args := whereClause(q)
buf.WriteString(where)
@ -128,6 +137,11 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) {
return q.executor.Query(qs, args...)
}
// SetSQL on the query.
func SetSQL(q *Query, sql string, args ...interface{}) {
q.plainSQL = plainSQL{sql: sql, args: args}
}
// SetCount on the query.
func SetCount(q *Query) {
q.count = true
@ -160,9 +174,9 @@ func Select(q *Query) []string {
return cols
}
// SetTable on the query.
func SetTable(q *Query, table string) {
q.table = table
// SetFrom on the query.
func SetFrom(q *Query, from string) {
q.from = from
}
// SetInnerJoin on the query.
@ -190,6 +204,11 @@ func SetWhere(q *Query, clause string, args ...interface{}) {
q.where = append(q.where, where{clause: clause, args: args})
}
// SetLastWhereAsOr sets the or seperator for the last element in the where slice
func SetLastWhereAsOr(q *Query) {
q.where[len(q.where)-1].orSeperator = true
}
// SetGroupBy on the query.
func SetGroupBy(q *Query, clause string) {
q.groupBy = append(q.groupBy, clause)
@ -211,21 +230,24 @@ func SetLimit(q *Query, limit int) {
}
func whereClause(q *Query) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
buf := &bytes.Buffer{}
var args []interface{}
if len(q.where) > 0 {
buf.WriteString(" WHERE ")
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf("%s", q.where[i].clause))
args = append(args, q.where[i].args...)
if i != len(q.where)-1 {
if q.where[i].orSeperator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}
buf.WriteString(" WHERE ")
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf("%s", q.where[i].clause))
args = append(args, q.where[i].args...)
if i >= len(q.where)-1 {
continue
}
if q.where[i].orSeperator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}

View file

@ -26,7 +26,7 @@ func TestBuildQuery(t *testing.T) {
q *Query
args []interface{}
}{
{&Query{table: "t"}, []interface{}{}},
{&Query{from: "t"}, []interface{}{}},
}
for i, test := range tests {
@ -57,6 +57,39 @@ func TestBuildQuery(t *testing.T) {
}
}
func TestSetLastWhereAsOr(t *testing.T) {
t.Parallel()
q := &Query{}
SetWhere(q, "")
if q.where[0].orSeperator {
t.Errorf("Do not want or seperator")
}
SetLastWhereAsOr(q)
if len(q.where) != 1 {
t.Errorf("Want len 1")
}
if !q.where[0].orSeperator {
t.Errorf("Want or seperator")
}
SetWhere(q, "")
SetLastWhereAsOr(q)
if len(q.where) != 2 {
t.Errorf("Want len 2")
}
if q.where[0].orSeperator != true {
t.Errorf("Expected true")
}
if q.where[1].orSeperator != true {
t.Errorf("Expected true")
}
}
func TestSetLimit(t *testing.T) {
t.Parallel()
@ -69,6 +102,21 @@ func TestSetLimit(t *testing.T) {
}
}
func TestSetSQL(t *testing.T) {
t.Parallel()
q := &Query{}
SetSQL(q, "select * from thing", 5, 3)
if len(q.plainSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args))
}
if q.plainSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.plainSQL.sql)
}
}
func TestSetWhere(t *testing.T) {
t.Parallel()
@ -133,11 +181,11 @@ func TestSetTable(t *testing.T) {
t.Parallel()
q := &Query{}
SetTable(q, "videos a, orders b")
SetFrom(q, "videos a, orders b")
expect := "videos a, orders b"
if q.table != expect {
t.Errorf("Expected %s, got %s", expect, q.table)
if q.from != expect {
t.Errorf("Expected %s, got %s", expect, q.from)
}
}

View file

@ -25,16 +25,16 @@ func (q *Query) Bind(obj interface{}) error {
row := ExecQueryOne(q)
err := BindOne(row, q.selectCols, obj)
if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err)
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
}
} else if kind == reflect.Slice {
rows, err := ExecQueryAll(q)
if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err)
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
}
err = BindAll(rows, q.selectCols, obj)
if err != nil {
return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.table, err)
return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.from, err)
}
} else {
return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String())
@ -43,7 +43,7 @@ func (q *Query) Bind(obj interface{}) error {
return nil
}
// Bind executes the query and inserts the
// BindP executes the query and inserts the
// result into the passed in object pointer.
// It panics on error.
func (q *Query) BindP(obj interface{}) {