From 8298da6f48c82b5bd93ba6a7673be38cb1b7fb2d Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sat, 6 Aug 2016 14:37:55 -0700 Subject: [PATCH] Refactor GetStructPointers --- boil/reflect.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/boil/reflect.go b/boil/reflect.go index 3ec0371..25a633e 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -122,27 +122,33 @@ func GetStructValues(obj interface{}, columns ...string) []interface{} { // GetStructPointers returns a slice of pointers to the matching columns in obj func GetStructPointers(obj interface{}, columns ...string) []interface{} { val := reflect.ValueOf(obj).Elem() - var ret []interface{} + + var ln int + var getField func(reflect.Value, int) reflect.Value if len(columns) == 0 { - fieldsLen := val.NumField() - ret = make([]interface{}, fieldsLen) - for i := 0; i < fieldsLen; i++ { - ret[i] = val.Field(i).Addr().Interface() + ln = val.NumField() + getField = func(v reflect.Value, i int) reflect.Value { + return v.Field(i) + } + } else { + ln = len(columns) + getField = func(v reflect.Value, i int) reflect.Value { + return v.FieldByName(strmangle.TitleCase(columns[i])) } - return ret } - ret = make([]interface{}, len(columns)) + ret := make([]interface{}, ln) + for i := 0; i < ln; i++ { + field := getField(val, i) - for i, c := range columns { - field := val.FieldByName(strmangle.TitleCase(c)) if !field.IsValid() { - panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(c))) + // Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually + // produce an Invalid value, so we make a hopefully safe assumption here. + panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i]))) } - field = field.Addr() - ret[i] = field.Interface() + ret[i] = field.Addr().Interface() } return ret