Move everything to better package structure

This commit is contained in:
Aaron L 2016-09-14 20:45:09 -07:00
commit 5149df8359
40 changed files with 241 additions and 86 deletions

1
queries/_fixtures/00.sql Normal file
View file

@ -0,0 +1 @@
SELECT * FROM "t";

1
queries/_fixtures/01.sql Normal file
View file

@ -0,0 +1 @@
SELECT * FROM "q" LIMIT 5 OFFSET 6;

1
queries/_fixtures/02.sql Normal file
View file

@ -0,0 +1 @@
SELECT * FROM "q" ORDER BY a ASC, b DESC;

1
queries/_fixtures/03.sql Normal file
View file

@ -0,0 +1 @@
SELECT count(*) as ab, thing as bd, "stuff" FROM "t";

1
queries/_fixtures/04.sql Normal file
View file

@ -0,0 +1 @@
SELECT count(*) as ab, thing as bd, "stuff" FROM "a", "b";

1
queries/_fixtures/05.sql Normal file
View file

@ -0,0 +1 @@
SELECT "a"."happy" as "a.happy", "r"."fun" as "r.fun", "q" FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;

1
queries/_fixtures/06.sql Normal file
View file

@ -0,0 +1 @@
SELECT "a".* FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;

1
queries/_fixtures/07.sql Normal file
View file

@ -0,0 +1 @@
SELECT "videos".* FROM "videos" INNER JOIN (select id from users where deleted = $1) u on u.id = videos.user_id WHERE (videos.deleted = $2);

1
queries/_fixtures/08.sql Normal file
View file

@ -0,0 +1 @@
SELECT * FROM "a" WHERE (a=$1 or b=$2) AND (c=$3) GROUP BY id, name HAVING id <> $4, length(name, $5) > $6;

1
queries/_fixtures/09.sql Normal file
View file

@ -0,0 +1 @@
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE (a=$1) AND (b=$2) AND (c=$3);

1
queries/_fixtures/10.sql Normal file
View file

@ -0,0 +1 @@
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3) LIMIT 5;

1
queries/_fixtures/11.sql Normal file
View file

@ -0,0 +1 @@
UPDATE thing happy, "fun", "stuff" SET ("col2", "fun"."col3", "col1") = ($1,$2,$3) WHERE (aa=$4 or bb=$5 or cc=$6) AND (dd=$7 or ee=$8 or ff=$9 and gg=$10) LIMIT 5;

1
queries/_fixtures/12.sql Normal file
View file

@ -0,0 +1 @@
SELECT "cats".* FROM "cats" INNER JOIN dogs d on d.cat_id = cats.id;

1
queries/_fixtures/13.sql Normal file
View file

@ -0,0 +1 @@
SELECT "c".* FROM cats c INNER JOIN dogs d on d.cat_id = cats.id;

1
queries/_fixtures/14.sql Normal file
View file

@ -0,0 +1 @@
SELECT "c".* FROM cats as c INNER JOIN dogs d on d.cat_id = cats.id;

1
queries/_fixtures/15.sql Normal file
View file

@ -0,0 +1 @@
SELECT "c".*, "d".* FROM cats as c, dogs as d INNER JOIN dogs d on d.cat_id = cats.id;

148
queries/eager_load.go Normal file
View file

@ -0,0 +1,148 @@
package queries
import (
"database/sql"
"reflect"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle"
)
type loadRelationshipState struct {
exec boil.Executor
loaded map[string]struct{}
toLoad []string
}
func (l loadRelationshipState) hasLoaded(depth int) bool {
_, ok := l.loaded[l.buildKey(depth)]
return ok
}
func (l loadRelationshipState) setLoaded(depth int) {
l.loaded[l.buildKey(depth)] = struct{}{}
}
func (l loadRelationshipState) buildKey(depth int) string {
buf := strmangle.GetBuffer()
for i, piece := range l.toLoad[:depth+1] {
if i != 0 {
buf.WriteByte('.')
}
buf.WriteString(piece)
}
str := buf.String()
strmangle.PutBuffer(buf)
return str
}
// loadRelationships dynamically calls the template generated eager load
// functions of the form:
//
// func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{})
//
// The arguments to this function are:
// - t is not considered here, and is always passed nil. The function exists on a loaded
// struct to avoid a circular dependency with boil, and the receiver is ignored.
// - exec is used to perform additional queries that might be required for loading the relationships.
// - singular is passed in to identify whether or not this was a single object
// or a slice that must be loaded into.
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
//
// It takes list of nested relationships to load.
func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
typ := reflect.TypeOf(obj).Elem()
if bkind == kindPtrSliceStruct {
typ = typ.Elem().Elem()
}
if !l.hasLoaded(depth) {
current := l.toLoad[depth]
ln, found := typ.FieldByName(loaderStructName)
// It's possible a Loaders struct doesn't exist on the struct.
if !found {
return errors.Errorf("attempted to load %s but no L struct was found", current)
}
// Attempt to find the LoadRelationshipName function
loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current)
if !found {
return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current)
}
// Hack to allow nil executors
execArg := reflect.ValueOf(l.exec)
if !execArg.IsValid() {
execArg = reflect.ValueOf((*sql.DB)(nil))
}
val := reflect.ValueOf(obj).Elem()
if bkind == kindPtrSliceStruct {
val = val.Index(0).Elem()
}
methodArgs := []reflect.Value{
val.FieldByName(loaderStructName),
execArg,
reflect.ValueOf(bkind == kindStruct),
reflect.ValueOf(obj),
}
resp := loadMethod.Func.Call(methodArgs)
if intf := resp[0].Interface(); intf != nil {
return errors.Wrapf(intf.(error), "failed to eager load %s", current)
}
l.setLoaded(depth)
}
// Pull one off the queue, continue if there's still some to go
depth++
if depth >= len(l.toLoad) {
return nil
}
loadedObject := reflect.ValueOf(obj)
// If we eagerly loaded nothing
if loadedObject.IsNil() {
return nil
}
loadedObject = reflect.Indirect(loadedObject)
// If it's singular we can just immediately call without looping
if bkind == kindStruct {
return l.loadRelationshipsRecurse(depth, loadedObject)
}
// Loop over all eager loaded objects
ln := loadedObject.Len()
if ln == 0 {
return nil
}
for i := 0; i < ln; i++ {
iter := loadedObject.Index(i).Elem()
if err := l.loadRelationshipsRecurse(depth, iter); err != nil {
return err
}
}
return nil
}
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
r := obj.FieldByName(relationshipStructName)
if !r.IsValid() || r.IsNil() {
return errors.Errorf("could not traverse into loaded %s relationship to load more things", l.toLoad[depth])
}
newObj := reflect.Indirect(r).FieldByName(l.toLoad[depth])
bkind := kindStruct
if reflect.Indirect(newObj).Kind() != reflect.Struct {
bkind = kindPtrSliceStruct
newObj = newObj.Addr()
}
return l.loadRelationships(depth, newObj.Interface(), bkind)
}

202
queries/eager_load_test.go Normal file
View file

@ -0,0 +1,202 @@
package queries
import (
"testing"
"github.com/vattle/sqlboiler/boil"
)
var loadFunctionCalled bool
var loadFunctionNestedCalled int
type testRStruct struct {
}
type testLStruct struct {
}
type testNestedStruct struct {
ID int
R *testNestedRStruct
L testNestedLStruct
}
type testNestedRStruct struct {
ToEagerLoad *testNestedStruct
}
type testNestedLStruct struct {
}
type testNestedSlice struct {
ID int
R *testNestedRSlice
L testNestedLSlice
}
type testNestedRSlice struct {
ToEagerLoad []*testNestedSlice
}
type testNestedLSlice struct {
}
func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error {
loadFunctionCalled = true
return nil
}
func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedStruct:
x.R = &testNestedRStruct{
&testNestedStruct{ID: 4},
}
case *[]*testNestedStruct:
for _, r := range *x {
r.R = &testNestedRStruct{
&testNestedStruct{ID: 4},
}
}
}
loadFunctionNestedCalled++
return nil
}
func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedSlice:
x.R = &testNestedRSlice{
[]*testNestedSlice{{ID: 5}},
}
case *[]*testNestedSlice:
for _, r := range *x {
r.R = &testNestedRSlice{
[]*testNestedSlice{{ID: 5}},
}
}
}
loadFunctionNestedCalled++
return nil
}
func testFakeState(toLoad ...string) loadRelationshipState {
return loadRelationshipState{
loaded: map[string]struct{}{},
toLoad: toLoad,
}
}
func TestLoadRelationshipsSlice(t *testing.T) {
// t.Parallel() Function uses globals
loadFunctionCalled = false
testSlice := []*struct {
ID int
R *testRStruct
L testLStruct
}{{}}
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
if !loadFunctionCalled {
t.Errorf("Load function was not called for testSlice")
}
}
func TestLoadRelationshipsSingular(t *testing.T) {
// t.Parallel() Function uses globals
loadFunctionCalled = false
testSingular := struct {
ID int
R *testRStruct
L testLStruct
}{}
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
if !loadFunctionCalled {
t.Errorf("Load function was not called for singular")
}
}
func TestLoadRelationshipsSliceNested(t *testing.T) {
// t.Parallel() Function uses globals
testSlice := []*testNestedStruct{
{
ID: 2,
},
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
testSliceSlice := []*testNestedSlice{
{
ID: 2,
},
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
}
func TestLoadRelationshipsSingularNested(t *testing.T) {
// t.Parallel() Function uses globals
testSingular := testNestedStruct{
ID: 3,
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
testSingularSlice := testNestedSlice{
ID: 3,
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
}
func TestLoadRelationshipsNoReload(t *testing.T) {
// t.Parallel() Function uses globals
testSingular := testNestedStruct{
ID: 3,
R: &testNestedRStruct{
&testNestedStruct{},
},
}
loadFunctionNestedCalled = 0
state := loadRelationshipState{
loaded: map[string]struct{}{
"ToEagerLoad": {},
"ToEagerLoad.ToEagerLoad": {},
},
toLoad: []string{"ToEagerLoad", "ToEagerLoad"},
}
if err := state.loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 0 {
t.Error("didn't want this called")
}
}

31
queries/helpers.go Normal file
View file

@ -0,0 +1,31 @@
package queries
import (
"fmt"
"reflect"
"github.com/vattle/sqlboiler/strmangle"
)
// NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
c := make([]string, 0, len(defaults))
val := reflect.Indirect(reflect.ValueOf(obj))
for _, d := range defaults {
fieldName := strmangle.TitleCase(d)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
}
zero := reflect.Zero(field.Type())
if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
c = append(c, d)
}
}
return c
}

67
queries/helpers_test.go Normal file
View file

@ -0,0 +1,67 @@
package queries
import (
"reflect"
"testing"
"time"
"gopkg.in/nullbio/null.v5"
)
type testObj struct {
ID int
Name string `db:"TestHello"`
HeadSize int
}
func TestNonZeroDefaultSet(t *testing.T) {
t.Parallel()
type Anything struct {
ID int
Name string
CreatedAt *time.Time
UpdatedAt null.Time
}
now := time.Now()
tests := []struct {
Defaults []string
Obj interface{}
Ret []string
}{
{
[]string{"id"},
Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{},
},
{
[]string{"id"},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{"id"},
},
{
[]string{},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{},
},
{
[]string{"id", "created_at", "updated_at"},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{"id"},
},
{
[]string{"id", "created_at", "updated_at"},
Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}},
[]string{"id", "created_at", "updated_at"},
},
}
for i, test := range tests {
z := NonZeroDefaultSet(test.Defaults, test.Obj)
if !reflect.DeepEqual(test.Ret, z) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
}
}
}

