Add driver support for enums

This commit is contained in:
Aaron L 2016-11-09 23:06:09 -08:00
parent 8d68f936e5
commit cb6de17ea6
2 changed files with 71 additions and 41 deletions

View file

@ -121,7 +121,11 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := m.dbConn.Query(`
select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable,
select
c.column_name,
if(c.data_type = 'enum', c.column_type, c.data_type),
if(extra = 'auto_increment','auto_increment', c.column_default),
c.is_nullable = 'YES',
exists (
select c.column_name
from information_schema.table_constraints tc
@ -140,24 +144,23 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
defer rows.Close()
for rows.Next() {
var colName, colType, colDefault, nullable string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil {
var colName, colType string
var nullable, unique bool
var defaultValue *string
if err := rows.Scan(&colName, &colType, &defaultValue, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr != nil && *defaultPtr != "NULL" {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
Default: colDefault,
Nullable: nullable == "YES",
Nullable: nullable,
Unique: unique,
}
if defaultValue != nil && *defaultValue != "NULL" {
column.Default = *defaultValue
}
columns = append(columns, column)
}

View file

@ -6,6 +6,7 @@ import (
"strings"
// Side-effect import sql driver
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
@ -123,7 +124,35 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
var columns []bdb.Column
rows, err := p.dbConn.Query(`
select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable,
select
c.column_name,
(
case when c.data_type = 'USER-DEFINED' and c.udt_name <> 'hstore'
then
(
select 'enum(''' || string_agg(labels.label, ''',''') || ''')'
from (
select pg_enum.enumlabel as label
from pg_enum
where pg_enum.enumtypid =
(
select typelem
from pg_type
where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name)
limit 1
)
order by pg_enum.enumsortorder
) as labels
)
else c.data_type
end
) as column_type,
c.udt_name,
e.data_type as array_type,
c.column_default,
c.is_nullable = 'YES' as is_nullable,
(select exists(
select 1
from information_schema.constraint_column_usage as ccu
@ -139,8 +168,10 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
where
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique
from information_schema.columns as c LEFT JOIN information_schema.element_types e
ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
from information_schema.columns as c
left join information_schema.element_types e
on ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name=$2 and c.table_schema = $1;
`, schema, tableName)
@ -151,29 +182,25 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
defer rows.Close()
for rows.Next() {
var colName, udtName, colType, colDefault, nullable string
var elementType *string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil {
var colName, colType, udtName string
var defaultValue, arrayType *string
var nullable, unique bool
if err := rows.Scan(&colName, &colType, &udtName, &arrayType, &defaultValue, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr == nil {
colDefault = ""
} else {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
ArrType: elementType,
ArrType: arrayType,
UDTName: udtName,
Default: colDefault,
Nullable: nullable == "YES",
Nullable: nullable,
Unique: unique,
}
if defaultValue != nil {
column.Default = *defaultValue
}
columns = append(columns, column)
}