sqlboiler/bdb/drivers/postgres.go
Aaron L 5f7bee14a0 Relax the definition of unique on columns (psql)
- Fix formatting on giant postgres query
- Invert the section that looks for unique constraints to look for a
  constraint and join to the constraint_column_usage to be more in line
  with the mysql query and also it makes a little bit more sense.
- Add more checks to ensure the schema is being enforced in the postgres
  side of things as several pieces in the unique checks were missing it
  which would lead to false positives.
- Fix #77 for postgres
2017-01-03 19:43:01 -08:00

424 lines
12 KiB
Go

package drivers
import (
"database/sql"
"fmt"
"strings"
// Side-effect import sql driver
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/strmangle"
)
// PostgresDriver holds the database connection string and a handle
// to the database connection.
type PostgresDriver struct {
connStr string
dbConn *sql.DB
}
// NewPostgresDriver takes the database connection details as parameters and
// returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
// the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{
connStr: PostgresBuildQueryString(user, pass, dbname, host, port, sslmode),
}
return &driver
}
// PostgresBuildQueryString builds a query string.
func PostgresBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
parts := []string{}
if len(user) != 0 {
parts = append(parts, fmt.Sprintf("user=%s", user))
}
if len(pass) != 0 {
parts = append(parts, fmt.Sprintf("password=%s", pass))
}
if len(dbname) != 0 {
parts = append(parts, fmt.Sprintf("dbname=%s", dbname))
}
if len(host) != 0 {
parts = append(parts, fmt.Sprintf("host=%s", host))
}
if port != 0 {
parts = append(parts, fmt.Sprintf("port=%d", port))
}
if len(sslmode) != 0 {
parts = append(parts, fmt.Sprintf("sslmode=%s", sslmode))
}
return strings.Join(parts, " ")
}
// Open opens the database connection using the connection string
func (p *PostgresDriver) Open() error {
var err error
p.dbConn, err = sql.Open("postgres", p.connStr)
if err != nil {
return err
}
return nil
}
// Close closes the database connection
func (p *PostgresDriver) Close() {
p.dbConn.Close()
}
// UseLastInsertID returns false for postgres
func (p *PostgresDriver) UseLastInsertID() bool {
return false
}
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is schema. It uses a whitelist and blacklist.
func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
var names []string
query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`)
args := []interface{}{schema}
if len(whitelist) > 0 {
query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(true, len(whitelist), 2, 1))
for _, w := range whitelist {
args = append(args, w)
}
} else if len(blacklist) > 0 {
query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(true, len(blacklist), 2, 1))
for _, b := range blacklist {
args = append(args, b)
}
}
rows, err := p.dbConn.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, nil
}
// Columns takes a table name and attempts to retrieve the table information
// from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string"
func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := p.dbConn.Query(`
select
c.column_name,
(
case when c.data_type = 'USER-DEFINED' and c.udt_name <> 'hstore'
then
(
select 'enum.' || c.udt_name || '(''' || string_agg(labels.label, ''',''') || ''')'
from (
select pg_enum.enumlabel as label
from pg_enum
where pg_enum.enumtypid =
(
select typelem
from pg_type
where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name)
limit 1
)
order by pg_enum.enumsortorder
) as labels
)
else c.data_type
end
) as column_type,
c.udt_name,
e.data_type as array_type,
c.column_default,
c.is_nullable = 'YES' as is_nullable,
(select exists(
select 1
from information_schema.table_constraints tc
inner join information_schema.constraint_column_usage as ccu on tc.constraint_name = ccu.constraint_name
where tc.table_schema = $1 and tc.constraint_type = 'UNIQUE' and ccu.constraint_schema = $1 and ccu.table_name = c.table_name and ccu.column_name = c.column_name and
(select count(*) from information_schema.constraint_column_usage where constraint_schema = $1 and constraint_name = tc.constraint_name) = 1
)) OR
(select exists(
select 1
from pg_indexes pgix
inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i' and pgc.relnatts = 1
inner join pg_index pgi on pgi.indexrelid = pgc.oid
inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
where
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique
from information_schema.columns as c
left join information_schema.element_types e
on ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name = $2 and c.table_schema = $1;
`, schema, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var colName, colType, udtName string
var defaultValue, arrayType *string
var nullable, unique bool
if err := rows.Scan(&colName, &colType, &udtName, &arrayType, &defaultValue, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
column := bdb.Column{
Name: colName,
DBType: colType,
ArrType: arrayType,
UDTName: udtName,
Nullable: nullable,
Unique: unique,
}
if defaultValue != nil {
column.Default = *defaultValue
}
columns = append(columns, column)
}
return columns, nil
}
// PrimaryKeyInfo looks up the primary key for a table.
func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
pkey := &bdb.PrimaryKey{}
var err error
query := `
select tc.constraint_name
from information_schema.table_constraints as tc
where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;`
row := p.dbConn.QueryRow(query, tableName, schema)
if err = row.Scan(&pkey.Name); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
queryColumns := `
select kcu.column_name
from information_schema.key_column_usage as kcu
where constraint_name = $1 and table_schema = $2;`
var rows *sql.Rows
if rows, err = p.dbConn.Query(queryColumns, pkey.Name, schema); err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var column string
err = rows.Scan(&column)
if err != nil {
return nil, err
}
columns = append(columns, column)
}
if err = rows.Err(); err != nil {
return nil, err
}
pkey.Columns = columns
return pkey, nil
}
// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
var fkeys []bdb.ForeignKey
query := `
select
tc.constraint_name,
kcu.table_name as source_table,
kcu.column_name as source_column,
ccu.table_name as dest_table,
ccu.column_name as dest_column
from information_schema.table_constraints as tc
inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name and tc.constraint_schema = kcu.constraint_schema
inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name and tc.constraint_schema = ccu.constraint_schema
where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = $2;`
var rows *sql.Rows
var err error
if rows, err = p.dbConn.Query(query, tableName, schema); err != nil {
return nil, err
}
for rows.Next() {
var fkey bdb.ForeignKey
var sourceTable string
fkey.Table = tableName
err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
if err != nil {
return nil, err
}
fkeys = append(fkeys, fkey)
}
if err = rows.Err(); err != nil {
return nil, err
}
return fkeys, nil
}
// TranslateColumnType converts postgres database types to Go types, for example
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
// as a Column object.
func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
if c.Nullable {
switch c.DBType {
case "bigint", "bigserial":
c.Type = "null.Int64"
case "integer", "serial":
c.Type = "null.Int"
case "smallint", "smallserial":
c.Type = "null.Int16"
case "decimal", "numeric", "double precision":
c.Type = "null.Float64"
case "real":
c.Type = "null.Float32"
case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case `"char"`:
c.Type = "null.Byte"
case "bytea":
c.Type = "null.Bytes"
case "json", "jsonb":
c.Type = "null.JSON"
case "boolean":
c.Type = "null.Bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "null.Time"
case "ARRAY":
if c.ArrType == nil {
panic("unable to get postgres ARRAY underlying type")
}
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s\n", c.UDTName)
}
default:
c.Type = "null.String"
}
} else {
switch c.DBType {
case "bigint", "bigserial":
c.Type = "int64"
case "integer", "serial":
c.Type = "int"
case "smallint", "smallserial":
c.Type = "int16"
case "decimal", "numeric", "double precision":
c.Type = "float64"
case "real":
c.Type = "float32"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "string"
case `"char"`:
c.Type = "types.Byte"
case "json", "jsonb":
c.Type = "types.JSON"
case "bytea":
c.Type = "[]byte"
case "boolean":
c.Type = "bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "time.Time"
case "ARRAY":
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s\n", c.UDTName)
}
default:
c.Type = "string"
}
}
return c
}
// getArrayType returns the correct boil.Array type for each database type
func getArrayType(c bdb.Column) string {
switch *c.ArrType {
case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial":
return "types.Int64Array"
case "bytea":
return "types.BytesArray"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
return "types.StringArray"
case "boolean":
return "types.BoolArray"
case "decimal", "numeric", "double precision", "real":
return "types.Float64Array"
default:
return "types.StringArray"
}
}
// RightQuote is the quoting character for the right side of the identifier
func (p *PostgresDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (p *PostgresDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate PSQL supports indexed placeholders
func (p *PostgresDriver) IndexPlaceholders() bool {
return true
}