145
queries/qm/query_mods.go Normal file
View file

@ -0,0 +1,145 @@
package qm
import "github.com/vattle/sqlboiler/queries"
// QueryMod to modify the query object
type QueryMod func(q *queries.Query)
// Apply the query mods to the Query object
func Apply(q *queries.Query, mods ...QueryMod) {
for _, mod := range mods {
mod(q)
}
}
// SQL allows you to execute a plain SQL statement
func SQL(sql string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.SetSQL(q, sql, args...)
}
}
// Load allows you to specify foreign key relationships to eager load
// for your query. Passed in relationships need to be in the format
// MyThing or MyThings.
// Relationship name plurality is important, if your relationship is
// singular, you need to specify the singular form and vice versa.
func Load(relationships ...string) QueryMod {
return func(q *queries.Query) {
queries.SetLoad(q, relationships...)
}
}
// InnerJoin on another table
func InnerJoin(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendInnerJoin(q, clause, args...)
}
}
// Select specific columns opposed to all columns
func Select(columns ...string) QueryMod {
return func(q *queries.Query) {
queries.AppendSelect(q, columns...)
}
}
// Where allows you to specify a where clause for your statement
func Where(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
}
}
// And allows you to specify a where clause separated by an AND for your statement
// And is a duplicate of the Where function, but allows for more natural looking
// query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?")))
func And(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
}
}
// Or allows you to specify a where clause separated by an OR for your statement
func Or(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
queries.SetLastWhereAsOr(q)
}
}
// WhereIn allows you to specify a "x IN (set)" clause for your where statement
// Example clauses: "column in ?", "(column1,column2) in ?"
func WhereIn(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
}
}
// AndIn allows you to specify a "x IN (set)" clause separated by an AndIn
// for your where statement. AndIn is a duplicate of the WhereIn function, but
// allows for more natural looking query mod chains, for example:
// (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?"))
func AndIn(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
}
}
// OrIn allows you to specify an IN clause separated by
// an OR for your where statement
func OrIn(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
queries.SetLastInAsOr(q)
}
}
// GroupBy allows you to specify a group by clause for your statement
func GroupBy(clause string) QueryMod {
return func(q *queries.Query) {
queries.AppendGroupBy(q, clause)
}
}
// OrderBy allows you to specify a order by clause for your statement
func OrderBy(clause string) QueryMod {
return func(q *queries.Query) {
queries.AppendOrderBy(q, clause)
}
}
// Having allows you to specify a having clause for your statement
func Having(clause string, args ...interface{}) QueryMod {
return func(q *queries.Query) {
queries.AppendHaving(q, clause, args...)
}
}
// From allows to specify the table for your statement
func From(from string) QueryMod {
return func(q *queries.Query) {
queries.AppendFrom(q, from)
}
}
// Limit the number of returned rows
func Limit(limit int) QueryMod {
return func(q *queries.Query) {
queries.SetLimit(q, limit)
}
}
// Offset into the results
func Offset(offset int) QueryMod {
return func(q *queries.Query) {
queries.SetOffset(q, offset)
}
}
// For inserts a concurrency locking clause at the end of your statement
func For(clause string) QueryMod {
return func(q *queries.Query) {
queries.SetFor(q, clause)
}
}

274
queries/query.go Normal file
View file

@ -0,0 +1,274 @@
package queries
import (
"database/sql"
"fmt"
"github.com/vattle/sqlboiler/boil"
)
// joinKind is the type of join
type joinKind int
// Join type constants
const (
JoinInner joinKind = iota
JoinOuterLeft
JoinOuterRight
JoinNatural
)
// Query holds the state for the built up query
type Query struct {
executor boil.Executor
dialect *Dialect
plainSQL plainSQL
load []string
delete bool
update map[string]interface{}
selectCols []string
count bool
from []string
joins []join
where []where
in []in
groupBy []string
orderBy []string
having []having
limit int
offset int
forlock string
}
// Dialect holds values that direct the query builder
// how to build compatible queries for each database.
// Each database driver needs to implement functions
// that provide these values.
type Dialect struct {
// The left quote character for SQL identifiers
LQ byte
// The right quote character for SQL identifiers
RQ byte
// Bool flag indicating whether indexed
// placeholders ($1) are used, or ? placeholders.
IndexPlaceholders bool
}
type where struct {
clause string
orSeparator bool
args []interface{}
}
type in struct {
clause string
orSeparator bool
args []interface{}
}
type having struct {
clause string
args []interface{}
}
type plainSQL struct {
sql string
args []interface{}
}
type join struct {
kind joinKind
clause string
args []interface{}
}
// SQL makes a plainSQL query, usually for use with bind
func SQL(exec boil.Executor, query string, args ...interface{}) *Query {
return &Query{
executor: exec,
plainSQL: plainSQL{
sql: query,
args: args,
},
}
}
// SQLG makes a plainSQL query using the global boil.Executor, usually for use with bind
func SQLG(query string, args ...interface{}) *Query {
return SQL(boil.GetDB(), query, args...)
}
// Exec executes a query that does not need a row returned
func (q *Query) Exec() (sql.Result, error) {
qs, args := buildQuery(q)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.Exec(qs, args...)
}
// QueryRow executes the query for the One finisher and returns a row
func (q *Query) QueryRow() *sql.Row {
qs, args := buildQuery(q)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.QueryRow(qs, args...)
}
// Query executes the query for the All finisher and returns multiple rows
func (q *Query) Query() (*sql.Rows, error) {
qs, args := buildQuery(q)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.Query(qs, args...)
}
// ExecP executes a query that does not need a row returned
// It will panic on error
func (q *Query) ExecP() sql.Result {
res, err := q.Exec()
if err != nil {
panic(boil.WrapErr(err))
}
return res
}
// QueryP executes the query for the All finisher and returns multiple rows
// It will panic on error
func (q *Query) QueryP() *sql.Rows {
rows, err := q.Query()
if err != nil {
panic(boil.WrapErr(err))
}
return rows
}
// SetExecutor on the query.
func SetExecutor(q *Query, exec boil.Executor) {
q.executor = exec
}
// GetExecutor on the query.
func GetExecutor(q *Query) boil.Executor {
return q.executor
}
// SetDialect on the query.
func SetDialect(q *Query, dialect *Dialect) {
q.dialect = dialect
}
// SetSQL on the query.
func SetSQL(q *Query, sql string, args ...interface{}) {
q.plainSQL = plainSQL{sql: sql, args: args}
}
// SetLoad on the query.
func SetLoad(q *Query, relationships ...string) {
q.load = append([]string(nil), relationships...)
}
// SetSelect on the query.
func SetSelect(q *Query, sel []string) {
q.selectCols = sel
}
// SetCount on the query.
func SetCount(q *Query) {
q.count = true
}
// SetDelete on the query.
func SetDelete(q *Query) {
q.delete = true
}
// SetLimit on the query.
func SetLimit(q *Query, limit int) {
q.limit = limit
}
// SetOffset on the query.
func SetOffset(q *Query, offset int) {
q.offset = offset
}
// SetFor on the query.
func SetFor(q *Query, clause string) {
q.forlock = clause
}
// SetUpdate on the query.
func SetUpdate(q *Query, cols map[string]interface{}) {
q.update = cols
}
// AppendSelect on the query.
func AppendSelect(q *Query, columns ...string) {
q.selectCols = append(q.selectCols, columns...)
}
// AppendFrom on the query.
func AppendFrom(q *Query, from ...string) {
q.from = append(q.from, from...)
}
// SetFrom replaces the current from statements.
func SetFrom(q *Query, from ...string) {
q.from = append([]string(nil), from...)
}
// AppendInnerJoin on the query.
func AppendInnerJoin(q *Query, clause string, args ...interface{}) {
q.joins = append(q.joins, join{clause: clause, kind: JoinInner, args: args})
}
// AppendHaving on the query.
func AppendHaving(q *Query, clause string, args ...interface{}) {
q.having = append(q.having, having{clause: clause, args: args})
}
// AppendWhere on the query.
func AppendWhere(q *Query, clause string, args ...interface{}) {
q.where = append(q.where, where{clause: clause, args: args})
}
// AppendIn on the query.
func AppendIn(q *Query, clause string, args ...interface{}) {
q.in = append(q.in, in{clause: clause, args: args})
}
// SetLastWhereAsOr sets the or separator for the tail "WHERE" in the slice
func SetLastWhereAsOr(q *Query) {
if len(q.where) == 0 {
return
}
q.where[len(q.where)-1].orSeparator = true
}
// SetLastInAsOr sets the or separator for the tail "IN" in the slice
func SetLastInAsOr(q *Query) {
if len(q.in) == 0 {
return
}
q.in[len(q.in)-1].orSeparator = true
}
// AppendGroupBy on the query.
func AppendGroupBy(q *Query, clause string) {
q.groupBy = append(q.groupBy, clause)
}
// AppendOrderBy on the query.
func AppendOrderBy(q *Query, clause string) {
q.orderBy = append(q.orderBy, clause)
}

