diff --git a/null/.gitignore b/null/.gitignore new file mode 100644 index 0000000..a4131fb --- /dev/null +++ b/null/.gitignore @@ -0,0 +1,2 @@ +coverage.out +/.idea diff --git a/null/LICENSE b/null/LICENSE new file mode 100644 index 0000000..fe95b73 --- /dev/null +++ b/null/LICENSE @@ -0,0 +1,11 @@ +Copyright for portions of project null-extended are held by *Greg Roseberry, 2014* as part of project null. +All other copyright for project null-extended are held by *Patrick O'Brien, 2016*. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/null/README.md b/null/README.md new file mode 100644 index 0000000..4e0e359 --- /dev/null +++ b/null/README.md @@ -0,0 +1,87 @@ +## null-extended [![GoDoc](https://godoc.org/github.com/nullbio/null?status.svg)](https://godoc.org/github.com/nullbio/null) [![Coverage](http://gocover.io/_badge/github.com/nullbio/null)](http://gocover.io/github.com/nullbio/null) + +null-extended is a library with reasonable options for dealing with nullable SQL and JSON values + +Types in `null` will only be considered null on null input, and will JSON encode to `null`. + +All types implement `sql.Scanner` and `driver.Valuer`, so you can use this library in place of `sql.NullXXX`. All types also implement: `encoding.TextMarshaler`, `encoding.TextUnmarshaler`, `json.Marshaler`, `json.Unmarshaler` and `sql.Scanner`. + +--- + +Install: + +`go get -u "gopkg.in/nullbio/null.v6"` + +### null package + +`import "gopkg.in/nullbio/null.v6"` + +The following are all types supported in this package. All types will marshal to JSON null if Invalid or SQL source data is null. + +#### null.JSON +Nullable []byte. + +Will marshal to JSON null if Invalid. []byte{} input will not produce an Invalid JSON, but []byte(nil) will. This should be used for storing raw JSON in the database. + +Also has `null.JSON.Marshal` and `null.JSON.Unmarshal` helpers to marshal and unmarshal foreign objects. + +#### null.Bytes +Nullable []byte. + +[]byte{} input will not produce an Invalid Bytes, but []byte(nil) will. This should be used for storing binary data (bytea in PSQL for example) in the database. + +#### null.String +Nullable string. + +#### null.Byte +Nullable byte. + +#### null.Bool +Nullable bool. + +#### null.Time +Nullable time.Time + +Marshals to JSON null if SQL source data is null. Uses `time.Time`'s marshaler. + +#### null.Float32 +Nullable float32. + +#### null.Float64 +Nullable float64. + +#### null.Int +Nullable int. + +#### null.Int8 +Nullable int8. + +#### null.Int16 +Nullable int16. + +#### null.Int32 +Nullable int32. + +#### null.Int64 +Nullable int64. + +#### null.Uint +Nullable uint. + +#### null.Uint8 +Nullable uint8. + +#### null.Uint16 +Nullable uint16. + +#### null.Uint32 +Nullable int32. + +#### null.Int64 +Nullable uint64. + +### Bugs +`json`'s `",omitempty"` struct tag does not work correctly right now. It will never omit a null or empty String. This might be [fixed eventually](https://github.com/golang/go/issues/4357). + +### License +BSD diff --git a/null/bool.go b/null/bool.go new file mode 100644 index 0000000..2873e40 --- /dev/null +++ b/null/bool.go @@ -0,0 +1,133 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "errors" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Bool is a nullable bool. +type Bool struct { + Bool bool + Valid bool +} + +// NewBool creates a new Bool +func NewBool(b bool, valid bool) Bool { + return Bool{ + Bool: b, + Valid: valid, + } +} + +// BoolFrom creates a new Bool that will always be valid. +func BoolFrom(b bool) Bool { + return NewBool(b, true) +} + +// BoolFromPtr creates a new Bool that will be null if f is nil. +func BoolFromPtr(b *bool) Bool { + if b == nil { + return NewBool(false, false) + } + return NewBool(*b, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Bool) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + b.Bool = false + b.Valid = false + return nil + } + + if err := json.Unmarshal(data, &b.Bool); err != nil { + return err + } + + b.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Bool) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + b.Valid = false + return nil + } + + str := string(text) + switch str { + case "true": + b.Bool = true + case "false": + b.Bool = false + default: + b.Valid = false + return errors.New("invalid input:" + str) + } + b.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Bool) MarshalJSON() ([]byte, error) { + if !b.Valid { + return NullBytes, nil + } + if !b.Bool { + return []byte("false"), nil + } + return []byte("true"), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (b Bool) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + if !b.Bool { + return []byte("false"), nil + } + return []byte("true"), nil +} + +// SetValid changes this Bool's value and also sets it to be non-null. +func (b *Bool) SetValid(v bool) { + b.Bool = v + b.Valid = true +} + +// Ptr returns a pointer to this Bool's value, or a nil pointer if this Bool is null. +func (b Bool) Ptr() *bool { + if !b.Valid { + return nil + } + return &b.Bool +} + +// IsNull returns true for invalid Bools, for future omitempty support (Go 1.4?) +func (b Bool) IsNull() bool { + return !b.Valid +} + +// Scan implements the Scanner interface. +func (b *Bool) Scan(value interface{}) error { + if value == nil { + b.Bool, b.Valid = false, false + return nil + } + b.Valid = true + return convert.ConvertAssign(&b.Bool, value) +} + +// Value implements the driver Valuer interface. +func (b Bool) Value() (driver.Value, error) { + if !b.Valid { + return nil, nil + } + return b.Bool, nil +} diff --git a/null/bool_test.go b/null/bool_test.go new file mode 100644 index 0000000..5d30eba --- /dev/null +++ b/null/bool_test.go @@ -0,0 +1,196 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + boolJSON = []byte(`true`) + falseJSON = []byte(`false`) +) + +func TestBoolFrom(t *testing.T) { + b := BoolFrom(true) + assertBool(t, b, "BoolFrom()") + + zero := BoolFrom(false) + if !zero.Valid { + t.Error("BoolFrom(false)", "is invalid, but should be valid") + } +} + +func TestBoolFromPtr(t *testing.T) { + n := true + bptr := &n + b := BoolFromPtr(bptr) + assertBool(t, b, "BoolFromPtr()") + + null := BoolFromPtr(nil) + assertNullBool(t, null, "BoolFromPtr(nil)") +} + +func TestUnmarshalBool(t *testing.T) { + var b Bool + err := json.Unmarshal(boolJSON, &b) + maybePanic(err) + assertBool(t, b, "bool json") + + var null Bool + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullBool(t, null, "null json") + + var badType Bool + err = json.Unmarshal(intJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullBool(t, badType, "wrong type json") + + var invalid Bool + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } +} + +func TestTextUnmarshalBool(t *testing.T) { + var b Bool + err := b.UnmarshalText([]byte("true")) + maybePanic(err) + assertBool(t, b, "UnmarshalText() bool") + + var zero Bool + err = zero.UnmarshalText([]byte("false")) + maybePanic(err) + assertFalseBool(t, zero, "UnmarshalText() false") + + var blank Bool + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullBool(t, blank, "UnmarshalText() empty bool") + + var invalid Bool + err = invalid.UnmarshalText([]byte(":D")) + if err == nil { + panic("err should not be nil") + } + assertNullBool(t, invalid, "invalid json") +} + +func TestMarshalBool(t *testing.T) { + b := BoolFrom(true) + data, err := json.Marshal(b) + maybePanic(err) + assertJSONEquals(t, data, "true", "non-empty json marshal") + + zero := NewBool(false, true) + data, err = json.Marshal(zero) + maybePanic(err) + assertJSONEquals(t, data, "false", "zero json marshal") + + // invalid values should be encoded as null + null := NewBool(false, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalBoolText(t *testing.T) { + b := BoolFrom(true) + data, err := b.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "true", "non-empty text marshal") + + zero := NewBool(false, true) + data, err = zero.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "false", "zero text marshal") + + // invalid values should be encoded as null + null := NewBool(false, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestBoolPointer(t *testing.T) { + b := BoolFrom(true) + ptr := b.Ptr() + if *ptr != true { + t.Errorf("bad %s bool: %#v ≠ %v\n", "pointer", ptr, true) + } + + null := NewBool(false, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s bool: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestBoolIsNull(t *testing.T) { + b := BoolFrom(true) + if b.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewBool(false, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewBool(false, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestBoolSetValid(t *testing.T) { + change := NewBool(false, false) + assertNullBool(t, change, "SetValid()") + change.SetValid(true) + assertBool(t, change, "SetValid()") +} + +func TestBoolScan(t *testing.T) { + var b Bool + err := b.Scan(true) + maybePanic(err) + assertBool(t, b, "scanned bool") + + var null Bool + err = null.Scan(nil) + maybePanic(err) + assertNullBool(t, null, "scanned null") +} + +func assertBool(t *testing.T, b Bool, from string) { + if b.Bool != true { + t.Errorf("bad %s bool: %v ≠ %v\n", from, b.Bool, true) + } + if !b.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertFalseBool(t *testing.T, b Bool, from string) { + if b.Bool != false { + t.Errorf("bad %s bool: %v ≠ %v\n", from, b.Bool, false) + } + if !b.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullBool(t *testing.T, b Bool, from string) { + if b.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/byte.go b/null/byte.go new file mode 100644 index 0000000..4c72602 --- /dev/null +++ b/null/byte.go @@ -0,0 +1,135 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "errors" +) + +// Byte is an nullable int. +type Byte struct { + Byte byte + Valid bool +} + +// NewByte creates a new Byte +func NewByte(b byte, valid bool) Byte { + return Byte{ + Byte: b, + Valid: valid, + } +} + +// ByteFrom creates a new Byte that will always be valid. +func ByteFrom(b byte) Byte { + return NewByte(b, true) +} + +// ByteFromPtr creates a new Byte that be null if i is nil. +func ByteFromPtr(b *byte) Byte { + if b == nil { + return NewByte(0, false) + } + return NewByte(*b, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Byte) UnmarshalJSON(data []byte) error { + if len(data) == 0 || bytes.Equal(data, NullBytes) { + b.Valid = false + b.Byte = 0 + return nil + } + + var x string + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if len(x) > 1 { + return errors.New("json: cannot convert to byte, text len is greater than one") + } + + b.Byte = x[0] + b.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Byte) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + b.Valid = false + return nil + } + + if len(text) > 1 { + return errors.New("text: cannot convert to byte, text len is greater than one") + } + + b.Valid = true + b.Byte = text[0] + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Byte) MarshalJSON() ([]byte, error) { + if !b.Valid { + return NullBytes, nil + } + return []byte{'"', b.Byte, '"'}, nil +} + +// MarshalText implements encoding.TextMarshaler. +func (b Byte) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + return []byte{b.Byte}, nil +} + +// SetValid changes this Byte's value and also sets it to be non-null. +func (b *Byte) SetValid(n byte) { + b.Byte = n + b.Valid = true +} + +// Ptr returns a pointer to this Byte's value, or a nil pointer if this Byte is null. +func (b Byte) Ptr() *byte { + if !b.Valid { + return nil + } + return &b.Byte +} + +// IsNull returns true for invalid Bytes, for future omitempty support (Go 1.4?) +func (b Byte) IsNull() bool { + return !b.Valid +} + +// Scan implements the Scanner interface. +func (b *Byte) Scan(value interface{}) error { + if value == nil { + b.Byte, b.Valid = 0, false + return nil + } + + val := value.(string) + if len(val) == 0 { + b.Valid = false + b.Byte = 0 + return nil + } + + b.Valid = true + b.Byte = byte(val[0]) + return nil +} + +// Value implements the driver Valuer interface. +func (b Byte) Value() (driver.Value, error) { + if !b.Valid { + return nil, nil + } + return []byte{b.Byte}, nil +} diff --git a/null/byte_test.go b/null/byte_test.go new file mode 100644 index 0000000..4c7dc81 --- /dev/null +++ b/null/byte_test.go @@ -0,0 +1,168 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + byteJSON = []byte(`"b"`) +) + +func TestByteFrom(t *testing.T) { + i := ByteFrom('b') + assertByte(t, i, "ByteFrom()") + + zero := ByteFrom(0) + if !zero.Valid { + t.Error("ByteFrom(0)", "is invalid, but should be valid") + } +} + +func TestByteFromPtr(t *testing.T) { + n := byte('b') + iptr := &n + i := ByteFromPtr(iptr) + assertByte(t, i, "ByteFromPtr()") + + null := ByteFromPtr(nil) + assertNullByte(t, null, "ByteFromPtr(nil)") +} + +func TestUnmarshalByte(t *testing.T) { + var null Byte + err := json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullByte(t, null, "null json") + + var badType Byte + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullByte(t, badType, "wrong type json") + + var invalid Byte + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullByte(t, invalid, "invalid json") +} + +func TestUnmarshalNonByteegerNumber(t *testing.T) { + var i Byte + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int") + } +} + +func TestTextUnmarshalByte(t *testing.T) { + var i Byte + err := i.UnmarshalText([]byte("b")) + maybePanic(err) + assertByte(t, i, "UnmarshalText() int") + + var blank Byte + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullByte(t, blank, "UnmarshalText() empty int") +} + +func TestMarshalByte(t *testing.T) { + i := ByteFrom('b') + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, `"b"`, "non-empty json marshal") + + // invalid values should be encoded as null + null := NewByte(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalByteText(t *testing.T) { + i := ByteFrom('b') + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "b", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewByte(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestBytePointer(t *testing.T) { + i := ByteFrom('b') + ptr := i.Ptr() + if *ptr != 'b' { + t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 'b') + } + + null := NewByte(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestByteIsNull(t *testing.T) { + i := ByteFrom('b') + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewByte(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewByte(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestByteSetValid(t *testing.T) { + change := NewByte(0, false) + assertNullByte(t, change, "SetValid()") + change.SetValid('b') + assertByte(t, change, "SetValid()") +} + +func TestByteScan(t *testing.T) { + var i Byte + err := i.Scan("b") + maybePanic(err) + assertByte(t, i, "scanned int") + + var null Byte + err = null.Scan(nil) + maybePanic(err) + assertNullByte(t, null, "scanned null") +} + +func assertByte(t *testing.T, i Byte, from string) { + if i.Byte != 'b' { + t.Errorf("bad %s int: %d ≠ %d\n", from, i.Byte, 'b') + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullByte(t *testing.T, i Byte, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/bytes.go b/null/bytes.go new file mode 100644 index 0000000..833e729 --- /dev/null +++ b/null/bytes.go @@ -0,0 +1,124 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + + "gopkg.in/nullbio/null.v6/convert" +) + +// NullBytes is a global byte slice of JSON null +var NullBytes = []byte("null") + +// Bytes is a nullable []byte. +type Bytes struct { + Bytes []byte + Valid bool +} + +// NewBytes creates a new Bytes +func NewBytes(b []byte, valid bool) Bytes { + return Bytes{ + Bytes: b, + Valid: valid, + } +} + +// BytesFrom creates a new Bytes that will be invalid if nil. +func BytesFrom(b []byte) Bytes { + return NewBytes(b, b != nil) +} + +// BytesFromPtr creates a new Bytes that will be invalid if nil. +func BytesFromPtr(b *[]byte) Bytes { + if b == nil { + return NewBytes(nil, false) + } + n := NewBytes(*b, true) + return n +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Bytes) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + b.Valid = false + b.Bytes = nil + return nil + } + + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + b.Bytes = []byte(s) + b.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Bytes) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + b.Bytes = nil + b.Valid = false + } else { + b.Bytes = append(b.Bytes[0:0], text...) + b.Valid = true + } + + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Bytes) MarshalJSON() ([]byte, error) { + if len(b.Bytes) == 0 || b.Bytes == nil { + return NullBytes, nil + } + return b.Bytes, nil +} + +// MarshalText implements encoding.TextMarshaler. +func (b Bytes) MarshalText() ([]byte, error) { + if !b.Valid { + return nil, nil + } + return b.Bytes, nil +} + +// SetValid changes this Bytes's value and also sets it to be non-null. +func (b *Bytes) SetValid(n []byte) { + b.Bytes = n + b.Valid = true +} + +// Ptr returns a pointer to this Bytes's value, or a nil pointer if this Bytes is null. +func (b Bytes) Ptr() *[]byte { + if !b.Valid { + return nil + } + return &b.Bytes +} + +// IsNull returns true for null or zero Bytes's, for future omitempty support (Go 1.4?) +func (b Bytes) IsNull() bool { + return !b.Valid +} + +// Scan implements the Scanner interface. +func (b *Bytes) Scan(value interface{}) error { + if value == nil { + b.Bytes, b.Valid = []byte{}, false + return nil + } + b.Valid = true + return convert.ConvertAssign(&b.Bytes, value) +} + +// Value implements the driver Valuer interface. +func (b Bytes) Value() (driver.Value, error) { + if !b.Valid { + return nil, nil + } + return b.Bytes, nil +} diff --git a/null/bytes_test.go b/null/bytes_test.go new file mode 100644 index 0000000..6bf5fbc --- /dev/null +++ b/null/bytes_test.go @@ -0,0 +1,167 @@ +package null + +import ( + "bytes" + "encoding/json" + "testing" +) + +var ( + bytesJSON = []byte(`"hello"`) +) + +func TestBytesFrom(t *testing.T) { + i := BytesFrom([]byte(`hello`)) + assertBytes(t, i, "BytesFrom()") + + zero := BytesFrom(nil) + if zero.Valid { + t.Error("BytesFrom(nil)", "is valid, but should be invalid") + } + + zero = BytesFrom([]byte{}) + if !zero.Valid { + t.Error("BytesFrom([]byte{})", "is invalid, but should be valid") + } +} + +func TestBytesFromPtr(t *testing.T) { + n := []byte(`hello`) + iptr := &n + i := BytesFromPtr(iptr) + assertBytes(t, i, "BytesFromPtr()") + + null := BytesFromPtr(nil) + assertNullBytes(t, null, "BytesFromPtr(nil)") +} + +func TestUnmarshalBytes(t *testing.T) { + var i Bytes + err := json.Unmarshal(bytesJSON, &i) + maybePanic(err) + assertBytes(t, i, "[]byte json") + + var ni Bytes + err = ni.UnmarshalJSON([]byte{}) + if err == nil { + t.Errorf("Expected error") + } + + var null Bytes + err = null.UnmarshalJSON([]byte("null")) + if null.Valid == true { + t.Errorf("expected Valid to be false, got true") + } + if null.Bytes != nil { + t.Errorf("Expected Bytes to be nil, but was not: %#v %#v", null.Bytes, []byte(`null`)) + } +} + +func TestTextUnmarshalBytes(t *testing.T) { + var i Bytes + err := i.UnmarshalText([]byte(`hello`)) + maybePanic(err) + assertBytes(t, i, "UnmarshalText() []byte") + + var blank Bytes + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullBytes(t, blank, "UnmarshalText() empty []byte") +} + +func TestMarshalBytes(t *testing.T) { + i := BytesFrom([]byte(`"hello"`)) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, `"hello"`, "non-empty json marshal") + + // invalid values should be encoded as null + null := NewBytes(nil, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalBytesText(t *testing.T) { + i := BytesFrom([]byte(`"hello"`)) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, `"hello"`, "non-empty text marshal") + + // invalid values should be encoded as null + null := NewBytes(nil, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestBytesPointer(t *testing.T) { + i := BytesFrom([]byte(`"hello"`)) + ptr := i.Ptr() + if !bytes.Equal(*ptr, []byte(`"hello"`)) { + t.Errorf("bad %s []byte: %#v ≠ %s\n", "pointer", ptr, `"hello"`) + } + + null := NewBytes(nil, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s []byte: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestBytesIsNull(t *testing.T) { + i := BytesFrom([]byte(`"hello"`)) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewBytes(nil, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewBytes(nil, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestBytesSetValid(t *testing.T) { + change := NewBytes(nil, false) + assertNullBytes(t, change, "SetValid()") + change.SetValid([]byte(`hello`)) + assertBytes(t, change, "SetValid()") +} + +func TestBytesScan(t *testing.T) { + var i Bytes + err := i.Scan(`hello`) + maybePanic(err) + assertBytes(t, i, "Scan() []byte") + + var null Bytes + err = null.Scan(nil) + maybePanic(err) + assertNullBytes(t, null, "scanned null") +} + +func assertBytes(t *testing.T, i Bytes, from string) { + if !bytes.Equal(i.Bytes, []byte("hello")) { + t.Errorf("bad %s []byte: %v ≠ %v\n", from, string(i.Bytes), string([]byte(`hello`))) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullBytes(t *testing.T, i Bytes, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/convert/convert.go b/null/convert/convert.go new file mode 100644 index 0000000..b231ebf --- /dev/null +++ b/null/convert/convert.go @@ -0,0 +1,266 @@ +package convert + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. +// These functions are copied from database/sql/convert.go build 1.6.2 + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// ConvertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func ConvertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + dv.Set(sv) + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return ConvertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} diff --git a/null/convert/convert_test.go b/null/convert/convert_test.go new file mode 100644 index 0000000..f837485 --- /dev/null +++ b/null/convert/convert_test.go @@ -0,0 +1,382 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// These functions are copied from database/sql/convert_test.go build 1.6.2 + +package convert + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "runtime" + "testing" + "time" +) + +var someTime = time.Unix(123, 0) +var answer int64 = 42 + +type userDefined float64 + +type userDefinedSlice []int + +type conversionTest struct { + s, d interface{} // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantstr string + wantbytes []byte + wantraw sql.RawBytes + wantf32 float32 + wantf64 float64 + wanttime time.Time + wantbool bool // used if d is of type *bool + wanterr string + wantiface interface{} + wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr + wantnil bool // if true, *d must be *int64(nil) + wantusrdef userDefined +} + +// Target variables for scanning into. +var ( + scanstr string + scanbytes []byte + scanraw sql.RawBytes + scanint int + scanint8 int8 + scanint16 int16 + scanint32 int32 + scanuint8 uint8 + scanuint16 uint16 + scanbool bool + scanf32 float32 + scanf64 float64 + scantime time.Time + scanptr *int64 + scaniface interface{} +) + +var conversionTests = []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, + + // To strings + {s: "string", d: &scanstr, wantstr: "string"}, + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + + // From time.Time: + {s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"}, + {s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"}, + {s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"}, + {s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"}, + {s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")}, + {s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()}, + + // To []byte + {s: nil, d: &scanbytes, wantbytes: nil}, + {s: "string", d: &scanbytes, wantbytes: []byte("string")}, + {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")}, + {s: 123, d: &scanbytes, wantbytes: []byte("123")}, + {s: int8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: int64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")}, + + // To sql.RawBytes + {s: nil, d: &scanraw, wantraw: nil}, + {s: []byte("byteslice"), d: &scanraw, wantraw: sql.RawBytes("byteslice")}, + {s: 123, d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: int8(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: int64(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: uint8(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: uint16(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: uint32(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: uint64(123), d: &scanraw, wantraw: sql.RawBytes("123")}, + {s: 1.5, d: &scanraw, wantraw: sql.RawBytes("1.5")}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"}, + + // int64 to smaller integers + {s: int64(5), d: &scanuint8, wantuint: 5}, + {s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"}, + {s: int64(256), d: &scanuint16, wantuint: 256}, + {s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + {s: "True", d: &scanbool, wantbool: true}, + {s: "TRUE", d: &scanbool, wantbool: true}, + {s: "1", d: &scanbool, wantbool: true}, + {s: 1, d: &scanbool, wantbool: true}, + {s: int64(1), d: &scanbool, wantbool: true}, + {s: uint16(1), d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + {s: "false", d: &scanbool, wantbool: false}, + {s: "FALSE", d: &scanbool, wantbool: false}, + {s: "0", d: &scanbool, wantbool: false}, + {s: 0, d: &scanbool, wantbool: false}, + {s: int64(0), d: &scanbool, wantbool: false}, + {s: uint16(0), d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`}, + {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`}, + + // Floats + {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: float64(1.5)}, + + // Pointers + {s: interface{}(nil), d: &scanptr, wantnil: true}, + {s: int64(42), d: &scanptr, wantptr: &answer}, + + // To interface{} + {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)}, + {s: int64(1), d: &scaniface, wantiface: int64(1)}, + {s: "str", d: &scaniface, wantiface: "str"}, + {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, + {s: true, d: &scaniface, wantiface: true}, + {s: nil, d: &scaniface}, + {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)}, + + // To a user-defined type + {s: 1.5, d: new(userDefined), wantusrdef: 1.5}, + {s: int64(123), d: new(userDefined), wantusrdef: 123}, + {s: "1.5", d: new(userDefined), wantusrdef: 1.5}, + {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *convert.userDefinedSlice`}, + + // Other errors + {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`}, +} + +func intPtrValue(intptr interface{}) interface{} { + return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int() +} + +func intValue(intptr interface{}) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr interface{}) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func float64Value(ptr interface{}) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr interface{}) float32 { + return *(ptr.(*float32)) +} + +func getTimeValue(ptr interface{}) time.Time { + return *(ptr.(*time.Time)) +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests { + err := ConvertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.Error() + } + errf := func(format string, args ...interface{}) { + base := fmt.Sprintf("ConvertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } + if !ct.wanttime.IsNull() && !ct.wanttime.Equal(getTimeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, getTimeValue(ct.d)) + } + if ct.wantnil && *ct.d.(**int64) != nil { + errf("want nil, got %v", intPtrValue(ct.d)) + } + if ct.wantptr != nil { + if *ct.d.(**int64) == nil { + errf("want pointer to %v, got nil", *ct.wantptr) + } else if *ct.wantptr != intPtrValue(ct.d) { + errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d)) + } + } + if ifptr, ok := ct.d.(*interface{}); ok { + if !reflect.DeepEqual(ct.wantiface, scaniface) { + errf("want interface %#v, got %#v", ct.wantiface, scaniface) + continue + } + if srcBytes, ok := ct.s.([]byte); ok { + dstBytes := (*ifptr).([]byte) + if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] { + errf("copy into interface{} didn't copy []byte data") + } + } + } + if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) { + errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined)) + } + } +} + +func TestNullString(t *testing.T) { + var ns sql.NullString + ConvertAssign(&ns, []byte("foo")) + if !ns.Valid { + t.Errorf("expecting not null") + } + if ns.String != "foo" { + t.Errorf("expecting foo; got %q", ns.String) + } + ConvertAssign(&ns, nil) + if ns.Valid { + t.Errorf("expecting null on nil") + } + if ns.String != "" { + t.Errorf("expecting blank on nil; got %q", ns.String) + } +} + +type valueConverterTest struct { + c driver.ValueConverter + in, out interface{} + err string +} + +var valueConverterTests = []valueConverterTest{ + {driver.DefaultParameterConverter, sql.NullString{"hi", true}, "hi", ""}, + {driver.DefaultParameterConverter, sql.NullString{"", false}, nil, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} + +// Tests that assigning to sql.RawBytes doesn't allocate (and also works). +func TestRawBytesAllocs(t *testing.T) { + var tests = []struct { + name string + in interface{} + want string + }{ + {"uint64", uint64(12345678), "12345678"}, + {"uint32", uint32(1234), "1234"}, + {"uint16", uint16(12), "12"}, + {"uint8", uint8(1), "1"}, + {"uint", uint(123), "123"}, + {"int", int(123), "123"}, + {"int8", int8(1), "1"}, + {"int16", int16(12), "12"}, + {"int32", int32(1234), "1234"}, + {"int64", int64(12345678), "12345678"}, + {"float32", float32(1.5), "1.5"}, + {"float64", float64(64), "64"}, + {"bool", false, "false"}, + } + + buf := make(sql.RawBytes, 10) + test := func(name string, in interface{}, want string) { + if err := ConvertAssign(&buf, in); err != nil { + t.Fatalf("%s: ConvertAssign = %v", name, err) + } + match := len(buf) == len(want) + if match { + for i, b := range buf { + if want[i] != b { + match = false + break + } + } + } + if !match { + t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want)) + } + } + + n := testing.AllocsPerRun(100, func() { + for _, tt := range tests { + test(tt.name, tt.in, tt.want) + } + }) + + // The numbers below are only valid for 64-bit interface word sizes, + // and gc. With 32-bit words there are more convT2E allocs, and + // with gccgo, only pointers currently go in interface data. + // So only care on amd64 gc for now. + measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc" + + if n > 0.5 && measureAllocs { + t.Fatalf("allocs = %v; want 0", n) + } + + // This one involves a convT2E allocation, string -> interface{} + n = testing.AllocsPerRun(100, func() { + test("string", "foo", "foo") + }) + if n > 1.5 && measureAllocs { + t.Fatalf("allocs = %v; want max 1", n) + } +} diff --git a/null/float32.go b/null/float32.go new file mode 100644 index 0000000..3e524de --- /dev/null +++ b/null/float32.go @@ -0,0 +1,123 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Float32 is a nullable float32. +type Float32 struct { + Float32 float32 + Valid bool +} + +// NewFloat32 creates a new Float32 +func NewFloat32(f float32, valid bool) Float32 { + return Float32{ + Float32: f, + Valid: valid, + } +} + +// Float32From creates a new Float32 that will always be valid. +func Float32From(f float32) Float32 { + return NewFloat32(f, true) +} + +// Float32FromPtr creates a new Float32 that be null if f is nil. +func Float32FromPtr(f *float32) Float32 { + if f == nil { + return NewFloat32(0, false) + } + return NewFloat32(*f, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *Float32) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + f.Valid = false + f.Float32 = 0 + return nil + } + + var x float64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + f.Float32 = float32(x) + f.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (f *Float32) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + f.Valid = false + return nil + } + var err error + res, err := strconv.ParseFloat(string(text), 32) + f.Valid = err == nil + if f.Valid { + f.Float32 = float32(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (f Float32) MarshalJSON() ([]byte, error) { + if !f.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatFloat(float64(f.Float32), 'f', -1, 32)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (f Float32) MarshalText() ([]byte, error) { + if !f.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatFloat(float64(f.Float32), 'f', -1, 32)), nil +} + +// SetValid changes this Float32's value and also sets it to be non-null. +func (f *Float32) SetValid(n float32) { + f.Float32 = n + f.Valid = true +} + +// Ptr returns a pointer to this Float32's value, or a nil pointer if this Float32 is null. +func (f Float32) Ptr() *float32 { + if !f.Valid { + return nil + } + return &f.Float32 +} + +// IsNull returns true for invalid Float32s, for future omitempty support (Go 1.4?) +func (f Float32) IsNull() bool { + return !f.Valid +} + +// Scan implements the Scanner interface. +func (f *Float32) Scan(value interface{}) error { + if value == nil { + f.Float32, f.Valid = 0, false + return nil + } + f.Valid = true + return convert.ConvertAssign(&f.Float32, value) +} + +// Value implements the driver Valuer interface. +func (f Float32) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return float64(f.Float32), nil +} diff --git a/null/float32_test.go b/null/float32_test.go new file mode 100644 index 0000000..66f49fc --- /dev/null +++ b/null/float32_test.go @@ -0,0 +1,164 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + float32JSON = []byte(`1.2345`) +) + +func TestFloat32From(t *testing.T) { + f := Float32From(1.2345) + assertFloat32(t, f, "Float32From()") + + zero := Float32From(0) + if !zero.Valid { + t.Error("Float32From(0)", "is invalid, but should be valid") + } +} + +func TestFloat32FromPtr(t *testing.T) { + n := float32(1.2345) + iptr := &n + f := Float32FromPtr(iptr) + assertFloat32(t, f, "Float32FromPtr()") + + null := Float32FromPtr(nil) + assertNullFloat32(t, null, "Float32FromPtr(nil)") +} + +func TestUnmarshalFloat32(t *testing.T) { + var f Float32 + err := json.Unmarshal(float32JSON, &f) + maybePanic(err) + assertFloat32(t, f, "float32 json") + + var null Float32 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullFloat32(t, null, "null json") + + var badType Float32 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullFloat32(t, badType, "wrong type json") + + var invalid Float32 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } +} + +func TestTextUnmarshalFloat32(t *testing.T) { + var f Float32 + err := f.UnmarshalText([]byte("1.2345")) + maybePanic(err) + assertFloat32(t, f, "UnmarshalText() float32") + + var blank Float32 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullFloat32(t, blank, "UnmarshalText() empty float32") +} + +func TestMarshalFloat32(t *testing.T) { + f := Float32From(1.2345) + data, err := json.Marshal(f) + maybePanic(err) + assertJSONEquals(t, data, "1.2345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewFloat32(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalFloat32Text(t *testing.T) { + f := Float32From(1.2345) + data, err := f.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "1.2345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewFloat32(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestFloat32Pointer(t *testing.T) { + f := Float32From(1.2345) + ptr := f.Ptr() + if *ptr != 1.2345 { + t.Errorf("bad %s float32: %#v ≠ %v\n", "pointer", ptr, 1.2345) + } + + null := NewFloat32(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s float32: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestFloat32IsNull(t *testing.T) { + f := Float32From(1.2345) + if f.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewFloat32(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewFloat32(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestFloat32SetValid(t *testing.T) { + change := NewFloat32(0, false) + assertNullFloat32(t, change, "SetValid()") + change.SetValid(1.2345) + assertFloat32(t, change, "SetValid()") +} + +func TestFloat32Scan(t *testing.T) { + var f Float32 + err := f.Scan(1.2345) + maybePanic(err) + assertFloat32(t, f, "scanned float32") + + var null Float32 + err = null.Scan(nil) + maybePanic(err) + assertNullFloat32(t, null, "scanned null") +} + +func assertFloat32(t *testing.T, f Float32, from string) { + if f.Float32 != 1.2345 { + t.Errorf("bad %s float32: %f ≠ %f\n", from, f.Float32, 1.2345) + } + if !f.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullFloat32(t *testing.T, f Float32, from string) { + if f.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/float64.go b/null/float64.go new file mode 100644 index 0000000..9bf3dee --- /dev/null +++ b/null/float64.go @@ -0,0 +1,118 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Float64 is a nullable float64. +type Float64 struct { + Float64 float64 + Valid bool +} + +// NewFloat64 creates a new Float64 +func NewFloat64(f float64, valid bool) Float64 { + return Float64{ + Float64: f, + Valid: valid, + } +} + +// Float64From creates a new Float64 that will always be valid. +func Float64From(f float64) Float64 { + return NewFloat64(f, true) +} + +// Float64FromPtr creates a new Float64 that be null if f is nil. +func Float64FromPtr(f *float64) Float64 { + if f == nil { + return NewFloat64(0, false) + } + return NewFloat64(*f, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *Float64) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + f.Float64 = 0 + f.Valid = false + return nil + } + + if err := json.Unmarshal(data, &f.Float64); err != nil { + return err + } + + f.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (f *Float64) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + f.Valid = false + return nil + } + var err error + f.Float64, err = strconv.ParseFloat(string(text), 64) + f.Valid = err == nil + return err +} + +// MarshalJSON implements json.Marshaler. +func (f Float64) MarshalJSON() ([]byte, error) { + if !f.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (f Float64) MarshalText() ([]byte, error) { + if !f.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil +} + +// SetValid changes this Float64's value and also sets it to be non-null. +func (f *Float64) SetValid(n float64) { + f.Float64 = n + f.Valid = true +} + +// Ptr returns a pointer to this Float64's value, or a nil pointer if this Float64 is null. +func (f Float64) Ptr() *float64 { + if !f.Valid { + return nil + } + return &f.Float64 +} + +// IsNull returns true for invalid Float64s, for future omitempty support (Go 1.4?) +func (f Float64) IsNull() bool { + return !f.Valid +} + +// Scan implements the Scanner interface. +func (f *Float64) Scan(value interface{}) error { + if value == nil { + f.Float64, f.Valid = 0, false + return nil + } + f.Valid = true + return convert.ConvertAssign(&f.Float64, value) +} + +// Value implements the driver Valuer interface. +func (f Float64) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return f.Float64, nil +} diff --git a/null/float64_test.go b/null/float64_test.go new file mode 100644 index 0000000..a04f479 --- /dev/null +++ b/null/float64_test.go @@ -0,0 +1,164 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + float64JSON = []byte(`1.2345`) +) + +func TestFloat64From(t *testing.T) { + f := Float64From(1.2345) + assertFloat64(t, f, "Float64From()") + + zero := Float64From(0) + if !zero.Valid { + t.Error("Float64From(0)", "is invalid, but should be valid") + } +} + +func TestFloat64FromPtr(t *testing.T) { + n := float64(1.2345) + iptr := &n + f := Float64FromPtr(iptr) + assertFloat64(t, f, "Float64FromPtr()") + + null := Float64FromPtr(nil) + assertNullFloat64(t, null, "Float64FromPtr(nil)") +} + +func TestUnmarshalFloat64(t *testing.T) { + var f Float64 + err := json.Unmarshal(float64JSON, &f) + maybePanic(err) + assertFloat64(t, f, "float64 json") + + var null Float64 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullFloat64(t, null, "null json") + + var badType Float64 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullFloat64(t, badType, "wrong type json") + + var invalid Float64 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } +} + +func TestTextUnmarshalFloat64(t *testing.T) { + var f Float64 + err := f.UnmarshalText([]byte("1.2345")) + maybePanic(err) + assertFloat64(t, f, "UnmarshalText() float64") + + var blank Float64 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullFloat64(t, blank, "UnmarshalText() empty float64") +} + +func TestMarshalFloat64(t *testing.T) { + f := Float64From(1.2345) + data, err := json.Marshal(f) + maybePanic(err) + assertJSONEquals(t, data, "1.2345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewFloat64(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalFloat64Text(t *testing.T) { + f := Float64From(1.2345) + data, err := f.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "1.2345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewFloat64(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestFloat64Pointer(t *testing.T) { + f := Float64From(1.2345) + ptr := f.Ptr() + if *ptr != 1.2345 { + t.Errorf("bad %s float64: %#v ≠ %v\n", "pointer", ptr, 1.2345) + } + + null := NewFloat64(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s float64: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestFloat64IsNull(t *testing.T) { + f := Float64From(1.2345) + if f.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewFloat64(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewFloat64(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestFloat64SetValid(t *testing.T) { + change := NewFloat64(0, false) + assertNullFloat64(t, change, "SetValid()") + change.SetValid(1.2345) + assertFloat64(t, change, "SetValid()") +} + +func TestFloat64Scan(t *testing.T) { + var f Float64 + err := f.Scan(1.2345) + maybePanic(err) + assertFloat64(t, f, "scanned float64") + + var null Float64 + err = null.Scan(nil) + maybePanic(err) + assertNullFloat64(t, null, "scanned null") +} + +func assertFloat64(t *testing.T, f Float64, from string) { + if f.Float64 != 1.2345 { + t.Errorf("bad %s float64: %f ≠ %f\n", from, f.Float64, 1.2345) + } + if !f.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullFloat64(t *testing.T, f Float64, from string) { + if f.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/int.go b/null/int.go new file mode 100644 index 0000000..f0cd8c9 --- /dev/null +++ b/null/int.go @@ -0,0 +1,123 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Int is an nullable int. +type Int struct { + Int int + Valid bool +} + +// NewInt creates a new Int +func NewInt(i int, valid bool) Int { + return Int{ + Int: i, + Valid: valid, + } +} + +// IntFrom creates a new Int that will always be valid. +func IntFrom(i int) Int { + return NewInt(i, true) +} + +// IntFromPtr creates a new Int that be null if i is nil. +func IntFromPtr(i *int) Int { + if i == nil { + return NewInt(0, false) + } + return NewInt(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (i *Int) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + i.Valid = false + i.Int = 0 + return nil + } + + var x int64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + i.Int = int(x) + i.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Int) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + i.Valid = false + return nil + } + var err error + res, err := strconv.ParseInt(string(text), 10, 0) + i.Valid = err == nil + if i.Valid { + i.Int = int(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (i Int) MarshalJSON() ([]byte, error) { + if !i.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatInt(int64(i.Int), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i Int) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int), 10)), nil +} + +// SetValid changes this Int's value and also sets it to be non-null. +func (i *Int) SetValid(n int) { + i.Int = n + i.Valid = true +} + +// Ptr returns a pointer to this Int's value, or a nil pointer if this Int is null. +func (i Int) Ptr() *int { + if !i.Valid { + return nil + } + return &i.Int +} + +// IsNull returns true for invalid Ints, for future omitempty support (Go 1.4?) +func (i Int) IsNull() bool { + return !i.Valid +} + +// Scan implements the Scanner interface. +func (i *Int) Scan(value interface{}) error { + if value == nil { + i.Int, i.Valid = 0, false + return nil + } + i.Valid = true + return convert.ConvertAssign(&i.Int, value) +} + +// Value implements the driver Valuer interface. +func (i Int) Value() (driver.Value, error) { + if !i.Valid { + return nil, nil + } + return int64(i.Int), nil +} diff --git a/null/int16.go b/null/int16.go new file mode 100644 index 0000000..e89c42a --- /dev/null +++ b/null/int16.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Int16 is an nullable int16. +type Int16 struct { + Int16 int16 + Valid bool +} + +// NewInt16 creates a new Int16 +func NewInt16(i int16, valid bool) Int16 { + return Int16{ + Int16: i, + Valid: valid, + } +} + +// Int16From creates a new Int16 that will always be valid. +func Int16From(i int16) Int16 { + return NewInt16(i, true) +} + +// Int16FromPtr creates a new Int16 that be null if i is nil. +func Int16FromPtr(i *int16) Int16 { + if i == nil { + return NewInt16(0, false) + } + return NewInt16(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (i *Int16) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + i.Valid = false + i.Int16 = 0 + return nil + } + + var x int64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxInt16 { + return fmt.Errorf("json: %d overflows max int16 value", x) + } + + i.Int16 = int16(x) + i.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Int16) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + i.Valid = false + return nil + } + var err error + res, err := strconv.ParseInt(string(text), 10, 16) + i.Valid = err == nil + if i.Valid { + i.Int16 = int16(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (i Int16) MarshalJSON() ([]byte, error) { + if !i.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatInt(int64(i.Int16), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i Int16) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int16), 10)), nil +} + +// SetValid changes this Int16's value and also sets it to be non-null. +func (i *Int16) SetValid(n int16) { + i.Int16 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int16's value, or a nil pointer if this Int16 is null. +func (i Int16) Ptr() *int16 { + if !i.Valid { + return nil + } + return &i.Int16 +} + +// IsNull returns true for invalid Int16's, for future omitempty support (Go 1.4?) +func (i Int16) IsNull() bool { + return !i.Valid +} + +// Scan implements the Scanner interface. +func (i *Int16) Scan(value interface{}) error { + if value == nil { + i.Int16, i.Valid = 0, false + return nil + } + i.Valid = true + return convert.ConvertAssign(&i.Int16, value) +} + +// Value implements the driver Valuer interface. +func (i Int16) Value() (driver.Value, error) { + if !i.Valid { + return nil, nil + } + return int64(i.Int16), nil +} diff --git a/null/int16_test.go b/null/int16_test.go new file mode 100644 index 0000000..ad2b8a5 --- /dev/null +++ b/null/int16_test.go @@ -0,0 +1,190 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + int16JSON = []byte(`32766`) +) + +func TestInt16From(t *testing.T) { + i := Int16From(32766) + assertInt16(t, i, "Int16From()") + + zero := Int16From(0) + if !zero.Valid { + t.Error("Int16From(0)", "is invalid, but should be valid") + } +} + +func TestInt16FromPtr(t *testing.T) { + n := int16(32766) + iptr := &n + i := Int16FromPtr(iptr) + assertInt16(t, i, "Int16FromPtr()") + + null := Int16FromPtr(nil) + assertNullInt16(t, null, "Int16FromPtr(nil)") +} + +func TestUnmarshalInt16(t *testing.T) { + var i Int16 + err := json.Unmarshal(int16JSON, &i) + maybePanic(err) + assertInt16(t, i, "int16 json") + + var null Int16 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt16(t, null, "null json") + + var badType Int16 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt16(t, badType, "wrong type json") + + var invalid Int16 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullInt16(t, invalid, "invalid json") +} + +func TestUnmarshalNonIntegerNumber16(t *testing.T) { + var i Int16 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int16") + } +} + +func TestUnmarshalInt16Overflow(t *testing.T) { + int16Overflow := uint16(math.MaxInt16) + + // Max int16 should decode successfully + var i Int16 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(int16Overflow), 10)), &i) + maybePanic(err) + // Attempt to overflow + int16Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(int16Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows int16") + } +} + +func TestTextUnmarshalInt16(t *testing.T) { + var i Int16 + err := i.UnmarshalText([]byte("32766")) + maybePanic(err) + assertInt16(t, i, "UnmarshalText() int16") + + var blank Int16 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt16(t, blank, "UnmarshalText() empty int16") +} + +func TestMarshalInt16(t *testing.T) { + i := Int16From(32766) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "32766", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt16(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalInt16Text(t *testing.T) { + i := Int16From(32766) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "32766", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt16(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestInt16Pointer(t *testing.T) { + i := Int16From(32766) + ptr := i.Ptr() + if *ptr != 32766 { + t.Errorf("bad %s int16: %#v ≠ %d\n", "pointer", ptr, 32766) + } + + null := NewInt16(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int16: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestInt16IsNull(t *testing.T) { + i := Int16From(32766) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewInt16(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewInt16(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestInt16SetValid(t *testing.T) { + change := NewInt16(0, false) + assertNullInt16(t, change, "SetValid()") + change.SetValid(32766) + assertInt16(t, change, "SetValid()") +} + +func TestInt16Scan(t *testing.T) { + var i Int16 + err := i.Scan(32766) + maybePanic(err) + assertInt16(t, i, "scanned int16") + + var null Int16 + err = null.Scan(nil) + maybePanic(err) + assertNullInt16(t, null, "scanned null") +} + +func assertInt16(t *testing.T, i Int16, from string) { + if i.Int16 != 32766 { + t.Errorf("bad %s int16: %d ≠ %d\n", from, i.Int16, 32766) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt16(t *testing.T, i Int16, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/int32.go b/null/int32.go new file mode 100644 index 0000000..00f23a3 --- /dev/null +++ b/null/int32.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Int32 is an nullable int32. +type Int32 struct { + Int32 int32 + Valid bool +} + +// NewInt32 creates a new Int32 +func NewInt32(i int32, valid bool) Int32 { + return Int32{ + Int32: i, + Valid: valid, + } +} + +// Int32From creates a new Int32 that will always be valid. +func Int32From(i int32) Int32 { + return NewInt32(i, true) +} + +// Int32FromPtr creates a new Int32 that be null if i is nil. +func Int32FromPtr(i *int32) Int32 { + if i == nil { + return NewInt32(0, false) + } + return NewInt32(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (i *Int32) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + i.Valid = false + i.Int32 = 0 + return nil + } + + var x int64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxInt32 { + return fmt.Errorf("json: %d overflows max int32 value", x) + } + + i.Int32 = int32(x) + i.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Int32) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + i.Valid = false + return nil + } + var err error + res, err := strconv.ParseInt(string(text), 10, 32) + i.Valid = err == nil + if i.Valid { + i.Int32 = int32(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (i Int32) MarshalJSON() ([]byte, error) { + if !i.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatInt(int64(i.Int32), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i Int32) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int32), 10)), nil +} + +// SetValid changes this Int32's value and also sets it to be non-null. +func (i *Int32) SetValid(n int32) { + i.Int32 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int32's value, or a nil pointer if this Int32 is null. +func (i Int32) Ptr() *int32 { + if !i.Valid { + return nil + } + return &i.Int32 +} + +// IsNull returns true for invalid Int32's, for future omitempty support (Go 1.4?) +func (i Int32) IsNull() bool { + return !i.Valid +} + +// Scan implements the Scanner interface. +func (i *Int32) Scan(value interface{}) error { + if value == nil { + i.Int32, i.Valid = 0, false + return nil + } + i.Valid = true + return convert.ConvertAssign(&i.Int32, value) +} + +// Value implements the driver Valuer interface. +func (i Int32) Value() (driver.Value, error) { + if !i.Valid { + return nil, nil + } + return int64(i.Int32), nil +} diff --git a/null/int32_test.go b/null/int32_test.go new file mode 100644 index 0000000..ebf6b9a --- /dev/null +++ b/null/int32_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + int32JSON = []byte(`2147483646`) +) + +func TestInt32From(t *testing.T) { + i := Int32From(2147483646) + assertInt32(t, i, "Int32From()") + + zero := Int32From(0) + if !zero.Valid { + t.Error("Int32From(0)", "is invalid, but should be valid") + } +} + +func TestInt32FromPtr(t *testing.T) { + n := int32(2147483646) + iptr := &n + i := Int32FromPtr(iptr) + assertInt32(t, i, "Int32FromPtr()") + + null := Int32FromPtr(nil) + assertNullInt32(t, null, "Int32FromPtr(nil)") +} + +func TestUnmarshalInt32(t *testing.T) { + var i Int32 + err := json.Unmarshal(int32JSON, &i) + maybePanic(err) + assertInt32(t, i, "int32 json") + + var null Int32 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt32(t, null, "null json") + + var badType Int32 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt32(t, badType, "wrong type json") + + var invalid Int32 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullInt32(t, invalid, "invalid json") +} + +func TestUnmarshalNonIntegerNumber32(t *testing.T) { + var i Int32 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int32") + } +} + +func TestUnmarshalInt32Overflow(t *testing.T) { + int32Overflow := uint32(math.MaxInt32) + + // Max int32 should decode successfully + var i Int32 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(int32Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + int32Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(int32Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows int32") + } +} + +func TestTextUnmarshalInt32(t *testing.T) { + var i Int32 + err := i.UnmarshalText([]byte("2147483646")) + maybePanic(err) + assertInt32(t, i, "UnmarshalText() int32") + + var blank Int32 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt32(t, blank, "UnmarshalText() empty int32") +} + +func TestMarshalInt32(t *testing.T) { + i := Int32From(2147483646) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "2147483646", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt32(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalInt32Text(t *testing.T) { + i := Int32From(2147483646) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "2147483646", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt32(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestInt32Pointer(t *testing.T) { + i := Int32From(2147483646) + ptr := i.Ptr() + if *ptr != 2147483646 { + t.Errorf("bad %s int32: %#v ≠ %d\n", "pointer", ptr, 2147483646) + } + + null := NewInt32(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int32: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestInt32IsNull(t *testing.T) { + i := Int32From(2147483646) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewInt32(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewInt32(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestInt32SetValid(t *testing.T) { + change := NewInt32(0, false) + assertNullInt32(t, change, "SetValid()") + change.SetValid(2147483646) + assertInt32(t, change, "SetValid()") +} + +func TestInt32Scan(t *testing.T) { + var i Int32 + err := i.Scan(2147483646) + maybePanic(err) + assertInt32(t, i, "scanned int32") + + var null Int32 + err = null.Scan(nil) + maybePanic(err) + assertNullInt32(t, null, "scanned null") +} + +func assertInt32(t *testing.T, i Int32, from string) { + if i.Int32 != 2147483646 { + t.Errorf("bad %s int32: %d ≠ %d\n", from, i.Int32, 2147483646) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt32(t *testing.T, i Int32, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/int64.go b/null/int64.go new file mode 100644 index 0000000..c1c463e --- /dev/null +++ b/null/int64.go @@ -0,0 +1,118 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Int64 is an nullable int64. +type Int64 struct { + Int64 int64 + Valid bool +} + +// NewInt64 creates a new Int64 +func NewInt64(i int64, valid bool) Int64 { + return Int64{ + Int64: i, + Valid: valid, + } +} + +// Int64From creates a new Int64 that will always be valid. +func Int64From(i int64) Int64 { + return NewInt64(i, true) +} + +// Int64FromPtr creates a new Int64 that be null if i is nil. +func Int64FromPtr(i *int64) Int64 { + if i == nil { + return NewInt64(0, false) + } + return NewInt64(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (i *Int64) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + i.Valid = false + i.Int64 = 0 + return nil + } + + if err := json.Unmarshal(data, &i.Int64); err != nil { + return err + } + + i.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Int64) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + i.Valid = false + return nil + } + var err error + i.Int64, err = strconv.ParseInt(string(text), 10, 64) + i.Valid = err == nil + return err +} + +// MarshalJSON implements json.Marshaler. +func (i Int64) MarshalJSON() ([]byte, error) { + if !i.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatInt(i.Int64, 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i Int64) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(i.Int64, 10)), nil +} + +// SetValid changes this Int64's value and also sets it to be non-null. +func (i *Int64) SetValid(n int64) { + i.Int64 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int64's value, or a nil pointer if this Int64 is null. +func (i Int64) Ptr() *int64 { + if !i.Valid { + return nil + } + return &i.Int64 +} + +// IsNull returns true for invalid Int64's, for future omitempty support (Go 1.4?) +func (i Int64) IsNull() bool { + return !i.Valid +} + +// Scan implements the Scanner interface. +func (i *Int64) Scan(value interface{}) error { + if value == nil { + i.Int64, i.Valid = 0, false + return nil + } + i.Valid = true + return convert.ConvertAssign(&i.Int64, value) +} + +// Value implements the driver Valuer interface. +func (i Int64) Value() (driver.Value, error) { + if !i.Valid { + return nil, nil + } + return i.Int64, nil +} diff --git a/null/int64_test.go b/null/int64_test.go new file mode 100644 index 0000000..4a31acb --- /dev/null +++ b/null/int64_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + int64JSON = []byte(`9223372036854775806`) +) + +func TestInt64From(t *testing.T) { + i := Int64From(9223372036854775806) + assertInt64(t, i, "Int64From()") + + zero := Int64From(0) + if !zero.Valid { + t.Error("Int64From(0)", "is invalid, but should be valid") + } +} + +func TestInt64FromPtr(t *testing.T) { + n := int64(9223372036854775806) + iptr := &n + i := Int64FromPtr(iptr) + assertInt64(t, i, "Int64FromPtr()") + + null := Int64FromPtr(nil) + assertNullInt64(t, null, "Int64FromPtr(nil)") +} + +func TestUnmarshalInt64(t *testing.T) { + var i Int64 + err := json.Unmarshal(int64JSON, &i) + maybePanic(err) + assertInt64(t, i, "int64 json") + + var null Int64 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt64(t, null, "null json") + + var badType Int64 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt64(t, badType, "wrong type json") + + var invalid Int64 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullInt64(t, invalid, "invalid json") +} + +func TestUnmarshalNonIntegerNumber64(t *testing.T) { + var i Int64 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int64") + } +} + +func TestUnmarshalInt64Overflow(t *testing.T) { + int64Overflow := uint64(math.MaxInt64) + + // Max int64 should decode successfully + var i Int64 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(int64Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + int64Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(int64Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows int64") + } +} + +func TestTextUnmarshalInt64(t *testing.T) { + var i Int64 + err := i.UnmarshalText([]byte("9223372036854775806")) + maybePanic(err) + assertInt64(t, i, "UnmarshalText() int64") + + var blank Int64 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt64(t, blank, "UnmarshalText() empty int64") +} + +func TestMarshalInt64(t *testing.T) { + i := Int64From(9223372036854775806) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "9223372036854775806", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt64(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalInt64Text(t *testing.T) { + i := Int64From(9223372036854775806) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "9223372036854775806", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt64(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestInt64Pointer(t *testing.T) { + i := Int64From(9223372036854775806) + ptr := i.Ptr() + if *ptr != 9223372036854775806 { + t.Errorf("bad %s int64: %#v ≠ %d\n", "pointer", ptr, 9223372036854775806) + } + + null := NewInt64(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int64: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestInt64IsNull(t *testing.T) { + i := Int64From(9223372036854775806) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewInt64(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewInt64(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestInt64SetValid(t *testing.T) { + change := NewInt64(0, false) + assertNullInt64(t, change, "SetValid()") + change.SetValid(9223372036854775806) + assertInt64(t, change, "SetValid()") +} + +func TestInt64Scan(t *testing.T) { + var i Int64 + err := i.Scan(9223372036854775806) + maybePanic(err) + assertInt64(t, i, "scanned int64") + + var null Int64 + err = null.Scan(nil) + maybePanic(err) + assertNullInt64(t, null, "scanned null") +} + +func assertInt64(t *testing.T, i Int64, from string) { + if i.Int64 != 9223372036854775806 { + t.Errorf("bad %s int64: %d ≠ %d\n", from, i.Int64, 9223372036854775806) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt64(t *testing.T, i Int64, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/int8.go b/null/int8.go new file mode 100644 index 0000000..0cffba0 --- /dev/null +++ b/null/int8.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Int8 is an nullable int8. +type Int8 struct { + Int8 int8 + Valid bool +} + +// NewInt8 creates a new Int8 +func NewInt8(i int8, valid bool) Int8 { + return Int8{ + Int8: i, + Valid: valid, + } +} + +// Int8From creates a new Int8 that will always be valid. +func Int8From(i int8) Int8 { + return NewInt8(i, true) +} + +// Int8FromPtr creates a new Int8 that be null if i is nil. +func Int8FromPtr(i *int8) Int8 { + if i == nil { + return NewInt8(0, false) + } + return NewInt8(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (i *Int8) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + i.Valid = false + i.Int8 = 0 + return nil + } + + var x int64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxInt8 { + return fmt.Errorf("json: %d overflows max int8 value", x) + } + + i.Int8 = int8(x) + i.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Int8) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + i.Valid = false + return nil + } + var err error + res, err := strconv.ParseInt(string(text), 10, 8) + i.Valid = err == nil + if i.Valid { + i.Int8 = int8(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (i Int8) MarshalJSON() ([]byte, error) { + if !i.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatInt(int64(i.Int8), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i Int8) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int8), 10)), nil +} + +// SetValid changes this Int8's value and also sets it to be non-null. +func (i *Int8) SetValid(n int8) { + i.Int8 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int8's value, or a nil pointer if this Int8 is null. +func (i Int8) Ptr() *int8 { + if !i.Valid { + return nil + } + return &i.Int8 +} + +// IsNull returns true for invalid Int8's, for future omitempty support (Go 1.4?) +func (i Int8) IsNull() bool { + return !i.Valid +} + +// Scan implements the Scanner interface. +func (i *Int8) Scan(value interface{}) error { + if value == nil { + i.Int8, i.Valid = 0, false + return nil + } + i.Valid = true + return convert.ConvertAssign(&i.Int8, value) +} + +// Value implements the driver Valuer interface. +func (i Int8) Value() (driver.Value, error) { + if !i.Valid { + return nil, nil + } + return int64(i.Int8), nil +} diff --git a/null/int8_test.go b/null/int8_test.go new file mode 100644 index 0000000..520defb --- /dev/null +++ b/null/int8_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + int8JSON = []byte(`126`) +) + +func TestInt8From(t *testing.T) { + i := Int8From(126) + assertInt8(t, i, "Int8From()") + + zero := Int8From(0) + if !zero.Valid { + t.Error("Int8From(0)", "is invalid, but should be valid") + } +} + +func TestInt8FromPtr(t *testing.T) { + n := int8(126) + iptr := &n + i := Int8FromPtr(iptr) + assertInt8(t, i, "Int8FromPtr()") + + null := Int8FromPtr(nil) + assertNullInt8(t, null, "Int8FromPtr(nil)") +} + +func TestUnmarshalInt8(t *testing.T) { + var i Int8 + err := json.Unmarshal(int8JSON, &i) + maybePanic(err) + assertInt8(t, i, "int8 json") + + var null Int8 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt8(t, null, "null json") + + var badType Int8 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt8(t, badType, "wrong type json") + + var invalid Int8 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullInt8(t, invalid, "invalid json") +} + +func TestUnmarshalNonIntegerNumber8(t *testing.T) { + var i Int8 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int8") + } +} + +func TestUnmarshalInt8Overflow(t *testing.T) { + int8Overflow := uint8(math.MaxInt8) + + // Max int8 should decode successfully + var i Int8 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(int8Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + int8Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(int8Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows int8") + } +} + +func TestTextUnmarshalInt8(t *testing.T) { + var i Int8 + err := i.UnmarshalText([]byte("126")) + maybePanic(err) + assertInt8(t, i, "UnmarshalText() int8") + + var blank Int8 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt8(t, blank, "UnmarshalText() empty int8") +} + +func TestMarshalInt8(t *testing.T) { + i := Int8From(126) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "126", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt8(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalInt8Text(t *testing.T) { + i := Int8From(126) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "126", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt8(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestInt8Pointer(t *testing.T) { + i := Int8From(126) + ptr := i.Ptr() + if *ptr != 126 { + t.Errorf("bad %s int8: %#v ≠ %d\n", "pointer", ptr, 126) + } + + null := NewInt8(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int8: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestInt8IsNull(t *testing.T) { + i := Int8From(126) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewInt8(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewInt8(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestInt8SetValid(t *testing.T) { + change := NewInt8(0, false) + assertNullInt8(t, change, "SetValid()") + change.SetValid(126) + assertInt8(t, change, "SetValid()") +} + +func TestInt8Scan(t *testing.T) { + var i Int8 + err := i.Scan(126) + maybePanic(err) + assertInt8(t, i, "scanned int8") + + var null Int8 + err = null.Scan(nil) + maybePanic(err) + assertNullInt8(t, null, "scanned null") +} + +func assertInt8(t *testing.T, i Int8, from string) { + if i.Int8 != 126 { + t.Errorf("bad %s int8: %d ≠ %d\n", from, i.Int8, 126) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt8(t *testing.T, i Int8, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/int_test.go b/null/int_test.go new file mode 100644 index 0000000..6a94f15 --- /dev/null +++ b/null/int_test.go @@ -0,0 +1,173 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + intJSON = []byte(`12345`) +) + +func TestIntFrom(t *testing.T) { + i := IntFrom(12345) + assertInt(t, i, "IntFrom()") + + zero := IntFrom(0) + if !zero.Valid { + t.Error("IntFrom(0)", "is invalid, but should be valid") + } +} + +func TestIntFromPtr(t *testing.T) { + n := int(12345) + iptr := &n + i := IntFromPtr(iptr) + assertInt(t, i, "IntFromPtr()") + + null := IntFromPtr(nil) + assertNullInt(t, null, "IntFromPtr(nil)") +} + +func TestUnmarshalInt(t *testing.T) { + var i Int + err := json.Unmarshal(intJSON, &i) + maybePanic(err) + assertInt(t, i, "int json") + + var null Int + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt(t, null, "null json") + + var badType Int + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt(t, badType, "wrong type json") + + var invalid Int + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullInt(t, invalid, "invalid json") +} + +func TestUnmarshalNonIntegerNumber(t *testing.T) { + var i Int + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to int") + } +} + +func TestTextUnmarshalInt(t *testing.T) { + var i Int + err := i.UnmarshalText([]byte("12345")) + maybePanic(err) + assertInt(t, i, "UnmarshalText() int") + + var blank Int + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt(t, blank, "UnmarshalText() empty int") +} + +func TestMarshalInt(t *testing.T) { + i := IntFrom(12345) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalIntText(t *testing.T) { + i := IntFrom(12345) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestIntPointer(t *testing.T) { + i := IntFrom(12345) + ptr := i.Ptr() + if *ptr != 12345 { + t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 12345) + } + + null := NewInt(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestIntIsNull(t *testing.T) { + i := IntFrom(12345) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewInt(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewInt(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestIntSetValid(t *testing.T) { + change := NewInt(0, false) + assertNullInt(t, change, "SetValid()") + change.SetValid(12345) + assertInt(t, change, "SetValid()") +} + +func TestIntScan(t *testing.T) { + var i Int + err := i.Scan(12345) + maybePanic(err) + assertInt(t, i, "scanned int") + + var null Int + err = null.Scan(nil) + maybePanic(err) + assertNullInt(t, null, "scanned null") +} + +func assertInt(t *testing.T, i Int, from string) { + if i.Int != 12345 { + t.Errorf("bad %s int: %d ≠ %d\n", from, i.Int, 12345) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt(t *testing.T, i Int, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/json.go b/null/json.go new file mode 100644 index 0000000..f5668c9 --- /dev/null +++ b/null/json.go @@ -0,0 +1,157 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + + "gopkg.in/nullbio/null.v6/convert" +) + +// JSON is a nullable []byte. +type JSON struct { + JSON []byte + Valid bool +} + +// NewJSON creates a new JSON +func NewJSON(b []byte, valid bool) JSON { + return JSON{ + JSON: b, + Valid: valid, + } +} + +// JSONFrom creates a new JSON that will be invalid if nil. +func JSONFrom(b []byte) JSON { + return NewJSON(b, b != nil) +} + +// JSONFromPtr creates a new JSON that will be invalid if nil. +func JSONFromPtr(b *[]byte) JSON { + if b == nil { + return NewJSON(nil, false) + } + n := NewJSON(*b, true) + return n +} + +// Unmarshal will unmarshal your JSON stored in +// your JSON object and store the result in the +// value pointed to by dest. +func (j JSON) Unmarshal(dest interface{}) error { + if dest == nil { + return errors.New("destination is nil, not a valid pointer to an object") + } + + // Call our implementation of + // JSON MarshalJSON through json.Marshal + // to get the value of the JSON object + res, err := json.Marshal(j) + if err != nil { + return err + } + + return json.Unmarshal(res, dest) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (j *JSON) UnmarshalJSON(data []byte) error { + if data == nil { + return fmt.Errorf("json: cannot unmarshal nil into Go value of type null.JSON") + } + + if bytes.Equal(data, NullBytes) { + j.JSON = NullBytes + j.Valid = false + return nil + } + + j.Valid = true + j.JSON = make([]byte, len(data)) + copy(j.JSON, data) + + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (j *JSON) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + j.JSON = nil + j.Valid = false + } else { + j.JSON = append(j.JSON[0:0], text...) + j.Valid = true + } + + return nil +} + +// Marshal will marshal the passed in object, +// and store it in the JSON member on the JSON object. +func (j *JSON) Marshal(obj interface{}) error { + res, err := json.Marshal(obj) + if err != nil { + return err + } + + // Call our implementation of + // JSON UnmarshalJSON through json.Unmarshal + // to set the result to the JSON object + return json.Unmarshal(res, j) +} + +// MarshalJSON implements json.Marshaler. +func (j JSON) MarshalJSON() ([]byte, error) { + if len(j.JSON) == 0 || j.JSON == nil { + return NullBytes, nil + } + return j.JSON, nil +} + +// MarshalText implements encoding.TextMarshaler. +func (j JSON) MarshalText() ([]byte, error) { + if !j.Valid { + return nil, nil + } + return j.JSON, nil +} + +// SetValid changes this JSON's value and also sets it to be non-null. +func (j *JSON) SetValid(n []byte) { + j.JSON = n + j.Valid = true +} + +// Ptr returns a pointer to this JSON's value, or a nil pointer if this JSON is null. +func (j JSON) Ptr() *[]byte { + if !j.Valid { + return nil + } + return &j.JSON +} + +// IsNull returns true for null or zero JSON's, for future omitempty support (Go 1.4?) +func (j JSON) IsNull() bool { + return !j.Valid +} + +// Scan implements the Scanner interface. +func (j *JSON) Scan(value interface{}) error { + if value == nil { + j.JSON, j.Valid = []byte{}, false + return nil + } + j.Valid = true + return convert.ConvertAssign(&j.JSON, value) +} + +// Value implements the driver Valuer interface. +func (j JSON) Value() (driver.Value, error) { + if !j.Valid { + return nil, nil + } + return j.JSON, nil +} diff --git a/null/json_test.go b/null/json_test.go new file mode 100644 index 0000000..41d759f --- /dev/null +++ b/null/json_test.go @@ -0,0 +1,238 @@ +package null + +import ( + "bytes" + "encoding/json" + "testing" +) + +var ( + jsonJSON = []byte(`"hello"`) +) + +func TestJSONFrom(t *testing.T) { + i := JSONFrom([]byte(`"hello"`)) + assertJSON(t, i, "JSONFrom()") + + zero := JSONFrom(nil) + if zero.Valid { + t.Error("JSONFrom(nil)", "is valid, but should be invalid") + } + + zero = JSONFrom([]byte{}) + if !zero.Valid { + t.Error("JSONFrom([]byte{})", "is invalid, but should be valid") + } +} + +func TestJSONFromPtr(t *testing.T) { + n := []byte(`"hello"`) + iptr := &n + i := JSONFromPtr(iptr) + assertJSON(t, i, "JSONFromPtr()") + + null := JSONFromPtr(nil) + assertNullJSON(t, null, "JSONFromPtr(nil)") +} + +type Test struct { + Name string + Age int +} + +func TestMarshal(t *testing.T) { + var i JSON + + test := &Test{Name: "hello", Age: 15} + + err := i.Marshal(test) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(i.JSON, []byte(`{"Name":"hello","Age":15}`)) { + t.Errorf("Mismatch between received and expected, got: %s", string(i.JSON)) + } + if i.Valid == false { + t.Error("Expected valid true, got Valid false") + } + + err = i.Marshal(nil) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(i.JSON, []byte("null")) { + t.Errorf("Expected null, but got %s", string(i.JSON)) + } + if i.Valid == true { + t.Error("Expected Valid false, got Valid true") + } +} + +func TestUnmarshal(t *testing.T) { + var i JSON + + test := &Test{} + + err := i.Unmarshal(test) + if err != nil { + t.Error(err) + } + + x := &Test{Name: "hello", Age: 15} + err = i.Marshal(x) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(i.JSON, []byte(`{"Name":"hello","Age":15}`)) { + t.Errorf("Mismatch between received and expected, got: %s", string(i.JSON)) + } + + err = i.Unmarshal(test) + if err != nil { + t.Error(err) + } + + if test.Age != 15 { + t.Errorf("Expected 15, got %d", test.Age) + } + if test.Name != "hello" { + t.Errorf("Expected name, got %s", test.Name) + } +} + +func TestUnmarshalJSON(t *testing.T) { + var i JSON + err := json.Unmarshal(jsonJSON, &i) + maybePanic(err) + assertJSON(t, i, "[]byte json") + + var ni JSON + err = ni.UnmarshalJSON([]byte{}) + if ni.Valid == false { + t.Errorf("expected Valid to be true, got false") + } + if !bytes.Equal(ni.JSON, nil) { + t.Errorf("Expected JSON to be nil, but was not: %#v %#v", ni.JSON, []byte(nil)) + } + + var null JSON + err = null.UnmarshalJSON(nil) + if ni.Valid == false { + t.Errorf("expected Valid to be true, got false") + } + if !bytes.Equal(null.JSON, nil) { + t.Errorf("Expected JSON to be []byte nil, but was not: %#v %#v", null.JSON, []byte(nil)) + } +} + +func TestTextUnmarshalJSON(t *testing.T) { + var i JSON + err := i.UnmarshalText([]byte(`"hello"`)) + maybePanic(err) + assertJSON(t, i, "UnmarshalText() []byte") + + var blank JSON + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullJSON(t, blank, "UnmarshalText() empty []byte") +} + +func TestMarshalJSON(t *testing.T) { + i := JSONFrom([]byte(`"hello"`)) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, `"hello"`, "non-empty json marshal") + + // invalid values should be encoded as null + null := NewJSON(nil, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalJSONText(t *testing.T) { + i := JSONFrom([]byte(`"hello"`)) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, `"hello"`, "non-empty text marshal") + + // invalid values should be encoded as null + null := NewJSON(nil, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestJSONPointer(t *testing.T) { + i := JSONFrom([]byte(`"hello"`)) + ptr := i.Ptr() + if !bytes.Equal(*ptr, []byte(`"hello"`)) { + t.Errorf("bad %s []byte: %#v ≠ %s\n", "pointer", ptr, `"hello"`) + } + + null := NewJSON(nil, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s []byte: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestJSONIsNull(t *testing.T) { + i := JSONFrom([]byte(`"hello"`)) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewJSON(nil, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewJSON(nil, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestJSONSetValid(t *testing.T) { + change := NewJSON(nil, false) + assertNullJSON(t, change, "SetValid()") + change.SetValid([]byte(`"hello"`)) + assertJSON(t, change, "SetValid()") +} + +func TestJSONScan(t *testing.T) { + var i JSON + err := i.Scan(`"hello"`) + maybePanic(err) + assertJSON(t, i, "scanned []byte") + + var null JSON + err = null.Scan(nil) + maybePanic(err) + assertNullJSON(t, null, "scanned null") +} + +func assertJSON(t *testing.T, i JSON, from string) { + if !bytes.Equal(i.JSON, []byte(`"hello"`)) { + t.Errorf("bad %s []byte: %#v ≠ %#v\n", from, string(i.JSON), string([]byte(`"hello"`))) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullJSON(t *testing.T, i JSON, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/nullable.go b/null/nullable.go new file mode 100644 index 0000000..2f105df --- /dev/null +++ b/null/nullable.go @@ -0,0 +1,5 @@ +package null + +type Nullable interface { + IsNull() bool +} diff --git a/null/string.go b/null/string.go new file mode 100644 index 0000000..fc9333a --- /dev/null +++ b/null/string.go @@ -0,0 +1,117 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + + "gopkg.in/nullbio/null.v6/convert" +) + +// String is a nullable string. It supports SQL and JSON serialization. +type String struct { + String string + Valid bool +} + +// StringFrom creates a new String that will never be blank. +func StringFrom(s string) String { + return NewString(s, true) +} + +// StringFromPtr creates a new String that be null if s is nil. +func StringFromPtr(s *string) String { + if s == nil { + return NewString("", false) + } + return NewString(*s, true) +} + +// NewString creates a new String +func NewString(s string, valid bool) String { + return String{ + String: s, + Valid: valid, + } +} + +// UnmarshalJSON implements json.Unmarshaler. +func (s *String) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + s.String = "" + s.Valid = false + return nil + } + + if err := json.Unmarshal(data, &s.String); err != nil { + return err + } + + s.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (s String) MarshalJSON() ([]byte, error) { + if !s.Valid { + return NullBytes, nil + } + return json.Marshal(s.String) +} + +// MarshalText implements encoding.TextMarshaler. +func (s String) MarshalText() ([]byte, error) { + if !s.Valid { + return []byte{}, nil + } + return []byte(s.String), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (s *String) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + s.Valid = false + return nil + } + + s.String = string(text) + s.Valid = true + return nil +} + +// SetValid changes this String's value and also sets it to be non-null. +func (s *String) SetValid(v string) { + s.String = v + s.Valid = true +} + +// Ptr returns a pointer to this String's value, or a nil pointer if this String is null. +func (s String) Ptr() *string { + if !s.Valid { + return nil + } + return &s.String +} + +// IsNull returns true for null strings, for potential future omitempty support. +func (s String) IsNull() bool { + return !s.Valid +} + +// Scan implements the Scanner interface. +func (s *String) Scan(value interface{}) error { + if value == nil { + s.String, s.Valid = "", false + return nil + } + s.Valid = true + return convert.ConvertAssign(&s.String, value) +} + +// Value implements the driver Valuer interface. +func (s String) Value() (driver.Value, error) { + if !s.Valid { + return nil, nil + } + return s.String, nil +} diff --git a/null/string_test.go b/null/string_test.go new file mode 100644 index 0000000..633e632 --- /dev/null +++ b/null/string_test.go @@ -0,0 +1,206 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + stringJSON = []byte(`"test"`) + blankStringJSON = []byte(`""`) + + nullJSON = []byte(`null`) + invalidJSON = []byte(`:)`) +) + +type stringInStruct struct { + Test String `json:"test,omitempty"` +} + +func TestStringFrom(t *testing.T) { + str := StringFrom("test") + assertStr(t, str, "StringFrom() string") + + zero := StringFrom("") + if !zero.Valid { + t.Error("StringFrom(0)", "is invalid, but should be valid") + } +} + +func TestStringFromPtr(t *testing.T) { + s := "test" + sptr := &s + str := StringFromPtr(sptr) + assertStr(t, str, "StringFromPtr() string") + + null := StringFromPtr(nil) + assertNullStr(t, null, "StringFromPtr(nil)") +} + +func TestUnmarshalString(t *testing.T) { + var str String + err := json.Unmarshal(stringJSON, &str) + maybePanic(err) + assertStr(t, str, "string json") + + var blank String + err = json.Unmarshal(blankStringJSON, &blank) + maybePanic(err) + if !blank.Valid { + t.Error("blank string should be valid") + } + + var null String + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullStr(t, null, "null json") + + var badType String + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullStr(t, badType, "wrong type json") + + var invalid String + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullStr(t, invalid, "invalid json") +} + +func TestTextUnmarshalString(t *testing.T) { + var str String + err := str.UnmarshalText([]byte("test")) + maybePanic(err) + assertStr(t, str, "UnmarshalText() string") + + var null String + err = null.UnmarshalText([]byte("")) + maybePanic(err) + assertNullStr(t, null, "UnmarshalText() empty string") +} + +func TestMarshalString(t *testing.T) { + str := StringFrom("test") + data, err := json.Marshal(str) + maybePanic(err) + assertJSONEquals(t, data, `"test"`, "non-empty json marshal") + data, err = str.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "test", "non-empty text marshal") + + // empty values should be encoded as an empty string + zero := StringFrom("") + data, err = json.Marshal(zero) + maybePanic(err) + assertJSONEquals(t, data, `""`, "empty json marshal") + data, err = zero.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "string marshal text") + + null := StringFromPtr(nil) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, `null`, "null json marshal") + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "string marshal text") +} + +// Tests omitempty... broken until Go 1.4 +// func TestMarshalStringInStruct(t *testing.T) { +// obj := stringInStruct{Test: StringFrom("")} +// data, err := json.Marshal(obj) +// maybePanic(err) +// assertJSONEquals(t, data, `{}`, "null string in struct") +// } + +func TestStringPointer(t *testing.T) { + str := StringFrom("test") + ptr := str.Ptr() + if *ptr != "test" { + t.Errorf("bad %s string: %#v ≠ %s\n", "pointer", ptr, "test") + } + + null := NewString("", false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s string: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestStringIsNull(t *testing.T) { + str := StringFrom("test") + if str.IsNull() { + t.Errorf("IsNull() should be false") + } + + blank := StringFrom("") + if blank.IsNull() { + t.Errorf("IsNull() should be false") + } + + empty := NewString("", true) + if empty.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := StringFromPtr(nil) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + var testInt interface{} + testInt = empty + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestStringSetValid(t *testing.T) { + change := NewString("", false) + assertNullStr(t, change, "SetValid()") + change.SetValid("test") + assertStr(t, change, "SetValid()") +} + +func TestStringScan(t *testing.T) { + var str String + err := str.Scan("test") + maybePanic(err) + assertStr(t, str, "scanned string") + + var null String + err = null.Scan(nil) + maybePanic(err) + assertNullStr(t, null, "scanned null") +} + +func maybePanic(err error) { + if err != nil { + panic(err) + } +} + +func assertStr(t *testing.T, s String, from string) { + if s.String != "test" { + t.Errorf("bad %s string: %s ≠ %s\n", from, s.String, "test") + } + if !s.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullStr(t *testing.T, s String, from string) { + if s.Valid { + t.Error(from, "is valid, but should be invalid") + } +} + +func assertJSONEquals(t *testing.T, data []byte, cmp string, from string) { + if string(data) != cmp { + t.Errorf("bad %s data: %s ≠ %s\n", from, data, cmp) + } +} diff --git a/null/time.go b/null/time.go new file mode 100644 index 0000000..91d05be --- /dev/null +++ b/null/time.go @@ -0,0 +1,123 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "fmt" + "time" +) + +// Time is a nullable time.Time. It supports SQL and JSON serialization. +type Time struct { + Time time.Time + Valid bool +} + +// NewTime creates a new Time. +func NewTime(t time.Time, valid bool) Time { + return Time{ + Time: t, + Valid: valid, + } +} + +// TimeFrom creates a new Time that will always be valid. +func TimeFrom(t time.Time) Time { + return NewTime(t, true) +} + +// TimeFromPtr creates a new Time that will be null if t is nil. +func TimeFromPtr(t *time.Time) Time { + if t == nil { + return NewTime(time.Time{}, false) + } + return NewTime(*t, true) +} + +// MarshalJSON implements json.Marshaler. +func (t Time) MarshalJSON() ([]byte, error) { + if !t.Valid { + return NullBytes, nil + } + return t.Time.MarshalJSON() +} + +// UnmarshalJSON implements json.Unmarshaler. +func (t *Time) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + t.Valid = false + t.Time = time.Time{} + return nil + } + + if err := t.Time.UnmarshalJSON(data); err != nil { + return err + } + + t.Valid = true + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (t Time) MarshalText() ([]byte, error) { + if !t.Valid { + return NullBytes, nil + } + return t.Time.MarshalText() +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (t *Time) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + t.Valid = false + return nil + } + if err := t.Time.UnmarshalText(text); err != nil { + return err + } + t.Valid = true + return nil +} + +// SetValid changes this Time's value and sets it to be non-null. +func (t *Time) SetValid(v time.Time) { + t.Time = v + t.Valid = true +} + +// Ptr returns a pointer to this Time's value, or a nil pointer if this Time is null. +func (t Time) Ptr() *time.Time { + if !t.Valid { + return nil + } + return &t.Time +} + +// IsNull returns true for invalid Times, for future omitempty support (Go 1.4?) +func (t Time) IsNull() bool { + return !t.Valid +} + +// Scan implements the Scanner interface. +func (t *Time) Scan(value interface{}) error { + var err error + switch x := value.(type) { + case time.Time: + t.Time = x + case nil: + t.Valid = false + return nil + default: + err = fmt.Errorf("null: cannot scan type %T into null.Time: %v", value, value) + } + t.Valid = err == nil + return err +} + +// Value implements the driver Valuer interface. +func (t Time) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + return t.Time, nil +} diff --git a/null/time_test.go b/null/time_test.go new file mode 100644 index 0000000..a74e012 --- /dev/null +++ b/null/time_test.go @@ -0,0 +1,178 @@ +package null + +import ( + "encoding/json" + "testing" + "time" +) + +var ( + timeString = "2012-12-21T21:21:21Z" + timeJSON = []byte(`"` + timeString + `"`) + nullTimeJSON = []byte(`null`) + timeValue, _ = time.Parse(time.RFC3339, timeString) + badObject = []byte(`{"hello": "world"}`) +) + +func TestUnmarshalTimeJSON(t *testing.T) { + var ti Time + err := json.Unmarshal(timeJSON, &ti) + maybePanic(err) + assertTime(t, ti, "UnmarshalJSON() json") + + var null Time + err = json.Unmarshal(nullTimeJSON, &null) + maybePanic(err) + assertNullTime(t, null, "null time json") + + var invalid Time + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*time.ParseError); !ok { + t.Errorf("expected json.ParseError, not %T", err) + } + assertNullTime(t, invalid, "invalid from object json") + + var bad Time + err = json.Unmarshal(badObject, &bad) + if err == nil { + t.Errorf("expected error: bad object") + } + assertNullTime(t, bad, "bad from object json") + + var wrongType Time + err = json.Unmarshal(intJSON, &wrongType) + if err == nil { + t.Errorf("expected error: wrong type JSON") + } + assertNullTime(t, wrongType, "wrong type object json") +} + +func TestUnmarshalTimeText(t *testing.T) { + ti := TimeFrom(timeValue) + txt, err := ti.MarshalText() + maybePanic(err) + assertJSONEquals(t, txt, timeString, "marshal text") + + var unmarshal Time + err = unmarshal.UnmarshalText(txt) + maybePanic(err) + assertTime(t, unmarshal, "unmarshal text") + + var invalid Time + err = invalid.UnmarshalText([]byte("hello world")) + if err == nil { + t.Error("expected error") + } + assertNullTime(t, invalid, "bad string") +} + +func TestMarshalTime(t *testing.T) { + ti := TimeFrom(timeValue) + data, err := json.Marshal(ti) + maybePanic(err) + assertJSONEquals(t, data, string(timeJSON), "non-empty json marshal") + + ti.Valid = false + data, err = json.Marshal(ti) + maybePanic(err) + assertJSONEquals(t, data, string(nullJSON), "null json marshal") +} + +func TestTimeFrom(t *testing.T) { + ti := TimeFrom(timeValue) + assertTime(t, ti, "TimeFrom() time.Time") +} + +func TestTimeFromPtr(t *testing.T) { + ti := TimeFromPtr(&timeValue) + assertTime(t, ti, "TimeFromPtr() time") + + null := TimeFromPtr(nil) + assertNullTime(t, null, "TimeFromPtr(nil)") +} + +func TestTimeSetValid(t *testing.T) { + var ti time.Time + change := NewTime(ti, false) + assertNullTime(t, change, "SetValid()") + change.SetValid(timeValue) + assertTime(t, change, "SetValid()") +} + +func TestTimeIsNull(t *testing.T) { + ti := TimeFrom(time.Now()) + if ti.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewTime(time.Now(), false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewTime(time.Time{}, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestTimePointer(t *testing.T) { + ti := TimeFrom(timeValue) + ptr := ti.Ptr() + if *ptr != timeValue { + t.Errorf("bad %s time: %#v ≠ %v\n", "pointer", ptr, timeValue) + } + + var nt time.Time + null := NewTime(nt, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s time: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestTimeScanValue(t *testing.T) { + var ti Time + err := ti.Scan(timeValue) + maybePanic(err) + assertTime(t, ti, "scanned time") + if v, err := ti.Value(); v != timeValue || err != nil { + t.Error("bad value or err:", v, err) + } + + var null Time + err = null.Scan(nil) + maybePanic(err) + assertNullTime(t, null, "scanned null") + if v, err := null.Value(); v != nil || err != nil { + t.Error("bad value or err:", v, err) + } + + var wrong Time + err = wrong.Scan(int64(42)) + if err == nil { + t.Error("expected error") + } + assertNullTime(t, wrong, "scanned wrong") +} + +func assertTime(t *testing.T, ti Time, from string) { + if ti.Time != timeValue { + t.Errorf("bad %v time: %v ≠ %v\n", from, ti.Time, timeValue) + } + if !ti.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullTime(t *testing.T, ti Time, from string) { + if ti.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/uint.go b/null/uint.go new file mode 100644 index 0000000..87bc95d --- /dev/null +++ b/null/uint.go @@ -0,0 +1,123 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Uint is an nullable uint. +type Uint struct { + Uint uint + Valid bool +} + +// NewUint creates a new Uint +func NewUint(i uint, valid bool) Uint { + return Uint{ + Uint: i, + Valid: valid, + } +} + +// UintFrom creates a new Uint that will always be valid. +func UintFrom(i uint) Uint { + return NewUint(i, true) +} + +// UintFromPtr creates a new Uint that be null if i is nil. +func UintFromPtr(i *uint) Uint { + if i == nil { + return NewUint(0, false) + } + return NewUint(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (u *Uint) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + u.Valid = false + u.Uint = 0 + return nil + } + + var x uint64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + u.Uint = uint(x) + u.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (u *Uint) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + u.Valid = false + return nil + } + var err error + res, err := strconv.ParseUint(string(text), 10, 0) + u.Valid = err == nil + if u.Valid { + u.Uint = uint(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (u Uint) MarshalJSON() ([]byte, error) { + if !u.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (u Uint) MarshalText() ([]byte, error) { + if !u.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint), 10)), nil +} + +// SetValid changes this Uint's value and also sets it to be non-null. +func (u *Uint) SetValid(n uint) { + u.Uint = n + u.Valid = true +} + +// Ptr returns a pointer to this Uint's value, or a nil pointer if this Uint is null. +func (u Uint) Ptr() *uint { + if !u.Valid { + return nil + } + return &u.Uint +} + +// IsNull returns true for invalid Uints, for future omitempty support (Go 1.4?) +func (u Uint) IsNull() bool { + return !u.Valid +} + +// Scan implements the Scanner interface. +func (u *Uint) Scan(value interface{}) error { + if value == nil { + u.Uint, u.Valid = 0, false + return nil + } + u.Valid = true + return convert.ConvertAssign(&u.Uint, value) +} + +// Value implements the driver Valuer interface. +func (u Uint) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + return int64(u.Uint), nil +} diff --git a/null/uint16.go b/null/uint16.go new file mode 100644 index 0000000..9c5cb02 --- /dev/null +++ b/null/uint16.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Uint16 is an nullable uint16. +type Uint16 struct { + Uint16 uint16 + Valid bool +} + +// NewUint16 creates a new Uint16 +func NewUint16(i uint16, valid bool) Uint16 { + return Uint16{ + Uint16: i, + Valid: valid, + } +} + +// Uint16From creates a new Uint16 that will always be valid. +func Uint16From(i uint16) Uint16 { + return NewUint16(i, true) +} + +// Uint16FromPtr creates a new Uint16 that be null if i is nil. +func Uint16FromPtr(i *uint16) Uint16 { + if i == nil { + return NewUint16(0, false) + } + return NewUint16(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (u *Uint16) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + u.Valid = false + u.Uint16 = 0 + return nil + } + + var x uint64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxUint16 { + return fmt.Errorf("json: %d overflows max uint8 value", x) + } + + u.Uint16 = uint16(x) + u.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (u *Uint16) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + u.Valid = false + return nil + } + var err error + res, err := strconv.ParseUint(string(text), 10, 16) + u.Valid = err == nil + if u.Valid { + u.Uint16 = uint16(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (u Uint16) MarshalJSON() ([]byte, error) { + if !u.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint16), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (u Uint16) MarshalText() ([]byte, error) { + if !u.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint16), 10)), nil +} + +// SetValid changes this Uint16's value and also sets it to be non-null. +func (u *Uint16) SetValid(n uint16) { + u.Uint16 = n + u.Valid = true +} + +// Ptr returns a pointer to this Uint16's value, or a nil pointer if this Uint16 is null. +func (u Uint16) Ptr() *uint16 { + if !u.Valid { + return nil + } + return &u.Uint16 +} + +// IsNull returns true for invalid Uint16's, for future omitempty support (Go 1.4?) +func (u Uint16) IsNull() bool { + return !u.Valid +} + +// Scan implements the Scanner interface. +func (u *Uint16) Scan(value interface{}) error { + if value == nil { + u.Uint16, u.Valid = 0, false + return nil + } + u.Valid = true + return convert.ConvertAssign(&u.Uint16, value) +} + +// Value implements the driver Valuer interface. +func (u Uint16) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + return int64(u.Uint16), nil +} diff --git a/null/uint16_test.go b/null/uint16_test.go new file mode 100644 index 0000000..17f7475 --- /dev/null +++ b/null/uint16_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + uint16JSON = []byte(`65534`) +) + +func TestUint16From(t *testing.T) { + i := Uint16From(65534) + assertUint16(t, i, "Uint16From()") + + zero := Uint16From(0) + if !zero.Valid { + t.Error("Uint16From(0)", "is invalid, but should be valid") + } +} + +func TestUint16FromPtr(t *testing.T) { + n := uint16(65534) + iptr := &n + i := Uint16FromPtr(iptr) + assertUint16(t, i, "Uint16FromPtr()") + + null := Uint16FromPtr(nil) + assertNullUint16(t, null, "Uint16FromPtr(nil)") +} + +func TestUnmarshalUint16(t *testing.T) { + var i Uint16 + err := json.Unmarshal(uint16JSON, &i) + maybePanic(err) + assertUint16(t, i, "uint16 json") + + var null Uint16 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullUint16(t, null, "null json") + + var badType Uint16 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullUint16(t, badType, "wrong type json") + + var invalid Uint16 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullUint16(t, invalid, "invalid json") +} + +func TestUnmarshalNonUintegerNumber16(t *testing.T) { + var i Uint16 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to uint16") + } +} + +func TestUnmarshalUint16Overflow(t *testing.T) { + uint16Overflow := int64(math.MaxUint16) + + // Max uint16 should decode successfully + var i Uint16 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(uint16Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + uint16Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(uint16Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows uint16") + } +} + +func TestTextUnmarshalUint16(t *testing.T) { + var i Uint16 + err := i.UnmarshalText([]byte("65534")) + maybePanic(err) + assertUint16(t, i, "UnmarshalText() uint16") + + var blank Uint16 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullUint16(t, blank, "UnmarshalText() empty uint16") +} + +func TestMarshalUint16(t *testing.T) { + i := Uint16From(65534) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "65534", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewUint16(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalUint16Text(t *testing.T) { + i := Uint16From(65534) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "65534", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewUint16(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestUint16Pointer(t *testing.T) { + i := Uint16From(65534) + ptr := i.Ptr() + if *ptr != 65534 { + t.Errorf("bad %s uint16: %#v ≠ %d\n", "pointer", ptr, 65534) + } + + null := NewUint16(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s uint16: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestUint16IsNull(t *testing.T) { + i := Uint16From(65534) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewUint16(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewUint16(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestUint16SetValid(t *testing.T) { + change := NewUint16(0, false) + assertNullUint16(t, change, "SetValid()") + change.SetValid(65534) + assertUint16(t, change, "SetValid()") +} + +func TestUint16Scan(t *testing.T) { + var i Uint16 + err := i.Scan(65534) + maybePanic(err) + assertUint16(t, i, "scanned uint16") + + var null Uint16 + err = null.Scan(nil) + maybePanic(err) + assertNullUint16(t, null, "scanned null") +} + +func assertUint16(t *testing.T, i Uint16, from string) { + if i.Uint16 != 65534 { + t.Errorf("bad %s uint16: %d ≠ %d\n", from, i.Uint16, 65534) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullUint16(t *testing.T, i Uint16, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/uint32.go b/null/uint32.go new file mode 100644 index 0000000..f1f90d1 --- /dev/null +++ b/null/uint32.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Uint32 is an nullable uint32. +type Uint32 struct { + Uint32 uint32 + Valid bool +} + +// NewUint32 creates a new Uint32 +func NewUint32(i uint32, valid bool) Uint32 { + return Uint32{ + Uint32: i, + Valid: valid, + } +} + +// Uint32From creates a new Uint32 that will always be valid. +func Uint32From(i uint32) Uint32 { + return NewUint32(i, true) +} + +// Uint32FromPtr creates a new Uint32 that be null if i is nil. +func Uint32FromPtr(i *uint32) Uint32 { + if i == nil { + return NewUint32(0, false) + } + return NewUint32(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (u *Uint32) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + u.Valid = false + u.Uint32 = 0 + return nil + } + + var x uint64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxUint32 { + return fmt.Errorf("json: %d overflows max uint32 value", x) + } + + u.Uint32 = uint32(x) + u.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (u *Uint32) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + u.Valid = false + return nil + } + var err error + res, err := strconv.ParseUint(string(text), 10, 32) + u.Valid = err == nil + if u.Valid { + u.Uint32 = uint32(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (u Uint32) MarshalJSON() ([]byte, error) { + if !u.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint32), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (u Uint32) MarshalText() ([]byte, error) { + if !u.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint32), 10)), nil +} + +// SetValid changes this Uint32's value and also sets it to be non-null. +func (u *Uint32) SetValid(n uint32) { + u.Uint32 = n + u.Valid = true +} + +// Ptr returns a pointer to this Uint32's value, or a nil pointer if this Uint32 is null. +func (u Uint32) Ptr() *uint32 { + if !u.Valid { + return nil + } + return &u.Uint32 +} + +// IsNull returns true for invalid Uint32's, for future omitempty support (Go 1.4?) +func (u Uint32) IsNull() bool { + return !u.Valid +} + +// Scan implements the Scanner interface. +func (u *Uint32) Scan(value interface{}) error { + if value == nil { + u.Uint32, u.Valid = 0, false + return nil + } + u.Valid = true + return convert.ConvertAssign(&u.Uint32, value) +} + +// Value implements the driver Valuer interface. +func (u Uint32) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + return uint64(u.Uint32), nil +} diff --git a/null/uint32_test.go b/null/uint32_test.go new file mode 100644 index 0000000..52570f6 --- /dev/null +++ b/null/uint32_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + uint32JSON = []byte(`4294967294`) +) + +func TestUint32From(t *testing.T) { + i := Uint32From(4294967294) + assertUint32(t, i, "Uint32From()") + + zero := Uint32From(0) + if !zero.Valid { + t.Error("Uint32From(0)", "is invalid, but should be valid") + } +} + +func TestUint32FromPtr(t *testing.T) { + n := uint32(4294967294) + iptr := &n + i := Uint32FromPtr(iptr) + assertUint32(t, i, "Uint32FromPtr()") + + null := Uint32FromPtr(nil) + assertNullUint32(t, null, "Uint32FromPtr(nil)") +} + +func TestUnmarshalUint32(t *testing.T) { + var i Uint32 + err := json.Unmarshal(uint32JSON, &i) + maybePanic(err) + assertUint32(t, i, "uint32 json") + + var null Uint32 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullUint32(t, null, "null json") + + var badType Uint32 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullUint32(t, badType, "wrong type json") + + var invalid Uint32 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullUint32(t, invalid, "invalid json") +} + +func TestUnmarshalNonUintegerNumber32(t *testing.T) { + var i Uint32 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to uint32") + } +} + +func TestUnmarshalUint32Overflow(t *testing.T) { + uint32Overflow := int64(math.MaxUint32) + + // Max uint32 should decode successfully + var i Uint32 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(uint32Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + uint32Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(uint32Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows uint32") + } +} + +func TestTextUnmarshalUint32(t *testing.T) { + var i Uint32 + err := i.UnmarshalText([]byte("4294967294")) + maybePanic(err) + assertUint32(t, i, "UnmarshalText() uint32") + + var blank Uint32 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullUint32(t, blank, "UnmarshalText() empty uint32") +} + +func TestMarshalUint32(t *testing.T) { + i := Uint32From(4294967294) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "4294967294", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewUint32(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalUint32Text(t *testing.T) { + i := Uint32From(4294967294) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "4294967294", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewUint32(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestUint32Pointer(t *testing.T) { + i := Uint32From(4294967294) + ptr := i.Ptr() + if *ptr != 4294967294 { + t.Errorf("bad %s uint32: %#v ≠ %d\n", "pointer", ptr, 4294967294) + } + + null := NewUint32(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s uint32: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestUint32IsNull(t *testing.T) { + i := Uint32From(4294967294) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewUint32(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewUint32(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestUint32SetValid(t *testing.T) { + change := NewUint32(0, false) + assertNullUint32(t, change, "SetValid()") + change.SetValid(4294967294) + assertUint32(t, change, "SetValid()") +} + +func TestUint32Scan(t *testing.T) { + var i Uint32 + err := i.Scan(4294967294) + maybePanic(err) + assertUint32(t, i, "scanned uint32") + + var null Uint32 + err = null.Scan(nil) + maybePanic(err) + assertNullUint32(t, null, "scanned null") +} + +func assertUint32(t *testing.T, i Uint32, from string) { + if i.Uint32 != 4294967294 { + t.Errorf("bad %s uint32: %d ≠ %d\n", from, i.Uint32, 4294967294) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullUint32(t *testing.T, i Uint32, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/uint64.go b/null/uint64.go new file mode 100644 index 0000000..f0ff839 --- /dev/null +++ b/null/uint64.go @@ -0,0 +1,121 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Uint64 is an nullable uint64. +type Uint64 struct { + Uint64 uint64 + Valid bool +} + +// NewUint64 creates a new Uint64 +func NewUint64(i uint64, valid bool) Uint64 { + return Uint64{ + Uint64: i, + Valid: valid, + } +} + +// Uint64From creates a new Uint64 that will always be valid. +func Uint64From(i uint64) Uint64 { + return NewUint64(i, true) +} + +// Uint64FromPtr creates a new Uint64 that be null if i is nil. +func Uint64FromPtr(i *uint64) Uint64 { + if i == nil { + return NewUint64(0, false) + } + return NewUint64(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (u *Uint64) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + u.Uint64 = 0 + u.Valid = false + return nil + } + + if err := json.Unmarshal(data, &u.Uint64); err != nil { + return err + } + + u.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (u *Uint64) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + u.Valid = false + return nil + } + var err error + res, err := strconv.ParseUint(string(text), 10, 64) + u.Valid = err == nil + if u.Valid { + u.Uint64 = uint64(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (u Uint64) MarshalJSON() ([]byte, error) { + if !u.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatUint(u.Uint64, 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (u Uint64) MarshalText() ([]byte, error) { + if !u.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatUint(u.Uint64, 10)), nil +} + +// SetValid changes this Uint64's value and also sets it to be non-null. +func (u *Uint64) SetValid(n uint64) { + u.Uint64 = n + u.Valid = true +} + +// Ptr returns a pointer to this Uint64's value, or a nil pointer if this Uint64 is null. +func (u Uint64) Ptr() *uint64 { + if !u.Valid { + return nil + } + return &u.Uint64 +} + +// IsNull returns true for invalid Uint64's, for future omitempty support (Go 1.4?) +func (u Uint64) IsNull() bool { + return !u.Valid +} + +// Scan implements the Scanner interface. +func (u *Uint64) Scan(value interface{}) error { + if value == nil { + u.Uint64, u.Valid = 0, false + return nil + } + u.Valid = true + return convert.ConvertAssign(&u.Uint64, value) +} + +// Value implements the driver Valuer interface. +func (u Uint64) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + return int64(u.Uint64), nil +} diff --git a/null/uint64_test.go b/null/uint64_test.go new file mode 100644 index 0000000..0d11778 --- /dev/null +++ b/null/uint64_test.go @@ -0,0 +1,173 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + uint64JSON = []byte(`18446744073709551614`) +) + +func TestUint64From(t *testing.T) { + i := Uint64From(18446744073709551614) + assertUint64(t, i, "Uint64From()") + + zero := Uint64From(0) + if !zero.Valid { + t.Error("Uint64From(0)", "is invalid, but should be valid") + } +} + +func TestUint64FromPtr(t *testing.T) { + n := uint64(18446744073709551614) + iptr := &n + i := Uint64FromPtr(iptr) + assertUint64(t, i, "Uint64FromPtr()") + + null := Uint64FromPtr(nil) + assertNullUint64(t, null, "Uint64FromPtr(nil)") +} + +func TestUnmarshalUint64(t *testing.T) { + var i Uint64 + err := json.Unmarshal(uint64JSON, &i) + maybePanic(err) + assertUint64(t, i, "uint64 json") + + var null Uint64 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullUint64(t, null, "null json") + + var badType Uint64 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullUint64(t, badType, "wrong type json") + + var invalid Uint64 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullUint64(t, invalid, "invalid json") +} + +func TestUnmarshalNonUintegerNumber64(t *testing.T) { + var i Uint64 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to uint64") + } +} + +func TestTextUnmarshalUint64(t *testing.T) { + var i Uint64 + err := i.UnmarshalText([]byte("18446744073709551614")) + maybePanic(err) + assertUint64(t, i, "UnmarshalText() uint64") + + var blank Uint64 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullUint64(t, blank, "UnmarshalText() empty uint64") +} + +func TestMarshalUint64(t *testing.T) { + i := Uint64From(18446744073709551614) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "18446744073709551614", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewUint64(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalUint64Text(t *testing.T) { + i := Uint64From(18446744073709551614) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "18446744073709551614", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewUint64(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestUint64Pointer(t *testing.T) { + i := Uint64From(18446744073709551614) + ptr := i.Ptr() + if *ptr != 18446744073709551614 { + t.Errorf("bad %s uint64: %#v ≠ %d\n", "pointer", ptr, uint64(18446744073709551614)) + } + + null := NewUint64(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s uint64: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestUint64IsNull(t *testing.T) { + i := Uint64From(18446744073709551614) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewUint64(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewUint64(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestUint64SetValid(t *testing.T) { + change := NewUint64(0, false) + assertNullUint64(t, change, "SetValid()") + change.SetValid(18446744073709551614) + assertUint64(t, change, "SetValid()") +} + +func TestUint64Scan(t *testing.T) { + var i Uint64 + err := i.Scan(uint64(18446744073709551614)) + maybePanic(err) + assertUint64(t, i, "scanned uint64") + + var null Uint64 + err = null.Scan(nil) + maybePanic(err) + assertNullUint64(t, null, "scanned null") +} + +func assertUint64(t *testing.T, i Uint64, from string) { + if i.Uint64 != 18446744073709551614 { + t.Errorf("bad %s uint64: %d ≠ %d\n", from, i.Uint64, uint64(18446744073709551614)) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullUint64(t *testing.T, i Uint64, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/uint8.go b/null/uint8.go new file mode 100644 index 0000000..4475c37 --- /dev/null +++ b/null/uint8.go @@ -0,0 +1,129 @@ +package null + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "strconv" + + "gopkg.in/nullbio/null.v6/convert" +) + +// Uint8 is an nullable uint8. +type Uint8 struct { + Uint8 uint8 + Valid bool +} + +// NewUint8 creates a new Uint8 +func NewUint8(i uint8, valid bool) Uint8 { + return Uint8{ + Uint8: i, + Valid: valid, + } +} + +// Uint8From creates a new Uint8 that will always be valid. +func Uint8From(i uint8) Uint8 { + return NewUint8(i, true) +} + +// Uint8FromPtr creates a new Uint8 that be null if i is nil. +func Uint8FromPtr(i *uint8) Uint8 { + if i == nil { + return NewUint8(0, false) + } + return NewUint8(*i, true) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (u *Uint8) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, NullBytes) { + u.Valid = false + u.Uint8 = 0 + return nil + } + + var x uint64 + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if x > math.MaxUint8 { + return fmt.Errorf("json: %d overflows max uint8 value", x) + } + + u.Uint8 = uint8(x) + u.Valid = true + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (u *Uint8) UnmarshalText(text []byte) error { + if text == nil || len(text) == 0 { + u.Valid = false + return nil + } + var err error + res, err := strconv.ParseUint(string(text), 10, 8) + u.Valid = err == nil + if u.Valid { + u.Uint8 = uint8(res) + } + return err +} + +// MarshalJSON implements json.Marshaler. +func (u Uint8) MarshalJSON() ([]byte, error) { + if !u.Valid { + return NullBytes, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint8), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (u Uint8) MarshalText() ([]byte, error) { + if !u.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatUint(uint64(u.Uint8), 10)), nil +} + +// SetValid changes this Uint8's value and also sets it to be non-null. +func (u *Uint8) SetValid(n uint8) { + u.Uint8 = n + u.Valid = true +} + +// Ptr returns a pointer to this Uint8's value, or a nil pointer if this Uint8 is null. +func (u Uint8) Ptr() *uint8 { + if !u.Valid { + return nil + } + return &u.Uint8 +} + +// IsNull returns true for invalid Uint8's, for future omitempty support (Go 1.4?) +func (u Uint8) IsNull() bool { + return !u.Valid +} + +// Scan implements the Scanner interface. +func (u *Uint8) Scan(value interface{}) error { + if value == nil { + u.Uint8, u.Valid = 0, false + return nil + } + u.Valid = true + return convert.ConvertAssign(&u.Uint8, value) +} + +// Value implements the driver Valuer interface. +func (u Uint8) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + return int64(u.Uint8), nil +} diff --git a/null/uint8_test.go b/null/uint8_test.go new file mode 100644 index 0000000..33681ee --- /dev/null +++ b/null/uint8_test.go @@ -0,0 +1,191 @@ +package null + +import ( + "encoding/json" + "math" + "strconv" + "testing" +) + +var ( + uint8JSON = []byte(`254`) +) + +func TestUint8From(t *testing.T) { + i := Uint8From(254) + assertUint8(t, i, "Uint8From()") + + zero := Uint8From(0) + if !zero.Valid { + t.Error("Uint8From(0)", "is invalid, but should be valid") + } +} + +func TestUint8FromPtr(t *testing.T) { + n := uint8(254) + iptr := &n + i := Uint8FromPtr(iptr) + assertUint8(t, i, "Uint8FromPtr()") + + null := Uint8FromPtr(nil) + assertNullUint8(t, null, "Uint8FromPtr(nil)") +} + +func TestUnmarshalUint8(t *testing.T) { + var i Uint8 + err := json.Unmarshal(uint8JSON, &i) + maybePanic(err) + assertUint8(t, i, "uint8 json") + + var null Uint8 + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullUint8(t, null, "null json") + + var badType Uint8 + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullUint8(t, badType, "wrong type json") + + var invalid Uint8 + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullUint8(t, invalid, "invalid json") +} + +func TestUnmarshalNonUintegerNumber8(t *testing.T) { + var i Uint8 + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-integer number coerced to uint8") + } +} + +func TestUnmarshalUint8Overflow(t *testing.T) { + uint8Overflow := int64(math.MaxUint8) + + // Max uint8 should decode successfully + var i Uint8 + err := json.Unmarshal([]byte(strconv.FormatUint(uint64(uint8Overflow), 10)), &i) + maybePanic(err) + + // Attempt to overflow + uint8Overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(uint64(uint8Overflow), 10)), &i) + if err == nil { + panic("err should be present; decoded value overflows uint8") + } +} + +func TestTextUnmarshalUint8(t *testing.T) { + var i Uint8 + err := i.UnmarshalText([]byte("254")) + maybePanic(err) + assertUint8(t, i, "UnmarshalText() uint8") + + var blank Uint8 + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullUint8(t, blank, "UnmarshalText() empty uint8") +} + +func TestMarshalUint8(t *testing.T) { + i := Uint8From(254) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "254", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewUint8(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalUint8Text(t *testing.T) { + i := Uint8From(254) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "254", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewUint8(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestUint8Pointer(t *testing.T) { + i := Uint8From(254) + ptr := i.Ptr() + if *ptr != 254 { + t.Errorf("bad %s uint8: %#v ≠ %d\n", "pointer", ptr, 254) + } + + null := NewUint8(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s uint8: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestUint8IsNull(t *testing.T) { + i := Uint8From(254) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewUint8(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewUint8(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestUint8SetValid(t *testing.T) { + change := NewUint8(0, false) + assertNullUint8(t, change, "SetValid()") + change.SetValid(254) + assertUint8(t, change, "SetValid()") +} + +func TestUint8Scan(t *testing.T) { + var i Uint8 + err := i.Scan(254) + maybePanic(err) + assertUint8(t, i, "scanned uint8") + + var null Uint8 + err = null.Scan(nil) + maybePanic(err) + assertNullUint8(t, null, "scanned null") +} + +func assertUint8(t *testing.T, i Uint8, from string) { + if i.Uint8 != 254 { + t.Errorf("bad %s uint8: %d ≠ %d\n", from, i.Uint8, 254) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullUint8(t *testing.T, i Uint8, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/null/uint_test.go b/null/uint_test.go new file mode 100644 index 0000000..e5d911d --- /dev/null +++ b/null/uint_test.go @@ -0,0 +1,173 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + uintJSON = []byte(`12345`) +) + +func TestUintFrom(t *testing.T) { + i := UintFrom(12345) + assertUint(t, i, "UintFrom()") + + zero := UintFrom(0) + if !zero.Valid { + t.Error("UintFrom(0)", "is invalid, but should be valid") + } +} + +func TestUintFromPtr(t *testing.T) { + n := uint(12345) + iptr := &n + i := UintFromPtr(iptr) + assertUint(t, i, "UintFromPtr()") + + null := UintFromPtr(nil) + assertNullUint(t, null, "UintFromPtr(nil)") +} + +func TestUnmarshalUint(t *testing.T) { + var i Uint + err := json.Unmarshal(uintJSON, &i) + maybePanic(err) + assertUint(t, i, "uint json") + + var null Uint + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullUint(t, null, "null json") + + var badType Uint + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullUint(t, badType, "wrong type json") + + var invalid Uint + err = invalid.UnmarshalJSON(invalidJSON) + if _, ok := err.(*json.SyntaxError); !ok { + t.Errorf("expected json.SyntaxError, not %T", err) + } + assertNullUint(t, invalid, "invalid json") +} + +func TestUnmarshalNonUintegerNumber(t *testing.T) { + var i Uint + err := json.Unmarshal(float64JSON, &i) + if err == nil { + panic("err should be present; non-uinteger number coerced to uint") + } +} + +func TestTextUnmarshalUint(t *testing.T) { + var i Uint + err := i.UnmarshalText([]byte("12345")) + maybePanic(err) + assertUint(t, i, "UnmarshalText() uint") + + var blank Uint + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullUint(t, blank, "UnmarshalText() empty uint") +} + +func TestMarshalUint(t *testing.T) { + i := UintFrom(12345) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewUint(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalUintText(t *testing.T) { + i := UintFrom(12345) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewUint(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestUintPointer(t *testing.T) { + i := UintFrom(12345) + ptr := i.Ptr() + if *ptr != 12345 { + t.Errorf("bad %s uint: %#v ≠ %d\n", "pointer", ptr, 12345) + } + + null := NewUint(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s uint: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestUintIsNull(t *testing.T) { + i := UintFrom(12345) + if i.IsNull() { + t.Errorf("IsNull() should be false") + } + + null := NewUint(0, false) + if !null.IsNull() { + t.Errorf("IsNull() should be true") + } + + zero := NewUint(0, true) + if zero.IsNull() { + t.Errorf("IsNull() should be false") + } + + var testInt interface{} + testInt = zero + if _, ok := testInt.(Nullable); !ok { + t.Errorf("Nullable interface should be implemented") + } +} + +func TestUintSetValid(t *testing.T) { + change := NewUint(0, false) + assertNullUint(t, change, "SetValid()") + change.SetValid(12345) + assertUint(t, change, "SetValid()") +} + +func TestUintScan(t *testing.T) { + var i Uint + err := i.Scan(12345) + maybePanic(err) + assertUint(t, i, "scanned uint") + + var null Uint + err = null.Scan(nil) + maybePanic(err) + assertNullUint(t, null, "scanned null") +} + +func assertUint(t *testing.T, i Uint, from string) { + if i.Uint != 12345 { + t.Errorf("bad %s uint: %d ≠ %d\n", from, i.Uint, 12345) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullUint(t *testing.T, i Uint, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +}