diff --git a/int.go b/int.go new file mode 100644 index 0000000..b6d387d --- /dev/null +++ b/int.go @@ -0,0 +1,93 @@ +package null + +import ( + "database/sql" + "encoding/json" + "strconv" +) + +// Int is a nullable int. +type Int struct { + sql.NullInt64 +} + +// IntFrom creates a new Int that will always be valid. +func IntFrom(i int64) Int { + return NewInt(i, true) +} + +// NewInt creates a new Int +func NewInt(i int64, valid bool) Int { + return Int{ + NullInt64: sql.NullInt64{ + Int64: i, + Valid: valid, + }, + } +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number and null input. +// 0 will not be considered a null Int. +// It also supports unmarshalling a sql.NullInt64. +func (i *Int) UnmarshalJSON(data []byte) error { + var err error + var v interface{} + json.Unmarshal(data, &v) + switch x := v.(type) { + case float64: + i.Int64 = int64(x) + case map[string]interface{}: + err = json.Unmarshal(data, &i.NullInt64) + case nil: + i.Valid = false + return nil + } + i.Valid = err == nil + return err +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Int if the input is a blank or not an integer. +// It will return an error if the input is not an integer, blank, or "null". +func (i *Int) UnmarshalText(text []byte) error { + str := string(text) + if str == "" || str == "null" { + i.Valid = false + return nil + } + var err error + i.Int64, err = strconv.ParseInt(string(text), 10, 64) + i.Valid = err == nil + return err +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this Int is null. +func (i Int) MarshalJSON() ([]byte, error) { + if !i.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(i.Int64, 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a blank string if this Int is null. +func (i Int) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(i.Int64, 10)), nil +} + +func (i Int) Pointer() *int64 { + if !i.Valid { + return nil + } + return &i.Int64 +} + +// IsZero returns true for invalid Ints, for future omitempty support (Go 1.4?) +func (i Int) IsZero() bool { + return !i.Valid +} diff --git a/int_test.go b/int_test.go new file mode 100644 index 0000000..eb6f5f8 --- /dev/null +++ b/int_test.go @@ -0,0 +1,139 @@ +package null + +import ( + "encoding/json" + "testing" +) + +var ( + intJSON = []byte(`12345`) + nullIntJSON = []byte(`{"Int64":12345,"Valid":true}`) +) + +func TestIntFrom(t *testing.T) { + i := IntFrom(12345) + assertInt(t, i, "IntFrom()") + + zero := IntFrom(0) + if !zero.Valid { + t.Error("IntFrom(0)", "is invalid, but should be valid") + } +} + +func TestUnmarshalInt(t *testing.T) { + var i Int + err := json.Unmarshal(intJSON, &i) + maybePanic(err) + assertInt(t, i, "int json") + + var ni Int + err = json.Unmarshal(nullIntJSON, &ni) + maybePanic(err) + assertInt(t, ni, "sq.NullInt64 json") + + var null Int + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt(t, null, "null json") +} + +func TestTextUnmarshalInt(t *testing.T) { + var i Int + err := i.UnmarshalText([]byte("12345")) + maybePanic(err) + assertInt(t, i, "UnmarshalText() int") + + var blank Int + err = blank.UnmarshalText([]byte("")) + maybePanic(err) + assertNullInt(t, blank, "UnmarshalText() empty int") + + var null Int + err = null.UnmarshalText([]byte("null")) + maybePanic(err) + assertNullInt(t, null, `UnmarshalText() "null"`) +} + +func TestMarshalInt(t *testing.T) { + i := IntFrom(12345) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty json marshal") + + // invalid values should be encoded as null + null := NewInt(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") +} + +func TestMarshalIntText(t *testing.T) { + i := IntFrom(12345) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "12345", "non-empty text marshal") + + // invalid values should be encoded as null + null := NewInt(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") +} + +func TestIntPointer(t *testing.T) { + i := IntFrom(12345) + ptr := i.Pointer() + if *ptr != 12345 { + t.Errorf("bad %s int: %#v ≠ %s\n", "pointer", ptr, 12345) + } + + null := NewInt(0, false) + ptr = null.Pointer() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } +} + +func TestIntIsZero(t *testing.T) { + i := IntFrom(12345) + if i.IsZero() { + t.Errorf("IsZero() should be false") + } + + null := NewInt(0, false) + if !null.IsZero() { + t.Errorf("IsZero() should be true") + } + + zero := NewInt(0, true) + if zero.IsZero() { + t.Errorf("IsZero() should be false") + } +} + +func TestIntScan(t *testing.T) { + var i Int + err := i.Scan(12345) + maybePanic(err) + assertInt(t, i, "scanned int") + + var null Int + err = null.Scan(nil) + maybePanic(err) + assertNullInt(t, null, "scanned null") +} + +func assertInt(t *testing.T, i Int, from string) { + if i.Int64 != 12345 { + t.Errorf("bad %s int: %d ≠ %d\n", from, i.Int64, 12345) + } + if !i.Valid { + t.Error(from, "is invalid, but should be valid") + } +} + +func assertNullInt(t *testing.T, i Int, from string) { + if i.Valid { + t.Error(from, "is valid, but should be invalid") + } +} diff --git a/string.go b/string.go index 5404ae8..6e54c52 100644 --- a/string.go +++ b/string.go @@ -33,9 +33,9 @@ func (s *String) UnmarshalJSON(data []byte) error { var err error var v interface{} json.Unmarshal(data, &v) - switch v.(type) { + switch x := v.(type) { case string: - err = json.Unmarshal(data, &s.String) + s.String = x case map[string]interface{}: err = json.Unmarshal(data, &s.NullString) case nil: @@ -65,7 +65,7 @@ func (s *String) UnmarshalText(text []byte) error { // Pointer returns a pointer to this String's value, or a nil pointer if this String is null. func (s String) Pointer() *string { - if s.String == "" { + if !s.Valid { return nil } return &s.String diff --git a/string_test.go b/string_test.go index a8556b6..abba506 100644 --- a/string_test.go +++ b/string_test.go @@ -18,44 +18,44 @@ type stringInStruct struct { func TestStringFrom(t *testing.T) { str := StringFrom("test") - assert(t, str, "StringFrom() string") + assertStr(t, str, "StringFrom() string") null := StringFrom("") - assertNull(t, null, "StringFrom() empty string") + assertNullStr(t, null, "StringFrom() empty string") } func TestUnmarshalString(t *testing.T) { var str String err := json.Unmarshal(stringJSON, &str) maybePanic(err) - assert(t, str, "string json") + assertStr(t, str, "string json") var ns String err = json.Unmarshal(nullStringJSON, &ns) maybePanic(err) - assert(t, ns, "null string object json") + assertStr(t, ns, "sql.NullString json") var blank String err = json.Unmarshal(blankStringJSON, &blank) maybePanic(err) - assertNull(t, blank, "blank string json") + assertNullStr(t, blank, "blank string json") var null String err = json.Unmarshal(nullJSON, &null) maybePanic(err) - assertNull(t, null, "null json") + assertNullStr(t, null, "null json") } func TestTextUnmarshalString(t *testing.T) { var str String err := str.UnmarshalText([]byte("test")) maybePanic(err) - assert(t, str, "UnmarshalText() string") + assertStr(t, str, "UnmarshalText() string") var null String err = null.UnmarshalText([]byte("")) maybePanic(err) - assertNull(t, null, "UnmarshalText() empty string") + assertNullStr(t, null, "UnmarshalText() empty string") } func TestMarshalString(t *testing.T) { @@ -79,7 +79,7 @@ func TestMarshalString(t *testing.T) { // assertJSONEquals(t, data, `{}`, "null string in struct") // } -func TestPointer(t *testing.T) { +func TestStringPointer(t *testing.T) { str := StringFrom("test") ptr := str.Pointer() if *ptr != "test" { @@ -89,11 +89,11 @@ func TestPointer(t *testing.T) { null := StringFrom("") ptr = null.Pointer() if ptr != nil { - t.Errorf("bad %s: %#v ≠ %s\n", "nil pointer", ptr, "nil") + t.Errorf("bad %s string: %#v ≠ %s\n", "nil pointer", ptr, "nil") } } -func TestIsZero(t *testing.T) { +func TestStringIsZero(t *testing.T) { str := StringFrom("test") if str.IsZero() { t.Errorf("IsZero() should be false") @@ -110,16 +110,16 @@ func TestIsZero(t *testing.T) { } } -func TestScan(t *testing.T) { +func TestStringScan(t *testing.T) { var str String err := str.Scan("test") maybePanic(err) - assert(t, str, "scanned string") + assertStr(t, str, "scanned string") var null String err = null.Scan(nil) maybePanic(err) - assertNull(t, null, "scanned null") + assertNullStr(t, null, "scanned null") } func maybePanic(err error) { @@ -128,7 +128,7 @@ func maybePanic(err error) { } } -func assert(t *testing.T, s String, from string) { +func assertStr(t *testing.T, s String, from string) { if s.String != "test" { t.Errorf("bad %s string: %s ≠ %s\n", from, s.String, "test") } @@ -137,7 +137,7 @@ func assert(t *testing.T, s String, from string) { } } -func assertNull(t *testing.T, s String, from string) { +func assertNullStr(t *testing.T, s String, from string) { if s.Valid { t.Error(from, "is valid, but should be invalid") }