573
queries/query_builders.go Normal file
View file

@ -0,0 +1,573 @@
package queries
import (
"bytes"
"fmt"
"regexp"
"sort"
"strings"
"github.com/vattle/sqlboiler/strmangle"
)
var (
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
rgxInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])IN([\s|\(|\?].*)$`)
)
func buildQuery(q *Query) (string, []interface{}) {
var buf *bytes.Buffer
var args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
buf, args = buildUpdateQuery(q)
default:
buf, args = buildSelectQuery(q)
}
defer strmangle.PutBuffer(buf)
// Cache the generated query for query object re-use
bufStr := buf.String()
q.plainSQL.sql = bufStr
q.plainSQL.args = args
return bufStr, args
}
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer()
var args []interface{}
buf.WriteString("SELECT ")
if q.count {
buf.WriteString("COUNT(")
}
hasSelectCols := len(q.selectCols) != 0
hasJoins := len(q.joins) != 0
if hasJoins && hasSelectCols && !q.count {
selectColsWithAs := writeAsStatements(q)
// Don't identQuoteSlice - writeAsStatements does this
buf.WriteString(strings.Join(selectColsWithAs, ", "))
} else if hasSelectCols {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
} else if hasJoins && !q.count {
selectColsWithStars := writeStars(q)
buf.WriteString(strings.Join(selectColsWithStars, ", "))
} else {
buf.WriteByte('*')
}
// close SQL COUNT function
if q.count {
buf.WriteByte(')')
}
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
if len(q.joins) > 0 {
argsLen := len(args)
joinBuf := strmangle.GetBuffer()
for _, j := range q.joins {
if j.kind != JoinInner {
panic("only inner joins are supported")
}
fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
args = append(args, j.args...)
}
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
} else {
resp = joinBuf.String()
}
fmt.Fprintf(buf, resp)
strmangle.PutBuffer(joinBuf)
}
where, whereArgs := whereClause(q, len(args)+1)
buf.WriteString(where)
if len(whereArgs) != 0 {
args = append(args, whereArgs...)
}
in, inArgs := inClause(q, len(args)+1)
buf.WriteString(in)
if len(inArgs) != 0 {
args = append(args, inArgs...)
}
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
var args []interface{}
buf := strmangle.GetBuffer()
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
where, whereArgs := whereClause(q, 1)
if len(whereArgs) != 0 {
args = append(args, whereArgs)
}
buf.WriteString(where)
in, inArgs := inClause(q, len(args)+1)
if len(inArgs) != 0 {
args = append(args, inArgs...)
}
buf.WriteString(in)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer()
buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
cols := make(sort.StringSlice, len(q.update))
var args []interface{}
count := 0
for name := range q.update {
cols[count] = name
count++
}
cols.Sort()
for i := 0; i < len(cols); i++ {
args = append(args, q.update[cols[i]])
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
}
buf.WriteString(fmt.Sprintf(
" SET (%s) = (%s)",
strings.Join(cols, ", "),
strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)),
)
where, whereArgs := whereClause(q, len(args)+1)
if len(whereArgs) != 0 {
args = append(args, whereArgs...)
}
buf.WriteString(where)
in, inArgs := inClause(q, len(args)+1)
if len(inArgs) != 0 {
args = append(args, inArgs...)
}
buf.WriteString(in)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided.
func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string {
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
fmt.Fprintf(
buf,
"INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ",
tableName,
strings.Join(whitelist, ", "),
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
)
for i, v := range update {
if i != 0 {
buf.WriteByte(',')
}
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted)
buf.WriteString(" = VALUES(")
buf.WriteString(quoted)
buf.WriteByte(')')
}
return buf.String()
}
// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided.
func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict)
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret)
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
fmt.Fprintf(
buf,
"INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ",
tableName,
strings.Join(whitelist, ", "),
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
)
if !updateOnConflict || len(update) == 0 {
buf.WriteString("DO NOTHING")
} else {
buf.WriteByte('(')
buf.WriteString(strings.Join(conflict, ", "))
buf.WriteString(") DO UPDATE SET ")
for i, v := range update {
if i != 0 {
buf.WriteByte(',')
}
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted)
buf.WriteString(" = EXCLUDED.")
buf.WriteString(quoted)
}
}
if len(ret) != 0 {
buf.WriteString(" RETURNING ")
buf.WriteString(strings.Join(ret, ", "))
}
return buf.String()
}
func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
if len(q.groupBy) != 0 {
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
}
if len(q.having) != 0 {
argsLen := len(*args)
havingBuf := strmangle.GetBuffer()
fmt.Fprintf(havingBuf, " HAVING ")
for i, j := range q.having {
if i > 0 {
fmt.Fprintf(havingBuf, ", ")
}
fmt.Fprintf(havingBuf, j.clause)
*args = append(*args, j.args...)
}
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1)
} else {
resp = havingBuf.String()
}
fmt.Fprintf(buf, resp)
strmangle.PutBuffer(havingBuf)
}
if len(q.orderBy) != 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(q.orderBy, ", "))
}
if q.limit != 0 {
fmt.Fprintf(buf, " LIMIT %d", q.limit)
}
if q.offset != 0 {
fmt.Fprintf(buf, " OFFSET %d", q.offset)
}
if len(q.forlock) != 0 {
fmt.Fprintf(buf, " FOR %s", q.forlock)
}
}
func writeStars(q *Query) []string {
cols := make([]string, len(q.from))
for i, f := range q.from {
toks := strings.Split(f, " ")
if len(toks) == 1 {
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
continue
}
alias, name, ok := parseFromClause(toks)
if !ok {
return nil
}
if len(alias) != 0 {
name = alias
}
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
}
return cols
}
func writeAsStatements(q *Query) []string {
cols := make([]string, len(q.selectCols))
for i, col := range q.selectCols {
if !rgxIdentifier.MatchString(col) {
cols[i] = col
continue
}
toks := strings.Split(col, ".")
if len(toks) == 1 {
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
continue
}
asParts := make([]string, len(toks))
for j, tok := range toks {
asParts[j] = strings.Trim(tok, `"`)
}
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
}
return cols
}
// whereClause parses a where slice and converts it into a
// single WHERE clause like:
// WHERE (a=$1) AND (b=$2)
//
// startAt specifies what number placeholders start at
func whereClause(q *Query, startAt int) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
var args []interface{}
buf.WriteString(" WHERE ")
for i, where := range q.where {
if i != 0 {
if where.orSeparator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}
buf.WriteString(fmt.Sprintf("(%s)", where.clause))
args = append(args, where.args...)
}
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(buf.String(), startAt)
} else {
resp = buf.String()
}
return resp, args
}
// inClause parses an in slice and converts it into a
// single IN clause, like:
// WHERE ("a", "b") IN (($1,$2),($3,$4)).
func inClause(q *Query, startAt int) (string, []interface{}) {
if len(q.in) == 0 {
return "", nil
}
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
var args []interface{}
if len(q.where) == 0 {
buf.WriteString(" WHERE ")
}
for i, in := range q.in {
ln := len(in.args)
// We only prefix the OR and AND separators after the first
// clause has been generated UNLESS there is already a where
// clause that we have to add on to.
if i != 0 || len(q.where) > 0 {
if in.orSeparator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}
matches := rgxInClause.FindStringSubmatch(in.clause)
// If we can't find any matches attempt a simple replace with 1 group.
// Clauses that fit this criteria will not be able to contain ? in their
// column name side, however if this case is being hit then the regexp
// probably needs adjustment, or the user is passing in invalid clauses.
if matches == nil {
clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln)
buf.WriteString(clause)
startAt = startAt + count
} else {
leftSide := strings.TrimSpace(matches[1])
rightSide := strings.TrimSpace(matches[2])
// If matches are found, we have to parse the left side (column side)
// of the clause to determine how many columns they are using.
// This number determines the groupAt for the convert function.
cols := strings.Split(leftSide, ",")
cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
groupAt := len(cols)
var leftClause string
var leftCount int
if q.dialect.IndexPlaceholders {
leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
} else {
// Count the number of cols that are question marks, so we know
// how much to offset convertInQuestionMarks by
for _, v := range cols {
if v == "?" {
leftCount++
}
}
leftClause = strings.Join(cols, ",")
}
rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
buf.WriteString(leftClause)
buf.WriteString(" IN ")
buf.WriteString(rightClause)
startAt = startAt + leftCount + rightCount
}
args = append(args, in.args...)
}
return buf.String(), args
}
// convertInQuestionMarks finds the first unescaped occurrence of ? and swaps it
// with a list of numbered placeholders, starting at startAt.
// It uses groupAt to determine how many placeholders should be in each group,
// for example, groupAt 2 would result in: (($1,$2),($3,$4))
// and groupAt 1 would result in ($1,$2,$3,$4)
func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
if startAt == 0 || len(clause) == 0 {
panic("Not a valid start number.")
}
paramBuf := strmangle.GetBuffer()
defer strmangle.PutBuffer(paramBuf)
foundAt := -1
for i := 0; i < len(clause); i++ {
if (clause[i] == '?' && i == 0) || (clause[i] == '?' && clause[i-1] != '\\') {
foundAt = i
break
}
}
if foundAt == -1 {
return strings.Replace(clause, `\?`, "?", -1), 0
}
paramBuf.WriteString(clause[:foundAt])
paramBuf.WriteByte('(')
paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt))
paramBuf.WriteByte(')')
paramBuf.WriteString(clause[foundAt+1:])
// Remove all backslashes from escaped question-marks
ret := strings.Replace(paramBuf.String(), `\?`, "?", -1)
return ret, total
}
// convertQuestionMarks converts each occurrence of ? with $<number>
// where <number> is an incrementing digit starting at startAt.
// If question-mark (?) is escaped using back-slash (\), it will be ignored.
func convertQuestionMarks(clause string, startAt int) (string, int) {
if startAt == 0 {
panic("Not a valid start number.")
}
paramBuf := strmangle.GetBuffer()
defer strmangle.PutBuffer(paramBuf)
paramIndex := 0
total := 0
for {
if paramIndex >= len(clause) {
break
}
clause = clause[paramIndex:]
paramIndex = strings.IndexByte(clause, '?')
if paramIndex == -1 {
paramBuf.WriteString(clause)
break
}
escapeIndex := strings.Index(clause, `\?`)
if escapeIndex != -1 && paramIndex > escapeIndex {
paramBuf.WriteString(clause[:escapeIndex] + "?")
paramIndex++
continue
}
paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
total++
startAt++
paramIndex++
}
return paramBuf.String(), total
}
// parseFromClause will parse something that looks like
// a
// a b
// a as b
func parseFromClause(toks []string) (alias, name string, ok bool) {
if len(toks) > 3 {
toks = toks[:3]
}
sawIdent, sawAs := false, false
for _, tok := range toks {
if t := strings.ToLower(tok); sawIdent && t == "as" {
sawAs = true
continue
} else if sawIdent && t == "on" {
break
}
if !rgxIdentifier.MatchString(tok) {
break
}
if sawIdent || sawAs {
alias = strings.Trim(tok, `"`)
break
}
name = strings.Trim(tok, `"`)
sawIdent = true
ok = true
}
return alias, name, ok
}

