Refactor RandomizeField

* Fix randomize arg in relationship helper
* Fix param name for RandomizeStruct
This commit is contained in:
Patrick O'brien 2016-08-15 01:32:57 +10:00
parent 6a040bc11c
commit 75e28d3f5b
3 changed files with 247 additions and 160 deletions

View file

@ -274,7 +274,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Float64"
case "real":
c.Type = "null.Float32"
case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case "bytea":
c.Type = "[]byte"

View file

@ -47,7 +47,7 @@ func (s *seed) nextInt() int {
// RandomizeStruct takes an object and fills it with random data.
// It will ignore the fields in the blacklist.
func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid bool, blacklist ...string) error {
func RandomizeStruct(str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error {
// Don't modify blacklist
copyBlacklist := make([]string, len(blacklist))
copy(copyBlacklist, blacklist)
@ -90,7 +90,7 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid
}
fieldDBType := colTypes[fieldTyp.Name]
if err := randomizeField(fieldVal, fieldDBType, includeInvalid); err != nil {
if err := randomizeField(fieldVal, fieldDBType, canBeNull); err != nil {
return err
}
}
@ -116,170 +116,257 @@ func randDate(sd int) time.Time {
return t
}
// randomizeField changes the value at field to a "randomized" value.
//
// If canBeNull is false:
// The value will always be a non-null and non-zero value.
// If canBeNull is true:
// The value has the possibility of being null or non-zero at random.
func randomizeField(field reflect.Value, fieldType string, canBeNull bool) error {
kind := field.Kind()
typ := field.Type()
var newVal interface{}
var value interface{}
var isNull bool
if kind == reflect.Struct {
var notNull bool
if canBeNull {
notNull = rand.Intn(2) == 1
} else {
notNull = false
}
// Validated columns always need to be set regardless of canBeNull,
// and they have to adhere to a strict value format.
validatedTypes := []string{"uuid", "interval"}
var foundValidated bool
switch typ {
case typeNullBool:
if notNull {
newVal = null.NewBool(sd.nextInt()%2 == 0, true)
} else {
newVal = null.NewBool(false, false)
}
case typeNullString:
if fieldType == "uuid" {
newVal = null.NewString(uuid.NewV4().String(), true)
} else if notNull {
switch fieldType {
case "interval":
newVal = null.NewString(strconv.Itoa((sd.nextInt()%26)+2)+" days", true)
default:
newVal = null.NewString(randStr(1, sd.nextInt()), true)
}
} else {
newVal = null.NewString("", false)
}
case typeNullTime:
if notNull {
newVal = null.NewTime(randDate(sd.nextInt()), true)
} else {
newVal = null.NewTime(time.Time{}, false)
}
case typeTime:
newVal = randDate(sd.nextInt())
case typeNullFloat32:
if notNull {
newVal = null.NewFloat32(float32(sd.nextInt()%10)/10.0+float32(sd.nextInt()%10), true)
} else {
newVal = null.NewFloat32(0.0, false)
}
case typeNullFloat64:
if notNull {
newVal = null.NewFloat64(float64(sd.nextInt()%10)/10.0+float64(sd.nextInt()%10), true)
} else {
newVal = null.NewFloat64(0.0, false)
}
case typeNullInt:
if notNull {
newVal = null.NewInt(sd.nextInt(), true)
} else {
newVal = null.NewInt(0, false)
}
case typeNullInt8:
if notNull {
newVal = null.NewInt8(int8(sd.nextInt()), true)
} else {
newVal = null.NewInt8(0, false)
}
case typeNullInt16:
if notNull {
newVal = null.NewInt16(int16(sd.nextInt()), true)
} else {
newVal = null.NewInt16(0, false)
}
case typeNullInt32:
if notNull {
newVal = null.NewInt32(int32(sd.nextInt()), true)
} else {
newVal = null.NewInt32(0, false)
}
case typeNullInt64:
if notNull {
newVal = null.NewInt64(int64(sd.nextInt()), true)
} else {
newVal = null.NewInt64(0, false)
}
case typeNullUint:
if notNull {
newVal = null.NewUint(uint(sd.nextInt()), true)
} else {
newVal = null.NewUint(0, false)
}
case typeNullUint8:
if notNull {
newVal = null.NewUint8(uint8(sd.nextInt()), true)
} else {
newVal = null.NewUint8(0, false)
}
case typeNullUint16:
if notNull {
newVal = null.NewUint16(uint16(sd.nextInt()), true)
} else {
newVal = null.NewUint16(0, false)
}
case typeNullUint32:
if notNull {
newVal = null.NewUint32(uint32(sd.nextInt()), true)
} else {
newVal = null.NewUint32(0, false)
}
case typeNullUint64:
if notNull {
newVal = null.NewUint64(uint64(sd.nextInt()), true)
} else {
newVal = null.NewUint64(0, false)
}
}
} else {
switch kind {
case reflect.Float32:
newVal = float32(float32(sd.nextInt()%10)/10.0 + float32(sd.nextInt()%10))
case reflect.Float64:
newVal = float64(float64(sd.nextInt()%10)/10.0 + float64(sd.nextInt()%10))
case reflect.Int:
newVal = sd.nextInt()
case reflect.Int8:
newVal = int8(sd.nextInt())
case reflect.Int16:
newVal = int16(sd.nextInt())
case reflect.Int32:
newVal = int32(sd.nextInt())
case reflect.Int64:
newVal = int64(sd.nextInt())
case reflect.Uint:
newVal = uint(sd.nextInt())
case reflect.Uint8:
newVal = uint8(sd.nextInt())
case reflect.Uint16:
newVal = uint16(sd.nextInt())
case reflect.Uint32:
newVal = uint32(sd.nextInt())
case reflect.Uint64:
newVal = uint64(sd.nextInt())
case reflect.Bool:
newVal = sd.nextInt()%2 == 0
case reflect.String:
switch fieldType {
case "interval":
newVal = strconv.Itoa((sd.nextInt()%26)+2) + " days"
case "uuid":
newVal = uuid.NewV4().String()
default:
newVal = randStr(1, sd.nextInt())
}
case reflect.Slice:
sliceVal := typ.Elem()
if sliceVal.Kind() != reflect.Uint8 {
return errors.Errorf("unsupported slice type: %T", typ.String())
}
newVal = randByteSlice(5+rand.Intn(20), sd.nextInt())
default:
return errors.Errorf("unsupported type: %T", typ.String())
for _, validated := range validatedTypes {
if fieldType == validated {
foundValidated = true
break
}
}
field.Set(reflect.ValueOf(newVal))
if foundValidated {
if kind == reflect.Struct {
switch typ {
case typeNullString:
if fieldType == "interval" {
value = null.NewString(strconv.Itoa((sd.nextInt()%26)+2)+" days", true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "uuid" {
value = null.NewString(uuid.NewV4().String(), true)
field.Set(reflect.ValueOf(value))
return nil
}
}
} else {
switch kind {
case reflect.String:
if fieldType == "interval" {
value = strconv.Itoa((sd.nextInt()%26)+2) + " days"
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "uuid" {
value = uuid.NewV4().String()
field.Set(reflect.ValueOf(value))
return nil
}
}
}
}
// Check the regular columns, these can be set or not set
// depending on the canBeNull flag.
if canBeNull {
// 1 in 3 chance of being null or zero value
isNull = rand.Intn(3) == 1
} else {
// if canBeNull is false, then never return null values.
isNull = false
}
// Retrieve the value to be returned
if kind == reflect.Struct {
if isNull {
value = getStructNullValue(typ)
} else {
value = getStructRandValue(typ)
}
} else {
if isNull {
value = getVariableNullValue(kind)
} else {
value = getVariableRandValue(kind, typ)
}
}
if value == nil {
return errors.Errorf("unsupported type: %T", typ.String())
}
field.Set(reflect.ValueOf(value))
return nil
}
// getStructNullValue for the matching type.
func getStructNullValue(typ reflect.Type) interface{} {
switch typ {
case typeTime:
return time.Time{}
case typeNullBool:
return null.NewBool(false, false)
case typeNullString:
return null.NewString("", false)
case typeNullTime:
return null.NewTime(time.Time{}, false)
case typeNullFloat32:
return null.NewFloat32(0.0, false)
case typeNullFloat64:
return null.NewFloat64(0.0, false)
case typeNullInt:
return null.NewInt(0, false)
case typeNullInt8:
return null.NewInt8(0, false)
case typeNullInt16:
return null.NewInt16(0, false)
case typeNullInt32:
return null.NewInt32(0, false)
case typeNullInt64:
return null.NewInt64(0, false)
case typeNullUint:
return null.NewUint(0, false)
case typeNullUint8:
return null.NewUint8(0, false)
case typeNullUint16:
return null.NewUint16(0, false)
case typeNullUint32:
return null.NewUint32(0, false)
case typeNullUint64:
return null.NewUint64(0, false)
}
return nil
}
// getStructRandValue returns a "random" value for the matching type.
// The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations.
func getStructRandValue(typ reflect.Type) interface{} {
switch typ {
case typeTime:
return randDate(sd.nextInt())
case typeNullBool:
return null.NewBool(sd.nextInt()%2 == 0, true)
case typeNullString:
return null.NewString(randStr(1, sd.nextInt()), true)
case typeNullTime:
return null.NewTime(randDate(sd.nextInt()), true)
case typeNullFloat32:
return null.NewFloat32(float32(sd.nextInt()%10)/10.0+float32(sd.nextInt()%10), true)
case typeNullFloat64:
return null.NewFloat64(float64(sd.nextInt()%10)/10.0+float64(sd.nextInt()%10), true)
case typeNullInt:
return null.NewInt(sd.nextInt(), true)
case typeNullInt8:
return null.NewInt8(int8(sd.nextInt()), true)
case typeNullInt16:
return null.NewInt16(int16(sd.nextInt()), true)
case typeNullInt32:
return null.NewInt32(int32(sd.nextInt()), true)
case typeNullInt64:
return null.NewInt64(int64(sd.nextInt()), true)
case typeNullUint:
return null.NewUint(uint(sd.nextInt()), true)
case typeNullUint8:
return null.NewUint8(uint8(sd.nextInt()), true)
case typeNullUint16:
return null.NewUint16(uint16(sd.nextInt()), true)
case typeNullUint32:
return null.NewUint32(uint32(sd.nextInt()), true)
case typeNullUint64:
return null.NewUint64(uint64(sd.nextInt()), true)
}
return nil
}
// getVariableNullValue for the matching type.
func getVariableNullValue(kind reflect.Kind) interface{} {
switch kind {
case reflect.Float32:
return float32(0)
case reflect.Float64:
return float64(0)
case reflect.Int:
return int(0)
case reflect.Int8:
return int8(0)
case reflect.Int16:
return int16(0)
case reflect.Int32:
return int32(0)
case reflect.Int64:
return int64(0)
case reflect.Uint:
return uint(0)
case reflect.Uint8:
return uint8(0)
case reflect.Uint16:
return uint16(0)
case reflect.Uint32:
return uint32(0)
case reflect.Uint64:
return uint64(0)
case reflect.Bool:
return false
case reflect.String:
return ""
case reflect.Slice:
return []byte(nil)
}
return nil
}
// getVariableRandValue returns a "random" value for the matching kind.
// The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations.
func getVariableRandValue(kind reflect.Kind, typ reflect.Type) interface{} {
switch kind {
case reflect.Float32:
return float32(float32(sd.nextInt()%10)/10.0 + float32(sd.nextInt()%10))
case reflect.Float64:
return float64(float64(sd.nextInt()%10)/10.0 + float64(sd.nextInt()%10))
case reflect.Int:
return sd.nextInt()
case reflect.Int8:
return int8(sd.nextInt())
case reflect.Int16:
return int16(sd.nextInt())
case reflect.Int32:
return int32(sd.nextInt())
case reflect.Int64:
return int64(sd.nextInt())
case reflect.Uint:
return uint(sd.nextInt())
case reflect.Uint8:
return uint8(sd.nextInt())
case reflect.Uint16:
return uint16(sd.nextInt())
case reflect.Uint32:
return uint32(sd.nextInt())
case reflect.Uint64:
return uint64(sd.nextInt())
case reflect.Bool:
return true
case reflect.String:
return randStr(1, sd.nextInt())
case reflect.Slice:
sliceVal := typ.Elem()
if sliceVal.Kind() != reflect.Uint8 {
return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String())
}
return randByteSlice(5+rand.Intn(20), sd.nextInt())
}
return nil
}

View file

@ -21,8 +21,8 @@ func Test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) {
t.Fatal(err)
}
boil.RandomizeStruct(&b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, true, "{{.ForeignColumn}}")
boil.RandomizeStruct(&c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, true, "{{.ForeignColumn}}")
boil.RandomizeStruct(&b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
boil.RandomizeStruct(&c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
{{if .Nullable -}}
a.{{.Column | titleCase}}.Valid = true
{{- end}}