sqlboiler/boil/reflect.go

454 lines
11 KiB
Go
Raw Normal View History

package boil
import (
"database/sql"
"fmt"
"reflect"
"strings"
2016-08-31 09:47:35 +02:00
"sync"
"github.com/pkg/errors"
2016-08-09 09:59:30 +02:00
"github.com/vattle/sqlboiler/strmangle"
)
var (
bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct}
2016-09-01 06:18:30 +02:00
mut sync.RWMutex
bindingMaps = make(map[string][]uint64)
)
// Identifies what kind of object we're binding to
type bindKind int
const (
kindStruct bindKind = iota
kindSliceStruct
kindPtrSliceStruct
)
2016-08-30 07:21:32 +02:00
const (
loadMethodPrefix = "Load"
relationshipStructName = "R"
2016-09-02 03:01:20 +02:00
loaderStructName = "L"
sentinel = uint64(255)
2016-08-30 07:21:32 +02:00
)
// BindP executes the query and inserts the
// result into the passed in object pointer.
// It panics on error. See boil.Bind() documentation.
func (q *Query) BindP(obj interface{}) {
if err := q.Bind(obj); err != nil {
panic(WrapErr(err))
}
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// Bind rules:
2016-08-18 03:52:42 +02:00
// - Struct tags control bind, in the form of: `boil:"name,bind"`
2016-08-23 05:14:46 +02:00
// - If "name" is omitted the sql column names that come back are TitleCased
// and matched against the field name.
// - If the "name" part of the struct tag is specified, the given name will
// be used instead of the struct field name for binding.
// - If the "name" of the struct tag is "-", this field will not be bound to.
// - If the ",bind" option is specified on a struct field and that field
// is a struct itself, it will be recursed into to look for fields for binding.
//
// Example Query:
//
// type JoinStruct struct {
2016-08-23 05:14:46 +02:00
// // User1 can have it's struct fields bound to since it specifies
// // ,bind in the struct tag, it will look specifically for
// // fields that are prefixed with "user." returning from the query.
// // For example "user.id" column name will bind to User1.ID
// User1 *models.User `boil:"user,bind"`
2016-08-23 05:14:46 +02:00
// // User2 will follow the same rules as noted above except it will use
// // "friend." as the prefix it's looking for.
// User2 *models.User `boil:"friend,bind"`
// // RandomData will not be recursed into to look for fields to
// // bind and will not be bound to because of the - for the name.
// RandomData myStruct `boil:"-"`
// // Date will not be recursed into to look for fields to bind because
// // it does not specify ,bind in the struct tag. But it can be bound to
// // as it does not specify a - for the name.
// Date time.Time
// }
//
// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct)
2016-08-18 03:52:42 +02:00
//
// For custom objects that want to use eager loading, please see the
// loadRelationships function.
func Bind(rows *sql.Rows, obj interface{}) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
}
2016-09-02 09:09:42 +02:00
return bind(rows, obj, structType, sliceType, singular)
}
// Bind executes the query and inserts the
// result into the passed in object pointer
//
// See documentation for boil.Bind()
func (q *Query) Bind(obj interface{}) error {
structType, sliceType, bkind, err := bindChecks(obj)
if err != nil {
return err
}
rows, err := ExecQueryAll(q)
if err != nil {
return errors.Wrap(err, "bind failed to execute query")
}
defer rows.Close()
if res := bind(rows, obj, structType, sliceType, bkind); res != nil {
return res
}
if len(q.load) == 0 {
return nil
}
2016-09-03 08:53:37 +02:00
state := loadRelationshipState{
exec: q.executor,
loaded: map[string]struct{}{},
}
2016-09-03 08:53:37 +02:00
for _, toLoad := range q.load {
state.toLoad = strings.Split(toLoad, ".")
if err = state.loadRelationships(0, obj, bkind); err != nil {
2016-08-30 07:21:32 +02:00
return err
}
}
return nil
}
// bindChecks resolves information about the bind target, and errors if it's not an object
// we can bind to.
func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Type, bkind bindKind, err error) {
typ := reflect.TypeOf(obj)
kind := typ.Kind()
setErr := func() {
err = errors.Errorf("obj type should be *Type, *[]Type, or *[]*Type but was %q", reflect.TypeOf(obj).String())
}
for i := 0; ; i++ {
switch i {
case 0:
if kind != reflect.Ptr {
setErr()
return
}
case 1:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindStruct
return
case reflect.Slice:
sliceType = typ
default:
setErr()
return
}
case 2:
switch kind {
case reflect.Struct:
structType = typ
bkind = kindSliceStruct
return
case reflect.Ptr:
default:
setErr()
return
}
case 3:
if kind != reflect.Struct {
setErr()
return
}
structType = typ
bkind = kindPtrSliceStruct
return
}
typ = typ.Elem()
kind = typ.Kind()
}
}
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, bkind bindKind) error {
cols, err := rows.Columns()
if err != nil {
return errors.Wrap(err, "bind failed to get column names")
}
var ptrSlice reflect.Value
switch bkind {
case kindSliceStruct, kindPtrSliceStruct:
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
}
2016-08-31 09:09:13 +02:00
var mapping []uint64
2016-08-31 09:47:35 +02:00
var ok bool
2016-09-01 06:18:30 +02:00
mapKey := makeCacheKey(structType.String(), cols)
2016-08-31 09:47:35 +02:00
mut.RLock()
2016-09-01 06:18:30 +02:00
mapping, ok = bindingMaps[mapKey]
2016-08-31 09:47:35 +02:00
mut.RUnlock()
if !ok {
2016-09-02 09:09:42 +02:00
mapping, err = bindMapping(structType, cols)
2016-08-31 09:47:35 +02:00
if err != nil {
return err
}
mut.Lock()
2016-09-01 06:18:30 +02:00
bindingMaps[mapKey] = mapping
2016-08-31 09:47:35 +02:00
mut.Unlock()
2016-08-31 09:09:13 +02:00
}
var oneStruct reflect.Value
if bkind == kindSliceStruct {
oneStruct = reflect.Indirect(reflect.New(structType))
}
foundOne := false
for rows.Next() {
foundOne = true
var newStruct reflect.Value
var pointers []interface{}
switch bkind {
case kindStruct:
2016-08-31 09:09:13 +02:00
pointers = ptrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
case kindSliceStruct:
pointers = ptrsFromMapping(oneStruct, mapping)
case kindPtrSliceStruct:
newStruct = reflect.New(structType)
2016-08-31 09:09:13 +02:00
pointers = ptrsFromMapping(reflect.Indirect(newStruct), mapping)
}
if err != nil {
return err
}
if err := rows.Scan(pointers...); err != nil {
return errors.Wrap(err, "failed to bind pointers to obj")
}
switch bkind {
case kindSliceStruct:
ptrSlice.Set(reflect.Append(ptrSlice, oneStruct))
case kindPtrSliceStruct:
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
}
}
if bkind == kindStruct && !foundOne {
return sql.ErrNoRows
}
return nil
}
2016-09-03 08:53:37 +02:00
// bindMapping creates a mapping that helps look up the pointer for the
// column given.
2016-09-02 09:09:42 +02:00
func bindMapping(typ reflect.Type, cols []string) ([]uint64, error) {
2016-08-31 09:09:13 +02:00
ptrs := make([]uint64, len(cols))
2016-09-02 09:09:42 +02:00
mapping := makeStructMapping(typ)
2016-08-31 09:09:13 +02:00
ColLoop:
for i, c := range cols {
2016-09-02 09:09:42 +02:00
name := strmangle.TitleCaseIdentifier(c)
2016-08-31 09:09:13 +02:00
ptrMap, ok := mapping[name]
if ok {
ptrs[i] = ptrMap
continue
2016-06-07 06:38:17 +02:00
}
2016-08-31 09:09:13 +02:00
suffix := "." + name
for maybeMatch, mapping := range mapping {
if strings.HasSuffix(maybeMatch, suffix) {
ptrs[i] = mapping
continue ColLoop
}
}
return nil, errors.Errorf("could not find struct field name in mapping: %s", name)
}
return ptrs, nil
}
2016-08-31 09:09:13 +02:00
// ptrsFromMapping expects to be passed an addressable struct that it's looking
2016-09-03 08:53:37 +02:00
// for things on.
2016-08-31 09:09:13 +02:00
func ptrsFromMapping(val reflect.Value, mapping []uint64) []interface{} {
ptrs := make([]interface{}, len(mapping))
for i, m := range mapping {
ptrs[i] = ptrFromMapping(val, m).Interface()
}
return ptrs
}
// ptrFromMapping expects to be passed an addressable struct that it's looking
// for things on.
func ptrFromMapping(val reflect.Value, mapping uint64) reflect.Value {
for i := 0; i < 8; i++ {
v := (mapping >> uint(i*8)) & sentinel
if v == sentinel {
if val.Kind() != reflect.Ptr {
return val.Addr()
}
return val
}
val = val.Field(int(v))
if val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
}
panic("could not find pointer from mapping")
}
2016-09-02 09:09:42 +02:00
func makeStructMapping(typ reflect.Type) map[string]uint64 {
2016-08-31 09:09:13 +02:00
fieldMaps := make(map[string]uint64)
2016-09-02 09:09:42 +02:00
makeStructMappingHelper(typ, "", 0, 0, fieldMaps)
2016-08-31 09:09:13 +02:00
return fieldMaps
}
2016-09-02 09:09:42 +02:00
func makeStructMappingHelper(typ reflect.Type, prefix string, current uint64, depth uint, fieldMaps map[string]uint64) {
2016-08-31 09:09:13 +02:00
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
2016-09-02 09:09:42 +02:00
tag, recurse := getBoilTag(f)
2016-08-31 09:09:13 +02:00
if len(tag) == 0 {
tag = f.Name
} else if tag[0] == '-' {
continue
}
if len(prefix) != 0 {
tag = fmt.Sprintf("%s.%s", prefix, tag)
}
if recurse {
2016-09-02 09:09:42 +02:00
makeStructMappingHelper(f.Type, tag, current|uint64(i)<<depth, depth+8, fieldMaps)
2016-08-31 09:09:13 +02:00
continue
}
fieldMaps[tag] = current | (sentinel << (depth + 8)) | (uint64(i) << depth)
}
}
2016-09-02 09:09:42 +02:00
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
tag := field.Tag.Get("boil")
2016-08-26 04:26:58 +02:00
name = field.Name
2016-08-26 04:26:58 +02:00
if len(tag) == 0 {
return name, false
}
ind := strings.IndexByte(tag, ',')
if ind == -1 {
2016-09-02 09:09:42 +02:00
return strmangle.TitleCase(tag), false
2016-08-26 04:26:58 +02:00
} else if ind == 0 {
return name, true
}
2016-08-26 04:26:58 +02:00
nameFragment := tag[:ind]
2016-09-02 09:09:42 +02:00
return strmangle.TitleCase(nameFragment), true
}
2016-09-01 06:18:30 +02:00
func makeCacheKey(typ string, cols []string) string {
buf := strmangle.GetBuffer()
buf.WriteString(typ)
for _, s := range cols {
buf.WriteString(s)
}
mapKey := buf.String()
strmangle.PutBuffer(buf)
return mapKey
}
// GetStructValues returns the values (as interface) of the matching columns in obj
2016-09-02 09:09:42 +02:00
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
2016-09-02 09:09:42 +02:00
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
2016-08-15 14:43:10 +02:00
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
2016-08-15 14:43:10 +02:00
}
ret[i] = field.Interface()
}
return ret
}
// GetSliceValues returns the values (as interface) of the matching columns in obj.
2016-09-02 09:09:42 +02:00
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(slice)*len(columns))
for i, obj := range slice {
val := reflect.Indirect(reflect.ValueOf(obj))
for j, c := range columns {
2016-09-02 09:09:42 +02:00
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i*len(columns)+j] = field.Interface()
}
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
2016-09-02 09:09:42 +02:00
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
2016-08-06 23:37:55 +02:00
var ln int
var getField func(reflect.Value, int) reflect.Value
2016-06-07 06:38:17 +02:00
if len(columns) == 0 {
2016-08-06 23:37:55 +02:00
ln = val.NumField()
getField = func(v reflect.Value, i int) reflect.Value {
return v.Field(i)
}
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
2016-09-02 09:09:42 +02:00
return v.FieldByName(strmangle.TitleCase(columns[i]))
2016-06-07 06:38:17 +02:00
}
}
2016-08-06 23:37:55 +02:00
ret := make([]interface{}, ln)
for i := 0; i < ln; i++ {
field := getField(val, i)
if !field.IsValid() {
2016-08-06 23:37:55 +02:00
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
// produce an Invalid value, so we make a hopefully safe assumption here.
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
}
2016-08-06 23:37:55 +02:00
ret[i] = field.Addr().Interface()
}
return ret
}