View file

@ -0,0 +1,547 @@
package queries
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
)
var writeGoldenFiles = flag.Bool(
"test.golden",
false,
"Write golden files.",
)
func TestBuildQuery(t *testing.T) {
t.Parallel()
tests := []struct {
q *Query
args []interface{}
}{
{&Query{from: []string{"t"}}, nil},
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
{&Query{
selectCols: []string{"a.happy", "r.fun", "q"},
from: []string{"happiness as a"},
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
}, nil},
{&Query{
from: []string{"happiness as a"},
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
}, nil},
{&Query{
from: []string{"videos"},
joins: []join{{
clause: "(select id from users where deleted = ?) u on u.id = videos.user_id",
args: []interface{}{true},
}},
where: []where{{clause: "videos.deleted = ?", args: []interface{}{false}}},
}, []interface{}{true, false}},
{&Query{
from: []string{"a"},
groupBy: []string{"id", "name"},
where: []where{
{clause: "a=? or b=?", args: []interface{}{1, 2}},
{clause: "c=?", args: []interface{}{3}},
},
having: []having{
{clause: "id <> ?", args: []interface{}{1}},
{clause: "length(name, ?) > ?", args: []interface{}{"utf8", 5}},
},
}, []interface{}{1, 2, 3, 1, "utf8", 5}},
{&Query{
delete: true,
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
where: []where{
{clause: "a=?", args: []interface{}{}},
{clause: "b=?", args: []interface{}{}},
{clause: "c=?", args: []interface{}{}},
},
}, nil},
{&Query{
delete: true,
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
where: []where{
{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}},
},
limit: 5,
}, nil},
{&Query{
from: []string{"thing happy", `"fun"`, `stuff`},
update: map[string]interface{}{
"col1": 1,
`"col2"`: 2,
`"fun".col3`: 3,
},
where: []where{
{clause: "aa=? or bb=? or cc=?", orSeparator: true, args: []interface{}{4, 5, 6}},
{clause: "dd=? or ee=? or ff=? and gg=?", args: []interface{}{7, 8, 9, 10}},
},
limit: 5,
}, []interface{}{2, 3, 1, 4, 5, 6, 7, 8, 9, 10}},
{&Query{from: []string{"cats"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
{&Query{from: []string{"cats c"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
{&Query{from: []string{"cats as c"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
{&Query{from: []string{"cats as c", "dogs as d"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
}
for i, test := range tests {
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
out, args := buildQuery(test.q)
if *writeGoldenFiles {
err := ioutil.WriteFile(filename, []byte(out), 0664)
if err != nil {
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
}
t.Logf("wrote golden file: %s\n", filename)
continue
}
byt, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("Failed to read golden file %q: %v", filename, err)
}
if string(bytes.TrimSpace(byt)) != out {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
}
if !reflect.DeepEqual(args, test.args) {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
}
}
}
func TestWriteStars(t *testing.T) {
t.Parallel()
tests := []struct {
In Query
Out []string
}{
{
In: Query{from: []string{`a`}},
Out: []string{`"a".*`},
},
{
In: Query{from: []string{`a as b`}},
Out: []string{`"b".*`},
},
{
In: Query{from: []string{`a as b`, `c`}},
Out: []string{`"b".*`, `"c".*`},
},
{
In: Query{from: []string{`a as b`, `c as d`}},
Out: []string{`"b".*`, `"d".*`},
},
}
for i, test := range tests {
test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
selects := writeStars(&test.In)
if !reflect.DeepEqual(selects, test.Out) {
t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects)
}
}
}
func TestWhereClause(t *testing.T) {
t.Parallel()
tests := []struct {
q Query
expect string
}{
// Or("a=?")
{
q: Query{
where: []where{{clause: "a=?", orSeparator: true}},
},
expect: " WHERE (a=$1)",
},
// Where("a=?")
{
q: Query{
where: []where{{clause: "a=?"}},
},
expect: " WHERE (a=$1)",
},
// Where("(a=?)")
{
q: Query{
where: []where{{clause: "(a=?)"}},
},
expect: " WHERE ((a=$1))",
},
// Where("((a=? OR b=?))")
{
q: Query{
where: []where{{clause: "((a=? OR b=?))"}},
},
expect: " WHERE (((a=$1 OR b=$2)))",
},
// Where("(a=?)", Or("(b=?)")
{
q: Query{
where: []where{
{clause: "(a=?)"},
{clause: "(b=?)", orSeparator: true},
},
},
expect: " WHERE ((a=$1)) OR ((b=$2))",
},
// Where("a=? OR b=?")
{
q: Query{
where: []where{{clause: "a=? OR b=?"}},
},
expect: " WHERE (a=$1 OR b=$2)",
},
// Where("a=?"), Where("b=?")
{
q: Query{
where: []where{{clause: "a=?"}, {clause: "b=?"}},
},
expect: " WHERE (a=$1) AND (b=$2)",
},
// Where("(a=? AND b=?) OR c=?")
{
q: Query{
where: []where{{clause: "(a=? AND b=?) OR c=?"}},
},
expect: " WHERE ((a=$1 AND b=$2) OR c=$3)",
},
// Where("a=? OR b=?"), Where("c=? OR d=? OR e=?")
{
q: Query{
where: []where{
{clause: "(a=? OR b=?)"},
{clause: "(c=? OR d=? OR e=?)"},
},
},
expect: " WHERE ((a=$1 OR b=$2)) AND ((c=$3 OR d=$4 OR e=$5))",
},
// Where("(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR f=?")
{
q: Query{
where: []where{
{clause: "(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR g=?"},
},
},
expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 AND e=$5) OR f=$6 OR g=$7)",
},
// Where("(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?")
{
q: Query{
where: []where{
{clause: "(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?"},
},
},
expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 OR e=$5) OR f=$6 OR g=$7)",
},
// Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?")
{
q: Query{
where: []where{
{clause: "a=? or b=?", orSeparator: true},
{clause: "c=? and d=?", orSeparator: true},
{clause: "e=? or f=?", orSeparator: true},
},
},
expect: " WHERE (a=$1 or b=$2) OR (c=$3 and d=$4) OR (e=$5 or f=$6)",
},
// Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?")
{
q: Query{
where: []where{
{clause: "a=? or b=?"},
{clause: "c=? and d=?", orSeparator: true},
{clause: "e=? or f=?"},
},
},
expect: " WHERE (a=$1 or b=$2) OR (c=$3 and d=$4) AND (e=$5 or f=$6)",
},
}
for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, _ := whereClause(&test.q, 1)
if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
}
}
}
func TestInClause(t *testing.T) {
t.Parallel()
tests := []struct {
q Query
expect string
args []interface{}
}{
{
q: Query{
in: []in{{clause: "a in ?", args: []interface{}{}, orSeparator: true}},
},
expect: ` WHERE "a" IN ()`,
},
{
q: Query{
in: []in{{clause: "a in ?", args: []interface{}{1}, orSeparator: true}},
},
expect: ` WHERE "a" IN ($1)`,
args: []interface{}{1},
},
{
q: Query{
in: []in{{clause: "a in ?", args: []interface{}{1, 2, 3}}},
},
expect: ` WHERE "a" IN ($1,$2,$3)`,
args: []interface{}{1, 2, 3},
},
{
q: Query{
in: []in{{clause: "? in ?", args: []interface{}{1, 2, 3}}},
},
expect: " WHERE $1 IN ($2,$3)",
args: []interface{}{1, 2, 3},
},
{
q: Query{
in: []in{{clause: "( ? , ? ) in ( ? )", orSeparator: true, args: []interface{}{"a", "b", 1, 2, 3, 4}}},
},
expect: " WHERE ( $1 , $2 ) IN ( (($3,$4),($5,$6)) )",
args: []interface{}{"a", "b", 1, 2, 3, 4},
},
{
q: Query{
in: []in{{clause: `("a")in(?)`, orSeparator: true, args: []interface{}{1, 2, 3}}},
},
expect: ` WHERE ("a") IN (($1,$2,$3))`,
args: []interface{}{1, 2, 3},
},
{
q: Query{
in: []in{{clause: `("a")in?`, args: []interface{}{1}}},
},
expect: ` WHERE ("a") IN ($1)`,
args: []interface{}{1},
},
{
q: Query{
where: []where{
{clause: "a=?", args: []interface{}{1}},
},
in: []in{
{clause: `?,?,"name" in ?`, orSeparator: true, args: []interface{}{"c", "d", 3, 4, 5, 6, 7, 8}},
{clause: `?,?,"name" in ?`, orSeparator: true, args: []interface{}{"e", "f", 9, 10, 11, 12, 13, 14}},
},
},
expect: ` OR $1,$2,"name" IN (($3,$4,$5),($6,$7,$8)) OR $9,$10,"name" IN (($11,$12,$13),($14,$15,$16))`,
args: []interface{}{"c", "d", 3, 4, 5, 6, 7, 8, "e", "f", 9, 10, 11, 12, 13, 14},
},
{
q: Query{
in: []in{
{clause: `("a")in`, args: []interface{}{1}},
{clause: `("a")in?`, orSeparator: true, args: []interface{}{1}},
},
},
expect: ` WHERE ("a")in OR ("a") IN ($1)`,
args: []interface{}{1, 1},
},
{
q: Query{
in: []in{
{clause: `\?,\? in \?`, args: []interface{}{1}},
{clause: `\?,\?in \?`, orSeparator: true, args: []interface{}{1}},
},
},
expect: ` WHERE ?,? IN ? OR ?,? IN ?`,
args: []interface{}{1, 1},
},
{
q: Query{
in: []in{
{clause: `("a")in`, args: []interface{}{1}},
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
},
},
expect: ` WHERE ("a")in AND ("a") IN thing OR ("a") IN ($1,$2,$3)`,
args: []interface{}{1, 1, 2, 3, 4, 5, 6},
},
{
q: Query{
in: []in{
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
{clause: `("a")in`, args: []interface{}{1}},
},
},
expect: ` WHERE ("a") IN ($1,$2,$3) AND ("a") IN thing AND ("a")in`,
args: []interface{}{4, 5, 6, 1, 2, 3, 1},
},
{
q: Query{
in: []in{
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
{clause: `("a")in`, args: []interface{}{1}},
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
},
},
expect: ` WHERE ("a") IN ($1,$2,$3) AND ("a")in AND ("a") IN thing`,
args: []interface{}{4, 5, 6, 1, 1, 2, 3},
},
}
for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, args := inClause(&test.q, 1)
if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
}
if !reflect.DeepEqual(args, test.args) {
t.Errorf("%d) Mismatch between expected args:\n%#v\n%#v\n", i, test.args, args)
}
}
}
func TestConvertQuestionMarks(t *testing.T) {
t.Parallel()
tests := []struct {
clause string
start int
expect string
count int
}{
{clause: "hello friend", start: 1, expect: "hello friend", count: 0},
{clause: "thing=?", start: 2, expect: "thing=$2", count: 1},
{clause: "thing=? and stuff=? and happy=?", start: 2, expect: "thing=$2 and stuff=$3 and happy=$4", count: 3},
{clause: `thing \? stuff`, start: 2, expect: `thing ? stuff`, count: 0},
{clause: `thing \? stuff and happy \? fun`, start: 2, expect: `thing ? stuff and happy ? fun`, count: 0},
{
clause: `thing \? stuff ? happy \? and mad ? fun \? \? \?`,
start: 2,
expect: `thing ? stuff $2 happy ? and mad $3 fun ? ? ?`,
count: 2,
},
{
clause: `thing ? stuff ? happy \? fun \? ? ?`,
start: 1,
expect: `thing $1 stuff $2 happy ? fun ? $3 $4`,
count: 4,
},
{clause: `?`, start: 1, expect: `$1`, count: 1},
{clause: `???`, start: 1, expect: `$1$2$3`, count: 3},
{clause: `\?`, start: 1, expect: `?`},
{clause: `\?\?\?`, start: 1, expect: `???`},
{clause: `\??\??\??`, start: 1, expect: `?$1?$2?$3`, count: 3},
{clause: `?\??\??\?`, start: 1, expect: `$1?$2?$3?`, count: 3},
}
for i, test := range tests {
res, count := convertQuestionMarks(test.clause, test.start)
if res != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
}
if count != test.count {
t.Errorf("%d) Expected count %d, got %d", i, test.count, count)
}
}
}
func TestConvertInQuestionMarks(t *testing.T) {
t.Parallel()
tests := []struct {
clause string
start int
group int
total int
expect string
}{
{clause: "?", expect: "(($1,$2,$3),($4,$5,$6),($7,$8,$9))", start: 1, total: 9, group: 3},
{clause: "?", expect: "(($2,$3),($4))", start: 2, total: 3, group: 2},
{clause: "hello friend", start: 1, expect: "hello friend", total: 0, group: 1},
{clause: "thing ? thing", start: 2, expect: "thing ($2,$3) thing", total: 2, group: 1},
{clause: "thing?thing", start: 2, expect: "thing($2)thing", total: 1, group: 1},
{clause: `thing \? stuff`, start: 2, expect: `thing ? stuff`, total: 0, group: 1},
{clause: `thing \? stuff and happy \? fun`, start: 2, expect: `thing ? stuff and happy ? fun`, total: 0, group: 1},
{clause: "thing ? thing ? thing", start: 1, expect: "thing ($1,$2,$3) thing ? thing", total: 3, group: 1},
{clause: `?`, start: 1, expect: `($1)`, total: 1, group: 1},
{clause: `???`, start: 1, expect: `($1,$2,$3)??`, total: 3, group: 1},
{clause: `\?`, start: 1, expect: `?`, total: 0, group: 1},
{clause: `\?\?\?`, start: 1, expect: `???`, total: 0, group: 1},
{clause: `\??\??\??`, start: 1, expect: `?($1,$2,$3)????`, total: 3, group: 1},
{clause: `?\??\??\?`, start: 1, expect: `($1,$2,$3)?????`, total: 3, group: 1},
}
for i, test := range tests {
res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total)
if res != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
}
if count != test.total {
t.Errorf("%d) Expected %d, got %d", i, test.total, count)
}
}
res, count := convertInQuestionMarks(false, "?", 1, 3, 9)
if res != "((?,?,?),(?,?,?),(?,?,?))" {
t.Errorf("Mismatch between expected and result: %s", res)
}
if count != 9 {
t.Errorf("Expected 9 results, got %d", count)
}
}
func TestWriteAsStatements(t *testing.T) {
t.Parallel()
query := Query{
selectCols: []string{
`a`,
`a.fun`,
`"b"."fun"`,
`"b".fun`,
`b."fun"`,
`a.clown.run`,
`COUNT(a)`,
},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
expect := []string{
`"a"`,
`"a"."fun" as "a.fun"`,
`"b"."fun" as "b.fun"`,
`"b"."fun" as "b.fun"`,
`"b"."fun" as "b.fun"`,
`"a"."clown"."run" as "a.clown.run"`,
`COUNT(a)`,
}
gots := writeAsStatements(&query)
for i, got := range gots {
if expect[i] != got {
t.Errorf(`%d) want: %s, got: %s`, i, expect[i], got)
}
}
}

438
queries/query_test.go Normal file
View file

@ -0,0 +1,438 @@
package queries
import (
"database/sql"
"reflect"
"testing"
)
func TestSetLimit(t *testing.T) {
t.Parallel()
q := &Query{}
SetLimit(q, 10)
expect := 10
if q.limit != expect {
t.Errorf("Expected %d, got %d", expect, q.limit)
}
}
func TestSetOffset(t *testing.T) {
t.Parallel()
q := &Query{}
SetOffset(q, 10)
expect := 10
if q.offset != expect {
t.Errorf("Expected %d, got %d", expect, q.offset)
}
}
func TestSetSQL(t *testing.T) {
t.Parallel()
q := &Query{}
SetSQL(q, "select * from thing", 5, 3)
if len(q.plainSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args))
}
if q.plainSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.plainSQL.sql)
}
}
func TestSetLoad(t *testing.T) {
t.Parallel()
q := &Query{}
SetLoad(q, "one", "two")
if len(q.load) != 2 {
t.Errorf("Expected len 2, got %d", len(q.load))
}
if q.load[0] != "one" || q.load[1] != "two" {
t.Errorf("Was not expected string, got %s", q.load)
}
}
func TestAppendWhere(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "x > $1 AND y > $2"
AppendWhere(q, expect, 5, 3)
AppendWhere(q, expect, 5, 3)
if len(q.where) != 2 {
t.Errorf("%#v", q.where)
}
if q.where[0].clause != expect || q.where[1].clause != expect {
t.Errorf("Expected %s, got %#v", expect, q.where)
}
if len(q.where[0].args) != 2 || len(q.where[0].args) != 2 {
t.Errorf("arg length wrong: %#v", q.where)
}
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
t.Errorf("args wrong: %#v", q.where)
}
q.where = []where{{clause: expect, args: []interface{}{5, 3}}}
if q.where[0].clause != expect {
t.Errorf("Expected %s, got %v", expect, q.where)
}
if len(q.where[0].args) != 2 {
t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args))
}
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args)
}
if len(q.where) != 1 {
t.Errorf("%#v", q.where)
}
}
func TestSetLastWhereAsOr(t *testing.T) {
t.Parallel()
q := &Query{}
AppendWhere(q, "")
if q.where[0].orSeparator {
t.Errorf("Do not want or separator")
}
SetLastWhereAsOr(q)
if len(q.where) != 1 {
t.Errorf("Want len 1")
}
if !q.where[0].orSeparator {
t.Errorf("Want or separator")
}
AppendWhere(q, "")
SetLastWhereAsOr(q)
if len(q.where) != 2 {
t.Errorf("Want len 2")
}
if q.where[0].orSeparator != true {
t.Errorf("Expected true")
}
if q.where[1].orSeparator != true {
t.Errorf("Expected true")
}
}
func TestAppendIn(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "col IN ?"
AppendIn(q, expect, 5, 3)
AppendIn(q, expect, 5, 3)
if len(q.in) != 2 {
t.Errorf("%#v", q.in)
}
if q.in[0].clause != expect || q.in[1].clause != expect {
t.Errorf("Expected %s, got %#v", expect, q.in)
}
if len(q.in[0].args) != 2 || len(q.in[0].args) != 2 {
t.Errorf("arg length wrong: %#v", q.in)
}
if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 {
t.Errorf("args wrong: %#v", q.in)
}
q.in = []in{{clause: expect, args: []interface{}{5, 3}}}
if q.in[0].clause != expect {
t.Errorf("Expected %s, got %v", expect, q.in)
}
if len(q.in[0].args) != 2 {
t.Errorf("Expected %d args, got %d", 2, len(q.in[0].args))
}
if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 {
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.in[0].args)
}
if len(q.in) != 1 {
t.Errorf("%#v", q.in)
}
}
func TestSetLastInAsOr(t *testing.T) {
t.Parallel()
q := &Query{}
AppendIn(q, "")
if q.in[0].orSeparator {
t.Errorf("Do not want or separator")
}
SetLastInAsOr(q)
if len(q.in) != 1 {
t.Errorf("Want len 1")
}
if !q.in[0].orSeparator {
t.Errorf("Want or separator")
}
AppendIn(q, "")
SetLastInAsOr(q)
if len(q.in) != 2 {
t.Errorf("Want len 2")
}
if q.in[0].orSeparator != true {
t.Errorf("Expected true")
}
if q.in[1].orSeparator != true {
t.Errorf("Expected true")
}
}
func TestAppendGroupBy(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "col1, col2"
AppendGroupBy(q, expect)
AppendGroupBy(q, expect)
if len(q.groupBy) != 2 && (q.groupBy[0] != expect || q.groupBy[1] != expect) {
t.Errorf("Expected %s, got %s %s", expect, q.groupBy[0], q.groupBy[1])
}
q.groupBy = []string{expect}
if len(q.groupBy) != 1 && q.groupBy[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.groupBy[0])
}
}
func TestAppendOrderBy(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "col1 desc, col2 asc"
AppendOrderBy(q, expect)
AppendOrderBy(q, expect)
if len(q.orderBy) != 2 && (q.orderBy[0] != expect || q.orderBy[1] != expect) {
t.Errorf("Expected %s, got %s %s", expect, q.orderBy[0], q.orderBy[1])
}
q.orderBy = []string{"col1 desc, col2 asc"}
if len(q.orderBy) != 1 && q.orderBy[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.orderBy[0])
}
}
func TestAppendHaving(t *testing.T) {
t.Parallel()
q := &Query{}
expect := "count(orders.order_id) > ?"
AppendHaving(q, expect, 10)
AppendHaving(q, expect, 10)
if len(q.having) != 2 {
t.Errorf("Expected 2, got %d", len(q.having))
}
if q.having[0].clause != expect || q.having[1].clause != expect {
t.Errorf("Expected %s, got %s %s", expect, q.having[0].clause, q.having[1].clause)
}
if q.having[0].args[0] != 10 || q.having[1].args[0] != 10 {
t.Errorf("Expected %v, got %v %v", 10, q.having[0].args[0], q.having[1].args[0])
}
q.having = []having{{clause: expect, args: []interface{}{10}}}
if len(q.having) != 1 && (q.having[0].clause != expect || q.having[0].args[0] != 10) {
t.Errorf("Expected %s, got %s %v", expect, q.having[0], q.having[0].args[0])
}
}
func TestFrom(t *testing.T) {
t.Parallel()
q := &Query{}
AppendFrom(q, "videos a", "orders b")
AppendFrom(q, "videos a", "orders b")
expect := []string{"videos a", "orders b", "videos a", "orders b"}
if !reflect.DeepEqual(q.from, expect) {
t.Errorf("Expected %s, got %s", expect, q.from)
}
SetFrom(q, "videos a", "orders b")
if !reflect.DeepEqual(q.from, expect[:2]) {
t.Errorf("Expected %s, got %s", expect, q.from)
}
}
func TestSetSelect(t *testing.T) {
t.Parallel()
q := &Query{selectCols: []string{"hello"}}
SetSelect(q, nil)
if q.selectCols != nil {
t.Errorf("want nil")
}
}
func TestSetCount(t *testing.T) {
t.Parallel()
q := &Query{}
SetCount(q)
if q.count != true {
t.Errorf("got false")
}
}
func TestSetUpdate(t *testing.T) {
t.Parallel()
q := &Query{}
SetUpdate(q, map[string]interface{}{"test": 5})
if q.update["test"] != 5 {
t.Errorf("Wrong update, got %v", q.update)
}
}
func TestSetDelete(t *testing.T) {
t.Parallel()
q := &Query{}
SetDelete(q)
if q.delete != true {
t.Errorf("Expected %t, got %t", true, q.delete)
}
}
func TestSetExecutor(t *testing.T) {
t.Parallel()
q := &Query{}
d := &sql.DB{}
SetExecutor(q, d)
if q.executor != d {
t.Errorf("Expected executor to get set to d, but was: %#v", q.executor)
}
}
func TestAppendSelect(t *testing.T) {
t.Parallel()
q := &Query{}
AppendSelect(q, "col1", "col2")
AppendSelect(q, "col1", "col2")
if len(q.selectCols) != 4 {
t.Errorf("Expected selectCols len 4, got %d", len(q.selectCols))
}
if q.selectCols[0] != `col1` && q.selectCols[1] != `col2` {
t.Errorf("select cols value mismatch: %#v", q.selectCols)
}
if q.selectCols[2] != `col1` && q.selectCols[3] != `col2` {
t.Errorf("select cols value mismatch: %#v", q.selectCols)
}
q.selectCols = []string{"col1", "col2"}
if q.selectCols[0] != `col1` && q.selectCols[1] != `col2` {
t.Errorf("select cols value mismatch: %#v", q.selectCols)
}
}
func TestSQL(t *testing.T) {
t.Parallel()
q := SQL(&sql.DB{}, "thing", 5)
if q.plainSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
}
if q.plainSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
}
}
func TestSQLG(t *testing.T) {
t.Parallel()
q := SQLG("thing", 5)
if q.plainSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
}
if q.plainSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
}
}
func TestAppendInnerJoin(t *testing.T) {
t.Parallel()
q := &Query{}
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
if len(q.joins) != 2 {
t.Errorf("Expected len 1, got %d", len(q.joins))
}
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
if q.joins[1].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
if len(q.joins[0].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.joins[0].args))
}
if len(q.joins[1].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.joins[1].args))
}
if q.joins[0].args[0] != 2 && q.joins[0].args[1] != 5 {
t.Errorf("Invalid args values, got %#v", q.joins[0].args)
}
q.joins = []join{{kind: JoinInner,
clause: "thing=$1 AND stuff=$2",
args: []interface{}{2, 5},
}}
if len(q.joins) != 1 {
t.Errorf("Expected len 1, got %d", len(q.joins))
}
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
}

