Initial nested eager load.

This commit is contained in:
Aaron L 2016-08-29 22:21:32 -07:00
parent 8ab617ef71
commit f2b8f39d47
2 changed files with 189 additions and 37 deletions

View file

@ -14,6 +14,11 @@ var (
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
)
const (
loadMethodPrefix = "Load"
relationshipStructName = "R"
)
// BindP executes the query and inserts the
// result into the passed in object pointer.
// It panics on error. See boil.Bind() documentation.
@ -100,11 +105,13 @@ func (q *Query) BindFast(obj interface{}, titleCases map[string]string) error {
return res
}
if len(q.load) == 0 {
return nil
for _, toLoad := range q.load {
toLoadFragments := strings.Split(toLoad, ".")
if err = loadRelationships(q.executor, toLoadFragments, obj, singular); err != nil {
return err
}
return q.loadRelationships(obj, singular)
}
return nil
}
// loadRelationships dynamically calls the template generated eager load
@ -119,33 +126,35 @@ func (q *Query) BindFast(obj interface{}, titleCases map[string]string) error {
// - singular is passed in to identify whether or not this was a single object
// or a slice that must be loaded into.
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
func (q *Query) loadRelationships(obj interface{}, singular bool) error {
//
// It takes list of nested relationships to load.
func loadRelationships(exec Executor, toLoad []string, obj interface{}, singular bool) error {
typ := reflect.TypeOf(obj).Elem()
if !singular {
typ = typ.Elem().Elem()
}
rel, found := typ.FieldByName("R")
// If the users object has no loaded struct, it must be
// a custom object and we should not attempt to load any relationships.
current := toLoad[0]
r, found := typ.FieldByName(relationshipStructName)
// It's possible a Relationship struct doesn't exist on the struct.
if !found {
return errors.New("load query mod was used but bound struct contained no R field")
return errors.Errorf("attempted to load %s but no R struct was found", current)
}
for _, relationship := range q.load {
// Attempt to find the LoadRelationshipName function
loadMethod, found := rel.Type.MethodByName("Load" + relationship)
loadMethod, found := r.Type.MethodByName(loadMethodPrefix + current)
if !found {
return errors.Errorf("could not find Load%s method for eager loading", relationship)
return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current)
}
execArg := reflect.ValueOf(q.executor)
// Hack to allow nil executors
execArg := reflect.ValueOf(exec)
if !execArg.IsValid() {
execArg = reflect.ValueOf((*sql.DB)(nil))
}
methodArgs := []reflect.Value{
reflect.Indirect(reflect.New(rel.Type)),
reflect.Indirect(reflect.New(r.Type)),
execArg,
reflect.ValueOf(singular),
reflect.ValueOf(obj),
@ -153,13 +162,54 @@ func (q *Query) loadRelationships(obj interface{}, singular bool) error {
resp := loadMethod.Func.Call(methodArgs)
if resp[0].Interface() != nil {
return resp[0].Interface().(error)
return errors.Wrapf(resp[0].Interface().(error), "failed to eager load %s", current)
}
// Pull one off the queue, continue if there's still some to go
toLoad = toLoad[1:]
if len(toLoad) == 0 {
return nil
}
loadedObject := reflect.ValueOf(obj)
// If we eagerly loaded nothing
if loadedObject.IsNil() {
return nil
}
loadedObject = loadedObject.Elem()
// If it's singular we can just immediately call without looping
if singular {
return loadRelationshipsRecurse(exec, current, toLoad, singular, loadedObject)
}
// Loop over all eager loaded objects
ln := loadedObject.Len()
if ln == 0 {
return nil
}
for i := 0; i < ln; i++ {
iter := loadedObject.Index(i).Elem()
if err := loadRelationshipsRecurse(exec, current, toLoad, singular, iter); err != nil {
return err
}
}
return nil
}
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
func loadRelationshipsRecurse(exec Executor, current string, toLoad []string, singular bool, obj reflect.Value) error {
r := obj.FieldByName(relationshipStructName)
if !r.IsValid() || r.IsNil() {
return errors.Errorf("could not traverse into loaded %s relationship to load more things", current)
}
newObj := r.Elem().FieldByName(current)
singular = newObj.Elem().Kind() == reflect.Struct
return loadRelationships(exec, toLoad, newObj.Interface(), singular)
}
// bindChecks resolves information about the bind target, and errors if it's not an object
// we can bind to.
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, singular bool, err error) {

View file

@ -165,14 +165,64 @@ func TestBindSingular(t *testing.T) {
}
var loadFunctionCalled bool
var loadFunctionNestedCalled int
type testRStruct struct{}
type testRStruct struct {
}
type testNestedStruct struct {
ID int
R *testNestedRStruct
}
type testNestedRStruct struct {
ToEagerLoad *testNestedStruct
}
type testNestedSlice struct {
ID int
R *testNestedRSlice
}
type testNestedRSlice struct {
ToEagerLoad *[]*testNestedSlice
}
func (r *testRStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error {
loadFunctionCalled = true
return nil
}
func (r *testNestedRStruct) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedStruct:
x.R = &testNestedRStruct{
&testNestedStruct{ID: 5},
}
case *[]*testNestedStruct:
for _, r := range *x {
r.R = &testNestedRStruct{
&testNestedStruct{ID: 5},
}
}
}
loadFunctionNestedCalled++
return nil
}
func (r *testNestedRSlice) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedSlice:
newSlice := []*testNestedSlice{&testNestedSlice{ID: 5}}
x.R = &testNestedRSlice{&newSlice}
case *[]*testNestedSlice:
newSlice := []*testNestedSlice{&testNestedSlice{ID: 5}}
for _, r := range *x {
r.R = &testNestedRSlice{&newSlice}
}
}
loadFunctionNestedCalled++
return nil
}
func TestLoadRelationshipsSlice(t *testing.T) {
// t.Parallel() Function uses globals
loadFunctionCalled = false
@ -182,8 +232,7 @@ func TestLoadRelationshipsSlice(t *testing.T) {
R *testRStruct
}{}
q := Query{load: []string{"TestOne"}, executor: nil}
if err := q.loadRelationships(&testSlice, false); err != nil {
if err := loadRelationships(nil, []string{"TestOne"}, &testSlice, false); err != nil {
t.Error(err)
}
@ -201,8 +250,7 @@ func TestLoadRelationshipsSingular(t *testing.T) {
R *testRStruct
}{}
q := Query{load: []string{"TestOne"}, executor: nil}
if err := q.loadRelationships(&testSingular, true); err != nil {
if err := loadRelationships(nil, []string{"TestOne"}, &testSingular, true); err != nil {
t.Error(err)
}
@ -211,6 +259,60 @@ func TestLoadRelationshipsSingular(t *testing.T) {
}
}
func TestLoadRelationshipsSliceNested(t *testing.T) {
// t.Parallel() Function uses globals
testSlice := []*testNestedStruct{
{
ID: 5,
},
}
loadFunctionNestedCalled = 0
if err := loadRelationships(nil, []string{"ToEagerLoad", "ToEagerLoad", "ToEagerLoad"}, &testSlice, false); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
testSliceSlice := []*testNestedSlice{
{
ID: 5,
},
}
loadFunctionNestedCalled = 0
if err := loadRelationships(nil, []string{"ToEagerLoad", "ToEagerLoad", "ToEagerLoad"}, &testSliceSlice, false); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
}
func TestLoadRelationshipsSingularNested(t *testing.T) {
// t.Parallel() Function uses globals
testSingular := testNestedStruct{
ID: 5,
}
loadFunctionNestedCalled = 0
if err := loadRelationships(nil, []string{"ToEagerLoad", "ToEagerLoad", "ToEagerLoad"}, &testSingular, true); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
testSingularSlice := testNestedSlice{
ID: 5,
}
loadFunctionNestedCalled = 0
if err := loadRelationships(nil, []string{"ToEagerLoad", "ToEagerLoad", "ToEagerLoad"}, &testSingularSlice, true); err != nil {
t.Error(err)
}
if loadFunctionNestedCalled != 3 {
t.Error("Load function was called:", loadFunctionNestedCalled, "times")
}
}
func TestBind_InnerJoin(t *testing.T) {
t.Parallel()