sqlboiler/boil/randomize/randomize.go
2016-09-09 00:35:43 +10:00

569 lines
14 KiB
Go

// Package randomize has helpers for randomization of structs and fields
package randomize
import (
"fmt"
"math/rand"
"reflect"
"regexp"
"sort"
"strconv"
"sync/atomic"
"time"
"gopkg.in/nullbio/null.v5"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
"github.com/vattle/sqlboiler/boil/types"
"github.com/vattle/sqlboiler/strmangle"
)
var (
typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeNullBytes = reflect.TypeOf(null.Bytes{})
typeNullJSON = reflect.TypeOf(null.JSON{})
typeTime = reflect.TypeOf(time.Time{})
typeJSON = reflect.TypeOf(types.JSON{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
validatedTypes = []string{
"inet", "line", "uuid", "interval",
"json", "jsonb", "box", "cidr", "circle",
"lseg", "macaddr", "path", "pg_lsn", "point",
"polygon", "txid_snapshot", "money",
}
)
// Seed is an atomic counter for pseudo-randomization structs. Using full
// randomization leads to collisions in a domain where uniqueness is an
// important factor.
type Seed int64
// NewSeed creates a new seed for pseudo-randomization.
func NewSeed() *Seed {
s := new(int64)
*s = time.Now().Unix()
return (*Seed)(s)
}
func (s *Seed) nextInt() int {
return int(atomic.AddInt64((*int64)(s), 1))
}
// Struct gets its fields filled with random data based on the seed.
// It will ignore the fields in the blacklist.
// It will ignore fields that have the struct tag boil:"-"
func Struct(s *Seed, str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error {
// Don't modify blacklist
copyBlacklist := make([]string, len(blacklist))
copy(copyBlacklist, blacklist)
blacklist = copyBlacklist
sort.Strings(blacklist)
// Check if it's pointer
value := reflect.ValueOf(str)
kind := value.Kind()
if kind != reflect.Ptr {
return errors.Errorf("Outer element should be a pointer, given a non-pointer: %T", str)
}
// Check if it's a struct
value = value.Elem()
kind = value.Kind()
if kind != reflect.Struct {
return errors.Errorf("Inner element should be a struct, given a non-struct: %T", str)
}
typ := value.Type()
nFields := value.NumField()
// Iterate through fields, randomizing
for i := 0; i < nFields; i++ {
fieldVal := value.Field(i)
fieldTyp := typ.Field(i)
var found bool
for _, v := range blacklist {
if strmangle.TitleCase(v) == fieldTyp.Name {
found = true
break
}
}
if found {
continue
}
if fieldTyp.Tag.Get("boil") == "-" {
continue
}
fieldDBType := colTypes[fieldTyp.Name]
if err := randomizeField(s, fieldVal, fieldDBType, canBeNull); err != nil {
return err
}
}
return nil
}
// randDate generates a random time.Time between 1850 and 2050.
// Only the Day/Month/Year columns are set so that Dates and DateTimes do
// not cause mismatches in the test data comparisons.
func randDate(s *Seed) time.Time {
t := time.Date(
1850+s.nextInt()%160,
time.Month(1+(s.nextInt()%12)),
1+(s.nextInt()%25),
0,
0,
0,
0,
time.UTC,
)
return t
}
// randomizeField changes the value at field to a "randomized" value.
//
// If canBeNull is false:
// The value will always be a non-null and non-zero value.
// If canBeNull is true:
// The value has the possibility of being null or non-zero at random.
func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bool) error {
kind := field.Kind()
typ := field.Type()
var value interface{}
var isNull bool
// Validated columns always need to be set regardless of canBeNull,
// and they have to adhere to a strict value format.
foundValidated := strmangle.SetInclude(fieldType, validatedTypes)
if foundValidated {
if kind == reflect.Struct {
switch typ {
case typeNullString:
if fieldType == "interval" {
value = null.NewString(strconv.Itoa((s.nextInt()%26)+2)+" days", true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "uuid" {
value = null.NewString(uuid.NewV4().String(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = null.NewString(randBox(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = null.NewString(randNetAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = null.NewString(randMacAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = null.NewString(randCircle(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = null.NewString(randLsn(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = null.NewString(randPoint(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = null.NewString(randTxID(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = null.NewString(randMoney(s), true)
field.Set(reflect.ValueOf(value))
return nil
}
case typeNullJSON:
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
field.Set(reflect.ValueOf(value))
return nil
}
} else {
switch kind {
case reflect.String:
if fieldType == "interval" {
value = strconv.Itoa((s.nextInt()%26)+2) + " days"
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "uuid" {
value = uuid.NewV4().String()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = randBox()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = randNetAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = randMacAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = randCircle()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = randLsn()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = randPoint()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = randTxID()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = randMoney(s)
field.Set(reflect.ValueOf(value))
return nil
}
}
switch typ {
case typeJSON:
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
field.Set(reflect.ValueOf(value))
return nil
}
}
}
// Check the regular columns, these can be set or not set
// depending on the canBeNull flag.
if canBeNull {
// 1 in 3 chance of being null or zero value
isNull = s.nextInt()%3 == 0
} else {
// if canBeNull is false, then never return null values.
isNull = false
}
// Retrieve the value to be returned
if kind == reflect.Struct {
if isNull {
value = getStructNullValue(typ)
} else {
value = getStructRandValue(s, typ)
}
} else {
if isNull {
value = getVariableZeroValue(kind)
} else {
value = getVariableRandValue(s, kind, typ)
}
}
if value == nil {
return errors.Errorf("unsupported type: %s", typ.String())
}
field.Set(reflect.ValueOf(value))
return nil
}
// getStructNullValue for the matching type.
func getStructNullValue(typ reflect.Type) interface{} {
switch typ {
case typeTime:
return time.Time{}
case typeNullBool:
return null.NewBool(false, false)
case typeNullString:
return null.NewString("", false)
case typeNullTime:
return null.NewTime(time.Time{}, false)
case typeNullFloat32:
return null.NewFloat32(0.0, false)
case typeNullFloat64:
return null.NewFloat64(0.0, false)
case typeNullInt:
return null.NewInt(0, false)
case typeNullInt8:
return null.NewInt8(0, false)
case typeNullInt16:
return null.NewInt16(0, false)
case typeNullInt32:
return null.NewInt32(0, false)
case typeNullInt64:
return null.NewInt64(0, false)
case typeNullUint:
return null.NewUint(0, false)
case typeNullUint8:
return null.NewUint8(0, false)
case typeNullUint16:
return null.NewUint16(0, false)
case typeNullUint32:
return null.NewUint32(0, false)
case typeNullUint64:
return null.NewUint64(0, false)
case typeNullBytes:
return null.NewBytes(nil, false)
}
return nil
}
// getStructRandValue returns a "random" value for the matching type.
// The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations.
func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
switch typ {
case typeTime:
return randDate(s)
case typeNullBool:
return null.NewBool(s.nextInt()%2 == 0, true)
case typeNullString:
return null.NewString(randStr(s, 1), true)
case typeNullTime:
return null.NewTime(randDate(s), true)
case typeNullFloat32:
return null.NewFloat32(float32(s.nextInt()%10)/10.0+float32(s.nextInt()%10), true)
case typeNullFloat64:
return null.NewFloat64(float64(s.nextInt()%10)/10.0+float64(s.nextInt()%10), true)
case typeNullInt:
return null.NewInt(s.nextInt(), true)
case typeNullInt8:
return null.NewInt8(int8(s.nextInt()), true)
case typeNullInt16:
return null.NewInt16(int16(s.nextInt()), true)
case typeNullInt32:
return null.NewInt32(int32(s.nextInt()), true)
case typeNullInt64:
return null.NewInt64(int64(s.nextInt()), true)
case typeNullUint:
return null.NewUint(uint(s.nextInt()), true)
case typeNullUint8:
return null.NewUint8(uint8(s.nextInt()), true)
case typeNullUint16:
return null.NewUint16(uint16(s.nextInt()), true)
case typeNullUint32:
return null.NewUint32(uint32(s.nextInt()), true)
case typeNullUint64:
return null.NewUint64(uint64(s.nextInt()), true)
case typeNullBytes:
return null.NewBytes(randByteSlice(s, 16), true)
}
return nil
}
// getVariableZeroValue for the matching type.
func getVariableZeroValue(kind reflect.Kind) interface{} {
switch kind {
case reflect.Float32:
return float32(0)
case reflect.Float64:
return float64(0)
case reflect.Int:
return int(0)
case reflect.Int8:
return int8(0)
case reflect.Int16:
return int16(0)
case reflect.Int32:
return int32(0)
case reflect.Int64:
return int64(0)
case reflect.Uint:
return uint(0)
case reflect.Uint8:
return uint8(0)
case reflect.Uint16:
return uint16(0)
case reflect.Uint32:
return uint32(0)
case reflect.Uint64:
return uint64(0)
case reflect.Bool:
return false
case reflect.String:
return ""
case reflect.Slice:
return []byte(nil)
}
return nil
}
// getVariableRandValue returns a "random" value for the matching kind.
// The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations.
func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} {
switch kind {
case reflect.Float32:
return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10))
case reflect.Float64:
return float64(float64(s.nextInt()%10)/10.0 + float64(s.nextInt()%10))
case reflect.Int:
return s.nextInt()
case reflect.Int8:
return int8(s.nextInt())
case reflect.Int16:
return int16(s.nextInt())
case reflect.Int32:
return int32(s.nextInt())
case reflect.Int64:
return int64(s.nextInt())
case reflect.Uint:
return uint(s.nextInt())
case reflect.Uint8:
return uint8(s.nextInt())
case reflect.Uint16:
return uint16(s.nextInt())
case reflect.Uint32:
return uint32(s.nextInt())
case reflect.Uint64:
return uint64(s.nextInt())
case reflect.Bool:
return true
case reflect.String:
return randStr(s, 1)
case reflect.Slice:
sliceVal := typ.Elem()
if sliceVal.Kind() != reflect.Uint8 {
return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String())
}
return randByteSlice(s, 5+s.nextInt()%20)
}
return nil
}
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randStr(s *Seed, ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabet[s.nextInt()%len(alphabet)])
}
return string(str)
}
func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256)
}
return str
}
func randPoint() string {
a := rand.Intn(100)
b := a + 1
return fmt.Sprintf("(%d,%d)", a, b)
}
func randBox() string {
a := rand.Intn(100)
b := a + 1
c := a + 2
d := a + 3
return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d)
}
func randCircle() string {
a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100)
return fmt.Sprintf("((%d,%d),%d)", a, b, c)
}
func randNetAddr() string {
return fmt.Sprintf(
"%d.%d.%d.%d",
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
)
}
func randMacAddr() string {
buf := make([]byte, 6)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
// Set the local bit
buf[0] |= 2
return fmt.Sprintf(
"%02x:%02x:%02x:%02x:%02x:%02x",
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5],
)
}
func randLsn() string {
a := rand.Int63n(9000000)
b := rand.Int63n(9000000)
return fmt.Sprintf("%d/%d", a, b)
}
func randTxID() string {
// Order of integers is relevant
a := rand.Intn(200) + 100
b := a + 100
c := a
d := a + 50
return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d)
}
func randMoney(s *Seed) string {
return fmt.Sprintf("%d.00", s.nextInt())
}