Add array types and hstore types

This commit is contained in:
Patrick O'brien 2016-09-12 03:40:59 +10:00
parent 793522650c
commit e62dfe369f
9 changed files with 2321 additions and 64 deletions

View file

@ -1061,6 +1061,16 @@ If your database uses multiple schemas you should generate a new package for eac
Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not
fake schemas (like MySQL). fake schemas (like MySQL).
#### How do I use types.BytesArray for Postgres bytea arrays?
Only "escaped format" is supported for types.BytesArray. This means that your byte slice needs to have
a format of "\\x00" (4 bytes per byte) opposed to "\x00" (1 byte per byte). This is to maintain compatibility
with all Postgres drivers. Example:
`x := types.BytesArray{0: []byte("\\x68\\x69")}`
Please note that multi-dimensional Postgres ARRAY types are not supported at this time.
#### Where is the homepage? #### Where is the homepage?
The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler

View file

@ -5,6 +5,11 @@ import "github.com/vattle/sqlboiler/strmangle"
// Column holds information about a database column. // Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType. // Types are Go types, converted by TranslateColumnType.
type Column struct { type Column struct {
// ArrType is the underlying data type of the Postgres
// ARRAY type. See here:
// https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html
ArrType *string
UDTName string
Name string Name string
Type string Type string
DBType string DBType string

View file

@ -148,12 +148,11 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
} }
column := bdb.Column{ column := bdb.Column{
Name: colName, Name: colName,
DBType: colType, DBType: colType,
Default: colDefault, Default: colDefault,
Nullable: nullable == "YES", Nullable: nullable == "YES",
Unique: unique, Unique: unique,
Validated: psqlIsValidated(colType),
} }
columns = append(columns, column) columns = append(columns, column)
} }
@ -306,19 +305,6 @@ func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column {
return c return c
} }
var mySQLValidatedTypes = []string{}
// isValidated checks if the database type is in the validatedTypes list.
func mySQLIsValidated(typ string) bool {
for _, v := range mySQLValidatedTypes {
if v == typ {
return true
}
}
return false
}
// RightQuote is the quoting character for the right side of the identifier // RightQuote is the quoting character for the right side of the identifier
func (m *MySQLDriver) RightQuote() byte { func (m *MySQLDriver) RightQuote() byte {
return '`' return '`'

View file

@ -19,9 +19,6 @@ type PostgresDriver struct {
dbConn *sql.DB dbConn *sql.DB
} }
// validatedTypes are types that cannot be zero values in the database.
var psqlValidatedTypes = []string{"uuid"}
// NewPostgresDriver takes the database connection details as parameters and // NewPostgresDriver takes the database connection details as parameters and
// returns a pointer to a PostgresDriver object. Note that it is required to // returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close // call PostgresDriver.Open() and PostgresDriver.Close() to open and close
@ -126,7 +123,7 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
var columns []bdb.Column var columns []bdb.Column
rows, err := p.dbConn.Query(` rows, err := p.dbConn.Query(`
select column_name, data_type, column_default, is_nullable, select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable,
(select exists( (select exists(
select 1 select 1
from information_schema.constraint_column_usage as ccu from information_schema.constraint_column_usage as ccu
@ -142,8 +139,10 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
where where
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique )) as is_unique
from information_schema.columns as c from information_schema.columns as c LEFT JOIN information_schema.element_types e
where table_name=$2 and table_schema = $1; ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name=$2 and c.table_schema = $1;
`, schema, tableName) `, schema, tableName)
if err != nil { if err != nil {
@ -152,10 +151,11 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var colName, colType, colDefault, nullable string var colName, udtName, colType, colDefault, nullable string
var elementType *string
var unique bool var unique bool
var defaultPtr *string var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
} }
@ -166,12 +166,13 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
} }
column := bdb.Column{ column := bdb.Column{
Name: colName, Name: colName,
DBType: colType, DBType: colType,
Default: colDefault, ArrType: elementType,
Nullable: nullable == "YES", UDTName: udtName,
Unique: unique, Default: colDefault,
Validated: psqlIsValidated(colType), Nullable: nullable == "YES",
Unique: unique,
} }
columns = append(columns, column) columns = append(columns, column)
} }
@ -297,6 +298,21 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Bool" c.Type = "null.Bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone": case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "null.Time" c.Type = "null.Time"
case "ARRAY":
if c.ArrType == nil {
panic("unable to get postgres ARRAY underlying type")
}
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.Hstore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default: default:
c.Type = "null.String" c.Type = "null.String"
} }
@ -322,6 +338,18 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "bool" c.Type = "bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone": case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "time.Time" c.Type = "time.Time"
case "ARRAY":
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.Hstore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default: default:
c.Type = "string" c.Type = "string"
} }
@ -330,15 +358,22 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
return c return c
} }
// isValidated checks if the database type is in the validatedTypes list. // getArrayType returns the correct boil.Array type for each database type
func psqlIsValidated(typ string) bool { func getArrayType(c bdb.Column) string {
for _, v := range psqlValidatedTypes { switch *c.ArrType {
if v == typ { case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial":
return true return "types.Int64Array"
} case "bytea":
return "types.BytesArray"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
return "types.StringArray"
case "bool":
return "types.BoolArray"
case "decimal", "numeric", "double precision", "real":
return "types.Float64Array"
default:
return "types.GenericArray"
} }
return false
} }
// RightQuote is the quoting character for the right side of the identifier // RightQuote is the quoting character for the right side of the identifier

