Move everything to better package structure
This commit is contained in:
parent
f6b4d3c6fd
commit
5149df8359
40 changed files with 241 additions and 86 deletions
1
queries/_fixtures/00.sql
Normal file
1
queries/_fixtures/00.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT * FROM "t";
|
1
queries/_fixtures/01.sql
Normal file
1
queries/_fixtures/01.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT * FROM "q" LIMIT 5 OFFSET 6;
|
1
queries/_fixtures/02.sql
Normal file
1
queries/_fixtures/02.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT * FROM "q" ORDER BY a ASC, b DESC;
|
1
queries/_fixtures/03.sql
Normal file
1
queries/_fixtures/03.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT count(*) as ab, thing as bd, "stuff" FROM "t";
|
1
queries/_fixtures/04.sql
Normal file
1
queries/_fixtures/04.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT count(*) as ab, thing as bd, "stuff" FROM "a", "b";
|
1
queries/_fixtures/05.sql
Normal file
1
queries/_fixtures/05.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "a"."happy" as "a.happy", "r"."fun" as "r.fun", "q" FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;
|
1
queries/_fixtures/06.sql
Normal file
1
queries/_fixtures/06.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "a".* FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;
|
1
queries/_fixtures/07.sql
Normal file
1
queries/_fixtures/07.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "videos".* FROM "videos" INNER JOIN (select id from users where deleted = $1) u on u.id = videos.user_id WHERE (videos.deleted = $2);
|
1
queries/_fixtures/08.sql
Normal file
1
queries/_fixtures/08.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT * FROM "a" WHERE (a=$1 or b=$2) AND (c=$3) GROUP BY id, name HAVING id <> $4, length(name, $5) > $6;
|
1
queries/_fixtures/09.sql
Normal file
1
queries/_fixtures/09.sql
Normal file
|
@ -0,0 +1 @@
|
|||
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE (a=$1) AND (b=$2) AND (c=$3);
|
1
queries/_fixtures/10.sql
Normal file
1
queries/_fixtures/10.sql
Normal file
|
@ -0,0 +1 @@
|
|||
DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3) LIMIT 5;
|
1
queries/_fixtures/11.sql
Normal file
1
queries/_fixtures/11.sql
Normal file
|
@ -0,0 +1 @@
|
|||
UPDATE thing happy, "fun", "stuff" SET ("col2", "fun"."col3", "col1") = ($1,$2,$3) WHERE (aa=$4 or bb=$5 or cc=$6) AND (dd=$7 or ee=$8 or ff=$9 and gg=$10) LIMIT 5;
|
1
queries/_fixtures/12.sql
Normal file
1
queries/_fixtures/12.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "cats".* FROM "cats" INNER JOIN dogs d on d.cat_id = cats.id;
|
1
queries/_fixtures/13.sql
Normal file
1
queries/_fixtures/13.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "c".* FROM cats c INNER JOIN dogs d on d.cat_id = cats.id;
|
1
queries/_fixtures/14.sql
Normal file
1
queries/_fixtures/14.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "c".* FROM cats as c INNER JOIN dogs d on d.cat_id = cats.id;
|
1
queries/_fixtures/15.sql
Normal file
1
queries/_fixtures/15.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "c".*, "d".* FROM cats as c, dogs as d INNER JOIN dogs d on d.cat_id = cats.id;
|
148
queries/eager_load.go
Normal file
148
queries/eager_load.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vattle/sqlboiler/boil"
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
type loadRelationshipState struct {
|
||||
exec boil.Executor
|
||||
loaded map[string]struct{}
|
||||
toLoad []string
|
||||
}
|
||||
|
||||
func (l loadRelationshipState) hasLoaded(depth int) bool {
|
||||
_, ok := l.loaded[l.buildKey(depth)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (l loadRelationshipState) setLoaded(depth int) {
|
||||
l.loaded[l.buildKey(depth)] = struct{}{}
|
||||
}
|
||||
|
||||
func (l loadRelationshipState) buildKey(depth int) string {
|
||||
buf := strmangle.GetBuffer()
|
||||
|
||||
for i, piece := range l.toLoad[:depth+1] {
|
||||
if i != 0 {
|
||||
buf.WriteByte('.')
|
||||
}
|
||||
buf.WriteString(piece)
|
||||
}
|
||||
|
||||
str := buf.String()
|
||||
strmangle.PutBuffer(buf)
|
||||
return str
|
||||
}
|
||||
|
||||
// loadRelationships dynamically calls the template generated eager load
|
||||
// functions of the form:
|
||||
//
|
||||
// func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{})
|
||||
//
|
||||
// The arguments to this function are:
|
||||
// - t is not considered here, and is always passed nil. The function exists on a loaded
|
||||
// struct to avoid a circular dependency with boil, and the receiver is ignored.
|
||||
// - exec is used to perform additional queries that might be required for loading the relationships.
|
||||
// - singular is passed in to identify whether or not this was a single object
|
||||
// or a slice that must be loaded into.
|
||||
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
|
||||
//
|
||||
// It takes list of nested relationships to load.
|
||||
func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
|
||||
typ := reflect.TypeOf(obj).Elem()
|
||||
if bkind == kindPtrSliceStruct {
|
||||
typ = typ.Elem().Elem()
|
||||
}
|
||||
|
||||
if !l.hasLoaded(depth) {
|
||||
current := l.toLoad[depth]
|
||||
ln, found := typ.FieldByName(loaderStructName)
|
||||
// It's possible a Loaders struct doesn't exist on the struct.
|
||||
if !found {
|
||||
return errors.Errorf("attempted to load %s but no L struct was found", current)
|
||||
}
|
||||
|
||||
// Attempt to find the LoadRelationshipName function
|
||||
loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current)
|
||||
if !found {
|
||||
return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current)
|
||||
}
|
||||
|
||||
// Hack to allow nil executors
|
||||
execArg := reflect.ValueOf(l.exec)
|
||||
if !execArg.IsValid() {
|
||||
execArg = reflect.ValueOf((*sql.DB)(nil))
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(obj).Elem()
|
||||
if bkind == kindPtrSliceStruct {
|
||||
val = val.Index(0).Elem()
|
||||
}
|
||||
|
||||
methodArgs := []reflect.Value{
|
||||
val.FieldByName(loaderStructName),
|
||||
execArg,
|
||||
reflect.ValueOf(bkind == kindStruct),
|
||||
reflect.ValueOf(obj),
|
||||
}
|
||||
resp := loadMethod.Func.Call(methodArgs)
|
||||
if intf := resp[0].Interface(); intf != nil {
|
||||
return errors.Wrapf(intf.(error), "failed to eager load %s", current)
|
||||
}
|
||||
|
||||
l.setLoaded(depth)
|
||||
}
|
||||
|
||||
// Pull one off the queue, continue if there's still some to go
|
||||
depth++
|
||||
if depth >= len(l.toLoad) {
|
||||
return nil
|
||||
}
|
||||
|
||||
loadedObject := reflect.ValueOf(obj)
|
||||
// If we eagerly loaded nothing
|
||||
if loadedObject.IsNil() {
|
||||
return nil
|
||||
}
|
||||
loadedObject = reflect.Indirect(loadedObject)
|
||||
|
||||
// If it's singular we can just immediately call without looping
|
||||
if bkind == kindStruct {
|
||||
return l.loadRelationshipsRecurse(depth, loadedObject)
|
||||
}
|
||||
|
||||
// Loop over all eager loaded objects
|
||||
ln := loadedObject.Len()
|
||||
if ln == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < ln; i++ {
|
||||
iter := loadedObject.Index(i).Elem()
|
||||
if err := l.loadRelationshipsRecurse(depth, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
|
||||
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
|
||||
func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
|
||||
r := obj.FieldByName(relationshipStructName)
|
||||
if !r.IsValid() || r.IsNil() {
|
||||
return errors.Errorf("could not traverse into loaded %s relationship to load more things", l.toLoad[depth])
|
||||
}
|
||||
newObj := reflect.Indirect(r).FieldByName(l.toLoad[depth])
|
||||
bkind := kindStruct
|
||||
if reflect.Indirect(newObj).Kind() != reflect.Struct {
|
||||
bkind = kindPtrSliceStruct
|
||||
newObj = newObj.Addr()
|
||||
}
|
||||
return l.loadRelationships(depth, newObj.Interface(), bkind)
|
||||
}
|
202
queries/eager_load_test.go
Normal file
202
queries/eager_load_test.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/vattle/sqlboiler/boil"
|
||||
)
|
||||
|
||||
var loadFunctionCalled bool
|
||||
var loadFunctionNestedCalled int
|
||||
|
||||
type testRStruct struct {
|
||||
}
|
||||
type testLStruct struct {
|
||||
}
|
||||
|
||||
type testNestedStruct struct {
|
||||
ID int
|
||||
R *testNestedRStruct
|
||||
L testNestedLStruct
|
||||
}
|
||||
type testNestedRStruct struct {
|
||||
ToEagerLoad *testNestedStruct
|
||||
}
|
||||
type testNestedLStruct struct {
|
||||
}
|
||||
|
||||
type testNestedSlice struct {
|
||||
ID int
|
||||
R *testNestedRSlice
|
||||
L testNestedLSlice
|
||||
}
|
||||
type testNestedRSlice struct {
|
||||
ToEagerLoad []*testNestedSlice
|
||||
}
|
||||
type testNestedLSlice struct {
|
||||
}
|
||||
|
||||
func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error {
|
||||
loadFunctionCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
|
||||
switch x := obj.(type) {
|
||||
case *testNestedStruct:
|
||||
x.R = &testNestedRStruct{
|
||||
&testNestedStruct{ID: 4},
|
||||
}
|
||||
case *[]*testNestedStruct:
|
||||
for _, r := range *x {
|
||||
r.R = &testNestedRStruct{
|
||||
&testNestedStruct{ID: 4},
|
||||
}
|
||||
}
|
||||
}
|
||||
loadFunctionNestedCalled++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
|
||||
|
||||
switch x := obj.(type) {
|
||||
case *testNestedSlice:
|
||||
x.R = &testNestedRSlice{
|
||||
[]*testNestedSlice{{ID: 5}},
|
||||
}
|
||||
case *[]*testNestedSlice:
|
||||
for _, r := range *x {
|
||||
r.R = &testNestedRSlice{
|
||||
[]*testNestedSlice{{ID: 5}},
|
||||
}
|
||||
}
|
||||
}
|
||||
loadFunctionNestedCalled++
|
||||
return nil
|
||||
}
|
||||
|
||||
func testFakeState(toLoad ...string) loadRelationshipState {
|
||||
return loadRelationshipState{
|
||||
loaded: map[string]struct{}{},
|
||||
toLoad: toLoad,
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRelationshipsSlice(t *testing.T) {
|
||||
// t.Parallel() Function uses globals
|
||||
loadFunctionCalled = false
|
||||
|
||||
testSlice := []*struct {
|
||||
ID int
|
||||
R *testRStruct
|
||||
L testLStruct
|
||||
}{{}}
|
||||
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !loadFunctionCalled {
|
||||
t.Errorf("Load function was not called for testSlice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRelationshipsSingular(t *testing.T) {
|
||||
// t.Parallel() Function uses globals
|
||||
loadFunctionCalled = false
|
||||
|
||||
testSingular := struct {
|
||||
ID int
|
||||
R *testRStruct
|
||||
L testLStruct
|
||||
}{}
|
||||
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !loadFunctionCalled {
|
||||
t.Errorf("Load function was not called for singular")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRelationshipsSliceNested(t *testing.T) {
|
||||
// t.Parallel() Function uses globals
|
||||
testSlice := []*testNestedStruct{
|
||||
{
|
||||
ID: 2,
|
||||
},
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
|
||||
}
|
||||
|
||||
testSliceSlice := []*testNestedSlice{
|
||||
{
|
||||
ID: 2,
|
||||
},
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRelationshipsSingularNested(t *testing.T) {
|
||||
// t.Parallel() Function uses globals
|
||||
testSingular := testNestedStruct{
|
||||
ID: 3,
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
|
||||
}
|
||||
|
||||
testSingularSlice := testNestedSlice{
|
||||
ID: 3,
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRelationshipsNoReload(t *testing.T) {
|
||||
// t.Parallel() Function uses globals
|
||||
testSingular := testNestedStruct{
|
||||
ID: 3,
|
||||
R: &testNestedRStruct{
|
||||
&testNestedStruct{},
|
||||
},
|
||||
}
|
||||
|
||||
loadFunctionNestedCalled = 0
|
||||
state := loadRelationshipState{
|
||||
loaded: map[string]struct{}{
|
||||
"ToEagerLoad": {},
|
||||
"ToEagerLoad.ToEagerLoad": {},
|
||||
},
|
||||
toLoad: []string{"ToEagerLoad", "ToEagerLoad"},
|
||||
}
|
||||
|
||||
if err := state.loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 0 {
|
||||
t.Error("didn't want this called")
|
||||
}
|
||||
}
|
31
queries/helpers.go
Normal file
31
queries/helpers.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// NonZeroDefaultSet returns the fields included in the
|
||||
// defaults slice that are non zero values
|
||||
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
||||
c := make([]string, 0, len(defaults))
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(obj))
|
||||
|
||||
for _, d := range defaults {
|
||||
fieldName := strmangle.TitleCase(d)
|
||||
field := val.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
|
||||
}
|
||||
|
||||
zero := reflect.Zero(field.Type())
|
||||
if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
|
||||
c = append(c, d)
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
67
queries/helpers_test.go
Normal file
67
queries/helpers_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/nullbio/null.v5"
|
||||
)
|
||||
|
||||
type testObj struct {
|
||||
ID int
|
||||
Name string `db:"TestHello"`
|
||||
HeadSize int
|
||||
}
|
||||
|
||||
func TestNonZeroDefaultSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type Anything struct {
|
||||
ID int
|
||||
Name string
|
||||
CreatedAt *time.Time
|
||||
UpdatedAt null.Time
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
Defaults []string
|
||||
Obj interface{}
|
||||
Ret []string
|
||||
}{
|
||||
{
|
||||
[]string{"id"},
|
||||
Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"id"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{"id"},
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{"id"},
|
||||
},
|
||||
{
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}},
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
z := NonZeroDefaultSet(test.Defaults, test.Obj)
|
||||
if !reflect.DeepEqual(test.Ret, z) {
|
||||
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
|
||||
}
|
||||
}
|
||||
}
|
145
queries/qm/query_mods.go
Normal file
145
queries/qm/query_mods.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package qm
|
||||
|
||||
import "github.com/vattle/sqlboiler/queries"
|
||||
|
||||
// QueryMod to modify the query object
|
||||
type QueryMod func(q *queries.Query)
|
||||
|
||||
// Apply the query mods to the Query object
|
||||
func Apply(q *queries.Query, mods ...QueryMod) {
|
||||
for _, mod := range mods {
|
||||
mod(q)
|
||||
}
|
||||
}
|
||||
|
||||
// SQL allows you to execute a plain SQL statement
|
||||
func SQL(sql string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.SetSQL(q, sql, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Load allows you to specify foreign key relationships to eager load
|
||||
// for your query. Passed in relationships need to be in the format
|
||||
// MyThing or MyThings.
|
||||
// Relationship name plurality is important, if your relationship is
|
||||
// singular, you need to specify the singular form and vice versa.
|
||||
func Load(relationships ...string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.SetLoad(q, relationships...)
|
||||
}
|
||||
}
|
||||
|
||||
// InnerJoin on another table
|
||||
func InnerJoin(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendInnerJoin(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Select specific columns opposed to all columns
|
||||
func Select(columns ...string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendSelect(q, columns...)
|
||||
}
|
||||
}
|
||||
|
||||
// Where allows you to specify a where clause for your statement
|
||||
func Where(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendWhere(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// And allows you to specify a where clause separated by an AND for your statement
|
||||
// And is a duplicate of the Where function, but allows for more natural looking
|
||||
// query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?")))
|
||||
func And(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendWhere(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Or allows you to specify a where clause separated by an OR for your statement
|
||||
func Or(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendWhere(q, clause, args...)
|
||||
queries.SetLastWhereAsOr(q)
|
||||
}
|
||||
}
|
||||
|
||||
// WhereIn allows you to specify a "x IN (set)" clause for your where statement
|
||||
// Example clauses: "column in ?", "(column1,column2) in ?"
|
||||
func WhereIn(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendIn(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// AndIn allows you to specify a "x IN (set)" clause separated by an AndIn
|
||||
// for your where statement. AndIn is a duplicate of the WhereIn function, but
|
||||
// allows for more natural looking query mod chains, for example:
|
||||
// (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?"))
|
||||
func AndIn(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendIn(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// OrIn allows you to specify an IN clause separated by
|
||||
// an OR for your where statement
|
||||
func OrIn(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendIn(q, clause, args...)
|
||||
queries.SetLastInAsOr(q)
|
||||
}
|
||||
}
|
||||
|
||||
// GroupBy allows you to specify a group by clause for your statement
|
||||
func GroupBy(clause string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendGroupBy(q, clause)
|
||||
}
|
||||
}
|
||||
|
||||
// OrderBy allows you to specify a order by clause for your statement
|
||||
func OrderBy(clause string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendOrderBy(q, clause)
|
||||
}
|
||||
}
|
||||
|
||||
// Having allows you to specify a having clause for your statement
|
||||
func Having(clause string, args ...interface{}) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendHaving(q, clause, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// From allows to specify the table for your statement
|
||||
func From(from string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.AppendFrom(q, from)
|
||||
}
|
||||
}
|
||||
|
||||
// Limit the number of returned rows
|
||||
func Limit(limit int) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.SetLimit(q, limit)
|
||||
}
|
||||
}
|
||||
|
||||
// Offset into the results
|
||||
func Offset(offset int) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.SetOffset(q, offset)
|
||||
}
|
||||
}
|
||||
|
||||
// For inserts a concurrency locking clause at the end of your statement
|
||||
func For(clause string) QueryMod {
|
||||
return func(q *queries.Query) {
|
||||
queries.SetFor(q, clause)
|
||||
}
|
||||
}
|
274
queries/query.go
Normal file
274
queries/query.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/vattle/sqlboiler/boil"
|
||||
)
|
||||
|
||||
// joinKind is the type of join
|
||||
type joinKind int
|
||||
|
||||
// Join type constants
|
||||
const (
|
||||
JoinInner joinKind = iota
|
||||
JoinOuterLeft
|
||||
JoinOuterRight
|
||||
JoinNatural
|
||||
)
|
||||
|
||||
// Query holds the state for the built up query
|
||||
type Query struct {
|
||||
executor boil.Executor
|
||||
dialect *Dialect
|
||||
plainSQL plainSQL
|
||||
load []string
|
||||
delete bool
|
||||
update map[string]interface{}
|
||||
selectCols []string
|
||||
count bool
|
||||
from []string
|
||||
joins []join
|
||||
where []where
|
||||
in []in
|
||||
groupBy []string
|
||||
orderBy []string
|
||||
having []having
|
||||
limit int
|
||||
offset int
|
||||
forlock string
|
||||
}
|
||||
|
||||
// Dialect holds values that direct the query builder
|
||||
// how to build compatible queries for each database.
|
||||
// Each database driver needs to implement functions
|
||||
// that provide these values.
|
||||
type Dialect struct {
|
||||
// The left quote character for SQL identifiers
|
||||
LQ byte
|
||||
// The right quote character for SQL identifiers
|
||||
RQ byte
|
||||
// Bool flag indicating whether indexed
|
||||
// placeholders ($1) are used, or ? placeholders.
|
||||
IndexPlaceholders bool
|
||||
}
|
||||
|
||||
type where struct {
|
||||
clause string
|
||||
orSeparator bool
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
type in struct {
|
||||
clause string
|
||||
orSeparator bool
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
type having struct {
|
||||
clause string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
type plainSQL struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
type join struct {
|
||||
kind joinKind
|
||||
clause string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// SQL makes a plainSQL query, usually for use with bind
|
||||
func SQL(exec boil.Executor, query string, args ...interface{}) *Query {
|
||||
return &Query{
|
||||
executor: exec,
|
||||
plainSQL: plainSQL{
|
||||
sql: query,
|
||||
args: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SQLG makes a plainSQL query using the global boil.Executor, usually for use with bind
|
||||
func SQLG(query string, args ...interface{}) *Query {
|
||||
return SQL(boil.GetDB(), query, args...)
|
||||
}
|
||||
|
||||
// Exec executes a query that does not need a row returned
|
||||
func (q *Query) Exec() (sql.Result, error) {
|
||||
qs, args := buildQuery(q)
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, qs)
|
||||
fmt.Fprintln(boil.DebugWriter, args)
|
||||
}
|
||||
return q.executor.Exec(qs, args...)
|
||||
}
|
||||
|
||||
// QueryRow executes the query for the One finisher and returns a row
|
||||
func (q *Query) QueryRow() *sql.Row {
|
||||
qs, args := buildQuery(q)
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, qs)
|
||||
fmt.Fprintln(boil.DebugWriter, args)
|
||||
}
|
||||
return q.executor.QueryRow(qs, args...)
|
||||
}
|
||||
|
||||
// Query executes the query for the All finisher and returns multiple rows
|
||||
func (q *Query) Query() (*sql.Rows, error) {
|
||||
qs, args := buildQuery(q)
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, qs)
|
||||
fmt.Fprintln(boil.DebugWriter, args)
|
||||
}
|
||||
return q.executor.Query(qs, args...)
|
||||
}
|
||||
|
||||
// ExecP executes a query that does not need a row returned
|
||||
// It will panic on error
|
||||
func (q *Query) ExecP() sql.Result {
|
||||
res, err := q.Exec()
|
||||
if err != nil {
|
||||
panic(boil.WrapErr(err))
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// QueryP executes the query for the All finisher and returns multiple rows
|
||||
// It will panic on error
|
||||
func (q *Query) QueryP() *sql.Rows {
|
||||
rows, err := q.Query()
|
||||
if err != nil {
|
||||
panic(boil.WrapErr(err))
|
||||
}
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
// SetExecutor on the query.
|
||||
func SetExecutor(q *Query, exec boil.Executor) {
|
||||
q.executor = exec
|
||||
}
|
||||
|
||||
// GetExecutor on the query.
|
||||
func GetExecutor(q *Query) boil.Executor {
|
||||
return q.executor
|
||||
}
|
||||
|
||||
// SetDialect on the query.
|
||||
func SetDialect(q *Query, dialect *Dialect) {
|
||||
q.dialect = dialect
|
||||
}
|
||||
|
||||
// SetSQL on the query.
|
||||
func SetSQL(q *Query, sql string, args ...interface{}) {
|
||||
q.plainSQL = plainSQL{sql: sql, args: args}
|
||||
}
|
||||
|
||||
// SetLoad on the query.
|
||||
func SetLoad(q *Query, relationships ...string) {
|
||||
q.load = append([]string(nil), relationships...)
|
||||
}
|
||||
|
||||
// SetSelect on the query.
|
||||
func SetSelect(q *Query, sel []string) {
|
||||
q.selectCols = sel
|
||||
}
|
||||
|
||||
// SetCount on the query.
|
||||
func SetCount(q *Query) {
|
||||
q.count = true
|
||||
}
|
||||
|
||||
// SetDelete on the query.
|
||||
func SetDelete(q *Query) {
|
||||
q.delete = true
|
||||
}
|
||||
|
||||
// SetLimit on the query.
|
||||
func SetLimit(q *Query, limit int) {
|
||||
q.limit = limit
|
||||
}
|
||||
|
||||
// SetOffset on the query.
|
||||
func SetOffset(q *Query, offset int) {
|
||||
q.offset = offset
|
||||
}
|
||||
|
||||
// SetFor on the query.
|
||||
func SetFor(q *Query, clause string) {
|
||||
q.forlock = clause
|
||||
}
|
||||
|
||||
// SetUpdate on the query.
|
||||
func SetUpdate(q *Query, cols map[string]interface{}) {
|
||||
q.update = cols
|
||||
}
|
||||
|
||||
// AppendSelect on the query.
|
||||
func AppendSelect(q *Query, columns ...string) {
|
||||
q.selectCols = append(q.selectCols, columns...)
|
||||
}
|
||||
|
||||
// AppendFrom on the query.
|
||||
func AppendFrom(q *Query, from ...string) {
|
||||
q.from = append(q.from, from...)
|
||||
}
|
||||
|
||||
// SetFrom replaces the current from statements.
|
||||
func SetFrom(q *Query, from ...string) {
|
||||
q.from = append([]string(nil), from...)
|
||||
}
|
||||
|
||||
// AppendInnerJoin on the query.
|
||||
func AppendInnerJoin(q *Query, clause string, args ...interface{}) {
|
||||
q.joins = append(q.joins, join{clause: clause, kind: JoinInner, args: args})
|
||||
}
|
||||
|
||||
// AppendHaving on the query.
|
||||
func AppendHaving(q *Query, clause string, args ...interface{}) {
|
||||
q.having = append(q.having, having{clause: clause, args: args})
|
||||
}
|
||||
|
||||
// AppendWhere on the query.
|
||||
func AppendWhere(q *Query, clause string, args ...interface{}) {
|
||||
q.where = append(q.where, where{clause: clause, args: args})
|
||||
}
|
||||
|
||||
// AppendIn on the query.
|
||||
func AppendIn(q *Query, clause string, args ...interface{}) {
|
||||
q.in = append(q.in, in{clause: clause, args: args})
|
||||
}
|
||||
|
||||
// SetLastWhereAsOr sets the or separator for the tail "WHERE" in the slice
|
||||
func SetLastWhereAsOr(q *Query) {
|
||||
if len(q.where) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
q.where[len(q.where)-1].orSeparator = true
|
||||
}
|
||||
|
||||
// SetLastInAsOr sets the or separator for the tail "IN" in the slice
|
||||
func SetLastInAsOr(q *Query) {
|
||||
if len(q.in) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
q.in[len(q.in)-1].orSeparator = true
|
||||
}
|
||||
|
||||
// AppendGroupBy on the query.
|
||||
func AppendGroupBy(q *Query, clause string) {
|
||||
q.groupBy = append(q.groupBy, clause)
|
||||
}
|
||||
|
||||
// AppendOrderBy on the query.
|
||||
func AppendOrderBy(q *Query, clause string) {
|
||||
q.orderBy = append(q.orderBy, clause)
|
||||
}
|
573
queries/query_builders.go
Normal file
573
queries/query_builders.go
Normal file
|
@ -0,0 +1,573 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
var (
|
||||
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
|
||||
rgxInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])IN([\s|\(|\?].*)$`)
|
||||
)
|
||||
|
||||
func buildQuery(q *Query) (string, []interface{}) {
|
||||
var buf *bytes.Buffer
|
||||
var args []interface{}
|
||||
|
||||
switch {
|
||||
case len(q.plainSQL.sql) != 0:
|
||||
return q.plainSQL.sql, q.plainSQL.args
|
||||
case q.delete:
|
||||
buf, args = buildDeleteQuery(q)
|
||||
case len(q.update) > 0:
|
||||
buf, args = buildUpdateQuery(q)
|
||||
default:
|
||||
buf, args = buildSelectQuery(q)
|
||||
}
|
||||
|
||||
defer strmangle.PutBuffer(buf)
|
||||
|
||||
// Cache the generated query for query object re-use
|
||||
bufStr := buf.String()
|
||||
q.plainSQL.sql = bufStr
|
||||
q.plainSQL.args = args
|
||||
|
||||
return bufStr, args
|
||||
}
|
||||
|
||||
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
buf := strmangle.GetBuffer()
|
||||
var args []interface{}
|
||||
|
||||
buf.WriteString("SELECT ")
|
||||
|
||||
if q.count {
|
||||
buf.WriteString("COUNT(")
|
||||
}
|
||||
|
||||
hasSelectCols := len(q.selectCols) != 0
|
||||
hasJoins := len(q.joins) != 0
|
||||
if hasJoins && hasSelectCols && !q.count {
|
||||
selectColsWithAs := writeAsStatements(q)
|
||||
// Don't identQuoteSlice - writeAsStatements does this
|
||||
buf.WriteString(strings.Join(selectColsWithAs, ", "))
|
||||
} else if hasSelectCols {
|
||||
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
|
||||
} else if hasJoins && !q.count {
|
||||
selectColsWithStars := writeStars(q)
|
||||
buf.WriteString(strings.Join(selectColsWithStars, ", "))
|
||||
} else {
|
||||
buf.WriteByte('*')
|
||||
}
|
||||
|
||||
// close SQL COUNT function
|
||||
if q.count {
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
||||
|
||||
if len(q.joins) > 0 {
|
||||
argsLen := len(args)
|
||||
joinBuf := strmangle.GetBuffer()
|
||||
for _, j := range q.joins {
|
||||
if j.kind != JoinInner {
|
||||
panic("only inner joins are supported")
|
||||
}
|
||||
fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
|
||||
args = append(args, j.args...)
|
||||
}
|
||||
var resp string
|
||||
if q.dialect.IndexPlaceholders {
|
||||
resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
|
||||
} else {
|
||||
resp = joinBuf.String()
|
||||
}
|
||||
fmt.Fprintf(buf, resp)
|
||||
strmangle.PutBuffer(joinBuf)
|
||||
}
|
||||
|
||||
where, whereArgs := whereClause(q, len(args)+1)
|
||||
buf.WriteString(where)
|
||||
if len(whereArgs) != 0 {
|
||||
args = append(args, whereArgs...)
|
||||
}
|
||||
|
||||
in, inArgs := inClause(q, len(args)+1)
|
||||
buf.WriteString(in)
|
||||
if len(inArgs) != 0 {
|
||||
args = append(args, inArgs...)
|
||||
}
|
||||
|
||||
writeModifiers(q, buf, &args)
|
||||
|
||||
buf.WriteByte(';')
|
||||
return buf, args
|
||||
}
|
||||
|
||||
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
var args []interface{}
|
||||
buf := strmangle.GetBuffer()
|
||||
|
||||
buf.WriteString("DELETE FROM ")
|
||||
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
||||
|
||||
where, whereArgs := whereClause(q, 1)
|
||||
if len(whereArgs) != 0 {
|
||||
args = append(args, whereArgs)
|
||||
}
|
||||
buf.WriteString(where)
|
||||
|
||||
in, inArgs := inClause(q, len(args)+1)
|
||||
if len(inArgs) != 0 {
|
||||
args = append(args, inArgs...)
|
||||
}
|
||||
buf.WriteString(in)
|
||||
|
||||
writeModifiers(q, buf, &args)
|
||||
|
||||
buf.WriteByte(';')
|
||||
|
||||
return buf, args
|
||||
}
|
||||
|
||||
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
buf := strmangle.GetBuffer()
|
||||
|
||||
buf.WriteString("UPDATE ")
|
||||
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
||||
|
||||
cols := make(sort.StringSlice, len(q.update))
|
||||
var args []interface{}
|
||||
|
||||
count := 0
|
||||
for name := range q.update {
|
||||
cols[count] = name
|
||||
count++
|
||||
}
|
||||
|
||||
cols.Sort()
|
||||
|
||||
for i := 0; i < len(cols); i++ {
|
||||
args = append(args, q.update[cols[i]])
|
||||
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
|
||||
}
|
||||
|
||||
buf.WriteString(fmt.Sprintf(
|
||||
" SET (%s) = (%s)",
|
||||
strings.Join(cols, ", "),
|
||||
strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)),
|
||||
)
|
||||
|
||||
where, whereArgs := whereClause(q, len(args)+1)
|
||||
if len(whereArgs) != 0 {
|
||||
args = append(args, whereArgs...)
|
||||
}
|
||||
buf.WriteString(where)
|
||||
|
||||
in, inArgs := inClause(q, len(args)+1)
|
||||
if len(inArgs) != 0 {
|
||||
args = append(args, inArgs...)
|
||||
}
|
||||
buf.WriteString(in)
|
||||
|
||||
writeModifiers(q, buf, &args)
|
||||
|
||||
buf.WriteByte(';')
|
||||
|
||||
return buf, args
|
||||
}
|
||||
|
||||
// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided.
|
||||
func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string {
|
||||
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
|
||||
|
||||
buf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(buf)
|
||||
|
||||
fmt.Fprintf(
|
||||
buf,
|
||||
"INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ",
|
||||
tableName,
|
||||
strings.Join(whitelist, ", "),
|
||||
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
|
||||
)
|
||||
|
||||
for i, v := range update {
|
||||
if i != 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
|
||||
buf.WriteString(quoted)
|
||||
buf.WriteString(" = VALUES(")
|
||||
buf.WriteString(quoted)
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided.
|
||||
func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
|
||||
conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict)
|
||||
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
|
||||
ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret)
|
||||
|
||||
buf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(buf)
|
||||
|
||||
fmt.Fprintf(
|
||||
buf,
|
||||
"INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ",
|
||||
tableName,
|
||||
strings.Join(whitelist, ", "),
|
||||
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
|
||||
)
|
||||
|
||||
if !updateOnConflict || len(update) == 0 {
|
||||
buf.WriteString("DO NOTHING")
|
||||
} else {
|
||||
buf.WriteByte('(')
|
||||
buf.WriteString(strings.Join(conflict, ", "))
|
||||
buf.WriteString(") DO UPDATE SET ")
|
||||
|
||||
for i, v := range update {
|
||||
if i != 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
|
||||
buf.WriteString(quoted)
|
||||
buf.WriteString(" = EXCLUDED.")
|
||||
buf.WriteString(quoted)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ret) != 0 {
|
||||
buf.WriteString(" RETURNING ")
|
||||
buf.WriteString(strings.Join(ret, ", "))
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
|
||||
if len(q.groupBy) != 0 {
|
||||
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
|
||||
}
|
||||
|
||||
if len(q.having) != 0 {
|
||||
argsLen := len(*args)
|
||||
havingBuf := strmangle.GetBuffer()
|
||||
fmt.Fprintf(havingBuf, " HAVING ")
|
||||
for i, j := range q.having {
|
||||
if i > 0 {
|
||||
fmt.Fprintf(havingBuf, ", ")
|
||||
}
|
||||
fmt.Fprintf(havingBuf, j.clause)
|
||||
*args = append(*args, j.args...)
|
||||
}
|
||||
var resp string
|
||||
if q.dialect.IndexPlaceholders {
|
||||
resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1)
|
||||
} else {
|
||||
resp = havingBuf.String()
|
||||
}
|
||||
fmt.Fprintf(buf, resp)
|
||||
strmangle.PutBuffer(havingBuf)
|
||||
}
|
||||
|
||||
if len(q.orderBy) != 0 {
|
||||
buf.WriteString(" ORDER BY ")
|
||||
buf.WriteString(strings.Join(q.orderBy, ", "))
|
||||
}
|
||||
|
||||
if q.limit != 0 {
|
||||
fmt.Fprintf(buf, " LIMIT %d", q.limit)
|
||||
}
|
||||
if q.offset != 0 {
|
||||
fmt.Fprintf(buf, " OFFSET %d", q.offset)
|
||||
}
|
||||
|
||||
if len(q.forlock) != 0 {
|
||||
fmt.Fprintf(buf, " FOR %s", q.forlock)
|
||||
}
|
||||
}
|
||||
|
||||
func writeStars(q *Query) []string {
|
||||
cols := make([]string, len(q.from))
|
||||
for i, f := range q.from {
|
||||
toks := strings.Split(f, " ")
|
||||
if len(toks) == 1 {
|
||||
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
|
||||
continue
|
||||
}
|
||||
|
||||
alias, name, ok := parseFromClause(toks)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(alias) != 0 {
|
||||
name = alias
|
||||
}
|
||||
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
|
||||
}
|
||||
|
||||
return cols
|
||||
}
|
||||
|
||||
func writeAsStatements(q *Query) []string {
|
||||
cols := make([]string, len(q.selectCols))
|
||||
for i, col := range q.selectCols {
|
||||
if !rgxIdentifier.MatchString(col) {
|
||||
cols[i] = col
|
||||
continue
|
||||
}
|
||||
|
||||
toks := strings.Split(col, ".")
|
||||
if len(toks) == 1 {
|
||||
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
|
||||
continue
|
||||
}
|
||||
|
||||
asParts := make([]string, len(toks))
|
||||
for j, tok := range toks {
|
||||
asParts[j] = strings.Trim(tok, `"`)
|
||||
}
|
||||
|
||||
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
|
||||
}
|
||||
|
||||
return cols
|
||||
}
|
||||
|
||||
// whereClause parses a where slice and converts it into a
|
||||
// single WHERE clause like:
|
||||
// WHERE (a=$1) AND (b=$2)
|
||||
//
|
||||
// startAt specifies what number placeholders start at
|
||||
func whereClause(q *Query, startAt int) (string, []interface{}) {
|
||||
if len(q.where) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
buf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(buf)
|
||||
var args []interface{}
|
||||
|
||||
buf.WriteString(" WHERE ")
|
||||
for i, where := range q.where {
|
||||
if i != 0 {
|
||||
if where.orSeparator {
|
||||
buf.WriteString(" OR ")
|
||||
} else {
|
||||
buf.WriteString(" AND ")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString(fmt.Sprintf("(%s)", where.clause))
|
||||
args = append(args, where.args...)
|
||||
}
|
||||
|
||||
var resp string
|
||||
if q.dialect.IndexPlaceholders {
|
||||
resp, _ = convertQuestionMarks(buf.String(), startAt)
|
||||
} else {
|
||||
resp = buf.String()
|
||||
}
|
||||
|
||||
return resp, args
|
||||
}
|
||||
|
||||
// inClause parses an in slice and converts it into a
|
||||
// single IN clause, like:
|
||||
// WHERE ("a", "b") IN (($1,$2),($3,$4)).
|
||||
func inClause(q *Query, startAt int) (string, []interface{}) {
|
||||
if len(q.in) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
buf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(buf)
|
||||
var args []interface{}
|
||||
|
||||
if len(q.where) == 0 {
|
||||
buf.WriteString(" WHERE ")
|
||||
}
|
||||
|
||||
for i, in := range q.in {
|
||||
ln := len(in.args)
|
||||
// We only prefix the OR and AND separators after the first
|
||||
// clause has been generated UNLESS there is already a where
|
||||
// clause that we have to add on to.
|
||||
if i != 0 || len(q.where) > 0 {
|
||||
if in.orSeparator {
|
||||
buf.WriteString(" OR ")
|
||||
} else {
|
||||
buf.WriteString(" AND ")
|
||||
}
|
||||
}
|
||||
|
||||
matches := rgxInClause.FindStringSubmatch(in.clause)
|
||||
// If we can't find any matches attempt a simple replace with 1 group.
|
||||
// Clauses that fit this criteria will not be able to contain ? in their
|
||||
// column name side, however if this case is being hit then the regexp
|
||||
// probably needs adjustment, or the user is passing in invalid clauses.
|
||||
if matches == nil {
|
||||
clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln)
|
||||
buf.WriteString(clause)
|
||||
startAt = startAt + count
|
||||
} else {
|
||||
leftSide := strings.TrimSpace(matches[1])
|
||||
rightSide := strings.TrimSpace(matches[2])
|
||||
// If matches are found, we have to parse the left side (column side)
|
||||
// of the clause to determine how many columns they are using.
|
||||
// This number determines the groupAt for the convert function.
|
||||
cols := strings.Split(leftSide, ",")
|
||||
cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
|
||||
groupAt := len(cols)
|
||||
|
||||
var leftClause string
|
||||
var leftCount int
|
||||
if q.dialect.IndexPlaceholders {
|
||||
leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
|
||||
} else {
|
||||
// Count the number of cols that are question marks, so we know
|
||||
// how much to offset convertInQuestionMarks by
|
||||
for _, v := range cols {
|
||||
if v == "?" {
|
||||
leftCount++
|
||||
}
|
||||
}
|
||||
leftClause = strings.Join(cols, ",")
|
||||
}
|
||||
rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
|
||||
buf.WriteString(leftClause)
|
||||
buf.WriteString(" IN ")
|
||||
buf.WriteString(rightClause)
|
||||
startAt = startAt + leftCount + rightCount
|
||||
}
|
||||
|
||||
args = append(args, in.args...)
|
||||
}
|
||||
|
||||
return buf.String(), args
|
||||
}
|
||||
|
||||
// convertInQuestionMarks finds the first unescaped occurrence of ? and swaps it
|
||||
// with a list of numbered placeholders, starting at startAt.
|
||||
// It uses groupAt to determine how many placeholders should be in each group,
|
||||
// for example, groupAt 2 would result in: (($1,$2),($3,$4))
|
||||
// and groupAt 1 would result in ($1,$2,$3,$4)
|
||||
func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
|
||||
if startAt == 0 || len(clause) == 0 {
|
||||
panic("Not a valid start number.")
|
||||
}
|
||||
|
||||
paramBuf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(paramBuf)
|
||||
|
||||
foundAt := -1
|
||||
for i := 0; i < len(clause); i++ {
|
||||
if (clause[i] == '?' && i == 0) || (clause[i] == '?' && clause[i-1] != '\\') {
|
||||
foundAt = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundAt == -1 {
|
||||
return strings.Replace(clause, `\?`, "?", -1), 0
|
||||
}
|
||||
|
||||
paramBuf.WriteString(clause[:foundAt])
|
||||
paramBuf.WriteByte('(')
|
||||
paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt))
|
||||
paramBuf.WriteByte(')')
|
||||
paramBuf.WriteString(clause[foundAt+1:])
|
||||
|
||||
// Remove all backslashes from escaped question-marks
|
||||
ret := strings.Replace(paramBuf.String(), `\?`, "?", -1)
|
||||
return ret, total
|
||||
}
|
||||
|
||||
// convertQuestionMarks converts each occurrence of ? with $<number>
|
||||
// where <number> is an incrementing digit starting at startAt.
|
||||
// If question-mark (?) is escaped using back-slash (\), it will be ignored.
|
||||
func convertQuestionMarks(clause string, startAt int) (string, int) {
|
||||
if startAt == 0 {
|
||||
panic("Not a valid start number.")
|
||||
}
|
||||
|
||||
paramBuf := strmangle.GetBuffer()
|
||||
defer strmangle.PutBuffer(paramBuf)
|
||||
paramIndex := 0
|
||||
total := 0
|
||||
|
||||
for {
|
||||
if paramIndex >= len(clause) {
|
||||
break
|
||||
}
|
||||
|
||||
clause = clause[paramIndex:]
|
||||
paramIndex = strings.IndexByte(clause, '?')
|
||||
|
||||
if paramIndex == -1 {
|
||||
paramBuf.WriteString(clause)
|
||||
break
|
||||
}
|
||||
|
||||
escapeIndex := strings.Index(clause, `\?`)
|
||||
if escapeIndex != -1 && paramIndex > escapeIndex {
|
||||
paramBuf.WriteString(clause[:escapeIndex] + "?")
|
||||
paramIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
|
||||
total++
|
||||
startAt++
|
||||
paramIndex++
|
||||
}
|
||||
|
||||
return paramBuf.String(), total
|
||||
}
|
||||
|
||||
// parseFromClause will parse something that looks like
|
||||
// a
|
||||
// a b
|
||||
// a as b
|
||||
func parseFromClause(toks []string) (alias, name string, ok bool) {
|
||||
if len(toks) > 3 {
|
||||
toks = toks[:3]
|
||||
}
|
||||
|
||||
sawIdent, sawAs := false, false
|
||||
for _, tok := range toks {
|
||||
if t := strings.ToLower(tok); sawIdent && t == "as" {
|
||||
sawAs = true
|
||||
continue
|
||||
} else if sawIdent && t == "on" {
|
||||
break
|
||||
}
|
||||
|
||||
if !rgxIdentifier.MatchString(tok) {
|
||||
break
|
||||
}
|
||||
|
||||
if sawIdent || sawAs {
|
||||
alias = strings.Trim(tok, `"`)
|
||||
break
|
||||
}
|
||||
|
||||
name = strings.Trim(tok, `"`)
|
||||
sawIdent = true
|
||||
ok = true
|
||||
}
|
||||
|
||||
return alias, name, ok
|
||||
}
|
547
queries/query_builders_test.go
Normal file
547
queries/query_builders_test.go
Normal file
|
@ -0,0 +1,547 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
var writeGoldenFiles = flag.Bool(
|
||||
"test.golden",
|
||||
false,
|
||||
"Write golden files.",
|
||||
)
|
||||
|
||||
func TestBuildQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
q *Query
|
||||
args []interface{}
|
||||
}{
|
||||
{&Query{from: []string{"t"}}, nil},
|
||||
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
|
||||
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
|
||||
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
||||
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
||||
{&Query{
|
||||
selectCols: []string{"a.happy", "r.fun", "q"},
|
||||
from: []string{"happiness as a"},
|
||||
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
|
||||
}, nil},
|
||||
{&Query{
|
||||
from: []string{"happiness as a"},
|
||||
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
|
||||
}, nil},
|
||||
{&Query{
|
||||
from: []string{"videos"},
|
||||
joins: []join{{
|
||||
clause: "(select id from users where deleted = ?) u on u.id = videos.user_id",
|
||||
args: []interface{}{true},
|
||||
}},
|
||||
where: []where{{clause: "videos.deleted = ?", args: []interface{}{false}}},
|
||||
}, []interface{}{true, false}},
|
||||
{&Query{
|
||||
from: []string{"a"},
|
||||
groupBy: []string{"id", "name"},
|
||||
where: []where{
|
||||
{clause: "a=? or b=?", args: []interface{}{1, 2}},
|
||||
{clause: "c=?", args: []interface{}{3}},
|
||||
},
|
||||
having: []having{
|
||||
{clause: "id <> ?", args: []interface{}{1}},
|
||||
{clause: "length(name, ?) > ?", args: []interface{}{"utf8", 5}},
|
||||
},
|
||||
}, []interface{}{1, 2, 3, 1, "utf8", 5}},
|
||||
{&Query{
|
||||
delete: true,
|
||||
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
|
||||
where: []where{
|
||||
{clause: "a=?", args: []interface{}{}},
|
||||
{clause: "b=?", args: []interface{}{}},
|
||||
{clause: "c=?", args: []interface{}{}},
|
||||
},
|
||||
}, nil},
|
||||
{&Query{
|
||||
delete: true,
|
||||
from: []string{"thing happy", `upset as "sad"`, "fun", "thing as stuff", `"angry" as mad`},
|
||||
where: []where{
|
||||
{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}},
|
||||
},
|
||||
limit: 5,
|
||||
}, nil},
|
||||
{&Query{
|
||||
from: []string{"thing happy", `"fun"`, `stuff`},
|
||||
update: map[string]interface{}{
|
||||
"col1": 1,
|
||||
`"col2"`: 2,
|
||||
`"fun".col3`: 3,
|
||||
},
|
||||
where: []where{
|
||||
{clause: "aa=? or bb=? or cc=?", orSeparator: true, args: []interface{}{4, 5, 6}},
|
||||
{clause: "dd=? or ee=? or ff=? and gg=?", args: []interface{}{7, 8, 9, 10}},
|
||||
},
|
||||
limit: 5,
|
||||
}, []interface{}{2, 3, 1, 4, 5, 6, 7, 8, 9, 10}},
|
||||
{&Query{from: []string{"cats"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
|
||||
{&Query{from: []string{"cats c"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
|
||||
{&Query{from: []string{"cats as c"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
|
||||
{&Query{from: []string{"cats as c", "dogs as d"}, joins: []join{{JoinInner, "dogs d on d.cat_id = cats.id", nil}}}, nil},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
|
||||
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
|
||||
out, args := buildQuery(test.q)
|
||||
|
||||
if *writeGoldenFiles {
|
||||
err := ioutil.WriteFile(filename, []byte(out), 0664)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
|
||||
}
|
||||
t.Logf("wrote golden file: %s\n", filename)
|
||||
continue
|
||||
}
|
||||
|
||||
byt, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read golden file %q: %v", filename, err)
|
||||
}
|
||||
|
||||
if string(bytes.TrimSpace(byt)) != out {
|
||||
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(args, test.args) {
|
||||
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteStars(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
In Query
|
||||
Out []string
|
||||
}{
|
||||
{
|
||||
In: Query{from: []string{`a`}},
|
||||
Out: []string{`"a".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`}},
|
||||
Out: []string{`"b".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`, `c`}},
|
||||
Out: []string{`"b".*`, `"c".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`, `c as d`}},
|
||||
Out: []string{`"b".*`, `"d".*`},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
|
||||
selects := writeStars(&test.In)
|
||||
if !reflect.DeepEqual(selects, test.Out) {
|
||||
t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhereClause(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
q Query
|
||||
expect string
|
||||
}{
|
||||
// Or("a=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "a=?", orSeparator: true}},
|
||||
},
|
||||
expect: " WHERE (a=$1)",
|
||||
},
|
||||
// Where("a=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "a=?"}},
|
||||
},
|
||||
expect: " WHERE (a=$1)",
|
||||
},
|
||||
// Where("(a=?)")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "(a=?)"}},
|
||||
},
|
||||
expect: " WHERE ((a=$1))",
|
||||
},
|
||||
// Where("((a=? OR b=?))")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "((a=? OR b=?))"}},
|
||||
},
|
||||
expect: " WHERE (((a=$1 OR b=$2)))",
|
||||
},
|
||||
// Where("(a=?)", Or("(b=?)")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "(a=?)"},
|
||||
{clause: "(b=?)", orSeparator: true},
|
||||
},
|
||||
},
|
||||
expect: " WHERE ((a=$1)) OR ((b=$2))",
|
||||
},
|
||||
// Where("a=? OR b=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "a=? OR b=?"}},
|
||||
},
|
||||
expect: " WHERE (a=$1 OR b=$2)",
|
||||
},
|
||||
// Where("a=?"), Where("b=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "a=?"}, {clause: "b=?"}},
|
||||
},
|
||||
expect: " WHERE (a=$1) AND (b=$2)",
|
||||
},
|
||||
// Where("(a=? AND b=?) OR c=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{{clause: "(a=? AND b=?) OR c=?"}},
|
||||
},
|
||||
expect: " WHERE ((a=$1 AND b=$2) OR c=$3)",
|
||||
},
|
||||
// Where("a=? OR b=?"), Where("c=? OR d=? OR e=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "(a=? OR b=?)"},
|
||||
{clause: "(c=? OR d=? OR e=?)"},
|
||||
},
|
||||
},
|
||||
expect: " WHERE ((a=$1 OR b=$2)) AND ((c=$3 OR d=$4 OR e=$5))",
|
||||
},
|
||||
// Where("(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR f=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "(a=? AND b=?) OR (c=? AND d=? AND e=?) OR f=? OR g=?"},
|
||||
},
|
||||
},
|
||||
expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 AND e=$5) OR f=$6 OR g=$7)",
|
||||
},
|
||||
// Where("(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "(a=? AND b=?) OR (c=? AND d=? OR e=?) OR f=? OR g=?"},
|
||||
},
|
||||
},
|
||||
expect: " WHERE ((a=$1 AND b=$2) OR (c=$3 AND d=$4 OR e=$5) OR f=$6 OR g=$7)",
|
||||
},
|
||||
// Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "a=? or b=?", orSeparator: true},
|
||||
{clause: "c=? and d=?", orSeparator: true},
|
||||
{clause: "e=? or f=?", orSeparator: true},
|
||||
},
|
||||
},
|
||||
expect: " WHERE (a=$1 or b=$2) OR (c=$3 and d=$4) OR (e=$5 or f=$6)",
|
||||
},
|
||||
// Where("a=? or b=?"), Or("c=? and d=?"), Or("e=? or f=?")
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "a=? or b=?"},
|
||||
{clause: "c=? and d=?", orSeparator: true},
|
||||
{clause: "e=? or f=?"},
|
||||
},
|
||||
},
|
||||
expect: " WHERE (a=$1 or b=$2) OR (c=$3 and d=$4) AND (e=$5 or f=$6)",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
|
||||
result, _ := whereClause(&test.q, 1)
|
||||
if result != test.expect {
|
||||
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInClause(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
q Query
|
||||
expect string
|
||||
args []interface{}
|
||||
}{
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: "a in ?", args: []interface{}{}, orSeparator: true}},
|
||||
},
|
||||
expect: ` WHERE "a" IN ()`,
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: "a in ?", args: []interface{}{1}, orSeparator: true}},
|
||||
},
|
||||
expect: ` WHERE "a" IN ($1)`,
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: "a in ?", args: []interface{}{1, 2, 3}}},
|
||||
},
|
||||
expect: ` WHERE "a" IN ($1,$2,$3)`,
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: "? in ?", args: []interface{}{1, 2, 3}}},
|
||||
},
|
||||
expect: " WHERE $1 IN ($2,$3)",
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: "( ? , ? ) in ( ? )", orSeparator: true, args: []interface{}{"a", "b", 1, 2, 3, 4}}},
|
||||
},
|
||||
expect: " WHERE ( $1 , $2 ) IN ( (($3,$4),($5,$6)) )",
|
||||
args: []interface{}{"a", "b", 1, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: `("a")in(?)`, orSeparator: true, args: []interface{}{1, 2, 3}}},
|
||||
},
|
||||
expect: ` WHERE ("a") IN (($1,$2,$3))`,
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{{clause: `("a")in?`, args: []interface{}{1}}},
|
||||
},
|
||||
expect: ` WHERE ("a") IN ($1)`,
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
where: []where{
|
||||
{clause: "a=?", args: []interface{}{1}},
|
||||
},
|
||||
in: []in{
|
||||
{clause: `?,?,"name" in ?`, orSeparator: true, args: []interface{}{"c", "d", 3, 4, 5, 6, 7, 8}},
|
||||
{clause: `?,?,"name" in ?`, orSeparator: true, args: []interface{}{"e", "f", 9, 10, 11, 12, 13, 14}},
|
||||
},
|
||||
},
|
||||
expect: ` OR $1,$2,"name" IN (($3,$4,$5),($6,$7,$8)) OR $9,$10,"name" IN (($11,$12,$13),($14,$15,$16))`,
|
||||
args: []interface{}{"c", "d", 3, 4, 5, 6, 7, 8, "e", "f", 9, 10, 11, 12, 13, 14},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{
|
||||
{clause: `("a")in`, args: []interface{}{1}},
|
||||
{clause: `("a")in?`, orSeparator: true, args: []interface{}{1}},
|
||||
},
|
||||
},
|
||||
expect: ` WHERE ("a")in OR ("a") IN ($1)`,
|
||||
args: []interface{}{1, 1},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{
|
||||
{clause: `\?,\? in \?`, args: []interface{}{1}},
|
||||
{clause: `\?,\?in \?`, orSeparator: true, args: []interface{}{1}},
|
||||
},
|
||||
},
|
||||
expect: ` WHERE ?,? IN ? OR ?,? IN ?`,
|
||||
args: []interface{}{1, 1},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{
|
||||
{clause: `("a")in`, args: []interface{}{1}},
|
||||
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
|
||||
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
|
||||
},
|
||||
},
|
||||
expect: ` WHERE ("a")in AND ("a") IN thing OR ("a") IN ($1,$2,$3)`,
|
||||
args: []interface{}{1, 1, 2, 3, 4, 5, 6},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{
|
||||
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
|
||||
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
|
||||
{clause: `("a")in`, args: []interface{}{1}},
|
||||
},
|
||||
},
|
||||
expect: ` WHERE ("a") IN ($1,$2,$3) AND ("a") IN thing AND ("a")in`,
|
||||
args: []interface{}{4, 5, 6, 1, 2, 3, 1},
|
||||
},
|
||||
{
|
||||
q: Query{
|
||||
in: []in{
|
||||
{clause: `("a")in?`, orSeparator: true, args: []interface{}{4, 5, 6}},
|
||||
{clause: `("a")in`, args: []interface{}{1}},
|
||||
{clause: `("a") in thing`, args: []interface{}{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
expect: ` WHERE ("a") IN ($1,$2,$3) AND ("a")in AND ("a") IN thing`,
|
||||
args: []interface{}{4, 5, 6, 1, 1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
|
||||
result, args := inClause(&test.q, 1)
|
||||
if result != test.expect {
|
||||
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
|
||||
}
|
||||
if !reflect.DeepEqual(args, test.args) {
|
||||
t.Errorf("%d) Mismatch between expected args:\n%#v\n%#v\n", i, test.args, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertQuestionMarks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
clause string
|
||||
start int
|
||||
expect string
|
||||
count int
|
||||
}{
|
||||
{clause: "hello friend", start: 1, expect: "hello friend", count: 0},
|
||||
{clause: "thing=?", start: 2, expect: "thing=$2", count: 1},
|
||||
{clause: "thing=? and stuff=? and happy=?", start: 2, expect: "thing=$2 and stuff=$3 and happy=$4", count: 3},
|
||||
{clause: `thing \? stuff`, start: 2, expect: `thing ? stuff`, count: 0},
|
||||
{clause: `thing \? stuff and happy \? fun`, start: 2, expect: `thing ? stuff and happy ? fun`, count: 0},
|
||||
{
|
||||
clause: `thing \? stuff ? happy \? and mad ? fun \? \? \?`,
|
||||
start: 2,
|
||||
expect: `thing ? stuff $2 happy ? and mad $3 fun ? ? ?`,
|
||||
count: 2,
|
||||
},
|
||||
{
|
||||
clause: `thing ? stuff ? happy \? fun \? ? ?`,
|
||||
start: 1,
|
||||
expect: `thing $1 stuff $2 happy ? fun ? $3 $4`,
|
||||
count: 4,
|
||||
},
|
||||
{clause: `?`, start: 1, expect: `$1`, count: 1},
|
||||
{clause: `???`, start: 1, expect: `$1$2$3`, count: 3},
|
||||
{clause: `\?`, start: 1, expect: `?`},
|
||||
{clause: `\?\?\?`, start: 1, expect: `???`},
|
||||
{clause: `\??\??\??`, start: 1, expect: `?$1?$2?$3`, count: 3},
|
||||
{clause: `?\??\??\?`, start: 1, expect: `$1?$2?$3?`, count: 3},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
res, count := convertQuestionMarks(test.clause, test.start)
|
||||
if res != test.expect {
|
||||
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
|
||||
}
|
||||
if count != test.count {
|
||||
t.Errorf("%d) Expected count %d, got %d", i, test.count, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertInQuestionMarks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
clause string
|
||||
start int
|
||||
group int
|
||||
total int
|
||||
expect string
|
||||
}{
|
||||
{clause: "?", expect: "(($1,$2,$3),($4,$5,$6),($7,$8,$9))", start: 1, total: 9, group: 3},
|
||||
{clause: "?", expect: "(($2,$3),($4))", start: 2, total: 3, group: 2},
|
||||
{clause: "hello friend", start: 1, expect: "hello friend", total: 0, group: 1},
|
||||
{clause: "thing ? thing", start: 2, expect: "thing ($2,$3) thing", total: 2, group: 1},
|
||||
{clause: "thing?thing", start: 2, expect: "thing($2)thing", total: 1, group: 1},
|
||||
{clause: `thing \? stuff`, start: 2, expect: `thing ? stuff`, total: 0, group: 1},
|
||||
{clause: `thing \? stuff and happy \? fun`, start: 2, expect: `thing ? stuff and happy ? fun`, total: 0, group: 1},
|
||||
{clause: "thing ? thing ? thing", start: 1, expect: "thing ($1,$2,$3) thing ? thing", total: 3, group: 1},
|
||||
{clause: `?`, start: 1, expect: `($1)`, total: 1, group: 1},
|
||||
{clause: `???`, start: 1, expect: `($1,$2,$3)??`, total: 3, group: 1},
|
||||
{clause: `\?`, start: 1, expect: `?`, total: 0, group: 1},
|
||||
{clause: `\?\?\?`, start: 1, expect: `???`, total: 0, group: 1},
|
||||
{clause: `\??\??\??`, start: 1, expect: `?($1,$2,$3)????`, total: 3, group: 1},
|
||||
{clause: `?\??\??\?`, start: 1, expect: `($1,$2,$3)?????`, total: 3, group: 1},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total)
|
||||
if res != test.expect {
|
||||
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
|
||||
}
|
||||
if count != test.total {
|
||||
t.Errorf("%d) Expected %d, got %d", i, test.total, count)
|
||||
}
|
||||
}
|
||||
|
||||
res, count := convertInQuestionMarks(false, "?", 1, 3, 9)
|
||||
if res != "((?,?,?),(?,?,?),(?,?,?))" {
|
||||
t.Errorf("Mismatch between expected and result: %s", res)
|
||||
}
|
||||
if count != 9 {
|
||||
t.Errorf("Expected 9 results, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAsStatements(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
query := Query{
|
||||
selectCols: []string{
|
||||
`a`,
|
||||
`a.fun`,
|
||||
`"b"."fun"`,
|
||||
`"b".fun`,
|
||||
`b."fun"`,
|
||||
`a.clown.run`,
|
||||
`COUNT(a)`,
|
||||
},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
expect := []string{
|
||||
`"a"`,
|
||||
`"a"."fun" as "a.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"a"."clown"."run" as "a.clown.run"`,
|
||||
`COUNT(a)`,
|
||||
}
|
||||
|
||||
gots := writeAsStatements(&query)
|
||||
|
||||
for i, got := range gots {
|
||||
if expect[i] != got {
|
||||
t.Errorf(`%d) want: %s, got: %s`, i, expect[i], got)
|
||||
}
|
||||
}
|
||||
}
|
438
queries/query_test.go
Normal file
438
queries/query_test.go
Normal file
|
@ -0,0 +1,438 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetLimit(q, 10)
|
||||
|
||||
expect := 10
|
||||
if q.limit != expect {
|
||||
t.Errorf("Expected %d, got %d", expect, q.limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOffset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetOffset(q, 10)
|
||||
|
||||
expect := 10
|
||||
if q.offset != expect {
|
||||
t.Errorf("Expected %d, got %d", expect, q.offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSQL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetSQL(q, "select * from thing", 5, 3)
|
||||
|
||||
if len(q.plainSQL.args) != 2 {
|
||||
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args))
|
||||
}
|
||||
|
||||
if q.plainSQL.sql != "select * from thing" {
|
||||
t.Errorf("Was not expected string, got %s", q.plainSQL.sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetLoad(q, "one", "two")
|
||||
|
||||
if len(q.load) != 2 {
|
||||
t.Errorf("Expected len 2, got %d", len(q.load))
|
||||
}
|
||||
|
||||
if q.load[0] != "one" || q.load[1] != "two" {
|
||||
t.Errorf("Was not expected string, got %s", q.load)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWhere(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
expect := "x > $1 AND y > $2"
|
||||
AppendWhere(q, expect, 5, 3)
|
||||
AppendWhere(q, expect, 5, 3)
|
||||
|
||||
if len(q.where) != 2 {
|
||||
t.Errorf("%#v", q.where)
|
||||
}
|
||||
|
||||
if q.where[0].clause != expect || q.where[1].clause != expect {
|
||||
t.Errorf("Expected %s, got %#v", expect, q.where)
|
||||
}
|
||||
|
||||
if len(q.where[0].args) != 2 || len(q.where[0].args) != 2 {
|
||||
t.Errorf("arg length wrong: %#v", q.where)
|
||||
}
|
||||
|
||||
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
|
||||
t.Errorf("args wrong: %#v", q.where)
|
||||
}
|
||||
|
||||
q.where = []where{{clause: expect, args: []interface{}{5, 3}}}
|
||||
if q.where[0].clause != expect {
|
||||
t.Errorf("Expected %s, got %v", expect, q.where)
|
||||
}
|
||||
|
||||
if len(q.where[0].args) != 2 {
|
||||
t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args))
|
||||
}
|
||||
|
||||
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
|
||||
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args)
|
||||
}
|
||||
|
||||
if len(q.where) != 1 {
|
||||
t.Errorf("%#v", q.where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLastWhereAsOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
q := &Query{}
|
||||
|
||||
AppendWhere(q, "")
|
||||
|
||||
if q.where[0].orSeparator {
|
||||
t.Errorf("Do not want or separator")
|
||||
}
|
||||
|
||||
SetLastWhereAsOr(q)
|
||||
|
||||
if len(q.where) != 1 {
|
||||
t.Errorf("Want len 1")
|
||||
}
|
||||
if !q.where[0].orSeparator {
|
||||
t.Errorf("Want or separator")
|
||||
}
|
||||
|
||||
AppendWhere(q, "")
|
||||
SetLastWhereAsOr(q)
|
||||
|
||||
if len(q.where) != 2 {
|
||||
t.Errorf("Want len 2")
|
||||
}
|
||||
if q.where[0].orSeparator != true {
|
||||
t.Errorf("Expected true")
|
||||
}
|
||||
if q.where[1].orSeparator != true {
|
||||
t.Errorf("Expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
expect := "col IN ?"
|
||||
AppendIn(q, expect, 5, 3)
|
||||
AppendIn(q, expect, 5, 3)
|
||||
|
||||
if len(q.in) != 2 {
|
||||
t.Errorf("%#v", q.in)
|
||||
}
|
||||
|
||||
if q.in[0].clause != expect || q.in[1].clause != expect {
|
||||
t.Errorf("Expected %s, got %#v", expect, q.in)
|
||||
}
|
||||
|
||||
if len(q.in[0].args) != 2 || len(q.in[0].args) != 2 {
|
||||
t.Errorf("arg length wrong: %#v", q.in)
|
||||
}
|
||||
|
||||
if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 {
|
||||
t.Errorf("args wrong: %#v", q.in)
|
||||
}
|
||||
|
||||
q.in = []in{{clause: expect, args: []interface{}{5, 3}}}
|
||||
if q.in[0].clause != expect {
|
||||
t.Errorf("Expected %s, got %v", expect, q.in)
|
||||
}
|
||||
|
||||
if len(q.in[0].args) != 2 {
|
||||
t.Errorf("Expected %d args, got %d", 2, len(q.in[0].args))
|
||||
}
|
||||
|
||||
if q.in[0].args[0].(int) != 5 || q.in[0].args[1].(int) != 3 {
|
||||
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.in[0].args)
|
||||
}
|
||||
|
||||
if len(q.in) != 1 {
|
||||
t.Errorf("%#v", q.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLastInAsOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
q := &Query{}
|
||||
|
||||
AppendIn(q, "")
|
||||
|
||||
if q.in[0].orSeparator {
|
||||
t.Errorf("Do not want or separator")
|
||||
}
|
||||
|
||||
SetLastInAsOr(q)
|
||||
|
||||
if len(q.in) != 1 {
|
||||
t.Errorf("Want len 1")
|
||||
}
|
||||
if !q.in[0].orSeparator {
|
||||
t.Errorf("Want or separator")
|
||||
}
|
||||
|
||||
AppendIn(q, "")
|
||||
SetLastInAsOr(q)
|
||||
|
||||
if len(q.in) != 2 {
|
||||
t.Errorf("Want len 2")
|
||||
}
|
||||
if q.in[0].orSeparator != true {
|
||||
t.Errorf("Expected true")
|
||||
}
|
||||
if q.in[1].orSeparator != true {
|
||||
t.Errorf("Expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendGroupBy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
expect := "col1, col2"
|
||||
AppendGroupBy(q, expect)
|
||||
AppendGroupBy(q, expect)
|
||||
|
||||
if len(q.groupBy) != 2 && (q.groupBy[0] != expect || q.groupBy[1] != expect) {
|
||||
t.Errorf("Expected %s, got %s %s", expect, q.groupBy[0], q.groupBy[1])
|
||||
}
|
||||
|
||||
q.groupBy = []string{expect}
|
||||
if len(q.groupBy) != 1 && q.groupBy[0] != expect {
|
||||
t.Errorf("Expected %s, got %s", expect, q.groupBy[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendOrderBy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
expect := "col1 desc, col2 asc"
|
||||
AppendOrderBy(q, expect)
|
||||
AppendOrderBy(q, expect)
|
||||
|
||||
if len(q.orderBy) != 2 && (q.orderBy[0] != expect || q.orderBy[1] != expect) {
|
||||
t.Errorf("Expected %s, got %s %s", expect, q.orderBy[0], q.orderBy[1])
|
||||
}
|
||||
|
||||
q.orderBy = []string{"col1 desc, col2 asc"}
|
||||
if len(q.orderBy) != 1 && q.orderBy[0] != expect {
|
||||
t.Errorf("Expected %s, got %s", expect, q.orderBy[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHaving(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
expect := "count(orders.order_id) > ?"
|
||||
AppendHaving(q, expect, 10)
|
||||
AppendHaving(q, expect, 10)
|
||||
|
||||
if len(q.having) != 2 {
|
||||
t.Errorf("Expected 2, got %d", len(q.having))
|
||||
}
|
||||
|
||||
if q.having[0].clause != expect || q.having[1].clause != expect {
|
||||
t.Errorf("Expected %s, got %s %s", expect, q.having[0].clause, q.having[1].clause)
|
||||
}
|
||||
|
||||
if q.having[0].args[0] != 10 || q.having[1].args[0] != 10 {
|
||||
t.Errorf("Expected %v, got %v %v", 10, q.having[0].args[0], q.having[1].args[0])
|
||||
}
|
||||
|
||||
q.having = []having{{clause: expect, args: []interface{}{10}}}
|
||||
if len(q.having) != 1 && (q.having[0].clause != expect || q.having[0].args[0] != 10) {
|
||||
t.Errorf("Expected %s, got %s %v", expect, q.having[0], q.having[0].args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
AppendFrom(q, "videos a", "orders b")
|
||||
AppendFrom(q, "videos a", "orders b")
|
||||
|
||||
expect := []string{"videos a", "orders b", "videos a", "orders b"}
|
||||
if !reflect.DeepEqual(q.from, expect) {
|
||||
t.Errorf("Expected %s, got %s", expect, q.from)
|
||||
}
|
||||
|
||||
SetFrom(q, "videos a", "orders b")
|
||||
if !reflect.DeepEqual(q.from, expect[:2]) {
|
||||
t.Errorf("Expected %s, got %s", expect, q.from)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSelect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{selectCols: []string{"hello"}}
|
||||
SetSelect(q, nil)
|
||||
|
||||
if q.selectCols != nil {
|
||||
t.Errorf("want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetCount(q)
|
||||
|
||||
if q.count != true {
|
||||
t.Errorf("got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetUpdate(q, map[string]interface{}{"test": 5})
|
||||
|
||||
if q.update["test"] != 5 {
|
||||
t.Errorf("Wrong update, got %v", q.update)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
SetDelete(q)
|
||||
|
||||
if q.delete != true {
|
||||
t.Errorf("Expected %t, got %t", true, q.delete)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetExecutor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
d := &sql.DB{}
|
||||
SetExecutor(q, d)
|
||||
|
||||
if q.executor != d {
|
||||
t.Errorf("Expected executor to get set to d, but was: %#v", q.executor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendSelect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
AppendSelect(q, "col1", "col2")
|
||||
AppendSelect(q, "col1", "col2")
|
||||
|
||||
if len(q.selectCols) != 4 {
|
||||
t.Errorf("Expected selectCols len 4, got %d", len(q.selectCols))
|
||||
}
|
||||
|
||||
if q.selectCols[0] != `col1` && q.selectCols[1] != `col2` {
|
||||
t.Errorf("select cols value mismatch: %#v", q.selectCols)
|
||||
}
|
||||
if q.selectCols[2] != `col1` && q.selectCols[3] != `col2` {
|
||||
t.Errorf("select cols value mismatch: %#v", q.selectCols)
|
||||
}
|
||||
|
||||
q.selectCols = []string{"col1", "col2"}
|
||||
if q.selectCols[0] != `col1` && q.selectCols[1] != `col2` {
|
||||
t.Errorf("select cols value mismatch: %#v", q.selectCols)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := SQL(&sql.DB{}, "thing", 5)
|
||||
if q.plainSQL.sql != "thing" {
|
||||
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
|
||||
}
|
||||
if q.plainSQL.args[0].(int) != 5 {
|
||||
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLG(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := SQLG("thing", 5)
|
||||
if q.plainSQL.sql != "thing" {
|
||||
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
|
||||
}
|
||||
if q.plainSQL.args[0].(int) != 5 {
|
||||
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendInnerJoin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &Query{}
|
||||
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
||||
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
||||
|
||||
if len(q.joins) != 2 {
|
||||
t.Errorf("Expected len 1, got %d", len(q.joins))
|
||||
}
|
||||
|
||||
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
|
||||
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||
}
|
||||
if q.joins[1].clause != "thing=$1 AND stuff=$2" {
|
||||
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||
}
|
||||
|
||||
if len(q.joins[0].args) != 2 {
|
||||
t.Errorf("Expected len 2, got %d", len(q.joins[0].args))
|
||||
}
|
||||
if len(q.joins[1].args) != 2 {
|
||||
t.Errorf("Expected len 2, got %d", len(q.joins[1].args))
|
||||
}
|
||||
|
||||
if q.joins[0].args[0] != 2 && q.joins[0].args[1] != 5 {
|
||||
t.Errorf("Invalid args values, got %#v", q.joins[0].args)
|
||||
}
|
||||
|
||||
q.joins = []join{{kind: JoinInner,
|
||||
clause: "thing=$1 AND stuff=$2",
|
||||
args: []interface{}{2, 5},
|
||||
}}
|
||||
|
||||
if len(q.joins) != 1 {
|
||||
t.Errorf("Expected len 1, got %d", len(q.joins))
|
||||
}
|
||||
|
||||
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
|
||||
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||
}
|
||||
}
|
480
queries/reflect.go
Normal file
480
queries/reflect.go
Normal file
|
@ -0,0 +1,480 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vattle/sqlboiler/boil"
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
var (
|
||||
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
|
||||
|
||||
mut sync.RWMutex
|
||||
bindingMaps = make(map[string][]uint64)
|
||||
structMaps = make(map[string]map[string]uint64)
|
||||
)
|
||||
|
||||
// Identifies what kind of object we're binding to
|
||||
type bindKind int
|
||||
|
||||
const (
|
||||
kindStruct bindKind = iota
|
||||
kindSliceStruct
|
||||
kindPtrSliceStruct
|
||||
)
|
||||
|
||||
const (
|
||||
loadMethodPrefix = "Load"
|
||||
relationshipStructName = "R"
|
||||
loaderStructName = "L"
|
||||
sentinel = uint64(255)
|
||||
)
|
||||
|
||||
// BindP executes the query and inserts the
|
||||
// result into the passed in object pointer.
|
||||
// It panics on error. See boil.Bind() documentation.
|
||||
func (q *Query) BindP(obj interface{}) {
|
||||
if err := q.Bind(obj); err != nil {
|
||||
panic(boil.WrapErr(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Bind executes the query and inserts the
|
||||
// result into the passed in object pointer
|
||||
//
|
||||
// Bind rules:
|
||||
// - Struct tags control bind, in the form of: `boil:"name,bind"`
|
||||
// - If "name" is omitted the sql column names that come back are TitleCased
|
||||
// and matched against the field name.
|
||||
// - If the "name" part of the struct tag is specified, the given name will
|
||||
// be used instead of the struct field name for binding.
|
||||
// - If the "name" of the struct tag is "-", this field will not be bound to.
|
||||
// - If the ",bind" option is specified on a struct field and that field
|
||||
// is a struct itself, it will be recursed into to look for fields for binding.
|
||||
//
|
||||
// Example Query:
|
||||
//
|
||||
// type JoinStruct struct {
|
||||
// // User1 can have it's struct fields bound to since it specifies
|
||||
// // ,bind in the struct tag, it will look specifically for
|
||||
// // fields that are prefixed with "user." returning from the query.
|
||||
// // For example "user.id" column name will bind to User1.ID
|
||||
// User1 *models.User `boil:"user,bind"`
|
||||
// // User2 will follow the same rules as noted above except it will use
|
||||
// // "friend." as the prefix it's looking for.
|
||||
// User2 *models.User `boil:"friend,bind"`
|
||||
// // RandomData will not be recursed into to look for fields to
|
||||
// // bind and will not be bound to because of the - for the name.
|
||||
// RandomData myStruct `boil:"-"`
|
||||
// // Date will not be recursed into to look for fields to bind because
|
||||
// // it does not specify ,bind in the struct tag. But it can be bound to
|
||||
// // as it does not specify a - for the name.
|
||||
// Date time.Time
|
||||
// }
|
||||
//
|
||||
// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct)
|
||||
//
|
||||
// For custom objects that want to use eager loading, please see the
|
||||
// loadRelationships function.
|
||||
func Bind(rows *sql.Rows, obj interface{}) error {
|
||||
structType, sliceType, singular, err := bindChecks(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bind(rows, obj, structType, sliceType, singular)
|
||||
}
|
||||
|
||||
// Bind executes the query and inserts the
|
||||
// result into the passed in object pointer
|
||||
//
|
||||
// See documentation for boil.Bind()
|
||||
func (q *Query) Bind(obj interface{}) error {
|
||||
structType, sliceType, bkind, err := bindChecks(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := q.Query()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bind failed to execute query")
|
||||
}
|
||||
defer rows.Close()
|
||||
if res := bind(rows, obj, structType, sliceType, bkind); res != nil {
|
||||
return res
|
||||
}
|
||||
|
||||
if len(q.load) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := loadRelationshipState{
|
||||
exec: q.executor,
|
||||
loaded: map[string]struct{}{},
|
||||
}
|
||||
for _, toLoad := range q.load {
|
||||
state.toLoad = strings.Split(toLoad, ".")
|
||||
if err = state.loadRelationships(0, obj, bkind); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindChecks resolves information about the bind target, and errors if it's not an object
|
||||
// we can bind to.
|
||||
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, bkind bindKind, err error) {
|
||||
typ := reflect.TypeOf(obj)
|
||||
kind := typ.Kind()
|
||||
|
||||
setErr := func() {
|
||||
err = errors.Errorf("obj type should be *Type, *[]Type, or *[]*Type but was %q", reflect.TypeOf(obj).String())
|
||||
}
|
||||
|
||||
for i := 0; ; i++ {
|
||||
switch i {
|
||||
case 0:
|
||||
if kind != reflect.Ptr {
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 1:
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
bkind = kindStruct
|
||||
return
|
||||
case reflect.Slice:
|
||||
sliceType = typ
|
||||
default:
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 2:
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
bkind = kindSliceStruct
|
||||
return
|
||||
case reflect.Ptr:
|
||||
default:
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 3:
|
||||
if kind != reflect.Struct {
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
structType = typ
|
||||
bkind = kindPtrSliceStruct
|
||||
return
|
||||
}
|
||||
|
||||
typ = typ.Elem()
|
||||
kind = typ.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, bkind bindKind) error {
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bind failed to get column names")
|
||||
}
|
||||
|
||||
var ptrSlice reflect.Value
|
||||
switch bkind {
|
||||
case kindSliceStruct, kindPtrSliceStruct:
|
||||
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
|
||||
}
|
||||
|
||||
var strMapping map[string]uint64
|
||||
var sok bool
|
||||
var mapping []uint64
|
||||
var ok bool
|
||||
|
||||
typStr := structType.String()
|
||||
|
||||
mapKey := makeCacheKey(typStr, cols)
|
||||
mut.RLock()
|
||||
mapping, ok = bindingMaps[mapKey]
|
||||
if !ok {
|
||||
if strMapping, sok = structMaps[typStr]; !sok {
|
||||
strMapping = MakeStructMapping(structType)
|
||||
}
|
||||
}
|
||||
mut.RUnlock()
|
||||
|
||||
if !ok {
|
||||
mapping, err = BindMapping(structType, strMapping, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mut.Lock()
|
||||
if !sok {
|
||||
structMaps[typStr] = strMapping
|
||||
}
|
||||
bindingMaps[mapKey] = mapping
|
||||
mut.Unlock()
|
||||
}
|
||||
|
||||
var oneStruct reflect.Value
|
||||
if bkind == kindSliceStruct {
|
||||
oneStruct = reflect.Indirect(reflect.New(structType))
|
||||
}
|
||||
|
||||
foundOne := false
|
||||
for rows.Next() {
|
||||
foundOne = true
|
||||
var newStruct reflect.Value
|
||||
var pointers []interface{}
|
||||
|
||||
switch bkind {
|
||||
case kindStruct:
|
||||
pointers = PtrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
|
||||
case kindSliceStruct:
|
||||
pointers = PtrsFromMapping(oneStruct, mapping)
|
||||
case kindPtrSliceStruct:
|
||||
newStruct = reflect.New(structType)
|
||||
pointers = PtrsFromMapping(reflect.Indirect(newStruct), mapping)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Scan(pointers...); err != nil {
|
||||
return errors.Wrap(err, "failed to bind pointers to obj")
|
||||
}
|
||||
|
||||
switch bkind {
|
||||
case kindSliceStruct:
|
||||
ptrSlice.Set(reflect.Append(ptrSlice, oneStruct))
|
||||
case kindPtrSliceStruct:
|
||||
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
|
||||
}
|
||||
}
|
||||
|
||||
if bkind == kindStruct && !foundOne {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindMapping creates a mapping that helps look up the pointer for the
|
||||
// column given.
|
||||
func BindMapping(typ reflect.Type, mapping map[string]uint64, cols []string) ([]uint64, error) {
|
||||
ptrs := make([]uint64, len(cols))
|
||||
|
||||
ColLoop:
|
||||
for i, c := range cols {
|
||||
name := strmangle.TitleCaseIdentifier(c)
|
||||
ptrMap, ok := mapping[name]
|
||||
if ok {
|
||||
ptrs[i] = ptrMap
|
||||
continue
|
||||
}
|
||||
|
||||
suffix := "." + name
|
||||
for maybeMatch, mapping := range mapping {
|
||||
if strings.HasSuffix(maybeMatch, suffix) {
|
||||
ptrs[i] = mapping
|
||||
continue ColLoop
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.Errorf("could not find struct field name in mapping: %s", name)
|
||||
}
|
||||
|
||||
return ptrs, nil
|
||||
}
|
||||
|
||||
// PtrsFromMapping expects to be passed an addressable struct and a mapping
|
||||
// of where to find things. It pulls the pointers out referred to by the mapping.
|
||||
func PtrsFromMapping(val reflect.Value, mapping []uint64) []interface{} {
|
||||
ptrs := make([]interface{}, len(mapping))
|
||||
for i, m := range mapping {
|
||||
ptrs[i] = ptrFromMapping(val, m, true).Interface()
|
||||
}
|
||||
return ptrs
|
||||
}
|
||||
|
||||
// ValuesFromMapping expects to be passed an addressable struct and a mapping
|
||||
// of where to find things. It pulls the pointers out referred to by the mapping.
|
||||
func ValuesFromMapping(val reflect.Value, mapping []uint64) []interface{} {
|
||||
ptrs := make([]interface{}, len(mapping))
|
||||
for i, m := range mapping {
|
||||
ptrs[i] = ptrFromMapping(val, m, false).Interface()
|
||||
}
|
||||
return ptrs
|
||||
}
|
||||
|
||||
// ptrFromMapping expects to be passed an addressable struct that it's looking
|
||||
// for things on.
|
||||
func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.Value {
|
||||
for i := 0; i < 8; i++ {
|
||||
v := (mapping >> uint(i*8)) & sentinel
|
||||
|
||||
if v == sentinel {
|
||||
if addressOf && val.Kind() != reflect.Ptr {
|
||||
return val.Addr()
|
||||
} else if !addressOf && val.Kind() == reflect.Ptr {
|
||||
return reflect.Indirect(val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
val = val.Field(int(v))
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
}
|
||||
|
||||
panic("could not find pointer from mapping")
|
||||
}
|
||||
|
||||
// MakeStructMapping creates a map of the struct to be able to quickly look
|
||||
// up its pointers and values by name.
|
||||
func MakeStructMapping(typ reflect.Type) map[string]uint64 {
|
||||
fieldMaps := make(map[string]uint64)
|
||||
makeStructMappingHelper(typ, "", 0, 0, fieldMaps)
|
||||
return fieldMaps
|
||||
}
|
||||
|
||||
func makeStructMappingHelper(typ reflect.Type, prefix string, current uint64, depth uint, fieldMaps map[string]uint64) {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
n := typ.NumField()
|
||||
for i := 0; i < n; i++ {
|
||||
f := typ.Field(i)
|
||||
|
||||
tag, recurse := getBoilTag(f)
|
||||
if len(tag) == 0 {
|
||||
tag = f.Name
|
||||
} else if tag[0] == '-' {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(prefix) != 0 {
|
||||
tag = fmt.Sprintf("%s.%s", prefix, tag)
|
||||
}
|
||||
|
||||
if recurse {
|
||||
makeStructMappingHelper(f.Type, tag, current|uint64(i)<<depth, depth+8, fieldMaps)
|
||||
continue
|
||||
}
|
||||
|
||||
fieldMaps[tag] = current | (sentinel << (depth + 8)) | (uint64(i) << depth)
|
||||
}
|
||||
}
|
||||
|
||||
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
|
||||
tag := field.Tag.Get("boil")
|
||||
name = field.Name
|
||||
|
||||
if len(tag) == 0 {
|
||||
return name, false
|
||||
}
|
||||
|
||||
ind := strings.IndexByte(tag, ',')
|
||||
if ind == -1 {
|
||||
return strmangle.TitleCase(tag), false
|
||||
} else if ind == 0 {
|
||||
return name, true
|
||||
}
|
||||
|
||||
nameFragment := tag[:ind]
|
||||
return strmangle.TitleCase(nameFragment), true
|
||||
}
|
||||
|
||||
func makeCacheKey(typ string, cols []string) string {
|
||||
buf := strmangle.GetBuffer()
|
||||
buf.WriteString(typ)
|
||||
for _, s := range cols {
|
||||
buf.WriteString(s)
|
||||
}
|
||||
mapKey := buf.String()
|
||||
strmangle.PutBuffer(buf)
|
||||
|
||||
return mapKey
|
||||
}
|
||||
|
||||
// GetStructValues returns the values (as interface) of the matching columns in obj
|
||||
func GetStructValues(obj interface{}, columns ...string) []interface{} {
|
||||
ret := make([]interface{}, len(columns))
|
||||
val := reflect.Indirect(reflect.ValueOf(obj))
|
||||
|
||||
for i, c := range columns {
|
||||
fieldName := strmangle.TitleCase(c)
|
||||
field := val.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
|
||||
}
|
||||
ret[i] = field.Interface()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetSliceValues returns the values (as interface) of the matching columns in obj.
|
||||
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
|
||||
ret := make([]interface{}, len(slice)*len(columns))
|
||||
|
||||
for i, obj := range slice {
|
||||
val := reflect.Indirect(reflect.ValueOf(obj))
|
||||
for j, c := range columns {
|
||||
fieldName := strmangle.TitleCase(c)
|
||||
field := val.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
|
||||
}
|
||||
ret[i*len(columns)+j] = field.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetStructPointers returns a slice of pointers to the matching columns in obj
|
||||
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
|
||||
val := reflect.ValueOf(obj).Elem()
|
||||
|
||||
var ln int
|
||||
var getField func(reflect.Value, int) reflect.Value
|
||||
|
||||
if len(columns) == 0 {
|
||||
ln = val.NumField()
|
||||
getField = func(v reflect.Value, i int) reflect.Value {
|
||||
return v.Field(i)
|
||||
}
|
||||
} else {
|
||||
ln = len(columns)
|
||||
getField = func(v reflect.Value, i int) reflect.Value {
|
||||
return v.FieldByName(strmangle.TitleCase(columns[i]))
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]interface{}, ln)
|
||||
for i := 0; i < ln; i++ {
|
||||
field := getField(val, i)
|
||||
|
||||
if !field.IsValid() {
|
||||
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
|
||||
// produce an Invalid value, so we make a hopefully safe assumption here.
|
||||
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
|
||||
}
|
||||
|
||||
ret[i] = field.Addr().Interface()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
707
queries/reflect_test.go
Normal file
707
queries/reflect_test.go
Normal file
|
@ -0,0 +1,707 @@
|
|||
package queries
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||
"gopkg.in/nullbio/null.v5"
|
||||
)
|
||||
|
||||
func bin64(i uint64) string {
|
||||
str := strconv.FormatUint(i, 2)
|
||||
pad := 64 - len(str)
|
||||
if pad > 0 {
|
||||
str = strings.Repeat("0", pad) + str
|
||||
}
|
||||
|
||||
var newStr string
|
||||
for i := 0; i < len(str); i += 8 {
|
||||
if i != 0 {
|
||||
newStr += " "
|
||||
}
|
||||
newStr += str[i : i+8]
|
||||
}
|
||||
|
||||
return newStr
|
||||
}
|
||||
|
||||
type mockRowMaker struct {
|
||||
int
|
||||
rows []driver.Value
|
||||
}
|
||||
|
||||
func TestBindStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResults := struct {
|
||||
ID int
|
||||
Name string `boil:"test"`
|
||||
}{}
|
||||
|
||||
query := &Query{
|
||||
from: []string{"fun"},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ret := sqlmock.NewRows([]string{"id", "test"})
|
||||
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
|
||||
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
|
||||
|
||||
SetExecutor(query, db)
|
||||
err = query.Bind(&testResults)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if id := testResults.ID; id != 35 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults.Name; name != "pat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResults := []struct {
|
||||
ID int
|
||||
Name string `boil:"test"`
|
||||
}{}
|
||||
|
||||
query := &Query{
|
||||
from: []string{"fun"},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ret := sqlmock.NewRows([]string{"id", "test"})
|
||||
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
|
||||
ret.AddRow(driver.Value(int64(12)), driver.Value("cat"))
|
||||
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
|
||||
|
||||
SetExecutor(query, db)
|
||||
err = query.Bind(&testResults)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(testResults) != 2 {
|
||||
t.Fatal("wrong number of results:", len(testResults))
|
||||
}
|
||||
if id := testResults[0].ID; id != 35 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults[0].Name; name != "pat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if id := testResults[1].ID; id != 12 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults[1].Name; name != "cat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindPtrSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResults := []*struct {
|
||||
ID int
|
||||
Name string `boil:"test"`
|
||||
}{}
|
||||
|
||||
query := &Query{
|
||||
from: []string{"fun"},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ret := sqlmock.NewRows([]string{"id", "test"})
|
||||
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
|
||||
ret.AddRow(driver.Value(int64(12)), driver.Value("cat"))
|
||||
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
|
||||
|
||||
SetExecutor(query, db)
|
||||
err = query.Bind(&testResults)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(testResults) != 2 {
|
||||
t.Fatal("wrong number of results:", len(testResults))
|
||||
}
|
||||
if id := testResults[0].ID; id != 35 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults[0].Name; name != "pat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if id := testResults[1].ID; id != 12 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults[1].Name; name != "cat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testMakeMapping(byt ...byte) uint64 {
|
||||
var x uint64
|
||||
for i, b := range byt {
|
||||
x |= uint64(b) << (uint(i) * 8)
|
||||
}
|
||||
x |= uint64(255) << uint(len(byt)*8)
|
||||
return x
|
||||
}
|
||||
|
||||
func TestMakeStructMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var testStruct = struct {
|
||||
LastName string `boil:"different"`
|
||||
AwesomeName string `boil:"awesome_name"`
|
||||
Face string `boil:"-"`
|
||||
Nose string
|
||||
|
||||
Nested struct {
|
||||
LastName string `boil:"different"`
|
||||
AwesomeName string `boil:"awesome_name"`
|
||||
Face string `boil:"-"`
|
||||
Nose string
|
||||
|
||||
Nested2 struct {
|
||||
Nose string
|
||||
} `boil:",bind"`
|
||||
} `boil:",bind"`
|
||||
}{}
|
||||
|
||||
got := MakeStructMapping(reflect.TypeOf(testStruct))
|
||||
|
||||
expectMap := map[string]uint64{
|
||||
"Different": testMakeMapping(0),
|
||||
"AwesomeName": testMakeMapping(1),
|
||||
"Nose": testMakeMapping(3),
|
||||
"Nested.Different": testMakeMapping(4, 0),
|
||||
"Nested.AwesomeName": testMakeMapping(4, 1),
|
||||
"Nested.Nose": testMakeMapping(4, 3),
|
||||
"Nested.Nested2.Nose": testMakeMapping(4, 4, 0),
|
||||
}
|
||||
|
||||
for expName, expVal := range expectMap {
|
||||
gotVal, ok := got[expName]
|
||||
if !ok {
|
||||
t.Errorf("%s) had no value", expName)
|
||||
continue
|
||||
}
|
||||
|
||||
if gotVal != expVal {
|
||||
t.Errorf("%s) wrong value,\nwant: %x (%s)\ngot: %x (%s)", expName, expVal, bin64(expVal), gotVal, bin64(gotVal))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtrFromMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type NestedPtrs struct {
|
||||
Int int
|
||||
IntP *int
|
||||
NestedPtrsP *NestedPtrs
|
||||
}
|
||||
|
||||
val := &NestedPtrs{
|
||||
Int: 5,
|
||||
IntP: new(int),
|
||||
NestedPtrsP: &NestedPtrs{
|
||||
Int: 6,
|
||||
IntP: new(int),
|
||||
},
|
||||
}
|
||||
|
||||
v := ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(0), true)
|
||||
if got := *v.Interface().(*int); got != 5 {
|
||||
t.Error("flat int was wrong:", got)
|
||||
}
|
||||
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(1), true)
|
||||
if got := *v.Interface().(*int); got != 0 {
|
||||
t.Error("flat pointer was wrong:", got)
|
||||
}
|
||||
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 0), true)
|
||||
if got := *v.Interface().(*int); got != 6 {
|
||||
t.Error("nested int was wrong:", got)
|
||||
}
|
||||
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 1), true)
|
||||
if got := *v.Interface().(*int); got != 0 {
|
||||
t.Error("nested pointer was wrong:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBoilTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type TestStruct struct {
|
||||
FirstName string `boil:"test_one,bind"`
|
||||
LastName string `boil:"test_two"`
|
||||
MiddleName string `boil:"middle_name,bind"`
|
||||
AwesomeName string `boil:"awesome_name"`
|
||||
Age string `boil:",bind"`
|
||||
Face string `boil:"-"`
|
||||
Nose string
|
||||
}
|
||||
|
||||
var structFields []reflect.StructField
|
||||
typ := reflect.TypeOf(TestStruct{})
|
||||
removeOk := func(thing reflect.StructField, ok bool) reflect.StructField {
|
||||
if !ok {
|
||||
panic("Exploded")
|
||||
}
|
||||
return thing
|
||||
}
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("FirstName")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("LastName")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("MiddleName")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("AwesomeName")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("Age")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("Face")))
|
||||
structFields = append(structFields, removeOk(typ.FieldByName("Nose")))
|
||||
|
||||
expect := []struct {
|
||||
Name string
|
||||
Recurse bool
|
||||
}{
|
||||
{"TestOne", true},
|
||||
{"TestTwo", false},
|
||||
{"MiddleName", true},
|
||||
{"AwesomeName", false},
|
||||
{"Age", true},
|
||||
{"-", false},
|
||||
{"Nose", false},
|
||||
}
|
||||
for i, s := range structFields {
|
||||
name, recurse := getBoilTag(s)
|
||||
if expect[i].Name != name {
|
||||
t.Errorf("Invalid name, expect %q, got %q", expect[i].Name, name)
|
||||
}
|
||||
if expect[i].Recurse != recurse {
|
||||
t.Errorf("Invalid recurse, expect %v, got %v", !recurse, recurse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindChecks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type useless struct {
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
BKind bindKind
|
||||
Fail bool
|
||||
Obj interface{}
|
||||
}{
|
||||
{BKind: kindStruct, Fail: false, Obj: &useless{}},
|
||||
{BKind: kindSliceStruct, Fail: false, Obj: &[]useless{}},
|
||||
{BKind: kindPtrSliceStruct, Fail: false, Obj: &[]*useless{}},
|
||||
{Fail: true, Obj: 5},
|
||||
{Fail: true, Obj: useless{}},
|
||||
{Fail: true, Obj: []useless{}},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
str, sli, bk, err := bindChecks(test.Obj)
|
||||
|
||||
if err != nil {
|
||||
if !test.Fail {
|
||||
t.Errorf("%d) should not fail, got: %v", i, err)
|
||||
}
|
||||
continue
|
||||
} else if test.Fail {
|
||||
t.Errorf("%d) should fail, got: %v", i, bk)
|
||||
continue
|
||||
}
|
||||
|
||||
if s := str.Kind(); s != reflect.Struct {
|
||||
t.Error("struct kind was wrong:", s)
|
||||
}
|
||||
if test.BKind != kindStruct {
|
||||
if s := sli.Kind(); s != reflect.Slice {
|
||||
t.Error("slice kind was wrong:", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindSingular(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResults := struct {
|
||||
ID int
|
||||
Name string `boil:"test"`
|
||||
}{}
|
||||
|
||||
query := &Query{
|
||||
from: []string{"fun"},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ret := sqlmock.NewRows([]string{"id", "test"})
|
||||
ret.AddRow(driver.Value(int64(35)), driver.Value("pat"))
|
||||
mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret)
|
||||
|
||||
SetExecutor(query, db)
|
||||
err = query.Bind(&testResults)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if id := testResults.ID; id != 35 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if name := testResults.Name; name != "pat" {
|
||||
t.Error("wrong name:", name)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBind_InnerJoin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResults := []*struct {
|
||||
Happy struct {
|
||||
ID int `boil:"identifier"`
|
||||
} `boil:",bind"`
|
||||
Fun struct {
|
||||
ID int `boil:"id"`
|
||||
} `boil:",bind"`
|
||||
}{}
|
||||
|
||||
query := &Query{
|
||||
from: []string{"fun"},
|
||||
joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}},
|
||||
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ret := sqlmock.NewRows([]string{"id"})
|
||||
ret.AddRow(driver.Value(int64(10)))
|
||||
ret.AddRow(driver.Value(int64(11)))
|
||||
mock.ExpectQuery(`SELECT "fun"\.\* FROM "fun" INNER JOIN happy as h on fun.id = h.fun_id;`).WillReturnRows(ret)
|
||||
|
||||
SetExecutor(query, db)
|
||||
err = query.Bind(&testResults)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(testResults) != 2 {
|
||||
t.Fatal("wrong number of results:", len(testResults))
|
||||
}
|
||||
if id := testResults[0].Happy.ID; id != 0 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if id := testResults[0].Fun.ID; id != 10 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
|
||||
if id := testResults[1].Happy.ID; id != 0 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
if id := testResults[1].Fun.ID; id != 11 {
|
||||
t.Error("wrong ID:", id)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// func TestBind_InnerJoinSelect(t *testing.T) {
|
||||
// t.Parallel()
|
||||
//
|
||||
// testResults := []*struct {
|
||||
// Happy struct {
|
||||
// ID int
|
||||
// } `boil:"h,bind"`
|
||||
// Fun struct {
|
||||
// ID int
|
||||
// } `boil:",bind"`
|
||||
// }{}
|
||||
//
|
||||
// query := &Query{
|
||||
// selectCols: []string{"fun.id", "h.id"},
|
||||
// from: []string{"fun"},
|
||||
// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
|
||||
// }
|
||||
//
|
||||
// db, mock, err := sqlmock.New()
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
|
||||
// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
|
||||
// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
|
||||
// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
|
||||
//
|
||||
// SetExecutor(query, db)
|
||||
// err = query.Bind(&testResults)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// if len(testResults) != 2 {
|
||||
// t.Fatal("wrong number of results:", len(testResults))
|
||||
// }
|
||||
// if id := testResults[0].Happy.ID; id != 11 {
|
||||
// t.Error("wrong ID:", id)
|
||||
// }
|
||||
// if id := testResults[0].Fun.ID; id != 10 {
|
||||
// t.Error("wrong ID:", id)
|
||||
// }
|
||||
//
|
||||
// if id := testResults[1].Happy.ID; id != 13 {
|
||||
// t.Error("wrong ID:", id)
|
||||
// }
|
||||
// if id := testResults[1].Fun.ID; id != 12 {
|
||||
// t.Error("wrong ID:", id)
|
||||
// }
|
||||
//
|
||||
// if err := mock.ExpectationsWereMet(); err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestBindPtrs_Easy(t *testing.T) {
|
||||
// t.Parallel()
|
||||
//
|
||||
// testStruct := struct {
|
||||
// ID int `boil:"identifier"`
|
||||
// Date time.Time
|
||||
// }{}
|
||||
//
|
||||
// cols := []string{"identifier", "date"}
|
||||
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// if ptrs[0].(*int) != &testStruct.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// if ptrs[1].(*time.Time) != &testStruct.Date {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func TestBindPtrs_Recursive(t *testing.T) {
|
||||
// t.Parallel()
|
||||
//
|
||||
// testStruct := struct {
|
||||
// Happy struct {
|
||||
// ID int `boil:"identifier"`
|
||||
// }
|
||||
// Fun struct {
|
||||
// ID int
|
||||
// } `boil:",bind"`
|
||||
// }{}
|
||||
//
|
||||
// cols := []string{"id", "fun.id"}
|
||||
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// if ptrs[0].(*int) != &testStruct.Fun.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// if ptrs[1].(*int) != &testStruct.Fun.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func TestBindPtrs_RecursiveTags(t *testing.T) {
|
||||
// t.Parallel()
|
||||
//
|
||||
// testStruct := struct {
|
||||
// Happy struct {
|
||||
// ID int `boil:"identifier"`
|
||||
// } `boil:",bind"`
|
||||
// Fun struct {
|
||||
// ID int `boil:"identification"`
|
||||
// } `boil:",bind"`
|
||||
// }{}
|
||||
//
|
||||
// cols := []string{"happy.identifier", "fun.identification"}
|
||||
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// if ptrs[0].(*int) != &testStruct.Happy.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// if ptrs[1].(*int) != &testStruct.Fun.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func TestBindPtrs_Ignore(t *testing.T) {
|
||||
// t.Parallel()
|
||||
//
|
||||
// testStruct := struct {
|
||||
// ID int `boil:"-"`
|
||||
// Happy struct {
|
||||
// ID int
|
||||
// } `boil:",bind"`
|
||||
// }{}
|
||||
//
|
||||
// cols := []string{"id"}
|
||||
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// }
|
||||
//
|
||||
// if ptrs[0].(*int) != &testStruct.Happy.ID {
|
||||
// t.Error("id is the wrong pointer")
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestGetStructValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
timeThing := time.Now()
|
||||
o := struct {
|
||||
TitleThing string
|
||||
Name string
|
||||
ID int
|
||||
Stuff int
|
||||
Things int
|
||||
Time time.Time
|
||||
NullBool null.Bool
|
||||
}{
|
||||
TitleThing: "patrick",
|
||||
Stuff: 10,
|
||||
Things: 0,
|
||||
Time: timeThing,
|
||||
NullBool: null.NewBool(true, false),
|
||||
}
|
||||
|
||||
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
|
||||
if vals[0].(string) != "patrick" {
|
||||
t.Errorf("Want test, got %s", vals[0])
|
||||
}
|
||||
if vals[1].(string) != "" {
|
||||
t.Errorf("Want empty string, got %s", vals[1])
|
||||
}
|
||||
if vals[2].(int) != 0 {
|
||||
t.Errorf("Want 0, got %d", vals[2])
|
||||
}
|
||||
if vals[3].(int) != 10 {
|
||||
t.Errorf("Want 10, got %d", vals[3])
|
||||
}
|
||||
if vals[4].(int) != 0 {
|
||||
t.Errorf("Want 0, got %d", vals[4])
|
||||
}
|
||||
if !vals[5].(time.Time).Equal(timeThing) {
|
||||
t.Errorf("Want %s, got %s", o.Time, vals[5])
|
||||
}
|
||||
if !vals[6].(null.Bool).IsZero() {
|
||||
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSliceValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
o := []struct {
|
||||
ID int
|
||||
Name string
|
||||
}{
|
||||
{5, "a"},
|
||||
{6, "b"},
|
||||
}
|
||||
|
||||
in := make([]interface{}, len(o))
|
||||
in[0] = o[0]
|
||||
in[1] = o[1]
|
||||
|
||||
vals := GetSliceValues(in, "id", "name")
|
||||
if got := vals[0].(int); got != 5 {
|
||||
t.Error(got)
|
||||
}
|
||||
if got := vals[1].(string); got != "a" {
|
||||
t.Error(got)
|
||||
}
|
||||
if got := vals[2].(int); got != 6 {
|
||||
t.Error(got)
|
||||
}
|
||||
if got := vals[3].(string); got != "b" {
|
||||
t.Error(got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStructPointers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
o := struct {
|
||||
Title string
|
||||
ID *int
|
||||
}{
|
||||
Title: "patrick",
|
||||
}
|
||||
|
||||
ptrs := GetStructPointers(&o, "title", "id")
|
||||
*ptrs[0].(*string) = "test"
|
||||
if o.Title != "test" {
|
||||
t.Errorf("Expected test, got %s", o.Title)
|
||||
}
|
||||
x := 5
|
||||
*ptrs[1].(**int) = &x
|
||||
if *o.ID != 5 {
|
||||
t.Errorf("Expected 5, got %d", *o.ID)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue