diff --git a/queries/reflect.go b/queries/reflect.go index bee0a12..1a1e6cb 100644 --- a/queries/reflect.go +++ b/queries/reflect.go @@ -407,74 +407,3 @@ func makeCacheKey(typ string, cols []string) string { return mapKey } - -// GetStructValues returns the values (as interface) of the matching columns in obj -func GetStructValues(obj interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(columns)) - val := reflect.Indirect(reflect.ValueOf(obj)) - - for i, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i] = field.Interface() - } - - return ret -} - -// GetSliceValues returns the values (as interface) of the matching columns in obj. -func GetSliceValues(slice []interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(slice)*len(columns)) - - for i, obj := range slice { - val := reflect.Indirect(reflect.ValueOf(obj)) - for j, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i*len(columns)+j] = field.Interface() - } - } - - return ret -} - -// GetStructPointers returns a slice of pointers to the matching columns in obj -func GetStructPointers(obj interface{}, columns ...string) []interface{} { - val := reflect.ValueOf(obj).Elem() - - var ln int - var getField func(reflect.Value, int) reflect.Value - - if len(columns) == 0 { - ln = val.NumField() - getField = func(v reflect.Value, i int) reflect.Value { - return v.Field(i) - } - } else { - ln = len(columns) - getField = func(v reflect.Value, i int) reflect.Value { - return v.FieldByName(strmangle.TitleCase(columns[i])) - } - } - - ret := make([]interface{}, ln) - for i := 0; i < ln; i++ { - field := getField(val, i) - - if !field.IsValid() { - // Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually - // produce an Invalid value, so we make a hopefully safe assumption here. - panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i]))) - } - - ret[i] = field.Addr().Interface() - } - - return ret -} diff --git a/queries/reflect_test.go b/queries/reflect_test.go index 7641557..a1484dc 100644 --- a/queries/reflect_test.go +++ b/queries/reflect_test.go @@ -6,10 +6,8 @@ import ( "strconv" "strings" "testing" - "time" "gopkg.in/DATA-DOG/go-sqlmock.v1" - "gopkg.in/nullbio/null.v5" ) func bin64(i uint64) string { @@ -609,99 +607,3 @@ func TestBind_InnerJoin(t *testing.T) { // t.Error("id is the wrong pointer") // } // } - -func TestGetStructValues(t *testing.T) { - t.Parallel() - - timeThing := time.Now() - o := struct { - TitleThing string - Name string - ID int - Stuff int - Things int - Time time.Time - NullBool null.Bool - }{ - TitleThing: "patrick", - Stuff: 10, - Things: 0, - Time: timeThing, - NullBool: null.NewBool(true, false), - } - - vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") - if vals[0].(string) != "patrick" { - t.Errorf("Want test, got %s", vals[0]) - } - if vals[1].(string) != "" { - t.Errorf("Want empty string, got %s", vals[1]) - } - if vals[2].(int) != 0 { - t.Errorf("Want 0, got %d", vals[2]) - } - if vals[3].(int) != 10 { - t.Errorf("Want 10, got %d", vals[3]) - } - if vals[4].(int) != 0 { - t.Errorf("Want 0, got %d", vals[4]) - } - if !vals[5].(time.Time).Equal(timeThing) { - t.Errorf("Want %s, got %s", o.Time, vals[5]) - } - if !vals[6].(null.Bool).IsZero() { - t.Errorf("Want %v, got %v", o.NullBool, vals[6]) - } -} - -func TestGetSliceValues(t *testing.T) { - t.Parallel() - - o := []struct { - ID int - Name string - }{ - {5, "a"}, - {6, "b"}, - } - - in := make([]interface{}, len(o)) - in[0] = o[0] - in[1] = o[1] - - vals := GetSliceValues(in, "id", "name") - if got := vals[0].(int); got != 5 { - t.Error(got) - } - if got := vals[1].(string); got != "a" { - t.Error(got) - } - if got := vals[2].(int); got != 6 { - t.Error(got) - } - if got := vals[3].(string); got != "b" { - t.Error(got) - } -} - -func TestGetStructPointers(t *testing.T) { - t.Parallel() - - o := struct { - Title string - ID *int - }{ - Title: "patrick", - } - - ptrs := GetStructPointers(&o, "title", "id") - *ptrs[0].(*string) = "test" - if o.Title != "test" { - t.Errorf("Expected test, got %s", o.Title) - } - x := 5 - *ptrs[1].(**int) = &x - if *o.ID != 5 { - t.Errorf("Expected 5, got %d", *o.ID) - } -}