diff --git a/bool.go b/bool.go index 0ad2ea4..be147c8 100644 --- a/bool.go +++ b/bool.go @@ -74,7 +74,7 @@ func (b *Bool) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (b Bool) MarshalJSON() ([]byte, error) { if !b.Valid { - return []byte("null"), nil + return NullBytes, nil } if !b.Bool { return []byte("false"), nil diff --git a/byte.go b/byte.go new file mode 100644 index 0000000..4967c09 --- /dev/null +++ b/byte.go @@ -0,0 +1,128 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "errors" + + "gopkg.in/nullbio/null.v5/convert" +) + +// Byte is an nullable int. +type Byte struct { + Byte byte + Valid bool +} + +// NewByte creates a new Byte +func NewByte(b byte, valid bool) Byte { + return Byte{ + Byte: b, + Valid: valid, + } +} + +// ByteFrom creates a new Byte that will always be valid. +func ByteFrom(b byte) Byte { + return NewByte(b, true) +} + +// ByteFromPtr creates a new Byte that be null if i is nil. +func ByteFromPtr(b *byte) Byte { + if b == nil { + return NewByte(0, false) + } + return NewByte(*b, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Byte) UnmarshalJSON(data []byte) error { + if len(data) == 0 || bytes.Equal(data, NullBytes) { + b.Valid = false + b.Byte = 0 + return nil + } + + var x string + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if len(x) > 1 { + return errors.New("json: cannot convert to byte, text len is greater than one") + } + + b.Byte = []byte(x)[0] + b.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Byte) UnmarshalText(text []byte) error { + if len(text) == 0 { + b.Valid = false + return nil + } + + if len(text) > 1 { + return errors.New("text: cannot convert to byte, text len is greater than one") + } + + b.Valid = true + b.Byte = text[0] + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Byte) MarshalJSON() ([]byte, error) { + if !b.Valid { + return NullBytes, nil + } + return []byte{b.Byte}, nil +} + +// MarshalText implements encoding.TextMarshaler. +func (b Byte) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + return []byte{b.Byte}, nil +} + +// SetValid changes this Byte's value and also sets it to be non-null. +func (b *Byte) SetValid(n byte) { + b.Byte = n + b.Valid = true +} + +// Ptr returns a pointer to this Byte's value, or a nil pointer if this Byte is null. +func (b Byte) Ptr() *byte { + if !b.Valid { + return nil + } + return &b.Byte +} + +// IsZero returns true for invalid Bytes, for future omitempty support (Go 1.4?) +func (b Byte) IsZero() bool { + return !b.Valid +} + +// Scan implements the Scanner interface. +func (b *Byte) Scan(value interface{}) error { + if value == nil { + b.Byte, b.Valid = 0, false + return nil + } + b.Valid = true + return convert.ConvertAssign(&b.Byte, value) +} + +// Value implements the driver Valuer interface. +func (b Byte) Value() (driver.Value, error) { + if !b.Valid { + return nil, nil + } + return []byte{b.Byte}, nil +} diff --git a/byte_test.go b/byte_test.go new file mode 100644 index 0000000..608ec5b --- /dev/null +++ b/byte_test.go @@ -0,0 +1,172 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + byteJSON = []byte(`"b"`) +) + +func TestByteFrom(t *testing.T) { + i := ByteFrom(12345) + assertByte(t, i, "ByteFrom()") + + zero := ByteFrom(0) + if !zero.Valid { + t.Error("ByteFrom(0)", "is invalid, but should be valid") + } +} + +func TestByteFromPtr(t *testing.T) { + n := int(12345) + iptr := &n + i := ByteFromPtr(iptr) + assertByte(t, i, "ByteFromPtr()") + + null := ByteFromPtr(nil) + assertNullByte(t, null, "ByteFromPtr(nil)") +} + +func TestUnmarshalByte(t *testing.T) { + var i Byte + err := json.Unmarshal(intJSON, &i) + maybePanic(err) + assertByte(t, i, "int json") + + var null Byte + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullByte(t, null, "null json") + + var badType Byte + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullByte(t, badType, "wrong type json") + + var invalid Byte + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullByte(t, invalid, "invalid json") +} + +func TestUnmarshalNonByteegerNumber(t *testing.T) { + var i Byte + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int") + } +} + +func TestTextUnmarshalByte(t *testing.T) { + var i Byte + err := i.UnmarshalText([]byte("12345")) + maybePanic(err) + assertByte(t, i, "UnmarshalText() int") + + var blank Byte + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullByte(t, blank, "UnmarshalText() empty int") + + var null Byte + err = null.UnmarshalText([]byte("null")) + maybePanic(err) + assertNullByte(t, null, `UnmarshalText() "null"`) +} + +func TestMarshalByte(t *testing.T) { + i := ByteFrom(12345) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewByte(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalByteText(t *testing.T) { + i := ByteFrom(12345) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewByte(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestBytePointer(t *testing.T) { + i := ByteFrom(12345) + ptr := i.Ptr() + if *ptr != 12345 { + t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 12345) + } + + null := NewByte(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestByteIsZero(t *testing.T) { + i := ByteFrom(12345) + if i.IsZero() { + t.Errorf("IsZero() should be false") + } + + null := NewByte(0, false) + if !null.IsZero() { + t.Errorf("IsZero() should be true") + } + + zero := NewByte(0, true) + if zero.IsZero() { + t.Errorf("IsZero() should be false") + } +} + +func TestByteSetValid(t *testing.T) { + change := NewByte(0, false) + assertNullByte(t, change, "SetValid()") + change.SetValid(12345) + assertByte(t, change, "SetValid()") +} + +func TestByteScan(t *testing.T) { + var i Byte + err := i.Scan(12345) + maybePanic(err) + assertByte(t, i, "scanned int") + + var null Byte + err = null.Scan(nil) + maybePanic(err) + assertNullByte(t, null, "scanned null") +} + +func assertByte(t *testing.T, i Byte, from string) { + if i.Byte != 12345 { + t.Errorf("bad %s int: %d ≠ %d\n", from, i.Byte, 12345) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullByte(t *testing.T, i Byte, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/bytes.go b/bytes.go index b59736f..1d89fe6 100644 --- a/bytes.go +++ b/bytes.go @@ -73,7 +73,7 @@ func (b *Bytes) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (b Bytes) MarshalJSON() ([]byte, error) { if len(b.Bytes) == 0 || b.Bytes == nil { - return []byte("null"), nil + return NullBytes, nil } return b.Bytes, nil } diff --git a/float32.go b/float32.go index fba24c0..32883c1 100644 --- a/float32.go +++ b/float32.go @@ -56,8 +56,7 @@ func (f *Float32) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (f *Float32) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { f.Valid = false return nil } @@ -73,7 +72,7 @@ func (f *Float32) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (f Float32) MarshalJSON() ([]byte, error) { if !f.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatFloat(float64(f.Float32), 'f', -1, 32)), nil } diff --git a/float64.go b/float64.go index 90e02f9..3c64e61 100644 --- a/float64.go +++ b/float64.go @@ -54,8 +54,7 @@ func (f *Float64) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (f *Float64) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { f.Valid = false return nil } @@ -68,7 +67,7 @@ func (f *Float64) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (f Float64) MarshalJSON() ([]byte, error) { if !f.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil } diff --git a/int.go b/int.go index ba69e23..854aab3 100644 --- a/int.go +++ b/int.go @@ -56,8 +56,7 @@ func (i *Int) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (i *Int) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { i.Valid = false return nil } @@ -73,7 +72,7 @@ func (i *Int) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (i Int) MarshalJSON() ([]byte, error) { if !i.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatInt(int64(i.Int), 10)), nil } diff --git a/int16.go b/int16.go index 119a4cc..2ac9f99 100644 --- a/int16.go +++ b/int16.go @@ -62,8 +62,7 @@ func (i *Int16) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (i *Int16) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { i.Valid = false return nil } @@ -79,7 +78,7 @@ func (i *Int16) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (i Int16) MarshalJSON() ([]byte, error) { if !i.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatInt(int64(i.Int16), 10)), nil } diff --git a/int32.go b/int32.go index f4c897d..ad3e132 100644 --- a/int32.go +++ b/int32.go @@ -62,8 +62,7 @@ func (i *Int32) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (i *Int32) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { i.Valid = false return nil } @@ -79,7 +78,7 @@ func (i *Int32) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (i Int32) MarshalJSON() ([]byte, error) { if !i.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatInt(int64(i.Int32), 10)), nil } diff --git a/int64.go b/int64.go index 2157899..29759dd 100644 --- a/int64.go +++ b/int64.go @@ -54,8 +54,7 @@ func (i *Int64) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (i *Int64) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { i.Valid = false return nil } @@ -68,7 +67,7 @@ func (i *Int64) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (i Int64) MarshalJSON() ([]byte, error) { if !i.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatInt(i.Int64, 10)), nil } diff --git a/int8.go b/int8.go index 6036ef7..19eece3 100644 --- a/int8.go +++ b/int8.go @@ -62,8 +62,7 @@ func (i *Int8) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (i *Int8) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { i.Valid = false return nil } @@ -79,7 +78,7 @@ func (i *Int8) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (i Int8) MarshalJSON() ([]byte, error) { if !i.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatInt(int64(i.Int8), 10)), nil } diff --git a/json.go b/json.go index 4720981..8d7e89a 100644 --- a/json.go +++ b/json.go @@ -64,7 +64,7 @@ func (j *JSON) UnmarshalJSON(data []byte) error { } if bytes.Equal(data, NullBytes) { - j.JSON = []byte("null") + j.JSON = NullBytes j.Valid = false return nil } @@ -106,7 +106,7 @@ func (j *JSON) Marshal(obj interface{}) error { // MarshalJSON implements json.Marshaler. func (j JSON) MarshalJSON() ([]byte, error) { if len(j.JSON) == 0 || j.JSON == nil { - return []byte("null"), nil + return NullBytes, nil } return j.JSON, nil } diff --git a/string.go b/string.go index 97c764e..9f868a9 100644 --- a/string.go +++ b/string.go @@ -54,7 +54,7 @@ func (s *String) UnmarshalJSON(data []byte) error { // MarshalJSON implements json.Marshaler. func (s String) MarshalJSON() ([]byte, error) { if !s.Valid { - return []byte("null"), nil + return NullBytes, nil } return json.Marshal(s.String) } diff --git a/time.go b/time.go index efc5d67..3cd0af1 100644 --- a/time.go +++ b/time.go @@ -37,7 +37,7 @@ func TimeFromPtr(t *time.Time) Time { // MarshalJSON implements json.Marshaler. func (t Time) MarshalJSON() ([]byte, error) { if !t.Valid { - return []byte("null"), nil + return NullBytes, nil } return t.Time.MarshalJSON() } @@ -61,15 +61,14 @@ func (t *Time) UnmarshalJSON(data []byte) error { // MarshalText implements encoding.TextMarshaler. func (t Time) MarshalText() ([]byte, error) { if !t.Valid { - return []byte("null"), nil + return NullBytes, nil } return t.Time.MarshalText() } // UnmarshalText implements encoding.TextUnmarshaler. func (t *Time) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { t.Valid = false return nil } diff --git a/uint.go b/uint.go index d8c4b53..4331494 100644 --- a/uint.go +++ b/uint.go @@ -56,8 +56,7 @@ func (u *Uint) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (u *Uint) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { u.Valid = false return nil } @@ -73,7 +72,7 @@ func (u *Uint) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (u Uint) MarshalJSON() ([]byte, error) { if !u.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatUint(uint64(u.Uint), 10)), nil } diff --git a/uint16.go b/uint16.go index 74b11ae..4edb211 100644 --- a/uint16.go +++ b/uint16.go @@ -62,8 +62,7 @@ func (u *Uint16) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (u *Uint16) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { u.Valid = false return nil } @@ -79,7 +78,7 @@ func (u *Uint16) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (u Uint16) MarshalJSON() ([]byte, error) { if !u.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatUint(uint64(u.Uint16), 10)), nil } diff --git a/uint32.go b/uint32.go index 25c52be..920b105 100644 --- a/uint32.go +++ b/uint32.go @@ -62,8 +62,7 @@ func (u *Uint32) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (u *Uint32) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { u.Valid = false return nil } @@ -79,7 +78,7 @@ func (u *Uint32) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (u Uint32) MarshalJSON() ([]byte, error) { if !u.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatUint(uint64(u.Uint32), 10)), nil } diff --git a/uint64.go b/uint64.go index 4c3cf5d..df2f26e 100644 --- a/uint64.go +++ b/uint64.go @@ -54,8 +54,7 @@ func (u *Uint64) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (u *Uint64) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { u.Valid = false return nil } @@ -71,7 +70,7 @@ func (u *Uint64) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (u Uint64) MarshalJSON() ([]byte, error) { if !u.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatUint(u.Uint64, 10)), nil } diff --git a/uint8.go b/uint8.go index 2aeef53..76ed02a 100644 --- a/uint8.go +++ b/uint8.go @@ -62,8 +62,7 @@ func (u *Uint8) UnmarshalJSON(data []byte) error { // UnmarshalText implements encoding.TextUnmarshaler. func (u *Uint8) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { + if len(text) == 0 || bytes.Equal(text, NullBytes) { u.Valid = false return nil } @@ -79,7 +78,7 @@ func (u *Uint8) UnmarshalText(text []byte) error { // MarshalJSON implements json.Marshaler. func (u Uint8) MarshalJSON() ([]byte, error) { if !u.Valid { - return []byte("null"), nil + return NullBytes, nil } return []byte(strconv.FormatUint(uint64(u.Uint8), 10)), nil } @@ -122,9 +121,9 @@ func (u *Uint8) Scan(value interface{}) error { } // Value implements the driver Valuer interface. -func (n Uint8) Value() (driver.Value, error) { - if !n.Valid { +func (u Uint8) Value() (driver.Value, error) { + if !u.Valid { return nil, nil } - return int64(n.Uint8), nil + return int64(u.Uint8), nil }