Add fast path for binding by caching TitleCase

This commit is contained in:
Aaron L 2016-08-23 23:20:41 -07:00
parent d130354890
commit 28e485603a
17 changed files with 104 additions and 49 deletions

View file

@ -251,6 +251,7 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, err
var fkey bdb.ForeignKey
var sourceTable string
fkey.Table = tableName
err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
if err != nil {
return nil, err

View file

@ -15,6 +15,7 @@ type PrimaryKey struct {
// ForeignKey represents a foreign key constraint in a database
type ForeignKey struct {
Table string
Name string
Column string
Nullable bool

View file

@ -4,6 +4,7 @@ package bdb
// local table has no id, and the foreign table has an id that matches a column
// in the local table.
type ToManyRelationship struct {
Table string
Column string
Nullable bool
Unique bool
@ -57,6 +58,7 @@ func buildRelationship(localTable Table, foreignKey ForeignKey, foreignTable Tab
if !foreignTable.IsJoinTable {
col := localTable.GetColumn(foreignKey.ForeignColumn)
return ToManyRelationship{
Table: foreignKey.Table,
Column: foreignKey.ForeignColumn,
Nullable: col.Nullable,
Unique: col.Unique,
@ -70,6 +72,7 @@ func buildRelationship(localTable Table, foreignKey ForeignKey, foreignTable Tab
col := foreignTable.GetColumn(foreignKey.Column)
relationship := ToManyRelationship{
Table: foreignKey.Table,
Column: foreignKey.ForeignColumn,
Nullable: col.Nullable,
Unique: col.Unique,

View file

@ -9,13 +9,18 @@ import (
// NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
func NonZeroDefaultSet(defaults []string, titleCases map[string]string, obj interface{}) []string {
c := make([]string, 0, len(defaults))
val := reflect.Indirect(reflect.ValueOf(obj))
for _, d := range defaults {
fieldName := strmangle.TitleCase(d)
var fieldName string
if titleCases == nil {
fieldName = strmangle.TitleCase(d)
} else {
fieldName = titleCases[d]
}
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))

View file

@ -59,7 +59,7 @@ func TestNonZeroDefaultSet(t *testing.T) {
}
for i, test := range tests {
z := NonZeroDefaultSet(test.Defaults, test.Obj)
z := NonZeroDefaultSet(test.Defaults, nil, test.Obj)
if !reflect.DeepEqual(test.Ret, z) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
}

View file

@ -61,13 +61,18 @@ func (q *Query) BindP(obj interface{}) {
// For custom objects that want to use eager loading, please see the
// loadRelationships function.
func Bind(rows *sql.Rows, obj interface{}) error {
return BindFast(rows, obj, nil)
}
// BindFast uses a lookup table for column_name to ColumnName to avoid TitleCase.
func BindFast(rows *sql.Rows, obj interface{}, titleCases map[string]string) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
}
return bind(rows, obj, structType, sliceType, singular)
return bind(rows, obj, structType, sliceType, singular, titleCases)
}
// Bind executes the query and inserts the
@ -75,6 +80,11 @@ func Bind(rows *sql.Rows, obj interface{}) error {
//
// See documentation for boil.Bind()
func (q *Query) Bind(obj interface{}) error {
return q.BindFast(obj, nil)
}
// BindFast uses a lookup table for column_name to ColumnName to avoid TitleCase.
func (q *Query) BindFast(obj interface{}, titleCases map[string]string) error {
structType, sliceType, singular, err := bindChecks(obj)
if err != nil {
return err
@ -86,7 +96,7 @@ func (q *Query) Bind(obj interface{}) error {
}
defer rows.Close()
if res := bind(rows, obj, structType, sliceType, singular); res != nil {
if res := bind(rows, obj, structType, sliceType, singular, titleCases); res != nil {
return res
}
@ -185,7 +195,7 @@ func bindChecks(obj interface{}) (structType reflect.Type, sliceType reflect.Typ
return structType, sliceType, singular, nil
}
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool) error {
func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, singular bool, titleCases map[string]string) error {
cols, err := rows.Columns()
if err != nil {
return errors.Wrap(err, "bind failed to get column names")
@ -203,10 +213,10 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
var pointers []interface{}
if singular {
pointers, err = bindPtrs(obj, cols...)
pointers, err = bindPtrs(obj, titleCases, cols...)
} else {
newStruct = reflect.New(structType)
pointers, err = bindPtrs(newStruct.Interface(), cols...)
pointers, err = bindPtrs(newStruct.Interface(), titleCases, cols...)
}
if err != nil {
return err
@ -228,14 +238,14 @@ func bind(rows *sql.Rows, obj interface{}, structType, sliceType reflect.Type, s
return nil
}
func bindPtrs(obj interface{}, cols ...string) ([]interface{}, error) {
func bindPtrs(obj interface{}, titleCases map[string]string, cols ...string) ([]interface{}, error) {
v := reflect.ValueOf(obj)
ptrs := make([]interface{}, len(cols))
for i, c := range cols {
names := strings.Split(c, ".")
ptr, ok := findField(names, v)
ptr, ok := findField(names, titleCases, v)
if !ok {
return nil, errors.Errorf("bindPtrs failed to find field %s", c)
}
@ -246,7 +256,7 @@ func bindPtrs(obj interface{}, cols ...string) ([]interface{}, error) {
return ptrs, nil
}
func findField(names []string, v reflect.Value) (interface{}, bool) {
func findField(names []string, titleCases map[string]string, v reflect.Value) (interface{}, bool) {
if !v.IsValid() || len(names) == 0 {
return nil, false
}
@ -262,13 +272,18 @@ func findField(names []string, v reflect.Value) (interface{}, bool) {
return nil, false
}
name := strmangle.TitleCase(names[0])
var name string
var ok bool
name, ok = titleCases[names[0]]
if !ok {
name = strmangle.TitleCase(names[0])
}
typ := v.Type()
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
fieldName, recurse := getBoilTag(f)
fieldName, recurse := getBoilTag(f, titleCases)
if fieldName == "-" {
continue
@ -278,7 +293,7 @@ func findField(names []string, v reflect.Value) (interface{}, bool) {
if fieldName == name {
names = names[1:]
}
if ptr, ok := findField(names, v.Field(i)); ok {
if ptr, ok := findField(names, titleCases, v.Field(i)); ok {
return ptr, ok
}
}
@ -297,12 +312,16 @@ func findField(names []string, v reflect.Value) (interface{}, bool) {
return nil, false
}
func getBoilTag(field reflect.StructField) (name string, recurse bool) {
func getBoilTag(field reflect.StructField, titleCases map[string]string) (name string, recurse bool) {
tag := field.Tag.Get("boil")
if len(tag) != 0 {
tagTokens := strings.Split(tag, ",")
name = strmangle.TitleCase(tagTokens[0])
var ok bool
name, ok = titleCases[tagTokens[0]]
if !ok {
name = strmangle.TitleCase(tagTokens[0])
}
recurse = len(tagTokens) > 1 && tagTokens[1] == "bind"
}
@ -314,14 +333,20 @@ func getBoilTag(field reflect.StructField) (name string, recurse bool) {
}
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
func GetStructValues(obj interface{}, titleCases map[string]string, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
var fieldName string
if titleCases == nil {
fieldName = strmangle.TitleCase(c)
} else {
fieldName = titleCases[c]
}
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj))
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i] = field.Interface()
}
@ -330,15 +355,22 @@ func GetStructValues(obj interface{}, columns ...string) []interface{} {
}
// GetSliceValues returns the values (as interface) of the matching columns in obj.
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
func GetSliceValues(slice []interface{}, titleCases map[string]string, columns ...string) []interface{} {
ret := make([]interface{}, len(slice)*len(columns))
for i, obj := range slice {
val := reflect.Indirect(reflect.ValueOf(obj))
for j, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
var fieldName string
if titleCases == nil {
fieldName = strmangle.TitleCase(c)
} else {
fieldName = titleCases[c]
}
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", strmangle.TitleCase(c), obj))
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i*len(columns)+j] = field.Interface()
}
@ -348,7 +380,7 @@ func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
func GetStructPointers(obj interface{}, titleCases map[string]string, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ln int
@ -362,7 +394,14 @@ func GetStructPointers(obj interface{}, columns ...string) []interface{} {
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
return v.FieldByName(strmangle.TitleCase(columns[i]))
var fieldName string
if titleCases == nil {
fieldName = strmangle.TitleCase(columns[i])
} else {
fieldName = titleCases[columns[i]]
}
return v.FieldByName(fieldName)
}
}

View file

@ -270,7 +270,7 @@ func TestBindPtrs_Easy(t *testing.T) {
}{}
cols := []string{"identifier", "date"}
ptrs, err := bindPtrs(&testStruct, cols...)
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
@ -296,7 +296,7 @@ func TestBindPtrs_Recursive(t *testing.T) {
}{}
cols := []string{"id", "fun.id"}
ptrs, err := bindPtrs(&testStruct, cols...)
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
@ -322,7 +322,7 @@ func TestBindPtrs_RecursiveTags(t *testing.T) {
}{}
cols := []string{"happy.identifier", "fun.identification"}
ptrs, err := bindPtrs(&testStruct, cols...)
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
@ -346,7 +346,7 @@ func TestBindPtrs_Ignore(t *testing.T) {
}{}
cols := []string{"id"}
ptrs, err := bindPtrs(&testStruct, cols...)
ptrs, err := bindPtrs(&testStruct, nil, cols...)
if err != nil {
t.Error(err)
}
@ -376,7 +376,7 @@ func TestGetStructValues(t *testing.T) {
NullBool: null.NewBool(true, false),
}
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
vals := GetStructValues(&o, nil, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
if vals[0].(string) != "patrick" {
t.Errorf("Want test, got %s", vals[0])
}
@ -415,7 +415,7 @@ func TestGetSliceValues(t *testing.T) {
in[0] = o[0]
in[1] = o[1]
vals := GetSliceValues(in, "id", "name")
vals := GetSliceValues(in, nil, "id", "name")
if got := vals[0].(int); got != 5 {
t.Error(got)
}
@ -440,7 +440,7 @@ func TestGetStructPointers(t *testing.T) {
Title: "patrick",
}
ptrs := GetStructPointers(&o, "title", "id")
ptrs := GetStructPointers(&o, nil, "title", "id")
*ptrs[0].(*string) = "test"
if o.Title != "test" {
t.Errorf("Expected test, got %s", o.Title)

View file

@ -5,6 +5,11 @@ var (
{{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
{{$varNameSingular}}TitleCases = map[string]string{
{{range $col := .Table.Columns | columnNames -}}
"{{$col}}": "{{titleCase $col}}",
{{end -}}
}
)
type (

View file

@ -16,7 +16,7 @@ func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) {
boil.SetLimit(q.Query, 1)
err := q.Bind(o)
err := q.BindFast(o, {{$varNameSingular}}TitleCases)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows
@ -41,7 +41,7 @@ func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice {
func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) {
var o {{$tableNameSingular}}Slice
err := q.Bind(&o)
err := q.BindFast(&o, {{$varNameSingular}}TitleCases)
if err != nil {
return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice")
}

View file

@ -40,7 +40,7 @@ func (r *{{.LocalTable.NameGo}}Loaded) Load{{.Function.Name}}(e boil.Executor, s
defer results.Close()
var resultSlice []*{{.ForeignTable.NameGo}}
if err = boil.Bind(results, &resultSlice); err != nil {
if err = boil.BindFast(results, &resultSlice, {{.ForeignKey.Table | singular | camelCase}}TitleCases); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}")
}

View file

@ -76,7 +76,7 @@ func (r *{{$rel.LocalTable.NameGo}}Loaded) Load{{$rel.Function.Name}}(e boil.Exe
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
}
{{else -}}
if err = boil.Bind(results, &resultSlice); err != nil {
if err = boil.BindFast(results, &resultSlice, {{$dot.Table.Name | singular | camelCase}}TitleCases); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}")
}
{{end}}

View file

@ -35,7 +35,7 @@ func {{$tableNameSingular}}Find(exec boil.Executor, {{$pkArgs}}, selectCols ...s
q := boil.SQL(query, {{$pkNames | join ", "}})
boil.SetExecutor(q, exec)
err := q.Bind({{$varNameSingular}}Obj)
err := q.BindFast({{$varNameSingular}}Obj, {{$varNameSingular}}TitleCases)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows

View file

@ -35,7 +35,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o),
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}TitleCases, o),
whitelist,
)
@ -49,10 +49,10 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
{{if .UseLastInsertID}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, ins)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, wl...))
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...))
}
result, err := exec.Exec(ins, boil.GetStructValues(o, wl...)...)
result, err := exec.Exec(ins, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...)...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
}
@ -67,21 +67,21 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
}
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}AutoIncPrimaryKeys))
err = exec.QueryRow(sel, lastID).Scan(boil.GetStructPointers(o, returnColumns...))
err = exec.QueryRow(sel, lastID).Scan(boil.GetStructPointers(o, {{$varNameSingular}}TitleCases, returnColumns...))
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
}
{{else}}
if len(returnColumns) != 0 {
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)...)
err = exec.QueryRow(ins, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...)...).Scan(boil.GetStructPointers(o, {{$varNameSingular}}TitleCases, returnColumns...)...)
} else {
_, err = exec.Exec(ins, boil.GetStructValues(o, wl...)...)
_, err = exec.Exec(ins, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...)...)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, ins)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, wl...))
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...))
}
if err != nil {

View file

@ -49,7 +49,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string
}
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns))
values = boil.GetStructValues(o, wl...)
values = boil.GetStructValues(o, {{$varNameSingular}}TitleCases, wl...)
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
if boil.DebugMode {

View file

@ -31,7 +31,7 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict boo
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o),
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}TitleCases, o),
whitelist,
)
update := strmangle.UpdateColumnSet(
@ -54,16 +54,16 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict boo
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...))
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, whitelist...))
}
{{- if .UseLastInsertID}}
return errors.New("don't know how to do this yet")
{{- else}}
if len(ret) != 0 {
err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...)
err = exec.QueryRow(query, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, whitelist...)...).Scan(boil.GetStructPointers(o, {{$varNameSingular}}TitleCases, ret...)...)
} else {
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...)
_, err = exec.Exec(query, boil.GetStructValues(o, {{$varNameSingular}}TitleCases, whitelist...)...)
}
{{- end}}

View file

@ -75,7 +75,7 @@ func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error {
q := boil.SQL(sql, args...)
boil.SetExecutor(q, exec)
err := q.Bind(&{{$varNamePlural}})
err := q.BindFast(&{{$varNamePlural}}, {{$varNameSingular}}TitleCases)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice")
}

View file

@ -77,6 +77,7 @@ func textsFromForeignKey(packageName string, tables []bdb.Table, table bdb.Table
func textsFromOneToOneRelationship(packageName string, tables []bdb.Table, table bdb.Table, toMany bdb.ToManyRelationship) RelationshipToOneTexts {
fkey := bdb.ForeignKey{
Table: toMany.Table,
Name: "none",
Column: toMany.Column,
Nullable: toMany.Nullable,