diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 17cf909..472e54c 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -274,7 +274,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Float64" case "real": c.Type = "null.Float32" - case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": c.Type = "null.String" case "bytea": c.Type = "[]byte" diff --git a/boil/testing.go b/boil/testing.go index e597680..dbe1c4d 100644 --- a/boil/testing.go +++ b/boil/testing.go @@ -47,7 +47,7 @@ func (s *seed) nextInt() int { // RandomizeStruct takes an object and fills it with random data. // It will ignore the fields in the blacklist. -func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid bool, blacklist ...string) error { +func RandomizeStruct(str interface{}, colTypes map[string]string, canBeNull bool, blacklist ...string) error { // Don't modify blacklist copyBlacklist := make([]string, len(blacklist)) copy(copyBlacklist, blacklist) @@ -90,7 +90,7 @@ func RandomizeStruct(str interface{}, colTypes map[string]string, includeInvalid } fieldDBType := colTypes[fieldTyp.Name] - if err := randomizeField(fieldVal, fieldDBType, includeInvalid); err != nil { + if err := randomizeField(fieldVal, fieldDBType, canBeNull); err != nil { return err } } @@ -116,170 +116,257 @@ func randDate(sd int) time.Time { return t } +// randomizeField changes the value at field to a "randomized" value. +// +// If canBeNull is false: +// The value will always be a non-null and non-zero value. + +// If canBeNull is true: +// The value has the possibility of being null or non-zero at random. func randomizeField(field reflect.Value, fieldType string, canBeNull bool) error { kind := field.Kind() typ := field.Type() - var newVal interface{} + var value interface{} + var isNull bool - if kind == reflect.Struct { - var notNull bool - if canBeNull { - notNull = rand.Intn(2) == 1 - } else { - notNull = false - } + // Validated columns always need to be set regardless of canBeNull, + // and they have to adhere to a strict value format. + validatedTypes := []string{"uuid", "interval"} + var foundValidated bool - switch typ { - case typeNullBool: - if notNull { - newVal = null.NewBool(sd.nextInt()%2 == 0, true) - } else { - newVal = null.NewBool(false, false) - } - case typeNullString: - if fieldType == "uuid" { - newVal = null.NewString(uuid.NewV4().String(), true) - } else if notNull { - switch fieldType { - case "interval": - newVal = null.NewString(strconv.Itoa((sd.nextInt()%26)+2)+" days", true) - default: - newVal = null.NewString(randStr(1, sd.nextInt()), true) - } - } else { - newVal = null.NewString("", false) - } - case typeNullTime: - if notNull { - newVal = null.NewTime(randDate(sd.nextInt()), true) - } else { - newVal = null.NewTime(time.Time{}, false) - } - case typeTime: - newVal = randDate(sd.nextInt()) - case typeNullFloat32: - if notNull { - newVal = null.NewFloat32(float32(sd.nextInt()%10)/10.0+float32(sd.nextInt()%10), true) - } else { - newVal = null.NewFloat32(0.0, false) - } - case typeNullFloat64: - if notNull { - newVal = null.NewFloat64(float64(sd.nextInt()%10)/10.0+float64(sd.nextInt()%10), true) - } else { - newVal = null.NewFloat64(0.0, false) - } - case typeNullInt: - if notNull { - newVal = null.NewInt(sd.nextInt(), true) - } else { - newVal = null.NewInt(0, false) - } - case typeNullInt8: - if notNull { - newVal = null.NewInt8(int8(sd.nextInt()), true) - } else { - newVal = null.NewInt8(0, false) - } - case typeNullInt16: - if notNull { - newVal = null.NewInt16(int16(sd.nextInt()), true) - } else { - newVal = null.NewInt16(0, false) - } - case typeNullInt32: - if notNull { - newVal = null.NewInt32(int32(sd.nextInt()), true) - } else { - newVal = null.NewInt32(0, false) - } - case typeNullInt64: - if notNull { - newVal = null.NewInt64(int64(sd.nextInt()), true) - } else { - newVal = null.NewInt64(0, false) - } - case typeNullUint: - if notNull { - newVal = null.NewUint(uint(sd.nextInt()), true) - } else { - newVal = null.NewUint(0, false) - } - case typeNullUint8: - if notNull { - newVal = null.NewUint8(uint8(sd.nextInt()), true) - } else { - newVal = null.NewUint8(0, false) - } - case typeNullUint16: - if notNull { - newVal = null.NewUint16(uint16(sd.nextInt()), true) - } else { - newVal = null.NewUint16(0, false) - } - case typeNullUint32: - if notNull { - newVal = null.NewUint32(uint32(sd.nextInt()), true) - } else { - newVal = null.NewUint32(0, false) - } - case typeNullUint64: - if notNull { - newVal = null.NewUint64(uint64(sd.nextInt()), true) - } else { - newVal = null.NewUint64(0, false) - } - } - } else { - switch kind { - case reflect.Float32: - newVal = float32(float32(sd.nextInt()%10)/10.0 + float32(sd.nextInt()%10)) - case reflect.Float64: - newVal = float64(float64(sd.nextInt()%10)/10.0 + float64(sd.nextInt()%10)) - case reflect.Int: - newVal = sd.nextInt() - case reflect.Int8: - newVal = int8(sd.nextInt()) - case reflect.Int16: - newVal = int16(sd.nextInt()) - case reflect.Int32: - newVal = int32(sd.nextInt()) - case reflect.Int64: - newVal = int64(sd.nextInt()) - case reflect.Uint: - newVal = uint(sd.nextInt()) - case reflect.Uint8: - newVal = uint8(sd.nextInt()) - case reflect.Uint16: - newVal = uint16(sd.nextInt()) - case reflect.Uint32: - newVal = uint32(sd.nextInt()) - case reflect.Uint64: - newVal = uint64(sd.nextInt()) - case reflect.Bool: - newVal = sd.nextInt()%2 == 0 - case reflect.String: - switch fieldType { - case "interval": - newVal = strconv.Itoa((sd.nextInt()%26)+2) + " days" - case "uuid": - newVal = uuid.NewV4().String() - default: - newVal = randStr(1, sd.nextInt()) - } - case reflect.Slice: - sliceVal := typ.Elem() - if sliceVal.Kind() != reflect.Uint8 { - return errors.Errorf("unsupported slice type: %T", typ.String()) - } - newVal = randByteSlice(5+rand.Intn(20), sd.nextInt()) - default: - return errors.Errorf("unsupported type: %T", typ.String()) + for _, validated := range validatedTypes { + if fieldType == validated { + foundValidated = true + break } } - field.Set(reflect.ValueOf(newVal)) + if foundValidated { + if kind == reflect.Struct { + switch typ { + case typeNullString: + if fieldType == "interval" { + value = null.NewString(strconv.Itoa((sd.nextInt()%26)+2)+" days", true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "uuid" { + value = null.NewString(uuid.NewV4().String(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + } + } else { + switch kind { + case reflect.String: + if fieldType == "interval" { + value = strconv.Itoa((sd.nextInt()%26)+2) + " days" + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "uuid" { + value = uuid.NewV4().String() + field.Set(reflect.ValueOf(value)) + return nil + } + } + } + } + + // Check the regular columns, these can be set or not set + // depending on the canBeNull flag. + if canBeNull { + // 1 in 3 chance of being null or zero value + isNull = rand.Intn(3) == 1 + } else { + // if canBeNull is false, then never return null values. + isNull = false + } + + // Retrieve the value to be returned + if kind == reflect.Struct { + if isNull { + value = getStructNullValue(typ) + } else { + value = getStructRandValue(typ) + } + } else { + if isNull { + value = getVariableNullValue(kind) + } else { + value = getVariableRandValue(kind, typ) + } + } + + if value == nil { + return errors.Errorf("unsupported type: %T", typ.String()) + } + + field.Set(reflect.ValueOf(value)) + return nil +} + +// getStructNullValue for the matching type. +func getStructNullValue(typ reflect.Type) interface{} { + switch typ { + case typeTime: + return time.Time{} + case typeNullBool: + return null.NewBool(false, false) + case typeNullString: + return null.NewString("", false) + case typeNullTime: + return null.NewTime(time.Time{}, false) + case typeNullFloat32: + return null.NewFloat32(0.0, false) + case typeNullFloat64: + return null.NewFloat64(0.0, false) + case typeNullInt: + return null.NewInt(0, false) + case typeNullInt8: + return null.NewInt8(0, false) + case typeNullInt16: + return null.NewInt16(0, false) + case typeNullInt32: + return null.NewInt32(0, false) + case typeNullInt64: + return null.NewInt64(0, false) + case typeNullUint: + return null.NewUint(0, false) + case typeNullUint8: + return null.NewUint8(0, false) + case typeNullUint16: + return null.NewUint16(0, false) + case typeNullUint32: + return null.NewUint32(0, false) + case typeNullUint64: + return null.NewUint64(0, false) + } + + return nil +} + +// getStructRandValue returns a "random" value for the matching type. +// The randomness is really an incrementation of the global seed, +// this is done to avoid duplicate key violations. +func getStructRandValue(typ reflect.Type) interface{} { + switch typ { + case typeTime: + return randDate(sd.nextInt()) + case typeNullBool: + return null.NewBool(sd.nextInt()%2 == 0, true) + case typeNullString: + return null.NewString(randStr(1, sd.nextInt()), true) + case typeNullTime: + return null.NewTime(randDate(sd.nextInt()), true) + case typeNullFloat32: + return null.NewFloat32(float32(sd.nextInt()%10)/10.0+float32(sd.nextInt()%10), true) + case typeNullFloat64: + return null.NewFloat64(float64(sd.nextInt()%10)/10.0+float64(sd.nextInt()%10), true) + case typeNullInt: + return null.NewInt(sd.nextInt(), true) + case typeNullInt8: + return null.NewInt8(int8(sd.nextInt()), true) + case typeNullInt16: + return null.NewInt16(int16(sd.nextInt()), true) + case typeNullInt32: + return null.NewInt32(int32(sd.nextInt()), true) + case typeNullInt64: + return null.NewInt64(int64(sd.nextInt()), true) + case typeNullUint: + return null.NewUint(uint(sd.nextInt()), true) + case typeNullUint8: + return null.NewUint8(uint8(sd.nextInt()), true) + case typeNullUint16: + return null.NewUint16(uint16(sd.nextInt()), true) + case typeNullUint32: + return null.NewUint32(uint32(sd.nextInt()), true) + case typeNullUint64: + return null.NewUint64(uint64(sd.nextInt()), true) + } + + return nil +} + +// getVariableNullValue for the matching type. +func getVariableNullValue(kind reflect.Kind) interface{} { + switch kind { + case reflect.Float32: + return float32(0) + case reflect.Float64: + return float64(0) + case reflect.Int: + return int(0) + case reflect.Int8: + return int8(0) + case reflect.Int16: + return int16(0) + case reflect.Int32: + return int32(0) + case reflect.Int64: + return int64(0) + case reflect.Uint: + return uint(0) + case reflect.Uint8: + return uint8(0) + case reflect.Uint16: + return uint16(0) + case reflect.Uint32: + return uint32(0) + case reflect.Uint64: + return uint64(0) + case reflect.Bool: + return false + case reflect.String: + return "" + case reflect.Slice: + return []byte(nil) + } + + return nil +} + +// getVariableRandValue returns a "random" value for the matching kind. +// The randomness is really an incrementation of the global seed, +// this is done to avoid duplicate key violations. +func getVariableRandValue(kind reflect.Kind, typ reflect.Type) interface{} { + switch kind { + case reflect.Float32: + return float32(float32(sd.nextInt()%10)/10.0 + float32(sd.nextInt()%10)) + case reflect.Float64: + return float64(float64(sd.nextInt()%10)/10.0 + float64(sd.nextInt()%10)) + case reflect.Int: + return sd.nextInt() + case reflect.Int8: + return int8(sd.nextInt()) + case reflect.Int16: + return int16(sd.nextInt()) + case reflect.Int32: + return int32(sd.nextInt()) + case reflect.Int64: + return int64(sd.nextInt()) + case reflect.Uint: + return uint(sd.nextInt()) + case reflect.Uint8: + return uint8(sd.nextInt()) + case reflect.Uint16: + return uint16(sd.nextInt()) + case reflect.Uint32: + return uint32(sd.nextInt()) + case reflect.Uint64: + return uint64(sd.nextInt()) + case reflect.Bool: + return true + case reflect.String: + return randStr(1, sd.nextInt()) + case reflect.Slice: + sliceVal := typ.Elem() + if sliceVal.Kind() != reflect.Uint8 { + return errors.Errorf("unsupported slice type: %T, was expecting byte slice.", typ.String()) + } + return randByteSlice(5+rand.Intn(20), sd.nextInt()) + } return nil } diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index f61e4c7..e354d00 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -21,8 +21,8 @@ func Test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { t.Fatal(err) } - boil.RandomizeStruct(&b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, true, "{{.ForeignColumn}}") - boil.RandomizeStruct(&c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, true, "{{.ForeignColumn}}") + boil.RandomizeStruct(&b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") + boil.RandomizeStruct(&c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") {{if .Nullable -}} a.{{.Column | titleCase}}.Valid = true {{- end}}