Added more reflect helpers

This commit is contained in:
Patrick O'brien 2016-07-09 02:39:36 +10:00
parent 08d168605f
commit 013b3ae0f8
9 changed files with 224 additions and 76 deletions

View file

@ -71,50 +71,55 @@ var (
rgxByteaDefaultValue = regexp.MustCompile(`(?i)\\x([0-9A-F]*)`) rgxByteaDefaultValue = regexp.MustCompile(`(?i)\\x([0-9A-F]*)`)
) )
// DefaultValue returns the Go converted value of the default value column // DefaultValues returns the Go converted values of the default value columns
func DefaultValue(column Column) string { func DefaultValues(columns []Column) []string {
defaultVal := "" var dVals []string
// Attempt to strip out the raw default value if its contained for _, c := range columns {
// within a Postgres type cast statement var dVal string
m := rgxRawDefaultValue.FindStringSubmatch(column.Default) // Attempt to strip out the raw default value if its contained
if len(m) > 1 { // within a Postgres type cast statement
defaultVal = m[len(m)-1] m := rgxRawDefaultValue.FindStringSubmatch(c.Default)
} else { if len(m) > 1 {
defaultVal = column.Default dVal = m[len(m)-1]
} else {
dVal = c.Default
}
switch c.Type {
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
"uint", "uint8", "uint16", "uint32", "uint64",
"int", "int8", "int16", "int32", "int64",
"null.Float32", "null.Float64", "float32", "float64":
dVals = append(dVals, dVal)
case "null.Bool", "bool":
m = rgxBoolDefaultValue.FindStringSubmatch(dVal)
if len(m) == 0 {
dVals = append(dVals, "false")
}
dVals = append(dVals, strings.ToLower(m[0]))
case "null.Time", "time.Time", "null.String", "string":
dVals = append(dVals, `"`+dVal+`"`)
case "[]byte":
m := rgxByteaDefaultValue.FindStringSubmatch(dVal)
if len(m) != 2 {
dVals = append(dVals, `[]byte{}`)
}
hexstr := m[1]
bs := make([]string, len(hexstr)/2)
count := 0
for i := 0; i < len(hexstr); i += 2 {
bs[count] = "0x" + hexstr[i:i+2]
count++
}
dVals = append(dVals, `[]byte{`+strings.Join(bs, ", ")+`}`)
default:
dVals = append(dVals, "")
}
} }
switch column.Type { return dVals
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
"uint", "uint8", "uint16", "uint32", "uint64",
"int", "int8", "int16", "int32", "int64",
"null.Float32", "null.Float64", "float32", "float64":
return defaultVal
case "null.Bool", "bool":
m = rgxBoolDefaultValue.FindStringSubmatch(defaultVal)
if len(m) == 0 {
return "false"
}
return strings.ToLower(m[0])
case "null.Time", "time.Time", "null.String", "string":
return `"` + defaultVal + `"`
case "[]byte":
m := rgxByteaDefaultValue.FindStringSubmatch(defaultVal)
if len(m) != 2 {
return `[]byte{}`
}
hexstr := m[1]
bs := make([]string, len(hexstr)/2)
c := 0
for i := 0; i < len(hexstr); i += 2 {
bs[c] = "0x" + hexstr[i:i+2]
c++
}
return `[]byte{` + strings.Join(bs, ", ") + `}`
default:
return ""
}
} }
// ZeroValue returns the zero value string of the column type // ZeroValue returns the zero value string of the column type
@ -134,7 +139,7 @@ func ZeroValue(column Column) string {
case "null.Time", "time.Time": case "null.Time", "time.Time":
return `time.Time{}` return `time.Time{}`
case "[]byte": case "[]byte":
return `[]byte{}` return `[]byte(nil)`
default: default:
return "" return ""
} }

View file

@ -52,21 +52,24 @@ func TestFilterColumnsByDefault(t *testing.T) {
} }
} }
func TestDefaultValue(t *testing.T) { func TestDefaultValues(t *testing.T) {
c := Column{} c := Column{}
c.Default = `\x12345678` c.Default = `\x12345678`
c.Type = "[]byte" c.Type = "[]byte"
res := DefaultValue(c) res := DefaultValues([]Column{c})
if res != `[]byte{0x12, 0x34, 0x56, 0x78}` { if len(res) != 1 {
t.Errorf("Expected res len 1, got %d", len(res))
}
if res[0] != `[]byte{0x12, 0x34, 0x56, 0x78}` {
t.Errorf("Invalid result: %#v", res) t.Errorf("Invalid result: %#v", res)
} }
c.Default = `\x` c.Default = `\x`
res = DefaultValue(c) res = DefaultValues([]Column{c})
if res != `[]byte{}` { if res[0] != `[]byte{}` {
t.Errorf("Invalid result: %#v", res) t.Errorf("Invalid result: %#v", res)
} }
} }

View file

@ -76,12 +76,12 @@ func TestSetWhere(t *testing.T) {
SetWhere(q, "x > $1 AND y > $2", 5, 3) SetWhere(q, "x > $1 AND y > $2", 5, 3)
if len(q.where) != 1 { if len(q.where) != 1 {
t.Errorf("Expected %d where slices, got %d", len(q.where)) t.Errorf("Expected %d where slices, got %d", 1, len(q.where))
} }
expect := "x > $1 AND y > $2" expect := "x > $1 AND y > $2"
if q.where[0].clause != expect { if q.where[0].clause != expect {
t.Errorf("Expected %s, got %s", expect, q.where) t.Errorf("Expected %s, got %v", expect, q.where)
} }
if len(q.where[0].args) != 2 { if len(q.where[0].args) != 2 {

View file

@ -6,6 +6,7 @@ import (
"math" "math"
"math/rand" "math/rand"
"reflect" "reflect"
"regexp"
"sort" "sort"
"time" "time"
@ -30,6 +31,8 @@ var (
typeNullBool = reflect.TypeOf(null.Bool{}) typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{}) typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{}) typeTime = reflect.TypeOf(time.Time{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
) )
// Bind executes the query and inserts the // Bind executes the query and inserts the
@ -159,9 +162,12 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
return typ, isSlice, nil return typ, isSlice, nil
} }
// IsZeroValue checks if the variables with matching columns in obj are zero values // IsZeroValue checks if the variables with matching columns in obj
func isZeroValue(obj interface{}, columns ...string) bool { // are or are not zero values, depending on whether shouldZero is true or false
val := reflect.ValueOf(obj) func IsZeroValue(obj interface{}, shouldZero bool, columns ...string) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for _, c := range columns { for _, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c)) field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() { if !field.IsValid() {
@ -169,12 +175,58 @@ func isZeroValue(obj interface{}, columns ...string) bool {
} }
zv := reflect.Zero(field.Type()) zv := reflect.Zero(field.Type())
if !reflect.DeepEqual(field.Interface(), zv.Interface()) { if shouldZero && !reflect.DeepEqual(field.Interface(), zv.Interface()) {
return false errs = append(errs, fmt.Errorf("Column with name %s is not zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
} else if !shouldZero && reflect.DeepEqual(field.Interface(), zv.Interface()) {
errs = append(errs, fmt.Errorf("Column with name %s is zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
} }
} }
return true return errs
}
// IsValueMatch checks whether the variables in obj with matching column names
// match the values in the values slice.
func IsValueMatch(obj interface{}, columns []string, values []interface{}) []error {
val := reflect.Indirect(reflect.ValueOf(obj))
var errs []error
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
if !field.IsValid() {
panic(fmt.Sprintf("Unable to find variable with column name %s", c))
}
typ := field.Type().String()
if typ == "time.Time" || typ == "null.Time" {
var timeField reflect.Value
var valTimeStr string
if typ == "time.Time" {
valTimeStr = values[i].(time.Time).String()
timeField = field
} else {
valTimeStr = values[i].(null.Time).Time.String()
timeField = field.FieldByName("Time")
validField := field.FieldByName("Valid")
if validField.Interface() != values[i].(null.Time).Valid {
errs = append(errs, fmt.Errorf("Null.Time column with name %s Valid field does not match: %v ≠ %v", c, values[i].(null.Time).Valid, validField.Interface()))
}
}
if (rgxValidTime.MatchString(valTimeStr) && timeField.Interface() == reflect.Zero(timeField.Type()).Interface()) ||
(!rgxValidTime.MatchString(valTimeStr) && timeField.Interface() != reflect.Zero(timeField.Type()).Interface()) {
errs = append(errs, fmt.Errorf("Time column with name %s Time field does not match: %v ≠ %v", c, values[i], timeField.Interface()))
}
continue
}
if !reflect.DeepEqual(field.Interface(), values[i]) {
errs = append(errs, fmt.Errorf("Column with name %s does not match value: %#v ≠ %#v", c, values[i], field.Interface()))
}
}
return errs
} }
// GetStructValues returns the values (as interface) of the matching columns in obj // GetStructValues returns the values (as interface) of the matching columns in obj

View file

@ -28,14 +28,16 @@ func TestIsZeroValue(t *testing.T) {
E int64 E int64
}{} }{}
if !isZeroValue(o, "A", "B", "C", "D", "E") { if errs := IsZeroValue(o, true, "A", "B", "C", "D", "E"); errs != nil {
t.Errorf("Expected all values to be zero values: %#v", o) for _, e := range errs {
t.Errorf("%s", e)
}
} }
colNames := []string{"A", "B", "C", "D", "E"} colNames := []string{"A", "B", "C", "D", "E"}
for _, c := range colNames { for _, c := range colNames {
if !isZeroValue(o, c) { if err := IsZeroValue(o, true, c); err != nil {
t.Errorf("Expected %s to be zero value: %#v", c, o) t.Errorf("Expected %s to be zero value: %s", c, err[0])
} }
} }
@ -45,9 +47,83 @@ func TestIsZeroValue(t *testing.T) {
o.D = null.NewInt64(2, false) o.D = null.NewInt64(2, false)
o.E = 5 o.E = 5
if errs := IsZeroValue(o, false, "A", "B", "C", "D", "E"); errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
for _, c := range colNames { for _, c := range colNames {
if isZeroValue(o, c) { if err := IsZeroValue(o, false, c); err != nil {
t.Errorf("Expected %s to be non-zero value: %#v", c, o) t.Errorf("Expected %s to be non-zero value: %s", c, err[0])
}
}
}
func TestIsValueMatch(t *testing.T) {
var errs []error
var values []interface{}
o := struct {
A []byte
B time.Time
C null.Time
D null.Int64
E int64
}{}
values = []interface{}{
[]byte(nil),
time.Time{},
null.Time{},
null.Int64{},
int64(0),
}
cols := []string{"A", "B", "C", "D", "E"}
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
values = []interface{}{
[]byte("hi"),
time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC),
null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true),
null.NewInt64(5, false),
int64(6),
}
errs = IsValueMatch(o, cols, values)
// Expect 6 errors
// 5 for each column and an additional 1 for the invalid Valid field match
if len(errs) != 6 {
t.Errorf("Expected 6 errors, got: %d", len(errs))
for _, e := range errs {
t.Errorf("%s", e)
}
}
o.A = []byte("hi")
o.B = time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC)
o.C = null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true)
o.D = null.NewInt64(5, false)
o.E = 6
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
}
}
o.B = time.Date(2007, 11, 2, 2, 2, 2, 2, time.UTC)
errs = IsValueMatch(o, cols, values)
if errs != nil {
for _, e := range errs {
t.Errorf("%s", e)
} }
} }
} }

