diff --git a/boil/testing.go b/boil/testing.go index fad72af..70e889d 100644 --- a/boil/testing.go +++ b/boil/testing.go @@ -8,7 +8,7 @@ import ( "sync/atomic" "time" - null "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v4" "github.com/pkg/errors" "github.com/satori/go.uuid" @@ -194,7 +194,7 @@ func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull b } } else { if isNull { - value = getVariableNullValue(kind) + value = getVariableZeroValue(kind) } else { value = s.getVariableRandValue(kind, typ) } @@ -205,6 +205,7 @@ func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull b } field.Set(reflect.ValueOf(value)) + return nil } @@ -290,8 +291,8 @@ func (s *Seed) getStructRandValue(typ reflect.Type) interface{} { return nil } -// getVariableNullValue for the matching type. -func getVariableNullValue(kind reflect.Kind) interface{} { +// getVariableZeroValue for the matching type. +func getVariableZeroValue(kind reflect.Kind) interface{} { switch kind { case reflect.Float32: return float32(0) diff --git a/boil/testing_test.go b/boil/testing_test.go index 8d18759..de3422f 100644 --- a/boil/testing_test.go +++ b/boil/testing_test.go @@ -1,6 +1,7 @@ package boil import ( + "reflect" "testing" "time" @@ -10,7 +11,7 @@ import ( func TestRandomizeStruct(t *testing.T) { t.Parallel() - s := new(Seed) + s := NewSeed() var testStruct = struct { Int int @@ -79,3 +80,67 @@ func TestRandomizeStruct(t *testing.T) { t.Errorf("the null values are not being randomized: %#v", testStruct) } } + +func TestRandomizeField(t *testing.T) { + t.Parallel() + + type RandomizeTest struct { + In interface{} + Typs []string + Out interface{} + } + + s := NewSeed() + inputs := []RandomizeTest{ + {In: &null.Bool{}, Out: null.Bool{}, Typs: []string{"boolean"}}, + {In: &null.String{}, Out: null.String{}, Typs: []string{"character", "uuid", "interval"}}, + {In: &null.Time{}, Out: null.Time{}, Typs: []string{"time"}}, + {In: &null.Float32{}, Out: null.Float32{}, Typs: []string{"real"}}, + {In: &null.Float64{}, Out: null.Float64{}, Typs: []string{"decimal"}}, + {In: &null.Int{}, Out: null.Int{}, Typs: []string{"integer"}}, + {In: &null.Int8{}, Out: null.Int8{}, Typs: []string{"integer"}}, + {In: &null.Int16{}, Out: null.Int16{}, Typs: []string{"smallint"}}, + {In: &null.Int32{}, Out: null.Int32{}, Typs: []string{"integer"}}, + {In: &null.Int64{}, Out: null.Int64{}, Typs: []string{"bigint"}}, + {In: &null.Uint{}, Out: null.Uint{}, Typs: []string{"integer"}}, + {In: &null.Uint8{}, Out: null.Uint8{}, Typs: []string{"integer"}}, + {In: &null.Uint16{}, Out: null.Uint16{}, Typs: []string{"integer"}}, + {In: &null.Uint32{}, Out: null.Uint32{}, Typs: []string{"integer"}}, + {In: &null.Uint64{}, Out: null.Uint64{}, Typs: []string{"integer"}}, + + {In: new(float32), Out: float32(0), Typs: []string{"real"}}, + {In: new(float64), Out: float64(0), Typs: []string{"numeric"}}, + {In: new(int), Out: int(0), Typs: []string{"integer"}}, + {In: new(int8), Out: int8(0), Typs: []string{"integer"}}, + {In: new(int16), Out: int16(0), Typs: []string{"smallserial"}}, + {In: new(int32), Out: int32(0), Typs: []string{"integer"}}, + {In: new(int64), Out: int64(0), Typs: []string{"bigserial"}}, + {In: new(uint), Out: uint(0), Typs: []string{"integer"}}, + {In: new(uint8), Out: uint8(0), Typs: []string{"integer"}}, + {In: new(uint16), Out: uint16(0), Typs: []string{"integer"}}, + {In: new(uint32), Out: uint32(0), Typs: []string{"integer"}}, + {In: new(uint64), Out: uint64(0), Typs: []string{"integer"}}, + + {In: new(bool), Out: false}, + {In: new(string), Out: ""}, + {In: new([]byte), Out: new([]byte)}, + {In: &time.Time{}, Out: &time.Time{}}, + } + + for i := 0; i < len(inputs); i++ { + for _, typ := range inputs[i].Typs { + val := reflect.Indirect(reflect.ValueOf(&inputs[i])) + field := val.FieldByName("In").Elem().Elem() + + // Make sure we never get back values that would be considered null + // by the boil whitelist generator, or by the database driver + if err := s.randomizeField(field, typ, false); err != nil { + t.Errorf("%d) %s", i, err) + } + + if inputs[i].In == inputs[i].Out { + t.Errorf("%d) Field should not be null, got: %v -- type: %s\n", i, inputs[i].In, typ) + } + } + } +}