diff --git a/boil/query.go b/boil/query.go index 8f31482..ef8df2a 100644 --- a/boil/query.go +++ b/boil/query.go @@ -60,8 +60,17 @@ func ExecStatement(q *Query) (sql.Result, error) { return q.executor.Exec(qs, args...) } -// ExecQuery executes the query for the All finisher and returns multiple rows -func ExecQuery(q *Query) (*sql.Rows, error) { +// ExecQuery executes the query for the One finisher and returns a single row +func ExecQuery(q *Query) *sql.Row { + qs, args := buildQuery(q) + if DebugMode { + fmt.Fprintln(DebugWriter, qs) + } + return q.executor.QueryRow(qs, args...) +} + +// ExecQueryAll executes the query for the All finisher and returns multiple rows +func ExecQueryAll(q *Query) (*sql.Rows, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) diff --git a/boil/reflect.go b/boil/reflect.go index 662c5a6..2f41efb 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -1,48 +1,17 @@ package boil import ( - "database/sql" "fmt" "reflect" + "strings" "github.com/nullbio/sqlboiler/strmangle" + "github.com/pkg/errors" ) -// Bind executes the query and inserts the -// result into the passed in object pointer -func (q *Query) Bind(obj interface{}) error { - return nil - /*typ := reflect.TypeOf(obj) - kind := typ.Kind() - - if kind != reflect.Ptr { - return fmt.Errorf("Bind not given a pointer to a slice or struct: %s", typ.String()) - } - - typ = typ.Elem() - kind = typ.Kind() - - if kind == reflect.Struct { - row := ExecQueryOne(q) - err := BindOne(row, q.selectCols, obj) - if err != nil { - return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err) - } - } else if kind == reflect.Slice { - rows, err := ExecQueryAll(q) - if err != nil { - return fmt.Errorf("Failed to execute Bind query for %s: %s", q.from, err) - } - err = BindAll(rows, q.selectCols, obj) - if err != nil { - return fmt.Errorf("Failed to Bind results to object provided for %s: %s", q.from, err) - } - } else { - return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String()) - } - - return nil*/ -} +var ( + bindAccepts = []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct} +) // BindP executes the query and inserts the // result into the passed in object pointer. @@ -53,60 +22,210 @@ func (q *Query) BindP(obj interface{}) { } } -// BindOne inserts the returned row columns into the -// passed in object pointer -func BindOne(row *sql.Row, selectCols []string, obj interface{}) error { - kind := reflect.ValueOf(obj).Kind() - if kind != reflect.Ptr { - return fmt.Errorf("BindOne given a non-pointer type") - } - - pointers := GetStructPointers(obj, selectCols...) - if err := row.Scan(pointers...); err != nil { - return fmt.Errorf("Unable to scan into pointers: %s", err) - } - - return nil -} - -// BindAll inserts the returned rows columns into the -// passed in slice of object pointers -func BindAll(rows *sql.Rows, selectCols []string, obj interface{}) error { - ptrSlice := reflect.ValueOf(obj) - typ := ptrSlice.Type() - ptrSlice = ptrSlice.Elem() +// Bind executes the query and inserts the +// result into the passed in object pointer +// +// Bind rules: +// - Struct tags control bind, in the form of: `boil:"name,bind"` +// - If the "name" part of the struct tag is specified, that will be used +// for binding instead of the snake_cased field name. +// - If the ",bind" option is specified on an struct field, it will be recursed +// into to look for fields for binding. +// - If the name of the struct tag is "-", this field will not be bound to +// +// Example Query: +// +// type JoinStruct struct { +// User1 *models.User `boil:"user,bind"` +// User2 *models.User `boil:"friend,bind"` +// // RandomData will not be recursed into to look for fields to +// // bind and will not be bound to because of the - for the name. +// RandomData myStruct `boil:"-"` +// // Date will not be recursed into to look for fields to bind because +// // it does not specify ,bind in the struct tag. But it can be bound to +// // as it does not specify a - for the name. +// Date time.Time +// } +// +// models.Users(qm.InnerJoin("users as friend on users.friend_id = friend.id")).Bind(&joinStruct) +func (q *Query) Bind(obj interface{}) error { + typ := reflect.TypeOf(obj) kind := typ.Kind() - var structTyp reflect.Type + var structType reflect.Type + var sliceType reflect.Type + var singular bool + + for i := 0; i < len(bindAccepts); i++ { + exp := bindAccepts[i] - for i, exp := range []reflect.Kind{reflect.Ptr, reflect.Slice, reflect.Ptr, reflect.Struct} { if i != 0 { typ = typ.Elem() kind = typ.Kind() } if kind != exp { - return fmt.Errorf("[%d] BindAll object type should be *[]*Type but was: %s", i, ptrSlice.Type().String()) + if exp == reflect.Slice || kind == reflect.Struct { + structType = typ + singular = true + break + } + + return errors.Errorf("obj type should be *[]*Type or *Type but was %q", reflect.TypeOf(obj).String()) } - if kind == reflect.Struct { - structTyp = typ + switch kind { + case reflect.Struct: + structType = typ + case reflect.Slice: + sliceType = typ } } + return bind(q, obj, structType, sliceType, singular) +} + +func bind(q *Query, obj interface{}, structType, sliceType reflect.Type, singular bool) error { + rows, err := ExecQueryAll(q) + if err != nil { + return errors.Wrap(err, "bind failed to execute query") + } + + cols, err := rows.Columns() + if err != nil { + return errors.Wrap(err, "bind failed to get column names") + } + + var ptrSlice reflect.Value + if !singular { + ptrSlice = reflect.ValueOf(obj) + } + for rows.Next() { - newStruct := reflect.New(structTyp) - pointers := GetStructPointers(newStruct.Interface(), selectCols...) - if err := rows.Scan(pointers...); err != nil { - return fmt.Errorf("Unable to scan into pointers: %s", err) + var newStruct reflect.Value + var pointers []interface{} + + if singular { + pointers, err = bindPtrs(obj, cols...) + } else { + newStruct = reflect.New(structType) + pointers, err = bindPtrs(newStruct.Interface(), cols...) + } + if err != nil { + return err } - ptrSlice.Set(reflect.Append(ptrSlice, newStruct)) + if err := rows.Scan(pointers...); err != nil { + return errors.Wrap(err, "failed to bind pointers to obj") + } + + if !singular { + ptrSlice.Set(reflect.Append(ptrSlice, newStruct)) + } } return nil } +func bindPtrs(obj interface{}, cols ...string) ([]interface{}, error) { + v := reflect.ValueOf(obj) + ptrs := make([]interface{}, len(cols)) + + for i, c := range cols { + names := strings.Split(c, ".") + + ptr, ok := findField(names, v) + if !ok { + return nil, errors.Errorf("bindPtrs failed to find field %s", c) + } + + ptrs[i] = ptr + } + + return ptrs, nil +} + +func findField(names []string, v reflect.Value) (interface{}, bool) { + fmt.Println("Names:", names) + fmt.Println("Type:", v.Type().String()) + + if !v.IsValid() || len(names) == 0 { + return nil, false + } + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, false + } + v = reflect.Indirect(v) + } + + if v.Kind() != reflect.Struct { + return nil, false + } + + name := strmangle.TitleCase(names[0]) + typ := v.Type() + + fi, ok := typ.FieldByName(name) + if ok { + fieldName, recurse := getBoilTag(fi) + if fieldName != "-" { + if recurse { + return findField(names[1:], v.FieldByName(name)) + } + + if len(names) == 1 { + field := v.FieldByName(name) + if field.Kind() != reflect.Ptr { + return field.Addr().Interface(), true + } + return field.Interface(), true + } + } + } + + n := typ.NumField() + for i := 0; i < n; i++ { + f := typ.Field(i) + fieldName, recurse := getBoilTag(f) + + fmt.Println(name, fieldName, recurse) + + if fieldName == "-" { + continue + } + + if recurse { + return findField(names, v.Field(i)) + } + + if fieldName == name { + fieldVal := v.Field(i) + if fieldVal.Kind() != reflect.Ptr { + return fieldVal.Addr().Interface(), true + } + return fieldVal.Interface(), true + } + } + + return nil, false +} + +func getBoilTag(field reflect.StructField) (name string, recurse bool) { + tag := field.Tag.Get("boil") + + if len(tag) != 0 { + tagTokens := strings.Split(tag, ",") + name = strmangle.TitleCase(tagTokens[0]) + recurse = len(tagTokens) > 1 && tagTokens[1] == "bind" + } else { + name = field.Name + } + + return name, recurse +} + // GetStructValues returns the values (as interface) of the matching columns in obj func GetStructValues(obj interface{}, columns ...string) []interface{} { ret := make([]interface{}, len(columns)) diff --git a/boil/reflect_test.go b/boil/reflect_test.go index 78f13b6..e33d870 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -11,16 +11,97 @@ func TestBind(t *testing.T) { t.Skip("Not implemented") } -func TestBindP(t *testing.T) { - t.Skip("Not implemented") +func TestBindPtrs_Easy(t *testing.T) { + t.Parallel() + + testStruct := struct { + ID int `boil:"identifier"` + Date time.Time + }{} + + cols := []string{"identifier", "date"} + ptrs, err := bindPtrs(&testStruct, cols...) + if err != nil { + t.Error(err) + } + + if ptrs[0].(*int) != &testStruct.ID { + t.Error("id is the wrong pointer") + } + if ptrs[1].(*time.Time) != &testStruct.Date { + t.Error("id is the wrong pointer") + } } -func TestBindOne(t *testing.T) { - t.Skip("Not implemented") +func TestBindPtrs_Recursive(t *testing.T) { + t.Parallel() + + testStruct := struct { + Happy struct { + ID int `boil:"identifier"` + } + Fun struct { + ID int + } `boil:",bind"` + }{} + + cols := []string{"id", "fun.id"} + ptrs, err := bindPtrs(&testStruct, cols...) + if err != nil { + t.Error(err) + } + + if ptrs[0].(*int) != &testStruct.Fun.ID { + t.Error("id is the wrong pointer") + } + if ptrs[1].(*int) != &testStruct.Fun.ID { + t.Error("id is the wrong pointer") + } } -func TestBindAll(t *testing.T) { - t.Skip("Not implemented") +func TestBindPtrs_RecursiveTags(t *testing.T) { + testStruct := struct { + Happy struct { + ID int `boil:"identifier"` + } `boil:",bind"` + Fun struct { + ID int `boil:"identification"` + } `boil:",bind"` + }{} + + cols := []string{"happy.identifier", "fun.identification"} + ptrs, err := bindPtrs(&testStruct, cols...) + if err != nil { + t.Error(err) + } + + if ptrs[0].(*int) != &testStruct.Happy.ID { + t.Error("id is the wrong pointer") + } + if ptrs[1].(*int) != &testStruct.Fun.ID { + t.Error("id is the wrong pointer") + } +} + +func TestBindPtrs_Ignore(t *testing.T) { + t.Parallel() + + testStruct := struct { + ID int `boil:"-"` + Happy struct { + ID int + } `boil:",bind"` + }{} + + cols := []string{"id"} + ptrs, err := bindPtrs(&testStruct, cols...) + if err != nil { + t.Error(err) + } + + if ptrs[0].(*int) != &testStruct.Happy.ID { + t.Error("id is the wrong pointer") + } } func TestGetStructValues(t *testing.T) {