480
queries/reflect.go Normal file
View file

@ -0,0 +1,480 @@
package queries
import (
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle"
)
var (
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
mut sync.RWMutex
bindingMaps = make(map[string][]uint64)
structMaps = make(map[string]map[string]uint64)
)
// Identifies what kind of object we're binding to
type bindKind int
const (
kindStruct bindKind = iota
kindSliceStruct
kindPtrSliceStruct
)
const (
loadMethodPrefix = "Load"
relationshipStructName = "R"
loaderStructName = "L"
sentinel = uint64(255)
)
// BindP executes the query and inserts the
// result into the passed in object pointer.
// It panics on error. See boil.Bind() documentation.
func (q *Query) BindP(obj interface{}) {
if err := q.Bind(obj); err != nil {
panic(boil.WrapErr(err))
}
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// Bind rules:
// - Struct tags control bind, in the form of: `boil:"name,bind"`
// - If "name" is omitted the sql column names that come back are TitleCased
// and matched against the field name.
// - If the "name" part of the struct tag is specified, the given name will
// be used instead of the struct field name for binding.
// - If the "name" of the struct tag is "-", this field will not be bound to.
// - If the ",bind" option is specified on a struct field and that field
// is a struct itself, it will be recursed into to look for fields for binding.
//
// Example Query:
//
// type JoinStruct struct {
// // User1 can have it's struct fields bound to since it specifies
// // ,bind in the struct tag, it will look specifically for
// // fields that are prefixed with "user." returning from the query.
// // For example "user.id" column name will bind to User1.ID
// User1 *models.User `boil:"user,bind"`
// // User2 will follow the same rules as noted above except it will use
// // "friend." as the prefix it's looking for.
// User2 *models.User `boil:"friend,bind"`
// // RandomData will not be recursed into to look for fields to
// // bind and will not be bound to because of the - for the name.
// RandomData myStruct `boil:"-"`
// // Date will not be recursed into to look for fields to bind because
// // it does not specify ,bind in the struct tag. But it can be bound to
// // as it does not specify a - for the name.
// Date time.Time
// }
//
// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct)
//
// For custom objects that want to use eager loading, please see the
// loadRelationships function.
func Bind(rows *sql.Rows, obj interface{}) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
}
return bind(rows, obj, structType, sliceType, singular)
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// See documentation for boil.Bind()
func (q *Query) Bind(obj interface{}) error {
structType, sliceType, bkind, err := bindChecks(obj)
if err != nil {
return err
}
rows, err := q.Query()
if err != nil {
return errors.Wrap(err, "bind failed to execute query")
}
defer rows.Close()
if res := bind(rows, obj, structType, sliceType, bkind); res != nil {
return res
}
if len(q.load) == 0 {
return nil
}
state := loadRelationshipState{
exec: q.executor,
loaded: map[string]struct{}{},
}
for _, toLoad := range q.load {
state.toLoad = strings.Split(toLoad, ".")
if err = state.loadRelationships(0, obj, bkind); err != nil {
return err
}
}
return nil
}
// bindChecks resolves information about the bind target, and errors if it's not an object
// we can bind to.
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, bkind bindKind, err error) {
typ := reflect.TypeOf(obj)
kind := typ.Kind()
setErr := func() {
err = errors.Errorf("obj type should be *Type, *[]Type, or *[]*Type but was %q", reflect.TypeOf(obj).String())
}
for i := 0; ; i++ {
switch i {
case 0:
if kind != reflect.Ptr {
setErr()
return
}
case 1:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindStruct
return
case reflect.Slice:
sliceType = typ
default:
setErr()
return
}
case 2:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindSliceStruct
return
case reflect.Ptr:
default:
setErr()
return
}
case 3:
if kind != reflect.Struct {
setErr()
return
}
structType = typ
bkind = kindPtrSliceStruct
return
}
typ = typ.Elem()
kind = typ.Kind()
}
}
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, bkind bindKind) error {
cols, err := rows.Columns()
if err != nil {
return errors.Wrap(err, "bind failed to get column names")
}
var ptrSlice reflect.Value
switch bkind {
case kindSliceStruct, kindPtrSliceStruct:
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
}
var strMapping map[string]uint64
var sok bool
var mapping []uint64
var ok bool
typStr := structType.String()
mapKey := makeCacheKey(typStr, cols)
mut.RLock()
mapping, ok = bindingMaps[mapKey]
if !ok {
if strMapping, sok = structMaps[typStr]; !sok {
strMapping = MakeStructMapping(structType)
}
}
mut.RUnlock()
if !ok {
mapping, err = BindMapping(structType, strMapping, cols)
if err != nil {
return err
}
mut.Lock()
if !sok {
structMaps[typStr] = strMapping
}
bindingMaps[mapKey] = mapping
mut.Unlock()
}
var oneStruct reflect.Value
if bkind == kindSliceStruct {
oneStruct = reflect.Indirect(reflect.New(structType))
}
foundOne := false
for rows.Next() {
foundOne = true
var newStruct reflect.Value
var pointers []interface{}
switch bkind {
case kindStruct:
pointers = PtrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
case kindSliceStruct:
pointers = PtrsFromMapping(oneStruct, mapping)
case kindPtrSliceStruct:
newStruct = reflect.New(structType)
pointers = PtrsFromMapping(reflect.Indirect(newStruct), mapping)
}
if err != nil {
return err
}
if err := rows.Scan(pointers...); err != nil {
return errors.Wrap(err, "failed to bind pointers to obj")
}
switch bkind {
case kindSliceStruct:
ptrSlice.Set(reflect.Append(ptrSlice, oneStruct))
case kindPtrSliceStruct:
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
}
}
if bkind == kindStruct && !foundOne {
return sql.ErrNoRows
}
return nil
}
// BindMapping creates a mapping that helps look up the pointer for the
// column given.
func BindMapping(typ reflect.Type, mapping map[string]uint64, cols []string) ([]uint64, error) {
ptrs := make([]uint64, len(cols))
ColLoop:
for i, c := range cols {
name := strmangle.TitleCaseIdentifier(c)
ptrMap, ok := mapping[name]
if ok {
ptrs[i] = ptrMap
continue
}
suffix := "." + name
for maybeMatch, mapping := range mapping {
if strings.HasSuffix(maybeMatch, suffix) {
ptrs[i] = mapping
continue ColLoop
}
}
return nil, errors.Errorf("could not find struct field name in mapping: %s", name)
}
return ptrs, nil
}
// PtrsFromMapping expects to be passed an addressable struct and a mapping
// of where to find things. It pulls the pointers out referred to by the mapping.
func PtrsFromMapping(val reflect.Value, mapping []uint64) []interface{} {
ptrs := make([]interface{}, len(mapping))
for i, m := range mapping {
ptrs[i] = ptrFromMapping(val, m, true).Interface()
}
return ptrs
}
// ValuesFromMapping expects to be passed an addressable struct and a mapping
// of where to find things. It pulls the pointers out referred to by the mapping.
func ValuesFromMapping(val reflect.Value, mapping []uint64) []interface{} {
ptrs := make([]interface{}, len(mapping))
for i, m := range mapping {
ptrs[i] = ptrFromMapping(val, m, false).Interface()
}
return ptrs
}
// ptrFromMapping expects to be passed an addressable struct that it's looking
// for things on.
func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.Value {
for i := 0; i < 8; i++ {
v := (mapping >> uint(i*8)) & sentinel
if v == sentinel {
if addressOf && val.Kind() != reflect.Ptr {
return val.Addr()
} else if !addressOf && val.Kind() == reflect.Ptr {
return reflect.Indirect(val)
}
return val
}
val = val.Field(int(v))
if val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
}
panic("could not find pointer from mapping")
}
// MakeStructMapping creates a map of the struct to be able to quickly look
// up its pointers and values by name.
func MakeStructMapping(typ reflect.Type) map[string]uint64 {
fieldMaps := make(map[string]uint64)
makeStructMappingHelper(typ, "", 0, 0, fieldMaps)
return fieldMaps
}
func makeStructMappingHelper(typ reflect.Type, prefix string, current uint64, depth uint, fieldMaps map[string]uint64) {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
tag, recurse := getBoilTag(f)
if len(tag) == 0 {
tag = f.Name
} else if tag[0] == '-' {
continue
}
if len(prefix) != 0 {
tag = fmt.Sprintf("%s.%s", prefix, tag)
}
if recurse {
makeStructMappingHelper(f.Type, tag, current|uint64(i)<<depth, depth+8, fieldMaps)
continue
}
fieldMaps[tag] = current | (sentinel << (depth + 8)) | (uint64(i) << depth)
}
}
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
tag := field.Tag.Get("boil")
name = field.Name
if len(tag) == 0 {
return name, false
}
ind := strings.IndexByte(tag, ',')
if ind == -1 {
return strmangle.TitleCase(tag), false
} else if ind == 0 {
return name, true
}
nameFragment := tag[:ind]
return strmangle.TitleCase(nameFragment), true
}
func makeCacheKey(typ string, cols []string) string {
buf := strmangle.GetBuffer()
buf.WriteString(typ)
for _, s := range cols {
buf.WriteString(s)
}
mapKey := buf.String()
strmangle.PutBuffer(buf)
return mapKey
}
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i] = field.Interface()
}
return ret
}
// GetSliceValues returns the values (as interface) of the matching columns in obj.
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(slice)*len(columns))
for i, obj := range slice {
val := reflect.Indirect(reflect.ValueOf(obj))
for j, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i*len(columns)+j] = field.Interface()
}
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ln int
var getField func(reflect.Value, int) reflect.Value
if len(columns) == 0 {
ln = val.NumField()
getField = func(v reflect.Value, i int) reflect.Value {
return v.Field(i)
}
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
return v.FieldByName(strmangle.TitleCase(columns[i]))
}
}
ret := make([]interface{}, ln)
for i := 0; i < ln; i++ {
field := getField(val, i)
if !field.IsValid() {
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
// produce an Invalid value, so we make a hopefully safe assumption here.
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
}
ret[i] = field.Addr().Interface()
}
return ret
}

