Refactor randomize
This commit is contained in:
parent
f23092d4a5
commit
f86d8d7ee9
2 changed files with 23 additions and 22 deletions
|
@ -1,4 +1,5 @@
|
|||
package boil
|
||||
// Package randomize has helpers for randomization of structs and fields
|
||||
package randomize
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
@ -54,9 +55,10 @@ func (s *Seed) nextInt() int {
|
|||
return int(atomic.AddInt64((*int64)(s), 1))
|
||||
}
|
||||
|
||||
// RandomizeStruct takes an object and fills it with random data.
|
||||
// Struct gets its fields filled with random data based on the seed.
|
||||
// It will ignore the fields in the blacklist.
|
||||
func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error {
|
||||
// It will ignore fields that have the struct tag boil:"-"
|
||||
func Struct(s *Seed, str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error {
|
||||
// Don't modify blacklist
|
||||
copyBlacklist := make([]string, len(blacklist))
|
||||
copy(copyBlacklist, blacklist)
|
||||
|
@ -98,13 +100,12 @@ func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canB
|
|||
continue
|
||||
}
|
||||
|
||||
tagVal, _ := getBoilTag(fieldTyp)
|
||||
if tagVal == "-" {
|
||||
if fieldTyp.Tag.Get("boil") == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldDBType := colTypes[fieldTyp.Name]
|
||||
if err := s.randomizeField(fieldVal, fieldDBType, canBeNull); err != nil {
|
||||
if err := randomizeField(s, fieldVal, fieldDBType, canBeNull); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +116,7 @@ func (s *Seed) RandomizeStruct(str interface{}, colTypes map[string]string, canB
|
|||
// randDate generates a random time.Time between 1850 and 2050.
|
||||
// Only the Day/Month/Year columns are set so that Dates and DateTimes do
|
||||
// not cause mismatches in the test data comparisons.
|
||||
func (s *Seed) randDate() time.Time {
|
||||
func randDate(s *Seed) time.Time {
|
||||
t := time.Date(
|
||||
1850+s.nextInt()%160,
|
||||
time.Month(1+(s.nextInt()%12)),
|
||||
|
@ -137,7 +138,7 @@ func (s *Seed) randDate() time.Time {
|
|||
|
||||
// If canBeNull is true:
|
||||
// The value has the possibility of being null or non-zero at random.
|
||||
func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull bool) error {
|
||||
func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bool) error {
|
||||
kind := field.Kind()
|
||||
typ := field.Type()
|
||||
|
||||
|
@ -195,13 +196,13 @@ func (s *Seed) randomizeField(field reflect.Value, fieldType string, canBeNull b
|
|||
if isNull {
|
||||
value = getStructNullValue(typ)
|
||||
} else {
|
||||
value = s.getStructRandValue(typ)
|
||||
value = getStructRandValue(s, typ)
|
||||
}
|
||||
} else {
|
||||
if isNull {
|
||||
value = getVariableZeroValue(kind)
|
||||
} else {
|
||||
value = s.getVariableRandValue(kind, typ)
|
||||
value = getVariableRandValue(s, kind, typ)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,16 +258,16 @@ func getStructNullValue(typ reflect.Type) interface{} {
|
|||
// getStructRandValue returns a "random" value for the matching type.
|
||||
// The randomness is really an incrementation of the global seed,
|
||||
// this is done to avoid duplicate key violations.
|
||||
func (s *Seed) getStructRandValue(typ reflect.Type) interface{} {
|
||||
func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
|
||||
switch typ {
|
||||
case typeTime:
|
||||
return s.randDate()
|
||||
return randDate(s)
|
||||
case typeNullBool:
|
||||
return null.NewBool(s.nextInt()%2 == 0, true)
|
||||
case typeNullString:
|
||||
return null.NewString(s.randStr(1), true)
|
||||
return null.NewString(randStr(s, 1), true)
|
||||
case typeNullTime:
|
||||
return null.NewTime(s.randDate(), true)
|
||||
return null.NewTime(randDate(s), true)
|
||||
case typeNullFloat32:
|
||||
return null.NewFloat32(float32(s.nextInt()%10)/10.0+float32(s.nextInt()%10), true)
|
||||
case typeNullFloat64:
|
||||
|
@ -337,7 +338,7 @@ func getVariableZeroValue(kind reflect.Kind) interface{} {
|
|||
// getVariableRandValue returns a "random" value for the matching kind.
|
||||
// The randomness is really an incrementation of the global seed,
|
||||
// this is done to avoid duplicate key violations.
|
||||
func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interface{} {
|
||||
func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} {
|
||||
switch kind {
|
||||
case reflect.Float32:
|
||||
return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10))
|
||||
|
@ -366,13 +367,13 @@ func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interfa
|
|||
case reflect.Bool:
|
||||
return true
|
||||
case reflect.String:
|
||||
return s.randStr(1)
|
||||
return randStr(s, 1)
|
||||
case reflect.Slice:
|
||||
sliceVal := typ.Elem()
|
||||
if sliceVal.Kind() != reflect.Uint8 {
|
||||
return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String())
|
||||
}
|
||||
return s.randByteSlice(5 + s.nextInt()%20)
|
||||
return randByteSlice(s, 5+s.nextInt()%20)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -380,7 +381,7 @@ func (s *Seed) getVariableRandValue(kind reflect.Kind, typ reflect.Type) interfa
|
|||
|
||||
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func (s *Seed) randStr(ln int) string {
|
||||
func randStr(s *Seed, ln int) string {
|
||||
str := make([]byte, ln)
|
||||
for i := 0; i < ln; i++ {
|
||||
str[i] = byte(alphabet[s.nextInt()%len(alphabet)])
|
||||
|
@ -389,7 +390,7 @@ func (s *Seed) randStr(ln int) string {
|
|||
return string(str)
|
||||
}
|
||||
|
||||
func (s *Seed) randByteSlice(ln int) []byte {
|
||||
func randByteSlice(s *Seed, ln int) []byte {
|
||||
str := make([]byte, ln)
|
||||
for i := 0; i < ln; i++ {
|
||||
str[i] = byte(s.nextInt() % 256)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package boil
|
||||
package randomize
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
@ -51,7 +51,7 @@ func TestRandomizeStruct(t *testing.T) {
|
|||
"NullInterval": "interval",
|
||||
}
|
||||
|
||||
err := s.RandomizeStruct(&testStruct, fieldTypes, true, "Ignore")
|
||||
err := Struct(s, &testStruct, fieldTypes, true, "Ignore")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func TestRandomizeField(t *testing.T) {
|
|||
|
||||
// Make sure we never get back values that would be considered null
|
||||
// by the boil whitelist generator, or by the database driver
|
||||
if err := s.randomizeField(field, typ, false); err != nil {
|
||||
if err := randomizeField(s, field, typ, false); err != nil {
|
||||
t.Errorf("%d) %s", i, err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue