diff --git a/boil/qm/query_mods.go b/boil/qm/query_mods.go index 959172f..e7fc934 100644 --- a/boil/qm/query_mods.go +++ b/boil/qm/query_mods.go @@ -19,17 +19,14 @@ func SQL(sql string, args ...interface{}) QueryMod { } } -// Limit the number of returned rows -func Limit(limit int) QueryMod { +// Load allows you to specify foreign key relationships to eager load +// for your query. Passed in relationships need to be in the format +// MyThing or MyThings. +// Relationship name plurality is important, if your relationship is +// singular, you need to specify the singular form and vice versa. +func Load(relationships ...string) QueryMod { return func(q *boil.Query) { - boil.SetLimit(q, limit) - } -} - -// Offset into the results -func Offset(offset int) QueryMod { - return func(q *boil.Query) { - boil.SetOffset(q, offset) + boil.SetLoad(q, relationships...) } } @@ -126,6 +123,20 @@ func From(from string) QueryMod { } } +// Limit the number of returned rows +func Limit(limit int) QueryMod { + return func(q *boil.Query) { + boil.SetLimit(q, limit) + } +} + +// Offset into the results +func Offset(offset int) QueryMod { + return func(q *boil.Query) { + boil.SetOffset(q, offset) + } +} + // Count turns the query into a counting calculation func Count() QueryMod { return func(q *boil.Query) { diff --git a/boil/query.go b/boil/query.go index 146791a..1d7a3e5 100644 --- a/boil/query.go +++ b/boil/query.go @@ -20,6 +20,7 @@ const ( type Query struct { executor Executor plainSQL plainSQL + load []string delete bool update map[string]interface{} selectCols []string @@ -113,6 +114,11 @@ func SetSQL(q *Query, sql string, args ...interface{}) { q.plainSQL = plainSQL{sql: sql, args: args} } +// SetLoad on the query. +func SetLoad(q *Query, relationships ...string) { + q.load = append([]string(nil), relationships...) +} + // SetCount on the query. func SetCount(q *Query) { q.modFunction = "COUNT" diff --git a/boil/query_test.go b/boil/query_test.go index 8952628..9721ddf 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -45,6 +45,21 @@ func TestSetSQL(t *testing.T) { } } +func TestSetLoad(t *testing.T) { + t.Parallel() + + q := &Query{} + SetLoad(q, "one", "two") + + if len(q.load) != 2 { + t.Errorf("Expected len 2, got %d", len(q.load)) + } + + if q.load[0] != "one" || q.load[1] != "two" { + t.Errorf("Was not expected string, got %s", q.load) + } +} + func TestAppendWhere(t *testing.T) { t.Parallel() diff --git a/boil/reflect.go b/boil/reflect.go index e58dccd..5db1e19 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -16,7 +16,7 @@ var ( // BindP executes the query and inserts the // result into the passed in object pointer. -// It panics on error. +// It panics on error. See boil.Bind() documentation. func (q *Query) BindP(obj interface{}) { if err := q.Bind(obj); err != nil { panic(WrapErr(err)) @@ -27,12 +27,13 @@ func (q *Query) BindP(obj interface{}) { // result into the passed in object pointer // // Bind rules: -// - Struct tags control bind, in the form of: `boil:"name,bind"` -// - If the "name" part of the struct tag is specified, that will be used -// for binding instead of the snake_cased field name. -// - If the ",bind" option is specified on an struct field, it will be recursed -// into to look for fields for binding. -// - If the name of the struct tag is "-", this field will not be bound to +// - Struct tags control bind, in the form of: `boil:"name,bind"` +// - If the "name" part of the struct tag is specified, that will be used +// - f the "name" part of the tag is specified, it is used for binding +// (the columns returned are title cased and matched). +// - If the ",bind" option is specified on an struct field, it will be recursed +// into to look for fields for binding. +// - If the name of the struct tag is "-", this field will not be bound to // // Example Query: // @@ -49,14 +50,105 @@ func (q *Query) BindP(obj interface{}) { // } // // models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct) +// +// For custom objects that want to use eager loading, please see the +// loadRelationships function. +func Bind(rows *sql.Rows, obj interface{}) error { + structType, sliceType, singular, err := bindChecks(obj) + + if err != nil { + return err + } + + return bind(rows, obj, structType, sliceType, singular) +} + +// Bind executes the query and inserts the +// result into the passed in object pointer +// +// See documentation for boil.Bind() func (q *Query) Bind(obj interface{}) error { + structType, sliceType, singular, err := bindChecks(obj) + if err != nil { + return err + } + + rows, err := ExecQueryAll(q) + if err != nil { + return errors.Wrap(err, "bind failed to execute query") + } + defer rows.Close() + + if res := bind(rows, obj, structType, sliceType, singular); res != nil { + return res + } + + if len(q.load) == 0 { + return nil + } + + return q.loadRelationships(obj, singular) +} + +// loadRelationships dynamically calls the template generated eager load +// functions of the form: +// +// func (t *TableRelationships) LoadRelationshipName(exec Executor, singular bool, obj interface{}) +// +// The arguments to this function are: +// - t is not considered here, and is always passed nil. The function exists on a relationships +// struct to avoid a circular dependency with boil, and the receiver is ignored. +// - exec is used to perform additional queries that might be required for loading the relationships. +// - singular is passed in to identify whether or not this was a single object +// or a slice that must be loaded into. +// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind. +func (q *Query) loadRelationships(obj interface{}, singular bool) error { + typ := reflect.TypeOf(obj).Elem() + if !singular { + typ = typ.Elem() + } + + rel, found := typ.FieldByName("Relationships") + // If the users object has no Relationships struct, it must be + // a custom object and we should not attempt to load any relationships. + if !found { + return errors.New("load query mod was used but bound struct contained no relationship field") + } + + for _, relationship := range q.load { + // Attempt to find the LoadRelationshipName function + loadMethod, found := rel.Type.MethodByName("Load" + relationship) + if !found { + return errors.Errorf("could not find Load%s method for eager loading", relationship) + } + + execArg := reflect.ValueOf(q.executor) + if !execArg.IsValid() { + execArg = reflect.ValueOf((*sql.DB)(nil)) + } + + methodArgs := []reflect.Value{ + reflect.Indirect(reflect.New(rel.Type)), + execArg, + reflect.ValueOf(singular), + reflect.ValueOf(obj), + } + + resp := loadMethod.Func.Call(methodArgs) + if resp[0].Interface() != nil { + return resp[0].Interface().(error) + } + } + + return nil +} + +// bindChecks resolves information about the bind target, and errors if it's not an object +// we can bind to. +func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, singular bool, err error) { typ := reflect.TypeOf(obj) kind := typ.Kind() - var structType reflect.Type - var sliceType reflect.Type - var singular bool - for i := 0; i < len(bindAccepts); i++ { exp := bindAccepts[i] @@ -72,7 +164,7 @@ func (q *Query) Bind(obj interface{}) error { break } - return errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String()) + return nil, nil, false, errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String()) } switch kind { @@ -83,16 +175,10 @@ func (q *Query) Bind(obj interface{}) error { } } - return bind(q, obj, structType, sliceType, singular) + return structType, sliceType, singular, nil } -func bind(q *Query, obj interface{}, structType, sliceType reflect.Type, singular bool) error { - rows, err := ExecQueryAll(q) - if err != nil { - return errors.Wrap(err, "bind failed to execute query") - } - defer rows.Close() - +func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool) error { cols, err := rows.Columns() if err != nil { return errors.Wrap(err, "bind failed to get column names") @@ -228,7 +314,7 @@ func GetStructValues(obj interface{}, columns ...string) []interface{} { for i, c := range columns { field := val.FieldByName(strmangle.TitleCase(c)) if !field.IsValid() { - panic(fmt.Sprintf("Unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj)) + panic(fmt.Sprintf("unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj)) } ret[i] = field.Interface() } @@ -236,6 +322,24 @@ func GetStructValues(obj interface{}, columns ...string) []interface{} { return ret } +// GetSliceValues returns the values (as interface) of the matching columns in obj. +func GetSliceValues(slice []interface{}, columns ...string) []interface{} { + ret := make([]interface{}, len(slice)*len(columns)) + + for i, obj := range slice { + val := reflect.Indirect(reflect.ValueOf(obj)) + for j, c := range columns { + field := val.FieldByName(strmangle.TitleCase(c)) + if !field.IsValid() { + panic(fmt.Sprintf("unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj)) + } + ret[i*len(columns)+j] = field.Interface() + } + } + + return ret +} + // GetStructPointers returns a slice of pointers to the matching columns in obj func GetStructPointers(obj interface{}, columns ...string) []interface{} { val := reflect.ValueOf(obj).Elem() diff --git a/boil/reflect_test.go b/boil/reflect_test.go index b26ab5e..6409887 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -103,6 +103,53 @@ func TestBindSingular(t *testing.T) { } } +var loadFunctionCalled bool + +type testRelationshipsStruct struct{} + +func (r *testRelationshipsStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error { + loadFunctionCalled = true + return nil +} + +func TestLoadRelationshipsSlice(t *testing.T) { + // t.Parallel() Function uses globals + loadFunctionCalled = false + + testSlice := []*struct { + ID int + Relationships *testRelationshipsStruct + }{} + + q := Query{load: []string{"TestOne"}, executor: nil} + if err := q.loadRelationships(testSlice, false); err != nil { + t.Error(err) + } + + if !loadFunctionCalled { + t.Errorf("Load function was not called for testSlice") + } +} + +func TestLoadRelationshipsSingular(t *testing.T) { + // t.Parallel() Function uses globals + loadFunctionCalled = false + + testSingular := &struct { + ID int + Relationships *testRelationshipsStruct + }{} + + q := Query{load: []string{"TestOne"}, executor: nil} + if err := q.loadRelationships(testSingular, true); err != nil { + t.Error(err) + } + + if !loadFunctionCalled { + t.Errorf("Load function was not called for singular") + } +} + func TestBind_InnerJoin(t *testing.T) { t.Parallel() @@ -352,6 +399,36 @@ func TestGetStructValues(t *testing.T) { } } +func TestGetSliceValues(t *testing.T) { + t.Parallel() + + o := []struct { + ID int + Name string + }{ + {5, "a"}, + {6, "b"}, + } + + in := make([]interface{}, len(o)) + in[0] = o[0] + in[1] = o[1] + + vals := GetSliceValues(in, "id", "name") + if got := vals[0].(int); got != 5 { + t.Error(got) + } + if got := vals[1].(string); got != "a" { + t.Error(got) + } + if got := vals[2].(int); got != 6 { + t.Error(got) + } + if got := vals[3].(string); got != "b" { + t.Error(got) + } +} + func TestGetStructPointers(t *testing.T) { t.Parallel() diff --git a/boil/testing.go b/boil/testing.go index 70e889d..61ec9b4 100644 --- a/boil/testing.go +++ b/boil/testing.go @@ -98,6 +98,11 @@ func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canB continue } + tagVal, _ := getBoilTag(fieldTyp) + if tagVal == "-" { + continue + } + fieldDBType := colTypes[fieldTyp.Name] if err := s.randomizeField(fieldVal, fieldDBType, canBeNull); err != nil { return err diff --git a/templates.go b/templates.go index 5f402dd..71381e8 100644 --- a/templates.go +++ b/templates.go @@ -157,4 +157,5 @@ var templateFunctions = template.FuncMap{ "sqlColDefinitions": bdb.SQLColDefinitions, "columnNames": bdb.ColumnNames, "columnDBTypes": bdb.ColumnDBTypes, + "getTable": bdb.GetTable, } diff --git a/templates/00_struct.tpl b/templates/00_struct.tpl index 089689f..267d917 100644 --- a/templates/00_struct.tpl +++ b/templates/00_struct.tpl @@ -11,7 +11,7 @@ type {{$modelName}} struct { {{end -}} {{- if .Table.IsJoinTable -}} {{- else}} - //Relationships *{{$modelName}}Relationships `boil:"-" json:"-" toml:"-" yaml:"-"` + Relationships *{{$modelName}}Relationships `boil:"-" json:"-" toml:"-" yaml:"-"` {{end -}} } diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index 43c3131..cdd7c7b 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -48,7 +48,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, - qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 1}}"."{{.JoinForeignColumn}}" = "{{id 0}}"."{{.ForeignColumn}}"`), + qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{.Column | titleCase | printf "%s.%s" $rel.Function.Receiver }}), ) {{else -}} diff --git a/templates/relationship_to_many_eager.tpl b/templates/relationship_to_many_eager.tpl new file mode 100644 index 0000000..bd0d02a --- /dev/null +++ b/templates/relationship_to_many_eager.tpl @@ -0,0 +1,124 @@ +{{- if .Table.IsJoinTable -}} +{{- else}} +{{- $dot := . -}} +{{- range .Table.ToManyRelationships -}} +{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- template "relationship_to_one_eager_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .) -}} +{{- else -}} + {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} + {{- $arg := printf "maybe%s" $rel.LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" $rel.LocalTable.NameGo}} +// Load{{$rel.Function.Name}} allows an eager lookup of values, cached into the +// relationships structs of the objects. +func (r *{{$rel.LocalTable.NameGo}}Relationships) Load{{$rel.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { + var slice []*{{$rel.LocalTable.NameGo}} + var object *{{$rel.LocalTable.NameGo}} + + count := 1 + if singular { + object = {{$arg}}.(*{{$rel.LocalTable.NameGo}}) + } else { + slice = {{$arg}}.({{$slice}}) + count = len(slice) + } + + args := make([]interface{}, count) + if singular { + args[0] = object.{{.Column | titleCase}} + } else { + for i, obj := range slice { + args[i] = obj.{{.Column | titleCase}} + } + } + + {{if .ToJoinTable -}} + query := fmt.Sprintf( + `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from "{{.ForeignTable}}" as "{{id 0}}" inner join "{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, + strmangle.Placeholders(count, 1, 1), + ) + {{else -}} + query := fmt.Sprintf( + `select * from "{{.ForeignTable}}" where "{{.ForeignColumn}}" in (%s)`, + strmangle.Placeholders(count, 1, 1), + ) + {{end -}} + + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } + + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") + } + defer results.Close() + + var resultSlice []*{{$rel.ForeignTable.NameGo}} + {{if .ToJoinTable -}} + {{- $foreignTable := getTable $dot.Tables .ForeignTable -}} + {{- $joinTable := getTable $dot.Tables .JoinTable -}} + {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} + var localJoinCols []{{$localCol.Type}} + for results.Next() { + one := new({{$rel.ForeignTable.NameGo}}) + var localJoinCol {{$localCol.Type}} + + err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } + + resultSlice = append(resultSlice, one) + localJoinCols = append(localJoinCols, localJoinCol) + } + + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } + {{else -}} + if err = boil.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") + } + {{end}} + + if singular { + if object.Relationships == nil { + object.Relationships = &{{$rel.LocalTable.NameGo}}Relationships{} + } + object.Relationships.{{$rel.Function.Name}} = resultSlice + return nil + } + + {{if .ToJoinTable -}} + for i, foreign := range resultSlice { + localJoinCol := localJoinCols[i] + for _, local := range slice { + if local.{{$rel.Function.LocalAssignment}} == localJoinCol { + if local.Relationships == nil { + local.Relationships = &{{$rel.LocalTable.NameGo}}Relationships{} + } + local.Relationships.{{$rel.Function.Name}} = append(local.Relationships.{{$rel.Function.Name}}, foreign) + break + } + } + } + {{else -}} + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{$rel.Function.LocalAssignment}} == foreign.{{$rel.Function.ForeignAssignment}} { + if local.Relationships == nil { + local.Relationships = &{{$rel.LocalTable.NameGo}}Relationships{} + } + local.Relationships.{{$rel.Function.Name}} = append(local.Relationships.{{$rel.Function.Name}}, foreign) + break + } + } + } + {{end}} + + return nil +} + +{{end -}}{{/* if ForeignColumnUnique */}} +{{- end -}}{{/* range tomany */}} +{{- end -}}{{/* if isjointable */}} diff --git a/templates/relationship_to_one_eager.tpl b/templates/relationship_to_one_eager.tpl new file mode 100644 index 0000000..6cc0475 --- /dev/null +++ b/templates/relationship_to_one_eager.tpl @@ -0,0 +1,77 @@ +{{- define "relationship_to_one_eager_helper" -}} + {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" .LocalTable.NameGo}} +// Load{{.Function.Name}} allows an eager lookup of values, cached into the +// relationships structs of the objects. +func (r *{{.LocalTable.NameGo}}Relationships) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { + var slice []*{{.LocalTable.NameGo}} + var object *{{.LocalTable.NameGo}} + + count := 1 + if singular { + object = {{$arg}}.(*{{.LocalTable.NameGo}}) + } else { + slice = {{$arg}}.({{$slice}}) + count = len(slice) + } + + args := make([]interface{}, count) + if singular { + args[0] = object.{{.LocalTable.ColumnNameGo}} + } else { + for i, obj := range slice { + args[i] = obj.{{.LocalTable.ColumnNameGo}} + } + } + + query := fmt.Sprintf( + `select * from "{{.ForeignKey.ForeignTable}}" where "{{.ForeignKey.ForeignColumn}}" in (%s)`, + strmangle.Placeholders(count, 1, 1), + ) + + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } + + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") + } + defer results.Close() + + var resultSlice []*{{.ForeignTable.NameGo}} + if err = boil.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") + } + + if singular && len(resultSlice) != 0 { + if object.Relationships == nil { + object.Relationships = &{{.LocalTable.NameGo}}Relationships{} + } + object.Relationships.{{.Function.Name}} = resultSlice[0] + return nil + } + + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { + if local.Relationships == nil { + local.Relationships = &{{.LocalTable.NameGo}}Relationships{} + } + local.Relationships.{{.Function.Name}} = foreign + break + } + } + } + + return nil +} +{{- end -}} +{{- if .Table.IsJoinTable -}} +{{- else -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} +{{- template "relationship_to_one_eager_helper" $rel -}} +{{end -}} +{{- end -}} diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index 138f703..21451d9 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -74,6 +74,21 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { t.Error("expected to find c") } + if err = a.Relationships.Load{{$rel.Function.Name}}(tx, false, {{$rel.LocalTable.NameGo}}Slice{&a}); err != nil { + t.Fatal(err) + } + if got := len(a.Relationships.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } + + a.Relationships.{{$rel.Function.Name}} = nil + if err = a.Relationships.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { + t.Fatal(err) + } + if got := len(a.Relationships.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } + if t.Failed() { t.Logf("%#v", {{$varname}}) } diff --git a/templates_test/relationship_to_one.tpl b/templates_test/relationship_to_one.tpl index ab4c5f2..237d3af 100644 --- a/templates_test/relationship_to_one.tpl +++ b/templates_test/relationship_to_one.tpl @@ -40,6 +40,21 @@ func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}( if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) } + + if err = local.Relationships.Load{{.Function.Name}}(tx, false, {{.LocalTable.NameGo}}Slice{&local}); err != nil { + t.Fatal(err) + } + if local.Relationships.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } + + local.Relationships.{{.Function.Name}} = nil + if err = local.Relationships.Load{{.Function.Name}}(tx, true, &local); err != nil { + t.Fatal(err) + } + if local.Relationships.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } } {{end -}} diff --git a/templates_test/singleton/boil_test_suite.tpl b/templates_test/singleton/boil_test_suite.tpl index f759357..19900ec 100644 --- a/templates_test/singleton/boil_test_suite.tpl +++ b/templates_test/singleton/boil_test_suite.tpl @@ -136,9 +136,9 @@ func TestInsert(t *testing.T) { {{- end -}} } -// The relationship tests cannot be run in parallel +// TestToMany tests cannot be run in parallel // or postgres deadlocks will occur. -func TestRelationships(t *testing.T) { +func TestToMany(t *testing.T) { {{- $dot := .}} {{- range $index, $table := .Tables}} {{- $tableName := $table.Name | plural | titleCase -}} @@ -147,16 +147,31 @@ func TestRelationships(t *testing.T) { {{- range $table.ToManyRelationships -}} {{- $rel := textsFromRelationship $dot.Tables $table . -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- $funcName := $rel.LocalTable.NameGo -}} - t.Run("{{$rel.ForeignTable.NameGo}}ToOne", test{{$rel.ForeignTable.NameGo}}ToOne{{$rel.LocalTable.NameGo}}_{{$funcName}}) + {{- $oneToOne := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + t.Run("{{$oneToOne.LocalTable.NameGo}}OneToOne{{$oneToOne.ForeignTable.NameGo}}_{{$oneToOne.Function.Name}}", test{{$oneToOne.LocalTable.NameGo}}ToOne{{$oneToOne.ForeignTable.NameGo}}_{{$oneToOne.Function.Name}}) {{else -}} - t.Run("{{$rel.LocalTable.NameGo}}ToMany", test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}) + t.Run("{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}", test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}) {{end -}}{{- /* if unique */ -}} {{- end -}}{{- /* range */ -}} {{- end -}}{{- /* outer if join table */ -}} {{- end -}}{{- /* outer tables range */ -}} } +// TestToOne tests cannot be run in parallel +// or postgres deadlocks will occur. +func TestToOne(t *testing.T) { + {{- $dot := . -}} +{{- range $index, $table := .Tables}} + {{- if $table.IsJoinTable -}} + {{- else -}} + {{- range $table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $table . -}} + t.Run("{{$rel.LocalTable.NameGo}}To{{$rel.ForeignTable.NameGo}}_{{$rel.Function.Name}}", test{{$rel.LocalTable.NameGo}}ToOne{{$rel.ForeignTable.NameGo}}_{{$rel.Function.Name}}) + {{end -}}{{- /* fkey range */ -}} + {{- end -}}{{- /* if join table */ -}} +{{- end -}}{{- /* tables range */ -}} +} + func TestReload(t *testing.T) { {{- range $index, $table := .Tables}} {{- if $table.IsJoinTable -}}