diff --git a/randomize/randomize.go b/randomize/randomize.go index 63a7237..0da7aed 100644 --- a/randomize/randomize.go +++ b/randomize/randomize.go @@ -4,6 +4,7 @@ package randomize import ( "database/sql" "fmt" + "math/rand" "reflect" "regexp" "sort" @@ -34,6 +35,7 @@ var ( typeNullUint32 = reflect.TypeOf(null.Uint32{}) typeNullUint64 = reflect.TypeOf(null.Uint64{}) typeNullString = reflect.TypeOf(null.String{}) + typeNullByte = reflect.TypeOf(null.Byte{}) typeNullBool = reflect.TypeOf(null.Bool{}) typeNullTime = reflect.TypeOf(null.Time{}) typeNullBytes = reflect.TypeOf(null.Bytes{}) @@ -341,7 +343,7 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo // only get zero values for non byte slices // to stop mysql from being a jerk if isNull && kind != reflect.Slice { - value = getVariableZeroValue(s, kind) + value = getVariableZeroValue(s, kind, typ) } else { value = getVariableRandValue(s, kind, typ) } @@ -457,6 +459,8 @@ func getStructNullValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint64(0, false) case typeNullBytes: return null.NewBytes(nil, false) + case typeNullByte: + return null.NewByte(byte(0), false) } return nil @@ -501,13 +505,21 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint64(uint64(s.nextInt()), true) case typeNullBytes: return null.NewBytes(randByteSlice(s, 1), true) + case typeNullByte: + return null.NewByte(byte(rand.Intn(125-65)+65), true) } return nil } // getVariableZeroValue for the matching type. -func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} { +func getVariableZeroValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} { + switch typ.String() { + case "types.Byte": + // Decimal 65 is 'A'. 0 is not a valid UTF8, so cannot use a zero value here. + return types.Byte(65) + } + switch kind { case reflect.Float32: return float32(0) @@ -548,6 +560,11 @@ func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} { // The randomness is really an incrementation of the global seed, // this is done to avoid duplicate key violations. func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} { + switch typ.String() { + case "types.Byte": + return types.Byte(rand.Intn(125-65) + 65) + } + switch kind { case reflect.Float32: return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10)) diff --git a/testdata/postgres_test_schema.sql b/testdata/postgres_test_schema.sql index ac2a24b..c046d29 100644 --- a/testdata/postgres_test_schema.sql +++ b/testdata/postgres_test_schema.sql @@ -31,14 +31,14 @@ CREATE TABLE magic ( nonbyte_four CHAR(1) NOT NULL DEFAULT 'b', nonbyte_five CHAR(1000), nonbyte_six CHAR(1000) NULL, - nonbyte_seven CHAR(1000) NOT NULL, + nonbyte_seven CHAR(1000) NOT NULL, nonbyte_eight CHAR(1000) NULL DEFAULT 'a', nonbyte_nine CHAR(1000) NOT NULL DEFAULT 'b', byte_zero "char", byte_one "char" NULL, - byte_two "char" NOT NULL, - byte_three "char" NULL DEFAULT 'a', + byte_two "char" NULL DEFAULT 'a', + byte_three "char" NOT NULL, byte_four "char" NOT NULL DEFAULT 'b', big_int_zero bigint,