Add extra type information everywhere for binding.
This commit is contained in:
parent
988cb2bf04
commit
08dc7a5cc1
4 changed files with 133 additions and 55 deletions
|
@ -52,9 +52,9 @@ func (l loadRelationshipState) buildKey(depth int) string {
|
|||
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
|
||||
//
|
||||
// It takes list of nested relationships to load.
|
||||
func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, singular bool) error {
|
||||
func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
|
||||
typ := reflect.TypeOf(obj).Elem()
|
||||
if !singular {
|
||||
if bkind == kindPtrSliceStruct {
|
||||
typ = typ.Elem().Elem()
|
||||
}
|
||||
|
||||
|
@ -79,14 +79,14 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
|
|||
}
|
||||
|
||||
val := reflect.ValueOf(obj).Elem()
|
||||
if !singular {
|
||||
if bkind == kindPtrSliceStruct {
|
||||
val = val.Index(0).Elem()
|
||||
}
|
||||
|
||||
methodArgs := []reflect.Value{
|
||||
val.FieldByName(loaderStructName),
|
||||
execArg,
|
||||
reflect.ValueOf(singular),
|
||||
reflect.ValueOf(bkind == kindStruct),
|
||||
reflect.ValueOf(obj),
|
||||
}
|
||||
resp := loadMethod.Func.Call(methodArgs)
|
||||
|
@ -111,8 +111,8 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
|
|||
loadedObject = reflect.Indirect(loadedObject)
|
||||
|
||||
// If it's singular we can just immediately call without looping
|
||||
if singular {
|
||||
return s.loadRelationshipsRecurse(depth, singular, loadedObject)
|
||||
if bkind == kindStruct {
|
||||
return s.loadRelationshipsRecurse(depth, loadedObject)
|
||||
}
|
||||
|
||||
// Loop over all eager loaded objects
|
||||
|
@ -122,7 +122,7 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
|
|||
}
|
||||
for i := 0; i < ln; i++ {
|
||||
iter := loadedObject.Index(i).Elem()
|
||||
if err := s.loadRelationshipsRecurse(depth, singular, iter); err != nil {
|
||||
if err := s.loadRelationshipsRecurse(depth, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -132,15 +132,16 @@ func (s loadRelationshipState) loadRelationships(depth int, obj interface{}, sin
|
|||
|
||||
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
|
||||
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
|
||||
func (s loadRelationshipState) loadRelationshipsRecurse(depth int, singular bool, obj reflect.Value) error {
|
||||
func (s loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
|
||||
r := obj.FieldByName(relationshipStructName)
|
||||
if !r.IsValid() || r.IsNil() {
|
||||
return errors.Errorf("could not traverse into loaded %s relationship to load more things", s.toLoad[depth])
|
||||
}
|
||||
newObj := reflect.Indirect(r).FieldByName(s.toLoad[depth])
|
||||
singular = reflect.Indirect(newObj).Kind() == reflect.Struct
|
||||
if !singular {
|
||||
bkind := kindStruct
|
||||
if reflect.Indirect(newObj).Kind() != reflect.Struct {
|
||||
bkind = kindPtrSliceStruct
|
||||
newObj = newObj.Addr()
|
||||
}
|
||||
return s.loadRelationships(depth, newObj.Interface(), singular)
|
||||
return s.loadRelationships(depth, newObj.Interface(), bkind)
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ func TestLoadRelationshipsSlice(t *testing.T) {
|
|||
L testLStruct
|
||||
}{{}}
|
||||
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, false); err != nil {
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
|
@ -108,7 +108,7 @@ func TestLoadRelationshipsSingular(t *testing.T) {
|
|||
L testLStruct
|
||||
}{}
|
||||
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, true); err != nil {
|
||||
if err := testFakeState("TestOne").loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
|
@ -125,7 +125,7 @@ func TestLoadRelationshipsSliceNested(t *testing.T) {
|
|||
},
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, false); err != nil {
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
|
@ -138,7 +138,7 @@ func TestLoadRelationshipsSliceNested(t *testing.T) {
|
|||
},
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, false); err != nil {
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSliceSlice, kindPtrSliceStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
|
@ -152,7 +152,7 @@ func TestLoadRelationshipsSingularNested(t *testing.T) {
|
|||
ID: 3,
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, true); err != nil {
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
|
@ -163,7 +163,7 @@ func TestLoadRelationshipsSingularNested(t *testing.T) {
|
|||
ID: 3,
|
||||
}
|
||||
loadFunctionNestedCalled = 0
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, true); err != nil {
|
||||
if err := testFakeState("ToEagerLoad", "ToEagerLoad", "ToEagerLoad").loadRelationships(0, &testSingularSlice, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 3 {
|
||||
|
@ -189,7 +189,7 @@ func TestLoadRelationshipsNoReload(t *testing.T) {
|
|||
toLoad: []string{"ToEagerLoad", "ToEagerLoad"},
|
||||
}
|
||||
|
||||
if err := state.loadRelationships(0, &testSingular, true); err != nil {
|
||||
if err := state.loadRelationships(0, &testSingular, kindStruct); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if loadFunctionNestedCalled != 0 {
|
||||
|
|
108
boil/reflect.go
108
boil/reflect.go
|
@ -13,13 +13,20 @@ import (
|
|||
|
||||
var (
|
||||
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
|
||||
)
|
||||
|
||||
var (
|
||||
mut sync.RWMutex
|
||||
bindingMaps = make(map[string][]uint64)
|
||||
)
|
||||
|
||||
// Identifies what kind of object we're binding to
|
||||
type bindKind int
|
||||
|
||||
const (
|
||||
kindStruct bindKind = iota
|
||||
kindSliceStruct
|
||||
kindPtrSliceStruct
|
||||
)
|
||||
|
||||
const (
|
||||
loadMethodPrefix = "Load"
|
||||
relationshipStructName = "R"
|
||||
|
@ -87,7 +94,7 @@ func Bind(rows *sql.Rows, obj interface{}) error {
|
|||
//
|
||||
// See documentation for boil.Bind()
|
||||
func (q *Query) Bind(obj interface{}) error {
|
||||
structType, sliceType, singular, err := bindChecks(obj)
|
||||
structType, sliceType, bkind, err := bindChecks(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -97,17 +104,21 @@ func (q *Query) Bind(obj interface{}) error {
|
|||
return errors.Wrap(err, "bind failed to execute query")
|
||||
}
|
||||
defer rows.Close()
|
||||
if res := bind(rows, obj, structType, sliceType, singular); res != nil {
|
||||
if res := bind(rows, obj, structType, sliceType, bkind); res != nil {
|
||||
return res
|
||||
}
|
||||
|
||||
if len(q.load) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := loadRelationshipState{
|
||||
exec: q.executor,
|
||||
loaded: map[string]struct{}{},
|
||||
}
|
||||
for _, toLoad := range q.load {
|
||||
state.toLoad = strings.Split(toLoad, ".")
|
||||
if err = state.loadRelationships(0, obj, singular); err != nil {
|
||||
if err = state.loadRelationships(0, obj, bkind); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -116,47 +127,68 @@ func (q *Query) Bind(obj interface{}) error {
|
|||
|
||||
// bindChecks resolves information about the bind target, and errors if it's not an object
|
||||
// we can bind to.
|
||||
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, singular bool, err error) {
|
||||
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, bkind bindKind, err error) {
|
||||
typ := reflect.TypeOf(obj)
|
||||
kind := typ.Kind()
|
||||
|
||||
for i := 0; i < len(bindAccepts); i++ {
|
||||
exp := bindAccepts[i]
|
||||
|
||||
if i != 0 {
|
||||
typ = typ.Elem()
|
||||
kind = typ.Kind()
|
||||
}
|
||||
|
||||
if kind != exp {
|
||||
if exp == reflect.Slice || kind == reflect.Struct {
|
||||
structType = typ
|
||||
singular = true
|
||||
break
|
||||
}
|
||||
|
||||
return nil, nil, false, errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String())
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
case reflect.Slice:
|
||||
sliceType = typ
|
||||
}
|
||||
setErr := func() {
|
||||
err = errors.Errorf("obj type should be *Type, *[]Type, or *[]*Type but was %q", reflect.TypeOf(obj).String())
|
||||
}
|
||||
|
||||
return structType, sliceType, singular, nil
|
||||
for i := 0; ; i++ {
|
||||
switch i {
|
||||
case 0:
|
||||
if kind != reflect.Ptr {
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 1:
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
bkind = kindStruct
|
||||
return
|
||||
case reflect.Slice:
|
||||
sliceType = typ
|
||||
default:
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 2:
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
bkind = kindSliceStruct
|
||||
return
|
||||
case reflect.Ptr:
|
||||
default:
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
case 3:
|
||||
if kind != reflect.Struct {
|
||||
setErr()
|
||||
return
|
||||
}
|
||||
structType = typ
|
||||
bkind = kindPtrSliceStruct
|
||||
return
|
||||
}
|
||||
|
||||
typ = typ.Elem()
|
||||
kind = typ.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool) error {
|
||||
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, bkind bindKind) error {
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bind failed to get column names")
|
||||
}
|
||||
|
||||
var ptrSlice reflect.Value
|
||||
if !singular {
|
||||
switch bkind {
|
||||
case kindPtrSliceStruct:
|
||||
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
|
||||
}
|
||||
|
||||
|
@ -185,9 +217,10 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
|
|||
var newStruct reflect.Value
|
||||
var pointers []interface{}
|
||||
|
||||
if singular {
|
||||
switch bkind {
|
||||
case kindStruct:
|
||||
pointers = ptrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
|
||||
} else {
|
||||
case kindPtrSliceStruct:
|
||||
newStruct = reflect.New(structType)
|
||||
pointers = ptrsFromMapping(reflect.Indirect(newStruct), mapping)
|
||||
}
|
||||
|
@ -199,12 +232,13 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
|
|||
return errors.Wrap(err, "failed to bind pointers to obj")
|
||||
}
|
||||
|
||||
if !singular {
|
||||
switch bkind {
|
||||
case kindPtrSliceStruct:
|
||||
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
|
||||
}
|
||||
}
|
||||
|
||||
if singular && !foundOne {
|
||||
if bkind == kindStruct && !foundOne {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
|
|
|
@ -228,6 +228,49 @@ func TestGetBoilTag(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBindChecks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type useless struct {
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
BKind bindKind
|
||||
Fail bool
|
||||
Obj interface{}
|
||||
}{
|
||||
{BKind: kindStruct, Fail: false, Obj: &useless{}},
|
||||
{BKind: kindSliceStruct, Fail: false, Obj: &[]useless{}},
|
||||
{BKind: kindPtrSliceStruct, Fail: false, Obj: &[]*useless{}},
|
||||
{Fail: true, Obj: 5},
|
||||
{Fail: true, Obj: useless{}},
|
||||
{Fail: true, Obj: []useless{}},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
str, sli, bk, err := bindChecks(test.Obj)
|
||||
|
||||
if err != nil {
|
||||
if !test.Fail {
|
||||
t.Errorf("%d) should not fail, got: %v", i, err)
|
||||
}
|
||||
continue
|
||||
} else if test.Fail {
|
||||
t.Errorf("%d) should fail, got: %v", i, bk)
|
||||
continue
|
||||
}
|
||||
|
||||
if s := str.Kind(); s != reflect.Struct {
|
||||
t.Error("struct kind was wrong:", s)
|
||||
}
|
||||
if test.BKind != kindStruct {
|
||||
if s := sli.Kind(); s != reflect.Slice {
|
||||
t.Error("slice kind was wrong:", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindSingular(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue