diff --git a/boil/qs/query_mods.go b/boil/qs/query_mods.go index c2651a7..1c5ee78 100644 --- a/boil/qs/query_mods.go +++ b/boil/qs/query_mods.go @@ -4,6 +4,12 @@ import "github.com/pobri19/sqlboiler/boil" type QueryMod func(q *boil.Query) +func Apply(q *boil.Query, mods ...QueryMod) { + for _, mod := range mods { + mod(q) + } +} + func Limit(limit int) QueryMod { return func(q *boil.Query) { boil.SetLimit(q, limit) diff --git a/boil/query.go b/boil/query.go index d83502a..3123ef5 100644 --- a/boil/query.go +++ b/boil/query.go @@ -91,14 +91,8 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) { return nil, nil } -func Apply(q *Query, mods ...func(q *Query)) { - for _, mod := range mods { - mod(q) - } -} - -func SetDelete(q *Query, flag bool) { - q.delete = flag +func SetDelete(q *Query) { + q.delete = true } func SetUpdate(q *Query, cols map[string]interface{}) { diff --git a/boil/query_test.go b/boil/query_test.go index 313d6e6..9950fc0 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -1,6 +1,7 @@ package boil import ( + "database/sql" "flag" "fmt" "io/ioutil" @@ -55,132 +56,240 @@ func TestBuildQuery(t *testing.T) { } } -// func TestApply(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn1 := Limit(10) -// qfn2 := Where("x > $1 AND y > $2", 5, 3) -// -// q.Apply(qfn1, qfn2) -// -// expect1 := 10 -// if q.limit != expect1 { -// t.Errorf("Expected %d, got %d", expect1, q.limit) -// } -// -// expect2 := "x > $1 AND y > $2" -// if len(q.where) != 1 { -// t.Errorf("Expected %d where slices, got %d", len(q.where)) -// } -// -// expect := "x > $1 AND y > $2" -// if q.where[0].clause != expect2 { -// t.Errorf("Expected %s, got %s", expect, q.where) -// } -// -// if len(q.where[0].args) != 2 { -// t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args)) -// } -// -// if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 { -// t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args) -// } -// } -// -// func TestLimit(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := Limit(10) -// -// qfn(q) -// -// expect := 10 -// if q.limit != expect { -// t.Errorf("Expected %d, got %d", expect, q.limit) -// } -// } -// -// func TestWhere(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := Where("x > $1 AND y > $2", 5, 3) -// -// qfn(q) -// -// if len(q.where) != 1 { -// t.Errorf("Expected %d where slices, got %d", len(q.where)) -// } -// -// expect := "x > $1 AND y > $2" -// if q.where[0].clause != expect { -// t.Errorf("Expected %s, got %s", expect, q.where) -// } -// -// if len(q.where[0].args) != 2 { -// t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args)) -// } -// -// if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 { -// t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args) -// } -// } -// -// func TestGroupBy(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := GroupBy("col1, col2") -// -// qfn(q) -// -// expect := "col1, col2" -// if len(q.groupBy) != 1 && q.groupBy[0] != expect { -// t.Errorf("Expected %s, got %s", expect, q.groupBy[0]) -// } -// } -// -// func TestOrderBy(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := OrderBy("col1 desc, col2 asc") -// -// qfn(q) -// -// expect := "col1 desc, col2 asc" -// if len(q.orderBy) != 1 && q.orderBy[0] != expect { -// t.Errorf("Expected %s, got %s", expect, q.orderBy[0]) -// } -// } -// -// func TestHaving(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := Having("count(orders.order_id) > 10") -// -// qfn(q) -// -// expect := "count(orders.order_id) > 10" -// if len(q.having) != 1 && q.having[0] != expect { -// t.Errorf("Expected %s, got %s", expect, q.having[0]) -// } -// } -// -// func TestFrom(t *testing.T) { -// t.Parallel() -// -// q := &boil.Query{} -// qfn := From("videos a, orders b") -// -// qfn(q) -// -// expect := "videos a, orders b" -// if q.from != expect { -// t.Errorf("Expected %s, got %s", expect, q.from) -// } -// } -// } +func TestExecQuery(t *testing.T) { + t.Parallel() +} + +func TestExecQueryOne(t *testing.T) { + t.Parallel() +} + +func TestExecQueryAll(t *testing.T) { + t.Parallel() +} + +func TestSetLimit(t *testing.T) { + t.Parallel() + + q := &Query{} + SetLimit(q, 10) + + expect := 10 + if q.limit != expect { + t.Errorf("Expected %d, got %d", expect, q.limit) + } +} + +func TestSetWhere(t *testing.T) { + t.Parallel() + + q := &Query{} + SetWhere(q, "x > $1 AND y > $2", 5, 3) + + if len(q.where) != 1 { + t.Errorf("Expected %d where slices, got %d", len(q.where)) + } + + expect := "x > $1 AND y > $2" + if q.where[0].clause != expect { + t.Errorf("Expected %s, got %s", expect, q.where) + } + + if len(q.where[0].args) != 2 { + t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args)) + } + + if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 { + t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args) + } +} + +func TestSetGroupBy(t *testing.T) { + t.Parallel() + + q := &Query{} + SetGroupBy(q, "col1, col2") + + expect := "col1, col2" + if len(q.groupBy) != 1 && q.groupBy[0] != expect { + t.Errorf("Expected %s, got %s", expect, q.groupBy[0]) + } +} + +func TestSetOrderBy(t *testing.T) { + t.Parallel() + + q := &Query{} + SetOrderBy(q, "col1 desc, col2 asc") + + expect := "col1 desc, col2 asc" + if len(q.orderBy) != 1 && q.orderBy[0] != expect { + t.Errorf("Expected %s, got %s", expect, q.orderBy[0]) + } +} + +func TestSetHaving(t *testing.T) { + t.Parallel() + + q := &Query{} + SetHaving(q, "count(orders.order_id) > 10") + + expect := "count(orders.order_id) > 10" + if len(q.having) != 1 && q.having[0] != expect { + t.Errorf("Expected %s, got %s", expect, q.having[0]) + } +} + +func TestSetFrom(t *testing.T) { + t.Parallel() + + q := &Query{} + SetFrom(q, "videos a, orders b") + + expect := "videos a, orders b" + if q.from != expect { + t.Errorf("Expected %s, got %s", expect, q.from) + } +} + +func TestSetDelete(t *testing.T) { + t.Parallel() + + q := &Query{} + SetDelete(q) + + if q.delete != true { + t.Errorf("Expected %t, got %t", true, q.delete) + } +} + +func TestSetUpdate(t *testing.T) { + t.Parallel() + + q := &Query{} + SetUpdate(q, map[string]interface{}{"col1": 1, "col2": 2}) + + if len(q.update) != 2 { + t.Errorf("Expected len 2, got %d", len(q.update)) + } + + if q.update["col1"] != 1 && q.update["col2"] != 2 { + t.Errorf("Value misatch: %#v", q.update) + } +} + +func TestSetExecutor(t *testing.T) { + t.Parallel() + + q := &Query{} + d := &sql.DB{} + SetExecutor(q, d) + + if q.executor != d { + t.Errorf("Expected executor to get set to d, but was: %#v", q.executor) + } +} + +func TestSetSelect(t *testing.T) { + t.Parallel() + + q := &Query{} + SetSelect(q, "col1", "col2") + + if len(q.selectCols) != 2 { + t.Errorf("Expected selectCols len 2, got %d", len(q.selectCols)) + } + + if q.selectCols[0] != "col1" && q.selectCols[1] != "col2" { + t.Errorf("select cols value mismatch: %#v", q.selectCols) + } +} + +func TestSetInnerJoin(t *testing.T) { + t.Parallel() + + q := &Query{} + SetInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5) + + if len(q.innerJoins) != 1 { + t.Errorf("Expected len 1, got %d", len(q.innerJoins)) + } + + if q.innerJoins[0].on != "thing=$1 AND stuff=$2" { + t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins) + } + + if len(q.innerJoins[0].args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.innerJoins[0].args)) + } + + if q.innerJoins[0].args[0] != 2 && q.innerJoins[0].args[1] != 5 { + t.Errorf("Invalid args values, got %#v", q.innerJoins[0].args) + } +} + +func TestSetOuterJoin(t *testing.T) { + t.Parallel() + q := &Query{} + SetOuterJoin(q, "thing=$1 AND stuff=$2", 2, 5) + + if len(q.outerJoins) != 1 { + t.Errorf("Expected len 1, got %d", len(q.outerJoins)) + } + + if q.outerJoins[0].on != "thing=$1 AND stuff=$2" { + t.Errorf("Got invalid innerJoin on string: %#v", q.outerJoins) + } + + if len(q.outerJoins[0].args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.outerJoins[0].args)) + } + + if q.outerJoins[0].args[0] != 2 && q.outerJoins[0].args[1] != 5 { + t.Errorf("Invalid args values, got %#v", q.outerJoins[0].args) + } +} + +func TestSetLeftOuterJoin(t *testing.T) { + t.Parallel() + q := &Query{} + SetLeftOuterJoin(q, "thing=$1 AND stuff=$2", 2, 5) + + if len(q.leftOuterJoins) != 1 { + t.Errorf("Expected len 1, got %d", len(q.leftOuterJoins)) + } + + if q.leftOuterJoins[0].on != "thing=$1 AND stuff=$2" { + t.Errorf("Got invalid innerJoin on string: %#v", q.leftOuterJoins) + } + + if len(q.leftOuterJoins[0].args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.leftOuterJoins[0].args)) + } + + if q.leftOuterJoins[0].args[0] != 2 && q.leftOuterJoins[0].args[1] != 5 { + t.Errorf("Invalid args values, got %#v", q.leftOuterJoins[0].args) + } +} + +func TestSetRightOuterJoin(t *testing.T) { + t.Parallel() + q := &Query{} + SetRightOuterJoin(q, "thing=$1 AND stuff=$2", 2, 5) + + if len(q.rightOuterJoins) != 1 { + t.Errorf("Expected len 1, got %d", len(q.rightOuterJoins)) + } + + if q.rightOuterJoins[0].on != "thing=$1 AND stuff=$2" { + t.Errorf("Got invalid innerJoin on string: %#v", q.rightOuterJoins) + } + + if len(q.rightOuterJoins[0].args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.rightOuterJoins[0].args)) + } + + if q.rightOuterJoins[0].args[0] != 2 && q.rightOuterJoins[0].args[1] != 5 { + t.Errorf("Invalid args values, got %#v", q.rightOuterJoins[0].args) + } +} diff --git a/cmds/templates/singles/helpers.tpl b/cmds/templates/singles/helpers.tpl index 1f8ac97..d8c1b8c 100644 --- a/cmds/templates/singles/helpers.tpl +++ b/cmds/templates/singles/helpers.tpl @@ -8,8 +8,8 @@ func NewQuery(mods ...qs.QueryMod) *boil.Query { // NewQueryX initializes a new Query using the passed in QueryMods func NewQueryX(executor boil.Executor, mods ...qs.QueryMod) *boil.Query { - q := &Query{executor: executor} - q.Apply(mods...) + q := &boil.Query{executor: executor} + qs.Apply(q, mods...) return q }