Added more reflect helpers
This commit is contained in:
parent
08d168605f
commit
013b3ae0f8
9 changed files with 224 additions and 76 deletions
|
@ -71,50 +71,55 @@ var (
|
|||
rgxByteaDefaultValue = regexp.MustCompile(`(?i)\\x([0-9A-F]*)`)
|
||||
)
|
||||
|
||||
// DefaultValue returns the Go converted value of the default value column
|
||||
func DefaultValue(column Column) string {
|
||||
defaultVal := ""
|
||||
// DefaultValues returns the Go converted values of the default value columns
|
||||
func DefaultValues(columns []Column) []string {
|
||||
var dVals []string
|
||||
|
||||
// Attempt to strip out the raw default value if its contained
|
||||
// within a Postgres type cast statement
|
||||
m := rgxRawDefaultValue.FindStringSubmatch(column.Default)
|
||||
if len(m) > 1 {
|
||||
defaultVal = m[len(m)-1]
|
||||
} else {
|
||||
defaultVal = column.Default
|
||||
for _, c := range columns {
|
||||
var dVal string
|
||||
// Attempt to strip out the raw default value if its contained
|
||||
// within a Postgres type cast statement
|
||||
m := rgxRawDefaultValue.FindStringSubmatch(c.Default)
|
||||
if len(m) > 1 {
|
||||
dVal = m[len(m)-1]
|
||||
} else {
|
||||
dVal = c.Default
|
||||
}
|
||||
|
||||
switch c.Type {
|
||||
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
|
||||
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
|
||||
"uint", "uint8", "uint16", "uint32", "uint64",
|
||||
"int", "int8", "int16", "int32", "int64",
|
||||
"null.Float32", "null.Float64", "float32", "float64":
|
||||
dVals = append(dVals, dVal)
|
||||
case "null.Bool", "bool":
|
||||
m = rgxBoolDefaultValue.FindStringSubmatch(dVal)
|
||||
if len(m) == 0 {
|
||||
dVals = append(dVals, "false")
|
||||
}
|
||||
dVals = append(dVals, strings.ToLower(m[0]))
|
||||
case "null.Time", "time.Time", "null.String", "string":
|
||||
dVals = append(dVals, `"`+dVal+`"`)
|
||||
case "[]byte":
|
||||
m := rgxByteaDefaultValue.FindStringSubmatch(dVal)
|
||||
if len(m) != 2 {
|
||||
dVals = append(dVals, `[]byte{}`)
|
||||
}
|
||||
hexstr := m[1]
|
||||
bs := make([]string, len(hexstr)/2)
|
||||
count := 0
|
||||
for i := 0; i < len(hexstr); i += 2 {
|
||||
bs[count] = "0x" + hexstr[i:i+2]
|
||||
count++
|
||||
}
|
||||
dVals = append(dVals, `[]byte{`+strings.Join(bs, ", ")+`}`)
|
||||
default:
|
||||
dVals = append(dVals, "")
|
||||
}
|
||||
}
|
||||
|
||||
switch column.Type {
|
||||
case "null.Uint", "null.Uint8", "null.Uint16", "null.Uint32", "null.Uint64",
|
||||
"null.Int", "null.Int8", "null.Int16", "null.Int32", "null.Int64",
|
||||
"uint", "uint8", "uint16", "uint32", "uint64",
|
||||
"int", "int8", "int16", "int32", "int64",
|
||||
"null.Float32", "null.Float64", "float32", "float64":
|
||||
return defaultVal
|
||||
case "null.Bool", "bool":
|
||||
m = rgxBoolDefaultValue.FindStringSubmatch(defaultVal)
|
||||
if len(m) == 0 {
|
||||
return "false"
|
||||
}
|
||||
return strings.ToLower(m[0])
|
||||
case "null.Time", "time.Time", "null.String", "string":
|
||||
return `"` + defaultVal + `"`
|
||||
case "[]byte":
|
||||
m := rgxByteaDefaultValue.FindStringSubmatch(defaultVal)
|
||||
if len(m) != 2 {
|
||||
return `[]byte{}`
|
||||
}
|
||||
hexstr := m[1]
|
||||
bs := make([]string, len(hexstr)/2)
|
||||
c := 0
|
||||
for i := 0; i < len(hexstr); i += 2 {
|
||||
bs[c] = "0x" + hexstr[i:i+2]
|
||||
c++
|
||||
}
|
||||
return `[]byte{` + strings.Join(bs, ", ") + `}`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return dVals
|
||||
}
|
||||
|
||||
// ZeroValue returns the zero value string of the column type
|
||||
|
@ -134,7 +139,7 @@ func ZeroValue(column Column) string {
|
|||
case "null.Time", "time.Time":
|
||||
return `time.Time{}`
|
||||
case "[]byte":
|
||||
return `[]byte{}`
|
||||
return `[]byte(nil)`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -52,21 +52,24 @@ func TestFilterColumnsByDefault(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDefaultValue(t *testing.T) {
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
c := Column{}
|
||||
|
||||
c.Default = `\x12345678`
|
||||
c.Type = "[]byte"
|
||||
|
||||
res := DefaultValue(c)
|
||||
if res != `[]byte{0x12, 0x34, 0x56, 0x78}` {
|
||||
res := DefaultValues([]Column{c})
|
||||
if len(res) != 1 {
|
||||
t.Errorf("Expected res len 1, got %d", len(res))
|
||||
}
|
||||
if res[0] != `[]byte{0x12, 0x34, 0x56, 0x78}` {
|
||||
t.Errorf("Invalid result: %#v", res)
|
||||
}
|
||||
|
||||
c.Default = `\x`
|
||||
|
||||
res = DefaultValue(c)
|
||||
if res != `[]byte{}` {
|
||||
res = DefaultValues([]Column{c})
|
||||
if res[0] != `[]byte{}` {
|
||||
t.Errorf("Invalid result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,12 +76,12 @@ func TestSetWhere(t *testing.T) {
|
|||
SetWhere(q, "x > $1 AND y > $2", 5, 3)
|
||||
|
||||
if len(q.where) != 1 {
|
||||
t.Errorf("Expected %d where slices, got %d", len(q.where))
|
||||
t.Errorf("Expected %d where slices, got %d", 1, len(q.where))
|
||||
}
|
||||
|
||||
expect := "x > $1 AND y > $2"
|
||||
if q.where[0].clause != expect {
|
||||
t.Errorf("Expected %s, got %s", expect, q.where)
|
||||
t.Errorf("Expected %s, got %v", expect, q.where)
|
||||
}
|
||||
|
||||
if len(q.where[0].args) != 2 {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"math"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
|
@ -30,6 +31,8 @@ var (
|
|||
typeNullBool = reflect.TypeOf(null.Bool{})
|
||||
typeNullTime = reflect.TypeOf(null.Time{})
|
||||
typeTime = reflect.TypeOf(time.Time{})
|
||||
|
||||
rgxValidTime = regexp.MustCompile(`[2-9]+`)
|
||||
)
|
||||
|
||||
// Bind executes the query and inserts the
|
||||
|
@ -159,9 +162,12 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
|
|||
return typ, isSlice, nil
|
||||
}
|
||||
|
||||
// IsZeroValue checks if the variables with matching columns in obj are zero values
|
||||
func isZeroValue(obj interface{}, columns ...string) bool {
|
||||
val := reflect.ValueOf(obj)
|
||||
// IsZeroValue checks if the variables with matching columns in obj
|
||||
// are or are not zero values, depending on whether shouldZero is true or false
|
||||
func IsZeroValue(obj interface{}, shouldZero bool, columns ...string) []error {
|
||||
val := reflect.Indirect(reflect.ValueOf(obj))
|
||||
|
||||
var errs []error
|
||||
for _, c := range columns {
|
||||
field := val.FieldByName(strmangle.TitleCase(c))
|
||||
if !field.IsValid() {
|
||||
|
@ -169,12 +175,58 @@ func isZeroValue(obj interface{}, columns ...string) bool {
|
|||
}
|
||||
|
||||
zv := reflect.Zero(field.Type())
|
||||
if !reflect.DeepEqual(field.Interface(), zv.Interface()) {
|
||||
return false
|
||||
if shouldZero && !reflect.DeepEqual(field.Interface(), zv.Interface()) {
|
||||
errs = append(errs, fmt.Errorf("Column with name %s is not zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
|
||||
} else if !shouldZero && reflect.DeepEqual(field.Interface(), zv.Interface()) {
|
||||
errs = append(errs, fmt.Errorf("Column with name %s is zero value: %#v, %#v", c, field.Interface(), zv.Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return errs
|
||||
}
|
||||
|
||||
// IsValueMatch checks whether the variables in obj with matching column names
|
||||
// match the values in the values slice.
|
||||
func IsValueMatch(obj interface{}, columns []string, values []interface{}) []error {
|
||||
val := reflect.Indirect(reflect.ValueOf(obj))
|
||||
|
||||
var errs []error
|
||||
for i, c := range columns {
|
||||
field := val.FieldByName(strmangle.TitleCase(c))
|
||||
if !field.IsValid() {
|
||||
panic(fmt.Sprintf("Unable to find variable with column name %s", c))
|
||||
}
|
||||
|
||||
typ := field.Type().String()
|
||||
if typ == "time.Time" || typ == "null.Time" {
|
||||
var timeField reflect.Value
|
||||
var valTimeStr string
|
||||
if typ == "time.Time" {
|
||||
valTimeStr = values[i].(time.Time).String()
|
||||
timeField = field
|
||||
} else {
|
||||
valTimeStr = values[i].(null.Time).Time.String()
|
||||
timeField = field.FieldByName("Time")
|
||||
validField := field.FieldByName("Valid")
|
||||
if validField.Interface() != values[i].(null.Time).Valid {
|
||||
errs = append(errs, fmt.Errorf("Null.Time column with name %s Valid field does not match: %v ≠ %v", c, values[i].(null.Time).Valid, validField.Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
if (rgxValidTime.MatchString(valTimeStr) && timeField.Interface() == reflect.Zero(timeField.Type()).Interface()) ||
|
||||
(!rgxValidTime.MatchString(valTimeStr) && timeField.Interface() != reflect.Zero(timeField.Type()).Interface()) {
|
||||
errs = append(errs, fmt.Errorf("Time column with name %s Time field does not match: %v ≠ %v", c, values[i], timeField.Interface()))
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(field.Interface(), values[i]) {
|
||||
errs = append(errs, fmt.Errorf("Column with name %s does not match value: %#v ≠ %#v", c, values[i], field.Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// GetStructValues returns the values (as interface) of the matching columns in obj
|
||||
|
|
|
@ -28,14 +28,16 @@ func TestIsZeroValue(t *testing.T) {
|
|||
E int64
|
||||
}{}
|
||||
|
||||
if !isZeroValue(o, "A", "B", "C", "D", "E") {
|
||||
t.Errorf("Expected all values to be zero values: %#v", o)
|
||||
if errs := IsZeroValue(o, true, "A", "B", "C", "D", "E"); errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
|
||||
colNames := []string{"A", "B", "C", "D", "E"}
|
||||
for _, c := range colNames {
|
||||
if !isZeroValue(o, c) {
|
||||
t.Errorf("Expected %s to be zero value: %#v", c, o)
|
||||
if err := IsZeroValue(o, true, c); err != nil {
|
||||
t.Errorf("Expected %s to be zero value: %s", c, err[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,9 +47,83 @@ func TestIsZeroValue(t *testing.T) {
|
|||
o.D = null.NewInt64(2, false)
|
||||
o.E = 5
|
||||
|
||||
if errs := IsZeroValue(o, false, "A", "B", "C", "D", "E"); errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range colNames {
|
||||
if isZeroValue(o, c) {
|
||||
t.Errorf("Expected %s to be non-zero value: %#v", c, o)
|
||||
if err := IsZeroValue(o, false, c); err != nil {
|
||||
t.Errorf("Expected %s to be non-zero value: %s", c, err[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValueMatch(t *testing.T) {
|
||||
var errs []error
|
||||
var values []interface{}
|
||||
|
||||
o := struct {
|
||||
A []byte
|
||||
B time.Time
|
||||
C null.Time
|
||||
D null.Int64
|
||||
E int64
|
||||
}{}
|
||||
|
||||
values = []interface{}{
|
||||
[]byte(nil),
|
||||
time.Time{},
|
||||
null.Time{},
|
||||
null.Int64{},
|
||||
int64(0),
|
||||
}
|
||||
|
||||
cols := []string{"A", "B", "C", "D", "E"}
|
||||
errs = IsValueMatch(o, cols, values)
|
||||
if errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
|
||||
values = []interface{}{
|
||||
[]byte("hi"),
|
||||
time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC),
|
||||
null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true),
|
||||
null.NewInt64(5, false),
|
||||
int64(6),
|
||||
}
|
||||
|
||||
errs = IsValueMatch(o, cols, values)
|
||||
// Expect 6 errors
|
||||
// 5 for each column and an additional 1 for the invalid Valid field match
|
||||
if len(errs) != 6 {
|
||||
t.Errorf("Expected 6 errors, got: %d", len(errs))
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
|
||||
o.A = []byte("hi")
|
||||
o.B = time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC)
|
||||
o.C = null.NewTime(time.Date(2007, 11, 2, 1, 1, 1, 1, time.UTC), true)
|
||||
o.D = null.NewInt64(5, false)
|
||||
o.E = 6
|
||||
|
||||
errs = IsValueMatch(o, cols, values)
|
||||
if errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
|
||||
o.B = time.Date(2007, 11, 2, 2, 2, 2, 2, time.UTC)
|
||||
errs = IsValueMatch(o, cols, values)
|
||||
if errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("%s", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -193,7 +193,6 @@ var defaultSingletonTestTemplateImports = map[string]imports{
|
|||
`"os"`,
|
||||
`"strconv"`,
|
||||
`"math/rand"`,
|
||||
`"regexp"`,
|
||||
`"bytes"`,
|
||||
},
|
||||
thirdParty: importList{},
|
||||
|
|
|
@ -145,5 +145,5 @@ var templateFunctions = template.FuncMap{
|
|||
"columnNames": bdb.ColumnNames,
|
||||
"toManyRelationships": bdb.ToManyRelationships,
|
||||
"zeroValue": bdb.ZeroValue,
|
||||
"defaultValue": bdb.DefaultValue,
|
||||
"defaultValues": bdb.DefaultValues,
|
||||
}
|
||||
|
|
|
@ -3,8 +3,11 @@
|
|||
{{- $tableNamePlural := .Table.Name | plural | titleCase -}}
|
||||
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
|
||||
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
|
||||
{{- $parent := .}}
|
||||
func Test{{$tableNamePlural}}Insert(t *testing.T) {
|
||||
var err error
|
||||
var errs []error
|
||||
emptyTime := time.Time{}.String()
|
||||
|
||||
{{$varNamePlural}}DeleteAllRows(t)
|
||||
|
||||
|
@ -42,22 +45,37 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
|
|||
t.Errorf("Unable to insert zero-value item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
|
||||
}
|
||||
|
||||
{{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames}}
|
||||
{{with .Table.Columns | filterColumnsByAutoIncrement true | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}
|
||||
// Ensure the auto increment columns are returned in the object
|
||||
{{range .}}
|
||||
if item.{{titleCase .}} <= 0 {
|
||||
t.Errorf("Expected the auto-increment columns to be greater than 0, got: %d", item.{{titleCase .}})
|
||||
if errs = boil.IsZeroValue(item, false, {{.}}); errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("Expected auto-increment columns to be greater than 0, err: %s\n", e)
|
||||
}
|
||||
}
|
||||
{{end}}
|
||||
{{end}}
|
||||
|
||||
emptyTime := time.Time{}.String()
|
||||
{{with .Table.Columns | filterColumnsBySimpleDefault}}
|
||||
simpleDefaults := []string{{"{"}}{{. | columnNames | stringMap $parent.StringFuncs.quoteWrap | join ", "}}{{"}"}}
|
||||
defaultValues := []interface{}{{"{"}}{{. | defaultValues | join ", "}}{{"}"}}
|
||||
|
||||
if len(simpleDefaults) != len(defaultValues) {
|
||||
t.Fatalf("Mismatch between slice lengths: %d, %d", len(simpleDefaults), len(defaultValues))
|
||||
}
|
||||
|
||||
if errs = boil.IsValueMatch(item, simpleDefaults, defaultValues); errs != nil {
|
||||
for _, e := range errs {
|
||||
t.Errorf("Expected default value to match column value, err: %s\n", e);
|
||||
}
|
||||
}
|
||||
{{end}}
|
||||
|
||||
|
||||
/*{{with .Table.Columns | filterColumnsBySimpleDefault}}
|
||||
// Ensure the default value columns are returned in the object
|
||||
{{range .}}
|
||||
{{$tc := titleCase .Name}}
|
||||
{{$zv := zeroValue .}}
|
||||
{{$dv := defaultValue .}}
|
||||
{{$dv := "false"}}
|
||||
{{$ty := trimPrefix "null." .Type}}
|
||||
{{if and (ne $ty "[]byte") .IsNullable}}
|
||||
if item.{{$tc}}.Valid == false {
|
||||
|
@ -84,7 +102,7 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) {
|
|||
}
|
||||
{{end}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
{{end}}*/
|
||||
|
||||
{{with .Table.Columns | filterColumnsByAutoIncrement false | filterColumnsByDefault false}}
|
||||
// Ensure the non-defaultvalue columns and non-autoincrement columns are stored correctly as zero or null values.
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
var dbNameRand *rand.Rand
|
||||
|
||||
func isZeroTime(time string) bool {
|
||||
re := regexp.MustCompile(`[2-9]+`)
|
||||
return !re.MatchString(time)
|
||||
}
|
||||
|
||||
func initDBNameRand(input string) {
|
||||
sum := md5.Sum([]byte(input))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue