Merge branch 'dev'

This commit is contained in:
Patrick O'brien 2016-09-15 19:52:44 +10:00
commit 65dd15de09
99 changed files with 6516 additions and 3332 deletions

159
README.md
View file

@ -80,23 +80,28 @@ Table of Contents
### Features ### Features
- Full model generation - Full model generation
- High performance through generation
- Extremely fast code generation - Extremely fast code generation
- High performance through generation & intelligent caching
- Uses boil.Executor (simple interface, sql.DB, sqlx.DB etc. compatible) - Uses boil.Executor (simple interface, sql.DB, sqlx.DB etc. compatible)
- Easy workflow (models can always be regenerated, full auto-complete) - Easy workflow (models can always be regenerated, full auto-complete)
- Strongly typed querying (usually no converting or binding to pointers) - Strongly typed querying (usually no converting or binding to pointers)
- Hooks (Before/After Create/Select/Update/Delete/Upsert) - Hooks (Before/After Create/Select/Update/Delete/Upsert)
- Automatic CreatedAt/UpdatedAt - Automatic CreatedAt/UpdatedAt
- Table whitelist/blacklist
- Relationships/Associations - Relationships/Associations
- Eager loading (recursive) - Eager loading (recursive)
- Custom struct tags
- Schema support
- Transactions - Transactions
- Raw SQL fallback - Raw SQL fallback
- Compatibility tests (Run against your own DB schema) - Compatibility tests (Run against your own DB schema)
- Debug logging - Debug logging
- Postgres 1d arrays, json, hstore & more
### Supported Databases ### Supported Databases
- PostgreSQL - PostgreSQL
- MySQL
*Note: Seeking contributors for other database engines.* *Note: Seeking contributors for other database engines.*
@ -203,30 +208,32 @@ order:
- `$XDG_CONFIG_HOME/sqlboiler/` - `$XDG_CONFIG_HOME/sqlboiler/`
- `$HOME/.config/sqlboiler/` - `$HOME/.config/sqlboiler/`
We require you pass in the `postgres` configuration via the configuration file rather than env vars. We require you pass in your `postgres` and `mysql` database configuration via the configuration file rather than env vars.
There is no command line argument support for database configuration. Values given under the `postgres` There is no command line argument support for database configuration. Values given under the `postgres` and `mysql`
block are passed directly to the [pq](github.com/lib/pq) driver. Here is a rundown of all the different block are passed directly to the postgres and mysql drivers. Here is a rundown of all the different
values that can go in that section: values that can go in that section:
| Name | Required | Default | | Name | Required | Postgres Default | MySQL Default |
| --- | --- | --- | | --- | --- | --- | --- |
| dbname | yes | none | | dbname | yes | none | none |
| host | yes | none | | host | yes | none | none |
| port | no | 5432 | | port | no | 5432 | 3306 |
| user | yes | none | | user | yes | none | none |
| pass | no | none | | pass | no | none | none |
| sslmode | no | "require" | | sslmode | no | "require" | "true" |
You can also pass in these top level configuration values if you would prefer You can also pass in these top level configuration values if you would prefer
not to pass them through the command line or environment variables: not to pass them through the command line or environment variables:
| Name | Default | | Name | Defaults |
| --- | --- | | ------------------ | --------- |
| basedir | none | | basedir | none |
| schema | "public" *(or dbname for mysql)* |
| pkgname | "models" | | pkgname | "models" |
| output | "models" | | output | "models" |
| exclude | [ ] | | whitelist | [] |
| tag | [ ] | | blacklist | [] |
| tag | [] |
| debug | false | | debug | false |
| no-hooks | false | | no-hooks | false |
| no-tests | false | | no-tests | false |
@ -256,23 +263,26 @@ Usage:
sqlboiler [flags] <driver> sqlboiler [flags] <driver>
Examples: Examples:
sqlboiler postgres sqlboiler postgres
sqlboiler mysql
Flags: Flags:
-b, --basedir string The base directory has the templates and templates_test folders -b, --blacklist stringSlice Do not include these tables in your generated package
-w, --whitelist stringSlice Only include these tables in your generated package
-s, --schema string The name of your database schema, for databases that support real schemas (default "public")
-p, --pkgname string The name you wish to assign to your generated package (default "models")
-o, --output string The name of the folder to output to (default "models")
-t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml
-d, --debug Debug mode prints stack traces on error -d, --debug Debug mode prints stack traces on error
-x, --exclude stringSlice Tables to be excluded from the generated package --basedir string The base directory has the templates and templates_test folders
--no-auto-timestamps Disable automatic timestamps for created_at/updated_at --no-auto-timestamps Disable automatic timestamps for created_at/updated_at
--no-hooks Disable hooks feature for your models --no-hooks Disable hooks feature for your models
--no-tests Disable generated go test files --no-tests Disable generated go test files
-o, --output string The name of the folder to output to (default "models")
-p, --pkgname string The name you wish to assign to your generated package (default "models")
-t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml
``` ```
Follow the steps below to do some basic model generation. Once we've generated Follow the steps below to do some basic model generation. Once you've generated
our models, we can run the compatibility tests which will exercise the entirety your models, you can run the compatibility tests which will exercise the entirety
of the generated code. This way we can ensure that our database is compatible of the generated code. This way you can ensure that your database is compatible
with SQLBoiler. If you find there are some failing tests, please check the with SQLBoiler. If you find there are some failing tests, please check the
[Diagnosing Problems](#diagnosing-problems) section. [Diagnosing Problems](#diagnosing-problems) section.
@ -281,8 +291,7 @@ with SQLBoiler. If you find there are some failing tests, please check the
sqlboiler -x goose_migrations postgres sqlboiler -x goose_migrations postgres
# Run the generated tests # Run the generated tests
go test ./models # This requires an administrator postgres user because of some go test ./models
# voodoo we do to disable triggers for the generated test db
``` ```
## Diagnosing Problems ## Diagnosing Problems
@ -292,7 +301,7 @@ The most common causes of problems and panics are:
- Forgetting to exclude tables you do not want included in your generation, like migration tables. - Forgetting to exclude tables you do not want included in your generation, like migration tables.
- Tables without a primary key. All tables require one. - Tables without a primary key. All tables require one.
- Forgetting to put foreign key constraints on your columns that reference other tables. - Forgetting to put foreign key constraints on your columns that reference other tables.
- The compatibility tests that run against your own DB schema require a superuser, ensure the user - The compatibility tests require privileges to create a database for testing purposes, ensure the user
supplied in your `sqlboiler.toml` config has adequate privileges. supplied in your `sqlboiler.toml` config has adequate privileges.
- A nil or closed database handle. Ensure your passed in `boil.Executor` is not nil. - A nil or closed database handle. Ensure your passed in `boil.Executor` is not nil.
- If you decide to use the `G` variant of functions instead, make sure you've initialized your - If you decide to use the `G` variant of functions instead, make sure you've initialized your
@ -345,9 +354,8 @@ ALTER TABLE pilot_languages ADD CONSTRAINT pilots_fkey FOREIGN KEY (pilot_id) RE
ALTER TABLE pilot_languages ADD CONSTRAINT languages_fkey FOREIGN KEY (language_id) REFERENCES languages(id); ALTER TABLE pilot_languages ADD CONSTRAINT languages_fkey FOREIGN KEY (language_id) REFERENCES languages(id);
``` ```
The generated model structs for this schema look like the following. Note that I've included the relationship The generated model structs for this schema look like the following. Note that we've included the relationship
structs as well so you can see how it all pieces together, but these are unexported and not something you should structs as well so you can see how it all pieces together:
ever need to touch directly:
```go ```go
type Pilot struct { type Pilot struct {
@ -355,6 +363,7 @@ type Pilot struct {
Name string `boil:"name" json:"name" toml:"name" yaml:"name"` Name string `boil:"name" json:"name" toml:"name" yaml:"name"`
R *pilotR `boil:"-" json:"-" toml:"-" yaml:"-"` R *pilotR `boil:"-" json:"-" toml:"-" yaml:"-"`
L pilotR `boil:"-" json:"-" toml:"-" yaml:"-"`
} }
type pilotR struct { type pilotR struct {
@ -371,6 +380,7 @@ type Jet struct {
Color string `boil:"color" json:"color" toml:"color" yaml:"color"` Color string `boil:"color" json:"color" toml:"color" yaml:"color"`
R *jetR `boil:"-" json:"-" toml:"-" yaml:"-"` R *jetR `boil:"-" json:"-" toml:"-" yaml:"-"`
L jetR `boil:"-" json:"-" toml:"-" yaml:"-"`
} }
type jetR struct { type jetR struct {
@ -382,6 +392,7 @@ type Language struct {
Language string `boil:"language" json:"language" toml:"language" yaml:"language"` Language string `boil:"language" json:"language" toml:"language" yaml:"language"`
R *languageR `boil:"-" json:"-" toml:"-" yaml:"-"` R *languageR `boil:"-" json:"-" toml:"-" yaml:"-"`
L languageR `boil:"-" json:"-" toml:"-" yaml:"-"`
} }
type languageR struct { type languageR struct {
@ -414,7 +425,7 @@ Note: You can set the timezone for this feature by calling `boil.SetLocation()`
This is somewhat of a work around until we can devise a better solution in a later version. This is somewhat of a work around until we can devise a better solution in a later version.
* **Update** * **Update**
* The `updated_at` column will always be set to `time.Now()`. If you need to override * The `updated_at` column will always be set to `time.Now()`. If you need to override
this value you will need to fall back to another method in the meantime: `boil.SQL()`, this value you will need to fall back to another method in the meantime: `queries.Raw()`,
overriding `updated_at` in all of your objects using a hook, or create your own wrapper. overriding `updated_at` in all of your objects using a hook, or create your own wrapper.
* **Upsert** * **Upsert**
* `created_at` will be set automatically if it is a zero value, otherwise your supplied value * `created_at` will be set automatically if it is a zero value, otherwise your supplied value
@ -452,36 +463,8 @@ err := models.NewQuery(db, From("pilots")).All()
As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers) As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers)
allow you to execute the final action. allow you to execute the final action.
If you plan on executing the same query with the same values using the query builder, We also generate query building helper methods for your relationships as well. Take a look at our
you should do so like the following to utilize caching: [Relationships Query Building](#relationships) section for some additional query building information.
```go
// Instead of this:
for i := 0; i < 10; i++ {
pilots := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)).All()
}
// You should do this
query := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5))
for i := 0; i < 10; i++ {
pilots := query.All()
}
// Every execution of All() after the first will use a cached version of
// the built query that short circuits the query builder all together.
// This allows you to save on performance.
// Just something to be aware of: query mods don't store pointers, so if
// your passed in variable's value changes, your generated query will not change.
```
Note: You will see exported `boil.SetX` methods in the boil package. These should not be used on query
objects because they will break caching. Unfortunately these had to be exported due to some circular
dependency issues, but they're not functionality we want exposed. If you want a different
query object, generate a new one.
Take a look at our [Relationships Query Building](#relationships) section for some additional query
building information.
### Query Mod System ### Query Mod System
@ -575,26 +558,31 @@ UpdateAll(models.M{"name": "John", "age": 23}) // Update all rows matching the b
DeleteAll() // Delete all rows matching the built query. DeleteAll() // Delete all rows matching the built query.
Exists() // Returns a bool indicating whether the row(s) for the built query exists. Exists() // Returns a bool indicating whether the row(s) for the built query exists.
Bind(&myObj) // Bind the results of a query to your own struct object. Bind(&myObj) // Bind the results of a query to your own struct object.
Exec() // Execute an SQL query that does not require any rows returned.
QueryRow() // Execute an SQL query expected to return only a single row.
Query() // Execute an SQL query expected to return multiple rows.
``` ```
### Raw Query ### Raw Query
We provide `boil.SQL()` for executing raw queries. Generally you will want to use `Bind()` with We provide `queries.Raw()` for executing raw queries. Generally you will want to use `Bind()` with
this, like the following: this, like the following:
```go ```go
err := boil.SQL(db, "select * from pilots where id=$1", 5).Bind(&obj) err := queries.Raw(db, "select * from pilots where id=$1", 5).Bind(&obj)
``` ```
You can use your own structs or a generated struct as a parameter to Bind. Bind supports both You can use your own structs or a generated struct as a parameter to Bind. Bind supports both
a single object for single row queries and a slice of objects for multiple row queries. a single object for single row queries and a slice of objects for multiple row queries.
You also have `models.NewQuery()` at your disposal if you would still like to use [Query Build](#query-building) `queries.Raw()` also has a method that can execute a query without binding to an object, if required.
but would like to build against a non-generated model.
You also have `models.NewQuery()` at your disposal if you would still like to use [Query Building](#query-building)
in combination with your own custom, non-generated model.
### Binding ### Binding
For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/boil#Bind). For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/queries#Bind).
The `Bind()` [Finisher](#finisher) allows the results of a query built with The `Bind()` [Finisher](#finisher) allows the results of a query built with
the [Raw SQL](#raw-query) method or the [Query Builder](#query-building) methods to be bound the [Raw SQL](#raw-query) method or the [Query Builder](#query-building) methods to be bound
@ -613,7 +601,7 @@ type PilotAndJet struct {
var paj PilotAndJet var paj PilotAndJet
// Use a raw query // Use a raw query
err := boil.SQL(` err := queries.Raw(`
select pilots.id as "pilots.id", pilots.name as "pilots.name", select pilots.id as "pilots.id", pilots.name as "pilots.name",
jets.id as "jets.id", jets.pilot_id as "jets.pilot_id", jets.id as "jets.id", jets.pilot_id as "jets.pilot_id",
jets.age as "jets.age", jets.name as "jets.name", jets.color as "jets.color" jets.age as "jets.age", jets.name as "jets.name", jets.color as "jets.color"
@ -641,7 +629,7 @@ var info JetInfo
err := models.NewQuery(db, Select("sum(age) as age_sum", "count(*) as juicy_count", From("jets"))).Bind(&info) err := models.NewQuery(db, Select("sum(age) as age_sum", "count(*) as juicy_count", From("jets"))).Bind(&info)
// Use a raw query // Use a raw query
err := boil.SQL(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info) err := queries.Raw(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info)
``` ```
We support the following struct tag modes for `Bind()` control: We support the following struct tag modes for `Bind()` control:
@ -905,8 +893,8 @@ err := p1.Insert(db) // Insert the first pilot with name "Larry"
// p1 now has an ID field set to 1 // p1 now has an ID field set to 1
var p2 models.Pilot var p2 models.Pilot
p2.Name "Borris" p2.Name "Boris"
err := p2.Insert(db) // Insert the second pilot with name "Borris" err := p2.Insert(db) // Insert the second pilot with name "Boris"
// p2 now has an ID field set to 2 // p2 now has an ID field set to 2
var p3 models.Pilot var p3 models.Pilot
@ -999,8 +987,13 @@ p1.Name = "Hogan"
err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name") err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name")
``` ```
The `updateOnConflict` argument allows you to specify whether you would like Postgres
to perform a `DO NOTHING` on conflict, opposed to a `DO UPDATE`. For MySQL, this param will not be generated.
The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres.
For MySQL, this param will not be generated.
Note: Passing a different set of column values to the update component is not currently supported. Note: Passing a different set of column values to the update component is not currently supported.
If this feature is important to you let us know and we can consider adding something for this.
### Reload ### Reload
In the event that your objects get out of sync with the database for whatever reason, In the event that your objects get out of sync with the database for whatever reason,
@ -1010,7 +1003,7 @@ attached to the objects.
```go ```go
pilot, _ := models.FindPilot(db, 1) pilot, _ := models.FindPilot(db, 1)
// > Object becomes out of sync for some reason // > Object becomes out of sync for some reason, perhaps async processing
// Refresh the object with the latest data from the db // Refresh the object with the latest data from the db
err := pilot.Reload(db) err := pilot.Reload(db)
@ -1051,10 +1044,26 @@ The generated models might import a couple of packages that are not on your syst
`cd` into your generated models directory and type `go get -u -t` to fetch them. You will only need `cd` into your generated models directory and type `go get -u -t` to fetch them. You will only need
to run this command once, not per generation. to run this command once, not per generation.
#### How should I handle multiple schemas?
If your database uses multiple schemas you should generate a new package for each of your schemas.
Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not
fake schemas (like MySQL).
#### How do I use types.BytesArray for Postgres bytea arrays?
Only "escaped format" is supported for types.BytesArray. This means that your byte slice needs to have
a format of "\\x00" (4 bytes per byte) opposed to "\x00" (1 byte per byte). This is to maintain compatibility
with all Postgres drivers. Example:
`x := types.BytesArray{0: []byte("\\x68\\x69")}`
Please note that multi-dimensional Postgres ARRAY types are not supported at this time.
#### Where is the homepage? #### Where is the homepage?
The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler)
[Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler generator is located at: https://github.com/vattle/sqlboiler
## Benchmarks ## Benchmarks

View file

@ -5,6 +5,11 @@ import "github.com/vattle/sqlboiler/strmangle"
// Column holds information about a database column. // Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType. // Types are Go types, converted by TranslateColumnType.
type Column struct { type Column struct {
// ArrType is the underlying data type of the Postgres
// ARRAY type. See here:
// https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html
ArrType *string
UDTName string
Name string Name string
Type string Type string
DBType string DBType string

View file

@ -9,13 +9,16 @@ import (
type MockDriver struct{} type MockDriver struct{}
// TableNames returns a list of mock table names // TableNames returns a list of mock table names
func (m *MockDriver) TableNames(exclude []string) ([]string, error) { func (m *MockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil return strmangle.SetComplement(tables, blacklist), nil
} }
// Columns returns a list of mock columns // Columns returns a list of mock columns
func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) { func (m *MockDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
return map[string][]bdb.Column{ return map[string][]bdb.Column{
"pilots": { "pilots": {
{Name: "id", Type: "int", DBType: "integer"}, {Name: "id", Type: "int", DBType: "integer"},
@ -56,7 +59,7 @@ func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) {
} }
// ForeignKeyInfo returns a list of mock foreignkeys // ForeignKeyInfo returns a list of mock foreignkeys
func (m *MockDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { func (m *MockDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
return map[string][]bdb.ForeignKey{ return map[string][]bdb.ForeignKey{
"jets": { "jets": {
{Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true},
@ -79,7 +82,7 @@ func (m *MockDriver) TranslateColumnType(c bdb.Column) bdb.Column {
} }
// PrimaryKeyInfo returns mock primary key info for the passed in table name // PrimaryKeyInfo returns mock primary key info for the passed in table name
func (m *MockDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { func (m *MockDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
return map[string]*bdb.PrimaryKey{ return map[string]*bdb.PrimaryKey{
"pilots": { "pilots": {
Name: "pilot_id_pkey", Name: "pilot_id_pkey",
@ -120,3 +123,18 @@ func (m *MockDriver) Open() error { return nil }
// Close mimics a database close call // Close mimics a database close call
func (m *MockDriver) Close() {} func (m *MockDriver) Close() {}
// RightQuote is the quoting character for the right side of the identifier
func (m *MockDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m *MockDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate fake support of indexed placeholders
func (m *MockDriver) IndexPlaceholders() bool {
return false
}

321
bdb/drivers/mysql.go Normal file
View file

@ -0,0 +1,321 @@
package drivers
import (
"database/sql"
"fmt"
"strconv"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
)
// MySQLDriver holds the database connection string and a handle
// to the database connection.
type MySQLDriver struct {
connStr string
dbConn *sql.DB
}
// NewMySQLDriver takes the database connection details as parameters and
// returns a pointer to a MySQLDriver object. Note that it is required to
// call MySQLDriver.Open() and MySQLDriver.Close() to open and close
// the database connection once an object has been obtained.
func NewMySQLDriver(user, pass, dbname, host string, port int, sslmode string) *MySQLDriver {
driver := MySQLDriver{
connStr: MySQLBuildQueryString(user, pass, dbname, host, port, sslmode),
}
return &driver
}
// MySQLBuildQueryString builds a query string for MySQL.
func MySQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
var config mysql.Config
config.User = user
if len(pass) != 0 {
config.Passwd = pass
}
config.DBName = dbname
config.Net = "tcp"
config.Addr = host
if port == 0 {
port = 3306
}
config.Addr += ":" + strconv.Itoa(port)
config.TLSConfig = sslmode
return config.FormatDSN()
}
// Open opens the database connection using the connection string
func (m *MySQLDriver) Open() error {
var err error
m.dbConn, err = sql.Open("mysql", m.connStr)
if err != nil {
return err
}
return nil
}
// Close closes the database connection
func (m *MySQLDriver) Close() {
m.dbConn.Close()
}
// UseLastInsertID returns false for postgres
func (m *MySQLDriver) UseLastInsertID() bool {
return true
}
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public.
func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
var names []string
query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`)
args := []interface{}{schema}
if len(whitelist) > 0 {
query += fmt.Sprintf(" and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:])
for _, w := range whitelist {
args = append(args, w)
}
} else if len(blacklist) > 0 {
query += fmt.Sprintf(" and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:])
for _, b := range blacklist {
args = append(args, b)
}
}
rows, err := m.dbConn.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, nil
}
// Columns takes a table name and attempts to retrieve the table information
// from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string"
func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := m.dbConn.Query(`
select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable,
exists (
select c.column_name
from information_schema.table_constraints tc
inner join information_schema.key_column_usage kcu
on tc.constraint_name = kcu.constraint_name and tc.table_name = kcu.table_name and tc.table_schema = kcu.table_schema
where c.column_name = kcu.column_name and tc.table_name = c.table_name and
(tc.constraint_type = 'PRIMARY KEY' or tc.constraint_type = 'UNIQUE')
) as is_unique
from information_schema.columns as c
where table_name = ? and table_schema = ?;
`, tableName, schema)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var colName, colType, colDefault, nullable string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr != nil && *defaultPtr != "NULL" {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
Default: colDefault,
Nullable: nullable == "YES",
Unique: unique,
}
columns = append(columns, column)
}
return columns, nil
}
// PrimaryKeyInfo looks up the primary key for a table.
func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
pkey := &bdb.PrimaryKey{}
var err error
query := `
select tc.constraint_name
from information_schema.table_constraints as tc
where tc.table_name = ? and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = ?;`
row := m.dbConn.QueryRow(query, tableName, schema)
if err = row.Scan(&pkey.Name); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
queryColumns := `
select kcu.column_name
from information_schema.key_column_usage as kcu
where table_name = ? and constraint_name = ? and table_schema = ?;`
var rows *sql.Rows
if rows, err = m.dbConn.Query(queryColumns, tableName, pkey.Name, schema); err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var column string
err = rows.Scan(&column)
if err != nil {
return nil, err
}
columns = append(columns, column)
}
if err = rows.Err(); err != nil {
return nil, err
}
pkey.Columns = columns
return pkey, nil
}
// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
var fkeys []bdb.ForeignKey
query := `
select constraint_name, table_name, column_name, referenced_table_name, referenced_column_name
from information_schema.key_column_usage
where table_schema = ? and referenced_table_schema = ? and table_name = ?
`
var rows *sql.Rows
var err error
if rows, err = m.dbConn.Query(query, schema, schema, tableName); err != nil {
return nil, err
}
for rows.Next() {
var fkey bdb.ForeignKey
var sourceTable string
fkey.Table = tableName
err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
if err != nil {
return nil, err
}
fkeys = append(fkeys, fkey)
}
if err = rows.Err(); err != nil {
return nil, err
}
return fkeys, nil
}
// TranslateColumnType converts postgres database types to Go types, for example
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
// as a Column object.
func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column {
if c.Nullable {
switch c.DBType {
case "tinyint":
c.Type = "null.Int8"
case "smallint":
c.Type = "null.Int16"
case "mediumint", "int", "integer":
c.Type = "null.Int"
case "bigint":
c.Type = "null.Int64"
case "float":
c.Type = "null.Float32"
case "double", "double precision", "real":
c.Type = "null.Float64"
case "boolean", "bool":
c.Type = "null.Bool"
case "date", "datetime", "timestamp", "time":
c.Type = "null.Time"
case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
c.Type = "null.Bytes"
case "json":
c.Type = "types.JSON"
default:
c.Type = "null.String"
}
} else {
switch c.DBType {
case "tinyint":
c.Type = "int8"
case "smallint":
c.Type = "int16"
case "mediumint", "int", "integer":
c.Type = "int"
case "bigint":
c.Type = "null.Int64"
case "float":
c.Type = "float32"
case "double", "double precision", "real":
c.Type = "float64"
case "boolean", "bool":
c.Type = "bool"
case "date", "datetime", "timestamp", "time":
c.Type = "time.Time"
case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
c.Type = "[]byte"
case "json":
c.Type = "types.JSON"
default:
c.Type = "string"
}
}
return c
}
// RightQuote is the quoting character for the right side of the identifier
func (m *MySQLDriver) RightQuote() byte {
return '`'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m *MySQLDriver) LeftQuote() byte {
return '`'
}
// IndexPlaceholders returns false to indicate MySQL doesnt support indexed placeholders
func (m *MySQLDriver) IndexPlaceholders() bool {
return false
}

View file

@ -19,23 +19,20 @@ type PostgresDriver struct {
dbConn *sql.DB dbConn *sql.DB
} }
// validatedTypes are types that cannot be zero values in the database.
var validatedTypes = []string{"uuid"}
// NewPostgresDriver takes the database connection details as parameters and // NewPostgresDriver takes the database connection details as parameters and
// returns a pointer to a PostgresDriver object. Note that it is required to // returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close // call PostgresDriver.Open() and PostgresDriver.Close() to open and close
// the database connection once an object has been obtained. // the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver { func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{ driver := PostgresDriver{
connStr: BuildQueryString(user, pass, dbname, host, port, sslmode), connStr: PostgresBuildQueryString(user, pass, dbname, host, port, sslmode),
} }
return &driver return &driver
} }
// BuildQueryString for Postgres // PostgresBuildQueryString builds a query string.
func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { func PostgresBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
parts := []string{} parts := []string{}
if len(user) != 0 { if len(user) != 0 {
parts = append(parts, fmt.Sprintf("user=%s", user)) parts = append(parts, fmt.Sprintf("user=%s", user))
@ -82,21 +79,25 @@ func (p *PostgresDriver) UseLastInsertID() bool {
// TableNames connects to the postgres database and // TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the // retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables // table schema is schema. It uses a whitelist and blacklist.
// such as gorp_migrations func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
var names []string var names []string
query := `select table_name from information_schema.tables where table_schema = 'public'` query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`)
if len(exclude) > 0 { args := []interface{}{schema}
quoteStr := func(x string) string { if len(whitelist) > 0 {
return `'` + x + `'` query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(true, len(whitelist), 2, 1))
for _, w := range whitelist {
args = append(args, w)
}
} else if len(blacklist) > 0 {
query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(true, len(blacklist), 2, 1))
for _, b := range blacklist {
args = append(args, b)
} }
exclude = strmangle.StringMap(quoteStr, exclude)
query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ","))
} }
rows, err := p.dbConn.Query(query) rows, err := p.dbConn.Query(query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -118,11 +119,11 @@ func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
// from the database information_schema.columns. It retrieves the column names // from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType() // and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string" // converts the SQL types to Go types, for example: "varchar" to "string"
func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column var columns []bdb.Column
rows, err := p.dbConn.Query(` rows, err := p.dbConn.Query(`
select column_name, data_type, column_default, is_nullable, select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable,
(select exists( (select exists(
select 1 select 1
from information_schema.constraint_column_usage as ccu from information_schema.constraint_column_usage as ccu
@ -136,11 +137,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
inner join pg_index pgi on pgi.indexrelid = pgc.oid inner join pg_index pgi on pgi.indexrelid = pgc.oid
inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
where where
pgix.schemaname = 'public' and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique )) as is_unique
from information_schema.columns as c from information_schema.columns as c LEFT JOIN information_schema.element_types e
where table_name=$1 and table_schema = 'public'; ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
`, tableName) = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name=$2 and c.table_schema = $1;
`, schema, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -148,10 +151,11 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var colName, colType, colDefault, nullable string var colName, udtName, colType, colDefault, nullable string
var elementType *string
var unique bool var unique bool
var defaultPtr *string var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
} }
@ -162,12 +166,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
} }
column := bdb.Column{ column := bdb.Column{
Name: colName, Name: colName,
DBType: colType, DBType: colType,
Default: colDefault, ArrType: elementType,
Nullable: nullable == "YES", UDTName: udtName,
Unique: unique, Default: colDefault,
Validated: isValidated(colType), Nullable: nullable == "YES",
Unique: unique,
} }
columns = append(columns, column) columns = append(columns, column)
} }
@ -176,16 +181,16 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
} }
// PrimaryKeyInfo looks up the primary key for a table. // PrimaryKeyInfo looks up the primary key for a table.
func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
pkey := &bdb.PrimaryKey{} pkey := &bdb.PrimaryKey{}
var err error var err error
query := ` query := `
select tc.constraint_name select tc.constraint_name
from information_schema.table_constraints as tc from information_schema.table_constraints as tc
where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = 'public';` where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;`
row := p.dbConn.QueryRow(query, tableName) row := p.dbConn.QueryRow(query, tableName, schema)
if err = row.Scan(&pkey.Name); err != nil { if err = row.Scan(&pkey.Name); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
@ -196,10 +201,10 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro
queryColumns := ` queryColumns := `
select kcu.column_name select kcu.column_name
from information_schema.key_column_usage as kcu from information_schema.key_column_usage as kcu
where constraint_name = $1 and table_schema = 'public';` where constraint_name = $1 and table_schema = $2;`
var rows *sql.Rows var rows *sql.Rows
if rows, err = p.dbConn.Query(queryColumns, pkey.Name); err != nil { if rows, err = p.dbConn.Query(queryColumns, pkey.Name, schema); err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
@ -226,7 +231,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro
} }
// ForeignKeyInfo retrieves the foreign keys for a given table name. // ForeignKeyInfo retrieves the foreign keys for a given table name.
func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
var fkeys []bdb.ForeignKey var fkeys []bdb.ForeignKey
query := ` query := `
@ -239,11 +244,11 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, err
from information_schema.table_constraints as tc from information_schema.table_constraints as tc
inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name
inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name
where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = 'public';` where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = $2;`
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if rows, err = p.dbConn.Query(query, tableName); err != nil { if rows, err = p.dbConn.Query(query, tableName, schema); err != nil {
return nil, err return nil, err
} }
@ -279,18 +284,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Int" c.Type = "null.Int"
case "smallint", "smallserial": case "smallint", "smallserial":
c.Type = "null.Int16" c.Type = "null.Int16"
case "decimal", "numeric", "double precision", "money": case "decimal", "numeric", "double precision":
c.Type = "null.Float64" c.Type = "null.Float64"
case "real": case "real":
c.Type = "null.Float32" c.Type = "null.Float32"
case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "null.String" c.Type = "null.String"
case "bytea": case "bytea":
c.Type = "[]byte" c.Type = "null.Bytes"
case "json", "jsonb":
c.Type = "null.JSON"
case "boolean": case "boolean":
c.Type = "null.Bool" c.Type = "null.Bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone": case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "null.Time" c.Type = "null.Time"
case "ARRAY":
if c.ArrType == nil {
panic("unable to get postgres ARRAY underlying type")
}
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default: default:
c.Type = "null.String" c.Type = "null.String"
} }
@ -302,18 +324,32 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "int" c.Type = "int"
case "smallint", "smallserial": case "smallint", "smallserial":
c.Type = "int16" c.Type = "int16"
case "decimal", "numeric", "double precision", "money": case "decimal", "numeric", "double precision":
c.Type = "float64" c.Type = "float64"
case "real": case "real":
c.Type = "float32" c.Type = "float32"
case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "string" c.Type = "string"
case "json", "jsonb":
c.Type = "types.JSON"
case "bytea": case "bytea":
c.Type = "[]byte" c.Type = "[]byte"
case "boolean": case "boolean":
c.Type = "bool" c.Type = "bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone": case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "time.Time" c.Type = "time.Time"
case "ARRAY":
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default: default:
c.Type = "string" c.Type = "string"
} }
@ -322,13 +358,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
return c return c
} }
// isValidated checks if the database type is in the validatedTypes list. // getArrayType returns the correct boil.Array type for each database type
func isValidated(typ string) bool { func getArrayType(c bdb.Column) string {
for _, v := range validatedTypes { switch *c.ArrType {
if v == typ { case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial":
return true return "types.Int64Array"
} case "bytea":
return "types.BytesArray"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
return "types.StringArray"
case "boolean":
return "types.BoolArray"
case "decimal", "numeric", "double precision", "real":
return "types.Float64Array"
default:
return "types.StringArray"
} }
}
return false
// RightQuote is the quoting character for the right side of the identifier
func (p *PostgresDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (p *PostgresDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate PSQL supports indexed placeholders
func (p *PostgresDriver) IndexPlaceholders() bool {
return true
} }

View file

@ -6,10 +6,10 @@ import "github.com/pkg/errors"
// Interface for a database driver. Functionality required to support a specific // Interface for a database driver. Functionality required to support a specific
// database type (eg, MySQL, Postgres etc.) // database type (eg, MySQL, Postgres etc.)
type Interface interface { type Interface interface {
TableNames(exclude []string) ([]string, error) TableNames(schema string, whitelist, blacklist []string) ([]string, error)
Columns(tableName string) ([]Column, error) Columns(schema, tableName string) ([]Column, error)
PrimaryKeyInfo(tableName string) (*PrimaryKey, error) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error)
ForeignKeyInfo(tableName string) ([]ForeignKey, error) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error)
// TranslateColumnType takes a Database column type and returns a go column type. // TranslateColumnType takes a Database column type and returns a go column type.
TranslateColumnType(Column) Column TranslateColumnType(Column) Column
@ -22,23 +22,32 @@ type Interface interface {
Open() error Open() error
// Close the database connection // Close the database connection
Close() Close()
// Dialect helpers, these provide the values that will go into
// a queries.Dialect, so the query builder knows how to support
// your database driver properly.
LeftQuote() byte
RightQuote() byte
IndexPlaceholders() bool
} }
// Tables returns the metadata for all tables, minus the tables // Tables returns the metadata for all tables, minus the tables
// specified in the exclude slice. // specified in the blacklist.
func Tables(db Interface, exclude ...string) ([]Table, error) { func Tables(db Interface, schema string, whitelist, blacklist []string) ([]Table, error) {
var err error var err error
names, err := db.TableNames(exclude) names, err := db.TableNames(schema, whitelist, blacklist)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "unable to get table names") return nil, errors.Wrap(err, "unable to get table names")
} }
var tables []Table var tables []Table
for _, name := range names { for _, name := range names {
t := Table{Name: name} t := Table{
Name: name,
}
if t.Columns, err = db.Columns(name); err != nil { if t.Columns, err = db.Columns(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name) return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
} }
@ -46,11 +55,11 @@ func Tables(db Interface, exclude ...string) ([]Table, error) {
t.Columns[i] = db.TranslateColumnType(c) t.Columns[i] = db.TranslateColumnType(c)
} }
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil { if t.PKey, err = db.PrimaryKeyInfo(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
} }
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil { if t.FKeys, err = db.ForeignKeyInfo(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
} }

View file

@ -6,20 +6,23 @@ import (
"github.com/vattle/sqlboiler/strmangle" "github.com/vattle/sqlboiler/strmangle"
) )
type mockDriver struct{} type testMockDriver struct{}
func (m mockDriver) TranslateColumnType(c Column) Column { return c } func (m testMockDriver) TranslateColumnType(c Column) Column { return c }
func (m mockDriver) UseLastInsertID() bool { return false } func (m testMockDriver) UseLastInsertID() bool { return false }
func (m mockDriver) Open() error { return nil } func (m testMockDriver) Open() error { return nil }
func (m mockDriver) Close() {} func (m testMockDriver) Close() {}
func (m mockDriver) TableNames(exclude []string) ([]string, error) { func (m testMockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil return strmangle.SetComplement(tables, blacklist), nil
} }
// Columns returns a list of mock columns // Columns returns a list of mock columns
func (m mockDriver) Columns(tableName string) ([]Column, error) { func (m testMockDriver) Columns(schema, tableName string) ([]Column, error) {
return map[string][]Column{ return map[string][]Column{
"pilots": { "pilots": {
{Name: "id", Type: "int", DBType: "integer"}, {Name: "id", Type: "int", DBType: "integer"},
@ -61,7 +64,7 @@ func (m mockDriver) Columns(tableName string) ([]Column, error) {
} }
// ForeignKeyInfo returns a list of mock foreignkeys // ForeignKeyInfo returns a list of mock foreignkeys
func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) { func (m testMockDriver) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) {
return map[string][]ForeignKey{ return map[string][]ForeignKey{
"jets": { "jets": {
{Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true},
@ -81,7 +84,7 @@ func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
} }
// PrimaryKeyInfo returns mock primary key info for the passed in table name // PrimaryKeyInfo returns mock primary key info for the passed in table name
func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) {
return map[string]*PrimaryKey{ return map[string]*PrimaryKey{
"pilots": {Name: "pilot_id_pkey", Columns: []string{"id"}}, "pilots": {Name: "pilot_id_pkey", Columns: []string{"id"}},
"airports": {Name: "airport_id_pkey", Columns: []string{"id"}}, "airports": {Name: "airport_id_pkey", Columns: []string{"id"}},
@ -93,10 +96,25 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
}[tableName], nil }[tableName], nil
} }
// RightQuote is the quoting character for the right side of the identifier
func (m testMockDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m testMockDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate fake support of indexed placeholders
func (m testMockDriver) IndexPlaceholders() bool {
return false
}
func TestTables(t *testing.T) { func TestTables(t *testing.T) {
t.Parallel() t.Parallel()
tables, err := Tables(mockDriver{}) tables, err := Tables(testMockDriver{}, "public", nil, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }

View file

@ -3,6 +3,7 @@ package bdb
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"strings"
) )
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`) var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
@ -79,3 +80,30 @@ func SQLColDefinitions(cols []Column, names []string) SQLColumnDefs {
return ret return ret
} }
// AutoIncPrimaryKey returns the auto-increment primary key column name or an
// empty string. Primary key columns with default values are presumed
// to be auto-increment, because pkeys need to be unique and a static
// default value would cause collisions.
func AutoIncPrimaryKey(cols []Column, pkey *PrimaryKey) *Column {
if pkey == nil {
return nil
}
for _, pkeyColumn := range pkey.Columns {
for _, c := range cols {
if c.Name != pkeyColumn {
continue
}
if c.Default != "auto_increment" || c.Nullable ||
!(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) {
continue
}
return &c
}
}
return nil
}

View file

@ -4,8 +4,11 @@ import "fmt"
// Table metadata from the database schema. // Table metadata from the database schema.
type Table struct { type Table struct {
Name string Name string
Columns []Column // For dbs with real schemas, like Postgres.
// Example value: "schema_name"."table_name"
SchemaName string
Columns []Column
PKey *PrimaryKey PKey *PrimaryKey
FKeys []ForeignKey FKeys []ForeignKey

View file

@ -3,10 +3,12 @@ package main
// Config for the running of the commands // Config for the running of the commands
type Config struct { type Config struct {
DriverName string DriverName string
Schema string
PkgName string PkgName string
OutFolder string OutFolder string
BaseDir string BaseDir string
ExcludeTables []string WhitelistTables []string
BlacklistTables []string
Tags []string Tags []string
Debug bool Debug bool
NoTests bool NoTests bool
@ -14,6 +16,7 @@ type Config struct {
NoAutoTimestamps bool NoAutoTimestamps bool
Postgres PostgresConfig Postgres PostgresConfig
MySQL MySQLConfig
} }
// PostgresConfig configures a postgres database // PostgresConfig configures a postgres database
@ -25,3 +28,13 @@ type PostgresConfig struct {
DBName string DBName string
SSLMode string SSLMode string
} }
// MySQLConfig configures a mysql database
type MySQLConfig struct {
User string
Pass string
Host string
Port int
DBName string
SSLMode string
}

View file

@ -153,7 +153,8 @@ var defaultTemplateImports = imports{
thirdParty: importList{ thirdParty: importList{
`"github.com/pkg/errors"`, `"github.com/pkg/errors"`,
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`, `"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`, `"github.com/vattle/sqlboiler/strmangle"`,
}, },
} }
@ -162,7 +163,8 @@ var defaultSingletonTemplateImports = map[string]imports{
"boil_queries": { "boil_queries": {
thirdParty: importList{ thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`, `"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
}, },
}, },
"boil_types": { "boil_types": {
@ -180,29 +182,38 @@ var defaultTestTemplateImports = imports{
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/randomize"`, `"github.com/vattle/sqlboiler/randomize"`,
`"github.com/vattle/sqlboiler/strmangle"`, `"github.com/vattle/sqlboiler/strmangle"`,
}, },
} }
var defaultSingletonTestTemplateImports = map[string]imports{ var defaultSingletonTestTemplateImports = map[string]imports{
"boil_viper_test": { "boil_main_test": {
standard: importList{ standard: importList{
`"database/sql"`, `"database/sql"`,
`"flag"`,
`"fmt"`,
`"math/rand"`,
`"os"`, `"os"`,
`"path/filepath"`, `"path/filepath"`,
`"testing"`,
`"time"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`, `"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
}, },
}, },
"boil_queries_test": { "boil_queries_test": {
standard: importList{ standard: importList{
`"crypto/md5"`, `"bytes"`,
`"fmt"`, `"fmt"`,
`"os"`, `"io"`,
`"strconv"`, `"io/ioutil"`,
`"math/rand"`, `"math/rand"`,
`"regexp"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
@ -218,27 +229,42 @@ var defaultSingletonTestTemplateImports = map[string]imports{
var defaultTestMainImports = map[string]imports{ var defaultTestMainImports = map[string]imports{
"postgres": { "postgres": {
standard: importList{ standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`, `"bytes"`,
`"database/sql"`, `"database/sql"`,
`"path/filepath"`, `"fmt"`,
`"time"`, `"io"`,
`"math/rand"`, `"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`, `"github.com/pkg/errors"`,
`"github.com/spf13/viper"`, `"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`, `"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/lib/pq"`, `_ "github.com/lib/pq"`,
}, },
}, },
"mysql": {
standard: importList{
`"bytes"`,
`"database/sql"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/go-sql-driver/mysql"`,
},
},
} }
// importsBasedOnType imports are only included in the template output if the // importsBasedOnType imports are only included in the template output if the
@ -246,51 +272,75 @@ var defaultTestMainImports = map[string]imports{
// TranslateColumnType to see the type assignments. // TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{ var importsBasedOnType = map[string]imports{
"null.Float32": { "null.Float32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Float64": { "null.Float64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Int": { "null.Int": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Int8": { "null.Int8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Int16": { "null.Int16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Int32": { "null.Int32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Int64": { "null.Int64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Uint": { "null.Uint": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Uint8": { "null.Uint8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Uint16": { "null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Uint32": { "null.Uint32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Uint64": { "null.Uint64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.String": { "null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Bool": { "null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"null.Time": { "null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.JSON": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Bytes": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
}, },
"time.Time": { "time.Time": {
standard: importList{`"time"`}, standard: importList{`"time"`},
}, },
"types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BytesArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Int64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Float64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BoolArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Hstore": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
} }

View file

@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) {
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`, `"gopkg.in/nullbio/null.v5"`,
}, },
} }
@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) {
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`, `"gopkg.in/nullbio/null.v5"`,
}, },
} }
@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) {
a := imports{ a := imports{
standard: importList{"fmt"}, standard: importList{"fmt"},
thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v4"}, thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"},
} }
b := imports{ b := imports{
standard: importList{"os"}, standard: importList{"os"},
@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) {
if c.standard[0] != "fmt" && c.standard[1] != "os" { if c.standard[0] != "fmt" && c.standard[1] != "os" {
t.Errorf("Wanted: fmt, os got: %#v", c.standard) t.Errorf("Wanted: fmt, os got: %#v", c.standard)
} }
if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" { if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" {
t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty) t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty)
} }
} }

85
main.go
View file

@ -2,7 +2,6 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -61,9 +60,11 @@ func main() {
// Set up the cobra root command flags // Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("schema", "s", "public", "The name of your database schema, for databases that support real schemas")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders") rootCmd.PersistentFlags().StringP("basedir", "", "", "The base directory has the templates and templates_test folders")
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml") rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files") rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
@ -72,6 +73,9 @@ func main() {
viper.SetDefault("postgres.sslmode", "require") viper.SetDefault("postgres.sslmode", "require")
viper.SetDefault("postgres.port", "5432") viper.SetDefault("postgres.port", "5432")
viper.SetDefault("mysql.sslmode", "true")
viper.SetDefault("mysql.port", "3306")
viper.BindPFlags(rootCmd.PersistentFlags()) viper.BindPFlags(rootCmd.PersistentFlags())
viper.AutomaticEnv() viper.AutomaticEnv()
@ -79,7 +83,7 @@ func main() {
if e, ok := err.(commandFailure); ok { if e, ok := err.(commandFailure); ok {
fmt.Printf("Error: %v\n\n", string(e)) fmt.Printf("Error: %v\n\n", string(e))
rootCmd.Help() rootCmd.Help()
} else if !cmdConfig.Debug { } else if !viper.GetBool("debug") {
fmt.Printf("Error: %v\n", err) fmt.Printf("Error: %v\n", err)
} else { } else {
fmt.Printf("Error: %+v\n", err) fmt.Printf("Error: %+v\n", err)
@ -107,6 +111,7 @@ func preRun(cmd *cobra.Command, args []string) error {
cmdConfig = &Config{ cmdConfig = &Config{
DriverName: driverName, DriverName: driverName,
OutFolder: viper.GetString("output"), OutFolder: viper.GetString("output"),
Schema: viper.GetString("schema"),
PkgName: viper.GetString("pkgname"), PkgName: viper.GetString("pkgname"),
Debug: viper.GetBool("debug"), Debug: viper.GetBool("debug"),
NoTests: viper.GetBool("no-tests"), NoTests: viper.GetBool("no-tests"),
@ -115,12 +120,20 @@ func preRun(cmd *cobra.Command, args []string) error {
} }
// BUG: https://github.com/spf13/viper/issues/200 // BUG: https://github.com/spf13/viper/issues/200
// Look up the value of ExcludeTables & Tags directly from PFlags in Cobra if we // Look up the value of blacklist, whitelist & tags directly from PFlags in Cobra if we
// detect a malformed value coming out of viper. // detect a malformed value coming out of viper.
// Once the bug is fixed we'll be able to move this into the init above // Once the bug is fixed we'll be able to move this into the init above
cmdConfig.ExcludeTables = viper.GetStringSlice("exclude") cmdConfig.BlacklistTables = viper.GetStringSlice("blacklist")
if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") { if len(cmdConfig.BlacklistTables) == 1 && strings.HasPrefix(cmdConfig.BlacklistTables[0], "[") {
cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude") cmdConfig.BlacklistTables, err = cmd.PersistentFlags().GetStringSlice("blacklist")
if err != nil {
return err
}
}
cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist")
if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") {
cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist")
if err != nil { if err != nil {
return err return err
} }
@ -134,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
} }
} }
if viper.IsSet("postgres.dbname") { if driverName == "postgres" {
cmdConfig.Postgres = PostgresConfig{ cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"), User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"), Pass: viper.GetString("postgres.pass"),
@ -144,10 +157,17 @@ func preRun(cmd *cobra.Command, args []string) error {
SSLMode: viper.GetString("postgres.sslmode"), SSLMode: viper.GetString("postgres.sslmode"),
} }
// Set the default SSLMode value // BUG: https://github.com/spf13/viper/issues/71
// Despite setting defaults, nested values don't get defaults
// Set them manually
if cmdConfig.Postgres.SSLMode == "" { if cmdConfig.Postgres.SSLMode == "" {
viper.Set("postgres.sslmode", "require") cmdConfig.Postgres.SSLMode = "require"
cmdConfig.Postgres.SSLMode = viper.GetString("postgres.sslmode") viper.Set("postgres.sslmode", cmdConfig.Postgres.SSLMode)
}
if cmdConfig.Postgres.Port == 0 {
cmdConfig.Postgres.Port = 5432
viper.Set("postgres.port", cmdConfig.Postgres.Port)
} }
err = vala.BeginValidation().Validate( err = vala.BeginValidation().Validate(
@ -161,8 +181,45 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return commandFailure(err.Error()) return commandFailure(err.Error())
} }
} else if driverName == "postgres" { }
return errors.New("postgres driver requires a postgres section in your config file")
if driverName == "mysql" {
cmdConfig.MySQL = MySQLConfig{
User: viper.GetString("mysql.user"),
Pass: viper.GetString("mysql.pass"),
Host: viper.GetString("mysql.host"),
Port: viper.GetInt("mysql.port"),
DBName: viper.GetString("mysql.dbname"),
SSLMode: viper.GetString("mysql.sslmode"),
}
// MySQL doesn't have schemas, just databases
cmdConfig.Schema = cmdConfig.MySQL.DBName
// BUG: https://github.com/spf13/viper/issues/71
// Despite setting defaults, nested values don't get defaults
// Set them manually
if cmdConfig.MySQL.SSLMode == "" {
cmdConfig.MySQL.SSLMode = "true"
viper.Set("mysql.sslmode", cmdConfig.MySQL.SSLMode)
}
if cmdConfig.MySQL.Port == 0 {
cmdConfig.MySQL.Port = 3306
viper.Set("mysql.port", cmdConfig.MySQL.Port)
}
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(cmdConfig.MySQL.User, "mysql.user"),
vala.StringNotEmpty(cmdConfig.MySQL.Host, "mysql.host"),
vala.Not(vala.Equals(cmdConfig.MySQL.Port, 0, "mysql.port")),
vala.StringNotEmpty(cmdConfig.MySQL.DBName, "mysql.dbname"),
vala.StringNotEmpty(cmdConfig.MySQL.SSLMode, "mysql.sslmode"),
).Check()
if err != nil {
return commandFailure(err.Error())
}
} }
cmdState, err = New(cmdConfig) cmdState, err = New(cmdConfig)

View file

@ -1,15 +1,16 @@
package boil package queries
import ( import (
"database/sql" "database/sql"
"reflect" "reflect"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle" "github.com/vattle/sqlboiler/strmangle"
) )
type loadRelationshipState struct { type loadRelationshipState struct {
exec Executor exec boil.Executor
loaded map[string]struct{} loaded map[string]struct{}
toLoad []string toLoad []string
} }

View file

@ -1,6 +1,10 @@
package boil package queries
import "testing" import (
"testing"
"github.com/vattle/sqlboiler/boil"
)
var loadFunctionCalled bool var loadFunctionCalled bool
var loadFunctionNestedCalled int var loadFunctionNestedCalled int
@ -32,12 +36,12 @@ type testNestedRSlice struct {
type testNestedLSlice struct { type testNestedLSlice struct {
} }
func (testLStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error { func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error {
loadFunctionCalled = true loadFunctionCalled = true
return nil return nil
} }
func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) { switch x := obj.(type) {
case *testNestedStruct: case *testNestedStruct:
x.R = &testNestedRStruct{ x.R = &testNestedRStruct{
@ -54,7 +58,7 @@ func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj inter
return nil return nil
} }
func (testNestedLSlice) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) { switch x := obj.(type) {
case *testNestedSlice: case *testNestedSlice:

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"fmt" "fmt"

View file

@ -1,11 +1,11 @@
package boil package queries
import ( import (
"reflect" "reflect"
"testing" "testing"
"time" "time"
"gopkg.in/nullbio/null.v4" "gopkg.in/nullbio/null.v5"
) )
type testObj struct { type testObj struct {

View file

@ -1,12 +1,12 @@
package qm package qm
import "github.com/vattle/sqlboiler/boil" import "github.com/vattle/sqlboiler/queries"
// QueryMod to modify the query object // QueryMod to modify the query object
type QueryMod func(q *boil.Query) type QueryMod func(q *queries.Query)
// Apply the query mods to the Query object // Apply the query mods to the Query object
func Apply(q *boil.Query, mods ...QueryMod) { func Apply(q *queries.Query, mods ...QueryMod) {
for _, mod := range mods { for _, mod := range mods {
mod(q) mod(q)
} }
@ -14,8 +14,8 @@ func Apply(q *boil.Query, mods ...QueryMod) {
// SQL allows you to execute a plain SQL statement // SQL allows you to execute a plain SQL statement
func SQL(sql string, args ...interface{}) QueryMod { func SQL(sql string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.SetSQL(q, sql, args...) queries.SetSQL(q, sql, args...)
} }
} }
@ -25,29 +25,29 @@ func SQL(sql string, args ...interface{}) QueryMod {
// Relationship name plurality is important, if your relationship is // Relationship name plurality is important, if your relationship is
// singular, you need to specify the singular form and vice versa. // singular, you need to specify the singular form and vice versa.
func Load(relationships ...string) QueryMod { func Load(relationships ...string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.SetLoad(q, relationships...) queries.SetLoad(q, relationships...)
} }
} }
// InnerJoin on another table // InnerJoin on another table
func InnerJoin(clause string, args ...interface{}) QueryMod { func InnerJoin(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendInnerJoin(q, clause, args...) queries.AppendInnerJoin(q, clause, args...)
} }
} }
// Select specific columns opposed to all columns // Select specific columns opposed to all columns
func Select(columns ...string) QueryMod { func Select(columns ...string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendSelect(q, columns...) queries.AppendSelect(q, columns...)
} }
} }
// Where allows you to specify a where clause for your statement // Where allows you to specify a where clause for your statement
func Where(clause string, args ...interface{}) QueryMod { func Where(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendWhere(q, clause, args...) queries.AppendWhere(q, clause, args...)
} }
} }
@ -55,24 +55,24 @@ func Where(clause string, args ...interface{}) QueryMod {
// And is a duplicate of the Where function, but allows for more natural looking // And is a duplicate of the Where function, but allows for more natural looking
// query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?"))) // query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?")))
func And(clause string, args ...interface{}) QueryMod { func And(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendWhere(q, clause, args...) queries.AppendWhere(q, clause, args...)
} }
} }
// Or allows you to specify a where clause separated by an OR for your statement // Or allows you to specify a where clause separated by an OR for your statement
func Or(clause string, args ...interface{}) QueryMod { func Or(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendWhere(q, clause, args...) queries.AppendWhere(q, clause, args...)
boil.SetLastWhereAsOr(q) queries.SetLastWhereAsOr(q)
} }
} }
// WhereIn allows you to specify a "x IN (set)" clause for your where statement // WhereIn allows you to specify a "x IN (set)" clause for your where statement
// Example clauses: "column in ?", "(column1,column2) in ?" // Example clauses: "column in ?", "(column1,column2) in ?"
func WhereIn(clause string, args ...interface{}) QueryMod { func WhereIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendIn(q, clause, args...) queries.AppendIn(q, clause, args...)
} }
} }
@ -81,65 +81,65 @@ func WhereIn(clause string, args ...interface{}) QueryMod {
// allows for more natural looking query mod chains, for example: // allows for more natural looking query mod chains, for example:
// (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?")) // (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?"))
func AndIn(clause string, args ...interface{}) QueryMod { func AndIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendIn(q, clause, args...) queries.AppendIn(q, clause, args...)
} }
} }
// OrIn allows you to specify an IN clause separated by // OrIn allows you to specify an IN clause separated by
// an OR for your where statement // an OR for your where statement
func OrIn(clause string, args ...interface{}) QueryMod { func OrIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendIn(q, clause, args...) queries.AppendIn(q, clause, args...)
boil.SetLastInAsOr(q) queries.SetLastInAsOr(q)
} }
} }
// GroupBy allows you to specify a group by clause for your statement // GroupBy allows you to specify a group by clause for your statement
func GroupBy(clause string) QueryMod { func GroupBy(clause string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendGroupBy(q, clause) queries.AppendGroupBy(q, clause)
} }
} }
// OrderBy allows you to specify a order by clause for your statement // OrderBy allows you to specify a order by clause for your statement
func OrderBy(clause string) QueryMod { func OrderBy(clause string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendOrderBy(q, clause) queries.AppendOrderBy(q, clause)
} }
} }
// Having allows you to specify a having clause for your statement // Having allows you to specify a having clause for your statement
func Having(clause string, args ...interface{}) QueryMod { func Having(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendHaving(q, clause, args...) queries.AppendHaving(q, clause, args...)
} }
} }
// From allows to specify the table for your statement // From allows to specify the table for your statement
func From(from string) QueryMod { func From(from string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.AppendFrom(q, from) queries.AppendFrom(q, from)
} }
} }
// Limit the number of returned rows // Limit the number of returned rows
func Limit(limit int) QueryMod { func Limit(limit int) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.SetLimit(q, limit) queries.SetLimit(q, limit)
} }
} }
// Offset into the results // Offset into the results
func Offset(offset int) QueryMod { func Offset(offset int) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.SetOffset(q, offset) queries.SetOffset(q, offset)
} }
} }
// For inserts a concurrency locking clause at the end of your statement // For inserts a concurrency locking clause at the end of your statement
func For(clause string) QueryMod { func For(clause string) QueryMod {
return func(q *boil.Query) { return func(q *queries.Query) {
boil.SetFor(q, clause) queries.SetFor(q, clause)
} }
} }

View file

@ -1,8 +1,10 @@
package boil package queries
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/vattle/sqlboiler/boil"
) )
// joinKind is the type of join // joinKind is the type of join
@ -18,8 +20,9 @@ const (
// Query holds the state for the built up query // Query holds the state for the built up query
type Query struct { type Query struct {
executor Executor executor boil.Executor
plainSQL plainSQL dialect *Dialect
rawSQL rawSQL
load []string load []string
delete bool delete bool
update map[string]interface{} update map[string]interface{}
@ -37,6 +40,20 @@ type Query struct {
forlock string forlock string
} }
// Dialect holds values that direct the query builder
// how to build compatible queries for each database.
// Each database driver needs to implement functions
// that provide these values.
type Dialect struct {
// The left quote character for SQL identifiers
LQ byte
// The right quote character for SQL identifiers
RQ byte
// Bool flag indicating whether indexed
// placeholders ($1) are used, or ? placeholders.
IndexPlaceholders bool
}
type where struct { type where struct {
clause string clause string
orSeparator bool orSeparator bool
@ -54,7 +71,7 @@ type having struct {
args []interface{} args []interface{}
} }
type plainSQL struct { type rawSQL struct {
sql string sql string
args []interface{} args []interface{}
} }
@ -65,65 +82,92 @@ type join struct {
args []interface{} args []interface{}
} }
// SQL makes a plainSQL query, usually for use with bind // Raw makes a raw query, usually for use with bind
func SQL(exec Executor, query string, args ...interface{}) *Query { func Raw(exec boil.Executor, query string, args ...interface{}) *Query {
return &Query{ return &Query{
executor: exec, executor: exec,
plainSQL: plainSQL{ rawSQL: rawSQL{
sql: query, sql: query,
args: args, args: args,
}, },
} }
} }
// SQLG makes a plainSQL query using the global Executor, usually for use with bind // RawG makes a raw query using the global boil.Executor, usually for use with bind
func SQLG(query string, args ...interface{}) *Query { func RawG(query string, args ...interface{}) *Query {
return SQL(GetDB(), query, args...) return Raw(boil.GetDB(), query, args...)
} }
// ExecQuery executes a query that does not need a row returned // Exec executes a query that does not need a row returned
func ExecQuery(q *Query) (sql.Result, error) { func (q *Query) Exec() (sql.Result, error) {
qs, args := buildQuery(q) qs, args := buildQuery(q)
if DebugMode { if boil.DebugMode {
fmt.Fprintln(DebugWriter, qs) fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(DebugWriter, args) fmt.Fprintln(boil.DebugWriter, args)
} }
return q.executor.Exec(qs, args...) return q.executor.Exec(qs, args...)
} }
// ExecQueryOne executes the query for the One finisher and returns a row // QueryRow executes the query for the One finisher and returns a row
func ExecQueryOne(q *Query) *sql.Row { func (q *Query) QueryRow() *sql.Row {
qs, args := buildQuery(q) qs, args := buildQuery(q)
if DebugMode { if boil.DebugMode {
fmt.Fprintln(DebugWriter, qs) fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(DebugWriter, args) fmt.Fprintln(boil.DebugWriter, args)
} }
return q.executor.QueryRow(qs, args...) return q.executor.QueryRow(qs, args...)
} }
// ExecQueryAll executes the query for the All finisher and returns multiple rows // Query executes the query for the All finisher and returns multiple rows
func ExecQueryAll(q *Query) (*sql.Rows, error) { func (q *Query) Query() (*sql.Rows, error) {
qs, args := buildQuery(q) qs, args := buildQuery(q)
if DebugMode { if boil.DebugMode {
fmt.Fprintln(DebugWriter, qs) fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(DebugWriter, args) fmt.Fprintln(boil.DebugWriter, args)
} }
return q.executor.Query(qs, args...) return q.executor.Query(qs, args...)
} }
// ExecP executes a query that does not need a row returned
// It will panic on error
func (q *Query) ExecP() sql.Result {
res, err := q.Exec()
if err != nil {
panic(boil.WrapErr(err))
}
return res
}
// QueryP executes the query for the All finisher and returns multiple rows
// It will panic on error
func (q *Query) QueryP() *sql.Rows {
rows, err := q.Query()
if err != nil {
panic(boil.WrapErr(err))
}
return rows
}
// SetExecutor on the query. // SetExecutor on the query.
func SetExecutor(q *Query, exec Executor) { func SetExecutor(q *Query, exec boil.Executor) {
q.executor = exec q.executor = exec
} }
// GetExecutor on the query. // GetExecutor on the query.
func GetExecutor(q *Query) Executor { func GetExecutor(q *Query) boil.Executor {
return q.executor return q.executor
} }
// SetDialect on the query.
func SetDialect(q *Query, dialect *Dialect) {
q.dialect = dialect
}
// SetSQL on the query. // SetSQL on the query.
func SetSQL(q *Query, sql string, args ...interface{}) { func SetSQL(q *Query, sql string, args ...interface{}) {
q.plainSQL = plainSQL{sql: sql, args: args} q.rawSQL = rawSQL{sql: sql, args: args}
} }
// SetLoad on the query. // SetLoad on the query.
@ -131,6 +175,11 @@ func SetLoad(q *Query, relationships ...string) {
q.load = append([]string(nil), relationships...) q.load = append([]string(nil), relationships...)
} }
// SetSelect on the query.
func SetSelect(q *Query, sel []string) {
q.selectCols = sel
}
// SetCount on the query. // SetCount on the query.
func SetCount(q *Query) { func SetCount(q *Query) {
q.count = true q.count = true

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"bytes" "bytes"
@ -20,8 +20,8 @@ func buildQuery(q *Query) (string, []interface{}) {
var args []interface{} var args []interface{}
switch { switch {
case len(q.plainSQL.sql) != 0: case len(q.rawSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args return q.rawSQL.sql, q.rawSQL.args
case q.delete: case q.delete:
buf, args = buildDeleteQuery(q) buf, args = buildDeleteQuery(q)
case len(q.update) > 0: case len(q.update) > 0:
@ -34,8 +34,8 @@ func buildQuery(q *Query) (string, []interface{}) {
// Cache the generated query for query object re-use // Cache the generated query for query object re-use
bufStr := buf.String() bufStr := buf.String()
q.plainSQL.sql = bufStr q.rawSQL.sql = bufStr
q.plainSQL.args = args q.rawSQL.args = args
return bufStr, args return bufStr, args
} }
@ -57,8 +57,8 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
// Don't identQuoteSlice - writeAsStatements does this // Don't identQuoteSlice - writeAsStatements does this
buf.WriteString(strings.Join(selectColsWithAs, ", ")) buf.WriteString(strings.Join(selectColsWithAs, ", "))
} else if hasSelectCols { } else if hasSelectCols {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), ", ")) buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
} else if hasJoins { } else if hasJoins && !q.count {
selectColsWithStars := writeStars(q) selectColsWithStars := writeStars(q)
buf.WriteString(strings.Join(selectColsWithStars, ", ")) buf.WriteString(strings.Join(selectColsWithStars, ", "))
} else { } else {
@ -70,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteByte(')') buf.WriteByte(')')
} }
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
if len(q.joins) > 0 { if len(q.joins) > 0 {
argsLen := len(args) argsLen := len(args)
@ -82,7 +82,12 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause) fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
args = append(args, j.args...) args = append(args, j.args...)
} }
resp, _ := convertQuestionMarks(joinBuf.String(), argsLen+1) var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
} else {
resp = joinBuf.String()
}
fmt.Fprintf(buf, resp) fmt.Fprintf(buf, resp)
strmangle.PutBuffer(joinBuf) strmangle.PutBuffer(joinBuf)
} }
@ -110,7 +115,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer() buf := strmangle.GetBuffer()
buf.WriteString("DELETE FROM ") buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
where, whereArgs := whereClause(q, 1) where, whereArgs := whereClause(q, 1)
if len(whereArgs) != 0 { if len(whereArgs) != 0 {
@ -135,7 +140,7 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer() buf := strmangle.GetBuffer()
buf.WriteString("UPDATE ") buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
cols := make(sort.StringSlice, len(q.update)) cols := make(sort.StringSlice, len(q.update))
var args []interface{} var args []interface{}
@ -150,13 +155,13 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
for i := 0; i < len(cols); i++ { for i := 0; i < len(cols); i++ {
args = append(args, q.update[cols[i]]) args = append(args, q.update[cols[i]])
cols[i] = strmangle.IdentQuote(cols[i]) cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
} }
buf.WriteString(fmt.Sprintf( buf.WriteString(fmt.Sprintf(
" SET (%s) = (%s)", " SET (%s) = (%s)",
strings.Join(cols, ", "), strings.Join(cols, ", "),
strmangle.Placeholders(len(cols), 1, 1)), strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)),
) )
where, whereArgs := whereClause(q, len(args)+1) where, whereArgs := whereClause(q, len(args)+1)
@ -178,11 +183,40 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
return buf, args return buf, args
} }
// BuildUpsertQuery builds a SQL statement string using the upsertData provided. // BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided.
func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string {
conflict = strmangle.IdentQuoteSlice(conflict) whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
whitelist = strmangle.IdentQuoteSlice(whitelist)
ret = strmangle.IdentQuoteSlice(ret) buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
fmt.Fprintf(
buf,
"INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ",
tableName,
strings.Join(whitelist, ", "),
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
)
for i, v := range update {
if i != 0 {
buf.WriteByte(',')
}
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted)
buf.WriteString(" = VALUES(")
buf.WriteString(quoted)
buf.WriteByte(')')
}
return buf.String()
}
// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided.
func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict)
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret)
buf := strmangle.GetBuffer() buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf) defer strmangle.PutBuffer(buf)
@ -192,7 +226,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf
"INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ", "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ",
tableName, tableName,
strings.Join(whitelist, ", "), strings.Join(whitelist, ", "),
strmangle.Placeholders(len(whitelist), 1, 1), strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
) )
if !updateOnConflict || len(update) == 0 { if !updateOnConflict || len(update) == 0 {
@ -206,7 +240,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf
if i != 0 { if i != 0 {
buf.WriteByte(',') buf.WriteByte(',')
} }
quoted := strmangle.IdentQuote(v) quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted) buf.WriteString(quoted)
buf.WriteString(" = EXCLUDED.") buf.WriteString(" = EXCLUDED.")
buf.WriteString(quoted) buf.WriteString(quoted)
@ -237,7 +271,12 @@ func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
fmt.Fprintf(havingBuf, j.clause) fmt.Fprintf(havingBuf, j.clause)
*args = append(*args, j.args...) *args = append(*args, j.args...)
} }
resp, _ := convertQuestionMarks(havingBuf.String(), argsLen+1) var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1)
} else {
resp = havingBuf.String()
}
fmt.Fprintf(buf, resp) fmt.Fprintf(buf, resp)
strmangle.PutBuffer(havingBuf) strmangle.PutBuffer(havingBuf)
} }
@ -264,7 +303,7 @@ func writeStars(q *Query) []string {
for i, f := range q.from { for i, f := range q.from {
toks := strings.Split(f, " ") toks := strings.Split(f, " ")
if len(toks) == 1 { if len(toks) == 1 {
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0])) cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
continue continue
} }
@ -276,7 +315,7 @@ func writeStars(q *Query) []string {
if len(alias) != 0 { if len(alias) != 0 {
name = alias name = alias
} }
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name)) cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
} }
return cols return cols
@ -292,7 +331,7 @@ func writeAsStatements(q *Query) []string {
toks := strings.Split(col, ".") toks := strings.Split(col, ".")
if len(toks) == 1 { if len(toks) == 1 {
cols[i] = strmangle.IdentQuote(col) cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
continue continue
} }
@ -301,7 +340,7 @@ func writeAsStatements(q *Query) []string {
asParts[j] = strings.Trim(tok, `"`) asParts[j] = strings.Trim(tok, `"`)
} }
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, ".")) cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
} }
return cols return cols
@ -335,7 +374,13 @@ func whereClause(q *Query, startAt int) (string, []interface{}) {
args = append(args, where.args...) args = append(args, where.args...)
} }
resp, _ := convertQuestionMarks(buf.String(), startAt) var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(buf.String(), startAt)
} else {
resp = buf.String()
}
return resp, args return resp, args
} }
@ -374,7 +419,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// column name side, however if this case is being hit then the regexp // column name side, however if this case is being hit then the regexp
// probably needs adjustment, or the user is passing in invalid clauses. // probably needs adjustment, or the user is passing in invalid clauses.
if matches == nil { if matches == nil {
clause, count := convertInQuestionMarks(in.clause, startAt, 1, ln) clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln)
buf.WriteString(clause) buf.WriteString(clause)
startAt = startAt + count startAt = startAt + count
} else { } else {
@ -384,11 +429,24 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// of the clause to determine how many columns they are using. // of the clause to determine how many columns they are using.
// This number determines the groupAt for the convert function. // This number determines the groupAt for the convert function.
cols := strings.Split(leftSide, ",") cols := strings.Split(leftSide, ",")
cols = strmangle.IdentQuoteSlice(cols) cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
groupAt := len(cols) groupAt := len(cols)
leftClause, leftCount := convertQuestionMarks(strings.Join(cols, ","), startAt) var leftClause string
rightClause, rightCount := convertInQuestionMarks(rightSide, startAt+leftCount, groupAt, ln-leftCount) var leftCount int
if q.dialect.IndexPlaceholders {
leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
} else {
// Count the number of cols that are question marks, so we know
// how much to offset convertInQuestionMarks by
for _, v := range cols {
if v == "?" {
leftCount++
}
}
leftClause = strings.Join(cols, ",")
}
rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
buf.WriteString(leftClause) buf.WriteString(leftClause)
buf.WriteString(" IN ") buf.WriteString(" IN ")
buf.WriteString(rightClause) buf.WriteString(rightClause)
@ -406,7 +464,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// It uses groupAt to determine how many placeholders should be in each group, // It uses groupAt to determine how many placeholders should be in each group,
// for example, groupAt 2 would result in: (($1,$2),($3,$4)) // for example, groupAt 2 would result in: (($1,$2),($3,$4))
// and groupAt 1 would result in ($1,$2,$3,$4) // and groupAt 1 would result in ($1,$2,$3,$4)
func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, int) { func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
if startAt == 0 || len(clause) == 0 { if startAt == 0 || len(clause) == 0 {
panic("Not a valid start number.") panic("Not a valid start number.")
} }
@ -428,7 +486,7 @@ func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string,
paramBuf.WriteString(clause[:foundAt]) paramBuf.WriteString(clause[:foundAt])
paramBuf.WriteByte('(') paramBuf.WriteByte('(')
paramBuf.WriteString(strmangle.Placeholders(total, startAt, groupAt)) paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt))
paramBuf.WriteByte(')') paramBuf.WriteByte(')')
paramBuf.WriteString(clause[foundAt+1:]) paramBuf.WriteString(clause[foundAt+1:])

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"bytes" "bytes"
@ -97,6 +97,7 @@ func TestBuildQuery(t *testing.T) {
for i, test := range tests { for i, test := range tests {
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i)) filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
out, args := buildQuery(test.q) out, args := buildQuery(test.q)
if *writeGoldenFiles { if *writeGoldenFiles {
@ -149,6 +150,7 @@ func TestWriteStars(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
selects := writeStars(&test.In) selects := writeStars(&test.In)
if !reflect.DeepEqual(selects, test.Out) { if !reflect.DeepEqual(selects, test.Out) {
t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects) t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects)
@ -275,6 +277,7 @@ func TestWhereClause(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, _ := whereClause(&test.q, 1) result, _ := whereClause(&test.q, 1)
if result != test.expect { if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
@ -407,6 +410,7 @@ func TestInClause(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, args := inClause(&test.q, 1) result, args := inClause(&test.q, 1)
if result != test.expect { if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
@ -489,7 +493,7 @@ func TestConvertInQuestionMarks(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
res, count := convertInQuestionMarks(test.clause, test.start, test.group, test.total) res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total)
if res != test.expect { if res != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res) t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
} }
@ -497,6 +501,14 @@ func TestConvertInQuestionMarks(t *testing.T) {
t.Errorf("%d) Expected %d, got %d", i, test.total, count) t.Errorf("%d) Expected %d, got %d", i, test.total, count)
} }
} }
res, count := convertInQuestionMarks(false, "?", 1, 3, 9)
if res != "((?,?,?),(?,?,?),(?,?,?))" {
t.Errorf("Mismatch between expected and result: %s", res)
}
if count != 9 {
t.Errorf("Expected 9 results, got %d", count)
}
} }
func TestWriteAsStatements(t *testing.T) { func TestWriteAsStatements(t *testing.T) {
@ -512,6 +524,7 @@ func TestWriteAsStatements(t *testing.T) {
`a.clown.run`, `a.clown.run`,
`COUNT(a)`, `COUNT(a)`,
}, },
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
expect := []string{ expect := []string{

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"database/sql" "database/sql"
@ -36,12 +36,12 @@ func TestSetSQL(t *testing.T) {
q := &Query{} q := &Query{}
SetSQL(q, "select * from thing", 5, 3) SetSQL(q, "select * from thing", 5, 3)
if len(q.plainSQL.args) != 2 { if len(q.rawSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args)) t.Errorf("Expected len 2, got %d", len(q.rawSQL.args))
} }
if q.plainSQL.sql != "select * from thing" { if q.rawSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.plainSQL.sql) t.Errorf("Was not expected string, got %s", q.rawSQL.sql)
} }
} }
@ -290,6 +290,17 @@ func TestFrom(t *testing.T) {
} }
} }
func TestSetSelect(t *testing.T) {
t.Parallel()
q := &Query{selectCols: []string{"hello"}}
SetSelect(q, nil)
if q.selectCols != nil {
t.Errorf("want nil")
}
}
func TestSetCount(t *testing.T) { func TestSetCount(t *testing.T) {
t.Parallel() t.Parallel()
@ -362,24 +373,24 @@ func TestAppendSelect(t *testing.T) {
func TestSQL(t *testing.T) { func TestSQL(t *testing.T) {
t.Parallel() t.Parallel()
q := SQL(&sql.DB{}, "thing", 5) q := Raw(&sql.DB{}, "thing", 5)
if q.plainSQL.sql != "thing" { if q.rawSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql)
} }
if q.plainSQL.args[0].(int) != 5 { if q.rawSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) t.Errorf("Expected 5, got %v", q.rawSQL.args[0])
} }
} }
func TestSQLG(t *testing.T) { func TestSQLG(t *testing.T) {
t.Parallel() t.Parallel()
q := SQLG("thing", 5) q := RawG("thing", 5)
if q.plainSQL.sql != "thing" { if q.rawSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql)
} }
if q.plainSQL.args[0].(int) != 5 { if q.rawSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) t.Errorf("Expected 5, got %v", q.rawSQL.args[0])
} }
} }

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"database/sql" "database/sql"
@ -8,6 +8,7 @@ import (
"sync" "sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle" "github.com/vattle/sqlboiler/strmangle"
) )
@ -40,7 +41,7 @@ const (
// It panics on error. See boil.Bind() documentation. // It panics on error. See boil.Bind() documentation.
func (q *Query) BindP(obj interface{}) { func (q *Query) BindP(obj interface{}) {
if err := q.Bind(obj); err != nil { if err := q.Bind(obj); err != nil {
panic(WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
@ -100,7 +101,7 @@ func (q *Query) Bind(obj interface{}) error {
return err return err
} }
rows, err := ExecQueryAll(q) rows, err := q.Query()
if err != nil { if err != nil {
return errors.Wrap(err, "bind failed to execute query") return errors.Wrap(err, "bind failed to execute query")
} }
@ -322,8 +323,10 @@ func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.V
v := (mapping >> uint(i*8)) & sentinel v := (mapping >> uint(i*8)) & sentinel
if v == sentinel { if v == sentinel {
if val.Kind() != reflect.Ptr { if addressOf && val.Kind() != reflect.Ptr {
return val.Addr() return val.Addr()
} else if !addressOf && val.Kind() == reflect.Ptr {
return reflect.Indirect(val)
} }
return val return val
} }
@ -404,74 +407,3 @@ func makeCacheKey(typ string, cols []string) string {
return mapKey return mapKey
} }
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i] = field.Interface()
}
return ret
}
// GetSliceValues returns the values (as interface) of the matching columns in obj.
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(slice)*len(columns))
for i, obj := range slice {
val := reflect.Indirect(reflect.ValueOf(obj))
for j, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i*len(columns)+j] = field.Interface()
}
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ln int
var getField func(reflect.Value, int) reflect.Value
if len(columns) == 0 {
ln = val.NumField()
getField = func(v reflect.Value, i int) reflect.Value {
return v.Field(i)
}
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
return v.FieldByName(strmangle.TitleCase(columns[i]))
}
}
ret := make([]interface{}, ln)
for i := 0; i < ln; i++ {
field := getField(val, i)
if !field.IsValid() {
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
// produce an Invalid value, so we make a hopefully safe assumption here.
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
}
ret[i] = field.Addr().Interface()
}
return ret
}

View file

@ -1,4 +1,4 @@
package boil package queries
import ( import (
"database/sql/driver" "database/sql/driver"
@ -6,10 +6,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"gopkg.in/DATA-DOG/go-sqlmock.v1" "gopkg.in/DATA-DOG/go-sqlmock.v1"
"gopkg.in/nullbio/null.v4"
) )
func bin64(i uint64) string { func bin64(i uint64) string {
@ -44,7 +42,8 @@ func TestBindStruct(t *testing.T) {
}{} }{}
query := &Query{ query := &Query{
from: []string{"fun"}, from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
@ -83,7 +82,8 @@ func TestBindSlice(t *testing.T) {
}{} }{}
query := &Query{ query := &Query{
from: []string{"fun"}, from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
@ -133,7 +133,8 @@ func TestBindPtrSlice(t *testing.T) {
}{} }{}
query := &Query{ query := &Query{
from: []string{"fun"}, from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
@ -265,6 +266,76 @@ func TestPtrFromMapping(t *testing.T) {
} }
} }
func TestValuesFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)}
v := ValuesFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping)
if got := v[0].(int); got != 5 {
t.Error("flat int was wrong:", got)
}
if got := v[1].(int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
if got := v[2].(int); got != 6 {
t.Error("nested int was wrong:", got)
}
if got := v[3].(int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestPtrsFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)}
v := PtrsFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping)
if got := *v[0].(*int); got != 5 {
t.Error("flat int was wrong:", got)
}
if got := *v[1].(*int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
if got := *v[2].(*int); got != 6 {
t.Error("nested int was wrong:", got)
}
if got := *v[3].(*int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestGetBoilTag(t *testing.T) { func TestGetBoilTag(t *testing.T) {
t.Parallel() t.Parallel()
@ -369,7 +440,8 @@ func TestBindSingular(t *testing.T) {
}{} }{}
query := &Query{ query := &Query{
from: []string{"fun"}, from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
@ -412,8 +484,9 @@ func TestBind_InnerJoin(t *testing.T) {
}{} }{}
query := &Query{ query := &Query{
from: []string{"fun"}, from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}}, joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
} }
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
@ -454,249 +527,59 @@ func TestBind_InnerJoin(t *testing.T) {
} }
} }
// func TestBind_InnerJoinSelect(t *testing.T) { func TestBind_InnerJoinSelect(t *testing.T) {
// t.Parallel()
//
// testResults := []*struct {
// Happy struct {
// ID int
// } `boil:"h,bind"`
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// query := &Query{
// selectCols: []string{"fun.id", "h.id"},
// from: []string{"fun"},
// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
// }
//
// db, mock, err := sqlmock.New()
// if err != nil {
// t.Error(err)
// }
//
// ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
//
// SetExecutor(query, db)
// err = query.Bind(&testResults)
// if err != nil {
// t.Error(err)
// }
//
// if len(testResults) != 2 {
// t.Fatal("wrong number of results:", len(testResults))
// }
// if id := testResults[0].Happy.ID; id != 11 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[0].Fun.ID; id != 10 {
// t.Error("wrong ID:", id)
// }
//
// if id := testResults[1].Happy.ID; id != 13 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[1].Fun.ID; id != 12 {
// t.Error("wrong ID:", id)
// }
//
// if err := mock.ExpectationsWereMet(); err != nil {
// t.Error(err)
// }
// }
// func TestBindPtrs_Easy(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"identifier"`
// Date time.Time
// }{}
//
// cols := []string{"identifier", "date"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*time.Time) != &testStruct.Date {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Recursive(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// }
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id", "fun.id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_RecursiveTags(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// } `boil:",bind"`
// Fun struct {
// ID int `boil:"identification"`
// } `boil:",bind"`
// }{}
//
// cols := []string{"happy.identifier", "fun.identification"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Ignore(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"-"`
// Happy struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// }
func TestGetStructValues(t *testing.T) {
t.Parallel() t.Parallel()
timeThing := time.Now() testResults := []*struct {
o := struct { Happy struct {
TitleThing string ID int
Name string } `boil:"h,bind"`
ID int Fun struct {
Stuff int ID int
Things int } `boil:",bind"`
Time time.Time }{}
NullBool null.Bool
}{ query := &Query{
TitleThing: "patrick", dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
Stuff: 10, selectCols: []string{"fun.id", "h.id"},
Things: 0, from: []string{"fun"},
Time: timeThing, joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
NullBool: null.NewBool(true, false),
} }
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") db, mock, err := sqlmock.New()
if vals[0].(string) != "patrick" { if err != nil {
t.Errorf("Want test, got %s", vals[0]) t.Error(err)
} }
if vals[1].(string) != "" {
t.Errorf("Want empty string, got %s", vals[1]) ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
} }
if vals[2].(int) != 0 {
t.Errorf("Want 0, got %d", vals[2]) if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
} }
if vals[3].(int) != 10 { if id := testResults[0].Happy.ID; id != 11 {
t.Errorf("Want 10, got %d", vals[3]) t.Error("wrong ID:", id)
} }
if vals[4].(int) != 0 { if id := testResults[0].Fun.ID; id != 10 {
t.Errorf("Want 0, got %d", vals[4]) t.Error("wrong ID:", id)
} }
if !vals[5].(time.Time).Equal(timeThing) {
t.Errorf("Want %s, got %s", o.Time, vals[5]) if id := testResults[1].Happy.ID; id != 13 {
t.Error("wrong ID:", id)
} }
if !vals[6].(null.Bool).IsZero() { if id := testResults[1].Fun.ID; id != 12 {
t.Errorf("Want %v, got %v", o.NullBool, vals[6]) t.Error("wrong ID:", id)
} }
}
if err := mock.ExpectationsWereMet(); err != nil {
func TestGetSliceValues(t *testing.T) { t.Error(err)
t.Parallel()
o := []struct {
ID int
Name string
}{
{5, "a"},
{6, "b"},
}
in := make([]interface{}, len(o))
in[0] = o[0]
in[1] = o[1]
vals := GetSliceValues(in, "id", "name")
if got := vals[0].(int); got != 5 {
t.Error(got)
}
if got := vals[1].(string); got != "a" {
t.Error(got)
}
if got := vals[2].(int); got != 6 {
t.Error(got)
}
if got := vals[3].(string); got != "b" {
t.Error(got)
}
}
func TestGetStructPointers(t *testing.T) {
t.Parallel()
o := struct {
Title string
ID *int
}{
Title: "patrick",
}
ptrs := GetStructPointers(&o, "title", "id")
*ptrs[0].(*string) = "test"
if o.Title != "test" {
t.Errorf("Expected test, got %s", o.Title)
}
x := 5
*ptrs[1].(**int) = &x
if *o.ID != 5 {
t.Errorf("Expected 5, got %d", *o.ID)
} }
} }

121
randomize/random.go Normal file
View file

@ -0,0 +1,121 @@
package randomize
import (
"crypto/md5"
"fmt"
"math/rand"
)
const alphabetAll = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const alphabetLowerAlpha = "abcdefghijklmnopqrstuvwxyz"
func randStr(s *Seed, ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabetAll[s.nextInt()%len(alphabetAll)])
}
return string(str)
}
func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256)
}
return str
}
func randPoint() string {
a := rand.Intn(100)
b := a + 1
return fmt.Sprintf("(%d,%d)", a, b)
}
func randBox() string {
a := rand.Intn(100)
b := a + 1
c := a + 2
d := a + 3
return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d)
}
func randCircle() string {
a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100)
return fmt.Sprintf("((%d,%d),%d)", a, b, c)
}
func randNetAddr() string {
return fmt.Sprintf(
"%d.%d.%d.%d",
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
)
}
func randMacAddr() string {
buf := make([]byte, 6)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
// Set the local bit
buf[0] |= 2
return fmt.Sprintf(
"%02x:%02x:%02x:%02x:%02x:%02x",
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5],
)
}
func randLsn() string {
a := rand.Int63n(9000000)
b := rand.Int63n(9000000)
return fmt.Sprintf("%d/%d", a, b)
}
func randTxID() string {
// Order of integers is relevant
a := rand.Intn(200) + 100
b := a + 100
c := a
d := a + 50
return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d)
}
func randMoney(s *Seed) string {
return fmt.Sprintf("%d.00", s.nextInt())
}
// StableDBName takes a database name in, and generates
// a random string using the database name as the rand Seed.
// getDBNameHash is used to generate unique test database names.
func StableDBName(input string) string {
return randStrFromSource(stableSource(input), 40)
}
// stableSource takes an input value, and produces a random
// seed from it that will produce very few collisions in
// a 40 character random string made from a different alphabet.
func stableSource(input string) *rand.Rand {
sum := md5.Sum([]byte(input))
var seed int64
for i, byt := range sum {
seed ^= int64(byt) << uint((i*4)%64)
}
return rand.New(rand.NewSource(seed))
}
func randStrFromSource(r *rand.Rand, length int) string {
ln := len(alphabetLowerAlpha)
output := make([]rune, length)
for i := 0; i < length; i++ {
output[i] = rune(alphabetLowerAlpha[r.Intn(ln)])
}
return string(output)
}

19
randomize/random_test.go Normal file
View file

@ -0,0 +1,19 @@
package randomize
import "testing"
func TestStableDBName(t *testing.T) {
t.Parallel()
db := "awesomedb"
one, two := StableDBName(db), StableDBName(db)
if len(one) != 40 {
t.Error("want 40 characters:", len(one), one)
}
if one != two {
t.Error("it should always produce the same value")
}
}

View file

@ -2,41 +2,58 @@
package randomize package randomize
import ( import (
"database/sql"
"fmt"
"reflect" "reflect"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
"gopkg.in/nullbio/null.v4" "gopkg.in/nullbio/null.v5"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"github.com/vattle/sqlboiler/strmangle" "github.com/vattle/sqlboiler/strmangle"
"github.com/vattle/sqlboiler/types"
) )
var ( var (
typeNullFloat32 = reflect.TypeOf(null.Float32{}) typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{}) typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{}) typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{}) typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{}) typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{}) typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{}) typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{}) typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{}) typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{}) typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{}) typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{}) typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{}) typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{}) typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{}) typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{}) typeNullBytes = reflect.TypeOf(null.Bytes{})
typeNullJSON = reflect.TypeOf(null.JSON{})
typeTime = reflect.TypeOf(time.Time{})
typeJSON = reflect.TypeOf(types.JSON{})
typeInt64Array = reflect.TypeOf(types.Int64Array{})
typeBytesArray = reflect.TypeOf(types.BytesArray{})
typeBoolArray = reflect.TypeOf(types.BoolArray{})
typeFloat64Array = reflect.TypeOf(types.Float64Array{})
typeStringArray = reflect.TypeOf(types.StringArray{})
typeHStore = reflect.TypeOf(types.HStore{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
rgxValidTime = regexp.MustCompile(`[2-9]+`) validatedTypes = []string{
"inet", "line", "uuid", "interval",
validatedTypes = []string{"uuid", "interval"} "json", "jsonb", "box", "cidr", "circle",
"lseg", "macaddr", "path", "pg_lsn", "point",
"polygon", "txid_snapshot", "money", "hstore",
}
) )
// Seed is an atomic counter for pseudo-randomization structs. Using full // Seed is an atomic counter for pseudo-randomization structs. Using full
@ -163,7 +180,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
return nil return nil
} }
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = null.NewString(randBox(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = null.NewString(randNetAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = null.NewString(randMacAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = null.NewString(randCircle(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = null.NewString(randLsn(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = null.NewString(randPoint(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = null.NewString(randTxID(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = null.NewString(randMoney(s), true)
field.Set(reflect.ValueOf(value))
return nil
}
case typeNullJSON:
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
field.Set(reflect.ValueOf(value))
return nil
case typeHStore:
value := types.HStore{}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
} }
} else { } else {
switch kind { switch kind {
case reflect.String: case reflect.String:
@ -177,6 +246,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
return nil return nil
} }
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = randBox()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = randNetAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = randMacAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = randCircle()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = randLsn()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = randPoint()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = randTxID()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = randMoney(s)
field.Set(reflect.ValueOf(value))
return nil
}
}
switch typ {
case typeJSON:
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
field.Set(reflect.ValueOf(value))
return nil
case typeHStore:
value := types.HStore{}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
} }
} }
} }
@ -191,8 +313,11 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
isNull = false isNull = false
} }
// Retrieve the value to be returned // If it's a Postgres array, treat it like one
if kind == reflect.Struct { if strings.HasPrefix(fieldType, "ARRAY") {
value = getArrayRandValue(s, typ, fieldType)
// Retrieve the value to be returned
} else if kind == reflect.Struct {
if isNull { if isNull {
value = getStructNullValue(typ) value = getStructNullValue(typ)
} else { } else {
@ -215,6 +340,69 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
return nil return nil
} }
func getArrayRandValue(s *Seed, typ reflect.Type, fieldType string) interface{} {
fieldType = strings.TrimLeft(fieldType, "ARRAY")
switch typ {
case typeInt64Array:
return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())}
case typeFloat64Array:
return types.Float64Array{float64(s.nextInt()), float64(s.nextInt())}
case typeBoolArray:
return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0}
case typeStringArray:
if fieldType == "interval" {
value := strconv.Itoa((s.nextInt()%26)+2) + " days"
return types.StringArray{value, value}
}
if fieldType == "uuid" {
value := uuid.NewV4().String()
return types.StringArray{value, value}
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value := randBox()
return types.StringArray{value, value}
}
if fieldType == "cidr" || fieldType == "inet" {
value := randNetAddr()
return types.StringArray{value, value}
}
if fieldType == "macaddr" {
value := randMacAddr()
return types.StringArray{value, value}
}
if fieldType == "circle" {
value := randCircle()
return types.StringArray{value, value}
}
if fieldType == "pg_lsn" {
value := randLsn()
return types.StringArray{value, value}
}
if fieldType == "point" {
value := randPoint()
return types.StringArray{value, value}
}
if fieldType == "txid_snapshot" {
value := randTxID()
return types.StringArray{value, value}
}
if fieldType == "money" {
value := randMoney(s)
return types.StringArray{value, value}
}
if fieldType == "json" || fieldType == "jsonb" {
value := []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
return types.StringArray{string(value)}
}
return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)}
case typeBytesArray:
return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)}
}
return nil
}
// getStructNullValue for the matching type. // getStructNullValue for the matching type.
func getStructNullValue(typ reflect.Type) interface{} { func getStructNullValue(typ reflect.Type) interface{} {
switch typ { switch typ {
@ -250,6 +438,8 @@ func getStructNullValue(typ reflect.Type) interface{} {
return null.NewUint32(0, false) return null.NewUint32(0, false)
case typeNullUint64: case typeNullUint64:
return null.NewUint64(0, false) return null.NewUint64(0, false)
case typeNullBytes:
return null.NewBytes(nil, false)
} }
return nil return nil
@ -292,6 +482,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
return null.NewUint32(uint32(s.nextInt()), true) return null.NewUint32(uint32(s.nextInt()), true)
case typeNullUint64: case typeNullUint64:
return null.NewUint64(uint64(s.nextInt()), true) return null.NewUint64(uint64(s.nextInt()), true)
case typeNullBytes:
return null.NewBytes(randByteSlice(s, 16), true)
} }
return nil return nil
@ -378,23 +570,3 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac
return nil return nil
} }
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randStr(s *Seed, ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabet[s.nextInt()%len(alphabet)])
}
return string(str)
}
func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256)
}
return str
}

View file

@ -5,7 +5,7 @@ import (
"testing" "testing"
"time" "time"
"gopkg.in/nullbio/null.v4" "gopkg.in/nullbio/null.v5"
) )
func TestRandomizeStruct(t *testing.T) { func TestRandomizeStruct(t *testing.T) {

View file

@ -15,7 +15,8 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/bdb/drivers" "github.com/vattle/sqlboiler/bdb/drivers"
"github.com/vattle/sqlboiler/boil" "github.com/vattle/sqlboiler/queries"
"github.com/vattle/sqlboiler/strmangle"
) )
const ( const (
@ -32,8 +33,9 @@ const (
type State struct { type State struct {
Config *Config Config *Config
Driver bdb.Interface Driver bdb.Interface
Tables []bdb.Table Tables []bdb.Table
Dialect queries.Dialect
Templates *templateList Templates *templateList
TestTemplates *templateList TestTemplates *templateList
@ -59,7 +61,7 @@ func New(config *Config) (*State, error) {
return nil, errors.Wrap(err, "unable to connect to the database") return nil, errors.Wrap(err, "unable to connect to the database")
} }
err = s.initTables(config.ExcludeTables) err = s.initTables(config.Schema, config.WhitelistTables, config.BlacklistTables)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "unable to initialize tables") return nil, errors.Wrap(err, "unable to initialize tables")
} }
@ -69,8 +71,7 @@ func New(config *Config) (*State, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "unable to json marshal tables") return nil, errors.Wrap(err, "unable to json marshal tables")
} }
boil.DebugWriter.Write(b) fmt.Printf("%s\n", b)
fmt.Fprintln(boil.DebugWriter)
} }
err = s.initOutFolder() err = s.initOutFolder()
@ -96,11 +97,15 @@ func New(config *Config) (*State, error) {
func (s *State) Run(includeTests bool) error { func (s *State) Run(includeTests bool) error {
singletonData := &templateData{ singletonData := &templateData{
Tables: s.Tables, Tables: s.Tables,
Schema: s.Config.Schema,
DriverName: s.Config.DriverName, DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(), UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName, PkgName: s.Config.PkgName,
NoHooks: s.Config.NoHooks, NoHooks: s.Config.NoHooks,
NoAutoTimestamps: s.Config.NoAutoTimestamps, NoAutoTimestamps: s.Config.NoAutoTimestamps,
Dialect: s.Dialect,
LQ: strmangle.QuoteCharacter(s.Dialect.LQ),
RQ: strmangle.QuoteCharacter(s.Dialect.RQ),
StringFuncs: templateStringMappers, StringFuncs: templateStringMappers,
} }
@ -127,12 +132,16 @@ func (s *State) Run(includeTests bool) error {
data := &templateData{ data := &templateData{
Tables: s.Tables, Tables: s.Tables,
Table: table, Table: table,
Schema: s.Config.Schema,
DriverName: s.Config.DriverName, DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(), UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName, PkgName: s.Config.PkgName,
NoHooks: s.Config.NoHooks, NoHooks: s.Config.NoHooks,
NoAutoTimestamps: s.Config.NoAutoTimestamps, NoAutoTimestamps: s.Config.NoAutoTimestamps,
Tags: s.Config.Tags, Tags: s.Config.Tags,
Dialect: s.Dialect,
LQ: strmangle.QuoteCharacter(s.Dialect.LQ),
RQ: strmangle.QuoteCharacter(s.Dialect.RQ),
StringFuncs: templateStringMappers, StringFuncs: templateStringMappers,
} }
@ -227,6 +236,15 @@ func (s *State) initDriver(driverName string) error {
s.Config.Postgres.Port, s.Config.Postgres.Port,
s.Config.Postgres.SSLMode, s.Config.Postgres.SSLMode,
) )
case "mysql":
s.Driver = drivers.NewMySQLDriver(
s.Config.MySQL.User,
s.Config.MySQL.Pass,
s.Config.MySQL.DBName,
s.Config.MySQL.Host,
s.Config.MySQL.Port,
s.Config.MySQL.SSLMode,
)
case "mock": case "mock":
s.Driver = &drivers.MockDriver{} s.Driver = &drivers.MockDriver{}
} }
@ -235,13 +253,17 @@ func (s *State) initDriver(driverName string) error {
return errors.New("An invalid driver name was provided") return errors.New("An invalid driver name was provided")
} }
s.Dialect.LQ = s.Driver.LeftQuote()
s.Dialect.RQ = s.Driver.RightQuote()
s.Dialect.IndexPlaceholders = s.Driver.IndexPlaceholders()
return nil return nil
} }
// initTables retrieves all "public" schema table names from the database. // initTables retrieves all "public" schema table names from the database.
func (s *State) initTables(exclude []string) error { func (s *State) initTables(schema string, whitelist, blacklist []string) error {
var err error var err error
s.Tables, err = bdb.Tables(s.Driver, exclude...) s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, blacklist)
if err != nil { if err != nil {
return errors.Wrap(err, "unable to fetch table data") return errors.Wrap(err, "unable to fetch table data")
} }

View file

@ -37,10 +37,10 @@ func TestNew(t *testing.T) {
}() }()
config := &Config{ config := &Config{
DriverName: "mock", DriverName: "mock",
PkgName: "models", PkgName: "models",
OutFolder: out, OutFolder: out,
ExcludeTables: []string{"hangars"}, BlacklistTables: []string{"hangars"},
} }
state, err = New(config) state, err = New(config)

View file

@ -22,6 +22,7 @@ var uppercaseWords = map[string]struct{}{
"id": {}, "id": {},
"uid": {}, "uid": {},
"uuid": {}, "uuid": {},
"json": {},
} }
func init() { func init() {
@ -33,9 +34,21 @@ func init() {
boilRuleset = newBoilRuleset() boilRuleset = newBoilRuleset()
} }
// SchemaTable returns a table name with a schema prefixed if
// using a database that supports real schemas, for example,
// for Postgres: "schema_name"."table_name", versus
// simply "table_name" for MySQL (because it does not support real schemas)
func SchemaTable(lq, rq string, driver string, schema string, table string) string {
if driver == "postgres" && schema != "public" {
return fmt.Sprintf(`%s%s%s.%s%s%s`, lq, schema, rq, lq, table, rq)
}
return fmt.Sprintf(`%s%s%s`, lq, table, rq)
}
// IdentQuote attempts to quote simple identifiers in SQL tatements // IdentQuote attempts to quote simple identifiers in SQL tatements
func IdentQuote(s string) string { func IdentQuote(lq byte, rq byte, s string) string {
if strings.ToLower(s) == "null" { if strings.ToLower(s) == "null" || s == "?" {
return s return s
} }
@ -52,28 +65,28 @@ func IdentQuote(s string) string {
buf.WriteByte('.') buf.WriteByte('.')
} }
if strings.HasPrefix(split, `"`) || strings.HasSuffix(split, `"`) || split == "*" { if split[0] == lq || split[len(split)-1] == rq || split == "*" {
buf.WriteString(split) buf.WriteString(split)
continue continue
} }
buf.WriteByte('"') buf.WriteByte(lq)
buf.WriteString(split) buf.WriteString(split)
buf.WriteByte('"') buf.WriteByte(rq)
} }
return buf.String() return buf.String()
} }
// IdentQuoteSlice applies IdentQuote to a slice. // IdentQuoteSlice applies IdentQuote to a slice.
func IdentQuoteSlice(s []string) []string { func IdentQuoteSlice(lq byte, rq byte, s []string) []string {
if len(s) == 0 { if len(s) == 0 {
return s return s
} }
strs := make([]string, len(s)) strs := make([]string, len(s))
for i, str := range s { for i, str := range s {
strs[i] = IdentQuote(str) strs[i] = IdentQuote(lq, rq, str)
} }
return strs return strs
@ -105,6 +118,16 @@ func Identifier(in int) string {
return cols.String() return cols.String()
} }
// QuoteCharacter returns a string that allows the quote character
// to be embedded into a Go string that uses double quotes:
func QuoteCharacter(q byte) string {
if q == '"' {
return `\"`
}
return string(q)
}
// Plural converts singular words to plural words (eg: person to people) // Plural converts singular words to plural words (eg: person to people)
func Plural(name string) string { func Plural(name string) string {
buf := GetBuffer() buf := GetBuffer()
@ -368,7 +391,8 @@ func PrefixStringSlice(str string, strs []string) []string {
// Placeholders generates the SQL statement placeholders for in queries. // Placeholders generates the SQL statement placeholders for in queries.
// For example, ($1,$2,$3),($4,$5,$6) etc. // For example, ($1,$2,$3),($4,$5,$6) etc.
// It will start counting placeholders at "start". // It will start counting placeholders at "start".
func Placeholders(count int, start int, group int) string { // If indexPlaceholders is false, it will convert to ? instead of $1 etc.
func Placeholders(indexPlaceholders bool, count int, start int, group int) string {
buf := GetBuffer() buf := GetBuffer()
defer PutBuffer(buf) defer PutBuffer(buf)
@ -387,7 +411,11 @@ func Placeholders(count int, start int, group int) string {
buf.WriteByte(',') buf.WriteByte(',')
} }
} }
buf.WriteString(fmt.Sprintf("$%d", start+i)) if indexPlaceholders {
buf.WriteString(fmt.Sprintf("$%d", start+i))
} else {
buf.WriteByte('?')
}
} }
if group > 1 { if group > 1 {
buf.WriteByte(')') buf.WriteByte(')')
@ -399,14 +427,19 @@ func Placeholders(count int, start int, group int) string {
// SetParamNames takes a slice of columns and returns a comma separated // SetParamNames takes a slice of columns and returns a comma separated
// list of parameter names for a template statement SET clause. // list of parameter names for a template statement SET clause.
// eg: "col1"=$1, "col2"=$2, "col3"=$3 // eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string { func SetParamNames(lq, rq string, start int, columns []string) string {
buf := GetBuffer() buf := GetBuffer()
defer PutBuffer(buf) defer PutBuffer(buf)
for i, c := range columns { for i, c := range columns {
buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, i+1)) if start != 0 {
buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, i+start))
} else {
buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq))
}
if i < len(columns)-1 { if i < len(columns)-1 {
buf.WriteString(", ") buf.WriteByte(',')
} }
} }
@ -415,16 +448,17 @@ func SetParamNames(columns []string) string {
// WhereClause returns the where clause using start as the $ flag index // WhereClause returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WhereClause(start int, cols []string) string { func WhereClause(lq, rq string, start int, cols []string) string {
if start == 0 {
panic("0 is not a valid start number for whereClause")
}
buf := GetBuffer() buf := GetBuffer()
defer PutBuffer(buf) defer PutBuffer(buf)
for i, c := range cols { for i, c := range cols {
buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, start+i)) if start != 0 {
buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i))
} else {
buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq))
}
if i < len(cols)-1 { if i < len(cols)-1 {
buf.WriteString(" AND ") buf.WriteString(" AND ")
} }

View file

@ -29,7 +29,7 @@ func TestIdentQuote(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
if got := IdentQuote(test.In); got != test.Out { if got := IdentQuote('"', '"', test.In); got != test.Out {
t.Errorf("want: %s, got: %s", test.Out, got) t.Errorf("want: %s, got: %s", test.Out, got)
} }
} }
@ -38,7 +38,7 @@ func TestIdentQuote(t *testing.T) {
func TestIdentQuoteSlice(t *testing.T) { func TestIdentQuoteSlice(t *testing.T) {
t.Parallel() t.Parallel()
ret := IdentQuoteSlice([]string{`thing`, `null`}) ret := IdentQuoteSlice('"', '"', []string{`thing`, `null`})
if ret[0] != `"thing"` { if ret[0] != `"thing"` {
t.Error(ret[0]) t.Error(ret[0])
} }
@ -69,34 +69,60 @@ func TestIdentifier(t *testing.T) {
} }
} }
func TestQuoteCharacter(t *testing.T) {
t.Parallel()
if QuoteCharacter('[') != "[" {
t.Error("want just the normal quote character")
}
if QuoteCharacter('`') != "`" {
t.Error("want just the normal quote character")
}
if QuoteCharacter('"') != `\"` {
t.Error("want an escaped character")
}
}
func TestPlaceholders(t *testing.T) { func TestPlaceholders(t *testing.T) {
t.Parallel() t.Parallel()
x := Placeholders(1, 2, 1) x := Placeholders(true, 1, 2, 1)
want := "$2" want := "$2"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
} }
x = Placeholders(5, 1, 1) x = Placeholders(true, 5, 1, 1)
want = "$1,$2,$3,$4,$5" want = "$1,$2,$3,$4,$5"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
} }
x = Placeholders(6, 1, 2) x = Placeholders(false, 5, 1, 1)
want = "?,?,?,?,?"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(true, 6, 1, 2)
want = "($1,$2),($3,$4),($5,$6)" want = "($1,$2),($3,$4),($5,$6)"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
} }
x = Placeholders(9, 1, 3) x = Placeholders(true, 6, 1, 2)
want = "($1,$2,$3),($4,$5,$6),($7,$8,$9)" want = "($1,$2),($3,$4),($5,$6)"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
} }
x = Placeholders(7, 1, 3) x = Placeholders(false, 9, 1, 3)
want = "(?,?,?),(?,?,?),(?,?,?)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(true, 7, 1, 3)
want = "($1,$2,$3),($4,$5,$6),($7)" want = "($1,$2,$3),($4,$5,$6),($7)"
if want != x { if want != x {
t.Errorf("want %s, got %s", want, x) t.Errorf("want %s, got %s", want, x)
@ -291,6 +317,28 @@ func TestPrefixStringSlice(t *testing.T) {
} }
} }
func TestSetParamNames(t *testing.T) {
t.Parallel()
tests := []struct {
Cols []string
Start int
Should string
}{
{Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=?,"col2"=?`},
{Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`},
{Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4,"col2"=$5`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4,"col2"=$5,"col3"=$6`},
}
for i, test := range tests {
r := SetParamNames(`"`, `"`, test.Start, test.Cols)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
}
}
func TestWhereClause(t *testing.T) { func TestWhereClause(t *testing.T) {
t.Parallel() t.Parallel()
@ -299,13 +347,14 @@ func TestWhereClause(t *testing.T) {
Start int Start int
Should string Should string
}{ }{
{Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=? AND "col2"=?`},
{Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`},
{Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`}, {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`}, {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`},
} }
for i, test := range tests { for i, test := range tests {
r := WhereClause(test.Start, test.Cols) r := WhereClause(`"`, `"`, test.Start, test.Cols)
if r != test.Should { if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
} }

View file

@ -8,21 +8,45 @@ import (
"text/template" "text/template"
"github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/queries"
"github.com/vattle/sqlboiler/strmangle" "github.com/vattle/sqlboiler/strmangle"
) )
// templateData for sqlboiler templates // templateData for sqlboiler templates
type templateData struct { type templateData struct {
Tables []bdb.Table Tables []bdb.Table
Table bdb.Table Table bdb.Table
DriverName string
UseLastInsertID bool // Controls what names are output
PkgName string PkgName string
Schema string
// Controls which code is output (mysql vs postgres ...)
DriverName string
UseLastInsertID bool
// Turn off auto timestamps or hook generation
NoHooks bool NoHooks bool
NoAutoTimestamps bool NoAutoTimestamps bool
Tags []string
// Tags control which
Tags []string
// StringFuncs are usable in templates with stringMap
StringFuncs map[string]func(string) string StringFuncs map[string]func(string) string
// Dialect controls quoting
Dialect queries.Dialect
LQ string
RQ string
}
func (t templateData) Quotes(s string) string {
return fmt.Sprintf("%s%s%s", t.LQ, s, t.RQ)
}
func (t templateData) SchemaTable(table string) string {
return strmangle.SchemaTable(t.LQ, t.RQ, t.DriverName, t.Schema, table)
} }
type templateList struct { type templateList struct {
@ -113,7 +137,7 @@ var templateStringMappers = map[string]func(string) string{
// add a function pointer here. // add a function pointer here.
var templateFunctions = template.FuncMap{ var templateFunctions = template.FuncMap{
// String ops // String ops
"quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) }, "quoteWrap": func(s string) string { return fmt.Sprintf(`"%s"`, s) },
"id": strmangle.Identifier, "id": strmangle.Identifier,
// Pluralization // Pluralization
@ -150,6 +174,7 @@ var templateFunctions = template.FuncMap{
// dbdrivers ops // dbdrivers ops
"filterColumnsByDefault": bdb.FilterColumnsByDefault, "filterColumnsByDefault": bdb.FilterColumnsByDefault,
"autoIncPrimaryKey": bdb.AutoIncPrimaryKey,
"sqlColDefinitions": bdb.SQLColDefinitions, "sqlColDefinitions": bdb.SQLColDefinitions,
"columnNames": bdb.ColumnNames, "columnNames": bdb.ColumnNames,
"columnDBTypes": bdb.ColumnDBTypes, "columnDBTypes": bdb.ColumnDBTypes,

View file

@ -1,5 +1,5 @@
{{- define "relationship_to_one_struct_helper" -}} {{- define "relationship_to_one_struct_helper" -}}
{{.Function.Name}} *{{.ForeignTable.NameGo}} {{.Function.Name}} *{{.ForeignTable.NameGo}}
{{- end -}} {{- end -}}
{{- $dot := . -}} {{- $dot := . -}}
@ -8,30 +8,30 @@
{{- $modelNameCamel := $tableNameSingular | camelCase -}} {{- $modelNameCamel := $tableNameSingular | camelCase -}}
// {{$modelName}} is an object representing the database table. // {{$modelName}} is an object representing the database table.
type {{$modelName}} struct { type {{$modelName}} struct {
{{range $column := .Table.Columns -}} {{range $column := .Table.Columns -}}
{{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"` {{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"`
{{end -}} {{end -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else}} {{- else}}
R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
{{end -}} {{end -}}
} }
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else}} {{- else}}
// {{$modelNameCamel}}R is where relationships are stored. // {{$modelNameCamel}}R is where relationships are stored.
type {{$modelNameCamel}}R struct { type {{$modelNameCamel}}R struct {
{{range .Table.FKeys -}} {{range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_struct_helper" $rel}} {{- template "relationship_to_one_struct_helper" $rel}}
{{end -}} {{end -}}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}} {{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}}
{{else -}} {{else -}}
{{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}}
{{$rel.Function.Name}} {{$rel.ForeignTable.Slice}} {{$rel.Function.Name}} {{$rel.ForeignTable.Slice}}
{{end -}}{{/* if ForeignColumnUnique */}} {{end -}}{{/* if ForeignColumnUnique */}}
{{- end -}}{{/* range tomany */}} {{- end -}}{{/* range tomany */}}
} }

View file

@ -3,31 +3,36 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
var ( var (
{{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
{{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
) )
type ( type (
{{$tableNameSingular}}Slice []*{{$tableNameSingular}} // {{$tableNameSingular}}Slice is an alias for a slice of pointers to {{$tableNameSingular}}.
{{if eq .NoHooks false -}} // This should generally be used opposed to []{{$tableNameSingular}}.
{{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error {{$tableNameSingular}}Slice []*{{$tableNameSingular}}
{{- end}} {{if eq .NoHooks false -}}
// {{$tableNameSingular}}Hook is the signature for custom {{$tableNameSingular}} hook methods
{{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error
{{- end}}
{{$varNameSingular}}Query struct { {{$varNameSingular}}Query struct {
*boil.Query *queries.Query
} }
) )
// Cache for insert and update // Cache for insert, update and upsert
var ( var (
{{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{})
{{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type) {{$varNameSingular}}Mapping = queries.MakeStructMapping({{$varNameSingular}}Type)
{{$varNameSingular}}InsertCacheMut sync.RWMutex {{$varNameSingular}}InsertCacheMut sync.RWMutex
{{$varNameSingular}}InsertCache = make(map[string]insertCache) {{$varNameSingular}}InsertCache = make(map[string]insertCache)
{{$varNameSingular}}UpdateCacheMut sync.RWMutex {{$varNameSingular}}UpdateCacheMut sync.RWMutex
{{$varNameSingular}}UpdateCache = make(map[string]updateCache) {{$varNameSingular}}UpdateCache = make(map[string]updateCache)
{{$varNameSingular}}UpsertCacheMut sync.RWMutex
{{$varNameSingular}}UpsertCache = make(map[string]insertCache)
) )
// Force time package dependency for automated UpdatedAt/CreatedAt. // Force time package dependency for automated UpdatedAt/CreatedAt.

View file

@ -14,123 +14,124 @@ var {{$varNameSingular}}AfterUpsertHooks []{{$tableNameSingular}}Hook
// doBeforeInsertHooks executes all "before insert" hooks. // doBeforeInsertHooks executes all "before insert" hooks.
func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeInsertHooks { for _, hook := range {{$varNameSingular}}BeforeInsertHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doBeforeUpdateHooks executes all "before Update" hooks. // doBeforeUpdateHooks executes all "before Update" hooks.
func (o *{{$tableNameSingular}}) doBeforeUpdateHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doBeforeUpdateHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeUpdateHooks { for _, hook := range {{$varNameSingular}}BeforeUpdateHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doBeforeDeleteHooks executes all "before Delete" hooks. // doBeforeDeleteHooks executes all "before Delete" hooks.
func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeDeleteHooks { for _, hook := range {{$varNameSingular}}BeforeDeleteHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doBeforeUpsertHooks executes all "before Upsert" hooks. // doBeforeUpsertHooks executes all "before Upsert" hooks.
func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { for _, hook := range {{$varNameSingular}}BeforeUpsertHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doAfterInsertHooks executes all "after Insert" hooks. // doAfterInsertHooks executes all "after Insert" hooks.
func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterInsertHooks { for _, hook := range {{$varNameSingular}}AfterInsertHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doAfterSelectHooks executes all "after Select" hooks. // doAfterSelectHooks executes all "after Select" hooks.
func (o *{{$tableNameSingular}}) doAfterSelectHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doAfterSelectHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterSelectHooks { for _, hook := range {{$varNameSingular}}AfterSelectHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doAfterUpdateHooks executes all "after Update" hooks. // doAfterUpdateHooks executes all "after Update" hooks.
func (o *{{$tableNameSingular}}) doAfterUpdateHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doAfterUpdateHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterUpdateHooks { for _, hook := range {{$varNameSingular}}AfterUpdateHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doAfterDeleteHooks executes all "after Delete" hooks. // doAfterDeleteHooks executes all "after Delete" hooks.
func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterDeleteHooks { for _, hook := range {{$varNameSingular}}AfterDeleteHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// doAfterUpsertHooks executes all "after Upsert" hooks. // doAfterUpsertHooks executes all "after Upsert" hooks.
func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) { func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterUpsertHooks { for _, hook := range {{$varNameSingular}}AfterUpsertHooks {
if err := hook(exec, o); err != nil { if err := hook(exec, o); err != nil {
return err return err
} }
} }
return nil return nil
} }
// Add{{$tableNameSingular}}Hook registers your hook function for all future operations.
func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) { func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) {
switch hookPoint { switch hookPoint {
case boil.BeforeInsertHook: case boil.BeforeInsertHook:
{{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpdateHook: case boil.BeforeUpdateHook:
{{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook)
case boil.BeforeDeleteHook: case boil.BeforeDeleteHook:
{{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpsertHook: case boil.BeforeUpsertHook:
{{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook)
case boil.AfterInsertHook: case boil.AfterInsertHook:
{{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook)
case boil.AfterSelectHook: case boil.AfterSelectHook:
{{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook)
case boil.AfterUpdateHook: case boil.AfterUpdateHook:
{{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook)
case boil.AfterDeleteHook: case boil.AfterDeleteHook:
{{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook)
case boil.AfterUpsertHook: case boil.AfterUpsertHook:
{{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook)
} }
} }
{{- end}} {{- end}}

View file

@ -2,114 +2,115 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
// OneP returns a single {{$varNameSingular}} record from the query, and panics on error. // OneP returns a single {{$varNameSingular}} record from the query, and panics on error.
func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) { func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) {
o, err := q.One() o, err := q.One()
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return o return o
} }
// One returns a single {{$varNameSingular}} record from the query. // One returns a single {{$varNameSingular}} record from the query.
func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) {
o := &{{$tableNameSingular}}{} o := &{{$tableNameSingular}}{}
boil.SetLimit(q.Query, 1) queries.SetLimit(q.Query, 1)
err := q.Bind(o) err := q.Bind(o)
if err != nil { if err != nil {
if errors.Cause(err) == sql.ErrNoRows { if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows return nil, sql.ErrNoRows
} }
return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}") return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}")
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { if err := o.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil {
return o, err return o, err
} }
{{- end}} {{- end}}
return o, nil return o, nil
} }
// AllP returns all {{$tableNameSingular}} records from the query, and panics on error. // AllP returns all {{$tableNameSingular}} records from the query, and panics on error.
func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice { func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice {
o, err := q.All() o, err := q.All()
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return o return o
} }
// All returns all {{$tableNameSingular}} records from the query. // All returns all {{$tableNameSingular}} records from the query.
func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) { func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) {
var o {{$tableNameSingular}}Slice var o {{$tableNameSingular}}Slice
err := q.Bind(&o) err := q.Bind(&o)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice") return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice")
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if len({{$varNameSingular}}AfterSelectHooks) != 0 { if len({{$varNameSingular}}AfterSelectHooks) != 0 {
for _, obj := range o { for _, obj := range o {
if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { if err := obj.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil {
return o, err return o, err
} }
} }
} }
{{- end}} {{- end}}
return o, nil return o, nil
} }
// CountP returns the count of all {{$tableNameSingular}} records in the query, and panics on error. // CountP returns the count of all {{$tableNameSingular}} records in the query, and panics on error.
func (q {{$varNameSingular}}Query) CountP() int64 { func (q {{$varNameSingular}}Query) CountP() int64 {
c, err := q.Count() c, err := q.Count()
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return c return c
} }
// Count returns the count of all {{$tableNameSingular}} records in the query. // Count returns the count of all {{$tableNameSingular}} records in the query.
func (q {{$varNameSingular}}Query) Count() (int64, error) { func (q {{$varNameSingular}}Query) Count() (int64, error) {
var count int64 var count int64
boil.SetCount(q.Query) queries.SetSelect(q.Query, nil)
queries.SetCount(q.Query)
err := boil.ExecQueryOne(q.Query).Scan(&count) err := q.Query.QueryRow().Scan(&count)
if err != nil { if err != nil {
return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows")
} }
return count, nil return count, nil
} }
// Exists checks if the row exists in the table, and panics on error. // Exists checks if the row exists in the table, and panics on error.
func (q {{$varNameSingular}}Query) ExistsP() bool { func (q {{$varNameSingular}}Query) ExistsP() bool {
e, err := q.Exists() e, err := q.Exists()
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return e return e
} }
// Exists checks if the row exists in the table. // Exists checks if the row exists in the table.
func (q {{$varNameSingular}}Query) Exists() (bool, error) { func (q {{$varNameSingular}}Query) Exists() (bool, error) {
var count int64 var count int64
boil.SetCount(q.Query) queries.SetCount(q.Query)
boil.SetLimit(q.Query, 1) queries.SetLimit(q.Query, 1)
err := boil.ExecQueryOne(q.Query).Scan(&count) err := q.Query.QueryRow().Scan(&count)
if err != nil { if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists")
} }
return count > 0, nil return count > 0, nil
} }

View file

@ -1,30 +1,34 @@
{{- define "relationship_to_one_helper" -}} {{- define "relationship_to_one_helper" -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
// {{.Function.Name}}G pointed to by the foreign key. // {{.Function.Name}}G pointed to by the foreign key.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...)
} }
// {{.Function.Name}} pointed to by the foreign key. // {{.Function.Name}} pointed to by the foreign key.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) { func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) {
queryMods := []qm.QueryMod{ queryMods := []qm.QueryMod{
qm.Where("{{.ForeignTable.ColumnName}}=$1", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}),
} }
queryMods = append(queryMods, mods...) queryMods = append(queryMods, mods...)
query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...)
boil.SetFrom(query.Query, "{{.ForeignTable.Name}}") queries.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}")
return query return query
} }
{{- end -}}{{/* end with */}}
{{end -}}{{/* end define */}}
{{end -}} {{- /* Begin execution of template for one-to-one relationship */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.FKeys -}} {{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_helper" $rel -}} {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}

View file

@ -1,46 +1,51 @@
{{- /* Begin execution of template for many-to-one or many-to-many relationship helper */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- $table := .Table -}} {{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}} {{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} {{- /* Begin execution of template for many-to-one relationship. */ -}}
{{- else -}} {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}} {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}}
{{- else -}}
{{- /* Begin execution of template for many-to-many relationship. */ -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}}
// {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}}
{{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...) return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...)
} }
// {{$rel.Function.Name}} retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} with an executor // {{$rel.Function.Name}} retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} with an executor
{{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query {
queryMods := []qm.QueryMod{ queryMods := []qm.QueryMod{
qm.Select(`"{{id 0}}".*`), qm.Select("{{id 0 | $dot.Quotes}}.*"),
} }
if len(mods) != 0 { if len(mods) != 0 {
queryMods = append(queryMods, mods...) queryMods = append(queryMods, mods...)
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
queryMods = append(queryMods, queryMods = append(queryMods,
qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"),
qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
) )
{{else -}} {{else -}}
queryMods = append(queryMods, queryMods = append(queryMods,
qm.Where(`"{{id 0}}"."{{.ForeignColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
) )
{{end}} {{end}}
query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...)
boil.SetFrom(query.Query, `"{{.ForeignTable}}" as "{{id 0}}"`) queries.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}")
return query return query
} }
{{end -}}{{- /* if unique foreign key */ -}} {{end -}}{{- /* if unique foreign key */ -}}
{{- end -}}{{- /* range relationships */ -}} {{- end -}}{{- /* range relationships */ -}}
{{- end -}}{{- /* outer if join table */ -}} {{- end -}}{{- /* if isJoinTable */ -}}

View file

@ -1,91 +1,93 @@
{{- define "relationship_to_one_eager_helper" -}} {{- define "relationship_to_one_eager_helper" -}}
{{- $varNameSingular := .Dot.Table.Name | singular | camelCase -}} {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- $noHooks := .Dot.NoHooks -}} {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}}
{{- with .Rel -}} {{- with .Rel -}}
{{- $arg := printf "maybe%s" .LocalTable.NameGo -}} {{- $arg := printf "maybe%s" .LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" .LocalTable.NameGo -}} {{- $slice := printf "%sSlice" .LocalTable.NameGo -}}
// Load{{.Function.Name}} allows an eager lookup of values, cached into the // Load{{.Function.Name}} allows an eager lookup of values, cached into the
// loaded structs of the objects. // loaded structs of the objects.
func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error {
var slice []*{{.LocalTable.NameGo}} var slice []*{{.LocalTable.NameGo}}
var object *{{.LocalTable.NameGo}} var object *{{.LocalTable.NameGo}}
count := 1 count := 1
if singular { if singular {
object = {{$arg}}.(*{{.LocalTable.NameGo}}) object = {{$arg}}.(*{{.LocalTable.NameGo}})
} else { } else {
slice = *{{$arg}}.(*{{$slice}}) slice = *{{$arg}}.(*{{$slice}})
count = len(slice) count = len(slice)
} }
args := make([]interface{}, count) args := make([]interface{}, count)
if singular { if singular {
args[0] = object.{{.LocalTable.ColumnNameGo}} args[0] = object.{{.LocalTable.ColumnNameGo}}
} else { } else {
for i, obj := range slice { for i, obj := range slice {
args[i] = obj.{{.LocalTable.ColumnNameGo}} args[i] = obj.{{.LocalTable.ColumnNameGo}}
} }
} }
query := fmt.Sprintf( query := fmt.Sprintf(
`select * from "{{.ForeignKey.ForeignTable}}" where "{{.ForeignKey.ForeignColumn}}" in (%s)`, "select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(count, 1, 1), strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
) )
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
} }
results, err := e.Query(query, args...) results, err := e.Query(query, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}")
} }
defer results.Close() defer results.Close()
var resultSlice []*{{.ForeignTable.NameGo}} var resultSlice []*{{.ForeignTable.NameGo}}
if err = boil.Bind(results, &resultSlice); err != nil { if err = queries.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}")
} }
{{if not $noHooks -}} {{if not $dot.NoHooks -}}
if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice { for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil { if err := obj.doAfterSelectHooks(e); err != nil {
return err return err
} }
} }
} }
{{- end}} {{- end}}
if singular && len(resultSlice) != 0 { if singular && len(resultSlice) != 0 {
if object.R == nil { if object.R == nil {
object.R = &{{$varNameSingular}}R{} object.R = &{{$varNameSingular}}R{}
} }
object.R.{{.Function.Name}} = resultSlice[0] object.R.{{.Function.Name}} = resultSlice[0]
return nil return nil
} }
for _, foreign := range resultSlice { for _, foreign := range resultSlice {
for _, local := range slice { for _, local := range slice {
if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} {
if local.R == nil { if local.R == nil {
local.R = &{{$varNameSingular}}R{} local.R = &{{$varNameSingular}}R{}
} }
local.R.{{.Function.Name}} = foreign local.R.{{.Function.Name}} = foreign
break break
} }
} }
} }
return nil return nil
} }
{{- end -}} {{- end -}}{{- /* end with */ -}}
{{end -}} {{end -}}{{- /* end define */ -}}
{{- /* Begin execution of template for one-to-one eager load */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.FKeys -}} {{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $rel) -}} {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}}
{{- end -}} {{- end -}}
{{end}} {{end}}

View file

@ -1,136 +1,141 @@
{{- /* Begin execution of template for many-to-one or many-to-many eager load */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} {{- /* Begin execution of template for many-to-one eager load */ -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}}
{{- else -}} {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}}
{{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} {{- else -}}
{{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} {{- /* Begin execution of template for many-to-many eager load */ -}}
{{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}}
{{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}}
{{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}}
{{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}}
// Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the // Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the
// loaded structs of the objects. // loaded structs of the objects.
func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error {
var slice []*{{$txt.LocalTable.NameGo}} var slice []*{{$txt.LocalTable.NameGo}}
var object *{{$txt.LocalTable.NameGo}} var object *{{$txt.LocalTable.NameGo}}
count := 1 count := 1
if singular { if singular {
object = {{$arg}}.(*{{$txt.LocalTable.NameGo}}) object = {{$arg}}.(*{{$txt.LocalTable.NameGo}})
} else { } else {
slice = *{{$arg}}.(*{{$slice}}) slice = *{{$arg}}.(*{{$slice}})
count = len(slice) count = len(slice)
} }
args := make([]interface{}, count) args := make([]interface{}, count)
if singular { if singular {
args[0] = object.{{.Column | titleCase}} args[0] = object.{{.Column | titleCase}}
} else { } else {
for i, obj := range slice { for i, obj := range slice {
args[i] = obj.{{.Column | titleCase}} args[i] = obj.{{.Column | titleCase}}
} }
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
query := fmt.Sprintf( {{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}}
`select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from "{{.ForeignTable}}" as "{{id 0}}" inner join "{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, query := fmt.Sprintf(
strmangle.Placeholders(count, 1, 1), "select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)",
) strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
{{else -}} )
query := fmt.Sprintf( {{else -}}
`select * from "{{.ForeignTable}}" where "{{.ForeignColumn}}" in (%s)`, query := fmt.Sprintf(
strmangle.Placeholders(count, 1, 1), "select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)",
) strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
{{end -}} )
{{end -}}
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
} }
results, err := e.Query(query, args...) results, err := e.Query(query, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") return errors.Wrap(err, "failed to eager load {{.ForeignTable}}")
} }
defer results.Close() defer results.Close()
var resultSlice []*{{$txt.ForeignTable.NameGo}} var resultSlice []*{{$txt.ForeignTable.NameGo}}
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
{{- $foreignTable := getTable $dot.Tables .ForeignTable -}} {{- $foreignTable := getTable $dot.Tables .ForeignTable -}}
{{- $joinTable := getTable $dot.Tables .JoinTable -}} {{- $joinTable := getTable $dot.Tables .JoinTable -}}
{{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}}
var localJoinCols []{{$localCol.Type}} var localJoinCols []{{$localCol.Type}}
for results.Next() { for results.Next() {
one := new({{$txt.ForeignTable.NameGo}}) one := new({{$txt.ForeignTable.NameGo}})
var localJoinCol {{$localCol.Type}} var localJoinCol {{$localCol.Type}}
err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol)
if err = results.Err(); err != nil { if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
} }
resultSlice = append(resultSlice, one) resultSlice = append(resultSlice, one)
localJoinCols = append(localJoinCols, localJoinCol) localJoinCols = append(localJoinCols, localJoinCol)
} }
if err = results.Err(); err != nil { if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
} }
{{else -}} {{else -}}
if err = boil.Bind(results, &resultSlice); err != nil { if err = queries.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}")
} }
{{end}} {{end}}
{{if not $dot.NoHooks -}} {{if not $dot.NoHooks -}}
if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 { if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice { for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil { if err := obj.doAfterSelectHooks(e); err != nil {
return err return err
} }
} }
} }
{{- end}} {{- end}}
if singular { if singular {
if object.R == nil { if object.R == nil {
object.R = &{{$varNameSingular}}R{} object.R = &{{$varNameSingular}}R{}
} }
object.R.{{$txt.Function.Name}} = resultSlice object.R.{{$txt.Function.Name}} = resultSlice
return nil return nil
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
for i, foreign := range resultSlice { for i, foreign := range resultSlice {
localJoinCol := localJoinCols[i] localJoinCol := localJoinCols[i]
for _, local := range slice { for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == localJoinCol { if local.{{$txt.Function.LocalAssignment}} == localJoinCol {
if local.R == nil { if local.R == nil {
local.R = &{{$varNameSingular}}R{} local.R = &{{$varNameSingular}}R{}
} }
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break break
} }
} }
} }
{{else -}} {{else -}}
for _, foreign := range resultSlice { for _, foreign := range resultSlice {
for _, local := range slice { for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} { if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} {
if local.R == nil { if local.R == nil {
local.R = &{{$varNameSingular}}R{} local.R = &{{$varNameSingular}}R{}
} }
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break break
} }
} }
} }
{{end}} {{end}}
return nil return nil
} }
{{end -}}{{/* if ForeignColumnUnique */}} {{end -}}{{/* if ForeignColumnUnique */}}
{{- end -}}{{/* range tomany */}} {{- end -}}{{/* range tomany */}}
{{- end -}}{{/* if isjointable */}} {{- end -}}{{/* if IsJoinTable */}}

View file

@ -1,101 +1,105 @@
{{- define "relationship_to_one_setops_helper" -}} {{- define "relationship_to_one_setops_helper" -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} {{- with .Rel -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}}
// Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item. // Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item.
// Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related.
// Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}. // Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec boil.Executor, insert bool, related *{{.ForeignTable.NameGo}}) error { func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec boil.Executor, insert bool, related *{{.ForeignTable.NameGo}}) error {
var err error var err error
if insert { if insert {
if err = related.Insert(exec); err != nil { if err = related.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table") return errors.Wrap(err, "failed to insert into foreign table")
} }
} }
oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}
{{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}} {{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}}
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal
return errors.Wrap(err, "failed to update local table") return errors.Wrap(err, "failed to update local table")
} }
if {{.Function.Receiver}}.R == nil { if {{.Function.Receiver}}.R == nil {
{{.Function.Receiver}}.R = &{{$localNameSingular}}R{ {{.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{.Function.Name}}: related, {{.Function.Name}}: related,
} }
} else { } else {
{{.Function.Receiver}}.R.{{.Function.Name}} = related {{.Function.Receiver}}.R.{{.Function.Name}} = related
} }
{{if (or .ForeignKey.Unique .Function.OneToOne) -}} {{if (or .ForeignKey.Unique .Function.OneToOne) -}}
if related.R == nil { if related.R == nil {
related.R = &{{$varNameSingular}}R{ related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.Function.Receiver}}, {{.Function.ForeignName}}: {{.Function.Receiver}},
} }
} else { } else {
related.R.{{.Function.ForeignName}} = {{.Function.Receiver}} related.R.{{.Function.ForeignName}} = {{.Function.Receiver}}
} }
{{else -}} {{else -}}
if related.R == nil { if related.R == nil {
related.R = &{{$varNameSingular}}R{ related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}}, {{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}},
} }
} else { } else {
related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}}) related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}})
} }
{{end -}} {{end -}}
{{if .ForeignKey.Nullable}} {{if .ForeignKey.Nullable}}
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
{{end -}} {{end -}}
return nil return nil
} }
{{- if .ForeignKey.Nullable}}
{{- if .ForeignKey.Nullable}}
// Remove{{.Function.Name}} relationship. // Remove{{.Function.Name}} relationship.
// Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil.
// Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional). // Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional).
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(exec boil.Executor, related *{{.ForeignTable.NameGo}}) error { func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(exec boil.Executor, related *{{.ForeignTable.NameGo}}) error {
var err error var err error
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
return errors.Wrap(err, "failed to update local table") return errors.Wrap(err, "failed to update local table")
} }
{{.Function.Receiver}}.R.{{.Function.Name}} = nil {{.Function.Receiver}}.R.{{.Function.Name}} = nil
if related == nil || related.R == nil { if related == nil || related.R == nil {
return nil return nil
} }
{{if .ForeignKey.Unique -}} {{if .ForeignKey.Unique -}}
related.R.{{.Function.ForeignName}} = nil related.R.{{.Function.ForeignName}} = nil
{{else -}} {{else -}}
for i, ri := range related.R.{{.Function.ForeignName}} { for i, ri := range related.R.{{.Function.ForeignName}} {
if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} { if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} {
continue continue
} }
ln := len(related.R.{{.Function.ForeignName}}) ln := len(related.R.{{.Function.ForeignName}})
if ln > 1 && i < ln-1 { if ln > 1 && i < ln-1 {
related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1] related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1]
} }
related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1] related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1]
break break
} }
{{end -}} {{end -}}
return nil return nil
} }
{{end -}} {{- end -}}{{/* if foreignkey nullable */}}
{{- end -}} {{end -}}{{/* end with */}}
{{- end -}}{{/* end define */}}
{{- /* Begin execution of template for one-to-one setops */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.FKeys -}} {{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_setops_helper" $rel -}} {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}

View file

@ -1,91 +1,93 @@
{{- /* Begin execution of template for many-to-one or many-to-many setops */ -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- $table := .Table -}} {{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}} {{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_setops_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} {{- /* Begin execution of template for many-to-one setops */ -}}
{{- else -}} {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}} {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}}
{{- $localNameSingular := .Table | singular | camelCase -}} {{- else -}}
{{- $foreignNameSingular := .ForeignTable | singular | camelCase}} {{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $localNameSingular := .Table | singular | camelCase -}}
{{- $foreignNameSingular := .ForeignTable | singular | camelCase}}
// Add{{$rel.Function.Name}} adds the given related objects to the existing relationships // Add{{$rel.Function.Name}} adds the given related objects to the existing relationships
// of the {{$table.Name | singular}}, optionally inserting them as new records. // of the {{$table.Name | singular}}, optionally inserting them as new records.
// Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}. // Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}.
// Sets related.R.{{$rel.Function.ForeignName}} appropriately. // Sets related.R.{{$rel.Function.ForeignName}} appropriately.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error {
var err error var err error
for _, rel := range related { for _, rel := range related {
{{if not .ToJoinTable -}} {{if not .ToJoinTable -}}
rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}}
{{if .ForeignColumnNullable -}} {{if .ForeignColumnNullable -}}
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true
{{end -}} {{end -}}
{{end -}} {{end -}}
if insert { if insert {
if err = rel.Insert(exec); err != nil { if err = rel.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table") return errors.Wrap(err, "failed to insert into foreign table")
} }
}{{if not .ToJoinTable}} else { }{{if not .ToJoinTable}} else {
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return errors.Wrap(err, "failed to update foreign table") return errors.Wrap(err, "failed to update foreign table")
} }
}{{end -}} }{{end -}}
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
for _, rel := range related { for _, rel := range related {
query := `insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}}
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query) fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values) fmt.Fprintln(boil.DebugWriter, values)
} }
_, err = exec.Exec(query, values...) _, err = exec.Exec(query, values...)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to insert into join table") return errors.Wrap(err, "failed to insert into join table")
} }
} }
{{end -}} {{end -}}
if {{$rel.Function.Receiver}}.R == nil { if {{$rel.Function.Receiver}}.R == nil {
{{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{ {{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{$rel.Function.Name}}: related, {{$rel.Function.Name}}: related,
} }
} else { } else {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...) {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...)
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
for _, rel := range related { for _, rel := range related {
if rel.R == nil { if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{ rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}}, {{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}},
} }
} else { } else {
rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}}) rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}})
} }
} }
{{else -}} {{else -}}
for _, rel := range related { for _, rel := range related {
if rel.R == nil { if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{ rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}}, {{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}},
} }
} else { } else {
rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}} rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}}
} }
} }
{{end -}} {{end -}}
return nil return nil
} }
{{- if (or .ForeignColumnNullable .ToJoinTable)}}
{{- if (or .ForeignColumnNullable .ToJoinTable)}}
// Set{{$rel.Function.Name}} removes all previously related items of the // Set{{$rel.Function.Name}} removes all previously related items of the
// {{$table.Name | singular}} replacing them completely with the passed // {{$table.Name | singular}} replacing them completely with the passed
// in related items, optionally inserting them as new records. // in related items, optionally inserting them as new records.
@ -93,126 +95,126 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function
// Replaces {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} with related. // Replaces {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} with related.
// Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error {
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
query := `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1` query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{else -}} {{else -}}
query := `update "{{.ForeignTable}}" set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{end -}} {{end -}}
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query) fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values) fmt.Fprintln(boil.DebugWriter, values)
} }
_, err := exec.Exec(query, values...) _, err := exec.Exec(query, values...)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to remove relationships before set") return errors.Wrap(err, "failed to remove relationships before set")
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
{{else -}} {{else -}}
if {{$rel.Function.Receiver}}.R != nil { if {{$rel.Function.Receiver}}.R != nil {
for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
if rel.R == nil { if rel.R == nil {
continue continue
} }
rel.R.{{$rel.Function.ForeignName}} = nil rel.R.{{$rel.Function.ForeignName}} = nil
} }
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
} }
{{end -}} {{end -}}
return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...) return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...)
} }
// Remove{{$rel.Function.Name}} relationships from objects passed in. // Remove{{$rel.Function.Name}} relationships from objects passed in.
// Removes related items from R.{{$rel.Function.Name}} (uses pointer comparison, removal does not keep order) // Removes related items from R.{{$rel.Function.Name}} (uses pointer comparison, removal does not keep order)
// Sets related.R.{{$rel.Function.ForeignName}}. // Sets related.R.{{$rel.Function.ForeignName}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Function.Name}}(exec boil.Executor, related ...*{{$rel.ForeignTable.NameGo}}) error { func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Function.Name}}(exec boil.Executor, related ...*{{$rel.ForeignTable.NameGo}}) error {
var err error var err error
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
query := fmt.Sprintf( query := fmt.Sprintf(
`delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(len(related), 1, 1), strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1),
) )
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query) fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values) fmt.Fprintln(boil.DebugWriter, values)
} }
_, err = exec.Exec(query, values...) _, err = exec.Exec(query, values...)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to remove relationships before set") return errors.Wrap(err, "failed to remove relationships before set")
} }
{{else -}} {{else -}}
for _, rel := range related { for _, rel := range related {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
{{if not .ToJoinTable -}} {{if not .ToJoinTable -}}
if rel.R != nil { if rel.R != nil {
rel.R.{{$rel.Function.ForeignName}} = nil rel.R.{{$rel.Function.ForeignName}} = nil
} }
{{end -}} {{end -}}
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return err return err
} }
} }
{{end -}} {{end -}}
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{end -}} {{end -}}
if {{$rel.Function.Receiver}}.R == nil { if {{$rel.Function.Receiver}}.R == nil {
return nil return nil
} }
for _, rel := range related { for _, rel := range related {
for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
if rel != ri { if rel != ri {
continue continue
} }
ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}) ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}})
if ln > 1 && i < ln-1 { if ln > 1 && i < ln-1 {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1] {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1]
} }
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1] {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1]
break break
} }
} }
return nil return nil
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) { func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) {
for _, rel := range related { for _, rel := range related {
if rel.R == nil { if rel.R == nil {
continue continue
} }
for i, ri := range rel.R.{{$rel.Function.ForeignName}} { for i, ri := range rel.R.{{$rel.Function.ForeignName}} {
if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} { if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} {
continue continue
} }
ln := len(rel.R.{{$rel.Function.ForeignName}}) ln := len(rel.R.{{$rel.Function.ForeignName}})
if ln > 1 && i < ln-1 { if ln > 1 && i < ln-1 {
rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1] rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1]
} }
rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1] rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1]
break break
} }
} }
} }
{{end -}}{{- /* if join table */ -}} {{end -}}{{- /* if ToJoinTable */ -}}
{{- end -}}{{- /* if nullable foreign key */ -}} {{- end -}}{{- /* if nullable foreign key */ -}}
{{- end -}}{{- /* if unique foreign key */ -}} {{- end -}}{{- /* if unique foreign key */ -}}
{{- end -}}{{- /* range relationships */ -}} {{- end -}}{{- /* range relationships */ -}}
{{- end -}}{{- /* outer if join table */ -}} {{- end -}}{{- /* if IsJoinTable */ -}}

View file

@ -3,11 +3,11 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
// {{$tableNamePlural}}G retrieves all records. // {{$tableNamePlural}}G retrieves all records.
func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{$tableNamePlural}}(boil.GetDB(), mods...) return {{$tableNamePlural}}(boil.GetDB(), mods...)
} }
// {{$tableNamePlural}} retrieves all the records using an executor. // {{$tableNamePlural}} retrieves all the records using an executor.
func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query {
mods = append(mods, qm.From("{{.Table.Name}}")) mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}"))
return {{$varNameSingular}}Query{NewQuery(exec, mods...)} return {{$varNameSingular}}Query{NewQuery(exec, mods...)}
} }

View file

@ -4,53 +4,53 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
// {{$tableNameSingular}}FindG retrieves a single record by ID. // Find{{$tableNameSingular}}G retrieves a single record by ID.
func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) {
return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
} }
// {{$tableNameSingular}}FindGP retrieves a single record by ID, and panics on error. // Find{{$tableNameSingular}}GP retrieves a single record by ID, and panics on error.
func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} {
retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return retobj return retobj
} }
// {{$tableNameSingular}}Find retrieves a single record by ID with an executor. // Find{{$tableNameSingular}} retrieves a single record by ID with an executor.
// If selectCols is empty Find will return all columns. // If selectCols is empty Find will return all columns.
func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) {
{{$varNameSingular}}Obj := &{{$tableNameSingular}}{} {{$varNameSingular}}Obj := &{{$tableNameSingular}}{}
sel := "*" sel := "*"
if len(selectCols) > 0 { if len(selectCols) > 0 {
sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",") sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",")
} }
query := fmt.Sprintf( query := fmt.Sprintf(
`select %s from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}}`, sel, "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel,
) )
q := boil.SQL(exec, query, {{$pkNames | join ", "}}) q := queries.Raw(exec, query, {{$pkNames | join ", "}})
err := q.Bind({{$varNameSingular}}Obj) err := q.Bind({{$varNameSingular}}Obj)
if err != nil { if err != nil {
if errors.Cause(err) == sql.ErrNoRows { if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows return nil, sql.ErrNoRows
} }
return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}")
} }
return {{$varNameSingular}}Obj, nil return {{$varNameSingular}}Obj, nil
} }
// {{$tableNameSingular}}FindP retrieves a single record by ID with an executor, and panics on error. // Find{{$tableNameSingular}}P retrieves a single record by ID with an executor, and panics on error.
func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} {
retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...)
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return retobj return retobj
} }

View file

@ -1,24 +1,25 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// InsertG a single record. See Insert for whitelist behavior description. // InsertG a single record. See Insert for whitelist behavior description.
func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error { func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error {
return o.Insert(boil.GetDB(), whitelist...) return o.Insert(boil.GetDB(), whitelist...)
} }
// InsertGP a single record, and panics on error. See Insert for whitelist // InsertGP a single record, and panics on error. See Insert for whitelist
// behavior description. // behavior description.
func (o *{{$tableNameSingular}}) InsertGP(whitelist ... string) { func (o *{{$tableNameSingular}}) InsertGP(whitelist ... string) {
if err := o.Insert(boil.GetDB(), whitelist...); err != nil { if err := o.Insert(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// InsertP a single record using an executor, and panics on error. See Insert // InsertP a single record using an executor, and panics on error. See Insert
// for whitelist behavior description. // for whitelist behavior description.
func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... string) { func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... string) {
if err := o.Insert(exec, whitelist...); err != nil { if err := o.Insert(exec, whitelist...); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// Insert a single record using an executor. // Insert a single record using an executor.
@ -27,115 +28,131 @@ func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... strin
// - All columns without a default value are included (i.e. name, age) // - All columns without a default value are included (i.e. name, age)
// - All columns with a default, but non-zero are included (i.e. health = 75) // - All columns with a default, but non-zero are included (i.e. health = 75)
func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string) error { func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string) error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
} }
var err error var err error
{{- template "timestamp_insert_helper" . }} {{- template "timestamp_insert_helper" . }}
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err := o.doBeforeInsertHooks(exec); err != nil { if err := o.doBeforeInsertHooks(exec); err != nil {
return err return err
} }
{{- end}} {{- end}}
nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) nzDefaults := queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
key := makeCacheKey(whitelist, nzDefaults) key := makeCacheKey(whitelist, nzDefaults)
{{$varNameSingular}}InsertCacheMut.RLock() {{$varNameSingular}}InsertCacheMut.RLock()
cache, cached := {{$varNameSingular}}InsertCache[key] cache, cached := {{$varNameSingular}}InsertCache[key]
{{$varNameSingular}}InsertCacheMut.RUnlock() {{$varNameSingular}}InsertCacheMut.RUnlock()
if !cached { if !cached {
wl, returnColumns := strmangle.InsertColumnSet( wl, returnColumns := strmangle.InsertColumnSet(
{{$varNameSingular}}Columns, {{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault, {{$varNameSingular}}ColumnsWithoutDefault,
nzDefaults, nzDefaults,
whitelist, whitelist,
) )
cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl)
if err != nil { if err != nil {
return err return err
} }
cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns)
if err != nil { if err != nil {
return err return err
} }
cache.query = fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1)) cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1))
if len(cache.retMapping) != 0 { if len(cache.retMapping) != 0 {
{{if .UseLastInsertID -}} {{if .UseLastInsertID -}}
cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns))
{{else -}} {{else -}}
cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"))
{{end -}} {{end -}}
} }
} }
value := reflect.Indirect(reflect.ValueOf(o)) value := reflect.Indirect(reflect.ValueOf(o))
vals := boil.ValuesFromMapping(value, cache.valueMapping) vals := queries.ValuesFromMapping(value, cache.valueMapping)
{{if .UseLastInsertID}} {{if .UseLastInsertID}}
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query) fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, vals) fmt.Fprintln(boil.DebugWriter, vals)
} }
result, err := exec.Exec(ins, vals...) result, err := exec.Exec(cache.query, vals...)
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
} }
if len(cache.retMapping) == 0 { var lastID int64
{{if not .NoHooks -}} var identifierCols []interface{}
return o.doAfterInsertHooks(exec)
{{else -}}
return nil
{{end -}}
}
lastID, err := result.LastInsertId() if len(cache.retMapping) == 0 {
if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 { goto CacheNoHooks
return ErrSyncFail }
}
if boil.DebugMode { lastID, err = result.LastInsertId()
fmt.Fprintln(boil.DebugWriter, cache.retQuery) if err != nil {
fmt.Fprintln(boil.DebugWriter, lastID) return ErrSyncFail
} }
err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) if lastID != 0 {
if err != nil { {{- $colName := index .Table.PKey.Columns 0 -}}
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") {{- $col := .Table.GetColumn $colName -}}
} o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID)
{{else}} identifierCols = []interface{}{lastID}
if len(cache.retMapping) != 0 { } else {
err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) identifierCols = []interface{}{
} else { {{range .Table.PKey.Columns -}}
_, err = exec.Exec(cache.query, vals...) o.{{. | singular | titleCase}},
} {{end -}}
}
}
if boil.DebugMode { if lastID != 0 && len(cache.retMapping) == 1 {
fmt.Fprintln(boil.DebugWriter, cache.query) if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, vals) fmt.Fprintln(boil.DebugWriter, cache.retQuery)
} fmt.Fprintln(boil.DebugWriter, identifierCols...)
}
if err != nil { err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...)
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") if err != nil {
} return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
{{end}} }
}
{{else}}
if len(cache.retMapping) != 0 {
err = exec.QueryRow(cache.query, vals...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...)
} else {
_, err = exec.Exec(cache.query, vals...)
}
if !cached { if boil.DebugMode {
{{$varNameSingular}}InsertCacheMut.Lock() fmt.Fprintln(boil.DebugWriter, cache.query)
{{$varNameSingular}}InsertCache[key] = cache fmt.Fprintln(boil.DebugWriter, vals)
{{$varNameSingular}}InsertCacheMut.Unlock() }
}
{{if not .NoHooks -}} if err != nil {
return o.doAfterInsertHooks(exec) return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
{{- else -}} }
return nil {{end}}
{{- end}} {{if .UseLastInsertID -}}
CacheNoHooks:
{{- end}}
if !cached {
{{$varNameSingular}}InsertCacheMut.Lock()
{{$varNameSingular}}InsertCache[key] = cache
{{$varNameSingular}}InsertCacheMut.Unlock()
}
{{if not .NoHooks -}}
return o.doAfterInsertHooks(exec)
{{- else -}}
return nil
{{- end}}
} }

View file

@ -3,28 +3,29 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// UpdateG a single {{$tableNameSingular}} record. See Update for // UpdateG a single {{$tableNameSingular}} record. See Update for
// whitelist behavior description. // whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error { func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error {
return o.Update(boil.GetDB(), whitelist...) return o.Update(boil.GetDB(), whitelist...)
} }
// UpdateGP a single {{$tableNameSingular}} record. // UpdateGP a single {{$tableNameSingular}} record.
// UpdateGP takes a whitelist of column names that should be updated. // UpdateGP takes a whitelist of column names that should be updated.
// Panics on error. See Update for whitelist behavior description. // Panics on error. See Update for whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateGP(whitelist ...string) { func (o *{{$tableNameSingular}}) UpdateGP(whitelist ...string) {
if err := o.Update(boil.GetDB(), whitelist...); err != nil { if err := o.Update(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// UpdateP uses an executor to update the {{$tableNameSingular}}, and panics on error. // UpdateP uses an executor to update the {{$tableNameSingular}}, and panics on error.
// See Update for whitelist behavior description. // See Update for whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... string) { func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... string) {
err := o.Update(exec, whitelist...) err := o.Update(exec, whitelist...)
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// Update uses an executor to update the {{$tableNameSingular}}. // Update uses an executor to update the {{$tableNameSingular}}.
@ -35,146 +36,147 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin
// Update does not automatically update the record in case of default values. Use .Reload() // Update does not automatically update the record in case of default values. Use .Reload()
// to refresh the records. // to refresh the records.
func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error { func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error {
{{- template "timestamp_update_helper" . -}} {{- template "timestamp_update_helper" . -}}
var err error var err error
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err = o.doBeforeUpdateHooks(exec); err != nil { if err = o.doBeforeUpdateHooks(exec); err != nil {
return err return err
} }
{{end -}} {{end -}}
key := makeCacheKey(whitelist, nil) key := makeCacheKey(whitelist, nil)
{{$varNameSingular}}UpdateCacheMut.RLock() {{$varNameSingular}}UpdateCacheMut.RLock()
cache, cached := {{$varNameSingular}}UpdateCache[key] cache, cached := {{$varNameSingular}}UpdateCache[key]
{{$varNameSingular}}UpdateCacheMut.RUnlock() {{$varNameSingular}}UpdateCacheMut.RUnlock()
if !cached { if !cached {
wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist)
cache.query = fmt.Sprintf(`UPDATE "{{.Table.Name}}" SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s",
cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl),
if err != nil { strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns),
return err )
} cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...))
} if err != nil {
return err
}
}
if len(cache.valueMapping) == 0 { if len(cache.valueMapping) == 0 {
return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
} }
values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) values := queries.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping)
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query) fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, values) fmt.Fprintln(boil.DebugWriter, values)
} }
result, err := exec.Exec(cache.query, values...) result, err := exec.Exec(cache.query, values...)
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row")
} }
if r, err := result.RowsAffected(); err == nil && r != 1 { if r, err := result.RowsAffected(); err == nil && r != 1 {
return errors.Errorf("failed to update single row, updated %d rows", r) return errors.Errorf("failed to update single row, updated %d rows", r)
} }
if !cached { if !cached {
{{$varNameSingular}}UpdateCacheMut.Lock() {{$varNameSingular}}UpdateCacheMut.Lock()
{{$varNameSingular}}UpdateCache[key] = cache {{$varNameSingular}}UpdateCache[key] = cache
{{$varNameSingular}}UpdateCacheMut.Unlock() {{$varNameSingular}}UpdateCacheMut.Unlock()
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
return o.doAfterUpdateHooks(exec) return o.doAfterUpdateHooks(exec)
{{- else -}} {{- else -}}
return nil return nil
{{- end}} {{- end}}
} }
// UpdateAllP updates all rows with matching column names, and panics on error. // UpdateAllP updates all rows with matching column names, and panics on error.
func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { func (q {{$varNameSingular}}Query) UpdateAllP(cols M) {
if err := q.UpdateAll(cols); err != nil { if err := q.UpdateAll(cols); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// UpdateAll updates all rows with the specified column values. // UpdateAll updates all rows with the specified column values.
func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
boil.SetUpdate(q.Query, cols) queries.SetUpdate(q.Query, cols)
_, err := boil.ExecQuery(q.Query) _, err := q.Query.Exec()
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}")
} }
return nil return nil
} }
// UpdateAllG updates all rows with the specified column values. // UpdateAllG updates all rows with the specified column values.
func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error { func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error {
return o.UpdateAll(boil.GetDB(), cols) return o.UpdateAll(boil.GetDB(), cols)
} }
// UpdateAllGP updates all rows with the specified column values, and panics on error. // UpdateAllGP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) { func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) {
if err := o.UpdateAll(boil.GetDB(), cols); err != nil { if err := o.UpdateAll(boil.GetDB(), cols); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// UpdateAllP updates all rows with the specified column values, and panics on error. // UpdateAllP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) { func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) {
if err := o.UpdateAll(exec, cols); err != nil { if err := o.UpdateAll(exec, cols); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// UpdateAll updates all rows with the specified column values, using an executor. // UpdateAll updates all rows with the specified column values, using an executor.
func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error { func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error {
ln := int64(len(o)) ln := int64(len(o))
if ln == 0 { if ln == 0 {
return nil return nil
} }
if len(cols) == 0 { if len(cols) == 0 {
return errors.New("{{.PkgName}}: update all requires at least one column argument") return errors.New("{{.PkgName}}: update all requires at least one column argument")
} }
colNames := make([]string, len(cols)) colNames := make([]string, len(cols))
args := make([]interface{}, len(cols)) args := make([]interface{}, len(cols))
i := 0 i := 0
for name, value := range cols { for name, value := range cols {
colNames[i] = strmangle.IdentQuote(name) colNames[i] = name
args[i] = value args[i] = value
i++ i++
} }
// Append all of the primary key values for each column // Append all of the primary key values for each column
args = append(args, o.inPrimaryKeyArgs()...) args = append(args, o.inPrimaryKeyArgs()...)
sql := fmt.Sprintf( sql := fmt.Sprintf(
`UPDATE {{.Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, "UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)",
strings.Join(colNames, ", "), strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames),
strmangle.Placeholders(len(colNames), 1, 1), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)),
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), )
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql) fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...) fmt.Fprintln(boil.DebugWriter, args...)
} }
result, err := exec.Exec(sql, args...) result, err := exec.Exec(sql, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice")
} }
if r, err := result.RowsAffected(); err == nil && r != ln { if r, err := result.RowsAffected(); err == nil && r != ln {
return errors.Errorf("failed to update %d rows, only affected %d", ln, r) return errors.Errorf("failed to update %d rows, only affected %d", ln, r)
} }
return nil return nil
} }

View file

@ -1,85 +1,188 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// UpsertG attempts an insert, and does an update or ignore on conflict. // UpsertG attempts an insert, and does an update or ignore on conflict.
func (o *{{$tableNameSingular}}) UpsertG(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error {
return o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...) return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...)
} }
// UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error.
func (o *{{$tableNameSingular}}) UpsertGP(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) {
if err := o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// UpsertP attempts an insert using an executor, and does an update or ignore on conflict. // UpsertP attempts an insert using an executor, and does an update or ignore on conflict.
// UpsertP panics on error. // UpsertP panics on error.
func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) {
if err := o.Upsert(exec, updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// Upsert attempts an insert using an executor, and does an update or ignore on conflict. // Upsert attempts an insert using an executor, and does an update or ignore on conflict.
func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert")
} }
{{- template "timestamp_upsert_helper" . }} {{- template "timestamp_upsert_helper" . }}
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err := o.doBeforeUpsertHooks(exec); err != nil { if err := o.doBeforeUpsertHooks(exec); err != nil {
return err return err
} }
{{- end}} {{- end}}
var err error // Build cache key in-line uglily - mysql vs postgres problems
var ret []string buf := strmangle.GetBuffer()
whitelist, ret = strmangle.InsertColumnSet( {{if ne .DriverName "mysql" -}}
{{$varNameSingular}}Columns, if updateOnConflict {
{{$varNameSingular}}ColumnsWithDefault, buf.WriteByte('t')
{{$varNameSingular}}ColumnsWithoutDefault, } else {
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), buf.WriteByte('f')
whitelist, }
) buf.WriteByte('.')
update := strmangle.UpdateColumnSet( for _, c := range conflictColumns {
{{$varNameSingular}}Columns, buf.WriteString(c)
{{$varNameSingular}}PrimaryKeyColumns, }
updateColumns, buf.WriteByte('.')
) {{end -}}
conflict := conflictColumns for _, c := range updateColumns {
if len(conflict) == 0 { buf.WriteString(c)
conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) }
copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) buf.WriteByte('.')
} for _, c := range whitelist {
buf.WriteString(c)
}
key := buf.String()
strmangle.PutBuffer(buf)
query := boil.BuildUpsertQuery("{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) {{$varNameSingular}}UpsertCacheMut.RLock()
cache, cached := {{$varNameSingular}}UpsertCache[key]
{{$varNameSingular}}UpsertCacheMut.RUnlock()
if boil.DebugMode { var err error
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...))
}
{{- if .UseLastInsertID}} if !cached {
return errors.New("don't know how to do this yet") var ret []string
{{- else}} whitelist, ret = strmangle.InsertColumnSet(
if len(ret) != 0 { {{$varNameSingular}}Columns,
err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) {{$varNameSingular}}ColumnsWithDefault,
} else { {{$varNameSingular}}ColumnsWithoutDefault,
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o),
} whitelist,
{{- end}} )
update := strmangle.UpdateColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns,
updateColumns,
)
if err != nil { {{if ne .DriverName "mysql" -}}
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") var conflict []string
} if len(conflictColumns) == 0 {
conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns))
copy(conflict, {{$varNameSingular}}PrimaryKeyColumns)
}
cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist)
{{- else -}}
cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist)
cache.retQuery = fmt.Sprintf(
"SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}",
strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","),
)
{{- end}}
{{if not .NoHooks -}} cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, whitelist)
if err := o.doAfterUpsertHooks(exec); err != nil { if err != nil {
return err return err
} }
{{- end}} if len(ret) != 0 {
cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, ret)
if err != nil {
return err
}
}
}
return nil value := reflect.Indirect(reflect.ValueOf(o))
values := queries.ValuesFromMapping(value, cache.valueMapping)
var returns []interface{}
if len(cache.retMapping) != 0 {
returns = queries.PtrsFromMapping(value, cache.retMapping)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, values)
}
{{- if .UseLastInsertID}}
result, err := exec.Exec(cache.query, values...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}")
}
if len(cache.retMapping) == 0 {
{{if not .NoHooks -}}
return o.doAfterUpsertHooks(exec)
{{else -}}
return nil
{{end -}}
}
lastID, err := result.LastInsertId()
if err != nil {
return ErrSyncFail
}
var identifierCols []interface{}
if lastID != 0 {
{{- $colName := index .Table.PKey.Columns 0 -}}
{{- $col := .Table.GetColumn $colName -}}
o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID)
identifierCols = []interface{}{lastID}
} else {
identifierCols = []interface{}{
{{range .Table.PKey.Columns -}}
o.{{. | singular | titleCase}},
{{end -}}
}
}
if lastID != 0 && len(cache.retMapping) == 1 {
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.retQuery)
fmt.Fprintln(boil.DebugWriter, identifierCols...)
}
err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(returns...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
}
}
{{- else}}
if len(cache.retMapping) != 0 {
err = exec.QueryRow(cache.query, values...).Scan(returns...)
} else {
_, err = exec.Exec(cache.query, values...)
}
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}")
}
{{- end}}
if !cached {
{{$varNameSingular}}UpsertCacheMut.Lock()
{{$varNameSingular}}UpsertCache[key] = cache
{{$varNameSingular}}UpsertCacheMut.Unlock()
}
{{if not .NoHooks -}}
return o.doAfterUpsertHooks(exec)
{{- else -}}
return nil
{{- end}}
} }

View file

@ -1,161 +1,162 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// DeleteP deletes a single {{$tableNameSingular}} record with an executor. // DeleteP deletes a single {{$tableNameSingular}} record with an executor.
// DeleteP will match against the primary key column to find the record to delete. // DeleteP will match against the primary key column to find the record to delete.
// Panics on error. // Panics on error.
func (o *{{$tableNameSingular}}) DeleteP(exec boil.Executor) { func (o *{{$tableNameSingular}}) DeleteP(exec boil.Executor) {
if err := o.Delete(exec); err != nil { if err := o.Delete(exec); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// DeleteG deletes a single {{$tableNameSingular}} record. // DeleteG deletes a single {{$tableNameSingular}} record.
// DeleteG will match against the primary key column to find the record to delete. // DeleteG will match against the primary key column to find the record to delete.
func (o *{{$tableNameSingular}}) DeleteG() error { func (o *{{$tableNameSingular}}) DeleteG() error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion") return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion")
} }
return o.Delete(boil.GetDB()) return o.Delete(boil.GetDB())
} }
// DeleteGP deletes a single {{$tableNameSingular}} record. // DeleteGP deletes a single {{$tableNameSingular}} record.
// DeleteGP will match against the primary key column to find the record to delete. // DeleteGP will match against the primary key column to find the record to delete.
// Panics on error. // Panics on error.
func (o *{{$tableNameSingular}}) DeleteGP() { func (o *{{$tableNameSingular}}) DeleteGP() {
if err := o.DeleteG(); err != nil { if err := o.DeleteG(); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// Delete deletes a single {{$tableNameSingular}} record with an executor. // Delete deletes a single {{$tableNameSingular}} record with an executor.
// Delete will match against the primary key column to find the record to delete. // Delete will match against the primary key column to find the record to delete.
func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete") return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete")
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err := o.doBeforeDeleteHooks(exec); err != nil { if err := o.doBeforeDeleteHooks(exec); err != nil {
return err return err
} }
{{- end}} {{- end}}
args := o.inPrimaryKeyArgs() args := o.inPrimaryKeyArgs()
sql := `DELETE FROM {{.Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}"
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql) fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...) fmt.Fprintln(boil.DebugWriter, args...)
} }
_, err := exec.Exec(sql, args...) _, err := exec.Exec(sql, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}") return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}")
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if err := o.doAfterDeleteHooks(exec); err != nil { if err := o.doAfterDeleteHooks(exec); err != nil {
return err return err
} }
{{- end}} {{- end}}
return nil return nil
} }
// DeleteAllP deletes all rows, and panics on error. // DeleteAllP deletes all rows, and panics on error.
func (q {{$varNameSingular}}Query) DeleteAllP() { func (q {{$varNameSingular}}Query) DeleteAllP() {
if err := q.DeleteAll(); err != nil { if err := q.DeleteAll(); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// DeleteAll deletes all matching rows. // DeleteAll deletes all matching rows.
func (q {{$varNameSingular}}Query) DeleteAll() error { func (q {{$varNameSingular}}Query) DeleteAll() error {
if q.Query == nil { if q.Query == nil {
return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all")
} }
boil.SetDelete(q.Query) queries.SetDelete(q.Query)
_, err := boil.ExecQuery(q.Query) _, err := q.Query.Exec()
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}")
} }
return nil return nil
} }
// DeleteAll deletes all rows in the slice, and panics on error. // DeleteAllGP deletes all rows in the slice, and panics on error.
func (o {{$tableNameSingular}}Slice) DeleteAllGP() { func (o {{$tableNameSingular}}Slice) DeleteAllGP() {
if err := o.DeleteAllG(); err != nil { if err := o.DeleteAllG(); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// DeleteAllG deletes all rows in the slice. // DeleteAllG deletes all rows in the slice.
func (o {{$tableNameSingular}}Slice) DeleteAllG() error { func (o {{$tableNameSingular}}Slice) DeleteAllG() error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
} }
return o.DeleteAll(boil.GetDB()) return o.DeleteAll(boil.GetDB())
} }
// DeleteAllP deletes all rows in the slice, using an executor, and panics on error. // DeleteAllP deletes all rows in the slice, using an executor, and panics on error.
func (o {{$tableNameSingular}}Slice) DeleteAllP(exec boil.Executor) { func (o {{$tableNameSingular}}Slice) DeleteAllP(exec boil.Executor) {
if err := o.DeleteAll(exec); err != nil { if err := o.DeleteAll(exec); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// DeleteAll deletes all rows in the slice, using an executor. // DeleteAll deletes all rows in the slice, using an executor.
func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
} }
if len(o) == 0 { if len(o) == 0 {
return nil return nil
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if len({{$varNameSingular}}BeforeDeleteHooks) != 0 { if len({{$varNameSingular}}BeforeDeleteHooks) != 0 {
for _, obj := range o { for _, obj := range o {
if err := obj.doBeforeDeleteHooks(exec); err != nil { if err := obj.doBeforeDeleteHooks(exec); err != nil {
return err return err
} }
} }
} }
{{- end}} {{- end}}
args := o.inPrimaryKeyArgs() args := o.inPrimaryKeyArgs()
sql := fmt.Sprintf( sql := fmt.Sprintf(
`DELETE FROM {{.Table.Name}} WHERE (%s) IN (%s)`, "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)",
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
) )
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql) fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args) fmt.Fprintln(boil.DebugWriter, args)
} }
_, err := exec.Exec(sql, args...) _, err := exec.Exec(sql, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice") return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice")
} }
{{if not .NoHooks -}} {{if not .NoHooks -}}
if len({{$varNameSingular}}AfterDeleteHooks) != 0 { if len({{$varNameSingular}}AfterDeleteHooks) != 0 {
for _, obj := range o { for _, obj := range o {
if err := obj.doAfterDeleteHooks(exec); err != nil { if err := obj.doAfterDeleteHooks(exec); err != nil {
return err return err
} }
} }
} }
{{- end}} {{- end}}
return nil return nil
} }

