From 8806e76d9f1bbd15068571fa5eb86809df65da3a Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 18 Aug 2016 03:56:00 +1000 Subject: [PATCH] Nearly finished relationship bind helper * If only reflection would be nice --- boil/qm/query_mods.go | 31 +++++++++++++++++-------- boil/query.go | 6 +++++ boil/query_test.go | 15 ++++++++++++ boil/reflect.go | 54 +++++++++++++++++++++++++++++++++++++++++-- boil/reflect_test.go | 49 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 143 insertions(+), 12 deletions(-) diff --git a/boil/qm/query_mods.go b/boil/qm/query_mods.go index 0fe1c12..f723380 100644 --- a/boil/qm/query_mods.go +++ b/boil/qm/query_mods.go @@ -19,17 +19,14 @@ func SQL(sql string, args ...interface{}) QueryMod { } } -// Limit the number of returned rows -func Limit(limit int) QueryMod { +// Load allows you to specify foreign key relationships to eager load +// for your query. Passed in relationships need to be in the format +// MyThing or MyThings. +// Relationship name plurality is important, if your relationship is +// singular, you need to specify the singular form and vice versa. +func Load(relationships ...string) QueryMod { return func(q *boil.Query) { - boil.SetLimit(q, limit) - } -} - -// Offset into the results -func Offset(offset int) QueryMod { - return func(q *boil.Query) { - boil.SetOffset(q, offset) + boil.SetLoad(q, relationships...) } } @@ -99,6 +96,20 @@ func From(from string) QueryMod { } } +// Limit the number of returned rows +func Limit(limit int) QueryMod { + return func(q *boil.Query) { + boil.SetLimit(q, limit) + } +} + +// Offset into the results +func Offset(offset int) QueryMod { + return func(q *boil.Query) { + boil.SetOffset(q, offset) + } +} + // Count turns the query into a counting calculation func Count() QueryMod { return func(q *boil.Query) { diff --git a/boil/query.go b/boil/query.go index 8070c54..698a5cf 100644 --- a/boil/query.go +++ b/boil/query.go @@ -20,6 +20,7 @@ const ( type Query struct { executor Executor plainSQL plainSQL + load []string delete bool update map[string]interface{} selectCols []string @@ -101,6 +102,11 @@ func SetSQL(q *Query, sql string, args ...interface{}) { q.plainSQL = plainSQL{sql: sql, args: args} } +// SetLoad on the query. +func SetLoad(q *Query, relationships ...string) { + q.load = append([]string(nil), relationships...) +} + // SetCount on the query. func SetCount(q *Query) { q.modFunction = "COUNT" diff --git a/boil/query_test.go b/boil/query_test.go index 18ad3f1..1e0ede2 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -45,6 +45,21 @@ func TestSetSQL(t *testing.T) { } } +func TestSetLoad(t *testing.T) { + t.Parallel() + + q := &Query{} + SetLoad(q, "one", "two") + + if len(q.load) != 2 { + t.Errorf("Expected len 2, got %d", len(q.load)) + } + + if q.load[0] != "one" || q.load[1] != "two" { + t.Errorf("Was not expected string, got %s", q.load) + } +} + func TestWhere(t *testing.T) { t.Parallel() diff --git a/boil/reflect.go b/boil/reflect.go index 8cb0f55..4769d39 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" + "github.com/davecgh/go-spew/spew" "github.com/pkg/errors" "github.com/vattle/sqlboiler/strmangle" ) @@ -29,7 +30,8 @@ func (q *Query) BindP(obj interface{}) { // Bind rules: // - Struct tags control bind, in the form of: `boil:"name,bind"` // - If the "name" part of the struct tag is specified, that will be used -// for binding instead of the snake_cased field name. +// - f the "name" part of the tag is specified, it is used for binding +// (the columns returned are title cased and matched). // - If the ",bind" option is specified on an struct field, it will be recursed // into to look for fields for binding. // - If the name of the struct tag is "-", this field will not be bound to @@ -75,7 +77,55 @@ func (q *Query) Bind(obj interface{}) error { } defer rows.Close() - return bind(rows, obj, structType, sliceType, singular) + if res := bind(rows, obj, structType, sliceType, singular); res != nil { + return res + } + + if len(q.load) == 0 { + return nil + } + + return q.loadRelationships(obj, singular) +} + +// loadRelationships calls the template generated eager load functions +// (LoadTableName()) using reflection, to eager load the relationships +// into the users Relationships struct attached to their object. +func (q *Query) loadRelationships(obj interface{}, singular bool) error { + typ := reflect.TypeOf(obj).Elem().Elem() + if !singular { + typ = typ.Elem() + } + + rel, found := typ.FieldByName("Relationships") + // If the users object has no Relationships struct, it must be + // a custom object and we should not attempt to load any relationships. + if !found { + return nil + } + + for _, relationship := range q.load { + // Attempt to find the LoadRelationshipName function + loadMethod, found := rel.Type.MethodByName("Load" + relationship) + if !found { + return errors.Errorf("could not find Load%s method for eager loading", relationship) + } + spew.Dump(reflect.New(rel.Type).Interface().(**testRelationshipsStruct)) + spew.Dump(reflect.Indirect(reflect.New(rel.Type)).Interface().(*testRelationshipsStruct)) + methodArgs := []reflect.Value{ + reflect.New(rel.Type), + reflect.ValueOf(q.executor), + reflect.ValueOf(singular), + reflect.ValueOf(obj), + } + + resp := loadMethod.Func.Call(methodArgs) + if resp[0].Interface() != nil { + return resp[0].Interface().(error) + } + } + + return nil } // bindChecks resolves information about the bind target, and errors if it's not an object diff --git a/boil/reflect_test.go b/boil/reflect_test.go index 846cb03..f96ca5c 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -103,6 +103,55 @@ func TestBindSingular(t *testing.T) { } } +var loadFunctionCalled bool + +type testRelationshipsStruct struct{} + +func (r *testRelationshipsStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error { + loadFunctionCalled = true + return nil +} + +func TestLoadRelationships(t *testing.T) { + t.Parallel() + + testSingular := []*struct { + ID int + Relationships *testRelationshipsStruct + }{} + + testSlice := []*struct { + ID int + Relationships *testRelationshipsStruct + }{} + + exec, _, err := sqlmock.New() + if err != nil { + t.Error(err) + } + + q := Query{ + load: []string{"TestOne"}, + executor: exec, + } + + if err := q.loadRelationships(testSlice, true); err != nil { + t.Error(err) + } + + if loadFunctionCalled == false { + t.Errorf("Load function was not called for testSlice") + } + + if err := q.loadRelationships(testSingular, false); err != nil { + t.Error(err) + } + + if loadFunctionCalled == false { + t.Errorf("Load function was not called for testSlice") + } +} + func TestBind_InnerJoin(t *testing.T) { t.Parallel()