Start optimization on bind

This commit is contained in:
Aaron L 2016-08-31 00:09:13 -07:00
parent 67ae024439
commit 7fb0e43648
2 changed files with 370 additions and 164 deletions

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/pkg/errors"
@ -259,6 +260,12 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
ptrSlice = reflect.Indirect(reflect.ValueOf(obj))
}
var mapping []uint64
mapping, err = bindMapping(structType, titleCases, cols)
if err != nil {
return err
}
foundOne := false
for rows.Next() {
foundOne = true
@ -266,10 +273,10 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
var pointers []interface{}
if singular {
pointers, err = bindPtrs(obj, titleCases, cols...)
pointers = ptrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping)
} else {
newStruct = reflect.New(structType)
pointers, err = bindPtrs(newStruct.Interface(), titleCases, cols...)
pointers = ptrsFromMapping(reflect.Indirect(newStruct), mapping)
}
if err != nil {
return err
@ -291,24 +298,130 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
return nil
}
func bindPtrs(obj interface{}, titleCases map[string]string, cols ...string) ([]interface{}, error) {
v := reflect.ValueOf(obj)
ptrs := make([]interface{}, len(cols))
func bindMapping(typ reflect.Type, titleCases map[string]string, cols []string) ([]uint64, error) {
ptrs := make([]uint64, len(cols))
mapping := makeStructMapping(typ, titleCases)
ColLoop:
for i, c := range cols {
names := strings.Split(c, ".")
for j, n := range names {
t, ok := titleCases[n]
if ok {
names[j] = t
continue
}
names[j] = strmangle.TitleCase(n)
}
name := strings.Join(names, ".")
ptr, ok := findField(names, titleCases, v)
if !ok {
return nil, errors.Errorf("bindPtrs failed to find field %s", c)
ptrMap, ok := mapping[name]
if ok {
ptrs[i] = ptrMap
continue
}
ptrs[i] = ptr
suffix := "." + name
for maybeMatch, mapping := range mapping {
if strings.HasSuffix(maybeMatch, suffix) {
ptrs[i] = mapping
continue ColLoop
}
}
return nil, errors.Errorf("could not find struct field name in mapping: %s", name)
}
return ptrs, nil
}
// ptrsFromMapping expects to be passed an addressable struct that it's looking
// for things on
func ptrsFromMapping(val reflect.Value, mapping []uint64) []interface{} {
ptrs := make([]interface{}, len(mapping))
for i, m := range mapping {
ptrs[i] = ptrFromMapping(val, m).Interface()
}
return ptrs
}
var sentinel = uint64(255)
// ptrFromMapping expects to be passed an addressable struct that it's looking
// for things on.
func ptrFromMapping(val reflect.Value, mapping uint64) reflect.Value {
for i := 0; i < 8; i++ {
v := (mapping >> uint(i*8)) & sentinel
if v == sentinel {
if val.Kind() != reflect.Ptr {
return val.Addr()
}
return val
}
val = val.Field(int(v))
if val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
}
panic("could not find pointer from mapping")
}
func makeStructMapping(typ reflect.Type, titleCases map[string]string) map[string]uint64 {
fieldMaps := make(map[string]uint64)
makeStructMappingHelper(typ, "", 0, 0, fieldMaps, titleCases)
return fieldMaps
}
func makeStructMappingHelper(typ reflect.Type, prefix string, current uint64, depth uint, fieldMaps map[string]uint64, titleCases map[string]string) {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
tag, recurse := getBoilTag(f, titleCases)
if len(tag) == 0 {
tag = f.Name
} else if tag[0] == '-' {
continue
}
if len(prefix) != 0 {
tag = fmt.Sprintf("%s.%s", prefix, tag)
}
if recurse {
makeStructMappingHelper(f.Type, tag, current|uint64(i)<<depth, depth+8, fieldMaps, titleCases)
continue
}
fieldMaps[tag] = current | (sentinel << (depth + 8)) | (uint64(i) << depth)
}
}
func bin64(i uint64) string {
str := strconv.FormatUint(i, 2)
pad := 64 - len(str)
if pad > 0 {
str = strings.Repeat("0", pad) + str
}
var newStr string
for i := 0; i < len(str); i += 8 {
if i != 0 {
newStr += " "
}
newStr += str[i : i+8]
}
return newStr
}
func findField(names []string, titleCases map[string]string, v reflect.Value) (interface{}, bool) {
if !v.IsValid() || len(names) == 0 {
return nil, false

View file

@ -65,13 +65,106 @@ func TestBind(t *testing.T) {
}
}
func TestGetBoilTag(t *testing.T) {
type TestStruct struct {
FirstName string `boil:"test_one,boil"`
LastName string `boil:"test_two"`
MiddleName string `boil:"middle_name,boil"`
func testMakeMapping(byt ...byte) uint64 {
var x uint64
for i, b := range byt {
x |= uint64(b) << (uint(i) * 8)
}
x |= uint64(255) << uint(len(byt)*8)
return x
}
func TestMakeStructMapping(t *testing.T) {
t.Parallel()
var testStruct = struct {
LastName string `boil:"different"`
AwesomeName string `boil:"awesome_name"`
Age string `boil:",boil"`
Face string `boil:"-"`
Nose string
Nested struct {
LastName string `boil:"different"`
AwesomeName string `boil:"awesome_name"`
Face string `boil:"-"`
Nose string
Nested2 struct {
Nose string
} `boil:",bind"`
} `boil:",bind"`
}{}
got := makeStructMapping(reflect.TypeOf(testStruct), nil)
expectMap := map[string]uint64{
"Different": testMakeMapping(0),
"AwesomeName": testMakeMapping(1),
"Nose": testMakeMapping(3),
"Nested.Different": testMakeMapping(4, 0),
"Nested.AwesomeName": testMakeMapping(4, 1),
"Nested.Nose": testMakeMapping(4, 3),
"Nested.Nested2.Nose": testMakeMapping(4, 4, 0),
}
for expName, expVal := range expectMap {
gotVal, ok := got[expName]
if !ok {
t.Errorf("%s) had no value", expName)
continue
}
if gotVal != expVal {
t.Errorf("%s) wrong value,\nwant: %x (%s)\ngot: %x (%s)", expName, expVal, bin64(expVal), gotVal, bin64(gotVal))
}
}
}
func TestPtrFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
v := ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(0))
if got := *v.Interface().(*int); got != 5 {
t.Error("flat int was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(1))
if got := *v.Interface().(*int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 0))
if got := *v.Interface().(*int); got != 6 {
t.Error("nested int was wrong:", got)
}
v = ptrFromMapping(reflect.Indirect(reflect.ValueOf(val)), testMakeMapping(2, 1))
if got := *v.Interface().(*int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestGetBoilTag(t *testing.T) {
t.Parallel()
type TestStruct struct {
FirstName string `boil:"test_one,bind"`
LastName string `boil:"test_two"`
MiddleName string `boil:"middle_name,bind"`
AwesomeName string `boil:"awesome_name"`
Age string `boil:",bind"`
Face string `boil:"-"`
Nose string
}
@ -370,156 +463,156 @@ func TestBind_InnerJoin(t *testing.T) {
}
}
func TestBind_InnerJoinSelect(t *testing.T) {
t.Parallel()
// func TestBind_InnerJoinSelect(t *testing.T) {
// t.Parallel()
//
// testResults := []*struct {
// Happy struct {
// ID int
// } `boil:"h,bind"`
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// query := &Query{
// selectCols: []string{"fun.id", "h.id"},
// from: []string{"fun"},
// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
// }
//
// db, mock, err := sqlmock.New()
// if err != nil {
// t.Error(err)
// }
//
// ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
//
// SetExecutor(query, db)
// err = query.Bind(&testResults)
// if err != nil {
// t.Error(err)
// }
//
// if len(testResults) != 2 {
// t.Fatal("wrong number of results:", len(testResults))
// }
// if id := testResults[0].Happy.ID; id != 11 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[0].Fun.ID; id != 10 {
// t.Error("wrong ID:", id)
// }
//
// if id := testResults[1].Happy.ID; id != 13 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[1].Fun.ID; id != 12 {
// t.Error("wrong ID:", id)
// }
//
// if err := mock.ExpectationsWereMet(); err != nil {
// t.Error(err)
// }
// }
testResults := []*struct {
Happy struct {
ID int
} `boil:"h,bind"`
Fun struct {
ID int
} `boil:",bind"`
}{}
query := &Query{
selectCols: []string{"fun.id", "h.id"},
from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
}
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
}
if id := testResults[0].Happy.ID; id != 11 {
t.Error("wrong ID:", id)
}
if id := testResults[0].Fun.ID; id != 10 {
t.Error("wrong ID:", id)
}
if id := testResults[1].Happy.ID; id != 13 {
t.Error("wrong ID:", id)
}
if id := testResults[1].Fun.ID; id != 12 {
t.Error("wrong ID:", id)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
func TestBindPtrs_Easy(t *testing.T) {
t.Parallel()
testStruct := struct {
ID int `boil:"identifier"`
Date time.Time
}{}
cols := []string{"identifier", "date"}
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
if ptrs[0].(*int) != &testStruct.ID {
t.Error("id is the wrong pointer")
}
if ptrs[1].(*time.Time) != &testStruct.Date {
t.Error("id is the wrong pointer")
}
}
func TestBindPtrs_Recursive(t *testing.T) {
t.Parallel()
testStruct := struct {
Happy struct {
ID int `boil:"identifier"`
}
Fun struct {
ID int
} `boil:",bind"`
}{}
cols := []string{"id", "fun.id"}
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
if ptrs[0].(*int) != &testStruct.Fun.ID {
t.Error("id is the wrong pointer")
}
if ptrs[1].(*int) != &testStruct.Fun.ID {
t.Error("id is the wrong pointer")
}
}
func TestBindPtrs_RecursiveTags(t *testing.T) {
t.Parallel()
testStruct := struct {
Happy struct {
ID int `boil:"identifier"`
} `boil:",bind"`
Fun struct {
ID int `boil:"identification"`
} `boil:",bind"`
}{}
cols := []string{"happy.identifier", "fun.identification"}
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
if ptrs[0].(*int) != &testStruct.Happy.ID {
t.Error("id is the wrong pointer")
}
if ptrs[1].(*int) != &testStruct.Fun.ID {
t.Error("id is the wrong pointer")
}
}
func TestBindPtrs_Ignore(t *testing.T) {
t.Parallel()
testStruct := struct {
ID int `boil:"-"`
Happy struct {
ID int
} `boil:",bind"`
}{}
cols := []string{"id"}
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
if ptrs[0].(*int) != &testStruct.Happy.ID {
t.Error("id is the wrong pointer")
}
}
// func TestBindPtrs_Easy(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"identifier"`
// Date time.Time
// }{}
//
// cols := []string{"identifier", "date"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*time.Time) != &testStruct.Date {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Recursive(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// }
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id", "fun.id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_RecursiveTags(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// } `boil:",bind"`
// Fun struct {
// ID int `boil:"identification"`
// } `boil:",bind"`
// }{}
//
// cols := []string{"happy.identifier", "fun.identification"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Ignore(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"-"`
// Happy struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// }
func TestGetStructValues(t *testing.T) {
t.Parallel()