diff --git a/boil/query.go b/boil/query.go index 78bc3d0..81372c9 100644 --- a/boil/query.go +++ b/boil/query.go @@ -116,7 +116,7 @@ func ExecQueryOne(q *Query) *sql.Row { if DebugMode { fmt.Fprintln(DebugWriter, qs) } - return q.executor.QueryRow(qs, args) + return q.executor.QueryRow(qs, args...) } // ExecQueryAll executes the query for the All finisher and returns multiple rows @@ -125,7 +125,7 @@ func ExecQueryAll(q *Query) (*sql.Rows, error) { if DebugMode { fmt.Fprintln(DebugWriter, qs) } - return q.executor.Query(qs, args) + return q.executor.Query(qs, args...) } func SetCount(q *Query) { @@ -148,6 +148,12 @@ func SetSelect(q *Query, columns ...string) { q.selectCols = append(q.selectCols, columns...) } +func GetSelect(q *Query) []string { + cols := make([]string, len(q.selectCols)) + copy(cols, q.selectCols) + return cols +} + func SetTable(q *Query, table string) { q.table = table } diff --git a/boil/reflect.go b/boil/reflect.go index 47ed838..469df2a 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -40,13 +40,44 @@ func (q *Query) Bind(obj interface{}) error { // BindOne inserts the returned row columns into the // passed in object pointer -func BindOne(row *sql.Row, obj interface{}) error { +func BindOne(row *sql.Row, selectCols []string, obj interface{}) error { + kind := reflect.ValueOf(obj).Kind() + if kind != reflect.Ptr { + return fmt.Errorf("BindOne given a non-pointer type") + } + + pointers := GetStructPointers(obj, selectCols...) + if err := row.Scan(pointers...); err != nil { + return fmt.Errorf("Unable to scan into pointers: %s", err) + } + return nil } // BindAll inserts the returned rows columns into the // passed in slice of object pointers -func BindAll(rows *sql.Rows, obj interface{}) error { +func BindAll(rows *sql.Rows, selectCols []string, obj interface{}) error { + val := reflect.ValueOf(obj) + typ := reflect.TypeOf(obj) + kind := val.Kind() + + if kind != reflect.Slice { + return fmt.Errorf("BindAll given a non-slice type") + } + + spare := reflect.New(typ.Elem().Elem()) + fmt.Printf("%T, %#v, %s\n", spare, spare, spare.Type().String()) + + index := 0 + for rows.Next() { + val = reflect.Append(val, spare) + pointers := GetStructPointers(val.Index(index), selectCols...) + if err := rows.Scan(pointers...); err != nil { + return fmt.Errorf("Unable to scan into pointers: %s", err) + } + index++ + } + return nil } diff --git a/boil/reflect_test.go b/boil/reflect_test.go index 55346e8..ab2c8bb 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -7,6 +7,18 @@ import ( "github.com/guregu/null" ) +func TestBind(t *testing.T) { + +} + +func TestBindOne(t *testing.T) { + +} + +func TestBindAll(t *testing.T) { + +} + func TestGetStructValues(t *testing.T) { t.Parallel() timeThing := time.Now() diff --git a/cmds/templates/finishers.tpl b/cmds/templates/finishers.tpl index b3628c5..94cafa6 100644 --- a/cmds/templates/finishers.tpl +++ b/cmds/templates/finishers.tpl @@ -8,7 +8,7 @@ func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { boil.SetLimit(q.Query, 1) res := boil.ExecQueryOne(q.Query) - err := boil.BindOne(res, o) + err := boil.BindOne(res, boil.GetSelect(q.Query), o) if err != nil { return nil, fmt.Errorf("{{.PkgName}}: failed to execute a one query for {{.Table.Name}}: %s", err) } @@ -17,14 +17,14 @@ func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { } func (q {{$varNameSingular}}Query) All() ({{$varNameSingular}}Slice, error) { - var o []*{{$tableNameSingular}} + var o *{{$varNameSingular}}Slice res, err := boil.ExecQueryAll(q.Query) if err != nil { return nil, fmt.Errorf("{{.PkgName}}: failed to execute an all query for {{.Table.Name}}: %s", err) } - err = boil.BindAll(res, o) + err = boil.BindAll(res, boil.GetSelect(q.Query), o) if err != nil { return nil, fmt.Errorf("{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice: %s", err) }