View file

@ -1,85 +1,94 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// ReloadGP refetches the object from the database and panics on error. // ReloadGP refetches the object from the database and panics on error.
func (o *{{$tableNameSingular}}) ReloadGP() { func (o *{{$tableNameSingular}}) ReloadGP() {
if err := o.ReloadG(); err != nil { if err := o.ReloadG(); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// ReloadP refetches the object from the database with an executor. Panics on error. // ReloadP refetches the object from the database with an executor. Panics on error.
func (o *{{$tableNameSingular}}) ReloadP(exec boil.Executor) { func (o *{{$tableNameSingular}}) ReloadP(exec boil.Executor) {
if err := o.Reload(exec); err != nil { if err := o.Reload(exec); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// ReloadG refetches the object from the database using the primary keys. // ReloadG refetches the object from the database using the primary keys.
func (o *{{$tableNameSingular}}) ReloadG() error { func (o *{{$tableNameSingular}}) ReloadG() error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload") return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload")
} }
return o.Reload(boil.GetDB()) return o.Reload(boil.GetDB())
} }
// Reload refetches the object from the database // Reload refetches the object from the database
// using the primary keys with an executor. // using the primary keys with an executor.
func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error { func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error {
ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
if err != nil { if err != nil {
return err return err
} }
*o = *ret *o = *ret
return nil return nil
} }
// ReloadAllGP refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
// Panics on error.
func (o *{{$tableNameSingular}}Slice) ReloadAllGP() { func (o *{{$tableNameSingular}}Slice) ReloadAllGP() {
if err := o.ReloadAllG(); err != nil { if err := o.ReloadAllG(); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// ReloadAllP refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
// Panics on error.
func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) { func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) {
if err := o.ReloadAll(exec); err != nil { if err := o.ReloadAll(exec); err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
} }
// ReloadAllG refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAllG() error { func (o *{{$tableNameSingular}}Slice) ReloadAllG() error {
if o == nil { if o == nil {
return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all")
} }
return o.ReloadAll(boil.GetDB()) return o.ReloadAll(boil.GetDB())
} }
// ReloadAll refetches every row with matching primary key column values // ReloadAll refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice. // and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error {
if o == nil || len(*o) == 0 { if o == nil || len(*o) == 0 {
return nil return nil
} }
{{$varNamePlural}} := {{$tableNameSingular}}Slice{} {{$varNamePlural}} := {{$tableNameSingular}}Slice{}
args := o.inPrimaryKeyArgs() args := o.inPrimaryKeyArgs()
sql := fmt.Sprintf( sql := fmt.Sprintf(
`SELECT {{.Table.Name}}.* FROM {{.Table.Name}} WHERE (%s) IN (%s)`, "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)",
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
) )
q := boil.SQL(exec, sql, args...) q := queries.Raw(exec, sql, args...)
err := q.Bind(&{{$varNamePlural}}) err := q.Bind(&{{$varNamePlural}})
if err != nil { if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice") return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice")
} }
*o = {{$varNamePlural}} *o = {{$varNamePlural}}
return nil return nil
} }

View file

@ -2,48 +2,49 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists. // {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists.
func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) {
var exists bool var exists bool
sql := `select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}} limit 1)` sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)"
if boil.DebugMode { if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql) fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}})
} }
row := exec.QueryRow(sql, {{$pkNames | join ", "}}) row := exec.QueryRow(sql, {{$pkNames | join ", "}})
err := row.Scan(&exists) err := row.Scan(&exists)
if err != nil { if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists")
} }
return exists, nil return exists, nil
} }
// {{$tableNameSingular}}ExistsG checks if the {{$tableNameSingular}} row exists. // {{$tableNameSingular}}ExistsG checks if the {{$tableNameSingular}} row exists.
func {{$tableNameSingular}}ExistsG({{$pkArgs}}) (bool, error) { func {{$tableNameSingular}}ExistsG({{$pkArgs}}) (bool, error) {
return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
} }
// {{$tableNameSingular}}ExistsGP checks if the {{$tableNameSingular}} row exists. Panics on error. // {{$tableNameSingular}}ExistsGP checks if the {{$tableNameSingular}} row exists. Panics on error.
func {{$tableNameSingular}}ExistsGP({{$pkArgs}}) bool { func {{$tableNameSingular}}ExistsGP({{$pkArgs}}) bool {
e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return e return e
} }
// {{$tableNameSingular}}ExistsP checks if the {{$tableNameSingular}} row exists. Panics on error. // {{$tableNameSingular}}ExistsP checks if the {{$tableNameSingular}} row exists. Panics on error.
func {{$tableNameSingular}}ExistsP(exec boil.Executor, {{$pkArgs}}) bool { func {{$tableNameSingular}}ExistsP(exec boil.Executor, {{$pkArgs}}) bool {
e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}}) e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}})
if err != nil { if err != nil {
panic(boil.WrapErr(err)) panic(boil.WrapErr(err))
} }
return e return e
} }

View file

@ -1,23 +1,23 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}}
func (o {{$tableNameSingular}}) inPrimaryKeyArgs() []interface{} { func (o {{$tableNameSingular}}) inPrimaryKeyArgs() []interface{} {
var args []interface{} var args []interface{}
{{- range $key, $value := .Table.PKey.Columns }} {{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o.{{titleCase $value}}) args = append(args, o.{{titleCase $value}})
{{ end -}} {{ end -}}
return args return args
} }
func (o {{$tableNameSingular}}Slice) inPrimaryKeyArgs() []interface{} { func (o {{$tableNameSingular}}Slice) inPrimaryKeyArgs() []interface{} {
var args []interface{} var args []interface{}
for i := 0; i < len(o); i++ { for i := 0; i < len(o); i++ {
{{- range $key, $value := .Table.PKey.Columns }} {{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o[i].{{titleCase $value}}) args = append(args, o[i].{{titleCase $value}})
{{ end -}} {{ end -}}
} }
return args return args
} }

View file

@ -1,82 +1,82 @@
{{- define "timestamp_insert_helper" -}} {{- define "timestamp_insert_helper" -}}
{{- if not .NoAutoTimestamps -}} {{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}} {{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}} {{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation()) currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}} {{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}} {{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}} {{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() { if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true o.CreatedAt.Valid = true
} }
{{- else}} {{- else}}
if o.CreatedAt.IsZero() { if o.CreatedAt.IsZero() {
o.CreatedAt = currTime o.CreatedAt = currTime
} }
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{- if eq $col.Name "updated_at" -}} {{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}} {{- if $col.Nullable}}
if o.UpdatedAt.Time.IsZero() { if o.UpdatedAt.Time.IsZero() {
o.UpdatedAt.Time = currTime o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true o.UpdatedAt.Valid = true
} }
{{- else}} {{- else}}
if o.UpdatedAt.IsZero() { if o.UpdatedAt.IsZero() {
o.UpdatedAt = currTime o.UpdatedAt = currTime
} }
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{end}} {{end}}
{{end}} {{end}}
{{- end}} {{- end}}
{{- end -}} {{- end -}}
{{- define "timestamp_update_helper" -}} {{- define "timestamp_update_helper" -}}
{{- if not .NoAutoTimestamps -}} {{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}} {{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "updated_at"}} {{if containsAny $colNames "updated_at"}}
currTime := time.Now().In(boil.GetLocation()) currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}} {{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "updated_at" -}} {{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}} {{- if $col.Nullable}}
o.UpdatedAt.Time = currTime o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true o.UpdatedAt.Valid = true
{{- else}} {{- else}}
o.UpdatedAt = currTime o.UpdatedAt = currTime
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{end}} {{end}}
{{end}} {{end}}
{{- end}} {{- end}}
{{end -}} {{end -}}
{{- define "timestamp_upsert_helper" -}} {{- define "timestamp_upsert_helper" -}}
{{- if not .NoAutoTimestamps -}} {{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}} {{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}} {{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation()) currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}} {{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}} {{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}} {{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() { if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true o.CreatedAt.Valid = true
} }
{{- else}} {{- else}}
if o.CreatedAt.IsZero() { if o.CreatedAt.IsZero() {
o.CreatedAt = currTime o.CreatedAt = currTime
} }
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{- if eq $col.Name "updated_at" -}} {{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}} {{- if $col.Nullable}}
o.UpdatedAt.Time = currTime o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true o.UpdatedAt.Valid = true
{{- else}} {{- else}}
o.UpdatedAt = currTime o.UpdatedAt = currTime
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{end}} {{end}}
{{end}} {{end}}
{{- end}} {{- end}}
{{end -}} {{end -}}

View file

@ -1,12 +1,19 @@
var dialect = queries.Dialect{
LQ: 0x{{printf "%x" .Dialect.LQ}},
RQ: 0x{{printf "%x" .Dialect.RQ}},
IndexPlaceholders: {{.Dialect.IndexPlaceholders}},
}
// NewQueryG initializes a new Query using the passed in QueryMods // NewQueryG initializes a new Query using the passed in QueryMods
func NewQueryG(mods ...qm.QueryMod) *boil.Query { func NewQueryG(mods ...qm.QueryMod) *queries.Query {
return NewQuery(boil.GetDB(), mods...) return NewQuery(boil.GetDB(), mods...)
} }
// NewQuery initializes a new Query using the passed in QueryMods // NewQuery initializes a new Query using the passed in QueryMods
func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query {
q := &boil.Query{} q := &queries.Query{}
boil.SetExecutor(q, exec) queries.SetExecutor(q, exec)
queries.SetDialect(q, &dialect)
qm.Apply(q, mods...) qm.Apply(q, mods...)
return q return q

View file

@ -6,33 +6,32 @@ type M map[string]interface{}
// fails or there was a primary key configuration that was not resolvable. // fails or there was a primary key configuration that was not resolvable.
var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert") var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert")
type insertCache struct{ type insertCache struct {
query string query string
retQuery string retQuery string
valueMapping []uint64 valueMapping []uint64
retMapping []uint64 retMapping []uint64
} }
type updateCache struct{ type updateCache struct {
query string query string
valueMapping []uint64 valueMapping []uint64
} }
func makeCacheKey(wl, nzDefaults []string) string { func makeCacheKey(wl, nzDefaults []string) string {
buf := strmangle.GetBuffer() buf := strmangle.GetBuffer()
for _, w := range wl { for _, w := range wl {
buf.WriteString(w) buf.WriteString(w)
} }
if len(nzDefaults) != 0 { if len(nzDefaults) != 0 {
buf.WriteByte('.') buf.WriteByte('.')
} }
for _, nz := range nzDefaults { for _, nz := range nzDefaults {
buf.WriteString(nz) buf.WriteString(nz)
} }
str := buf.String() str := buf.String()
strmangle.PutBuffer(buf) strmangle.PutBuffer(buf)
return str return str
} }

View file

@ -3,11 +3,11 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}(t *testing.T) { func test{{$tableNamePlural}}(t *testing.T) {
t.Parallel() t.Parallel()
query := {{$tableNamePlural}}(nil) query := {{$tableNamePlural}}(nil)
if query.Query == nil { if query.Query == nil {
t.Error("expected a query, got nothing") t.Error("expected a query, got nothing")
} }
} }

View file

@ -3,93 +3,93 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Delete(t *testing.T) { func test{{$tableNamePlural}}Delete(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$varNameSingular}}.Delete(tx); err != nil { if err = {{$varNameSingular}}.Delete(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 0 { if count != 0 {
t.Error("want zero records, got:", count) t.Error("want zero records, got:", count)
} }
} }
func test{{$tableNamePlural}}QueryDeleteAll(t *testing.T) { func test{{$tableNamePlural}}QueryDeleteAll(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil { if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 0 { if count != 0 {
t.Error("want zero records, got:", count) t.Error("want zero records, got:", count)
} }
} }
func test{{$tableNamePlural}}SliceDeleteAll(t *testing.T) { func test{{$tableNamePlural}}SliceDeleteAll(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.DeleteAll(tx); err != nil { if err = slice.DeleteAll(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 0 { if count != 0 {
t.Error("want zero records, got:", count) t.Error("want zero records, got:", count)
} }
} }

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Exists(t *testing.T) { func test{{$tableNamePlural}}Exists(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
{{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}} {{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}}
e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}}) e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}})
if err != nil { if err != nil {
t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err) t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err)
} }
if e != true { if e != true {
t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.") t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.")
} }
} }

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Find(t *testing.T) { func test{{$tableNamePlural}}Find(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
{{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}}) {{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if {{$varNameSingular}}Found == nil { if {{$varNameSingular}}Found == nil {
t.Error("want a record, got nil") t.Error("want a record, got nil")
} }
} }

View file

@ -3,111 +3,111 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Bind(t *testing.T) { func test{{$tableNamePlural}}Bind(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil { if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil {
t.Error(err) t.Error(err)
} }
} }
func test{{$tableNamePlural}}One(t *testing.T) { func test{{$tableNamePlural}}One(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if x, err := {{$tableNamePlural}}(tx).One(); err != nil { if x, err := {{$tableNamePlural}}(tx).One(); err != nil {
t.Error(err) t.Error(err)
} else if x == nil { } else if x == nil {
t.Error("expected to get a non nil record") t.Error("expected to get a non nil record")
} }
} }
func test{{$tableNamePlural}}All(t *testing.T) { func test{{$tableNamePlural}}All(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}}One := &{{$tableNameSingular}}{} {{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{} {{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil { if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$varNameSingular}}Two.Insert(tx); err != nil { if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
slice, err := {{$tableNamePlural}}(tx).All() slice, err := {{$tableNamePlural}}(tx).All()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if len(slice) != 2 { if len(slice) != 2 {
t.Error("want 2 records, got:", len(slice)) t.Error("want 2 records, got:", len(slice))
} }
} }
func test{{$tableNamePlural}}Count(t *testing.T) { func test{{$tableNamePlural}}Count(t *testing.T) {
t.Parallel() t.Parallel()
var err error var err error
seed := randomize.NewSeed() seed := randomize.NewSeed()
{{$varNameSingular}}One := &{{$tableNameSingular}}{} {{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{} {{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil { if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$varNameSingular}}Two.Insert(tx); err != nil { if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 2 { if count != 2 {
t.Error("want 2 records, got:", count) t.Error("want 2 records, got:", count)
} }
} }

View file

@ -5,57 +5,57 @@
var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}} var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}}
func test{{$tableNamePlural}}InPrimaryKeyArgs(t *testing.T) { func test{{$tableNamePlural}}InPrimaryKeyArgs(t *testing.T) {
t.Parallel() t.Parallel()
var err error var err error
var o {{$tableNameSingular}} var o {{$tableNameSingular}}
o = {{$tableNameSingular}}{} o = {{$tableNameSingular}}{}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err) t.Errorf("Could not randomize struct: %s", err)
} }
args := o.inPrimaryKeyArgs() args := o.inPrimaryKeyArgs()
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) { if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args)) t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args))
} }
{{range $key, $value := .Table.PKey.Columns}} {{range $key, $value := .Table.PKey.Columns}}
if o.{{titleCase $value}} != args[{{$key}}] { if o.{{titleCase $value}} != args[{{$key}}] {
t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}]) t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}])
} }
{{- end}} {{- end}}
} }
func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) { func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) {
t.Parallel() t.Parallel()
var err error var err error
o := make({{$tableNameSingular}}Slice, 3) o := make({{$tableNameSingular}}Slice, 3)
seed := randomize.NewSeed() seed := randomize.NewSeed()
for i := range o { for i := range o {
o[i] = &{{$tableNameSingular}}{} o[i] = &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err) t.Errorf("Could not randomize struct: %s", err)
} }
} }
args := o.inPrimaryKeyArgs() args := o.inPrimaryKeyArgs()
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 { if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args))
} }
argC := 0 argC := 0
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
{{range $key, $value := .Table.PKey.Columns}} {{range $key, $value := .Table.PKey.Columns}}
if o[i].{{titleCase $value}} != args[argC] { if o[i].{{titleCase $value}} != args[argC] {
t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i])
} }
argC++ argC++
{{- end}} {{- end}}
} }
} }

View file

@ -4,142 +4,142 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func {{$varNameSingular}}BeforeInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}BeforeInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}AfterInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}AfterInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}AfterSelectHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}AfterSelectHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}BeforeUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}BeforeUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}AfterUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}AfterUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}BeforeDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}BeforeDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{} *o = {{$tableNameSingular}}{}
return nil return nil
} }
func test{{$tableNamePlural}}Hooks(t *testing.T) { func test{{$tableNamePlural}}Hooks(t *testing.T) {
t.Parallel() t.Parallel()
var err error var err error
empty := &{{$tableNameSingular}}{} empty := &{{$tableNameSingular}}{}
o := &{{$tableNameSingular}}{} o := &{{$tableNameSingular}}{}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil { if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err)
} }
Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook) Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook)
if err = o.doBeforeInsertHooks(nil); err != nil { if err = o.doBeforeInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeInsertHooks: %s", err) t.Errorf("Unable to execute doBeforeInsertHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o) t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook) Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook)
if err = o.doAfterInsertHooks(nil); err != nil { if err = o.doAfterInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterInsertHooks: %s", err) t.Errorf("Unable to execute doAfterInsertHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o) t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook) Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook)
if err = o.doAfterSelectHooks(nil); err != nil { if err = o.doAfterSelectHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterSelectHooks: %s", err) t.Errorf("Unable to execute doAfterSelectHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o) t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook) Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook)
if err = o.doBeforeUpdateHooks(nil); err != nil { if err = o.doBeforeUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err) t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o) t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook) Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook)
if err = o.doAfterUpdateHooks(nil); err != nil { if err = o.doAfterUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpdateHooks: %s", err) t.Errorf("Unable to execute doAfterUpdateHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o) t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook) Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook)
if err = o.doBeforeDeleteHooks(nil); err != nil { if err = o.doBeforeDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err) t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o) t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook) Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook)
if err = o.doAfterDeleteHooks(nil); err != nil { if err = o.doAfterDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterDeleteHooks: %s", err) t.Errorf("Unable to execute doAfterDeleteHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o) t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook)
if err = o.doBeforeUpsertHooks(nil); err != nil { if err = o.doBeforeUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o) t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook) Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook)
if err = o.doAfterUpsertHooks(nil); err != nil { if err = o.doAfterUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpsertHooks: %s", err) t.Errorf("Unable to execute doAfterUpsertHooks: %s", err)
} }
if !reflect.DeepEqual(o, empty) { if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o)
} }
{{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{}
} }
{{- end}} {{- end}}

View file

@ -4,53 +4,53 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $parent := . -}} {{- $parent := . -}}
func test{{$tableNamePlural}}Insert(t *testing.T) { func test{{$tableNamePlural}}Insert(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
} }
func test{{$tableNamePlural}}InsertWhitelist(t *testing.T) { func test{{$tableNamePlural}}InsertWhitelist(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
} }

View file

@ -0,0 +1,166 @@
type mysqlTester struct {
dbConn *sql.DB
dbName string
host string
user string
pass string
sslmode string
port int
optionFile string
testDBName string
}
func init() {
dbMain = &mysqlTester{}
}
func (m *mysqlTester) setup() error {
var err error
m.dbName = viper.GetString("mysql.dbname")
m.host = viper.GetString("mysql.host")
m.user = viper.GetString("mysql.user")
m.pass = viper.GetString("mysql.pass")
m.port = viper.GetInt("mysql.port")
m.sslmode = viper.GetString("mysql.sslmode")
// Create a randomized db name.
m.testDBName = randomize.StableDBName(m.dbName)
if err = m.makeOptionFile(); err != nil {
return errors.Wrap(err, "couldn't make option file")
}
if err = m.dropTestDB(); err != nil {
return err
}
if err = m.createTestDB(); err != nil {
return err
}
dumpCmd := exec.Command("mysqldump", m.defaultsFile(), "--no-data", m.dbName)
createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName)
r, w := io.Pipe()
dumpCmd.Stdout = w
createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r)
if err = dumpCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start mysqldump command")
}
if err = createCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start mysql command")
}
if err = dumpCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for mysqldump command")
}
w.Close() // After dumpCmd is done, close the write end of the pipe
if err = createCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for mysql command")
}
return nil
}
func (m *mysqlTester) sslMode(mode string) string {
switch mode {
case "true":
return "REQUIRED"
case "false":
return "DISABLED"
default:
return "PREFERRED"
}
}
func (m *mysqlTester) defaultsFile() string {
return fmt.Sprintf("--defaults-file=%s", m.optionFile)
}
func (m *mysqlTester) makeOptionFile() error {
tmp, err := ioutil.TempFile("", "optionfile")
if err != nil {
return errors.Wrap(err, "failed to create option file")
}
fmt.Fprintln(tmp, "[client]")
fmt.Fprintf(tmp, "host=%s\n", m.host)
fmt.Fprintf(tmp, "port=%d\n", m.port)
fmt.Fprintf(tmp, "user=%s\n", m.user)
fmt.Fprintf(tmp, "password=%s\n", m.pass)
fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode))
fmt.Fprintln(tmp, "[mysqldump]")
fmt.Fprintf(tmp, "host=%s\n", m.host)
fmt.Fprintf(tmp, "port=%d\n", m.port)
fmt.Fprintf(tmp, "user=%s\n", m.user)
fmt.Fprintf(tmp, "password=%s\n", m.pass)
fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode))
m.optionFile = tmp.Name()
return tmp.Close()
}
func (m *mysqlTester) createTestDB() error {
sql := fmt.Sprintf("create database %s;", m.testDBName)
return m.runCmd(sql, "mysql")
}
func (m *mysqlTester) dropTestDB() error {
sql := fmt.Sprintf("drop database if exists %s;", m.testDBName)
return m.runCmd(sql, "mysql")
}
func (m *mysqlTester) teardown() error {
if m.dbConn != nil {
m.dbConn.Close()
}
if err := m.dropTestDB(); err != nil {
return err
}
return os.Remove(m.optionFile)
}
func (m *mysqlTester) runCmd(stdin, command string, args ...string) error {
args = append([]string{m.defaultsFile()}, args...)
cmd := exec.Command(command, args...)
cmd.Stdin = strings.NewReader(stdin)
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
fmt.Println("failed running:", command, args)
fmt.Println(stdout.String())
fmt.Println(stderr.String())
return err
}
return nil
}
func (m *mysqlTester) conn() (*sql.DB, error) {
if m.dbConn != nil {
return m.dbConn, nil
}
var err error
m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode))
if err != nil {
return nil, err
}
return m.dbConn, nil
}

View file

@ -1,275 +1,166 @@
type PostgresCfg struct { type pgTester struct {
User string `toml:"user"` dbConn *sql.DB
Pass string `toml:"pass"`
Host string `toml:"host"` dbName string
Port int `toml:"port"` host string
DBName string `toml:"dbname"` user string
SSLMode string `toml:"sslmode"` pass string
sslmode string
port int
pgPassFile string
testDBName string
} }
type Config struct { func init() {
Postgres PostgresCfg `toml:"postgres"` dbMain = &pgTester{}
}
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-2)
}
var code int
if err = disableTriggers(); err != nil {
fmt.Println("Unable to disable triggers:", err)
} else {
boil.SetDB(dbConn)
code = m.Run()
}
if err = teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-3)
}
os.Exit(code)
}
// disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors.
func disableTriggers() error {
var stmts []string
{{range .Tables}}
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
{{- end}}
if len(stmts) == 0 {
return nil
}
var err error
for _, s := range stmts {
_, err = dbConn.Exec(s)
if err != nil {
return err
}
}
return nil
}
// teardown executes cleanup tasks when the tests finish running
func teardown() error {
err := dropTestDB()
return err
}
// dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations.
func dropTestDB() error {
var err error
if dbConn != nil {
if err = dbConn.Close(); err != nil {
return err
}
}
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
if err != nil {
return err
}
return dbConn.Close()
}
// DBConnect connects to a database and returns the handle.
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode)
return sql.Open("postgres", connStr)
} }
// setup dumps the database schema and imports it into a temporary randomly // setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the // generated test database so that tests can be run against it using the
// generated sqlboiler ORM package. // generated sqlboiler ORM package.
func setup() error { func (p *pgTester) setup() error {
var err error var err error
// Initialize Viper and load the config file p.dbName = viper.GetString("postgres.dbname")
err = InitViper() p.host = viper.GetString("postgres.host")
if err != nil { p.user = viper.GetString("postgres.user")
return errors.Wrap(err, "Unable to load config file") p.pass = viper.GetString("postgres.pass")
} p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.sslmode")
// Create a randomized db name.
p.testDBName = randomize.StableDBName(p.dbName)
viper.SetDefault("postgres.sslmode", "require") if err = p.makePGPassFile(); err != nil {
viper.SetDefault("postgres.port", "5432") return err
}
// Create a randomized test configuration object. if err = p.dropTestDB(); err != nil {
testCfg.Postgres.Host = viper.GetString("postgres.host") return err
testCfg.Postgres.Port = viper.GetInt("postgres.port") }
testCfg.Postgres.User = viper.GetString("postgres.user") if err = p.createTestDB(); err != nil {
testCfg.Postgres.Pass = viper.GetString("postgres.pass") return err
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) }
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
// Set the default SSLMode value dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName)
if testCfg.Postgres.SSLMode == "" { dumpCmd.Env = append(os.Environ(), p.pgEnv()...)
viper.Set("postgres.sslmode", "require") createCmd := exec.Command("psql", p.testDBName)
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") createCmd.Env = append(os.Environ(), p.pgEnv()...)
}
err = vala.BeginValidation().Validate( r, w := io.Pipe()
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), dumpCmd.Stdout = w
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r)
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check()
if err != nil { if err = dumpCmd.Start(); err != nil {
return errors.Wrap(err, "Unable to load testCfg") return errors.Wrap(err, "failed to start pg_dump command")
} }
if err = createCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start psql command")
}
err = dropTestDB() if err = dumpCmd.Wait(); err != nil {
if err != nil { fmt.Println(err)
fmt.Printf("%#v\n", err) return errors.Wrap(err, "failed to wait for pg_dump command")
return err }
}
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema") w.Close() // After dumpCmd is done, close the write end of the pipe
if err != nil {
return errors.Wrap(err, "Unable to create sqlboiler schema tmp file")
}
defer os.Remove(fhSchema.Name())
passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler") if err = createCmd.Wait(); err != nil {
if err != nil { fmt.Println(err)
return errors.Wrap(err, "Unable to create sqlboiler tmp dir for postgres pw file") return errors.Wrap(err, "failed to wait for psql command")
} }
defer os.RemoveAll(passDir)
// Write the postgres user password to a tmp file for pg_dump return nil
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.user"),
))
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
}
passFilePath := filepath.Join(passDir, "pwfile")
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
if err != nil {
return errors.Wrap(err, "Unable to create pwfile in passDir")
}
// The params for the pg_dump command to dump the database schema
params := []string{
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")),
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")),
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")),
"--schema-only",
viper.GetString("postgres.dbname"),
}
// Dump the database schema into the sqlboilerschema tmp file
errBuf := bytes.Buffer{}
cmd := exec.Command("pg_dump", params...)
cmd.Stderr = &errBuf
cmd.Stdout = fhSchema
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath))
if err := cmd.Run(); err != nil {
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
return err
}
dbConn, err = DBConnect(
viper.GetString("postgres.user"),
viper.GetString("postgres.pass"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
)
if err != nil {
return err
}
// Create the randomly generated database
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
if err != nil {
return err
}
// Close the old connection so we can reconnect to the test database
if err = dbConn.Close(); err != nil {
return err
}
// Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
// Write the test config credentials to a tmp file for pg_dump
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
testCfg.Postgres.Host,
testCfg.Postgres.Port,
testCfg.Postgres.DBName,
testCfg.Postgres.User,
))
if len(testCfg.Postgres.Pass) > 0 {
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
}
testPassFilePath := passDir + "/testpwfile"
err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)
if err != nil {
return errors.Wrapf(err, "Unable to create testpwfile in passDir")
}
// The params for the psql schema import command
params = []string{
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
}
// Import the database schema into the generated database.
// It is now ready to be used by the generated ORM package for testing.
outBuf := bytes.Buffer{}
cmd = exec.Command("psql", params...)
cmd.Stderr = &errBuf
cmd.Stdout = &outBuf
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath))
if err = cmd.Run(); err != nil {
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
}
return nil
} }
func (p *pgTester) runCmd(stdin, command string, args ...string) error {
cmd := exec.Command(command, args...)
cmd.Env = append(os.Environ(), p.pgEnv()...)
if len(stdin) != 0 {
cmd.Stdin = strings.NewReader(stdin)
}
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
fmt.Println("failed running:", command, args)
fmt.Println(stdout.String())
fmt.Println(stderr.String())
return err
}
return nil
}
func (p *pgTester) pgEnv() []string {
return []string{
fmt.Sprintf("PGHOST=%s", p.host),
fmt.Sprintf("PGPORT=%d", p.port),
fmt.Sprintf("PGUSER=%s", p.user),
fmt.Sprintf("PGPASS=%s", p.pgPassFile),
}
}
func (p *pgTester) makePGPassFile() error {
tmp, err := ioutil.TempFile("", "pgpass")
if err != nil {
return errors.Wrap(err, "failed to create option file")
}
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
p.pgPassFile = tmp.Name()
return tmp.Close()
}
func (p *pgTester) createTestDB() error {
return p.runCmd("", "createdb", p.testDBName)
}
func (p *pgTester) dropTestDB() error {
return p.runCmd("", "dropdb", "--if-exists", p.testDBName)
}
// teardown executes cleanup tasks when the tests finish running
func (p *pgTester) teardown() error {
var err error
if err = p.dbConn.Close(); err != nil {
return err
}
p.dbConn = nil
if err = p.dropTestDB(); err != nil {
return err
}
return os.Remove(p.pgPassFile)
}
func (p *pgTester) conn() (*sql.DB, error) {
if p.dbConn != nil {
return p.dbConn, nil
}
var err error
p.dbConn, err = sql.Open("postgres", drivers.PostgresBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode))
if err != nil {
return nil, err
}
return p.dbConn, nil
}

View file

@ -1,98 +1,98 @@
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . }} {{- $dot := . }}
{{- $table := .Table }} {{- $table := .Table }}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} {{- template "relationship_to_one_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}} {{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}} {{- $rel := textsFromRelationship $dot.Tables $table . -}}
func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}} var a {{$rel.LocalTable.NameGo}}
var b, c {{$rel.ForeignTable.NameGo}} var b, c {{$rel.ForeignTable.NameGo}}
if err := a.Insert(tx); err != nil { if err := a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
seed := randomize.NewSeed() seed := randomize.NewSeed()
randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
{{if .Nullable -}} {{if .Nullable -}}
a.{{.Column | titleCase}}.Valid = true a.{{.Column | titleCase}}.Valid = true
{{- end}} {{- end}}
{{- if .ForeignColumnNullable -}} {{- if .ForeignColumnNullable -}}
b.{{.ForeignColumn | titleCase}}.Valid = true b.{{.ForeignColumn | titleCase}}.Valid = true
c.{{.ForeignColumn | titleCase}}.Valid = true c.{{.ForeignColumn | titleCase}}.Valid = true
{{- end}} {{- end}}
{{if not .ToJoinTable -}} {{if not .ToJoinTable -}}
b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
{{- end}} {{- end}}
if err = b.Insert(tx); err != nil { if err = b.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = c.Insert(tx); err != nil { if err = c.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
{{if .ToJoinTable -}} {{if .ToJoinTable -}}
_, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
{{end}} {{end}}
{{$varname := .ForeignTable | singular | camelCase -}} {{$varname := .ForeignTable | singular | camelCase -}}
{{$varname}}, err := a.{{$rel.Function.Name}}(tx).All() {{$varname}}, err := a.{{$rel.Function.Name}}(tx).All()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
bFound, cFound := false, false bFound, cFound := false, false
for _, v := range {{$varname}} { for _, v := range {{$varname}} {
if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} { if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} {
bFound = true bFound = true
} }
if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} { if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} {
cFound = true cFound = true
} }
} }
if !bFound { if !bFound {
t.Error("expected to find b") t.Error("expected to find b")
} }
if !cFound { if !cFound {
t.Error("expected to find c") t.Error("expected to find c")
} }
slice := {{$rel.LocalTable.NameGo}}Slice{&a} slice := {{$rel.LocalTable.NameGo}}Slice{&a}
if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil { if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got := len(a.R.{{$rel.Function.Name}}); got != 2 { if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got) t.Error("number of eager loaded records wrong, got:", got)
} }
a.R.{{$rel.Function.Name}} = nil a.R.{{$rel.Function.Name}} = nil
if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got := len(a.R.{{$rel.Function.Name}}); got != 2 { if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got) t.Error("number of eager loaded records wrong, got:", got)
} }
if t.Failed() { if t.Failed() {
t.Logf("%#v", {{$varname}}) t.Logf("%#v", {{$varname}})
} }
} }
{{end -}}{{- /* if unique */ -}} {{end -}}{{- /* if unique */ -}}

View file

@ -1,306 +1,306 @@
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- $table := .Table -}} {{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}} {{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_setops_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} {{- template "relationship_to_one_setops_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}} {{- else -}}
{{- $varNameSingular := .Table | singular | camelCase -}} {{- $varNameSingular := .Table | singular | camelCase -}}
{{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}} {{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}}
{{- $rel := textsFromRelationship $dot.Tables $table .}} {{- $rel := textsFromRelationship $dot.Tables $table .}}
func test{{$rel.LocalTable.NameGo}}ToManyAddOp{{$rel.Function.Name}}(t *testing.T) { func test{{$rel.LocalTable.NameGo}}ToManyAddOp{{$rel.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}} var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}} var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners { for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if err := a.Insert(tx); err != nil { if err := a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = b.Insert(tx); err != nil { if err = b.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = c.Insert(tx); err != nil { if err = c.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{ foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{
{&b, &c}, {&b, &c},
{&d, &e}, {&d, &e},
} }
for i, x := range foreignersSplitByInsertion { for i, x := range foreignersSplitByInsertion {
err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...) err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
first := x[0] first := x[0]
second := x[1] second := x[1]
{{- if .ToJoinTable}} {{- if .ToJoinTable}}
if first.R.{{$rel.Function.ForeignName}}[0] != &a { if first.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice") t.Error("relationship was not added properly to the slice")
} }
if second.R.{{$rel.Function.ForeignName}}[0] != &a { if second.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice") t.Error("relationship was not added properly to the slice")
} }
{{- else}} {{- else}}
if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} { if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}}) t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}})
} }
if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} { if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}}) t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}})
} }
if first.R.{{$rel.Function.ForeignName}} != &a { if first.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice") t.Error("relationship was not added properly to the foreign slice")
} }
if second.R.{{$rel.Function.ForeignName}} != &a { if second.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice") t.Error("relationship was not added properly to the foreign slice")
} }
{{- end}} {{- end}}
if a.R.{{$rel.Function.Name}}[i*2] != first { if a.R.{{$rel.Function.Name}}[i*2] != first {
t.Error("relationship struct slice not set to correct value") t.Error("relationship struct slice not set to correct value")
} }
if a.R.{{$rel.Function.Name}}[i*2+1] != second { if a.R.{{$rel.Function.Name}}[i*2+1] != second {
t.Error("relationship struct slice not set to correct value") t.Error("relationship struct slice not set to correct value")
} }
count, err := a.{{$rel.Function.Name}}(tx).Count() count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if want := int64((i+1)*2); count != want { if want := int64((i+1)*2); count != want {
t.Error("want", want, "got", count) t.Error("want", want, "got", count)
} }
} }
} }
{{- if (or .ForeignColumnNullable .ToJoinTable)}} {{- if (or .ForeignColumnNullable .ToJoinTable)}}
func test{{$rel.LocalTable.NameGo}}ToManySetOp{{$rel.Function.Name}}(t *testing.T) { func test{{$rel.LocalTable.NameGo}}ToManySetOp{{$rel.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}} var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}} var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners { for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if err = a.Insert(tx); err != nil { if err = a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = b.Insert(tx); err != nil { if err = b.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = c.Insert(tx); err != nil { if err = c.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c) err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
count, err := a.{{$rel.Function.Name}}(tx).Count() count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if count != 2 { if count != 2 {
t.Error("count was wrong:", count) t.Error("count was wrong:", count)
} }
err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e) err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
count, err = a.{{$rel.Function.Name}}(tx).Count() count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if count != 2 { if count != 2 {
t.Error("count was wrong:", count) t.Error("count was wrong:", count)
} }
{{- if .ToJoinTable}} {{- if .ToJoinTable}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 { if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice") t.Error("relationship was not removed properly from the slice")
} }
if len(c.R.{{$rel.Function.ForeignName}}) != 0 { if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice") t.Error("relationship was not removed properly from the slice")
} }
if d.R.{{$rel.Function.ForeignName}}[0] != &a { if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice") t.Error("relationship was not added properly to the slice")
} }
if e.R.{{$rel.Function.ForeignName}}[0] != &a { if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice") t.Error("relationship was not added properly to the slice")
} }
{{- else}} {{- else}}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil") t.Error("want b's foreign key value to be nil")
} }
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil") t.Error("want c's foreign key value to be nil")
} }
if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} { if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}}) t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}})
} }
if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} { if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}}) t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}})
} }
if b.R.{{$rel.Function.ForeignName}} != nil { if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct") t.Error("relationship was not removed properly from the foreign struct")
} }
if c.R.{{$rel.Function.ForeignName}} != nil { if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct") t.Error("relationship was not removed properly from the foreign struct")
} }
if d.R.{{$rel.Function.ForeignName}} != &a { if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct") t.Error("relationship was not added properly to the foreign struct")
} }
if e.R.{{$rel.Function.ForeignName}} != &a { if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct") t.Error("relationship was not added properly to the foreign struct")
} }
{{- end}} {{- end}}
if a.R.{{$rel.Function.Name}}[0] != &d { if a.R.{{$rel.Function.Name}}[0] != &d {
t.Error("relationship struct slice not set to correct value") t.Error("relationship struct slice not set to correct value")
} }
if a.R.{{$rel.Function.Name}}[1] != &e { if a.R.{{$rel.Function.Name}}[1] != &e {
t.Error("relationship struct slice not set to correct value") t.Error("relationship struct slice not set to correct value")
} }
} }
func test{{$rel.LocalTable.NameGo}}ToManyRemoveOp{{$rel.Function.Name}}(t *testing.T) { func test{{$rel.LocalTable.NameGo}}ToManyRemoveOp{{$rel.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}} var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}} var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners { for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if err := a.Insert(tx); err != nil { if err := a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...) err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
count, err := a.{{$rel.Function.Name}}(tx).Count() count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if count != 4 { if count != 4 {
t.Error("count was wrong:", count) t.Error("count was wrong:", count)
} }
err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...) err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
count, err = a.{{$rel.Function.Name}}(tx).Count() count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if count != 2 { if count != 2 {
t.Error("count was wrong:", count) t.Error("count was wrong:", count)
} }
{{- if .ToJoinTable}} {{- if .ToJoinTable}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 { if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice") t.Error("relationship was not removed properly from the slice")
} }
if len(c.R.{{$rel.Function.ForeignName}}) != 0 { if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice") t.Error("relationship was not removed properly from the slice")
} }
if d.R.{{$rel.Function.ForeignName}}[0] != &a { if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct") t.Error("relationship was not added properly to the foreign struct")
} }
if e.R.{{$rel.Function.ForeignName}}[0] != &a { if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct") t.Error("relationship was not added properly to the foreign struct")
} }
{{- else}} {{- else}}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil") t.Error("want b's foreign key value to be nil")
} }
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil") t.Error("want c's foreign key value to be nil")
} }
if b.R.{{$rel.Function.ForeignName}} != nil { if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct") t.Error("relationship was not removed properly from the foreign struct")
} }
if c.R.{{$rel.Function.ForeignName}} != nil { if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct") t.Error("relationship was not removed properly from the foreign struct")
} }
if d.R.{{$rel.Function.ForeignName}} != &a { if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved") t.Error("relationship to a should have been preserved")
} }
if e.R.{{$rel.Function.ForeignName}} != &a { if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved") t.Error("relationship to a should have been preserved")
} }
{{- end}} {{- end}}
if len(a.R.{{$rel.Function.Name}}) != 2 { if len(a.R.{{$rel.Function.Name}}) != 2 {
t.Error("should have preserved two relationships") t.Error("should have preserved two relationships")
} }
// Removal doesn't do a stable deletion for performance so we have to flip the order // Removal doesn't do a stable deletion for performance so we have to flip the order
if a.R.{{$rel.Function.Name}}[1] != &d { if a.R.{{$rel.Function.Name}}[1] != &d {
t.Error("relationship to d should have been preserved") t.Error("relationship to d should have been preserved")
} }
if a.R.{{$rel.Function.Name}}[0] != &e { if a.R.{{$rel.Function.Name}}[0] != &e {
t.Error("relationship to e should have been preserved") t.Error("relationship to e should have been preserved")
} }
} }
{{end -}} {{end -}}
{{- end -}}{{- /* if unique foreign key */ -}} {{- end -}}{{- /* if unique foreign key */ -}}

View file

@ -1,69 +1,69 @@
{{- define "relationship_to_one_test_helper"}} {{- define "relationship_to_one_test_helper"}}
func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var foreign {{.ForeignTable.NameGo}} var foreign {{.ForeignTable.NameGo}}
var local {{.LocalTable.NameGo}} var local {{.LocalTable.NameGo}}
{{if .ForeignKey.Nullable -}} {{if .ForeignKey.Nullable -}}
local.{{.ForeignKey.Column | titleCase}}.Valid = true local.{{.ForeignKey.Column | titleCase}}.Valid = true
{{end}} {{end}}
{{- if .ForeignKey.ForeignColumnNullable -}} {{- if .ForeignKey.ForeignColumnNullable -}}
foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true
{{end}} {{end}}
{{if not .Function.OneToOne -}} {{if not .Function.OneToOne -}}
if err := foreign.Insert(tx); err != nil { if err := foreign.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}} local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}}
if err := local.Insert(tx); err != nil { if err := local.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
{{else -}} {{else -}}
if err := local.Insert(tx); err != nil { if err := local.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}} foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}}
if err := foreign.Insert(tx); err != nil { if err := foreign.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
{{end -}} {{end -}}
check, err := local.{{.Function.Name}}(tx).One() check, err := local.{{.Function.Name}}(tx).One()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} {
t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}})
} }
slice := {{.LocalTable.NameGo}}Slice{&local} slice := {{.LocalTable.NameGo}}Slice{&local}
if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil { if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if local.R.{{.Function.Name}} == nil { if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded") t.Error("struct should have been eager loaded")
} }
local.R.{{.Function.Name}} = nil local.R.{{.Function.Name}} = nil
if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil { if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if local.R.{{.Function.Name}} == nil { if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded") t.Error("struct should have been eager loaded")
} }
} }
{{end -}} {{end -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.FKeys -}} {{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_test_helper" $rel -}} {{- template "relationship_to_one_test_helper" $rel -}}
{{end -}} {{end -}}
{{- end -}} {{- end -}}

View file

@ -2,131 +2,131 @@
{{- $varNameSingular := .ForeignKey.Table | singular | camelCase -}} {{- $varNameSingular := .ForeignKey.Table | singular | camelCase -}}
{{- $foreignVarNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} {{- $foreignVarNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
func test{{.LocalTable.NameGo}}ToOneSetOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { func test{{.LocalTable.NameGo}}ToOneSetOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{.LocalTable.NameGo}} var a {{.LocalTable.NameGo}}
var b, c {{.ForeignTable.NameGo}} var b, c {{.ForeignTable.NameGo}}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Insert(tx); err != nil { if err := a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = b.Insert(tx); err != nil { if err = b.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} { for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} {
err = a.Set{{.Function.Name}}(tx, i != 0, x) err = a.Set{{.Function.Name}}(tx, i != 0, x)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}) t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}})
} }
if a.R.{{.Function.Name}} != x { if a.R.{{.Function.Name}} != x {
t.Error("relationship struct not set to correct value") t.Error("relationship struct not set to correct value")
} }
zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}})) zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}}))
reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero) reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero)
if err = a.Reload(tx); err != nil { if err = a.Reload(tx); err != nil {
t.Fatal("failed to reload", err) t.Fatal("failed to reload", err)
} }
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}}) t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}})
} }
{{if .ForeignKey.Unique -}} {{if .ForeignKey.Unique -}}
if x.R.{{.Function.ForeignName}} != &a { if x.R.{{.Function.ForeignName}} != &a {
t.Error("failed to append to foreign relationship struct") t.Error("failed to append to foreign relationship struct")
} }
{{else -}} {{else -}}
if x.R.{{.Function.ForeignName}}[0] != &a { if x.R.{{.Function.ForeignName}}[0] != &a {
t.Error("failed to append to foreign relationship struct") t.Error("failed to append to foreign relationship struct")
} }
{{end -}} {{end -}}
} }
} }
{{- if .ForeignKey.Nullable}} {{- if .ForeignKey.Nullable}}
func test{{.LocalTable.NameGo}}ToOneRemoveOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { func test{{.LocalTable.NameGo}}ToOneRemoveOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
var err error var err error
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
var a {{.LocalTable.NameGo}} var a {{.LocalTable.NameGo}}
var b {{.ForeignTable.NameGo}} var b {{.ForeignTable.NameGo}}
seed := randomize.NewSeed() seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = a.Insert(tx); err != nil { if err = a.Insert(tx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil { if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = a.Remove{{.Function.Name}}(tx, &b); err != nil { if err = a.Remove{{.Function.Name}}(tx, &b); err != nil {
t.Error("failed to remove relationship") t.Error("failed to remove relationship")
} }
count, err := a.{{.Function.Name}}(tx).Count() count, err := a.{{.Function.Name}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 0 { if count != 0 {
t.Error("want no relationships remaining") t.Error("want no relationships remaining")
} }
if a.R.{{.Function.Name}} != nil { if a.R.{{.Function.Name}} != nil {
t.Error("R struct entry should be nil") t.Error("R struct entry should be nil")
} }
if a.{{.LocalTable.ColumnNameGo}}.Valid { if a.{{.LocalTable.ColumnNameGo}}.Valid {
t.Error("R struct entry should be nil") t.Error("R struct entry should be nil")
} }
{{if .ForeignKey.Unique -}} {{if .ForeignKey.Unique -}}
if b.R.{{.Function.ForeignName}} != nil { if b.R.{{.Function.ForeignName}} != nil {
t.Error("failed to remove a from b's relationships") t.Error("failed to remove a from b's relationships")
} }
{{else -}} {{else -}}
if len(b.R.{{.Function.ForeignName}}) != 0 { if len(b.R.{{.Function.ForeignName}}) != 0 {
t.Error("failed to remove a from b's relationships") t.Error("failed to remove a from b's relationships")
} }
{{end -}} {{end -}}
} }
{{end -}} {{end -}}
{{- end -}} {{- end -}}
{{- if .Table.IsJoinTable -}} {{- if .Table.IsJoinTable -}}
{{- else -}} {{- else -}}
{{- $dot := . -}} {{- $dot := . -}}
{{- range .Table.FKeys -}} {{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}} {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}}
{{template "relationship_to_one_setops_test_helper" $rel -}} {{template "relationship_to_one_setops_test_helper" $rel -}}
{{- end -}} {{- end -}}

View file

@ -3,45 +3,45 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Reload(t *testing.T) { func test{{$tableNamePlural}}Reload(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
if err = {{$varNameSingular}}.Reload(tx); err != nil { if err = {{$varNameSingular}}.Reload(tx); err != nil {
t.Error(err) t.Error(err)
} }
} }
func test{{$tableNamePlural}}ReloadAll(t *testing.T) { func test{{$tableNamePlural}}ReloadAll(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.ReloadAll(tx); err != nil { if err = slice.ReloadAll(tx); err != nil {
t.Error(err) t.Error(err)
} }
} }

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Select(t *testing.T) { func test{{$tableNamePlural}}Select(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
slice, err := {{$tableNamePlural}}(tx).All() slice, err := {{$tableNamePlural}}(tx).All()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if len(slice) != 1 { if len(slice) != 1 {
t.Error("want one record, got:", len(slice)) t.Error("want one record, got:", len(slice))
} }
} }

View file

@ -0,0 +1,131 @@
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
var (
dbMain tester
)
type tester interface {
setup() error
conn() (*sql.DB, error)
teardown() error
}
func TestMain(m *testing.M) {
if dbMain == nil {
fmt.Println("no dbMain tester interface was ready")
os.Exit(-1)
}
rand.Seed(time.Now().UnixNano())
var err error
// Load configuration
err = initViper()
if err != nil {
fmt.Println("unable to load config file")
os.Exit(-2)
}
setConfigDefaults()
if err := validateConfig("{{.DriverName}}"); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-3)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-4)
}
conn, err := dbMain.conn()
if err != nil {
fmt.Println("failed to get connection:", err)
}
var code int
boil.SetDB(conn)
code = m.Run()
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-5)
}
os.Exit(code)
}
func initViper() error {
var err error
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}
// setConfigDefaults is only necessary because of bugs in viper, noted in main
func setConfigDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
if viper.GetInt("postgres.port") == 0 {
viper.Set("postgres.port", 5432)
}
if viper.GetString("mysql.sslmode") == "" {
viper.Set("mysql.sslmode", "true")
}
if viper.GetInt("mysql.port") == 0 {
viper.Set("mysql.port", 3306)
}
}
func validateConfig(driverName string) error {
if driverName == "postgres" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
}
if driverName == "mysql" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
}
return errors.New("not a valid driver name")
}

View file

@ -7,53 +7,31 @@ func MustTx(transactor boil.Transactor, err error) boil.Transactor {
return transactor return transactor
} }
func initDBNameRand(input string) { var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE ONLY .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`)
sum := md5.Sum([]byte(input)) var rgxMySQLkey = regexp.MustCompile(`(?m)((,\n)?\s+CONSTRAINT.*?FOREIGN KEY.*?\n)+`)
var sumInt string func newFKeyDestroyer(regex *regexp.Regexp, reader io.Reader) io.Reader {
for _, v := range sum { return &fKeyDestroyer{
sumInt = sumInt + strconv.Itoa(int(v)) reader: reader,
rgx: regex,
} }
}
// Cut integer to 18 digits to ensure no int64 overflow. type fKeyDestroyer struct {
sumInt = sumInt[:18] reader io.Reader
buf *bytes.Buffer
rgx *regexp.Regexp
}
sumTmp := sumInt func (f *fKeyDestroyer) Read(b []byte) (int, error) {
for i, v := range sumInt { if f.buf == nil {
if v == '0' { all, err := ioutil.ReadAll(f.reader)
sumTmp = sumInt[i+1:] if err != nil {
continue return 0, err
} }
break
f.buf = bytes.NewBuffer(f.rgx.ReplaceAll(all, []byte{}))
} }
sumInt = sumTmp return f.buf.Read(b)
randSeed, err := strconv.ParseInt(sumInt, 0, 64)
if err != nil {
fmt.Printf("Unable to parse sumInt: %s", err)
os.Exit(-1)
}
dbNameRand = rand.New(rand.NewSource(randSeed))
}
var alphabetChars = "abcdefghijklmnopqrstuvwxyz"
func randStr(length int) string {
c := len(alphabetChars)
output := make([]rune, length)
for i := 0; i < length; i++ {
output[i] = rune(alphabetChars[dbNameRand.Intn(c)])
}
return string(output)
}
// getDBNameHash takes a database name in, and generates
// a random string using the database name as the rand Seed.
// getDBNameHash is used to generate unique test database names.
func getDBNameHash(input string) string {
initDBNameRand(input)
return randStr(40)
} }

View file

@ -1,37 +0,0 @@
var (
testCfg *Config
dbConn *sql.DB
)
func InitViper() error {
var err error
testCfg = &Config{}
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}

View file

@ -3,97 +3,97 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Update(t *testing.T) { func test{{$tableNamePlural}}Update(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
// If table only contains primary key columns, we need to pass // If table only contains primary key columns, we need to pass
// them into a whitelist to get a valid test result, // them into a whitelist to get a valid test result,
// otherwise the Update method will error because it will not be able to // otherwise the Update method will error because it will not be able to
// generate a whitelist (due to it excluding primary key columns). // generate a whitelist (due to it excluding primary key columns).
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Error(err) t.Error(err)
} }
} else { } else {
if err = {{$varNameSingular}}.Update(tx); err != nil { if err = {{$varNameSingular}}.Update(tx); err != nil {
t.Error(err) t.Error(err)
} }
} }
} }
func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) {
t.Parallel() t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{} {{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil { if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err) t.Error(err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
// Remove Primary keys and unique columns from what we plan to update // Remove Primary keys and unique columns from what we plan to update
var fields []string var fields []string
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
fields = {{$varNameSingular}}Columns fields = {{$varNameSingular}}Columns
} else { } else {
fields = strmangle.SetComplement( fields = strmangle.SetComplement(
{{$varNameSingular}}Columns, {{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns, {{$varNameSingular}}PrimaryKeyColumns,
) )
} }
value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}})) value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}}))
updateMap := M{} updateMap := M{}
for _, col := range fields { for _, col := range fields {
updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface()
} }
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.UpdateAll(tx, updateMap); err != nil { if err = slice.UpdateAll(tx, updateMap); err != nil {
t.Error(err) t.Error(err)
} }
} }

View file

@ -3,44 +3,47 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Upsert(t *testing.T) { func test{{$tableNamePlural}}Upsert(t *testing.T) {
t.Parallel() {{if not (eq .DriverName "postgres") -}}
t.Skip("not implemented for {{.DriverName}}")
{{end -}}
t.Parallel()
seed := randomize.NewSeed() seed := randomize.NewSeed()
var err error var err error
// Attempt the INSERT side of an UPSERT // Attempt the INSERT side of an UPSERT
{{$varNameSingular}} := {{$tableNameSingular}}{} {{$varNameSingular}} := {{$tableNameSingular}}{}
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
tx := MustTx(boil.Begin()) tx := MustTx(boil.Begin())
defer tx.Rollback() defer tx.Rollback()
if err = {{$varNameSingular}}.Upsert(tx, false, nil, nil); err != nil { if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}false, nil, {{end}}nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
} }
count, err := {{$tableNamePlural}}(tx).Count() count, err := {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
// Attempt the UPDATE side of an UPSERT // Attempt the UPDATE side of an UPSERT
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
} }
if err = {{$varNameSingular}}.Upsert(tx, true, nil, nil); err != nil { if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}true, nil, {{end}}nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
} }
count, err = {{$tableNamePlural}}(tx).Count() count, err = {{$tableNamePlural}}(tx).Count()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if count != 1 { if count != 1 {
t.Error("want one record, got:", count) t.Error("want one record, got:", count)
} }
} }

View file

@ -93,7 +93,46 @@ CREATE TABLE magic (
strange_three timestamp without time zone default (now() at time zone 'utc'), strange_three timestamp without time zone default (now() at time zone 'utc'),
strange_four timestamp with time zone default (now() at time zone 'utc'), strange_four timestamp with time zone default (now() at time zone 'utc'),
strange_five interval NOT NULL DEFAULT '21 days', strange_five interval NOT NULL DEFAULT '21 days',
strange_six interval NULL DEFAULT '23 hours' strange_six interval NULL DEFAULT '23 hours',
aa json NULL,
bb json NOT NULL,
cc jsonb NULL,
dd jsonb NOT NULL,
ee box NULL,
ff box NOT NULL,
gg cidr NULL,
hh cidr NOT NULL,
ii circle NULL,
jj circle NOT NULL,
kk double precision NULL,
ll double precision NOT NULL,
mm inet NULL,
nn inet NOT NULL,
oo line NULL,
pp line NOT NULL,
qq lseg NULL,
rr lseg NOT NULL,
ss macaddr NULL,
tt macaddr NOT NULL,
uu money NULL,
vv money NOT NULL,
ww path NULL,
xx path NOT NULL,
yy pg_lsn NULL,
zz pg_lsn NOT NULL,
aaa point NULL,
bbb point NOT NULL,
ccc polygon NULL,
ddd polygon NOT NULL,
eee tsquery NULL,
fff tsquery NOT NULL,
ggg tsvector NULL,
hhh tsvector NOT NULL,
iii txid_snapshot NULL,
jjj txid_snapshot NOT NULL,
kkk xml NULL,
lll xml NOT NULL
); );
create table owner ( create table owner (
@ -136,12 +175,6 @@ create table spider_toys (
primary key (spider_id) primary key (spider_id)
); );
/*
Test:
* Variations of capitalization
* Single value columns
* Primary key as only value
*/
create table pals ( create table pals (
pal character varying, pal character varying,
primary key (pal) primary key (pal)
@ -161,3 +194,22 @@ create table enemies (
enemies character varying, enemies character varying,
primary key (enemies) primary key (enemies)
); );
create table fun_arrays (
id serial,
fun_one integer[] null,
fun_two integer[] not null,
fun_three boolean[] null,
fun_four boolean[] not null,
fun_five varchar[] null,
fun_six varchar[] not null,
fun_seven decimal[] null,
fun_eight decimal[] not null,
fun_nine bytea[] null,
fun_ten bytea[] not null,
fun_eleven jsonb[] null,
fun_twelve jsonb[] not null,
fun_thirteen json[] null,
fun_fourteen json[] not null,
primary key (id)
)

View file

@ -12,7 +12,7 @@ import (
func TestTextsFromForeignKey(t *testing.T) { func TestTextsFromForeignKey(t *testing.T) {
t.Parallel() t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{}) tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) {
func TestTextsFromOneToOneRelationship(t *testing.T) { func TestTextsFromOneToOneRelationship(t *testing.T) {
t.Parallel() t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{}) tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) {
func TestTextsFromRelationship(t *testing.T) { func TestTextsFromRelationship(t *testing.T) {
t.Parallel() t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{}) tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

719
types/array.go Normal file
View file

@ -0,0 +1,719 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func encode(x interface{}) []byte {
switch v := x.(type) {
case int64:
return strconv.AppendInt(nil, v, 10)
case float64:
return strconv.AppendFloat(nil, v, 'f', -1, 64)
case []byte:
return encodeBytes(v)
case string:
return []byte(v)
case bool:
return strconv.AppendBool(nil, v)
case time.Time:
return formatTimestamp(v)
default:
panic(fmt.Errorf("encode: unknown type for %T", v))
}
}
// FormatTimestamp formats t into Postgres' text format for timestamps.
func formatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format(time.RFC3339Nano))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
func encodeBytes(v []byte) (result []byte) {
for _, b := range v {
if b == '\\' {
result = append(result, '\\', '\\')
} else if b < 0x20 || b > 0x7e {
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
} else {
result = append(result, b)
}
}
return result
}
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytes(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
//
// For example:
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
//
// var x []sql.NullInt64
// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
//
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
// bound is not one (such as `[0:0]={1}') are not supported.
func Array(a interface{}) interface {
driver.Valuer
sql.Scanner
} {
switch a := a.(type) {
case []bool:
return (*BoolArray)(&a)
case []float64:
return (*Float64Array)(&a)
case []int64:
return (*Int64Array)(&a)
case []string:
return (*StringArray)(&a)
case *[]bool:
return (*BoolArray)(a)
case *[]float64:
return (*Float64Array)(a)
case *[]int64:
return (*Int64Array)(a)
case *[]string:
return (*StringArray)(a)
default:
panic(fmt.Sprintf("boil: invalid type received %T", a))
}
}
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
// to override the array delimiter used by GenericArray.
type ArrayDelimiter interface {
// ArrayDelimiter returns the delimiter character(s) for this element's type.
ArrayDelimiter() string
}
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
type BoolArray []bool
// Scan implements the sql.Scanner interface.
func (a *BoolArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to BoolArray", src)
}
func (a *BoolArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BoolArray, len(elems))
for i, v := range elems {
if len(v) != 1 {
return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v)
}
switch v[0] {
case 't':
b[i] = true
case 'f':
b[i] = false
default:
return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a BoolArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be exactly two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1+2*n)
for i := 0; i < n; i++ {
b[2*i] = ','
if a[i] {
b[1+2*i] = 't'
} else {
b[1+2*i] = 'f'
}
}
b[0] = '{'
b[2*n] = '}'
return string(b), nil
}
return "{}", nil
}
// BytesArray represents a one-dimensional array of the PostgreSQL bytea type.
type BytesArray [][]byte
// Scan implements the sql.Scanner interface.
func (a *BytesArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to BytesArray", src)
}
func (a *BytesArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BytesArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BytesArray, len(elems))
for i, v := range elems {
b[i], err = parseBytes(v)
if err != nil {
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface. It uses the "hex" format which
// is only supported on PostgreSQL 9.0 or newer.
func (a BytesArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
size := 1 + 6*n
for _, x := range a {
size += hex.EncodedLen(len(x))
}
b := make([]byte, size)
for i, s := 0, b; i < n; i++ {
o := copy(s, `,"\\x`)
o += hex.Encode(s[o:], a[i])
s[o] = '"'
s = s[o+1:]
}
b[0] = '{'
b[size-1] = '}'
return string(b), nil
}
return "{}", nil
}
// Float64Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float64Array []float64
// Scan implements the sql.Scanner interface.
func (a *Float64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to Float64Array", src)
}
func (a *Float64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
return fmt.Errorf("boil: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
type Int64Array []int64
// Scan implements the sql.Scanner interface.
func (a *Int64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to Int64Array", src)
}
func (a *Int64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
return fmt.Errorf("boil: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, a[0], 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, a[i], 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// StringArray represents a one-dimensional array of the PostgreSQL character types.
type StringArray []string
// Scan implements the sql.Scanner interface.
func (a *StringArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to StringArray", src)
}
func (a *StringArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(StringArray, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("boil: parsing array element index %d: cannot convert nil to string", i)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, []byte(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(a[i]))
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// appendArray appends rv to the buffer, returning the extended buffer and
// the delimiter used between elements.
//
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
var del string
var err error
b = append(b, '{')
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
return b, del, err
}
for i := 1; i < n; i++ {
b = append(b, del...)
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
return b, del, err
}
}
return append(b, '}'), del, nil
}
// appendArrayElement appends rv to the buffer, returning the extended buffer
// and the delimiter to use before the next element.
//
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
// using driver.DefaultParameterConverter and the resulting []byte or string
// is double-quoted.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
if n := rv.Len(); n > 0 {
return appendArray(b, rv, n)
}
return b, "", nil
}
}
var del = ","
var err error
var iv interface{} = rv.Interface()
if ad, ok := iv.(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
return b, del, err
}
switch v := iv.(type) {
case nil:
return append(b, "NULL"...), del, nil
case []byte:
return appendArrayQuotedBytes(b, v), del, nil
case string:
return appendArrayQuotedBytes(b, []byte(v)), del, nil
}
b, err = appendValue(b, iv)
return b, del, err
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func appendValue(b []byte, v driver.Value) ([]byte, error) {
return append(b, encode(v)...), nil
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
if len(src) < 1 || src[0] != '{' {
return nil, nil, fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '{', 0)
}
Open:
for i < len(src) {
switch src[i] {
case '{':
depth++
i++
case '}':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
case '{':
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) {
dims[depth-1]++
i += len(del)
goto Element
} else if src[i] == '}' {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
if src[i] == '}' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
err = fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '}', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("boil: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("boil: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}

800
types/array_test.go Normal file
View file

@ -0,0 +1,800 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"database/sql"
"database/sql/driver"
"math/rand"
"reflect"
"strings"
"testing"
)
func TestParseArray(t *testing.T) {
for _, tt := range []struct {
input string
delim string
dims []int
elems [][]byte
}{
{`{}`, `,`, nil, [][]byte{}},
{`{NULL}`, `,`, []int{1}, [][]byte{nil}},
{`{a}`, `,`, []int{1}, [][]byte{{'a'}}},
{`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}},
{`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}},
{`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}},
{`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{
{'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
}},
{`{""}`, `,`, []int{1}, [][]byte{{}}},
{`{","}`, `,`, []int{1}, [][]byte{{','}}},
{`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}},
{`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}},
{`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}},
{`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{
{','}, {','}, {','}, {','}, {','}, {','},
}},
{`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}},
{`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}},
{`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}},
{`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}},
{`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{
{'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'},
}},
{`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}},
} {
dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim))
if err != nil {
t.Fatalf("Expected no error for %q, got %q", tt.input, err)
}
if !reflect.DeepEqual(dims, tt.dims) {
t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims)
}
if !reflect.DeepEqual(elems, tt.elems) {
t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems)
}
}
}
func TestParseArrayError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "expected '{' at offset 0"},
{`x`, "expected '{' at offset 0"},
{`}`, "expected '{' at offset 0"},
{`{`, "expected '}' at offset 1"},
{`{{}`, "expected '}' at offset 3"},
{`{}}`, "unexpected '}' at offset 2"},
{`{,}`, "unexpected ',' at offset 1"},
{`{,x}`, "unexpected ',' at offset 1"},
{`{x,}`, "unexpected '}' at offset 3"},
{`{""x}`, "unexpected 'x' at offset 3"},
{`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"},
} {
_, _, err := parseArray([]byte(tt.input), []byte{','})
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
}
}
func TestArrayScanner(t *testing.T) {
var s sql.Scanner
s = Array(&[]bool{})
if _, ok := s.(*BoolArray); !ok {
t.Errorf("Expected *BoolArray, got %T", s)
}
s = Array(&[]float64{})
if _, ok := s.(*Float64Array); !ok {
t.Errorf("Expected *Float64Array, got %T", s)
}
s = Array(&[]int64{})
if _, ok := s.(*Int64Array); !ok {
t.Errorf("Expected *Int64Array, got %T", s)
}
s = Array(&[]string{})
if _, ok := s.(*StringArray); !ok {
t.Errorf("Expected *StringArray, got %T", s)
}
}
func TestArrayValuer(t *testing.T) {
var v driver.Valuer
v = Array([]bool{})
if _, ok := v.(*BoolArray); !ok {
t.Errorf("Expected *BoolArray, got %T", v)
}
v = Array([]float64{})
if _, ok := v.(*Float64Array); !ok {
t.Errorf("Expected *Float64Array, got %T", v)
}
v = Array([]int64{})
if _, ok := v.(*Int64Array); !ok {
t.Errorf("Expected *Int64Array, got %T", v)
}
v = Array([]string{})
if _, ok := v.(*StringArray); !ok {
t.Errorf("Expected *StringArray, got %T", v)
}
}
func TestBoolArrayScanUnsupported(t *testing.T) {
var arr BoolArray
err := arr.Scan(1)
if err == nil {
t.Fatal("Expected error when scanning from int")
}
if !strings.Contains(err.Error(), "int to BoolArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var BoolArrayStringTests = []struct {
str string
arr BoolArray
}{
{`{}`, BoolArray{}},
{`{t}`, BoolArray{true}},
{`{f,t}`, BoolArray{false, true}},
}
func TestBoolArrayScanBytes(t *testing.T) {
for _, tt := range BoolArrayStringTests {
bytes := []byte(tt.str)
arr := BoolArray{true, true, true}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkBoolArrayScanBytes(b *testing.B) {
var a BoolArray
var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`)
for i := 0; i < b.N; i++ {
a = BoolArray{}
a.Scan(x)
}
}
func TestBoolArrayScanString(t *testing.T) {
for _, tt := range BoolArrayStringTests {
arr := BoolArray{true, true, true}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestBoolArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"},
{`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`},
{`{a}`, `could not parse boolean array index 0: invalid boolean "a"`},
{`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`},
{`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`},
} {
arr := BoolArray{true, true, true}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, BoolArray{true, true, true}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestBoolArrayValue(t *testing.T) {
result, err := BoolArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = BoolArray([]bool{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = BoolArray([]bool{false, true, false}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkBoolArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]bool, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.Intn(2) == 0
}
a := BoolArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestBytesArrayScanUnsupported(t *testing.T) {
var arr BytesArray
err := arr.Scan(1)
if err == nil {
t.Fatal("Expected error when scanning from int")
}
if !strings.Contains(err.Error(), "int to BytesArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var BytesArrayStringTests = []struct {
str string
arr BytesArray
}{
{`{}`, BytesArray{}},
{`{NULL}`, BytesArray{nil}},
{`{"\\xfeff"}`, BytesArray{{'\xFE', '\xFF'}}},
{`{"\\xdead","\\xbeef"}`, BytesArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}},
}
func TestBytesArrayScanBytes(t *testing.T) {
for _, tt := range BytesArrayStringTests {
bytes := []byte(tt.str)
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkBytesArrayScanBytes(b *testing.B) {
var a BytesArray
var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`)
for i := 0; i < b.N; i++ {
a = BytesArray{}
a.Scan(x)
}
}
func TestBytesArrayScanString(t *testing.T) {
for _, tt := range BytesArrayStringTests {
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestBytesArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to BytesArray"},
{`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"},
} {
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, BytesArray{{2}, {6}, {0, 0}}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestBytesArrayValue(t *testing.T) {
result, err := BytesArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = BytesArray([][]byte{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = BytesArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkBytesArrayValue(b *testing.B) {
rand.Seed(1)
x := make([][]byte, 10)
for i := 0; i < len(x); i++ {
x[i] = make([]byte, len(x))
for j := 0; j < len(x); j++ {
x[i][j] = byte(rand.Int())
}
}
a := BytesArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestFloat64ArrayScanUnsupported(t *testing.T) {
var arr Float64Array
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to Float64Array") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var Float64ArrayStringTests = []struct {
str string
arr Float64Array
}{
{`{}`, Float64Array{}},
{`{1.2}`, Float64Array{1.2}},
{`{3.456,7.89}`, Float64Array{3.456, 7.89}},
{`{3,1,2}`, Float64Array{3, 1, 2}},
}
func TestFloat64ArrayScanBytes(t *testing.T) {
for _, tt := range Float64ArrayStringTests {
bytes := []byte(tt.str)
arr := Float64Array{5, 5, 5}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkFloat64ArrayScanBytes(b *testing.B) {
var a Float64Array
var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`)
for i := 0; i < b.N; i++ {
a = Float64Array{}
a.Scan(x)
}
}
func TestFloat64ArrayScanString(t *testing.T) {
for _, tt := range Float64ArrayStringTests {
arr := Float64Array{5, 5, 5}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestFloat64ArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"},
{`{NULL}`, "parsing array element index 0:"},
{`{a}`, "parsing array element index 0:"},
{`{5.6,a}`, "parsing array element index 1:"},
{`{5.6,7.8,a}`, "parsing array element index 2:"},
} {
arr := Float64Array{5, 5, 5}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestFloat64ArrayValue(t *testing.T) {
result, err := Float64Array(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = Float64Array([]float64{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkFloat64ArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]float64, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.NormFloat64()
}
a := Float64Array(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestInt64ArrayScanUnsupported(t *testing.T) {
var arr Int64Array
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to Int64Array") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var Int64ArrayStringTests = []struct {
str string
arr Int64Array
}{
{`{}`, Int64Array{}},
{`{12}`, Int64Array{12}},
{`{345,678}`, Int64Array{345, 678}},
}
func TestInt64ArrayScanBytes(t *testing.T) {
for _, tt := range Int64ArrayStringTests {
bytes := []byte(tt.str)
arr := Int64Array{5, 5, 5}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkInt64ArrayScanBytes(b *testing.B) {
var a Int64Array
var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`)
for i := 0; i < b.N; i++ {
a = Int64Array{}
a.Scan(x)
}
}
func TestInt64ArrayScanString(t *testing.T) {
for _, tt := range Int64ArrayStringTests {
arr := Int64Array{5, 5, 5}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestInt64ArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"},
{`{NULL}`, "parsing array element index 0:"},
{`{a}`, "parsing array element index 0:"},
{`{5,a}`, "parsing array element index 1:"},
{`{5,6,a}`, "parsing array element index 2:"},
} {
arr := Int64Array{5, 5, 5}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestInt64ArrayValue(t *testing.T) {
result, err := Int64Array(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = Int64Array([]int64{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = Int64Array([]int64{1, 2, 3}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkInt64ArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]int64, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.Int63()
}
a := Int64Array(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestStringArrayScanUnsupported(t *testing.T) {
var arr StringArray
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to StringArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var StringArrayStringTests = []struct {
str string
arr StringArray
}{
{`{}`, StringArray{}},
{`{t}`, StringArray{"t"}},
{`{f,1}`, StringArray{"f", "1"}},
{`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}},
}
func TestStringArrayScanBytes(t *testing.T) {
for _, tt := range StringArrayStringTests {
bytes := []byte(tt.str)
arr := StringArray{"x", "x", "x"}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkStringArrayScanBytes(b *testing.B) {
var a StringArray
var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`)
var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`)
for i := 0; i < b.N; i++ {
a = StringArray{}
a.Scan(x)
a = StringArray{}
a.Scan(y)
}
}
func TestStringArrayScanString(t *testing.T) {
for _, tt := range StringArrayStringTests {
arr := StringArray{"x", "x", "x"}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestStringArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"},
{`{NULL}`, "parsing array element index 0: cannot convert nil to string"},
{`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"},
{`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"},
} {
arr := StringArray{"x", "x", "x"}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestStringArrayValue(t *testing.T) {
result, err := StringArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = StringArray([]string{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkStringArrayValue(b *testing.B) {
x := make([]string, 10)
for i := 0; i < len(x); i++ {
x[i] = strings.Repeat(`abc"def\ghi`, 5)
}
a := StringArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}

135
types/hstore.go Normal file
View file

@ -0,0 +1,135 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"database/sql"
"database/sql/driver"
"strings"
)
// HStore is a wrapper for transferring HStore values back and forth easily.
type HStore map[string]sql.NullString
// escapes and quotes hstore keys/values
// s should be a sql.NullString or string
func hQuote(s interface{}) string {
var str string
switch v := s.(type) {
case sql.NullString:
if !v.Valid {
return "NULL"
}
str = v.String
case string:
str = v
default:
panic("not a string or sql.NullString")
}
str = strings.Replace(str, "\\", "\\\\", -1)
return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"`
}
// Scan implements the Scanner interface.
//
// Note h is reallocated before the scan to clear existing values. If the
// hstore column's database value is NULL, then h is set to nil instead.
func (h *HStore) Scan(value interface{}) error {
if value == nil {
h = nil
return nil
}
*h = make(map[string]sql.NullString)
var b byte
pair := [][]byte{{}, {}}
pi := 0
inQuote := false
didQuote := false
sawSlash := false
bindex := 0
for bindex, b = range value.([]byte) {
if sawSlash {
pair[pi] = append(pair[pi], b)
sawSlash = false
continue
}
switch b {
case '\\':
sawSlash = true
continue
case '"':
inQuote = !inQuote
if !didQuote {
didQuote = true
}
continue
default:
if !inQuote {
switch b {
case ' ', '\t', '\n', '\r':
continue
case '=':
continue
case '>':
pi = 1
didQuote = false
continue
case ',':
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
pair[0] = []byte{}
pair[1] = []byte{}
pi = 0
continue
}
}
}
pair[pi] = append(pair[pi], b)
}
if bindex > 0 {
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
}
return nil
}
// Value implements the driver Valuer interface. Note if h is nil, the
// database column value will be set to NULL.
func (h HStore) Value() (driver.Value, error) {
if h == nil {
return nil, nil
}
parts := []string{}
for key, val := range h {
thispart := hQuote(key) + "=>" + hQuote(val)
parts = append(parts, thispart)
}
return []byte(strings.Join(parts, ",")), nil
}

77
types/json.go Normal file
View file

@ -0,0 +1,77 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// JSON is an alias for json.RawMessage, which is
// a []byte underneath.
// JSON implements Marshal and Unmarshal.
type JSON json.RawMessage
// String output your JSON.
func (j JSON) String() string {
return string(j)
}
// Unmarshal your JSON variable into dest.
func (j JSON) Unmarshal(dest interface{}) error {
return json.Unmarshal(j, dest)
}
// Marshal obj into your JSON variable.
func (j *JSON) Marshal(obj interface{}) error {
res, err := json.Marshal(obj)
if err != nil {
return err
}
*j = res
return nil
}
// UnmarshalJSON sets *j to a copy of data.
func (j *JSON) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSON: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
// MarshalJSON returns j as the JSON encoding of j.
func (j JSON) MarshalJSON() ([]byte, error) {
return j, nil
}
// Value returns j as a value.
// Unmarshal into RawMessage for validation.
func (j JSON) Value() (driver.Value, error) {
var r json.RawMessage
if err := j.Unmarshal(&r); err != nil {
return nil, err
}
return []byte(r), nil
}
// Scan stores the src in *j.
func (j *JSON) Scan(src interface{}) error {
var source []byte
switch src.(type) {
case string:
source = []byte(src.(string))
case []byte:
source = src.([]byte)
default:
return errors.New("Incompatible type for JSON")
}
*j = JSON(append((*j)[0:0], source...))
return nil
}

119
types/json_test.go Normal file
View file

@ -0,0 +1,119 @@
package types
import (
"bytes"
"testing"
)
func TestJSONString(t *testing.T) {
t.Parallel()
j := JSON("hello")
if j.String() != "hello" {
t.Errorf("Expected %q, got %s", "hello", j.String())
}
}
func TestJSONUnmarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
var jt JSONTest
j := JSON(`{"Name":"hi","Age":15}`)
err := j.Unmarshal(&jt)
if err != nil {
t.Error(err)
}
if jt.Name != "hi" {
t.Errorf("Expected %q, got %s", "hi", jt.Name)
}
if jt.Age != 15 {
t.Errorf("Expected %v, got %v", 15, jt.Age)
}
}
func TestJSONMarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
jt := JSONTest{
Name: "hi",
Age: 15,
}
var j JSON
err := j.Marshal(jt)
if err != nil {
t.Error(err)
}
if j.String() != `{"Name":"hi","Age":15}` {
t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String())
}
}
func TestJSONUnmarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(nil)
err := j.UnmarshalJSON(JSON(`"hi"`))
if err != nil {
t.Error(err)
}
if j.String() != `"hi"` {
t.Errorf("Expected %q, got %s", "hi", j.String())
}
}
func TestJSONMarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(`"hi"`)
res, err := j.MarshalJSON()
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(`"hi"`)) {
t.Errorf("Expected %q, got %v", `"hi"`, res)
}
}
func TestJSONValue(t *testing.T) {
t.Parallel()
j := JSON(`{"Name":"hi","Age":15}`)
v, err := j.Value()
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, v.([]byte)) {
t.Errorf("byte mismatch, %v %v", j, v)
}
}
func TestJSONScan(t *testing.T) {
t.Parallel()
j := JSON{}
err := j.Scan(`"hello"`)
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, []byte(`"hello"`)) {
t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`)))
}
}