Added more reflect helpers

This commit is contained in:
Patrick O'brien 2016-07-09 02:39:36 +10:00
parent 08d168605f
commit 013b3ae0f8
9 changed files with 224 additions and 76 deletions

View file

@ -71,50 +71,55 @@ var (
rgxByteaDefaultValue = regexp.MustCompile(`(?i)\\x([0-9A-F]*)`)
)
// DefaultValue returns the Go converted value of the default value column
func DefaultValue(column Column) string {
defaultVal := ""
// DefaultValues returns the Go converted values of the default value columns
func DefaultValues(columns []Column) []string {
var dVals []string
// Attempt to strip out the raw default value if its contained
// within a Postgres type cast statement
m := rgxRawDefaultValue.FindStringSubmatch(column.Default)
if len(m) > 1 {
defaultVal = m[len(m)-1]
} else {
defaultVal = column.Default
for _, c := range columns {
var dVal string
// Attempt to strip out the raw default value if its contained
// within a Postgres type cast statement
m := rgxRawDefaultValue.FindStringSubmatch(c.Default)
if len(m) > 1 {
dVal = m[len(m)-1]
} else {
dVal = c.Default
}
switch c.Type {
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
"uint", "uint8", "uint16", "uint32", "uint64",
"int", "int8", "int16", "int32", "int64",
"null.Float32", "null.Float64", "float32", "float64":
dVals = append(dVals, dVal)
case "null.Bool", "bool":
m = rgxBoolDefaultValue.FindStringSubmatch(dVal)
if len(m) == 0 {
dVals = append(dVals, "false")
}
dVals = append(dVals, strings.ToLower(m[0]))
case "null.Time", "time.Time", "null.String", "string":
dVals = append(dVals, `"`+dVal+`"`)
case "[]byte":
m := rgxByteaDefaultValue.FindStringSubmatch(dVal)
if len(m) != 2 {
dVals = append(dVals, `[]byte{}`)
}
hexstr := m[1]
bs := make([]string, len(hexstr)/2)
count := 0
for i := 0; i < len(hexstr); i += 2 {
bs[count] = "0x" + hexstr[i:i+2]
count++
}
dVals = append(dVals, `[]byte{`+strings.Join(bs, ", ")+`}`)
default:
dVals = append(dVals, "")
}
}
switch column.Type {
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
"uint", "uint8", "uint16", "uint32", "uint64",
"int", "int8", "int16", "int32", "int64",
"null.Float32", "null.Float64", "float32", "float64":
return defaultVal
case "null.Bool", "bool":
m = rgxBoolDefaultValue.FindStringSubmatch(defaultVal)
if len(m) == 0 {
return "false"
}
return strings.ToLower(m[0])
case "null.Time", "time.Time", "null.String", "string":
return `"` + defaultVal + `"`
case "[]byte":
m := rgxByteaDefaultValue.FindStringSubmatch(defaultVal)
if len(m) != 2 {
return `[]byte{}`
}
hexstr := m[1]
bs := make([]string, len(hexstr)/2)
c := 0
for i := 0; i < len(hexstr); i += 2 {
bs[c] = "0x" + hexstr[i:i+2]
c++
}
return `[]byte{` + strings.Join(bs, ", ") + `}`
default:
return ""
}
return dVals
}
// ZeroValue returns the zero value string of the column type
@ -134,7 +139,7 @@ func ZeroValue(column Column) string {
case "null.Time", "time.Time":
return `time.Time{}`
case "[]byte":
return `[]byte{}`
return `[]byte(nil)`
default:
return ""
}

View file

@ -52,21 +52,24 @@ func TestFilterColumnsByDefault(t *testing.T) {
}
}
func TestDefaultValue(t *testing.T) {
func TestDefaultValues(t *testing.T) {
c := Column{}
c.Default = `\x12345678`
c.Type = "[]byte"
res := DefaultValue(c)
if res != `[]byte{0x12, 0x34, 0x56, 0x78}` {
res := DefaultValues([]Column{c})
if len(res) != 1 {
t.Errorf("Expected res len 1, got %d", len(res))
}
if res[0] != `[]byte{0x12, 0x34, 0x56, 0x78}` {
t.Errorf("Invalid result: %#v", res)
}
c.Default = `\x`
res = DefaultValue(c)
if res != `[]byte{}` {
res = DefaultValues([]Column{c})
if res[0] != `[]byte{}` {
t.Errorf("Invalid result: %#v", res)
}
}

View file

@ -76,12 +76,12 @@ func TestSetWhere(t *testing.T) {
SetWhere(q, "x > $1 AND y > $2", 5, 3)
if len(q.where) != 1 {
t.Errorf("Expected %d where slices, got %d", len(q.where))
t.Errorf("Expected %d where slices, got %d", 1, len(q.where))
}
expect := "x > $1 AND y > $2"
if q.where[0].clause != expect {
t.Errorf("Expected %s, got %s", expect, q.where)
t.Errorf("Expected %s, got %v", expect, q.where)
}
if len(q.where[0].args) != 2 {

View file

@ -6,6 +6,7 @@ import (
"math"
"math/rand"
"reflect"
"regexp"
"sort"
"time"
@ -30,6 +31,8 @@ var (
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
)
// Bind executes the query and inserts the
@ -159,9 +162,12 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
return typ, isSlice, nil
}
// IsZeroValue checks if the variables with matching columns in obj are zero values
func isZeroValue(obj interface{}, columns ...string) bool {
val := reflect.ValueOf(obj)
// IsZeroValue checks if the variables with matching columns in obj
// are or are not zero values, depending on whether shouldZero is true or false
func IsZeroValue(obj interface{}, shouldZero bool, columns ...string) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for _, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
@ -169,12 +175,58 @@ func isZeroValue(obj interface{}, columns ...string) bool {
}
zv := reflect.Zero(field.Type())
if !reflect.DeepEqual(field.Interface(), zv.Interface()) {
return false
if shouldZero && !reflect.DeepEqual(field.Interface(), zv.Interface()) {
errs = append(errs, fmt.Errorf("Column with name %s is not zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
} else if !shouldZero && reflect.DeepEqual(field.Interface(), zv.Interface()) {
errs = append(errs, fmt.Errorf("Column with name %s is zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
}
}
return true
return errs
}
// IsValueMatch checks whether the variables in obj with matching column names
// match the values in the values slice.
func IsValueMatch(obj interface{}, columns []string, values []interface{}) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Unable to find variable with column name %s", c))
}
typ := field.Type().String()
if typ == "time.Time" || typ == "null.Time" {
var timeField reflect.Value
var valTimeStr string
if typ == "time.Time" {
valTimeStr = values[i].(time.Time).String()
timeField = field
} else {
valTimeStr = values[i].(null.Time).Time.String()
timeField = field.FieldByName("Time")
validField := field.FieldByName("Valid")
if validField.Interface() != values[i].(null.Time).Valid {
errs = append(errs, fmt.Errorf("Null.Time column with name %s Valid field does not match: %v ≠ %v", c, values[i].(null.Time).Valid, validField.Interface()))
}
}
if (rgxValidTime.MatchString(valTimeStr) && timeField.Interface() == reflect.Zero(timeField.Type()).Interface()) ||
(!rgxValidTime.MatchString(valTimeStr) && timeField.Interface() != reflect.Zero(timeField.Type()).Interface()) {
errs = append(errs, fmt.Errorf("Time column with name %s Time field does not match: %v ≠ %v", c, values[i], timeField.Interface()))
}
continue
}
if !reflect.DeepEqual(field.Interface(), values[i]) {
errs = append(errs, fmt.Errorf("Column with name %s does not match value: %#v ≠ %#v", c, values[i], field.Interface()))
}
}
return errs
}
// GetStructValues returns the values (as interface) of the matching columns in obj

View file

@ -28,14 +28,16 @@ func TestIsZeroValue(t *testing.T) {
E int64
}{}
if !isZeroValue(o, "A", "B", "C", "D", "E") {
t.Errorf("Expected all values to be zero values: %#v", o)
if errs := IsZeroValue(o, true, "A", "B", "C", "D", "E"); errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
colNames := []string{"A", "B", "C", "D", "E"}
for _, c := range colNames {
if !isZeroValue(o, c) {
t.Errorf("Expected %s to be zero value: %#v", c, o)
if err := IsZeroValue(o, true, c); err != nil {
t.Errorf("Expected %s to be zero value: %s", c, err[0])
}
}
@ -45,9 +47,83 @@ func TestIsZeroValue(t *testing.T) {
o.D = null.NewInt64(2, false)
o.E = 5
if errs := IsZeroValue(o, false, "A", "B", "C", "D", "E"); errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
for _, c := range colNames {
if isZeroValue(o, c) {
t.Errorf("Expected %s to be non-zero value: %#v", c, o)
if err := IsZeroValue(o, false, c); err != nil {
t.Errorf("Expected %s to be non-zero value: %s", c, err[0])
}
}
}
func TestIsValueMatch(t *testing.T) {
var errs []error
var values []interface{}
o := struct {
A []byte
B time.Time
C null.Time
D null.Int64
E int64
}{}
values = []interface{}{
[]byte(nil),
time.Time{},
null.Time{},
null.Int64{},
int64(0),
}
cols := []string{"A", "B", "C", "D", "E"}
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
values = []interface{}{
[]byte("hi"),
time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC),
null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true),
null.NewInt64(5, false),
int64(6),
}
errs = IsValueMatch(o, cols, values)
// Expect 6 errors
// 5 for each column and an additional 1 for the invalid Valid field match
if len(errs) != 6 {
t.Errorf("Expected 6 errors, got: %d", len(errs))
for _, e := range errs {
t.Errorf("%s", e)
}
}
o.A = []byte("hi")
o.B = time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC)
o.C = null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true)
o.D = null.NewInt64(5, false)
o.E = 6
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
o.B = time.Date(2007, 11, 2, 2, 2, 2, 2, time.UTC)
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
}

View file

@ -193,7 +193,6 @@ var defaultSingletonTestTemplateImports = map[string]imports{
`"os"`,
`"strconv"`,
`"math/rand"`,
`"regexp"`,
`"bytes"`,
},
thirdParty: importList{},

View file

@ -145,5 +145,5 @@ var templateFunctions = template.FuncMap{
"columnNames": bdb.ColumnNames,
"toManyRelationships": bdb.ToManyRelationships,
"zeroValue": bdb.ZeroValue,
"defaultValue": bdb.DefaultValue,
"defaultValues": bdb.DefaultValues,
}

View file

@ -3,8 +3,11 @@
{{- $tableNamePlural := .Table.Name | plural | titleCase -}}
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $parent := .}}
func Test{{$tableNamePlural}}Insert(t *testing.T) {
var err error
var errs []error
emptyTime := time.Time{}.String()
{{$varNamePlural}}DeleteAllRows(t)
@ -42,22 +45,37 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
}
{{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames}}
{{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}
// Ensure the auto increment columns are returned in the object
{{range .}}
if item.{{titleCase .}} <= 0 {
t.Errorf("Expected the auto-increment columns to be greater than 0, got: %d", item.{{titleCase .}})
if errs = boil.IsZeroValue(item, false, {{.}}); errs != nil {
for _, e := range errs {
t.Errorf("Expected auto-increment columns to be greater than 0, err: %s\n", e)
}
}
{{end}}
{{end}}
emptyTime := time.Time{}.String()
{{with .Table.Columns | filterColumnsBySimpleDefault}}
simpleDefaults := []string{{"{"}}{{. | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}{{"}"}}
defaultValues := []interface{}{{"{"}}{{. | defaultValues | join ", "}}{{"}"}}
if len(simpleDefaults) != len(defaultValues) {
t.Fatalf("Mismatch between slice lengths: %d, %d", len(simpleDefaults), len(defaultValues))
}
if errs = boil.IsValueMatch(item, simpleDefaults, defaultValues); errs != nil {
for _, e := range errs {
t.Errorf("Expected default value to match column value, err: %s\n", e);
}
}
{{end}}
/*{{with .Table.Columns | filterColumnsBySimpleDefault}}
// Ensure the default value columns are returned in the object
{{range .}}
{{$tc := titleCase .Name}}
{{$zv := zeroValue .}}
{{$dv := defaultValue .}}
{{$dv := "false"}}
{{$ty := trimPrefix "null." .Type}}
{{if and (ne $ty "[]byte") .IsNullable}}
if item.{{$tc}}.Valid == false {
@ -84,7 +102,7 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
}
{{end}}
{{end}}
{{end}}
{{end}}*/
{{with .Table.Columns | filterColumnsByAutoIncrement false | filterColumnsByDefault false}}
// Ensure the non-defaultvalue columns and non-autoincrement columns are stored correctly as zero or null values.

View file

@ -1,10 +1,5 @@
var dbNameRand *rand.Rand
func isZeroTime(time string) bool {
re := regexp.MustCompile(`[2-9]+`)
return !re.MatchString(time)
}
func initDBNameRand(input string) {
sum := md5.Sum([]byte(input))