707
queries/reflect_test.go Normal file
View file

@ -0,0 +1,707 @@
package queries
import (
"database/sql/driver"
"reflect"
"strconv"
"strings"
"testing"
"time"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
"gopkg.in/nullbio/null.v5"
)
func bin64(i uint64) string {
str := strconv.FormatUint(i, 2)
pad := 64 - len(str)
if pad > 0 {
str = strings.Repeat("0", pad) + str
}
var newStr string
for i := 0; i < len(str); i += 8 {
if i != 0 {
newStr += " "
}
newStr += str[i : i+8]
}
return newStr
}
type mockRowMaker struct {
int
rows []driver.Value
}
func TestBindStruct(t *testing.T) {
t.Parallel()
testResults := struct {
ID int
Name string `boil:"test"`
}{}
query := &Query{
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"id", "test"})
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if id := testResults.ID; id != 35 {
t.Error("wrong ID:", id)
}
if name := testResults.Name; name != "pat" {
t.Error("wrong name:", name)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
func TestBindSlice(t *testing.T) {
t.Parallel()
testResults := []struct {
ID int
Name string `boil:"test"`
}{}
query := &Query{
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"id", "test"})
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
ret.AddRow(driver.Value(int64(12)), driver.Value("cat"))
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
}
if id := testResults[0].ID; id != 35 {
t.Error("wrong ID:", id)
}
if name := testResults[0].Name; name != "pat" {
t.Error("wrong name:", name)
}
if id := testResults[1].ID; id != 12 {
t.Error("wrong ID:", id)
}
if name := testResults[1].Name; name != "cat" {
t.Error("wrong name:", name)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
func TestBindPtrSlice(t *testing.T) {
t.Parallel()
testResults := []*struct {
ID int
Name string `boil:"test"`
}{}
query := &Query{
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"id", "test"})
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
ret.AddRow(driver.Value(int64(12)), driver.Value("cat"))
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
}
if id := testResults[0].ID; id != 35 {
t.Error("wrong ID:", id)
}
if name := testResults[0].Name; name != "pat" {
t.Error("wrong name:", name)
}
if id := testResults[1].ID; id != 12 {
t.Error("wrong ID:", id)
}
if name := testResults[1].Name; name != "cat" {
t.Error("wrong name:", name)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
func testMakeMapping(byt ...byte) uint64 {
var x uint64
for i, b := range byt {
x |= uint64(b) << (uint(i) * 8)
}
x |= uint64(255) << uint(len(byt)*8)
return x
}
func TestMakeStructMapping(t *testing.T) {
t.Parallel()
var testStruct = struct {
LastName string `boil:"different"`
AwesomeName string `boil:"awesome_name"`
Face string `boil:"-"`
Nose string
Nested struct {
LastName string `boil:"different"`
AwesomeName string `boil:"awesome_name"`
Face string `boil:"-"`
Nose string
Nested2 struct {
Nose string
} `boil:",bind"`
} `boil:",bind"`
}{}
got := MakeStructMapping(reflect.TypeOf(testStruct))
expectMap := map[string]uint64{
"Different": testMakeMapping(0),
"AwesomeName": testMakeMapping(1),
"Nose": testMakeMapping(3),
"Nested.Different": testMakeMapping(4, 0),
"Nested.AwesomeName": testMakeMapping(4, 1),
"Nested.Nose": testMakeMapping(4, 3),
"Nested.Nested2.Nose": testMakeMapping(4, 4, 0),
}
for expName, expVal := range expectMap {
gotVal, ok := got[expName]
if !ok {
t.Errorf("%s) had no value", expName)
continue
}
if gotVal != expVal {
t.Errorf("%s) wrong value,\nwant: %x (%s)\ngot: %x (%s)", expName, expVal, bin64(expVal), gotVal, bin64(gotVal))
}
}
}
func TestPtrFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
v := ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(0), true)
if got := *v.Interface().(*int); got != 5 {
t.Error("flat int was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(1), true)
if got := *v.Interface().(*int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 0), true)
if got := *v.Interface().(*int); got != 6 {
t.Error("nested int was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 1), true)
if got := *v.Interface().(*int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestGetBoilTag(t *testing.T) {
t.Parallel()
type TestStruct struct {
FirstName string `boil:"test_one,bind"`
LastName string `boil:"test_two"`
MiddleName string `boil:"middle_name,bind"`
AwesomeName string `boil:"awesome_name"`
Age string `boil:",bind"`
Face string `boil:"-"`
Nose string
}
var structFields []reflect.StructField
typ := reflect.TypeOf(TestStruct{})
removeOk := func(thing reflect.StructField, ok bool) reflect.StructField {
if !ok {
panic("Exploded")
}
return thing
}
structFields = append(structFields, removeOk(typ.FieldByName("FirstName")))
structFields = append(structFields, removeOk(typ.FieldByName("LastName")))
structFields = append(structFields, removeOk(typ.FieldByName("MiddleName")))
structFields = append(structFields, removeOk(typ.FieldByName("AwesomeName")))
structFields = append(structFields, removeOk(typ.FieldByName("Age")))
structFields = append(structFields, removeOk(typ.FieldByName("Face")))
structFields = append(structFields, removeOk(typ.FieldByName("Nose")))
expect := []struct {
Name string
Recurse bool
}{
{"TestOne", true},
{"TestTwo", false},
{"MiddleName", true},
{"AwesomeName", false},
{"Age", true},
{"-", false},
{"Nose", false},
}
for i, s := range structFields {
name, recurse := getBoilTag(s)
if expect[i].Name != name {
t.Errorf("Invalid name, expect %q, got %q", expect[i].Name, name)
}
if expect[i].Recurse != recurse {
t.Errorf("Invalid recurse, expect %v, got %v", !recurse, recurse)
}
}
}
func TestBindChecks(t *testing.T) {
t.Parallel()
type useless struct {
}
var tests = []struct {
BKind bindKind
Fail bool
Obj interface{}
}{
{BKind: kindStruct, Fail: false, Obj: &useless{}},
{BKind: kindSliceStruct, Fail: false, Obj: &[]useless{}},
{BKind: kindPtrSliceStruct, Fail: false, Obj: &[]*useless{}},
{Fail: true, Obj: 5},
{Fail: true, Obj: useless{}},
{Fail: true, Obj: []useless{}},
}
for i, test := range tests {
str, sli, bk, err := bindChecks(test.Obj)
if err != nil {
if !test.Fail {
t.Errorf("%d) should not fail, got: %v", i, err)
}
continue
} else if test.Fail {
t.Errorf("%d) should fail, got: %v", i, bk)
continue
}
if s := str.Kind(); s != reflect.Struct {
t.Error("struct kind was wrong:", s)
}
if test.BKind != kindStruct {
if s := sli.Kind(); s != reflect.Slice {
t.Error("slice kind was wrong:", s)
}
}
}
}
func TestBindSingular(t *testing.T) {
t.Parallel()
testResults := struct {
ID int
Name string `boil:"test"`
}{}
query := &Query{
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"id", "test"})
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if id := testResults.ID; id != 35 {
t.Error("wrong ID:", id)
}
if name := testResults.Name; name != "pat" {
t.Error("wrong name:", name)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
func TestBind_InnerJoin(t *testing.T) {
t.Parallel()
testResults := []*struct {
Happy struct {
ID int `boil:"identifier"`
} `boil:",bind"`
Fun struct {
ID int `boil:"id"`
} `boil:",bind"`
}{}
query := &Query{
from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"id"})
ret.AddRow(driver.Value(int64(10)))
ret.AddRow(driver.Value(int64(11)))
mock.ExpectQuery(`SELECT "fun"\.\* FROM "fun" INNER JOIN happy as h on fun.id = h.fun_id;`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
}
if id := testResults[0].Happy.ID; id != 0 {
t.Error("wrong ID:", id)
}
if id := testResults[0].Fun.ID; id != 10 {
t.Error("wrong ID:", id)
}
if id := testResults[1].Happy.ID; id != 0 {
t.Error("wrong ID:", id)
}
if id := testResults[1].Fun.ID; id != 11 {
t.Error("wrong ID:", id)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
// func TestBind_InnerJoinSelect(t *testing.T) {
// t.Parallel()
//
// testResults := []*struct {
// Happy struct {
// ID int
// } `boil:"h,bind"`
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// query := &Query{
// selectCols: []string{"fun.id", "h.id"},
// from: []string{"fun"},
// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
// }
//
// db, mock, err := sqlmock.New()
// if err != nil {
// t.Error(err)
// }
//
// ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
//
// SetExecutor(query, db)
// err = query.Bind(&testResults)
// if err != nil {
// t.Error(err)
// }
//
// if len(testResults) != 2 {
// t.Fatal("wrong number of results:", len(testResults))
// }
// if id := testResults[0].Happy.ID; id != 11 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[0].Fun.ID; id != 10 {
// t.Error("wrong ID:", id)
// }
//
// if id := testResults[1].Happy.ID; id != 13 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[1].Fun.ID; id != 12 {
// t.Error("wrong ID:", id)
// }
//
// if err := mock.ExpectationsWereMet(); err != nil {
// t.Error(err)
// }
// }
// func TestBindPtrs_Easy(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"identifier"`
// Date time.Time
// }{}
//
// cols := []string{"identifier", "date"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*time.Time) != &testStruct.Date {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Recursive(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// }
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id", "fun.id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_RecursiveTags(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// } `boil:",bind"`
// Fun struct {
// ID int `boil:"identification"`
// } `boil:",bind"`
// }{}
//
// cols := []string{"happy.identifier", "fun.identification"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Ignore(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"-"`
// Happy struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// }
func TestGetStructValues(t *testing.T) {
t.Parallel()
timeThing := time.Now()
o := struct {
TitleThing string
Name string
ID int
Stuff int
Things int
Time time.Time
NullBool null.Bool
}{
TitleThing: "patrick",
Stuff: 10,
Things: 0,
Time: timeThing,
NullBool: null.NewBool(true, false),
}
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
if vals[0].(string) != "patrick" {
t.Errorf("Want test, got %s", vals[0])
}
if vals[1].(string) != "" {
t.Errorf("Want empty string, got %s", vals[1])
}
if vals[2].(int) != 0 {
t.Errorf("Want 0, got %d", vals[2])
}
if vals[3].(int) != 10 {
t.Errorf("Want 10, got %d", vals[3])
}
if vals[4].(int) != 0 {
t.Errorf("Want 0, got %d", vals[4])
}
if !vals[5].(time.Time).Equal(timeThing) {
t.Errorf("Want %s, got %s", o.Time, vals[5])
}
if !vals[6].(null.Bool).IsZero() {
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
}
}
func TestGetSliceValues(t *testing.T) {
t.Parallel()
o := []struct {
ID int
Name string
}{
{5, "a"},
{6, "b"},
}
in := make([]interface{}, len(o))
in[0] = o[0]
in[1] = o[1]
vals := GetSliceValues(in, "id", "name")
if got := vals[0].(int); got != 5 {
t.Error(got)
}
if got := vals[1].(string); got != "a" {
t.Error(got)
}
if got := vals[2].(int); got != 6 {
t.Error(got)
}
if got := vals[3].(string); got != "b" {
t.Error(got)
}
}
func TestGetStructPointers(t *testing.T) {
t.Parallel()
o := struct {
Title string
ID *int
}{
Title: "patrick",
}
ptrs := GetStructPointers(&o, "title", "id")
*ptrs[0].(*string) = "test"
if o.Title != "test" {
t.Errorf("Expected test, got %s", o.Title)
}
x := 5
*ptrs[1].(**int) = &x
if *o.ID != 5 {
t.Errorf("Expected 5, got %d", *o.ID)
}
}