Add extra type information everywhere for binding.

This commit is contained in:
Aaron L 2016-09-03 10:33:28 -07:00
parent 988cb2bf04
commit 08dc7a5cc1
4 changed files with 133 additions and 55 deletions

View file

@ -52,9 +52,9 @@ func (l loadRelationshipState) buildKey(depth int) string {
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
//
// It takes list of nested relationships to load.
func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, singular bool) error {
func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
typ := reflect.TypeOf(obj).Elem()
if !singular {
if bkind == kindPtrSliceStruct {
typ = typ.Elem().Elem()
}
@ -79,14 +79,14 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
}
val := reflect.ValueOf(obj).Elem()
if !singular {
if bkind == kindPtrSliceStruct {
val = val.Index(0).Elem()
}
methodArgs := []reflect.Value{
val.FieldByName(loaderStructName),
execArg,
reflect.ValueOf(singular),
reflect.ValueOf(bkind == kindStruct),
reflect.ValueOf(obj),
}
resp := loadMethod.Func.Call(methodArgs)
@ -111,8 +111,8 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
loadedObject = reflect.Indirect(loadedObject)
// If it's singular we can just immediately call without looping
if singular {
return s.loadRelationshipsRecurse(depth, singular, loadedObject)
if bkind == kindStruct {
return s.loadRelationshipsRecurse(depth, loadedObject)
}
// Loop over all eager loaded objects
@ -122,7 +122,7 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
}
for i := 0; i < ln; i++ {
iter := loadedObject.Index(i).Elem()
if err := s.loadRelationshipsRecurse(depth, singular, iter); err != nil {
if err := s.loadRelationshipsRecurse(depth, iter); err != nil {
return err
}
}
@ -132,15 +132,16 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
func (s loadRelationshipState) loadRelationshipsRecurse(depth int, singular bool, obj reflect.Value) error {
func (s loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
r := obj.FieldByName(relationshipStructName)
if !r.IsValid() || r.IsNil() {
return errors.Errorf("could not traverse into loaded %s relationship to load more things", s.toLoad[depth])
}
newObj := reflect.Indirect(r).FieldByName(s.toLoad[depth])
singular = reflect.Indirect(newObj).Kind() == reflect.Struct
if !singular {
bkind := kindStruct
if reflect.Indirect(newObj).Kind() != reflect.Struct {
bkind = kindPtrSliceStruct
newObj = newObj.Addr()
}
return s.loadRelationships(depth, newObj.Interface(), singular)
return s.loadRelationships(depth, newObj.Interface(), bkind)
}

View file

@ -89,7 +89,7 @@ func TestLoadRelationshipsSlice(t *testing.T) {
L testLStruct
}{{}}
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, false); err != nil {
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
@ -108,7 +108,7 @@ func TestLoadRelationshipsSingular(t *testing.T) {
L testLStruct
}{}
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, true); err != nil {
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
@ -125,7 +125,7 @@ func TestLoadRelationshipsSliceNested(t *testing.T) {
},
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, false); err != nil {
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
@ -138,7 +138,7 @@ func TestLoadRelationshipsSliceNested(t *testing.T) {
},
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, false); err != nil {
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, kindPtrSliceStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
@ -152,7 +152,7 @@ func TestLoadRelationshipsSingularNested(t *testing.T) {
ID: 3,
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, true); err != nil {
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
@ -163,7 +163,7 @@ func TestLoadRelationshipsSingularNested(t *testing.T) {
ID: 3,
}
loadFunctionNestedCalled = 0
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, true); err != nil {
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
@ -189,7 +189,7 @@ func TestLoadRelationshipsNoReload(t *testing.T) {
toLoad: []string{"ToEagerLoad", "ToEagerLoad"},
}
if err := state.loadRelationships(0, &testSingular, true); err != nil {
if err := state.loadRelationships(0, &testSingular, kindStruct); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 0 {

View file

@ -13,13 +13,20 @@ import (
var (
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
)
var (
mut sync.RWMutex
bindingMaps = make(map[string][]uint64)
)
// Identifies what kind of object we're binding to
type bindKind int
const (
kindStruct bindKind = iota
kindSliceStruct
kindPtrSliceStruct
)
const (
loadMethodPrefix = "Load"
relationshipStructName = "R"
@ -87,7 +94,7 @@ func Bind(rows *sql.Rows, obj interface{}) error {
//
// See documentation for boil.Bind()
func (q *Query) Bind(obj interface{}) error {
structType, sliceType, singular, err := bindChecks(obj)
structType, sliceType, bkind, err := bindChecks(obj)
if err != nil {
return err
}
@ -97,17 +104,21 @@ func (q *Query) Bind(obj interface{}) error {
return errors.Wrap(err, "bind failed to execute query")
}
defer rows.Close()
if res := bind(rows, obj, structType, sliceType, singular); res != nil {
if res := bind(rows, obj, structType, sliceType, bkind); res != nil {
return res
}
if len(q.load) == 0 {
return nil
}
state := loadRelationshipState{
exec: q.executor,
loaded: map[string]struct{}{},
}
for _, toLoad := range q.load {
state.toLoad = strings.Split(toLoad, ".")
if err = state.loadRelationships(0, obj, singular); err != nil {
if err = state.loadRelationships(0, obj, bkind); err != nil {
return err
}
}
@ -116,47 +127,68 @@ func (q *Query) Bind(obj interface{}) error {
// bindChecks resolves information about the bind target, and errors if it's not an object
// we can bind to.
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, singular bool, err error) {
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, bkind bindKind, err error) {
typ := reflect.TypeOf(obj)
kind := typ.Kind()
for i := 0; i < len(bindAccepts); i++ {
exp := bindAccepts[i]
if i != 0 {
typ = typ.Elem()
kind = typ.Kind()
}
if kind != exp {
if exp == reflect.Slice || kind == reflect.Struct {
structType = typ
singular = true
break
}
return nil, nil, false, errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String())
}
switch kind {
case reflect.Struct:
structType = typ
case reflect.Slice:
sliceType = typ
}
setErr := func() {
err = errors.Errorf("obj type should be *Type, *[]Type, or *[]*Type but was %q", reflect.TypeOf(obj).String())
}
return structType, sliceType, singular, nil
for i := 0; ; i++ {
switch i {
case 0:
if kind != reflect.Ptr {
setErr()
return
}
case 1:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindStruct
return
case reflect.Slice:
sliceType = typ
default:
setErr()
return
}
case 2:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindSliceStruct
return
case reflect.Ptr:
default:
setErr()
return
}
case 3:
if kind != reflect.Struct {
setErr()
return
}
structType = typ
bkind = kindPtrSliceStruct
return
}
typ = typ.Elem()
kind = typ.Kind()
}
}
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool) error {
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, bkind bindKind) error {
cols, err := rows.Columns()
if err != nil {
return errors.Wrap(err, "bind failed to get column names")
}
var ptrSlice reflect.Value
if !singular {
switch bkind {
case kindPtrSliceStruct:
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
}
@ -185,9 +217,10 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
var newStruct reflect.Value
var pointers []interface{}
if singular {
switch bkind {
case kindStruct:
pointers = ptrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
} else {
case kindPtrSliceStruct:
newStruct = reflect.New(structType)
pointers = ptrsFromMapping(reflect.Indirect(newStruct), mapping)
}
@ -199,12 +232,13 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
return errors.Wrap(err, "failed to bind pointers to obj")
}
if !singular {
switch bkind {
case kindPtrSliceStruct:
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
}
}
if singular && !foundOne {
if bkind == kindStruct && !foundOne {
return sql.ErrNoRows
}

View file

@ -228,6 +228,49 @@ func TestGetBoilTag(t *testing.T) {
}
}
func TestBindChecks(t *testing.T) {
t.Parallel()
type useless struct {
}
var tests = []struct {
BKind bindKind
Fail bool
Obj interface{}
}{
{BKind: kindStruct, Fail: false, Obj: &useless{}},
{BKind: kindSliceStruct, Fail: false, Obj: &[]useless{}},
{BKind: kindPtrSliceStruct, Fail: false, Obj: &[]*useless{}},
{Fail: true, Obj: 5},
{Fail: true, Obj: useless{}},
{Fail: true, Obj: []useless{}},
}
for i, test := range tests {
str, sli, bk, err := bindChecks(test.Obj)
if err != nil {
if !test.Fail {
t.Errorf("%d) should not fail, got: %v", i, err)
}
continue
} else if test.Fail {
t.Errorf("%d) should fail, got: %v", i, bk)
continue
}
if s := str.Kind(); s != reflect.Struct {
t.Error("struct kind was wrong:", s)
}
if test.BKind != kindStruct {
if s := sli.Kind(); s != reflect.Slice {
t.Error("slice kind was wrong:", s)
}
}
}
}
func TestBindSingular(t *testing.T) {
t.Parallel()