Merge branch 'dev'

This commit is contained in:
Aaron L 2016-11-12 00:08:09 -08:00
commit ca748b070d
25 changed files with 701 additions and 93 deletions

View file

@ -70,10 +70,11 @@ Table of Contents
* [Upsert](#upsert)
* [Reload](#reload)
* [Exists](#exists)
* [Enums](#enums)
* [FAQ](#faq)
* [Won't compiling models for a huge database be very slow?](#wont-compiling-models-for-a-huge-database-be-very-slow)
* [Missing imports for generated package](#missing-imports-for-generated-package)
* [Benchmarks](#benchmarks)
* [Benchmarks](#benchmarks)
## About SQL Boiler
@ -97,6 +98,7 @@ Table of Contents
- Debug logging
- Schemas support
- 1d arrays, json, hstore & more
- Enum types
### Supported Databases
@ -121,8 +123,8 @@ if err != nil {
return err
}
// If you don't want to pass in db to all generated methods
// you can use boil.SetDB to set it globally, and then use
// If you don't want to pass in db to all generated methods
// you can use boil.SetDB to set it globally, and then use
// the G variant methods like so:
boil.SetDB(db)
users, err := models.UsersG().All()
@ -178,7 +180,7 @@ fmt.Println(len(users.R.FavoriteMovies))
* Go 1.6 minimum, and Go 1.7 for compatibility tests.
* Table names and column names should use `snake_case` format.
* We require `snake_case` table names and column names. This is a recommended default in Postgres,
* We require `snake_case` table names and column names. This is a recommended default in Postgres,
and we agree that it's good form, so we're enforcing this format for all drivers for the time being.
* Join tables should use a *composite primary key*.
* For join tables to be used transparently for relationships your join table must have
@ -1048,6 +1050,43 @@ exists, err := jet.Pilot(db).Exists()
exists, err := models.Pilots(db, Where("id=?", 5)).Exists()
```
### Enums
If your MySQL or Postgres tables use enums we will generate constants that hold their values
that you can use in your queries. For example:
```
CREATE TYPE workday AS ENUM('monday', 'tuesday', 'wednesday', 'thursday', 'friday');
CREATE TABLE event_one (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
day workday NOT NULL
);
```
An enum type defined like the above, being used by a table, will generate the following enums:
```go
const (
WorkdayMonday = "monday"
WorkdayTuesday = "tuesday"
WorkdayWednesday = "wednesday"
WorkdayThursday = "thursday"
WorkdayFriday = "friday"
)
```
For Postgres we use `enum type name + title cased` value to generate the const variable name.
For MySQL we use `table name + column name + title cased value` to generate the const variable name.
Note: If your enum holds a value we cannot parse correctly due, to non-alphabet characters for example,
it may not be generated. In this event, you will receive errors in your generated tests because
the value randomizer in the test suite does not know how to generate valid enum values. You will
still be able to use your generated library, and it will still work as expected, but the only way
to get the tests to pass in this event is to either use a parsable enum value or use a regular column
instead of an enum.
## FAQ
#### Won't compiling models for a huge database be very slow?

View file

@ -1,6 +1,10 @@
package bdb
import "github.com/vattle/sqlboiler/strmangle"
import (
"strings"
"github.com/vattle/sqlboiler/strmangle"
)
// Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType.
@ -54,3 +58,16 @@ func FilterColumnsByDefault(defaults bool, columns []Column) []Column {
return cols
}
// FilterColumnsByEnum generates the list of columns that are enum values.
func FilterColumnsByEnum(columns []Column) []Column {
var cols []Column
for _, c := range columns {
if strings.HasPrefix(c.DBType, "enum") {
cols = append(cols, c)
}
}
return cols
}

View file

@ -66,3 +66,23 @@ func TestFilterColumnsByDefault(t *testing.T) {
t.Errorf("Invalid result: %#v", res)
}
}
func TestFilterColumnsByEnum(t *testing.T) {
t.Parallel()
cols := []Column{
{Name: "col1", DBType: "enum('hello')"},
{Name: "col2", DBType: "enum('hello','there')"},
{Name: "col3", DBType: "enum"},
{Name: "col4", DBType: ""},
{Name: "col5", DBType: "int"},
}
res := FilterColumnsByEnum(cols)
if res[0].Name != `col1` {
t.Errorf("Invalid result: %#v", res)
}
if res[1].Name != `col2` {
t.Errorf("Invalid result: %#v", res)
}
}

View file

@ -121,7 +121,11 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := m.dbConn.Query(`
select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable,
select
c.column_name,
if(c.data_type = 'enum', c.column_type, c.data_type),
if(extra = 'auto_increment','auto_increment', c.column_default),
c.is_nullable = 'YES',
exists (
select c.column_name
from information_schema.table_constraints tc
@ -140,24 +144,23 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
defer rows.Close()
for rows.Next() {
var colName, colType, colDefault, nullable string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil {
var colName, colType string
var nullable, unique bool
var defaultValue *string
if err := rows.Scan(&colName, &colType, &defaultValue, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr != nil && *defaultPtr != "NULL" {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
Default: colDefault,
Nullable: nullable == "YES",
Nullable: nullable,
Unique: unique,
}
if defaultValue != nil && *defaultValue != "NULL" {
column.Default = *defaultValue
}
columns = append(columns, column)
}

View file

@ -6,6 +6,7 @@ import (
"strings"
// Side-effect import sql driver
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
@ -123,25 +124,55 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
var columns []bdb.Column
rows, err := p.dbConn.Query(`
select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable,
(select exists(
select 1
select
c.column_name,
(
case when c.data_type = 'USER-DEFINED' and c.udt_name <> 'hstore'
then
(
select 'enum.' || c.udt_name || '(''' || string_agg(labels.label, ''',''') || ''')'
from (
select pg_enum.enumlabel as label
from pg_enum
where pg_enum.enumtypid =
(
select typelem
from pg_type
where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name)
limit 1
)
order by pg_enum.enumsortorder
) as labels
)
else c.data_type
end
) as column_type,
c.udt_name,
e.data_type as array_type,
c.column_default,
c.is_nullable = 'YES' as is_nullable,
(select exists(
select 1
from information_schema.constraint_column_usage as ccu
inner join information_schema.table_constraints tc on ccu.constraint_name = tc.constraint_name
where ccu.table_name = c.table_name and ccu.column_name = c.column_name and tc.constraint_type = 'UNIQUE'
inner join information_schema.table_constraints tc on ccu.constraint_name = tc.constraint_name
where ccu.table_name = c.table_name and ccu.column_name = c.column_name and tc.constraint_type = 'UNIQUE'
)) OR (select exists(
select 1
from
pg_indexes pgix
inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i'
inner join pg_index pgi on pgi.indexrelid = pgc.oid
inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
where
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
select 1
from
pg_indexes pgix
inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i'
inner join pg_index pgi on pgi.indexrelid = pgc.oid
inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
where
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique
from information_schema.columns as c LEFT JOIN information_schema.element_types e
ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
from information_schema.columns as c
left join information_schema.element_types e
on ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name=$2 and c.table_schema = $1;
`, schema, tableName)
@ -151,29 +182,25 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error)
defer rows.Close()
for rows.Next() {
var colName, udtName, colType, colDefault, nullable string
var elementType *string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil {
var colName, colType, udtName string
var defaultValue, arrayType *string
var nullable, unique bool
if err := rows.Scan(&colName, &colType, &udtName, &arrayType, &defaultValue, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr == nil {
colDefault = ""
} else {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
ArrType: elementType,
ArrType: arrayType,
UDTName: udtName,
Default: colDefault,
Nullable: nullable == "YES",
Nullable: nullable,
Unique: unique,
}
if defaultValue != nil {
column.Default = *defaultValue
}
columns = append(columns, column)
}
@ -290,6 +317,8 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Float32"
case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case `"char"`:
c.Type = "null.Byte"
case "bytea":
c.Type = "null.Bytes"
case "json", "jsonb":
@ -330,6 +359,8 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "float32"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "string"
case `"char"`:
c.Type = "types.Byte"
case "json", "jsonb":
c.Type = "types.JSON"
case "bytea":

View file

@ -1,8 +1,7 @@
test:
pre:
- mkdir -p /home/ubuntu/.go_workspace/src/github.com/jstemmer
- git clone git@github.com:nullbio/go-junit-report.git /home/ubuntu/.go_workspace/src/github.com/jstemmer/go-junit-report
- go install github.com/jstemmer/go-junit-report
- go get -u github.com/jstemmer/go-junit-report
- echo -e "[postgres]\nhost=\"localhost\"\nport=5432\nuser=\"ubuntu\"\ndbname=\"sqlboiler\"\n[mysql]\nhost=\"localhost\"\nport=3306\nuser=\"ubuntu\"\ndbname=\"sqlboiler\"\nsslmode=\"false\"" > sqlboiler.toml
- createdb -U ubuntu sqlboiler
- psql -U ubuntu sqlboiler < ./testdata/postgres_test_schema.sql

View file

@ -274,55 +274,55 @@ var defaultTestMainImports = map[string]imports{
// TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{
"null.Float32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Float64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.JSON": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bytes": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"time.Time": {
standard: importList{`"time"`},

View file

@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) {
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v5"`,
`"gopkg.in/nullbio/null.v6"`,
},
}
@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) {
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v5"`,
`"gopkg.in/nullbio/null.v6"`,
},
}
@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) {
a := imports{
standard: importList{"fmt"},
thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"},
thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v6"},
}
b := imports{
standard: importList{"os"},
@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) {
if c.standard[0] != "fmt" && c.standard[1] != "os" {
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
}
if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" {
t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty)
if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v6" {
t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v6 got: %#v", c.thirdParty)
}
}

View file

@ -125,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error {
OutFolder: viper.GetString("output"),
Schema: viper.GetString("schema"),
PkgName: viper.GetString("pkgname"),
BaseDir: viper.GetString("basedir"),
Debug: viper.GetBool("debug"),
NoTests: viper.GetBool("no-tests"),
NoHooks: viper.GetBool("no-hooks"),

View file

@ -71,7 +71,7 @@ func eagerLoad(exec boil.Executor, toLoad []string, obj interface{}, bkind bindK
// - t is not considered here, and is always passed nil. The function exists on a loaded
// struct to avoid a circular dependency with boil, and the receiver is ignored.
// - exec is used to perform additional queries that might be required for loading the relationships.
// - singular is passed in to identify whether or not this was a single object
// - bkind is passed in to identify whether or not this was a single object
// or a slice that must be loaded into.
// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
//

View file

@ -5,7 +5,7 @@ import (
"testing"
"time"
"gopkg.in/nullbio/null.v5"
null "gopkg.in/nullbio/null.v6"
)
type testObj struct {

View file

@ -4,6 +4,7 @@ package randomize
import (
"database/sql"
"fmt"
"math/rand"
"reflect"
"regexp"
"sort"
@ -12,7 +13,7 @@ import (
"sync/atomic"
"time"
"gopkg.in/nullbio/null.v5"
null "gopkg.in/nullbio/null.v6"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
@ -34,6 +35,7 @@ var (
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullByte = reflect.TypeOf(null.Byte{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeNullBytes = reflect.TypeOf(null.Bytes{})
@ -156,9 +158,26 @@ func randDate(s *Seed) time.Time {
// If canBeNull is true:
// The value has the possibility of being null or non-zero at random.
func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bool) error {
kind := field.Kind()
typ := field.Type()
if strings.HasPrefix(fieldType, "enum") {
enum, err := randEnumValue(fieldType)
if err != nil {
return err
}
if kind == reflect.Struct {
val := null.NewString(enum, rand.Intn(1) == 0)
field.Set(reflect.ValueOf(val))
} else {
field.Set(reflect.ValueOf(enum))
}
return nil
}
var value interface{}
var isNull bool
@ -341,7 +360,7 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
// only get zero values for non byte slices
// to stop mysql from being a jerk
if isNull && kind != reflect.Slice {
value = getVariableZeroValue(s, kind)
value = getVariableZeroValue(s, kind, typ)
} else {
value = getVariableRandValue(s, kind, typ)
}
@ -457,6 +476,8 @@ func getStructNullValue(s *Seed, typ reflect.Type) interface{} {
return null.NewUint64(0, false)
case typeNullBytes:
return null.NewBytes(nil, false)
case typeNullByte:
return null.NewByte(byte(0), false)
}
return nil
@ -501,13 +522,21 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
return null.NewUint64(uint64(s.nextInt()), true)
case typeNullBytes:
return null.NewBytes(randByteSlice(s, 1), true)
case typeNullByte:
return null.NewByte(byte(rand.Intn(125-65)+65), true)
}
return nil
}
// getVariableZeroValue for the matching type.
func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} {
func getVariableZeroValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} {
switch typ.String() {
case "types.Byte":
// Decimal 65 is 'A'. 0 is not a valid UTF8, so cannot use a zero value here.
return types.Byte(65)
}
switch kind {
case reflect.Float32:
return float32(0)
@ -548,6 +577,11 @@ func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} {
// The randomness is really an incrementation of the global seed,
// this is done to avoid duplicate key violations.
func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} {
switch typ.String() {
case "types.Byte":
return types.Byte(rand.Intn(125-65) + 65)
}
switch kind {
case reflect.Float32:
return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10))
@ -587,3 +621,12 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac
return nil
}
func randEnumValue(enum string) (string, error) {
vals := strmangle.ParseEnumVals(enum)
if vals == nil || len(vals) == 0 {
return "", fmt.Errorf("unable to parse enum string: %s", enum)
}
return vals[rand.Intn(len(vals)-1)], nil
}

View file

@ -5,7 +5,7 @@ import (
"testing"
"time"
"gopkg.in/nullbio/null.v5"
null "gopkg.in/nullbio/null.v6"
)
func TestRandomizeStruct(t *testing.T) {
@ -144,3 +144,28 @@ func TestRandomizeField(t *testing.T) {
}
}
}
func TestRandEnumValue(t *testing.T) {
t.Parallel()
enum1 := "enum.workday('monday','tuesday')"
enum2 := "enum('monday','tuesday')"
r1, err := randEnumValue(enum1)
if err != nil {
t.Error(err)
}
if r1 != "monday" && r1 != "tuesday" {
t.Errorf("Expected monday or tueday, got: %q", r1)
}
r2, err := randEnumValue(enum2)
if err != nil {
t.Error(err)
}
if r2 != "monday" && r2 != "tuesday" {
t.Errorf("Expected monday or tueday, got: %q", r2)
}
}

View file

@ -16,15 +16,31 @@ import (
var (
idAlphabet = []byte("abcdefghijklmnopqrstuvwxyz")
smartQuoteRgx = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(\."?[_a-z][_a-z0-9]*"?)*(\.\*)?$`)
rgxEnum = regexp.MustCompile(`^enum(\.[a-z_]+)?\((,?'[^']+')+\)$`)
rgxEnumIsOK = regexp.MustCompile(`^(?i)[a-z][a-z0-9_]*$`)
rgxEnumShouldTitle = regexp.MustCompile(`^[a-z][a-z0-9_]*$`)
)
var uppercaseWords = map[string]struct{}{
"guid": {},
"id": {},
"ip": {},
"uid": {},
"uuid": {},
"json": {},
"acl": {},
"api": {},
"ascii": {},
"cpu": {},
"eof": {},
"guid": {},
"id": {},
"ip": {},
"json": {},
"ram": {},
"sla": {},
"udp": {},
"ui": {},
"uid": {},
"uuid": {},
"uri": {},
"url": {},
"utf8": {},
}
func init() {
@ -364,7 +380,7 @@ func MakeStringMap(types map[string]string) string {
c := 0
for _, k := range keys {
v := types[k]
buf.WriteString(fmt.Sprintf(`"%s": "%s"`, k, v))
buf.WriteString(fmt.Sprintf("`%s`: `%s`", k, v))
if c < len(types)-1 {
buf.WriteString(", ")
}
@ -562,3 +578,55 @@ func GenerateIgnoreTags(tags []string) string {
return buf.String()
}
// ParseEnumVals returns the values from an enum string
//
// Postgres and MySQL drivers return different values
// psql: enum.enum_name('values'...)
// mysql: enum('values'...)
func ParseEnumVals(s string) []string {
if !rgxEnum.MatchString(s) {
return nil
}
startIndex := strings.IndexByte(s, '(')
s = s[startIndex+2 : len(s)-2]
return strings.Split(s, "','")
}
// ParseEnumName returns the name portion of an enum if it exists
//
// Postgres and MySQL drivers return different values
// psql: enum.enum_name('values'...)
// mysql: enum('values'...)
// In the case of mysql, the name will never return anything
func ParseEnumName(s string) string {
if !rgxEnum.MatchString(s) {
return ""
}
endIndex := strings.IndexByte(s, '(')
s = s[:endIndex]
startIndex := strings.IndexByte(s, '.')
if startIndex < 0 {
return ""
}
return s[startIndex+1:]
}
// IsEnumNormal checks a set of eval values to see if they're "normal"
func IsEnumNormal(values []string) bool {
for _, v := range values {
if !rgxEnumIsOK.MatchString(v) {
return false
}
}
return true
}
// ShouldTitleCaseEnum checks a value to see if it's title-case-able
func ShouldTitleCaseEnum(value string) bool {
return rgxEnumShouldTitle.MatchString(value)
}

View file

@ -291,8 +291,8 @@ func TestMakeStringMap(t *testing.T) {
r = MakeStringMap(m)
e1 := `"TestOne": "interval", "TestTwo": "integer"`
e2 := `"TestTwo": "integer", "TestOne": "interval"`
e1 := "`TestOne`: `interval`, `TestTwo`: `integer`"
e2 := "`TestTwo`: `integer`, `TestOne`: `interval`"
if r != e1 && r != e2 {
t.Errorf("Got %s", r)
@ -513,3 +513,70 @@ func TestGenerateIgnoreTags(t *testing.T) {
t.Errorf("expected %s, got %s", exp, tags)
}
}
func TestParseEnum(t *testing.T) {
t.Parallel()
tests := []struct {
Enum string
Name string
Vals []string
}{
{"enum('one')", "", []string{"one"}},
{"enum('one','two')", "", []string{"one", "two"}},
{"enum.working('one')", "working", []string{"one"}},
{"enum.wor_king('one','two')", "wor_king", []string{"one", "two"}},
}
for i, test := range tests {
name := ParseEnumName(test.Enum)
vals := ParseEnumVals(test.Enum)
if name != test.Name {
t.Errorf("%d) name was wrong, want: %s got: %s (%s)", i, test.Name, name, test.Enum)
}
for j, v := range test.Vals {
if v != vals[j] {
t.Errorf("%d.%d) value was wrong, want: %s got: %s (%s)", i, j, v, vals[j], test.Enum)
}
}
}
}
func TestIsEnumNormal(t *testing.T) {
t.Parallel()
tests := []struct {
Vals []string
Ok bool
}{
{[]string{"o1ne", "two2"}, true},
{[]string{"one", "t#wo2"}, false},
{[]string{"1one", "two2"}, false},
}
for i, test := range tests {
if got := IsEnumNormal(test.Vals); got != test.Ok {
t.Errorf("%d) want: %t got: %t, %#v", i, test.Ok, got, test.Vals)
}
}
}
func TestShouldTitleCaseEnum(t *testing.T) {
t.Parallel()
tests := []struct {
Val string
Ok bool
}{
{"hello_there0", true},
{"hEllo", false},
{"_hello", false},
{"0hello", false},
}
for i, test := range tests {
if got := ShouldTitleCaseEnum(test.Val); got != test.Ok {
t.Errorf("%d) want: %t got: %t, %v", i, test.Ok, got, test.Val)
}
}
}

View file

@ -121,6 +121,28 @@ func loadTemplate(dir string, filename string) (*template.Template, error) {
return tpl.Lookup(filename), err
}
// set is to stop duplication from named enums, allowing a template loop
// to keep some state
type once map[string]struct{}
func newOnce() once {
return make(once)
}
func (o once) Has(s string) bool {
_, ok := o[s]
return ok
}
func (o once) Put(s string) bool {
if _, ok := o[s]; ok {
return false
}
o[s] = struct{}{}
return true
}
// templateStringMappers are placed into the data to make it easy to use the
// stringMap function.
var templateStringMappers = map[string]func(string) string{
@ -157,6 +179,15 @@ var templateFunctions = template.FuncMap{
"generateTags": strmangle.GenerateTags,
"generateIgnoreTags": strmangle.GenerateIgnoreTags,
// Enum ops
"parseEnumName": strmangle.ParseEnumName,
"parseEnumVals": strmangle.ParseEnumVals,
"isEnumNormal": strmangle.IsEnumNormal,
"shouldTitleCaseEnum": strmangle.ShouldTitleCaseEnum,
"onceNew": newOnce,
"oncePut": once.Put,
"onceHas": once.Has,
// String Map ops
"makeStringMap": strmangle.MakeStringMap,
@ -173,6 +204,7 @@ var templateFunctions = template.FuncMap{
// dbdrivers ops
"filterColumnsByDefault": bdb.FilterColumnsByDefault,
"filterColumnsByEnum": bdb.FilterColumnsByEnum,
"sqlColDefinitions": bdb.SQLColDefinitions,
"columnNames": bdb.ColumnNames,
"columnDBTypes": bdb.ColumnDBTypes,

View file

@ -22,11 +22,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula
args := make([]interface{}, count)
if singular {
object.R = &{{$varNameSingular}}R{}
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
args[0] = object.{{$txt.LocalTable.ColumnNameGo}}
} else {
for i, obj := range slice {
obj.R = &{{$varNameSingular}}R{}
if obj.R == nil {
obj.R = &{{$varNameSingular}}R{}
}
args[i] = obj.{{$txt.LocalTable.ColumnNameGo}}
}
}

View file

@ -22,11 +22,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula
args := make([]interface{}, count)
if singular {
object.R = &{{$varNameSingular}}R{}
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
args[0] = object.{{$txt.LocalTable.ColumnNameGo}}
} else {
for i, obj := range slice {
obj.R = &{{$varNameSingular}}R{}
if obj.R == nil {
obj.R = &{{$varNameSingular}}R{}
}
args[i] = obj.{{$txt.LocalTable.ColumnNameGo}}
}
}

View file

@ -23,11 +23,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula
args := make([]interface{}, count)
if singular {
object.R = &{{$varNameSingular}}R{}
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
args[0] = object.{{.Column | titleCase}}
} else {
for i, obj := range slice {
obj.R = &{{$varNameSingular}}R{}
if obj.R == nil {
obj.R = &{{$varNameSingular}}R{}
}
args[i] = obj.{{.Column | titleCase}}
}
}

View file

@ -35,3 +35,52 @@ func makeCacheKey(wl, nzDefaults []string) string {
strmangle.PutBuffer(buf)
return str
}
{{/*
The following is a little bit of black magic and deserves some explanation
Because postgres and mysql define enums completely differently (one at the
database level as a custom datatype, and one at the table column level as
a unique thing per table)... There's a chance the enum is named (postgres)
and not (mysql). So we can't do this per table so this code is here.
We loop through each table and column looking for enums. If it's named, we
then use some disgusting magic to write state during the template compile to
the "once" map. This lets named enums only be defined once if they're referenced
multiple times in many (or even the same) tables.
Then we check if all it's values are normal, if they are we create the enum
output, if not we output a friendly error message as a comment to aid in
debugging.
Postgres output looks like: EnumNameEnumValue = "enumvalue"
MySQL output looks like: TableNameColNameEnumValue = "enumvalue"
It only titlecases the EnumValue portion if it's snake-cased.
*/}}
{{$dot := . -}}
{{$once := onceNew}}
{{- range $table := .Tables -}}
{{- range $col := $table.Columns | filterColumnsByEnum -}}
{{- $name := parseEnumName $col.DBType -}}
{{- $vals := parseEnumVals $col.DBType -}}
{{- $isNamed := ne (len $name) 0}}
{{- if and $isNamed (onceHas $once $name) -}}
{{- else -}}
{{- if $isNamed -}}
{{$_ := oncePut $once $name}}
{{- end -}}
{{- if and (gt (len $vals) 0) (isEnumNormal $vals)}}
// Enum values for {{if $isNamed}}{{$name}}{{else}}{{$table.Name}}.{{$col.Name}}{{end}}
const (
{{- range $val := $vals -}}
{{- if $isNamed}}{{titleCase $name}}{{else}}{{titleCase $table.Name}}{{titleCase $col.Name}}{{end -}}
{{if shouldTitleCaseEnum $val}}{{titleCase $val}}{{else}}{{$val}}{{end}} = "{{$val}}"
{{end -}}
)
{{- else}}
// Enum values for {{if $isNamed}}{{$name}}{{else}}{{$table.Name}}.{{$col.Name}}{{end}} are not proper Go identifiers, cannot emit constants
{{- end -}}
{{- end -}}
{{- end -}}
{{- end -}}

View file

@ -1,3 +1,23 @@
CREATE TABLE event_one (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
day enum('monday','tuesday','wednesday')
);
CREATE TABLE event_two (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
face enum('happy','sad','bitter')
);
CREATE TABLE event_three (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
face enum('happy','sad','bitter'),
mood enum('happy','sad','bitter'),
day enum('monday','tuesday','wednesday')
);
CREATE TABLE magic (
id int PRIMARY KEY NOT NULL AUTO_INCREMENT,
id_two int NOT NULL,

View file

@ -1,3 +1,33 @@
CREATE TYPE workday AS ENUM('monday', 'tuesday', 'wednesday', 'thursday', 'friday');
CREATE TYPE faceyface AS ENUM('angry', 'hungry', 'bitter');
CREATE TABLE event_one (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
day workday NOT NULL
);
CREATE TABLE event_two (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
day workday NOT NULL
);
CREATE TABLE event_three (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
day workday NOT NULL,
face faceyface NOT NULL,
thing workday NULL,
stuff faceyface NULL
);
CREATE TABLE facey (
id serial PRIMARY KEY NOT NULL,
name VARCHAR(255),
face faceyface NOT NULL
);
CREATE TABLE magic (
id serial PRIMARY KEY NOT NULL,
id_two serial NOT NULL,
@ -24,6 +54,23 @@ CREATE TABLE magic (
string_ten VARCHAR(1000) NULL DEFAULT '',
string_eleven VARCHAR(1000) NOT NULL DEFAULT '',
nonbyte_zero CHAR(1),
nonbyte_one CHAR(1) NULL,
nonbyte_two CHAR(1) NOT NULL,
nonbyte_three CHAR(1) NULL DEFAULT 'a',
nonbyte_four CHAR(1) NOT NULL DEFAULT 'b',
nonbyte_five CHAR(1000),
nonbyte_six CHAR(1000) NULL,
nonbyte_seven CHAR(1000) NOT NULL,
nonbyte_eight CHAR(1000) NULL DEFAULT 'a',
nonbyte_nine CHAR(1000) NOT NULL DEFAULT 'b',
byte_zero "char",
byte_one "char" NULL,
byte_two "char" NULL DEFAULT 'a',
byte_three "char" NOT NULL,
byte_four "char" NOT NULL DEFAULT 'b',
big_int_zero bigint,
big_int_one bigint NULL,
big_int_two bigint NOT NULL,

61
types/byte.go Normal file
View file

@ -0,0 +1,61 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// Byte is an alias for byte.
// Byte implements Marshal and Unmarshal.
type Byte byte
// String output your byte.
func (b Byte) String() string {
return string(b)
}
// UnmarshalJSON sets *b to a copy of data.
func (b *Byte) UnmarshalJSON(data []byte) error {
if b == nil {
return errors.New("json: unmarshal json on nil pointer to byte")
}
var x string
if err := json.Unmarshal(data, &x); err != nil {
return err
}
if len(x) > 1 {
return errors.New("json: cannot convert to byte, text len is greater than one")
}
*b = Byte(x[0])
return nil
}
// MarshalJSON returns the JSON encoding of b.
func (b Byte) MarshalJSON() ([]byte, error) {
return []byte{'"', byte(b), '"'}, nil
}
// Value returns b as a driver.Value.
func (b Byte) Value() (driver.Value, error) {
return []byte{byte(b)}, nil
}
// Scan stores the src in *b.
func (b *Byte) Scan(src interface{}) error {
switch src.(type) {
case uint8:
*b = Byte(src.(uint8))
case string:
*b = Byte(src.(string)[0])
case []byte:
*b = Byte(src.([]byte)[0])
default:
return errors.New("incompatible type for byte")
}
return nil
}

74
types/byte_test.go Normal file
View file

@ -0,0 +1,74 @@
package types
import (
"bytes"
"encoding/json"
"testing"
)
func TestByteString(t *testing.T) {
t.Parallel()
b := Byte('b')
if b.String() != "b" {
t.Errorf("Expected %q, got %s", "b", b.String())
}
}
func TestByteUnmarshal(t *testing.T) {
t.Parallel()
var b Byte
err := json.Unmarshal([]byte(`"b"`), &b)
if err != nil {
t.Error(err)
}
if b != 'b' {
t.Errorf("Expected %q, got %s", "b", b)
}
}
func TestByteMarshal(t *testing.T) {
t.Parallel()
b := Byte('b')
res, err := json.Marshal(&b)
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(`"b"`)) {
t.Errorf("expected %s, got %s", `"b"`, b.String())
}
}
func TestByteValue(t *testing.T) {
t.Parallel()
b := Byte('b')
v, err := b.Value()
if err != nil {
t.Error(err)
}
if !bytes.Equal([]byte{byte(b)}, v.([]byte)) {
t.Errorf("byte mismatch, %v %v", b, v)
}
}
func TestByteScan(t *testing.T) {
t.Parallel()
var b Byte
s := "b"
err := b.Scan(s)
if err != nil {
t.Error(err)
}
if !bytes.Equal([]byte{byte(b)}, []byte{'b'}) {
t.Errorf("bad []byte: %#v ≠ %#v\n", b, "b")
}
}

View file

@ -35,7 +35,7 @@ func (j *JSON) Marshal(obj interface{}) error {
// UnmarshalJSON sets *j to a copy of data.
func (j *JSON) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSON: UnmarshalJSON on nil pointer")
return errors.New("json: unmarshal json on nil pointer to json")
}
*j = append((*j)[0:0], data...)
@ -68,7 +68,7 @@ func (j *JSON) Scan(src interface{}) error {
case []byte:
source = src.([]byte)
default:
return errors.New("Incompatible type for JSON")
return errors.New("incompatible type for json")
}
*j = JSON(append((*j)[0:0], source...))