sqlboiler/boil/reflect.go
Aaron L 80c08ce20a Make Bind act on a sql.Rows
- This change was important to be able to call boil.Bind() from other
  contexts than a query. We don't use query mods everywhere, and we
  shouldn't. This allows us to use the bind independent of that.
2016-08-16 22:14:07 -07:00

294 lines
6.9 KiB
Go

package boil
import (
"database/sql"
"fmt"
"reflect"
"strings"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/strmangle"
)
var (
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
)
// BindP executes the query and inserts the
// result into the passed in object pointer.
// It panics on error. See boil.Bind() documentation.
func (q *Query) BindP(obj interface{}) {
if err := q.Bind(obj); err != nil {
panic(WrapErr(err))
}
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// Bind rules:
// - Struct tags control bind, in the form of: `boil:"name,bind"`
// - If the "name" part of the struct tag is specified, that will be used
// for binding instead of the snake_cased field name.
// - If the ",bind" option is specified on an struct field, it will be recursed
// into to look for fields for binding.
// - If the name of the struct tag is "-", this field will not be bound to
//
// Example Query:
//
// type JoinStruct struct {
// User1 *models.User `boil:"user,bind"`
// User2 *models.User `boil:"friend,bind"`
// // RandomData will not be recursed into to look for fields to
// // bind and will not be bound to because of the - for the name.
// RandomData myStruct `boil:"-"`
// // Date will not be recursed into to look for fields to bind because
// // it does not specify ,bind in the struct tag. But it can be bound to
// // as it does not specify a - for the name.
// Date time.Time
// }
//
// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct)
func Bind(rows *sql.Rows, obj interface{}) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
}
return bind(rows, obj, structType, sliceType, singular)
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// See documentation for boil.Bind()
func (q *Query) Bind(obj interface{}) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
}
rows, err := ExecQueryAll(q)
if err != nil {
return errors.Wrap(err, "bind failed to execute query")
}
defer rows.Close()
return bind(rows, obj, structType, sliceType, singular)
}
// bindChecks resolves information about the bind target, and errors if it's not an object
// we can bind to.
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, singular bool, err error) {
typ := reflect.TypeOf(obj)
kind := typ.Kind()
for i := 0; i < len(bindAccepts); i++ {
exp := bindAccepts[i]
if i != 0 {
typ = typ.Elem()
kind = typ.Kind()
}
if kind != exp {
if exp == reflect.Slice || kind == reflect.Struct {
structType = typ
singular = true
break
}
return nil, nil, false, errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String())
}
switch kind {
case reflect.Struct:
structType = typ
case reflect.Slice:
sliceType = typ
}
}
return structType, sliceType, singular, nil
}
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool) error {
cols, err := rows.Columns()
if err != nil {
return errors.Wrap(err, "bind failed to get column names")
}
var ptrSlice reflect.Value
if !singular {
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
}
foundOne := false
for rows.Next() {
foundOne = true
var newStruct reflect.Value
var pointers []interface{}
if singular {
pointers, err = bindPtrs(obj, cols...)
} else {
newStruct = reflect.New(structType)
pointers, err = bindPtrs(newStruct.Interface(), cols...)
}
if err != nil {
return err
}
if err := rows.Scan(pointers...); err != nil {
return errors.Wrap(err, "failed to bind pointers to obj")
}
if !singular {
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
}
}
if singular && !foundOne {
return sql.ErrNoRows
}
return nil
}
func bindPtrs(obj interface{}, cols ...string) ([]interface{}, error) {
v := reflect.ValueOf(obj)
ptrs := make([]interface{}, len(cols))
for i, c := range cols {
names := strings.Split(c, ".")
ptr, ok := findField(names, v)
if !ok {
return nil, errors.Errorf("bindPtrs failed to find field %s", c)
}
ptrs[i] = ptr
}
return ptrs, nil
}
func findField(names []string, v reflect.Value) (interface{}, bool) {
if !v.IsValid() || len(names) == 0 {
return nil, false
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return nil, false
}
v = reflect.Indirect(v)
}
if v.Kind() != reflect.Struct {
return nil, false
}
name := strmangle.TitleCase(names[0])
typ := v.Type()
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
fieldName, recurse := getBoilTag(f)
if fieldName == "-" {
continue
}
if recurse {
if fieldName == name {
names = names[1:]
}
if ptr, ok := findField(names, v.Field(i)); ok {
return ptr, ok
}
}
if fieldName != name || len(names) > 1 {
continue
}
fieldVal := v.Field(i)
if fieldVal.Kind() != reflect.Ptr {
return fieldVal.Addr().Interface(), true
}
return fieldVal.Interface(), true
}
return nil, false
}
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
tag := field.Tag.Get("boil")
if len(tag) != 0 {
tagTokens := strings.Split(tag, ",")
name = strmangle.TitleCase(tagTokens[0])
recurse = len(tagTokens) > 1 && tagTokens[1] == "bind"
}
if len(name) == 0 {
name = field.Name
}
return name, recurse
}
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj))
}
ret[i] = field.Interface()
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ln int
var getField func(reflect.Value, int) reflect.Value
if len(columns) == 0 {
ln = val.NumField()
getField = func(v reflect.Value, i int) reflect.Value {
return v.Field(i)
}
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
return v.FieldByName(strmangle.TitleCase(columns[i]))
}
}
ret := make([]interface{}, ln)
for i := 0; i < ln; i++ {
field := getField(val, i)
if !field.IsValid() {
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
// produce an Invalid value, so we make a hopefully safe assumption here.
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
}
ret[i] = field.Addr().Interface()
}
return ret
}