sqlboiler/queries/eager_load.go

179 lines
5 KiB
Go
Raw Normal View History

package queries
2016-09-03 08:53:37 +02:00
import (
"database/sql"
"fmt"
2016-09-03 08:53:37 +02:00
"reflect"
"strings"
2016-09-03 08:53:37 +02:00
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
2016-09-03 08:53:37 +02:00
"github.com/vattle/sqlboiler/strmangle"
)
type loadRelationshipState struct {
exec boil.Executor
2016-09-03 08:53:37 +02:00
loaded map[string]struct{}
toLoad []string
}
func (l loadRelationshipState) hasLoaded(depth int) bool {
_, ok := l.loaded[l.buildKey(depth)]
return ok
}
func (l loadRelationshipState) setLoaded(depth int) {
l.loaded[l.buildKey(depth)] = struct{}{}
}
func (l loadRelationshipState) buildKey(depth int) string {
buf := strmangle.GetBuffer()
for i, piece := range l.toLoad[:depth+1] {
if i != 0 {
buf.WriteByte('.')
}
buf.WriteString(piece)
}
str := buf.String()
strmangle.PutBuffer(buf)
return str
}
// eagerLoad loads all of the model's relationships
//
// toLoad should look like:
// []string{"Relationship", "Relationship.NestedRelationship"} ... etc
func eagerLoad(exec boil.Executor, toLoad []string, obj interface{}, bkind bindKind) error {
state := loadRelationshipState{
exec: exec,
loaded: map[string]struct{}{},
}
for _, toLoad := range toLoad {
state.toLoad = strings.Split(toLoad, ".")
if err := state.loadRelationships(0, obj, bkind); err != nil {
return err
}
}
return nil
}
2016-09-03 08:53:37 +02:00
// loadRelationships dynamically calls the template generated eager load
// functions of the form:
//
// func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{})
//
// The arguments to this function are:
// - t is not considered here, and is always passed nil. The function exists on a loaded
// struct to avoid a circular dependency with boil, and the receiver is ignored.
// - exec is used to perform additional queries that might be required for loading the relationships.
// - singular is passed in to identify whether or not this was a single object
// or a slice that must be loaded into.
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
2016-09-05 13:33:18 +02:00
func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
2016-09-03 08:53:37 +02:00
typ := reflect.TypeOf(obj).Elem()
if bkind == kindPtrSliceStruct {
2016-09-03 08:53:37 +02:00
typ = typ.Elem().Elem()
}
loadingFrom := reflect.ValueOf(obj)
if loadingFrom.IsNil() {
return nil
}
loadingFrom = reflect.Indirect(loadingFrom)
fmt.Println("load rels", typ.String(), l.toLoad[depth:])
2016-09-05 13:33:18 +02:00
if !l.hasLoaded(depth) {
fmt.Println("!loaded", l.toLoad[depth])
2016-09-05 13:33:18 +02:00
current := l.toLoad[depth]
ln, found := typ.FieldByName(loaderStructName)
2016-09-03 08:53:37 +02:00
// It's possible a Loaders struct doesn't exist on the struct.
if !found {
return errors.Errorf("attempted to load %s but no L struct was found", current)
}
// Attempt to find the LoadRelationshipName function
2016-09-05 13:33:18 +02:00
loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current)
2016-09-03 08:53:37 +02:00
if !found {
return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current)
}
// Hack to allow nil executors
2016-09-05 13:33:18 +02:00
execArg := reflect.ValueOf(l.exec)
2016-09-03 08:53:37 +02:00
if !execArg.IsValid() {
execArg = reflect.ValueOf((*sql.DB)(nil))
}
// Get a loader instance from anything we have, *struct, or *[]*struct
val := reflect.Indirect(loadingFrom)
if bkind == kindPtrSliceStruct {
val = reflect.Indirect(val.Index(0))
2016-09-03 08:53:37 +02:00
}
methodArgs := []reflect.Value{
val.FieldByName(loaderStructName),
execArg,
reflect.ValueOf(bkind == kindStruct),
2016-09-03 08:53:37 +02:00
reflect.ValueOf(obj),
}
resp := loadMethod.Func.Call(methodArgs)
if intf := resp[0].Interface(); intf != nil {
return errors.Wrapf(intf.(error), "failed to eager load %s", current)
}
2016-09-05 13:33:18 +02:00
l.setLoaded(depth)
} else {
fmt.Println("!loading", l.toLoad[depth])
2016-09-03 08:53:37 +02:00
}
// Check if we can stop
if depth+1 >= len(l.toLoad) {
2016-09-03 08:53:37 +02:00
return nil
}
// If it's singular we can just immediately call without looping
if bkind == kindStruct {
return l.loadRelationshipsRecurse(depth, loadingFrom)
2016-09-03 08:53:37 +02:00
}
// Loop over all eager loaded objects
ln := loadingFrom.Len()
2016-09-03 08:53:37 +02:00
if ln == 0 {
return nil
}
for i := 0; i < ln; i++ {
iter := reflect.Indirect(loadingFrom.Index(i))
2016-09-05 13:33:18 +02:00
if err := l.loadRelationshipsRecurse(depth, iter); err != nil {
2016-09-03 08:53:37 +02:00
return err
}
}
return nil
}
// loadRelationshipsRecurse is a helper function for taking a reflect.Value and
// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice
2016-09-05 13:33:18 +02:00
func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
// Get relationship struct, and if it's good to go, grab the value we just loaded.
relationshipStruct := obj.FieldByName(relationshipStructName)
if !relationshipStruct.IsValid() || relationshipStruct.IsNil() {
2016-09-05 13:33:18 +02:00
return errors.Errorf("could not traverse into loaded %s relationship to load more things", l.toLoad[depth])
2016-09-03 08:53:37 +02:00
}
loadedObject := reflect.Indirect(relationshipStruct).FieldByName(l.toLoad[depth])
fmt.Println("loadRecurse", l.toLoad[depth])
// Pop one off the queue
depth++
bkind := kindStruct
if reflect.Indirect(loadedObject).Kind() != reflect.Struct {
bkind = kindPtrSliceStruct
loadedObject = loadedObject.Addr()
2016-09-03 08:53:37 +02:00
}
return l.loadRelationships(depth, loadedObject.Interface(), bkind)
2016-09-03 08:53:37 +02:00
}