Add NullJSON and JSON types, fix randomize struct

This commit is contained in:
Patrick O'brien 2016-09-08 19:07:33 +10:00
parent ce8573eccd
commit 757cbde016
5 changed files with 224 additions and 8 deletions

View file

@ -281,8 +281,10 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Float32"
case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case "bytea", "json", "jsonb":
c.Type = "[]byte"
case "bytea":
c.Type = "null.Bytes"
case "json", "jsonb":
c.Type = "null.JSON"
case "boolean":
c.Type = "null.Bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone":
@ -305,7 +307,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "string"
case "json", "jsonb":
c.Type = "json.RawMessage"
c.Type = "types.JSON"
case "bytea":
c.Type = "[]byte"
case "boolean":

View file

@ -2,6 +2,7 @@
package randomize
import (
"fmt"
"reflect"
"regexp"
"sort"
@ -13,6 +14,7 @@ import (
"github.com/pkg/errors"
"github.com/satori/go.uuid"
"github.com/vattle/sqlboiler/boil/types"
"github.com/vattle/sqlboiler/strmangle"
)
@ -32,11 +34,13 @@ var (
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeNullBytes = reflect.TypeOf(null.Bytes{})
typeNullJSON = reflect.TypeOf(null.JSON{})
typeTime = reflect.TypeOf(time.Time{})
typeJSON = reflect.TypeOf(types.JSON{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
rgxValidTime = regexp.MustCompile(`[2-9]+`)
validatedTypes = []string{"uuid", "interval"}
validatedTypes = []string{"uuid", "interval", "json", "jsonb"}
)
// Seed is an atomic counter for pseudo-randomization structs. Using full
@ -163,6 +167,10 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
field.Set(reflect.ValueOf(value))
return nil
}
case typeNullJSON:
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
field.Set(reflect.ValueOf(value))
return nil
}
} else {
switch kind {
@ -178,6 +186,12 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
return nil
}
}
switch typ {
case typeJSON:
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
field.Set(reflect.ValueOf(value))
return nil
}
}
}
@ -250,6 +264,8 @@ func getStructNullValue(typ reflect.Type) interface{} {
return null.NewUint32(0, false)
case typeNullUint64:
return null.NewUint64(0, false)
case typeNullBytes:
return null.NewBytes(nil, false)
}
return nil
@ -292,6 +308,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
return null.NewUint32(uint32(s.nextInt()), true)
case typeNullUint64:
return null.NewUint64(uint64(s.nextInt()), true)
case typeNullBytes:
return null.NewBytes(randByteSlice(s, 16), true)
}
return nil

77
boil/types/json.go Normal file
View file

@ -0,0 +1,77 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// JSON is an alias for json.RawMessage, which is
// a []byte underneath.
// JSON implements Marshal and Unmarshal.
type JSON json.RawMessage
// String output your JSON.
func (j JSON) String() string {
return string(j)
}
// Unmarshal your JSON variable into dest.
func (j JSON) Unmarshal(dest interface{}) error {
return json.Unmarshal(j, dest)
}
// Marshal obj into your JSON variable.
func (j *JSON) Marshal(obj interface{}) error {
res, err := json.Marshal(obj)
if err != nil {
return err
}
*j = res
return nil
}
// UnmarshalJSON sets *j to a copy of data.
func (j *JSON) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSON: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
// MarshalJSON returns j as the JSON encoding of j.
func (j JSON) MarshalJSON() ([]byte, error) {
return j, nil
}
// Value returns j as a value.
// Unmarshal into RawMessage for validation.
func (j JSON) Value() (driver.Value, error) {
var r json.RawMessage
if err := j.Unmarshal(&r); err != nil {
return nil, err
}
return []byte(r), nil
}
// Scan stores the src in *j.
func (j *JSON) Scan(src interface{}) error {
var source []byte
switch src.(type) {
case string:
source = []byte(src.(string))
case []byte:
source = src.([]byte)
default:
return errors.New("Incompatible type for JSON")
}
*j = JSON(append((*j)[0:0], source...))
return nil
}

119
boil/types/json_test.go Normal file
View file

@ -0,0 +1,119 @@
package types
import (
"bytes"
"testing"
)
func TestJSONString(t *testing.T) {
t.Parallel()
j := JSON("hello")
if j.String() != "hello" {
t.Errorf("Expected %q, got %s", "hello", j.String())
}
}
func TestJSONUnmarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
var jt JSONTest
j := JSON(`{"Name":"hi","Age":15}`)
err := j.Unmarshal(&jt)
if err != nil {
t.Error(err)
}
if jt.Name != "hi" {
t.Errorf("Expected %q, got %s", "hi", jt.Name)
}
if jt.Age != 15 {
t.Errorf("Expected %v, got %v", 15, jt.Age)
}
}
func TestJSONMarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
jt := JSONTest{
Name: "hi",
Age: 15,
}
var j JSON
err := j.Marshal(jt)
if err != nil {
t.Error(err)
}
if j.String() != `{"Name":"hi","Age":15}` {
t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String())
}
}
func TestJSONUnmarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(nil)
err := j.UnmarshalJSON(JSON(`"hi"`))
if err != nil {
t.Error(err)
}
if j.String() != `"hi"` {
t.Errorf("Expected %q, got %s", "hi", j.String())
}
}
func TestJSONMarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(`"hi"`)
res, err := j.MarshalJSON()
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(`"hi"`)) {
t.Errorf("Expected %q, got %v", `"hi"`, res)
}
}
func TestJSONValue(t *testing.T) {
t.Parallel()
j := JSON(`{"Name":"hi","Age":15}`)
v, err := j.Value()
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, v.([]byte)) {
t.Errorf("byte mismatch, %v %v", j, v)
}
}
func TestJSONScan(t *testing.T) {
t.Parallel()
j := JSON{}
err := j.Scan(`"hello"`)
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, []byte(`"hello"`)) {
t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`)))
}
}

View file

@ -299,7 +299,7 @@ var importsBasedOnType = map[string]imports{
"time.Time": {
standard: importList{`"time"`},
},
"json.RawBytes": {
standard: importList{`"encoding/json"`},
"types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
}