View file

@ -193,7 +193,6 @@ var defaultSingletonTestTemplateImports = map[string]imports{
`"os"`, `"os"`,
`"strconv"`, `"strconv"`,
`"math/rand"`, `"math/rand"`,
`"regexp"`,
`"bytes"`, `"bytes"`,
}, },
thirdParty: importList{}, thirdParty: importList{},

View file

@ -145,5 +145,5 @@ var templateFunctions = template.FuncMap{
"columnNames": bdb.ColumnNames, "columnNames": bdb.ColumnNames,
"toManyRelationships": bdb.ToManyRelationships, "toManyRelationships": bdb.ToManyRelationships,
"zeroValue": bdb.ZeroValue, "zeroValue": bdb.ZeroValue,
"defaultValue": bdb.DefaultValue, "defaultValues": bdb.DefaultValues,
} }

View file

@ -3,8 +3,11 @@
{{- $tableNamePlural := .Table.Name | plural | titleCase -}} {{- $tableNamePlural := .Table.Name | plural | titleCase -}}
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $parent := .}}
func Test{{$tableNamePlural}}Insert(t *testing.T) { func Test{{$tableNamePlural}}Insert(t *testing.T) {
var err error var err error
var errs []error
emptyTime := time.Time{}.String()
{{$varNamePlural}}DeleteAllRows(t) {{$varNamePlural}}DeleteAllRows(t)
@ -42,22 +45,37 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err) t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
} }
{{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames}} {{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}
// Ensure the auto increment columns are returned in the object // Ensure the auto increment columns are returned in the object
{{range .}} if errs = boil.IsZeroValue(item, false, {{.}}); errs != nil {
if item.{{titleCase .}} <= 0 { for _, e := range errs {
t.Errorf("Expected the auto-increment columns to be greater than 0, got: %d", item.{{titleCase .}}) t.Errorf("Expected auto-increment columns to be greater than 0, err: %s\n", e)
}
} }
{{end}}
{{end}} {{end}}
emptyTime := time.Time{}.String()
{{with .Table.Columns | filterColumnsBySimpleDefault}} {{with .Table.Columns | filterColumnsBySimpleDefault}}
simpleDefaults := []string{{"{"}}{{. | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}{{"}"}}
defaultValues := []interface{}{{"{"}}{{. | defaultValues | join ", "}}{{"}"}}
if len(simpleDefaults) != len(defaultValues) {
t.Fatalf("Mismatch between slice lengths: %d, %d", len(simpleDefaults), len(defaultValues))
}
if errs = boil.IsValueMatch(item, simpleDefaults, defaultValues); errs != nil {
for _, e := range errs {
t.Errorf("Expected default value to match column value, err: %s\n", e);
}
}
{{end}}
/*{{with .Table.Columns | filterColumnsBySimpleDefault}}
// Ensure the default value columns are returned in the object // Ensure the default value columns are returned in the object
{{range .}} {{range .}}
{{$tc := titleCase .Name}} {{$tc := titleCase .Name}}
{{$zv := zeroValue .}} {{$zv := zeroValue .}}
{{$dv := defaultValue .}} {{$dv := "false"}}
{{$ty := trimPrefix "null." .Type}} {{$ty := trimPrefix "null." .Type}}
{{if and (ne $ty "[]byte") .IsNullable}} {{if and (ne $ty "[]byte") .IsNullable}}
if item.{{$tc}}.Valid == false { if item.{{$tc}}.Valid == false {
@ -84,7 +102,7 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
} }
{{end}} {{end}}
{{end}} {{end}}
{{end}} {{end}}*/
{{with .Table.Columns | filterColumnsByAutoIncrement false | filterColumnsByDefault false}} {{with .Table.Columns | filterColumnsByAutoIncrement false | filterColumnsByDefault false}}
// Ensure the non-defaultvalue columns and non-autoincrement columns are stored correctly as zero or null values. // Ensure the non-defaultvalue columns and non-autoincrement columns are stored correctly as zero or null values.

View file

@ -1,10 +1,5 @@
var dbNameRand *rand.Rand var dbNameRand *rand.Rand
func isZeroTime(time string) bool {
re := regexp.MustCompile(`[2-9]+`)
return !re.MatchString(time)
}
func initDBNameRand(input string) { func initDBNameRand(input string) {
sum := md5.Sum([]byte(input)) sum := md5.Sum([]byte(input))