Refactor randomize

This commit is contained in:
Aaron L 2016-08-18 00:03:14 -07:00
parent f23092d4a5
commit f86d8d7ee9
2 changed files with 23 additions and 22 deletions

View file

@ -1,4 +1,5 @@
package boil // Package randomize has helpers for randomization of structs and fields
package randomize
import ( import (
"reflect" "reflect"
@ -54,9 +55,10 @@ func (s *Seed) nextInt() int {
return int(atomic.AddInt64((*int64)(s), 1)) return int(atomic.AddInt64((*int64)(s), 1))
} }
// RandomizeStruct takes an object and fills it with random data. // Struct gets its fields filled with random data based on the seed.
// It will ignore the fields in the blacklist. // It will ignore the fields in the blacklist.
func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error { // It will ignore fields that have the struct tag boil:"-"
func Struct(s *Seed, str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error {
// Don't modify blacklist // Don't modify blacklist
copyBlacklist := make([]string, len(blacklist)) copyBlacklist := make([]string, len(blacklist))
copy(copyBlacklist, blacklist) copy(copyBlacklist, blacklist)
@ -98,13 +100,12 @@ func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canB
continue continue
} }
tagVal, _ := getBoilTag(fieldTyp) if fieldTyp.Tag.Get("boil") == "-" {
if tagVal == "-" {
continue continue
} }
fieldDBType := colTypes[fieldTyp.Name] fieldDBType := colTypes[fieldTyp.Name]
if err := s.randomizeField(fieldVal, fieldDBType, canBeNull); err != nil { if err := randomizeField(s, fieldVal, fieldDBType, canBeNull); err != nil {
return err return err
} }
} }
@ -115,7 +116,7 @@ func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canB
// randDate generates a random time.Time between 1850 and 2050. // randDate generates a random time.Time between 1850 and 2050.
// Only the Day/Month/Year columns are set so that Dates and DateTimes do // Only the Day/Month/Year columns are set so that Dates and DateTimes do
// not cause mismatches in the test data comparisons. // not cause mismatches in the test data comparisons.
func (s *Seed) randDate() time.Time { func randDate(s *Seed) time.Time {
t := time.Date( t := time.Date(
1850+s.nextInt()%160, 1850+s.nextInt()%160,
time.Month(1+(s.nextInt()%12)), time.Month(1+(s.nextInt()%12)),
@ -137,7 +138,7 @@ func (s *Seed) randDate() time.Time {
// If canBeNull is true: // If canBeNull is true:
// The value has the possibility of being null or non-zero at random. // The value has the possibility of being null or non-zero at random.
func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull bool) error { func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bool) error {
kind := field.Kind() kind := field.Kind()
typ := field.Type() typ := field.Type()
@ -195,13 +196,13 @@ func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull b
if isNull { if isNull {
value = getStructNullValue(typ) value = getStructNullValue(typ)
} else { } else {
value = s.getStructRandValue(typ) value = getStructRandValue(s, typ)
} }
} else { } else {
if isNull { if isNull {
value = getVariableZeroValue(kind) value = getVariableZeroValue(kind)
} else { } else {
value = s.getVariableRandValue(kind, typ) value = getVariableRandValue(s, kind, typ)
} }
} }
@ -257,16 +258,16 @@ func getStructNullValue(typ reflect.Type) interface{} {
// getStructRandValue returns a "random" value for the matching type. // getStructRandValue returns a "random" value for the matching type.
// The randomness is really an incrementation of the global seed, // The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations. // this is done to avoid duplicate key violations.
func (s *Seed) getStructRandValue(typ reflect.Type) interface{} { func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
switch typ { switch typ {
case typeTime: case typeTime:
return s.randDate() return randDate(s)
case typeNullBool: case typeNullBool:
return null.NewBool(s.nextInt()%2 == 0, true) return null.NewBool(s.nextInt()%2 == 0, true)
case typeNullString: case typeNullString:
return null.NewString(s.randStr(1), true) return null.NewString(randStr(s, 1), true)
case typeNullTime: case typeNullTime:
return null.NewTime(s.randDate(), true) return null.NewTime(randDate(s), true)
case typeNullFloat32: case typeNullFloat32:
return null.NewFloat32(float32(s.nextInt()%10)/10.0+float32(s.nextInt()%10), true) return null.NewFloat32(float32(s.nextInt()%10)/10.0+float32(s.nextInt()%10), true)
case typeNullFloat64: case typeNullFloat64:
@ -337,7 +338,7 @@ func getVariableZeroValue(kind reflect.Kind) interface{} {
// getVariableRandValue returns a "random" value for the matching kind. // getVariableRandValue returns a "random" value for the matching kind.
// The randomness is really an incrementation of the global seed, // The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations. // this is done to avoid duplicate key violations.
func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interface{} { func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} {
switch kind { switch kind {
case reflect.Float32: case reflect.Float32:
return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10)) return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10))
@ -366,13 +367,13 @@ func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interfa
case reflect.Bool: case reflect.Bool:
return true return true
case reflect.String: case reflect.String:
return s.randStr(1) return randStr(s, 1)
case reflect.Slice: case reflect.Slice:
sliceVal := typ.Elem() sliceVal := typ.Elem()
if sliceVal.Kind() != reflect.Uint8 { if sliceVal.Kind() != reflect.Uint8 {
return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String()) return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String())
} }
return s.randByteSlice(5 + s.nextInt()%20) return randByteSlice(s, 5+s.nextInt()%20)
} }
return nil return nil
@ -380,7 +381,7 @@ func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interfa
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func (s *Seed) randStr(ln int) string { func randStr(s *Seed, ln int) string {
str := make([]byte, ln) str := make([]byte, ln)
for i := 0; i < ln; i++ { for i := 0; i < ln; i++ {
str[i] = byte(alphabet[s.nextInt()%len(alphabet)]) str[i] = byte(alphabet[s.nextInt()%len(alphabet)])
@ -389,7 +390,7 @@ func (s *Seed) randStr(ln int) string {
return string(str) return string(str)
} }
func (s *Seed) randByteSlice(ln int) []byte { func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln) str := make([]byte, ln)
for i := 0; i < ln; i++ { for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256) str[i] = byte(s.nextInt() % 256)

View file

@ -1,4 +1,4 @@
package boil package randomize
import ( import (
"reflect" "reflect"
@ -51,7 +51,7 @@ func TestRandomizeStruct(t *testing.T) {
"NullInterval": "interval", "NullInterval": "interval",
} }
err := s.RandomizeStruct(&testStruct, fieldTypes, true, "Ignore") err := Struct(s, &testStruct, fieldTypes, true, "Ignore")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -134,7 +134,7 @@ func TestRandomizeField(t *testing.T) {
// Make sure we never get back values that would be considered null // Make sure we never get back values that would be considered null
// by the boil whitelist generator, or by the database driver // by the boil whitelist generator, or by the database driver
if err := s.randomizeField(field, typ, false); err != nil { if err := randomizeField(s, field, typ, false); err != nil {
t.Errorf("%d) %s", i, err) t.Errorf("%d) %s", i, err)
} }