sqlboiler/boil/reflect.go
2016-07-09 02:39:36 +10:00

546 lines
14 KiB
Go

package boil
import (
"database/sql"
"fmt"
"math"
"math/rand"
"reflect"
"regexp"
"sort"
"time"
"github.com/nullbio/sqlboiler/strmangle"
"gopkg.in/nullbio/null.v4"
)
var (
typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
)
// Bind executes the query and inserts the
// result into the passed in object pointer
func (q *Query) Bind(obj interface{}) error {
typ := reflect.TypeOf(obj)
kind := typ.Kind()
if kind != reflect.Ptr {
return fmt.Errorf("Bind not given a pointer to a slice or struct: %s", typ.String())
}
typ = typ.Elem()
kind = typ.Kind()
if kind == reflect.Struct {
row := ExecQueryOne(q)
err := BindOne(row, q.selectCols, obj)
if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err)
}
} else if kind == reflect.Slice {
rows, err := ExecQueryAll(q)
if err != nil {
return fmt.Errorf("Failed to execute Bind query for %s: %s", q.table, err)
}
err = BindAll(rows, q.selectCols, obj)
if err != nil {
return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.table, err)
}
} else {
return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String())
}
return nil
}
// BindOne inserts the returned row columns into the
// passed in object pointer
func BindOne(row *sql.Row, selectCols []string, obj interface{}) error {
kind := reflect.ValueOf(obj).Kind()
if kind != reflect.Ptr {
return fmt.Errorf("BindOne given a non-pointer type")
}
pointers := GetStructPointers(obj, selectCols...)
if err := row.Scan(pointers...); err != nil {
return fmt.Errorf("Unable to scan into pointers: %s", err)
}
return nil
}
// BindAll inserts the returned rows columns into the
// passed in slice of object pointers
func BindAll(rows *sql.Rows, selectCols []string, obj interface{}) error {
ptrSlice := reflect.ValueOf(obj)
typ := ptrSlice.Type()
ptrSlice = ptrSlice.Elem()
kind := typ.Kind()
var structTyp reflect.Type
for i, exp := range []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct} {
if i != 0 {
typ = typ.Elem()
kind = typ.Kind()
}
if kind != exp {
return fmt.Errorf("[%d] BindAll object type should be *[]*Type but was: %s", i, ptrSlice.Type().String())
}
if kind == reflect.Struct {
structTyp = typ
}
}
for rows.Next() {
newStruct := reflect.New(structTyp)
pointers := GetStructPointers(newStruct.Interface(), selectCols...)
if err := rows.Scan(pointers...); err != nil {
return fmt.Errorf("Unable to scan into pointers: %s", err)
}
ptrSlice.Set(reflect.Append(ptrSlice, newStruct))
}
return nil
}
func checkType(obj interface{}) (reflect.Type, bool, error) {
val := reflect.ValueOf(obj)
typ := val.Type()
kind := val.Kind()
if kind != reflect.Ptr {
return nil, false, fmt.Errorf("Bind must be given pointers to structs but got type: %s, kind: %s", typ.String(), kind)
}
typ = typ.Elem()
kind = typ.Kind()
isSlice := false
switch kind {
case reflect.Slice:
typ = typ.Elem()
kind = typ.Kind()
isSlice = true
case reflect.Struct:
return typ, isSlice, nil
default:
return nil, false, fmt.Errorf("Bind was given an invalid object must be []*T or *T but got type: %s, kind: %s", typ.String(), kind)
}
if kind != reflect.Ptr {
return nil, false, fmt.Errorf("Bind must be given pointers to structs but got type: %s, kind: %s", typ.String(), kind)
}
typ = typ.Elem()
kind = typ.Kind()
if kind != reflect.Struct {
return nil, false, fmt.Errorf("Bind must be a struct but got type: %s, kind: %s", typ.String(), kind)
}
return typ, isSlice, nil
}
// IsZeroValue checks if the variables with matching columns in obj
// are or are not zero values, depending on whether shouldZero is true or false
func IsZeroValue(obj interface{}, shouldZero bool, columns ...string) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for _, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Unable to find variable with column name %s", c))
}
zv := reflect.Zero(field.Type())
if shouldZero && !reflect.DeepEqual(field.Interface(), zv.Interface()) {
errs = append(errs, fmt.Errorf("Column with name %s is not zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
} else if !shouldZero && reflect.DeepEqual(field.Interface(), zv.Interface()) {
errs = append(errs, fmt.Errorf("Column with name %s is zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
}
}
return errs
}
// IsValueMatch checks whether the variables in obj with matching column names
// match the values in the values slice.
func IsValueMatch(obj interface{}, columns []string, values []interface{}) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Unable to find variable with column name %s", c))
}
typ := field.Type().String()
if typ == "time.Time" || typ == "null.Time" {
var timeField reflect.Value
var valTimeStr string
if typ == "time.Time" {
valTimeStr = values[i].(time.Time).String()
timeField = field
} else {
valTimeStr = values[i].(null.Time).Time.String()
timeField = field.FieldByName("Time")
validField := field.FieldByName("Valid")
if validField.Interface() != values[i].(null.Time).Valid {
errs = append(errs, fmt.Errorf("Null.Time column with name %s Valid field does not match: %v ≠ %v", c, values[i].(null.Time).Valid, validField.Interface()))
}
}
if (rgxValidTime.MatchString(valTimeStr) && timeField.Interface() == reflect.Zero(timeField.Type()).Interface()) ||
(!rgxValidTime.MatchString(valTimeStr) && timeField.Interface() != reflect.Zero(timeField.Type()).Interface()) {
errs = append(errs, fmt.Errorf("Time column with name %s Time field does not match: %v ≠ %v", c, values[i], timeField.Interface()))
}
continue
}
if !reflect.DeepEqual(field.Interface(), values[i]) {
errs = append(errs, fmt.Errorf("Column with name %s does not match value: %#v ≠ %#v", c, values[i], field.Interface()))
}
}
return errs
}
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
ret[i] = field.Interface()
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ret []interface{}
if len(columns) == 0 {
fieldsLen := val.NumField()
ret = make([]interface{}, fieldsLen)
for i := 0; i < fieldsLen; i++ {
ret[i] = val.Field(i).Addr().Interface()
}
return ret
}
ret = make([]interface{}, len(columns))
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(c)))
}
field = field.Addr()
ret[i] = field.Interface()
}
return ret
}
// RandomizeSlice takes a pointer to a slice of pointers to objects
// and fills the pointed to objects with random data.
// It will ignore the fields in the blacklist.
func RandomizeSlice(obj interface{}, blacklist ...string) error {
ptrSlice := reflect.ValueOf(obj)
typ := ptrSlice.Type()
ptrSlice = ptrSlice.Elem()
kind := typ.Kind()
var structTyp reflect.Type
for i, exp := range []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct} {
if i != 0 {
typ = typ.Elem()
kind = typ.Kind()
}
if kind != exp {
return fmt.Errorf("[%d] RandomizeSlice object type should be *[]*Type but was: %s", i, ptrSlice.Type().String())
}
if kind == reflect.Struct {
structTyp = typ
}
}
for i := 0; i < ptrSlice.Len(); i++ {
o := ptrSlice.Index(i)
o.Set(reflect.New(structTyp))
if err := RandomizeStruct(o.Interface(), blacklist...); err != nil {
return err
}
}
return nil
}
// RandomizeStruct takes an object and fills it with random data.
// It will ignore the fields in the blacklist.
func RandomizeStruct(str interface{}, blacklist ...string) error {
// Don't modify blacklist
copyBlacklist := make([]string, len(blacklist))
copy(copyBlacklist, blacklist)
blacklist = copyBlacklist
sort.Strings(blacklist)
// Check if it's pointer
value := reflect.ValueOf(str)
kind := value.Kind()
if kind != reflect.Ptr {
return fmt.Errorf("Outer element should be a pointer, given a non-pointer: %T", str)
}
// Check if it's a struct
value = value.Elem()
kind = value.Kind()
if kind != reflect.Struct {
return fmt.Errorf("Inner element should be a struct, given a non-struct: %T", str)
}
typ := value.Type()
nFields := value.NumField()
// Iterate through fields, randomizing
for i := 0; i < nFields; i++ {
fieldVal := value.Field(i)
fieldTyp := typ.Field(i)
found := sort.Search(len(blacklist), func(i int) bool {
return blacklist[i] == fieldTyp.Name
})
if found != len(blacklist) {
continue
}
if err := randomizeField(fieldVal); err != nil {
return err
}
}
return nil
}
// randDate generates a random time.Time between 1850 and 2050.
// Only the Day/Month/Year columns are set so that Dates and DateTimes do
// not cause mismatches in the test data comparisons.
func randDate() time.Time {
t := time.Date(
1850+rand.Intn(200),
time.Month(1+rand.Intn(12)),
1+rand.Intn(25),
0,
0,
0,
0,
time.UTC,
)
return t
}
func randomizeField(field reflect.Value) error {
kind := field.Kind()
typ := field.Type()
var newVal interface{}
if kind == reflect.Struct {
b := rand.Intn(2) == 1
switch typ {
case typeNullBool:
if b {
newVal = null.NewBool(rand.Intn(2) == 1, b)
} else {
newVal = null.NewBool(false, false)
}
case typeNullString:
if b {
newVal = null.NewString(randStr(1), b)
} else {
newVal = null.NewString("", false)
}
case typeNullTime:
if b {
newVal = null.NewTime(randDate(), b)
} else {
newVal = null.NewTime(time.Time{}, false)
}
case typeTime:
newVal = randDate()
case typeNullFloat32:
if b {
newVal = null.NewFloat32(float32(rand.Intn(9))/10.0+float32(rand.Intn(9)), b)
} else {
newVal = null.NewFloat32(0.0, false)
}
case typeNullFloat64:
if b {
newVal = null.NewFloat64(float64(rand.Intn(9))/10.0+float64(rand.Intn(9)), b)
} else {
newVal = null.NewFloat64(0.0, false)
}
case typeNullInt:
if b {
newVal = null.NewInt(rand.Int(), b)
} else {
newVal = null.NewInt(0, false)
}
case typeNullInt8:
if b {
newVal = null.NewInt8(int8(rand.Intn(int(math.MaxInt8))), b)
} else {
newVal = null.NewInt8(0, false)
}
case typeNullInt16:
if b {
newVal = null.NewInt16(int16(rand.Intn(int(math.MaxInt16))), b)
} else {
newVal = null.NewInt16(0, false)
}
case typeNullInt32:
if b {
newVal = null.NewInt32(rand.Int31(), b)
} else {
newVal = null.NewInt32(0, false)
}
case typeNullInt64:
if b {
newVal = null.NewInt64(rand.Int63(), b)
} else {
newVal = null.NewInt64(0, false)
}
case typeNullUint:
if b {
newVal = null.NewUint(uint(rand.Int()), b)
} else {
newVal = null.NewUint(0, false)
}
case typeNullUint8:
if b {
newVal = null.NewUint8(uint8(rand.Intn(int(math.MaxInt8))), b)
} else {
newVal = null.NewUint8(0, false)
}
case typeNullUint16:
if b {
newVal = null.NewUint16(uint16(rand.Intn(int(math.MaxInt16))), b)
} else {
newVal = null.NewUint16(0, false)
}
case typeNullUint32:
if b {
newVal = null.NewUint32(uint32(rand.Int31()), b)
} else {
newVal = null.NewUint32(0, false)
}
case typeNullUint64:
if b {
newVal = null.NewUint64(uint64(rand.Int63()), b)
} else {
newVal = null.NewUint64(0, false)
}
}
} else {
switch kind {
case reflect.Float32:
newVal = float32(rand.Intn(9))/10.0 + float32(rand.Intn(9))
case reflect.Float64:
newVal = float64(rand.Intn(9))/10.0 + float64(rand.Intn(9))
case reflect.Int:
newVal = rand.Int()
case reflect.Int8:
newVal = int8(rand.Intn(int(math.MaxInt8)))
case reflect.Int16:
newVal = int16(rand.Intn(int(math.MaxInt16)))
case reflect.Int32:
newVal = rand.Int31()
case reflect.Int64:
newVal = rand.Int63()
case reflect.Uint:
newVal = uint(rand.Int())
case reflect.Uint8:
newVal = uint8(rand.Intn(int(math.MaxInt8)))
case reflect.Uint16:
newVal = uint16(rand.Intn(int(math.MaxInt16)))
case reflect.Uint32:
newVal = uint32(rand.Int31())
case reflect.Uint64:
newVal = uint64(rand.Int63())
case reflect.Bool:
var b bool
if rand.Intn(2) == 1 {
b = true
}
newVal = b
case reflect.String:
newVal = randStr(1)
case reflect.Slice:
sliceVal := typ.Elem()
if sliceVal.Kind() != reflect.Uint8 {
return fmt.Errorf("unsupported slice type: %T", typ.String())
}
newVal = randByteSlice(5 + rand.Intn(20))
default:
return fmt.Errorf("unsupported type: %T", typ.String())
}
}
field.Set(reflect.ValueOf(newVal))
return nil
}
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randStr(ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabet[rand.Intn(len(alphabet))])
}
return string(str)
}
func randByteSlice(ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(rand.Intn(256))
}
return str
}