View file

@ -2,17 +2,20 @@
package randomize package randomize
import ( import (
"database/sql"
"fmt" "fmt"
"math/rand" "math/rand"
"reflect" "reflect"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
"gopkg.in/nullbio/null.v5" "gopkg.in/nullbio/null.v5"
"github.com/lib/pq/hstore"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"github.com/vattle/sqlboiler/boil/types" "github.com/vattle/sqlboiler/boil/types"
@ -20,32 +23,39 @@ import (
) )
var ( var (
typeNullFloat32 = reflect.TypeOf(null.Float32{}) typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{}) typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{}) typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{}) typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{}) typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{}) typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{}) typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{}) typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{}) typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{}) typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{}) typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{}) typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{}) typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{}) typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{}) typeNullTime = reflect.TypeOf(null.Time{})
typeNullBytes = reflect.TypeOf(null.Bytes{}) typeNullBytes = reflect.TypeOf(null.Bytes{})
typeNullJSON = reflect.TypeOf(null.JSON{}) typeNullJSON = reflect.TypeOf(null.JSON{})
typeTime = reflect.TypeOf(time.Time{}) typeTime = reflect.TypeOf(time.Time{})
typeJSON = reflect.TypeOf(types.JSON{}) typeJSON = reflect.TypeOf(types.JSON{})
rgxValidTime = regexp.MustCompile(`[2-9]+`) typeInt64Array = reflect.TypeOf(types.Int64Array{})
typeBytesArray = reflect.TypeOf(types.BytesArray{})
typeBoolArray = reflect.TypeOf(types.BoolArray{})
typeFloat64Array = reflect.TypeOf(types.Float64Array{})
typeStringArray = reflect.TypeOf(types.StringArray{})
typeGenericArray = reflect.TypeOf(types.GenericArray{})
typeHstore = reflect.TypeOf(types.Hstore{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
validatedTypes = []string{ validatedTypes = []string{
"inet", "line", "uuid", "interval", "inet", "line", "uuid", "interval",
"json", "jsonb", "box", "cidr", "circle", "json", "jsonb", "box", "cidr", "circle",
"lseg", "macaddr", "path", "pg_lsn", "point", "lseg", "macaddr", "path", "pg_lsn", "point",
"polygon", "txid_snapshot", "money", "polygon", "txid_snapshot", "money", "hstore",
} }
) )
@ -218,7 +228,14 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
return nil return nil
case typeHstore:
value := hstore.Hstore{Map: map[string]sql.NullString{}}
value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
} }
} else { } else {
switch kind { switch kind {
case reflect.String: case reflect.String:
@ -279,6 +296,12 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
return nil return nil
case typeHstore:
value := types.Hstore{}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
} }
} }
} }
@ -293,8 +316,15 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
isNull = false isNull = false
} }
// Retrieve the value to be returned // If it's a Postgres array, treat it like one
if kind == reflect.Struct { if strings.HasPrefix(fieldType, "ARRAY") {
if isNull {
value = getArrayNullValue(typ)
} else {
value = getArrayRandValue(s, typ)
}
// Retrieve the value to be returned
} else if kind == reflect.Struct {
if isNull { if isNull {
value = getStructNullValue(typ) value = getStructNullValue(typ)
} else { } else {
@ -317,6 +347,45 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
return nil return nil
} }
func getArrayNullValue(typ reflect.Type) interface{} {
fmt.Println(typ)
switch typ {
case typeInt64Array:
return types.Int64Array{}
case typeFloat64Array:
return types.Float64Array{}
case typeBoolArray:
return types.BoolArray{}
case typeStringArray:
return types.StringArray{}
case typeBytesArray:
return types.BytesArray{}
case typeGenericArray:
return types.GenericArray{}
}
return nil
}
func getArrayRandValue(s *Seed, typ reflect.Type) interface{} {
switch typ {
case typeInt64Array:
return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())}
case typeFloat64Array:
return types.Float64Array{float64(s.nextInt()), float64(s.nextInt())}
case typeBoolArray:
return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0}
case typeStringArray:
return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)}
case typeBytesArray:
return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)}
case typeGenericArray:
return types.GenericArray{A: []types.JSON{randJSON(s, 4), randJSON(s, 4), randJSON(s, 4)}}
}
return nil
}
// getStructNullValue for the matching type. // getStructNullValue for the matching type.
func getStructNullValue(typ reflect.Type) interface{} { func getStructNullValue(typ reflect.Type) interface{} {
switch typ { switch typ {
@ -505,6 +574,17 @@ func randByteSlice(s *Seed, ln int) []byte {
return str return str
} }
func randJSON(s *Seed, ln int) types.JSON {
str := make(types.JSON, ln)
str[0] = '"'
for i := 1; i < ln-1; i++ {
str[i] = byte(s.nextInt() % 256)
}
str[ln-1] = '"'
return str
}
func randPoint() string { func randPoint() string {
a := rand.Intn(100) a := rand.Intn(100)
b := a + 1 b := a + 1

863
boil/types/array.go Normal file
View file

@ -0,0 +1,863 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func encode(x interface{}) []byte {
switch v := x.(type) {
case int64:
return strconv.AppendInt(nil, v, 10)
case float64:
return strconv.AppendFloat(nil, v, 'f', -1, 64)
case []byte:
return encodeBytes(v)
case string:
return []byte(v)
case bool:
return strconv.AppendBool(nil, v)
case time.Time:
return formatTimestamp(v)
default:
panic(fmt.Errorf("encode: unknown type for %T", v))
}
}
// FormatTimestamp formats t into Postgres' text format for timestamps.
func formatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format(time.RFC3339Nano))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
func encodeBytes(v []byte) (result []byte) {
for _, b := range v {
if b == '\\' {
result = append(result, '\\', '\\')
} else if b < 0x20 || b > 0x7e {
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
} else {
result = append(result, b)
}
}
return result
}
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytes(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
//
// For example:
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
//
// var x []sql.NullInt64
// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
//
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
// bound is not one (such as `[0:0]={1}') are not supported.
func Array(a interface{}) interface {
driver.Valuer
sql.Scanner
} {
switch a := a.(type) {
case []bool:
return (*BoolArray)(&a)
case []float64:
return (*Float64Array)(&a)
case []int64:
return (*Int64Array)(&a)
case []string:
return (*StringArray)(&a)
case *[]bool:
return (*BoolArray)(a)
case *[]float64:
return (*Float64Array)(a)
case *[]int64:
return (*Int64Array)(a)
case *[]string:
return (*StringArray)(a)
}
return GenericArray{a}
}
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
// to override the array delimiter used by GenericArray.
type ArrayDelimiter interface {
// ArrayDelimiter returns the delimiter character(s) for this element's type.
ArrayDelimiter() string
}
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
type BoolArray []bool
// Scan implements the sql.Scanner interface.
func (a *BoolArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
}
func (a *BoolArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BoolArray, len(elems))
for i, v := range elems {
if len(v) != 1 {
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
}
switch v[0] {
case 't':
b[i] = true
case 'f':
b[i] = false
default:
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a BoolArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be exactly two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1+2*n)
for i := 0; i < n; i++ {
b[2*i] = ','
if a[i] {
b[1+2*i] = 't'
} else {
b[1+2*i] = 'f'
}
}
b[0] = '{'
b[2*n] = '}'
return string(b), nil
}
return "{}", nil
}
// BytesArray represents a one-dimensional array of the PostgreSQL bytea type.
type BytesArray [][]byte
// Scan implements the sql.Scanner interface.
func (a *BytesArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("pq: cannot convert %T to BytesArray", src)
}
func (a *BytesArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BytesArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BytesArray, len(elems))
for i, v := range elems {
b[i], err = parseBytes(v)
if err != nil {
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface. It uses the "hex" format which
// is only supported on PostgreSQL 9.0 or newer.
func (a BytesArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
size := 1 + 6*n
for _, x := range a {
size += hex.EncodedLen(len(x))
}
b := make([]byte, size)
for i, s := 0, b; i < n; i++ {
o := copy(s, `,"\\x`)
o += hex.Encode(s[o:], a[i])
s[o] = '"'
s = s[o+1:]
}
b[0] = '{'
b[size-1] = '}'
return string(b), nil
}
return "{}", nil
}
// Float64Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float64Array []float64
// Scan implements the sql.Scanner interface.
func (a *Float64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
}
func (a *Float64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
// an array or slice of any dimension.
type GenericArray struct{ A interface{} }
func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
var assign func([]byte, reflect.Value) error
var del = ","
// TODO calculate the assign function for other types
// TODO repeat this section on the element type of arrays or slices (multidimensional)
{
if reflect.PtrTo(rt).Implements(typeSQLScanner) {
// dest is always addressable because it is an element of a slice.
assign = func(src []byte, dest reflect.Value) (err error) {
ss := dest.Addr().Interface().(sql.Scanner)
if src == nil {
err = ss.Scan(nil)
} else {
err = ss.Scan(src)
}
return
}
goto FoundType
}
assign = func([]byte, reflect.Value) error {
return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
}
}
FoundType:
if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
return rt, assign, del
}
// Scan implements the sql.Scanner interface.
func (a GenericArray) Scan(src interface{}) error {
dpv := reflect.ValueOf(a.A)
switch {
case dpv.Kind() != reflect.Ptr:
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
case dpv.IsNil():
return fmt.Errorf("pq: destination %T is nil", a.A)
}
dv := dpv.Elem()
switch dv.Kind() {
case reflect.Slice:
case reflect.Array:
default:
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
}
switch src := src.(type) {
case []byte:
return a.scanBytes(src, dv)
case string:
return a.scanBytes([]byte(src), dv)
}
return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
}
func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
dims, elems, err := parseArray(src, []byte(del))
if err != nil {
return err
}
// TODO allow multidimensional
if len(dims) > 1 {
return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
}
// Treat a zero-dimensional array like an array with a single dimension of zero.
if len(dims) == 0 {
dims = append(dims, 0)
}
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
switch rt.Kind() {
case reflect.Slice:
case reflect.Array:
if rt.Len() != dims[i] {
return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
}
default:
// TODO handle multidimensional
}
}
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
for i, e := range elems {
if err := assign(e, values.Index(i)); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
// TODO handle multidimensional
switch dv.Kind() {
case reflect.Slice:
dv.Set(values.Slice(0, dims[0]))
case reflect.Array:
for i := 0; i < dims[0]; i++ {
dv.Index(i).Set(values.Index(i))
}
}
return nil
}
// Value implements the driver.Valuer interface.
func (a GenericArray) Value() (driver.Value, error) {
if a.A == nil {
return nil, nil
}
rv := reflect.ValueOf(a.A)
if k := rv.Kind(); k != reflect.Array && k != reflect.Slice {
return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
}
if n := rv.Len(); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 0, 1+2*n)
b, _, err := appendArray(b, rv, n)
return string(b), err
}
return "{}", nil
}
// Int64Array represents a one-dimensional array of the PostgreSQL integer types.
type Int64Array []int64
// Scan implements the sql.Scanner interface.
func (a *Int64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
}
func (a *Int64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, a[0], 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, a[i], 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// StringArray represents a one-dimensional array of the PostgreSQL character types.
type StringArray []string
// Scan implements the sql.Scanner interface.
func (a *StringArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
}
func (a *StringArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(StringArray, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, []byte(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(a[i]))
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// appendArray appends rv to the buffer, returning the extended buffer and
// the delimiter used between elements.
//
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
var del string
var err error
b = append(b, '{')
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
return b, del, err
}
for i := 1; i < n; i++ {
b = append(b, del...)
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
return b, del, err
}
}
return append(b, '}'), del, nil
}
// appendArrayElement appends rv to the buffer, returning the extended buffer
// and the delimiter to use before the next element.
//
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
// using driver.DefaultParameterConverter and the resulting []byte or string
// is double-quoted.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
if n := rv.Len(); n > 0 {
return appendArray(b, rv, n)
}
return b, "", nil
}
}
var del = ","
var err error
var iv interface{} = rv.Interface()
if ad, ok := iv.(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
return b, del, err
}
switch v := iv.(type) {
case nil:
return append(b, "NULL"...), del, nil
case []byte:
return appendArrayQuotedBytes(b, v), del, nil
case string:
return appendArrayQuotedBytes(b, []byte(v)), del, nil
}
b, err = appendValue(b, iv)
return b, del, err
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func appendValue(b []byte, v driver.Value) ([]byte, error) {
return append(b, encode(v)...), nil
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
if len(src) < 1 || src[0] != '{' {
return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
}
Open:
for i < len(src) {
switch src[i] {
case '{':
depth++
i++
case '}':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
case '{':
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) {
dims[depth-1]++
i += len(del)
goto Element
} else if src[i] == '}' {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
if src[i] == '}' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}

1125
boil/types/array_test.go Normal file

File diff suppressed because it is too large Load diff

135
boil/types/hstore.go Normal file
View file

@ -0,0 +1,135 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"database/sql"
"database/sql/driver"
"strings"
)
// Hstore is a wrapper for transferring Hstore values back and forth easily.
type Hstore map[string]sql.NullString
// escapes and quotes hstore keys/values
// s should be a sql.NullString or string
func hQuote(s interface{}) string {
var str string
switch v := s.(type) {
case sql.NullString:
if !v.Valid {
return "NULL"
}
str = v.String
case string:
str = v
default:
panic("not a string or sql.NullString")
}
str = strings.Replace(str, "\\", "\\\\", -1)
return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"`
}
// Scan implements the Scanner interface.
//
// Note h is reallocated before the scan to clear existing values. If the
// hstore column's database value is NULL, then h is set to nil instead.
func (h *Hstore) Scan(value interface{}) error {
if value == nil {
h = nil
return nil
}
*h = make(map[string]sql.NullString)
var b byte
pair := [][]byte{{}, {}}
pi := 0
inQuote := false
didQuote := false
sawSlash := false
bindex := 0
for bindex, b = range value.([]byte) {
if sawSlash {
pair[pi] = append(pair[pi], b)
sawSlash = false
continue
}
switch b {
case '\\':
sawSlash = true
continue
case '"':
inQuote = !inQuote
if !didQuote {
didQuote = true
}
continue
default:
if !inQuote {
switch b {
case ' ', '\t', '\n', '\r':
continue
case '=':
continue
case '>':
pi = 1
didQuote = false
continue
case ',':
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
pair[0] = []byte{}
pair[1] = []byte{}
pi = 0
continue
}
}
}
pair[pi] = append(pair[pi], b)
}
if bindex > 0 {
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
}
return nil
}
// Value implements the driver Valuer interface. Note if h is nil, the
// database column value will be set to NULL.
func (h Hstore) Value() (driver.Value, error) {
if h == nil {
return nil, nil
}
parts := []string{}
for key, val := range h {
thispart := hQuote(key) + "=>" + hQuote(val)
parts = append(parts, thispart)
}
return []byte(strings.Join(parts, ",")), nil
}

View file

@ -302,4 +302,22 @@ var importsBasedOnType = map[string]imports{
"types.JSON": { "types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
}, },
"types.BytesArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
"types.GenericArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
"types.Int64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
"types.Float64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
"types.BoolArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
"types.Hstore": {
thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`},
},
} }