Add NullJSON and JSON types, fix randomize struct
This commit is contained in:
parent
ce8573eccd
commit
757cbde016
5 changed files with 224 additions and 8 deletions
|
@ -281,8 +281,10 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
|
|||
c.Type = "null.Float32"
|
||||
case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
|
||||
c.Type = "null.String"
|
||||
case "bytea", "json", "jsonb":
|
||||
c.Type = "[]byte"
|
||||
case "bytea":
|
||||
c.Type = "null.Bytes"
|
||||
case "json", "jsonb":
|
||||
c.Type = "null.JSON"
|
||||
case "boolean":
|
||||
c.Type = "null.Bool"
|
||||
case "date", "time", "timestamp without time zone", "timestamp with time zone":
|
||||
|
@ -305,7 +307,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
|
|||
case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
|
||||
c.Type = "string"
|
||||
case "json", "jsonb":
|
||||
c.Type = "json.RawMessage"
|
||||
c.Type = "types.JSON"
|
||||
case "bytea":
|
||||
c.Type = "[]byte"
|
||||
case "boolean":
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package randomize
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/satori/go.uuid"
|
||||
"github.com/vattle/sqlboiler/boil/types"
|
||||
"github.com/vattle/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
|
@ -32,11 +34,13 @@ var (
|
|||
typeNullString = reflect.TypeOf(null.String{})
|
||||
typeNullBool = reflect.TypeOf(null.Bool{})
|
||||
typeNullTime = reflect.TypeOf(null.Time{})
|
||||
typeNullBytes = reflect.TypeOf(null.Bytes{})
|
||||
typeNullJSON = reflect.TypeOf(null.JSON{})
|
||||
typeTime = reflect.TypeOf(time.Time{})
|
||||
typeJSON = reflect.TypeOf(types.JSON{})
|
||||
rgxValidTime = regexp.MustCompile(`[2-9]+`)
|
||||
|
||||
rgxValidTime = regexp.MustCompile(`[2-9]+`)
|
||||
|
||||
validatedTypes = []string{"uuid", "interval"}
|
||||
validatedTypes = []string{"uuid", "interval", "json", "jsonb"}
|
||||
)
|
||||
|
||||
// Seed is an atomic counter for pseudo-randomization structs. Using full
|
||||
|
@ -163,6 +167,10 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
|
|||
field.Set(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
case typeNullJSON:
|
||||
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
|
||||
field.Set(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
switch kind {
|
||||
|
@ -178,6 +186,12 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
|
|||
return nil
|
||||
}
|
||||
}
|
||||
switch typ {
|
||||
case typeJSON:
|
||||
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
|
||||
field.Set(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,6 +264,8 @@ func getStructNullValue(typ reflect.Type) interface{} {
|
|||
return null.NewUint32(0, false)
|
||||
case typeNullUint64:
|
||||
return null.NewUint64(0, false)
|
||||
case typeNullBytes:
|
||||
return null.NewBytes(nil, false)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -292,6 +308,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
|
|||
return null.NewUint32(uint32(s.nextInt()), true)
|
||||
case typeNullUint64:
|
||||
return null.NewUint64(uint64(s.nextInt()), true)
|
||||
case typeNullBytes:
|
||||
return null.NewBytes(randByteSlice(s, 16), true)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
77
boil/types/json.go
Normal file
77
boil/types/json.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// JSON is an alias for json.RawMessage, which is
|
||||
// a []byte underneath.
|
||||
// JSON implements Marshal and Unmarshal.
|
||||
type JSON json.RawMessage
|
||||
|
||||
// String output your JSON.
|
||||
func (j JSON) String() string {
|
||||
return string(j)
|
||||
}
|
||||
|
||||
// Unmarshal your JSON variable into dest.
|
||||
func (j JSON) Unmarshal(dest interface{}) error {
|
||||
return json.Unmarshal(j, dest)
|
||||
}
|
||||
|
||||
// Marshal obj into your JSON variable.
|
||||
func (j *JSON) Marshal(obj interface{}) error {
|
||||
res, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*j = res
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON sets *j to a copy of data.
|
||||
func (j *JSON) UnmarshalJSON(data []byte) error {
|
||||
if j == nil {
|
||||
return errors.New("JSON: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
|
||||
*j = append((*j)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON returns j as the JSON encoding of j.
|
||||
func (j JSON) MarshalJSON() ([]byte, error) {
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// Value returns j as a value.
|
||||
// Unmarshal into RawMessage for validation.
|
||||
func (j JSON) Value() (driver.Value, error) {
|
||||
var r json.RawMessage
|
||||
if err := j.Unmarshal(&r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(r), nil
|
||||
}
|
||||
|
||||
// Scan stores the src in *j.
|
||||
func (j *JSON) Scan(src interface{}) error {
|
||||
var source []byte
|
||||
|
||||
switch src.(type) {
|
||||
case string:
|
||||
source = []byte(src.(string))
|
||||
case []byte:
|
||||
source = src.([]byte)
|
||||
default:
|
||||
return errors.New("Incompatible type for JSON")
|
||||
}
|
||||
|
||||
*j = JSON(append((*j)[0:0], source...))
|
||||
|
||||
return nil
|
||||
}
|
119
boil/types/json_test.go
Normal file
119
boil/types/json_test.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJSONString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
j := JSON("hello")
|
||||
if j.String() != "hello" {
|
||||
t.Errorf("Expected %q, got %s", "hello", j.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type JSONTest struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
var jt JSONTest
|
||||
|
||||
j := JSON(`{"Name":"hi","Age":15}`)
|
||||
err := j.Unmarshal(&jt)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if jt.Name != "hi" {
|
||||
t.Errorf("Expected %q, got %s", "hi", jt.Name)
|
||||
}
|
||||
if jt.Age != 15 {
|
||||
t.Errorf("Expected %v, got %v", 15, jt.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONMarshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type JSONTest struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
jt := JSONTest{
|
||||
Name: "hi",
|
||||
Age: 15,
|
||||
}
|
||||
|
||||
var j JSON
|
||||
err := j.Marshal(jt)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if j.String() != `{"Name":"hi","Age":15}` {
|
||||
t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
j := JSON(nil)
|
||||
|
||||
err := j.UnmarshalJSON(JSON(`"hi"`))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if j.String() != `"hi"` {
|
||||
t.Errorf("Expected %q, got %s", "hi", j.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONMarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
j := JSON(`"hi"`)
|
||||
res, err := j.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(res, []byte(`"hi"`)) {
|
||||
t.Errorf("Expected %q, got %v", `"hi"`, res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
j := JSON(`{"Name":"hi","Age":15}`)
|
||||
v, err := j.Value()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(j, v.([]byte)) {
|
||||
t.Errorf("byte mismatch, %v %v", j, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONScan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
j := JSON{}
|
||||
|
||||
err := j.Scan(`"hello"`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(j, []byte(`"hello"`)) {
|
||||
t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`)))
|
||||
}
|
||||
}
|
|
@ -299,7 +299,7 @@ var importsBasedOnType = map[string]imports{
|
|||
"time.Time": {
|
||||
standard: importList{`"time"`},
|
||||
},
|
||||
"json.RawBytes": {
|
||||
standard: importList{`"encoding/json"`},
|
||||
"types.JSON": {
|
||||
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue