Rewrite bind.
- Break bind down into separate functions - Implement naming inference in bind for multiple tables - Make One and All take the same code path mostly
This commit is contained in:
parent
1236072f05
commit
ee2b4e7990
3 changed files with 285 additions and 76 deletions
|
@ -60,8 +60,17 @@ func ExecStatement(q *Query) (sql.Result, error) {
|
|||
return q.executor.Exec(qs, args...)
|
||||
}
|
||||
|
||||
// ExecQuery executes the query for the All finisher and returns multiple rows
|
||||
func ExecQuery(q *Query) (*sql.Rows, error) {
|
||||
// ExecQuery executes the query for the One finisher and returns a single row
|
||||
func ExecQuery(q *Query) *sql.Row {
|
||||
qs, args := buildQuery(q)
|
||||
if DebugMode {
|
||||
fmt.Fprintln(DebugWriter, qs)
|
||||
}
|
||||
return q.executor.QueryRow(qs, args...)
|
||||
}
|
||||
|
||||
// ExecQueryAll executes the query for the All finisher and returns multiple rows
|
||||
func ExecQueryAll(q *Query) (*sql.Rows, error) {
|
||||
qs, args := buildQuery(q)
|
||||
if DebugMode {
|
||||
fmt.Fprintln(DebugWriter, qs)
|
||||
|
|
253
boil/reflect.go
253
boil/reflect.go
|
@ -1,48 +1,17 @@
|
|||
package boil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/nullbio/sqlboiler/strmangle"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Bind executes the query and inserts the
|
||||
// result into the passed in object pointer
|
||||
func (q *Query) Bind(obj interface{}) error {
|
||||
return nil
|
||||
/*typ := reflect.TypeOf(obj)
|
||||
kind := typ.Kind()
|
||||
|
||||
if kind != reflect.Ptr {
|
||||
return fmt.Errorf("Bind not given a pointer to a slice or struct: %s", typ.String())
|
||||
}
|
||||
|
||||
typ = typ.Elem()
|
||||
kind = typ.Kind()
|
||||
|
||||
if kind == reflect.Struct {
|
||||
row := ExecQueryOne(q)
|
||||
err := BindOne(row, q.selectCols, obj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
|
||||
}
|
||||
} else if kind == reflect.Slice {
|
||||
rows, err := ExecQueryAll(q)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err)
|
||||
}
|
||||
err = BindAll(rows, q.selectCols, obj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.from, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String())
|
||||
}
|
||||
|
||||
return nil*/
|
||||
}
|
||||
var (
|
||||
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
|
||||
)
|
||||
|
||||
// BindP executes the query and inserts the
|
||||
// result into the passed in object pointer.
|
||||
|
@ -53,60 +22,210 @@ func (q *Query) BindP(obj interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// BindOne inserts the returned row columns into the
|
||||
// passed in object pointer
|
||||
func BindOne(row *sql.Row, selectCols []string, obj interface{}) error {
|
||||
kind := reflect.ValueOf(obj).Kind()
|
||||
if kind != reflect.Ptr {
|
||||
return fmt.Errorf("BindOne given a non-pointer type")
|
||||
}
|
||||
|
||||
pointers := GetStructPointers(obj, selectCols...)
|
||||
if err := row.Scan(pointers...); err != nil {
|
||||
return fmt.Errorf("Unable to scan into pointers: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindAll inserts the returned rows columns into the
|
||||
// passed in slice of object pointers
|
||||
func BindAll(rows *sql.Rows, selectCols []string, obj interface{}) error {
|
||||
ptrSlice := reflect.ValueOf(obj)
|
||||
typ := ptrSlice.Type()
|
||||
ptrSlice = ptrSlice.Elem()
|
||||
// Bind executes the query and inserts the
|
||||
// result into the passed in object pointer
|
||||
//
|
||||
// Bind rules:
|
||||
// - Struct tags control bind, in the form of: `boil:"name,bind"`
|
||||
// - If the "name" part of the struct tag is specified, that will be used
|
||||
// for binding instead of the snake_cased field name.
|
||||
// - If the ",bind" option is specified on an struct field, it will be recursed
|
||||
// into to look for fields for binding.
|
||||
// - If the name of the struct tag is "-", this field will not be bound to
|
||||
//
|
||||
// Example Query:
|
||||
//
|
||||
// type JoinStruct struct {
|
||||
// User1 *models.User `boil:"user,bind"`
|
||||
// User2 *models.User `boil:"friend,bind"`
|
||||
// // RandomData will not be recursed into to look for fields to
|
||||
// // bind and will not be bound to because of the - for the name.
|
||||
// RandomData myStruct `boil:"-"`
|
||||
// // Date will not be recursed into to look for fields to bind because
|
||||
// // it does not specify ,bind in the struct tag. But it can be bound to
|
||||
// // as it does not specify a - for the name.
|
||||
// Date time.Time
|
||||
// }
|
||||
//
|
||||
// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct)
|
||||
func (q *Query) Bind(obj interface{}) error {
|
||||
typ := reflect.TypeOf(obj)
|
||||
kind := typ.Kind()
|
||||
|
||||
var structTyp reflect.Type
|
||||
var structType reflect.Type
|
||||
var sliceType reflect.Type
|
||||
var singular bool
|
||||
|
||||
for i := 0; i < len(bindAccepts); i++ {
|
||||
exp := bindAccepts[i]
|
||||
|
||||
for i, exp := range []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct} {
|
||||
if i != 0 {
|
||||
typ = typ.Elem()
|
||||
kind = typ.Kind()
|
||||
}
|
||||
|
||||
if kind != exp {
|
||||
return fmt.Errorf("[%d] BindAll object type should be *[]*Type but was: %s", i, ptrSlice.Type().String())
|
||||
if exp == reflect.Slice || kind == reflect.Struct {
|
||||
structType = typ
|
||||
singular = true
|
||||
break
|
||||
}
|
||||
|
||||
if kind == reflect.Struct {
|
||||
structTyp = typ
|
||||
return errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String())
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
structType = typ
|
||||
case reflect.Slice:
|
||||
sliceType = typ
|
||||
}
|
||||
}
|
||||
|
||||
return bind(q, obj, structType, sliceType, singular)
|
||||
}
|
||||
|
||||
func bind(q *Query, obj interface{}, structType, sliceType reflect.Type, singular bool) error {
|
||||
rows, err := ExecQueryAll(q)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bind failed to execute query")
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bind failed to get column names")
|
||||
}
|
||||
|
||||
var ptrSlice reflect.Value
|
||||
if !singular {
|
||||
ptrSlice = reflect.ValueOf(obj)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
newStruct := reflect.New(structTyp)
|
||||
pointers := GetStructPointers(newStruct.Interface(), selectCols...)
|
||||
if err := rows.Scan(pointers...); err != nil {
|
||||
return fmt.Errorf("Unable to scan into pointers: %s", err)
|
||||
var newStruct reflect.Value
|
||||
var pointers []interface{}
|
||||
|
||||
if singular {
|
||||
pointers, err = bindPtrs(obj, cols...)
|
||||
} else {
|
||||
newStruct = reflect.New(structType)
|
||||
pointers, err = bindPtrs(newStruct.Interface(), cols...)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Scan(pointers...); err != nil {
|
||||
return errors.Wrap(err, "failed to bind pointers to obj")
|
||||
}
|
||||
|
||||
if !singular {
|
||||
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindPtrs(obj interface{}, cols ...string) ([]interface{}, error) {
|
||||
v := reflect.ValueOf(obj)
|
||||
ptrs := make([]interface{}, len(cols))
|
||||
|
||||
for i, c := range cols {
|
||||
names := strings.Split(c, ".")
|
||||
|
||||
ptr, ok := findField(names, v)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("bindPtrs failed to find field %s", c)
|
||||
}
|
||||
|
||||
ptrs[i] = ptr
|
||||
}
|
||||
|
||||
return ptrs, nil
|
||||
}
|
||||
|
||||
func findField(names []string, v reflect.Value) (interface{}, bool) {
|
||||
fmt.Println("Names:", names)
|
||||
fmt.Println("Type:", v.Type().String())
|
||||
|
||||
if !v.IsValid() || len(names) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return nil, false
|
||||
}
|
||||
v = reflect.Indirect(v)
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
name := strmangle.TitleCase(names[0])
|
||||
typ := v.Type()
|
||||
|
||||
fi, ok := typ.FieldByName(name)
|
||||
if ok {
|
||||
fieldName, recurse := getBoilTag(fi)
|
||||
if fieldName != "-" {
|
||||
if recurse {
|
||||
return findField(names[1:], v.FieldByName(name))
|
||||
}
|
||||
|
||||
if len(names) == 1 {
|
||||
field := v.FieldByName(name)
|
||||
if field.Kind() != reflect.Ptr {
|
||||
return field.Addr().Interface(), true
|
||||
}
|
||||
return field.Interface(), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n := typ.NumField()
|
||||
for i := 0; i < n; i++ {
|
||||
f := typ.Field(i)
|
||||
fieldName, recurse := getBoilTag(f)
|
||||
|
||||
fmt.Println(name, fieldName, recurse)
|
||||
|
||||
if fieldName == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
if recurse {
|
||||
return findField(names, v.Field(i))
|
||||
}
|
||||
|
||||
if fieldName == name {
|
||||
fieldVal := v.Field(i)
|
||||
if fieldVal.Kind() != reflect.Ptr {
|
||||
return fieldVal.Addr().Interface(), true
|
||||
}
|
||||
return fieldVal.Interface(), true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
|
||||
tag := field.Tag.Get("boil")
|
||||
|
||||
if len(tag) != 0 {
|
||||
tagTokens := strings.Split(tag, ",")
|
||||
name = strmangle.TitleCase(tagTokens[0])
|
||||
recurse = len(tagTokens) > 1 && tagTokens[1] == "bind"
|
||||
} else {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
return name, recurse
|
||||
}
|
||||
|
||||
// GetStructValues returns the values (as interface) of the matching columns in obj
|
||||
func GetStructValues(obj interface{}, columns ...string) []interface{} {
|
||||
ret := make([]interface{}, len(columns))
|
||||
|
|
|
@ -11,16 +11,97 @@ func TestBind(t *testing.T) {
|
|||
t.Skip("Not implemented")
|
||||
}
|
||||
|
||||
func TestBindP(t *testing.T) {
|
||||
t.Skip("Not implemented")
|
||||
func TestBindPtrs_Easy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testStruct := struct {
|
||||
ID int `boil:"identifier"`
|
||||
Date time.Time
|
||||
}{}
|
||||
|
||||
cols := []string{"identifier", "date"}
|
||||
ptrs, err := bindPtrs(&testStruct, cols...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
func TestBindOne(t *testing.T) {
|
||||
t.Skip("Not implemented")
|
||||
if ptrs[0].(*int) != &testStruct.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
if ptrs[1].(*time.Time) != &testStruct.Date {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindAll(t *testing.T) {
|
||||
t.Skip("Not implemented")
|
||||
func TestBindPtrs_Recursive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testStruct := struct {
|
||||
Happy struct {
|
||||
ID int `boil:"identifier"`
|
||||
}
|
||||
Fun struct {
|
||||
ID int
|
||||
} `boil:",bind"`
|
||||
}{}
|
||||
|
||||
cols := []string{"id", "fun.id"}
|
||||
ptrs, err := bindPtrs(&testStruct, cols...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if ptrs[0].(*int) != &testStruct.Fun.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
if ptrs[1].(*int) != &testStruct.Fun.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindPtrs_RecursiveTags(t *testing.T) {
|
||||
testStruct := struct {
|
||||
Happy struct {
|
||||
ID int `boil:"identifier"`
|
||||
} `boil:",bind"`
|
||||
Fun struct {
|
||||
ID int `boil:"identification"`
|
||||
} `boil:",bind"`
|
||||
}{}
|
||||
|
||||
cols := []string{"happy.identifier", "fun.identification"}
|
||||
ptrs, err := bindPtrs(&testStruct, cols...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if ptrs[0].(*int) != &testStruct.Happy.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
if ptrs[1].(*int) != &testStruct.Fun.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindPtrs_Ignore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testStruct := struct {
|
||||
ID int `boil:"-"`
|
||||
Happy struct {
|
||||
ID int
|
||||
} `boil:",bind"`
|
||||
}{}
|
||||
|
||||
cols := []string{"id"}
|
||||
ptrs, err := bindPtrs(&testStruct, cols...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if ptrs[0].(*int) != &testStruct.Happy.ID {
|
||||
t.Error("id is the wrong pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStructValues(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue