Merge branch 'dev'

This commit is contained in:
Patrick O'brien 2016-09-15 19:52:44 +10:00
commit 65dd15de09
99 changed files with 6516 additions and 3332 deletions

159
README.md
View file

@ -80,23 +80,28 @@ Table of Contents
### Features
- Full model generation
- High performance through generation
- Extremely fast code generation
- High performance through generation & intelligent caching
- Uses boil.Executor (simple interface, sql.DB, sqlx.DB etc. compatible)
- Easy workflow (models can always be regenerated, full auto-complete)
- Strongly typed querying (usually no converting or binding to pointers)
- Hooks (Before/After Create/Select/Update/Delete/Upsert)
- Automatic CreatedAt/UpdatedAt
- Table whitelist/blacklist
- Relationships/Associations
- Eager loading (recursive)
- Custom struct tags
- Schema support
- Transactions
- Raw SQL fallback
- Compatibility tests (Run against your own DB schema)
- Debug logging
- Postgres 1d arrays, json, hstore & more
### Supported Databases
- PostgreSQL
- MySQL
*Note: Seeking contributors for other database engines.*
@ -203,30 +208,32 @@ order:
- `$XDG_CONFIG_HOME/sqlboiler/`
- `$HOME/.config/sqlboiler/`
We require you pass in the `postgres` configuration via the configuration file rather than env vars.
There is no command line argument support for database configuration. Values given under the `postgres`
block are passed directly to the [pq](github.com/lib/pq) driver. Here is a rundown of all the different
We require you pass in your `postgres` and `mysql` database configuration via the configuration file rather than env vars.
There is no command line argument support for database configuration. Values given under the `postgres` and `mysql`
block are passed directly to the postgres and mysql drivers. Here is a rundown of all the different
values that can go in that section:
| Name | Required | Default |
| --- | --- | --- |
| dbname | yes | none |
| host | yes | none |
| port | no | 5432 |
| user | yes | none |
| pass | no | none |
| sslmode | no | "require" |
| Name | Required | Postgres Default | MySQL Default |
| --- | --- | --- | --- |
| dbname | yes | none | none |
| host | yes | none | none |
| port | no | 5432 | 3306 |
| user | yes | none | none |
| pass | no | none | none |
| sslmode | no | "require" | "true" |
You can also pass in these top level configuration values if you would prefer
not to pass them through the command line or environment variables:
| Name | Default |
| --- | --- |
| Name | Defaults |
| ------------------ | --------- |
| basedir | none |
| schema | "public" *(or dbname for mysql)* |
| pkgname | "models" |
| output | "models" |
| exclude | [ ] |
| tag | [ ] |
| whitelist | [] |
| blacklist | [] |
| tag | [] |
| debug | false |
| no-hooks | false |
| no-tests | false |
@ -256,23 +263,26 @@ Usage:
sqlboiler [flags] <driver>
Examples:
sqlboiler postgres
sqlboiler postgres
sqlboiler mysql
Flags:
-b, --basedir string The base directory has the templates and templates_test folders
-b, --blacklist stringSlice Do not include these tables in your generated package
-w, --whitelist stringSlice Only include these tables in your generated package
-s, --schema string The name of your database schema, for databases that support real schemas (default "public")
-p, --pkgname string The name you wish to assign to your generated package (default "models")
-o, --output string The name of the folder to output to (default "models")
-t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml
-d, --debug Debug mode prints stack traces on error
-x, --exclude stringSlice Tables to be excluded from the generated package
--basedir string The base directory has the templates and templates_test folders
--no-auto-timestamps Disable automatic timestamps for created_at/updated_at
--no-hooks Disable hooks feature for your models
--no-tests Disable generated go test files
-o, --output string The name of the folder to output to (default "models")
-p, --pkgname string The name you wish to assign to your generated package (default "models")
-t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml
```
Follow the steps below to do some basic model generation. Once we've generated
our models, we can run the compatibility tests which will exercise the entirety
of the generated code. This way we can ensure that our database is compatible
Follow the steps below to do some basic model generation. Once you've generated
your models, you can run the compatibility tests which will exercise the entirety
of the generated code. This way you can ensure that your database is compatible
with SQLBoiler. If you find there are some failing tests, please check the
[Diagnosing Problems](#diagnosing-problems) section.
@ -281,8 +291,7 @@ with SQLBoiler. If you find there are some failing tests, please check the
sqlboiler -x goose_migrations postgres
# Run the generated tests
go test ./models # This requires an administrator postgres user because of some
# voodoo we do to disable triggers for the generated test db
go test ./models
```
## Diagnosing Problems
@ -292,7 +301,7 @@ The most common causes of problems and panics are:
- Forgetting to exclude tables you do not want included in your generation, like migration tables.
- Tables without a primary key. All tables require one.
- Forgetting to put foreign key constraints on your columns that reference other tables.
- The compatibility tests that run against your own DB schema require a superuser, ensure the user
- The compatibility tests require privileges to create a database for testing purposes, ensure the user
supplied in your `sqlboiler.toml` config has adequate privileges.
- A nil or closed database handle. Ensure your passed in `boil.Executor` is not nil.
- If you decide to use the `G` variant of functions instead, make sure you've initialized your
@ -345,9 +354,8 @@ ALTER TABLE pilot_languages ADD CONSTRAINT pilots_fkey FOREIGN KEY (pilot_id) RE
ALTER TABLE pilot_languages ADD CONSTRAINT languages_fkey FOREIGN KEY (language_id) REFERENCES languages(id);
```
The generated model structs for this schema look like the following. Note that I've included the relationship
structs as well so you can see how it all pieces together, but these are unexported and not something you should
ever need to touch directly:
The generated model structs for this schema look like the following. Note that we've included the relationship
structs as well so you can see how it all pieces together:
```go
type Pilot struct {
@ -355,6 +363,7 @@ type Pilot struct {
Name string `boil:"name" json:"name" toml:"name" yaml:"name"`
R *pilotR `boil:"-" json:"-" toml:"-" yaml:"-"`
L pilotR `boil:"-" json:"-" toml:"-" yaml:"-"`
}
type pilotR struct {
@ -371,6 +380,7 @@ type Jet struct {
Color string `boil:"color" json:"color" toml:"color" yaml:"color"`
R *jetR `boil:"-" json:"-" toml:"-" yaml:"-"`
L jetR `boil:"-" json:"-" toml:"-" yaml:"-"`
}
type jetR struct {
@ -382,6 +392,7 @@ type Language struct {
Language string `boil:"language" json:"language" toml:"language" yaml:"language"`
R *languageR `boil:"-" json:"-" toml:"-" yaml:"-"`
L languageR `boil:"-" json:"-" toml:"-" yaml:"-"`
}
type languageR struct {
@ -414,7 +425,7 @@ Note: You can set the timezone for this feature by calling `boil.SetLocation()`
This is somewhat of a work around until we can devise a better solution in a later version.
* **Update**
* The `updated_at` column will always be set to `time.Now()`. If you need to override
this value you will need to fall back to another method in the meantime: `boil.SQL()`,
this value you will need to fall back to another method in the meantime: `queries.Raw()`,
overriding `updated_at` in all of your objects using a hook, or create your own wrapper.
* **Upsert**
* `created_at` will be set automatically if it is a zero value, otherwise your supplied value
@ -452,36 +463,8 @@ err := models.NewQuery(db, From("pilots")).All()
As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers)
allow you to execute the final action.
If you plan on executing the same query with the same values using the query builder,
you should do so like the following to utilize caching:
```go
// Instead of this:
for i := 0; i < 10; i++ {
pilots := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)).All()
}
// You should do this
query := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5))
for i := 0; i < 10; i++ {
pilots := query.All()
}
// Every execution of All() after the first will use a cached version of
// the built query that short circuits the query builder all together.
// This allows you to save on performance.
// Just something to be aware of: query mods don't store pointers, so if
// your passed in variable's value changes, your generated query will not change.
```
Note: You will see exported `boil.SetX` methods in the boil package. These should not be used on query
objects because they will break caching. Unfortunately these had to be exported due to some circular
dependency issues, but they're not functionality we want exposed. If you want a different
query object, generate a new one.
Take a look at our [Relationships Query Building](#relationships) section for some additional query
building information.
We also generate query building helper methods for your relationships as well. Take a look at our
[Relationships Query Building](#relationships) section for some additional query building information.
### Query Mod System
@ -575,26 +558,31 @@ UpdateAll(models.M{"name": "John", "age": 23}) // Update all rows matching the b
DeleteAll() // Delete all rows matching the built query.
Exists() // Returns a bool indicating whether the row(s) for the built query exists.
Bind(&myObj) // Bind the results of a query to your own struct object.
Exec() // Execute an SQL query that does not require any rows returned.
QueryRow() // Execute an SQL query expected to return only a single row.
Query() // Execute an SQL query expected to return multiple rows.
```
### Raw Query
We provide `boil.SQL()` for executing raw queries. Generally you will want to use `Bind()` with
We provide `queries.Raw()` for executing raw queries. Generally you will want to use `Bind()` with
this, like the following:
```go
err := boil.SQL(db, "select * from pilots where id=$1", 5).Bind(&obj)
err := queries.Raw(db, "select * from pilots where id=$1", 5).Bind(&obj)
```
You can use your own structs or a generated struct as a parameter to Bind. Bind supports both
a single object for single row queries and a slice of objects for multiple row queries.
You also have `models.NewQuery()` at your disposal if you would still like to use [Query Build](#query-building)
but would like to build against a non-generated model.
`queries.Raw()` also has a method that can execute a query without binding to an object, if required.
You also have `models.NewQuery()` at your disposal if you would still like to use [Query Building](#query-building)
in combination with your own custom, non-generated model.
### Binding
For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/boil#Bind).
For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/queries#Bind).
The `Bind()` [Finisher](#finisher) allows the results of a query built with
the [Raw SQL](#raw-query) method or the [Query Builder](#query-building) methods to be bound
@ -613,7 +601,7 @@ type PilotAndJet struct {
var paj PilotAndJet
// Use a raw query
err := boil.SQL(`
err := queries.Raw(`
select pilots.id as "pilots.id", pilots.name as "pilots.name",
jets.id as "jets.id", jets.pilot_id as "jets.pilot_id",
jets.age as "jets.age", jets.name as "jets.name", jets.color as "jets.color"
@ -641,7 +629,7 @@ var info JetInfo
err := models.NewQuery(db, Select("sum(age) as age_sum", "count(*) as juicy_count", From("jets"))).Bind(&info)
// Use a raw query
err := boil.SQL(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info)
err := queries.Raw(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info)
```
We support the following struct tag modes for `Bind()` control:
@ -905,8 +893,8 @@ err := p1.Insert(db) // Insert the first pilot with name "Larry"
// p1 now has an ID field set to 1
var p2 models.Pilot
p2.Name "Borris"
err := p2.Insert(db) // Insert the second pilot with name "Borris"
p2.Name "Boris"
err := p2.Insert(db) // Insert the second pilot with name "Boris"
// p2 now has an ID field set to 2
var p3 models.Pilot
@ -999,8 +987,13 @@ p1.Name = "Hogan"
err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name")
```
The `updateOnConflict` argument allows you to specify whether you would like Postgres
to perform a `DO NOTHING` on conflict, opposed to a `DO UPDATE`. For MySQL, this param will not be generated.
The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres.
For MySQL, this param will not be generated.
Note: Passing a different set of column values to the update component is not currently supported.
If this feature is important to you let us know and we can consider adding something for this.
### Reload
In the event that your objects get out of sync with the database for whatever reason,
@ -1010,7 +1003,7 @@ attached to the objects.
```go
pilot, _ := models.FindPilot(db, 1)
// > Object becomes out of sync for some reason
// > Object becomes out of sync for some reason, perhaps async processing
// Refresh the object with the latest data from the db
err := pilot.Reload(db)
@ -1051,10 +1044,26 @@ The generated models might import a couple of packages that are not on your syst
`cd` into your generated models directory and type `go get -u -t` to fetch them. You will only need
to run this command once, not per generation.
#### How should I handle multiple schemas?
If your database uses multiple schemas you should generate a new package for each of your schemas.
Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not
fake schemas (like MySQL).
#### How do I use types.BytesArray for Postgres bytea arrays?
Only "escaped format" is supported for types.BytesArray. This means that your byte slice needs to have
a format of "\\x00" (4 bytes per byte) opposed to "\x00" (1 byte per byte). This is to maintain compatibility
with all Postgres drivers. Example:
`x := types.BytesArray{0: []byte("\\x68\\x69")}`
Please note that multi-dimensional Postgres ARRAY types are not supported at this time.
#### Where is the homepage?
The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler)
[Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler
The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler)
generator is located at: https://github.com/vattle/sqlboiler
## Benchmarks

View file

@ -5,6 +5,11 @@ import "github.com/vattle/sqlboiler/strmangle"
// Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType.
type Column struct {
// ArrType is the underlying data type of the Postgres
// ARRAY type. See here:
// https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html
ArrType *string
UDTName string
Name string
Type string
DBType string

View file

@ -9,13 +9,16 @@ import (
type MockDriver struct{}
// TableNames returns a list of mock table names
func (m *MockDriver) TableNames(exclude []string) ([]string, error) {
func (m *MockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil
return strmangle.SetComplement(tables, blacklist), nil
}
// Columns returns a list of mock columns
func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) {
func (m *MockDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
return map[string][]bdb.Column{
"pilots": {
{Name: "id", Type: "int", DBType: "integer"},
@ -56,7 +59,7 @@ func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) {
}
// ForeignKeyInfo returns a list of mock foreignkeys
func (m *MockDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) {
func (m *MockDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
return map[string][]bdb.ForeignKey{
"jets": {
{Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true},
@ -79,7 +82,7 @@ func (m *MockDriver) TranslateColumnType(c bdb.Column) bdb.Column {
}
// PrimaryKeyInfo returns mock primary key info for the passed in table name
func (m *MockDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) {
func (m *MockDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
return map[string]*bdb.PrimaryKey{
"pilots": {
Name: "pilot_id_pkey",
@ -120,3 +123,18 @@ func (m *MockDriver) Open() error { return nil }
// Close mimics a database close call
func (m *MockDriver) Close() {}
// RightQuote is the quoting character for the right side of the identifier
func (m *MockDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m *MockDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate fake support of indexed placeholders
func (m *MockDriver) IndexPlaceholders() bool {
return false
}

321
bdb/drivers/mysql.go Normal file
View file

@ -0,0 +1,321 @@
package drivers
import (
"database/sql"
"fmt"
"strconv"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
)
// MySQLDriver holds the database connection string and a handle
// to the database connection.
type MySQLDriver struct {
connStr string
dbConn *sql.DB
}
// NewMySQLDriver takes the database connection details as parameters and
// returns a pointer to a MySQLDriver object. Note that it is required to
// call MySQLDriver.Open() and MySQLDriver.Close() to open and close
// the database connection once an object has been obtained.
func NewMySQLDriver(user, pass, dbname, host string, port int, sslmode string) *MySQLDriver {
driver := MySQLDriver{
connStr: MySQLBuildQueryString(user, pass, dbname, host, port, sslmode),
}
return &driver
}
// MySQLBuildQueryString builds a query string for MySQL.
func MySQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
var config mysql.Config
config.User = user
if len(pass) != 0 {
config.Passwd = pass
}
config.DBName = dbname
config.Net = "tcp"
config.Addr = host
if port == 0 {
port = 3306
}
config.Addr += ":" + strconv.Itoa(port)
config.TLSConfig = sslmode
return config.FormatDSN()
}
// Open opens the database connection using the connection string
func (m *MySQLDriver) Open() error {
var err error
m.dbConn, err = sql.Open("mysql", m.connStr)
if err != nil {
return err
}
return nil
}
// Close closes the database connection
func (m *MySQLDriver) Close() {
m.dbConn.Close()
}
// UseLastInsertID returns false for postgres
func (m *MySQLDriver) UseLastInsertID() bool {
return true
}
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public.
func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
var names []string
query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`)
args := []interface{}{schema}
if len(whitelist) > 0 {
query += fmt.Sprintf(" and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:])
for _, w := range whitelist {
args = append(args, w)
}
} else if len(blacklist) > 0 {
query += fmt.Sprintf(" and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:])
for _, b := range blacklist {
args = append(args, b)
}
}
rows, err := m.dbConn.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, nil
}
// Columns takes a table name and attempts to retrieve the table information
// from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string"
func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := m.dbConn.Query(`
select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable,
exists (
select c.column_name
from information_schema.table_constraints tc
inner join information_schema.key_column_usage kcu
on tc.constraint_name = kcu.constraint_name and tc.table_name = kcu.table_name and tc.table_schema = kcu.table_schema
where c.column_name = kcu.column_name and tc.table_name = c.table_name and
(tc.constraint_type = 'PRIMARY KEY' or tc.constraint_type = 'UNIQUE')
) as is_unique
from information_schema.columns as c
where table_name = ? and table_schema = ?;
`, tableName, schema)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var colName, colType, colDefault, nullable string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
if defaultPtr != nil && *defaultPtr != "NULL" {
colDefault = *defaultPtr
}
column := bdb.Column{
Name: colName,
DBType: colType,
Default: colDefault,
Nullable: nullable == "YES",
Unique: unique,
}
columns = append(columns, column)
}
return columns, nil
}
// PrimaryKeyInfo looks up the primary key for a table.
func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
pkey := &bdb.PrimaryKey{}
var err error
query := `
select tc.constraint_name
from information_schema.table_constraints as tc
where tc.table_name = ? and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = ?;`
row := m.dbConn.QueryRow(query, tableName, schema)
if err = row.Scan(&pkey.Name); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
queryColumns := `
select kcu.column_name
from information_schema.key_column_usage as kcu
where table_name = ? and constraint_name = ? and table_schema = ?;`
var rows *sql.Rows
if rows, err = m.dbConn.Query(queryColumns, tableName, pkey.Name, schema); err != nil {
return nil, err
}
defer rows.Close()
var columns []string
for rows.Next() {
var column string
err = rows.Scan(&column)
if err != nil {
return nil, err
}
columns = append(columns, column)
}
if err = rows.Err(); err != nil {
return nil, err
}
pkey.Columns = columns
return pkey, nil
}
// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
var fkeys []bdb.ForeignKey
query := `
select constraint_name, table_name, column_name, referenced_table_name, referenced_column_name
from information_schema.key_column_usage
where table_schema = ? and referenced_table_schema = ? and table_name = ?
`
var rows *sql.Rows
var err error
if rows, err = m.dbConn.Query(query, schema, schema, tableName); err != nil {
return nil, err
}
for rows.Next() {
var fkey bdb.ForeignKey
var sourceTable string
fkey.Table = tableName
err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
if err != nil {
return nil, err
}
fkeys = append(fkeys, fkey)
}
if err = rows.Err(); err != nil {
return nil, err
}
return fkeys, nil
}
// TranslateColumnType converts postgres database types to Go types, for example
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
// as a Column object.
func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column {
if c.Nullable {
switch c.DBType {
case "tinyint":
c.Type = "null.Int8"
case "smallint":
c.Type = "null.Int16"
case "mediumint", "int", "integer":
c.Type = "null.Int"
case "bigint":
c.Type = "null.Int64"
case "float":
c.Type = "null.Float32"
case "double", "double precision", "real":
c.Type = "null.Float64"
case "boolean", "bool":
c.Type = "null.Bool"
case "date", "datetime", "timestamp", "time":
c.Type = "null.Time"
case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
c.Type = "null.Bytes"
case "json":
c.Type = "types.JSON"
default:
c.Type = "null.String"
}
} else {
switch c.DBType {
case "tinyint":
c.Type = "int8"
case "smallint":
c.Type = "int16"
case "mediumint", "int", "integer":
c.Type = "int"
case "bigint":
c.Type = "null.Int64"
case "float":
c.Type = "float32"
case "double", "double precision", "real":
c.Type = "float64"
case "boolean", "bool":
c.Type = "bool"
case "date", "datetime", "timestamp", "time":
c.Type = "time.Time"
case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
c.Type = "[]byte"
case "json":
c.Type = "types.JSON"
default:
c.Type = "string"
}
}
return c
}
// RightQuote is the quoting character for the right side of the identifier
func (m *MySQLDriver) RightQuote() byte {
return '`'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m *MySQLDriver) LeftQuote() byte {
return '`'
}
// IndexPlaceholders returns false to indicate MySQL doesnt support indexed placeholders
func (m *MySQLDriver) IndexPlaceholders() bool {
return false
}

View file

@ -19,23 +19,20 @@ type PostgresDriver struct {
dbConn *sql.DB
}
// validatedTypes are types that cannot be zero values in the database.
var validatedTypes = []string{"uuid"}
// NewPostgresDriver takes the database connection details as parameters and
// returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
// the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{
connStr: BuildQueryString(user, pass, dbname, host, port, sslmode),
connStr: PostgresBuildQueryString(user, pass, dbname, host, port, sslmode),
}
return &driver
}
// BuildQueryString for Postgres
func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
// PostgresBuildQueryString builds a query string.
func PostgresBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
parts := []string{}
if len(user) != 0 {
parts = append(parts, fmt.Sprintf("user=%s", user))
@ -82,21 +79,25 @@ func (p *PostgresDriver) UseLastInsertID() bool {
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables
// such as gorp_migrations
func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
// table schema is schema. It uses a whitelist and blacklist.
func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
var names []string
query := `select table_name from information_schema.tables where table_schema = 'public'`
if len(exclude) > 0 {
quoteStr := func(x string) string {
return `'` + x + `'`
query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`)
args := []interface{}{schema}
if len(whitelist) > 0 {
query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(true, len(whitelist), 2, 1))
for _, w := range whitelist {
args = append(args, w)
}
} else if len(blacklist) > 0 {
query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(true, len(blacklist), 2, 1))
for _, b := range blacklist {
args = append(args, b)
}
exclude = strmangle.StringMap(quoteStr, exclude)
query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ","))
}
rows, err := p.dbConn.Query(query)
rows, err := p.dbConn.Query(query, args...)
if err != nil {
return nil, err
@ -118,11 +119,11 @@ func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) {
// from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string"
func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) {
var columns []bdb.Column
rows, err := p.dbConn.Query(`
select column_name, data_type, column_default, is_nullable,
select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable,
(select exists(
select 1
from information_schema.constraint_column_usage as ccu
@ -136,11 +137,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
inner join pg_index pgi on pgi.indexrelid = pgc.oid
inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
where
pgix.schemaname = 'public' and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true
)) as is_unique
from information_schema.columns as c
where table_name=$1 and table_schema = 'public';
`, tableName)
from information_schema.columns as c LEFT JOIN information_schema.element_types e
ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier)
= (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
where c.table_name=$2 and c.table_schema = $1;
`, schema, tableName)
if err != nil {
return nil, err
@ -148,10 +151,11 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
defer rows.Close()
for rows.Next() {
var colName, colType, colDefault, nullable string
var colName, udtName, colType, colDefault, nullable string
var elementType *string
var unique bool
var defaultPtr *string
if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil {
if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil {
return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
}
@ -162,12 +166,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
}
column := bdb.Column{
Name: colName,
DBType: colType,
Default: colDefault,
Nullable: nullable == "YES",
Unique: unique,
Validated: isValidated(colType),
Name: colName,
DBType: colType,
ArrType: elementType,
UDTName: udtName,
Default: colDefault,
Nullable: nullable == "YES",
Unique: unique,
}
columns = append(columns, column)
}
@ -176,16 +181,16 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) {
}
// PrimaryKeyInfo looks up the primary key for a table.
func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) {
func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) {
pkey := &bdb.PrimaryKey{}
var err error
query := `
select tc.constraint_name
from information_schema.table_constraints as tc
where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = 'public';`
where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;`
row := p.dbConn.QueryRow(query, tableName)
row := p.dbConn.QueryRow(query, tableName, schema)
if err = row.Scan(&pkey.Name); err != nil {
if err == sql.ErrNoRows {
return nil, nil
@ -196,10 +201,10 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro
queryColumns := `
select kcu.column_name
from information_schema.key_column_usage as kcu
where constraint_name = $1 and table_schema = 'public';`
where constraint_name = $1 and table_schema = $2;`
var rows *sql.Rows
if rows, err = p.dbConn.Query(queryColumns, pkey.Name); err != nil {
if rows, err = p.dbConn.Query(queryColumns, pkey.Name, schema); err != nil {
return nil, err
}
defer rows.Close()
@ -226,7 +231,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro
}
// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) {
func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) {
var fkeys []bdb.ForeignKey
query := `
@ -239,11 +244,11 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, err
from information_schema.table_constraints as tc
inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name
inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name
where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = 'public';`
where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = $2;`
var rows *sql.Rows
var err error
if rows, err = p.dbConn.Query(query, tableName); err != nil {
if rows, err = p.dbConn.Query(query, tableName, schema); err != nil {
return nil, err
}
@ -279,18 +284,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "null.Int"
case "smallint", "smallserial":
c.Type = "null.Int16"
case "decimal", "numeric", "double precision", "money":
case "decimal", "numeric", "double precision":
c.Type = "null.Float64"
case "real":
c.Type = "null.Float32"
case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case "bytea":
c.Type = "[]byte"
c.Type = "null.Bytes"
case "json", "jsonb":
c.Type = "null.JSON"
case "boolean":
c.Type = "null.Bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "null.Time"
case "ARRAY":
if c.ArrType == nil {
panic("unable to get postgres ARRAY underlying type")
}
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default:
c.Type = "null.String"
}
@ -302,18 +324,32 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
c.Type = "int"
case "smallint", "smallserial":
c.Type = "int16"
case "decimal", "numeric", "double precision", "money":
case "decimal", "numeric", "double precision":
c.Type = "float64"
case "real":
c.Type = "float32"
case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
c.Type = "string"
case "json", "jsonb":
c.Type = "types.JSON"
case "bytea":
c.Type = "[]byte"
case "boolean":
c.Type = "bool"
case "date", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "time.Time"
case "ARRAY":
c.Type = getArrayType(c)
// Make DBType something like ARRAYinteger for parsing with randomize.Struct
c.DBType = c.DBType + *c.ArrType
case "USER-DEFINED":
if c.UDTName == "hstore" {
c.Type = "types.HStore"
c.DBType = "hstore"
} else {
c.Type = "string"
fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName)
}
default:
c.Type = "string"
}
@ -322,13 +358,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column {
return c
}
// isValidated checks if the database type is in the validatedTypes list.
func isValidated(typ string) bool {
for _, v := range validatedTypes {
if v == typ {
return true
}
// getArrayType returns the correct boil.Array type for each database type
func getArrayType(c bdb.Column) string {
switch *c.ArrType {
case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial":
return "types.Int64Array"
case "bytea":
return "types.BytesArray"
case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
return "types.StringArray"
case "boolean":
return "types.BoolArray"
case "decimal", "numeric", "double precision", "real":
return "types.Float64Array"
default:
return "types.StringArray"
}
return false
}
// RightQuote is the quoting character for the right side of the identifier
func (p *PostgresDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (p *PostgresDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate PSQL supports indexed placeholders
func (p *PostgresDriver) IndexPlaceholders() bool {
return true
}

View file

@ -6,10 +6,10 @@ import "github.com/pkg/errors"
// Interface for a database driver. Functionality required to support a specific
// database type (eg, MySQL, Postgres etc.)
type Interface interface {
TableNames(exclude []string) ([]string, error)
Columns(tableName string) ([]Column, error)
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
TableNames(schema string, whitelist, blacklist []string) ([]string, error)
Columns(schema, tableName string) ([]Column, error)
PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error)
ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error)
// TranslateColumnType takes a Database column type and returns a go column type.
TranslateColumnType(Column) Column
@ -22,23 +22,32 @@ type Interface interface {
Open() error
// Close the database connection
Close()
// Dialect helpers, these provide the values that will go into
// a queries.Dialect, so the query builder knows how to support
// your database driver properly.
LeftQuote() byte
RightQuote() byte
IndexPlaceholders() bool
}
// Tables returns the metadata for all tables, minus the tables
// specified in the exclude slice.
func Tables(db Interface, exclude ...string) ([]Table, error) {
// specified in the blacklist.
func Tables(db Interface, schema string, whitelist, blacklist []string) ([]Table, error) {
var err error
names, err := db.TableNames(exclude)
names, err := db.TableNames(schema, whitelist, blacklist)
if err != nil {
return nil, errors.Wrap(err, "unable to get table names")
}
var tables []Table
for _, name := range names {
t := Table{Name: name}
t := Table{
Name: name,
}
if t.Columns, err = db.Columns(name); err != nil {
if t.Columns, err = db.Columns(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
}
@ -46,11 +55,11 @@ func Tables(db Interface, exclude ...string) ([]Table, error) {
t.Columns[i] = db.TranslateColumnType(c)
}
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
if t.PKey, err = db.PrimaryKeyInfo(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
}
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
if t.FKeys, err = db.ForeignKeyInfo(schema, name); err != nil {
return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
}

View file

@ -6,20 +6,23 @@ import (
"github.com/vattle/sqlboiler/strmangle"
)
type mockDriver struct{}
type testMockDriver struct{}
func (m mockDriver) TranslateColumnType(c Column) Column { return c }
func (m mockDriver) UseLastInsertID() bool { return false }
func (m mockDriver) Open() error { return nil }
func (m mockDriver) Close() {}
func (m testMockDriver) TranslateColumnType(c Column) Column { return c }
func (m testMockDriver) UseLastInsertID() bool { return false }
func (m testMockDriver) Open() error { return nil }
func (m testMockDriver) Close() {}
func (m mockDriver) TableNames(exclude []string) ([]string, error) {
func (m testMockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
if len(whitelist) > 0 {
return whitelist, nil
}
tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"}
return strmangle.SetComplement(tables, exclude), nil
return strmangle.SetComplement(tables, blacklist), nil
}
// Columns returns a list of mock columns
func (m mockDriver) Columns(tableName string) ([]Column, error) {
func (m testMockDriver) Columns(schema, tableName string) ([]Column, error) {
return map[string][]Column{
"pilots": {
{Name: "id", Type: "int", DBType: "integer"},
@ -61,7 +64,7 @@ func (m mockDriver) Columns(tableName string) ([]Column, error) {
}
// ForeignKeyInfo returns a list of mock foreignkeys
func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
func (m testMockDriver) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) {
return map[string][]ForeignKey{
"jets": {
{Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true},
@ -81,7 +84,7 @@ func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
}
// PrimaryKeyInfo returns mock primary key info for the passed in table name
func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) {
return map[string]*PrimaryKey{
"pilots": {Name: "pilot_id_pkey", Columns: []string{"id"}},
"airports": {Name: "airport_id_pkey", Columns: []string{"id"}},
@ -93,10 +96,25 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
}[tableName], nil
}
// RightQuote is the quoting character for the right side of the identifier
func (m testMockDriver) RightQuote() byte {
return '"'
}
// LeftQuote is the quoting character for the left side of the identifier
func (m testMockDriver) LeftQuote() byte {
return '"'
}
// IndexPlaceholders returns true to indicate fake support of indexed placeholders
func (m testMockDriver) IndexPlaceholders() bool {
return false
}
func TestTables(t *testing.T) {
t.Parallel()
tables, err := Tables(mockDriver{})
tables, err := Tables(testMockDriver{}, "public", nil, nil)
if err != nil {
t.Error(err)
}

View file

@ -3,6 +3,7 @@ package bdb
import (
"fmt"
"regexp"
"strings"
)
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
@ -79,3 +80,30 @@ func SQLColDefinitions(cols []Column, names []string) SQLColumnDefs {
return ret
}
// AutoIncPrimaryKey returns the auto-increment primary key column name or an
// empty string. Primary key columns with default values are presumed
// to be auto-increment, because pkeys need to be unique and a static
// default value would cause collisions.
func AutoIncPrimaryKey(cols []Column, pkey *PrimaryKey) *Column {
if pkey == nil {
return nil
}
for _, pkeyColumn := range pkey.Columns {
for _, c := range cols {
if c.Name != pkeyColumn {
continue
}
if c.Default != "auto_increment" || c.Nullable ||
!(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) {
continue
}
return &c
}
}
return nil
}

View file

@ -4,8 +4,11 @@ import "fmt"
// Table metadata from the database schema.
type Table struct {
Name string
Columns []Column
Name string
// For dbs with real schemas, like Postgres.
// Example value: "schema_name"."table_name"
SchemaName string
Columns []Column
PKey *PrimaryKey
FKeys []ForeignKey

View file

@ -3,10 +3,12 @@ package main
// Config for the running of the commands
type Config struct {
DriverName string
Schema string
PkgName string
OutFolder string
BaseDir string
ExcludeTables []string
WhitelistTables []string
BlacklistTables []string
Tags []string
Debug bool
NoTests bool
@ -14,6 +16,7 @@ type Config struct {
NoAutoTimestamps bool
Postgres PostgresConfig
MySQL MySQLConfig
}
// PostgresConfig configures a postgres database
@ -25,3 +28,13 @@ type PostgresConfig struct {
DBName string
SSLMode string
}
// MySQLConfig configures a mysql database
type MySQLConfig struct {
User string
Pass string
Host string
Port int
DBName string
SSLMode string
}

View file

@ -153,7 +153,8 @@ var defaultTemplateImports = imports{
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
@ -162,7 +163,8 @@ var defaultSingletonTemplateImports = map[string]imports{
"boil_queries": {
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
},
},
"boil_types": {
@ -180,29 +182,38 @@ var defaultTestTemplateImports = imports{
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/randomize"`,
`"github.com/vattle/sqlboiler/randomize"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
var defaultSingletonTestTemplateImports = map[string]imports{
"boil_viper_test": {
"boil_main_test": {
standard: importList{
`"database/sql"`,
`"flag"`,
`"fmt"`,
`"math/rand"`,
`"os"`,
`"path/filepath"`,
`"testing"`,
`"time"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_queries_test": {
standard: importList{
`"crypto/md5"`,
`"bytes"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"io"`,
`"io/ioutil"`,
`"math/rand"`,
`"regexp"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
@ -218,27 +229,42 @@ var defaultSingletonTestTemplateImports = map[string]imports{
var defaultTestMainImports = map[string]imports{
"postgres": {
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"path/filepath"`,
`"time"`,
`"math/rand"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/lib/pq"`,
},
},
"mysql": {
standard: importList{
`"bytes"`,
`"database/sql"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/go-sql-driver/mysql"`,
},
},
}
// importsBasedOnType imports are only included in the template output if the
@ -246,51 +272,75 @@ var defaultTestMainImports = map[string]imports{
// TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{
"null.Float32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Float64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Int": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Int8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Int16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Int32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Int64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Uint": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Uint8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Uint32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Uint64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.JSON": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"null.Bytes": {
thirdParty: importList{`"gopkg.in/nullbio/null.v5"`},
},
"time.Time": {
standard: importList{`"time"`},
},
"types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BytesArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Int64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Float64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BoolArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Hstore": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
}

View file

@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) {
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`,
`"gopkg.in/nullbio/null.v5"`,
},
}
@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) {
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"gopkg.in/nullbio/null.v4"`,
`"gopkg.in/nullbio/null.v5"`,
},
}
@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) {
a := imports{
standard: importList{"fmt"},
thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v4"},
thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"},
}
b := imports{
standard: importList{"os"},
@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) {
if c.standard[0] != "fmt" && c.standard[1] != "os" {
t.Errorf("Wanted: fmt, os got: %#v", c.standard)
}
if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" {
t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty)
if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" {
t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty)
}
}

85
main.go
View file

@ -2,7 +2,6 @@
package main
import (
"errors"
"fmt"
"os"
"path/filepath"
@ -61,9 +60,11 @@ func main() {
// Set up the cobra root command flags
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("schema", "s", "public", "The name of your database schema, for databases that support real schemas")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders")
rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package")
rootCmd.PersistentFlags().StringP("basedir", "", "", "The base directory has the templates and templates_test folders")
rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package")
rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
@ -72,6 +73,9 @@ func main() {
viper.SetDefault("postgres.sslmode", "require")
viper.SetDefault("postgres.port", "5432")
viper.SetDefault("mysql.sslmode", "true")
viper.SetDefault("mysql.port", "3306")
viper.BindPFlags(rootCmd.PersistentFlags())
viper.AutomaticEnv()
@ -79,7 +83,7 @@ func main() {
if e, ok := err.(commandFailure); ok {
fmt.Printf("Error: %v\n\n", string(e))
rootCmd.Help()
} else if !cmdConfig.Debug {
} else if !viper.GetBool("debug") {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Printf("Error: %+v\n", err)
@ -107,6 +111,7 @@ func preRun(cmd *cobra.Command, args []string) error {
cmdConfig = &Config{
DriverName: driverName,
OutFolder: viper.GetString("output"),
Schema: viper.GetString("schema"),
PkgName: viper.GetString("pkgname"),
Debug: viper.GetBool("debug"),
NoTests: viper.GetBool("no-tests"),
@ -115,12 +120,20 @@ func preRun(cmd *cobra.Command, args []string) error {
}
// BUG: https://github.com/spf13/viper/issues/200
// Look up the value of ExcludeTables & Tags directly from PFlags in Cobra if we
// Look up the value of blacklist, whitelist & tags directly from PFlags in Cobra if we
// detect a malformed value coming out of viper.
// Once the bug is fixed we'll be able to move this into the init above
cmdConfig.ExcludeTables = viper.GetStringSlice("exclude")
if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") {
cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude")
cmdConfig.BlacklistTables = viper.GetStringSlice("blacklist")
if len(cmdConfig.BlacklistTables) == 1 && strings.HasPrefix(cmdConfig.BlacklistTables[0], "[") {
cmdConfig.BlacklistTables, err = cmd.PersistentFlags().GetStringSlice("blacklist")
if err != nil {
return err
}
}
cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist")
if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") {
cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist")
if err != nil {
return err
}
@ -134,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
}
}
if viper.IsSet("postgres.dbname") {
if driverName == "postgres" {
cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"),
@ -144,10 +157,17 @@ func preRun(cmd *cobra.Command, args []string) error {
SSLMode: viper.GetString("postgres.sslmode"),
}
// Set the default SSLMode value
// BUG: https://github.com/spf13/viper/issues/71
// Despite setting defaults, nested values don't get defaults
// Set them manually
if cmdConfig.Postgres.SSLMode == "" {
viper.Set("postgres.sslmode", "require")
cmdConfig.Postgres.SSLMode = viper.GetString("postgres.sslmode")
cmdConfig.Postgres.SSLMode = "require"
viper.Set("postgres.sslmode", cmdConfig.Postgres.SSLMode)
}
if cmdConfig.Postgres.Port == 0 {
cmdConfig.Postgres.Port = 5432
viper.Set("postgres.port", cmdConfig.Postgres.Port)
}
err = vala.BeginValidation().Validate(
@ -161,8 +181,45 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil {
return commandFailure(err.Error())
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if driverName == "mysql" {
cmdConfig.MySQL = MySQLConfig{
User: viper.GetString("mysql.user"),
Pass: viper.GetString("mysql.pass"),
Host: viper.GetString("mysql.host"),
Port: viper.GetInt("mysql.port"),
DBName: viper.GetString("mysql.dbname"),
SSLMode: viper.GetString("mysql.sslmode"),
}
// MySQL doesn't have schemas, just databases
cmdConfig.Schema = cmdConfig.MySQL.DBName
// BUG: https://github.com/spf13/viper/issues/71
// Despite setting defaults, nested values don't get defaults
// Set them manually
if cmdConfig.MySQL.SSLMode == "" {
cmdConfig.MySQL.SSLMode = "true"
viper.Set("mysql.sslmode", cmdConfig.MySQL.SSLMode)
}
if cmdConfig.MySQL.Port == 0 {
cmdConfig.MySQL.Port = 3306
viper.Set("mysql.port", cmdConfig.MySQL.Port)
}
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(cmdConfig.MySQL.User, "mysql.user"),
vala.StringNotEmpty(cmdConfig.MySQL.Host, "mysql.host"),
vala.Not(vala.Equals(cmdConfig.MySQL.Port, 0, "mysql.port")),
vala.StringNotEmpty(cmdConfig.MySQL.DBName, "mysql.dbname"),
vala.StringNotEmpty(cmdConfig.MySQL.SSLMode, "mysql.sslmode"),
).Check()
if err != nil {
return commandFailure(err.Error())
}
}
cmdState, err = New(cmdConfig)

View file

@ -1,15 +1,16 @@
package boil
package queries
import (
"database/sql"
"reflect"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle"
)
type loadRelationshipState struct {
exec Executor
exec boil.Executor
loaded map[string]struct{}
toLoad []string
}

View file

@ -1,6 +1,10 @@
package boil
package queries
import "testing"
import (
"testing"
"github.com/vattle/sqlboiler/boil"
)
var loadFunctionCalled bool
var loadFunctionNestedCalled int
@ -32,12 +36,12 @@ type testNestedRSlice struct {
type testNestedLSlice struct {
}
func (testLStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error {
func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error {
loadFunctionCalled = true
return nil
}
func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error {
func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedStruct:
x.R = &testNestedRStruct{
@ -54,7 +58,7 @@ func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj inter
return nil
}
func (testNestedLSlice) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error {
func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error {
switch x := obj.(type) {
case *testNestedSlice:

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"fmt"

View file

@ -1,11 +1,11 @@
package boil
package queries
import (
"reflect"
"testing"
"time"
"gopkg.in/nullbio/null.v4"
"gopkg.in/nullbio/null.v5"
)
type testObj struct {

View file

@ -1,12 +1,12 @@
package qm
import "github.com/vattle/sqlboiler/boil"
import "github.com/vattle/sqlboiler/queries"
// QueryMod to modify the query object
type QueryMod func(q *boil.Query)
type QueryMod func(q *queries.Query)
// Apply the query mods to the Query object
func Apply(q *boil.Query, mods ...QueryMod) {
func Apply(q *queries.Query, mods ...QueryMod) {
for _, mod := range mods {
mod(q)
}
@ -14,8 +14,8 @@ func Apply(q *boil.Query, mods ...QueryMod) {
// SQL allows you to execute a plain SQL statement
func SQL(sql string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.SetSQL(q, sql, args...)
return func(q *queries.Query) {
queries.SetSQL(q, sql, args...)
}
}
@ -25,29 +25,29 @@ func SQL(sql string, args ...interface{}) QueryMod {
// Relationship name plurality is important, if your relationship is
// singular, you need to specify the singular form and vice versa.
func Load(relationships ...string) QueryMod {
return func(q *boil.Query) {
boil.SetLoad(q, relationships...)
return func(q *queries.Query) {
queries.SetLoad(q, relationships...)
}
}
// InnerJoin on another table
func InnerJoin(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendInnerJoin(q, clause, args...)
return func(q *queries.Query) {
queries.AppendInnerJoin(q, clause, args...)
}
}
// Select specific columns opposed to all columns
func Select(columns ...string) QueryMod {
return func(q *boil.Query) {
boil.AppendSelect(q, columns...)
return func(q *queries.Query) {
queries.AppendSelect(q, columns...)
}
}
// Where allows you to specify a where clause for your statement
func Where(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendWhere(q, clause, args...)
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
}
}
@ -55,24 +55,24 @@ func Where(clause string, args ...interface{}) QueryMod {
// And is a duplicate of the Where function, but allows for more natural looking
// query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?")))
func And(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendWhere(q, clause, args...)
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
}
}
// Or allows you to specify a where clause separated by an OR for your statement
func Or(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendWhere(q, clause, args...)
boil.SetLastWhereAsOr(q)
return func(q *queries.Query) {
queries.AppendWhere(q, clause, args...)
queries.SetLastWhereAsOr(q)
}
}
// WhereIn allows you to specify a "x IN (set)" clause for your where statement
// Example clauses: "column in ?", "(column1,column2) in ?"
func WhereIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendIn(q, clause, args...)
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
}
}
@ -81,65 +81,65 @@ func WhereIn(clause string, args ...interface{}) QueryMod {
// allows for more natural looking query mod chains, for example:
// (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?"))
func AndIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendIn(q, clause, args...)
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
}
}
// OrIn allows you to specify an IN clause separated by
// an OR for your where statement
func OrIn(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendIn(q, clause, args...)
boil.SetLastInAsOr(q)
return func(q *queries.Query) {
queries.AppendIn(q, clause, args...)
queries.SetLastInAsOr(q)
}
}
// GroupBy allows you to specify a group by clause for your statement
func GroupBy(clause string) QueryMod {
return func(q *boil.Query) {
boil.AppendGroupBy(q, clause)
return func(q *queries.Query) {
queries.AppendGroupBy(q, clause)
}
}
// OrderBy allows you to specify a order by clause for your statement
func OrderBy(clause string) QueryMod {
return func(q *boil.Query) {
boil.AppendOrderBy(q, clause)
return func(q *queries.Query) {
queries.AppendOrderBy(q, clause)
}
}
// Having allows you to specify a having clause for your statement
func Having(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendHaving(q, clause, args...)
return func(q *queries.Query) {
queries.AppendHaving(q, clause, args...)
}
}
// From allows to specify the table for your statement
func From(from string) QueryMod {
return func(q *boil.Query) {
boil.AppendFrom(q, from)
return func(q *queries.Query) {
queries.AppendFrom(q, from)
}
}
// Limit the number of returned rows
func Limit(limit int) QueryMod {
return func(q *boil.Query) {
boil.SetLimit(q, limit)
return func(q *queries.Query) {
queries.SetLimit(q, limit)
}
}
// Offset into the results
func Offset(offset int) QueryMod {
return func(q *boil.Query) {
boil.SetOffset(q, offset)
return func(q *queries.Query) {
queries.SetOffset(q, offset)
}
}
// For inserts a concurrency locking clause at the end of your statement
func For(clause string) QueryMod {
return func(q *boil.Query) {
boil.SetFor(q, clause)
return func(q *queries.Query) {
queries.SetFor(q, clause)
}
}

View file

@ -1,8 +1,10 @@
package boil
package queries
import (
"database/sql"
"fmt"
"github.com/vattle/sqlboiler/boil"
)
// joinKind is the type of join
@ -18,8 +20,9 @@ const (
// Query holds the state for the built up query
type Query struct {
executor Executor
plainSQL plainSQL
executor boil.Executor
dialect *Dialect
rawSQL rawSQL
load []string
delete bool
update map[string]interface{}
@ -37,6 +40,20 @@ type Query struct {
forlock string
}
// Dialect holds values that direct the query builder
// how to build compatible queries for each database.
// Each database driver needs to implement functions
// that provide these values.
type Dialect struct {
// The left quote character for SQL identifiers
LQ byte
// The right quote character for SQL identifiers
RQ byte
// Bool flag indicating whether indexed
// placeholders ($1) are used, or ? placeholders.
IndexPlaceholders bool
}
type where struct {
clause string
orSeparator bool
@ -54,7 +71,7 @@ type having struct {
args []interface{}
}
type plainSQL struct {
type rawSQL struct {
sql string
args []interface{}
}
@ -65,65 +82,92 @@ type join struct {
args []interface{}
}
// SQL makes a plainSQL query, usually for use with bind
func SQL(exec Executor, query string, args ...interface{}) *Query {
// Raw makes a raw query, usually for use with bind
func Raw(exec boil.Executor, query string, args ...interface{}) *Query {
return &Query{
executor: exec,
plainSQL: plainSQL{
rawSQL: rawSQL{
sql: query,
args: args,
},
}
}
// SQLG makes a plainSQL query using the global Executor, usually for use with bind
func SQLG(query string, args ...interface{}) *Query {
return SQL(GetDB(), query, args...)
// RawG makes a raw query using the global boil.Executor, usually for use with bind
func RawG(query string, args ...interface{}) *Query {
return Raw(boil.GetDB(), query, args...)
}
// ExecQuery executes a query that does not need a row returned
func ExecQuery(q *Query) (sql.Result, error) {
// Exec executes a query that does not need a row returned
func (q *Query) Exec() (sql.Result, error) {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
fmt.Fprintln(DebugWriter, args)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.Exec(qs, args...)
}
// ExecQueryOne executes the query for the One finisher and returns a row
func ExecQueryOne(q *Query) *sql.Row {
// QueryRow executes the query for the One finisher and returns a row
func (q *Query) QueryRow() *sql.Row {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
fmt.Fprintln(DebugWriter, args)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.QueryRow(qs, args...)
}
// ExecQueryAll executes the query for the All finisher and returns multiple rows
func ExecQueryAll(q *Query) (*sql.Rows, error) {
// Query executes the query for the All finisher and returns multiple rows
func (q *Query) Query() (*sql.Rows, error) {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
fmt.Fprintln(DebugWriter, args)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, qs)
fmt.Fprintln(boil.DebugWriter, args)
}
return q.executor.Query(qs, args...)
}
// ExecP executes a query that does not need a row returned
// It will panic on error
func (q *Query) ExecP() sql.Result {
res, err := q.Exec()
if err != nil {
panic(boil.WrapErr(err))
}
return res
}
// QueryP executes the query for the All finisher and returns multiple rows
// It will panic on error
func (q *Query) QueryP() *sql.Rows {
rows, err := q.Query()
if err != nil {
panic(boil.WrapErr(err))
}
return rows
}
// SetExecutor on the query.
func SetExecutor(q *Query, exec Executor) {
func SetExecutor(q *Query, exec boil.Executor) {
q.executor = exec
}
// GetExecutor on the query.
func GetExecutor(q *Query) Executor {
func GetExecutor(q *Query) boil.Executor {
return q.executor
}
// SetDialect on the query.
func SetDialect(q *Query, dialect *Dialect) {
q.dialect = dialect
}
// SetSQL on the query.
func SetSQL(q *Query, sql string, args ...interface{}) {
q.plainSQL = plainSQL{sql: sql, args: args}
q.rawSQL = rawSQL{sql: sql, args: args}
}
// SetLoad on the query.
@ -131,6 +175,11 @@ func SetLoad(q *Query, relationships ...string) {
q.load = append([]string(nil), relationships...)
}
// SetSelect on the query.
func SetSelect(q *Query, sel []string) {
q.selectCols = sel
}
// SetCount on the query.
func SetCount(q *Query) {
q.count = true

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"bytes"
@ -20,8 +20,8 @@ func buildQuery(q *Query) (string, []interface{}) {
var args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case len(q.rawSQL.sql) != 0:
return q.rawSQL.sql, q.rawSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
@ -34,8 +34,8 @@ func buildQuery(q *Query) (string, []interface{}) {
// Cache the generated query for query object re-use
bufStr := buf.String()
q.plainSQL.sql = bufStr
q.plainSQL.args = args
q.rawSQL.sql = bufStr
q.rawSQL.args = args
return bufStr, args
}
@ -57,8 +57,8 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
// Don't identQuoteSlice - writeAsStatements does this
buf.WriteString(strings.Join(selectColsWithAs, ", "))
} else if hasSelectCols {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), ", "))
} else if hasJoins {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
} else if hasJoins && !q.count {
selectColsWithStars := writeStars(q)
buf.WriteString(strings.Join(selectColsWithStars, ", "))
} else {
@ -70,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteByte(')')
}
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
if len(q.joins) > 0 {
argsLen := len(args)
@ -82,7 +82,12 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
args = append(args, j.args...)
}
resp, _ := convertQuestionMarks(joinBuf.String(), argsLen+1)
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
} else {
resp = joinBuf.String()
}
fmt.Fprintf(buf, resp)
strmangle.PutBuffer(joinBuf)
}
@ -110,7 +115,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer()
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
where, whereArgs := whereClause(q, 1)
if len(whereArgs) != 0 {
@ -135,7 +140,7 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := strmangle.GetBuffer()
buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
cols := make(sort.StringSlice, len(q.update))
var args []interface{}
@ -150,13 +155,13 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
for i := 0; i < len(cols); i++ {
args = append(args, q.update[cols[i]])
cols[i] = strmangle.IdentQuote(cols[i])
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
}
buf.WriteString(fmt.Sprintf(
" SET (%s) = (%s)",
strings.Join(cols, ", "),
strmangle.Placeholders(len(cols), 1, 1)),
strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)),
)
where, whereArgs := whereClause(q, len(args)+1)
@ -178,11 +183,40 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
return buf, args
}
// BuildUpsertQuery builds a SQL statement string using the upsertData provided.
func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
conflict = strmangle.IdentQuoteSlice(conflict)
whitelist = strmangle.IdentQuoteSlice(whitelist)
ret = strmangle.IdentQuoteSlice(ret)
// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided.
func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string {
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
fmt.Fprintf(
buf,
"INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ",
tableName,
strings.Join(whitelist, ", "),
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
)
for i, v := range update {
if i != 0 {
buf.WriteByte(',')
}
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted)
buf.WriteString(" = VALUES(")
buf.WriteString(quoted)
buf.WriteByte(')')
}
return buf.String()
}
// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided.
func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict)
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret)
buf := strmangle.GetBuffer()
defer strmangle.PutBuffer(buf)
@ -192,7 +226,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf
"INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ",
tableName,
strings.Join(whitelist, ", "),
strmangle.Placeholders(len(whitelist), 1, 1),
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
)
if !updateOnConflict || len(update) == 0 {
@ -206,7 +240,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf
if i != 0 {
buf.WriteByte(',')
}
quoted := strmangle.IdentQuote(v)
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
buf.WriteString(quoted)
buf.WriteString(" = EXCLUDED.")
buf.WriteString(quoted)
@ -237,7 +271,12 @@ func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
fmt.Fprintf(havingBuf, j.clause)
*args = append(*args, j.args...)
}
resp, _ := convertQuestionMarks(havingBuf.String(), argsLen+1)
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1)
} else {
resp = havingBuf.String()
}
fmt.Fprintf(buf, resp)
strmangle.PutBuffer(havingBuf)
}
@ -264,7 +303,7 @@ func writeStars(q *Query) []string {
for i, f := range q.from {
toks := strings.Split(f, " ")
if len(toks) == 1 {
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0]))
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
continue
}
@ -276,7 +315,7 @@ func writeStars(q *Query) []string {
if len(alias) != 0 {
name = alias
}
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name))
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
}
return cols
@ -292,7 +331,7 @@ func writeAsStatements(q *Query) []string {
toks := strings.Split(col, ".")
if len(toks) == 1 {
cols[i] = strmangle.IdentQuote(col)
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
continue
}
@ -301,7 +340,7 @@ func writeAsStatements(q *Query) []string {
asParts[j] = strings.Trim(tok, `"`)
}
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, "."))
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
}
return cols
@ -335,7 +374,13 @@ func whereClause(q *Query, startAt int) (string, []interface{}) {
args = append(args, where.args...)
}
resp, _ := convertQuestionMarks(buf.String(), startAt)
var resp string
if q.dialect.IndexPlaceholders {
resp, _ = convertQuestionMarks(buf.String(), startAt)
} else {
resp = buf.String()
}
return resp, args
}
@ -374,7 +419,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// column name side, however if this case is being hit then the regexp
// probably needs adjustment, or the user is passing in invalid clauses.
if matches == nil {
clause, count := convertInQuestionMarks(in.clause, startAt, 1, ln)
clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln)
buf.WriteString(clause)
startAt = startAt + count
} else {
@ -384,11 +429,24 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// of the clause to determine how many columns they are using.
// This number determines the groupAt for the convert function.
cols := strings.Split(leftSide, ",")
cols = strmangle.IdentQuoteSlice(cols)
cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
groupAt := len(cols)
leftClause, leftCount := convertQuestionMarks(strings.Join(cols, ","), startAt)
rightClause, rightCount := convertInQuestionMarks(rightSide, startAt+leftCount, groupAt, ln-leftCount)
var leftClause string
var leftCount int
if q.dialect.IndexPlaceholders {
leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
} else {
// Count the number of cols that are question marks, so we know
// how much to offset convertInQuestionMarks by
for _, v := range cols {
if v == "?" {
leftCount++
}
}
leftClause = strings.Join(cols, ",")
}
rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
buf.WriteString(leftClause)
buf.WriteString(" IN ")
buf.WriteString(rightClause)
@ -406,7 +464,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) {
// It uses groupAt to determine how many placeholders should be in each group,
// for example, groupAt 2 would result in: (($1,$2),($3,$4))
// and groupAt 1 would result in ($1,$2,$3,$4)
func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, int) {
func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
if startAt == 0 || len(clause) == 0 {
panic("Not a valid start number.")
}
@ -428,7 +486,7 @@ func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string,
paramBuf.WriteString(clause[:foundAt])
paramBuf.WriteByte('(')
paramBuf.WriteString(strmangle.Placeholders(total, startAt, groupAt))
paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt))
paramBuf.WriteByte(')')
paramBuf.WriteString(clause[foundAt+1:])

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"bytes"
@ -97,6 +97,7 @@ func TestBuildQuery(t *testing.T) {
for i, test := range tests {
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
out, args := buildQuery(test.q)
if *writeGoldenFiles {
@ -149,6 +150,7 @@ func TestWriteStars(t *testing.T) {
}
for i, test := range tests {
test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
selects := writeStars(&test.In)
if !reflect.DeepEqual(selects, test.Out) {
t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects)
@ -275,6 +277,7 @@ func TestWhereClause(t *testing.T) {
}
for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, _ := whereClause(&test.q, 1)
if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
@ -407,6 +410,7 @@ func TestInClause(t *testing.T) {
}
for i, test := range tests {
test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}
result, args := inClause(&test.q, 1)
if result != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result)
@ -489,7 +493,7 @@ func TestConvertInQuestionMarks(t *testing.T) {
}
for i, test := range tests {
res, count := convertInQuestionMarks(test.clause, test.start, test.group, test.total)
res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total)
if res != test.expect {
t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res)
}
@ -497,6 +501,14 @@ func TestConvertInQuestionMarks(t *testing.T) {
t.Errorf("%d) Expected %d, got %d", i, test.total, count)
}
}
res, count := convertInQuestionMarks(false, "?", 1, 3, 9)
if res != "((?,?,?),(?,?,?),(?,?,?))" {
t.Errorf("Mismatch between expected and result: %s", res)
}
if count != 9 {
t.Errorf("Expected 9 results, got %d", count)
}
}
func TestWriteAsStatements(t *testing.T) {
@ -512,6 +524,7 @@ func TestWriteAsStatements(t *testing.T) {
`a.clown.run`,
`COUNT(a)`,
},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
expect := []string{

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"database/sql"
@ -36,12 +36,12 @@ func TestSetSQL(t *testing.T) {
q := &Query{}
SetSQL(q, "select * from thing", 5, 3)
if len(q.plainSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.plainSQL.args))
if len(q.rawSQL.args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.rawSQL.args))
}
if q.plainSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.plainSQL.sql)
if q.rawSQL.sql != "select * from thing" {
t.Errorf("Was not expected string, got %s", q.rawSQL.sql)
}
}
@ -290,6 +290,17 @@ func TestFrom(t *testing.T) {
}
}
func TestSetSelect(t *testing.T) {
t.Parallel()
q := &Query{selectCols: []string{"hello"}}
SetSelect(q, nil)
if q.selectCols != nil {
t.Errorf("want nil")
}
}
func TestSetCount(t *testing.T) {
t.Parallel()
@ -362,24 +373,24 @@ func TestAppendSelect(t *testing.T) {
func TestSQL(t *testing.T) {
t.Parallel()
q := SQL(&sql.DB{}, "thing", 5)
if q.plainSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
q := Raw(&sql.DB{}, "thing", 5)
if q.rawSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql)
}
if q.plainSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
if q.rawSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.rawSQL.args[0])
}
}
func TestSQLG(t *testing.T) {
t.Parallel()
q := SQLG("thing", 5)
if q.plainSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql)
q := RawG("thing", 5)
if q.rawSQL.sql != "thing" {
t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql)
}
if q.plainSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.plainSQL.args[0])
if q.rawSQL.args[0].(int) != 5 {
t.Errorf("Expected 5, got %v", q.rawSQL.args[0])
}
}

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"database/sql"
@ -8,6 +8,7 @@ import (
"sync"
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/strmangle"
)
@ -40,7 +41,7 @@ const (
// It panics on error. See boil.Bind() documentation.
func (q *Query) BindP(obj interface{}) {
if err := q.Bind(obj); err != nil {
panic(WrapErr(err))
panic(boil.WrapErr(err))
}
}
@ -100,7 +101,7 @@ func (q *Query) Bind(obj interface{}) error {
return err
}
rows, err := ExecQueryAll(q)
rows, err := q.Query()
if err != nil {
return errors.Wrap(err, "bind failed to execute query")
}
@ -322,8 +323,10 @@ func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.V
v := (mapping >> uint(i*8)) & sentinel
if v == sentinel {
if val.Kind() != reflect.Ptr {
if addressOf && val.Kind() != reflect.Ptr {
return val.Addr()
} else if !addressOf && val.Kind() == reflect.Ptr {
return reflect.Indirect(val)
}
return val
}
@ -404,74 +407,3 @@ func makeCacheKey(typ string, cols []string) string {
return mapKey
}
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.Indirect(reflect.ValueOf(obj))
for i, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i] = field.Interface()
}
return ret
}
// GetSliceValues returns the values (as interface) of the matching columns in obj.
func GetSliceValues(slice []interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(slice)*len(columns))
for i, obj := range slice {
val := reflect.Indirect(reflect.ValueOf(obj))
for j, c := range columns {
fieldName := strmangle.TitleCase(c)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj))
}
ret[i*len(columns)+j] = field.Interface()
}
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem()
var ln int
var getField func(reflect.Value, int) reflect.Value
if len(columns) == 0 {
ln = val.NumField()
getField = func(v reflect.Value, i int) reflect.Value {
return v.Field(i)
}
} else {
ln = len(columns)
getField = func(v reflect.Value, i int) reflect.Value {
return v.FieldByName(strmangle.TitleCase(columns[i]))
}
}
ret := make([]interface{}, ln)
for i := 0; i < ln; i++ {
field := getField(val, i)
if !field.IsValid() {
// Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually
// produce an Invalid value, so we make a hopefully safe assumption here.
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i])))
}
ret[i] = field.Addr().Interface()
}
return ret
}

View file

@ -1,4 +1,4 @@
package boil
package queries
import (
"database/sql/driver"
@ -6,10 +6,8 @@ import (
"strconv"
"strings"
"testing"
"time"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
"gopkg.in/nullbio/null.v4"
)
func bin64(i uint64) string {
@ -44,7 +42,8 @@ func TestBindStruct(t *testing.T) {
}{}
query := &Query{
from: []string{"fun"},
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
@ -83,7 +82,8 @@ func TestBindSlice(t *testing.T) {
}{}
query := &Query{
from: []string{"fun"},
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
@ -133,7 +133,8 @@ func TestBindPtrSlice(t *testing.T) {
}{}
query := &Query{
from: []string{"fun"},
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
@ -265,6 +266,76 @@ func TestPtrFromMapping(t *testing.T) {
}
}
func TestValuesFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)}
v := ValuesFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping)
if got := v[0].(int); got != 5 {
t.Error("flat int was wrong:", got)
}
if got := v[1].(int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
if got := v[2].(int); got != 6 {
t.Error("nested int was wrong:", got)
}
if got := v[3].(int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestPtrsFromMapping(t *testing.T) {
t.Parallel()
type NestedPtrs struct {
Int int
IntP *int
NestedPtrsP *NestedPtrs
}
val := &NestedPtrs{
Int: 5,
IntP: new(int),
NestedPtrsP: &NestedPtrs{
Int: 6,
IntP: new(int),
},
}
mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)}
v := PtrsFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping)
if got := *v[0].(*int); got != 5 {
t.Error("flat int was wrong:", got)
}
if got := *v[1].(*int); got != 0 {
t.Error("flat pointer was wrong:", got)
}
if got := *v[2].(*int); got != 6 {
t.Error("nested int was wrong:", got)
}
if got := *v[3].(*int); got != 0 {
t.Error("nested pointer was wrong:", got)
}
}
func TestGetBoilTag(t *testing.T) {
t.Parallel()
@ -369,7 +440,8 @@ func TestBindSingular(t *testing.T) {
}{}
query := &Query{
from: []string{"fun"},
from: []string{"fun"},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
@ -412,8 +484,9 @@ func TestBind_InnerJoin(t *testing.T) {
}{}
query := &Query{
from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}},
from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}},
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
}
db, mock, err := sqlmock.New()
@ -454,249 +527,59 @@ func TestBind_InnerJoin(t *testing.T) {
}
}
// func TestBind_InnerJoinSelect(t *testing.T) {
// t.Parallel()
//
// testResults := []*struct {
// Happy struct {
// ID int
// } `boil:"h,bind"`
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// query := &Query{
// selectCols: []string{"fun.id", "h.id"},
// from: []string{"fun"},
// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
// }
//
// db, mock, err := sqlmock.New()
// if err != nil {
// t.Error(err)
// }
//
// ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
//
// SetExecutor(query, db)
// err = query.Bind(&testResults)
// if err != nil {
// t.Error(err)
// }
//
// if len(testResults) != 2 {
// t.Fatal("wrong number of results:", len(testResults))
// }
// if id := testResults[0].Happy.ID; id != 11 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[0].Fun.ID; id != 10 {
// t.Error("wrong ID:", id)
// }
//
// if id := testResults[1].Happy.ID; id != 13 {
// t.Error("wrong ID:", id)
// }
// if id := testResults[1].Fun.ID; id != 12 {
// t.Error("wrong ID:", id)
// }
//
// if err := mock.ExpectationsWereMet(); err != nil {
// t.Error(err)
// }
// }
// func TestBindPtrs_Easy(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"identifier"`
// Date time.Time
// }{}
//
// cols := []string{"identifier", "date"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*time.Time) != &testStruct.Date {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Recursive(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// }
// Fun struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id", "fun.id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_RecursiveTags(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// Happy struct {
// ID int `boil:"identifier"`
// } `boil:",bind"`
// Fun struct {
// ID int `boil:"identification"`
// } `boil:",bind"`
// }{}
//
// cols := []string{"happy.identifier", "fun.identification"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// if ptrs[1].(*int) != &testStruct.Fun.ID {
// t.Error("id is the wrong pointer")
// }
// }
//
// func TestBindPtrs_Ignore(t *testing.T) {
// t.Parallel()
//
// testStruct := struct {
// ID int `boil:"-"`
// Happy struct {
// ID int
// } `boil:",bind"`
// }{}
//
// cols := []string{"id"}
// ptrs, err := bindPtrs(&testStruct, nil, cols...)
// if err != nil {
// t.Error(err)
// }
//
// if ptrs[0].(*int) != &testStruct.Happy.ID {
// t.Error("id is the wrong pointer")
// }
// }
func TestGetStructValues(t *testing.T) {
func TestBind_InnerJoinSelect(t *testing.T) {
t.Parallel()
timeThing := time.Now()
o := struct {
TitleThing string
Name string
ID int
Stuff int
Things int
Time time.Time
NullBool null.Bool
}{
TitleThing: "patrick",
Stuff: 10,
Things: 0,
Time: timeThing,
NullBool: null.NewBool(true, false),
testResults := []*struct {
Happy struct {
ID int
} `boil:"h,bind"`
Fun struct {
ID int
} `boil:",bind"`
}{}
query := &Query{
dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true},
selectCols: []string{"fun.id", "h.id"},
from: []string{"fun"},
joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}},
}
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
if vals[0].(string) != "patrick" {
t.Errorf("Want test, got %s", vals[0])
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
if vals[1].(string) != "" {
t.Errorf("Want empty string, got %s", vals[1])
ret := sqlmock.NewRows([]string{"fun.id", "h.id"})
ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11)))
ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13)))
mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret)
SetExecutor(query, db)
err = query.Bind(&testResults)
if err != nil {
t.Error(err)
}
if vals[2].(int) != 0 {
t.Errorf("Want 0, got %d", vals[2])
if len(testResults) != 2 {
t.Fatal("wrong number of results:", len(testResults))
}
if vals[3].(int) != 10 {
t.Errorf("Want 10, got %d", vals[3])
if id := testResults[0].Happy.ID; id != 11 {
t.Error("wrong ID:", id)
}
if vals[4].(int) != 0 {
t.Errorf("Want 0, got %d", vals[4])
if id := testResults[0].Fun.ID; id != 10 {
t.Error("wrong ID:", id)
}
if !vals[5].(time.Time).Equal(timeThing) {
t.Errorf("Want %s, got %s", o.Time, vals[5])
if id := testResults[1].Happy.ID; id != 13 {
t.Error("wrong ID:", id)
}
if !vals[6].(null.Bool).IsZero() {
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
}
}
func TestGetSliceValues(t *testing.T) {
t.Parallel()
o := []struct {
ID int
Name string
}{
{5, "a"},
{6, "b"},
}
in := make([]interface{}, len(o))
in[0] = o[0]
in[1] = o[1]
vals := GetSliceValues(in, "id", "name")
if got := vals[0].(int); got != 5 {
t.Error(got)
}
if got := vals[1].(string); got != "a" {
t.Error(got)
}
if got := vals[2].(int); got != 6 {
t.Error(got)
}
if got := vals[3].(string); got != "b" {
t.Error(got)
}
}
func TestGetStructPointers(t *testing.T) {
t.Parallel()
o := struct {
Title string
ID *int
}{
Title: "patrick",
}
ptrs := GetStructPointers(&o, "title", "id")
*ptrs[0].(*string) = "test"
if o.Title != "test" {
t.Errorf("Expected test, got %s", o.Title)
}
x := 5
*ptrs[1].(**int) = &x
if *o.ID != 5 {
t.Errorf("Expected 5, got %d", *o.ID)
if id := testResults[1].Fun.ID; id != 12 {
t.Error("wrong ID:", id)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}

121
randomize/random.go Normal file
View file

@ -0,0 +1,121 @@
package randomize
import (
"crypto/md5"
"fmt"
"math/rand"
)
const alphabetAll = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const alphabetLowerAlpha = "abcdefghijklmnopqrstuvwxyz"
func randStr(s *Seed, ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabetAll[s.nextInt()%len(alphabetAll)])
}
return string(str)
}
func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256)
}
return str
}
func randPoint() string {
a := rand.Intn(100)
b := a + 1
return fmt.Sprintf("(%d,%d)", a, b)
}
func randBox() string {
a := rand.Intn(100)
b := a + 1
c := a + 2
d := a + 3
return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d)
}
func randCircle() string {
a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100)
return fmt.Sprintf("((%d,%d),%d)", a, b, c)
}
func randNetAddr() string {
return fmt.Sprintf(
"%d.%d.%d.%d",
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
rand.Intn(254)+1,
)
}
func randMacAddr() string {
buf := make([]byte, 6)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
// Set the local bit
buf[0] |= 2
return fmt.Sprintf(
"%02x:%02x:%02x:%02x:%02x:%02x",
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5],
)
}
func randLsn() string {
a := rand.Int63n(9000000)
b := rand.Int63n(9000000)
return fmt.Sprintf("%d/%d", a, b)
}
func randTxID() string {
// Order of integers is relevant
a := rand.Intn(200) + 100
b := a + 100
c := a
d := a + 50
return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d)
}
func randMoney(s *Seed) string {
return fmt.Sprintf("%d.00", s.nextInt())
}
// StableDBName takes a database name in, and generates
// a random string using the database name as the rand Seed.
// getDBNameHash is used to generate unique test database names.
func StableDBName(input string) string {
return randStrFromSource(stableSource(input), 40)
}
// stableSource takes an input value, and produces a random
// seed from it that will produce very few collisions in
// a 40 character random string made from a different alphabet.
func stableSource(input string) *rand.Rand {
sum := md5.Sum([]byte(input))
var seed int64
for i, byt := range sum {
seed ^= int64(byt) << uint((i*4)%64)
}
return rand.New(rand.NewSource(seed))
}
func randStrFromSource(r *rand.Rand, length int) string {
ln := len(alphabetLowerAlpha)
output := make([]rune, length)
for i := 0; i < length; i++ {
output[i] = rune(alphabetLowerAlpha[r.Intn(ln)])
}
return string(output)
}

19
randomize/random_test.go Normal file
View file

@ -0,0 +1,19 @@
package randomize
import "testing"
func TestStableDBName(t *testing.T) {
t.Parallel()
db := "awesomedb"
one, two := StableDBName(db), StableDBName(db)
if len(one) != 40 {
t.Error("want 40 characters:", len(one), one)
}
if one != two {
t.Error("it should always produce the same value")
}
}

View file

@ -2,41 +2,58 @@
package randomize
import (
"database/sql"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"gopkg.in/nullbio/null.v4"
"gopkg.in/nullbio/null.v5"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
"github.com/vattle/sqlboiler/strmangle"
"github.com/vattle/sqlboiler/types"
)
var (
typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{})
typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeNullBytes = reflect.TypeOf(null.Bytes{})
typeNullJSON = reflect.TypeOf(null.JSON{})
typeTime = reflect.TypeOf(time.Time{})
typeJSON = reflect.TypeOf(types.JSON{})
typeInt64Array = reflect.TypeOf(types.Int64Array{})
typeBytesArray = reflect.TypeOf(types.BytesArray{})
typeBoolArray = reflect.TypeOf(types.BoolArray{})
typeFloat64Array = reflect.TypeOf(types.Float64Array{})
typeStringArray = reflect.TypeOf(types.StringArray{})
typeHStore = reflect.TypeOf(types.HStore{})
rgxValidTime = regexp.MustCompile(`[2-9]+`)
rgxValidTime = regexp.MustCompile(`[2-9]+`)
validatedTypes = []string{"uuid", "interval"}
validatedTypes = []string{
"inet", "line", "uuid", "interval",
"json", "jsonb", "box", "cidr", "circle",
"lseg", "macaddr", "path", "pg_lsn", "point",
"polygon", "txid_snapshot", "money", "hstore",
}
)
// Seed is an atomic counter for pseudo-randomization structs. Using full
@ -163,7 +180,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = null.NewString(randBox(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = null.NewString(randNetAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = null.NewString(randMacAddr(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = null.NewString(randCircle(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = null.NewString(randLsn(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = null.NewString(randPoint(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = null.NewString(randTxID(), true)
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = null.NewString(randMoney(s), true)
field.Set(reflect.ValueOf(value))
return nil
}
case typeNullJSON:
value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true)
field.Set(reflect.ValueOf(value))
return nil
case typeHStore:
value := types.HStore{}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
}
} else {
switch kind {
case reflect.String:
@ -177,6 +246,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value = randBox()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "cidr" || fieldType == "inet" {
value = randNetAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "macaddr" {
value = randMacAddr()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "circle" {
value = randCircle()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "pg_lsn" {
value = randLsn()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "point" {
value = randPoint()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "txid_snapshot" {
value = randTxID()
field.Set(reflect.ValueOf(value))
return nil
}
if fieldType == "money" {
value = randMoney(s)
field.Set(reflect.ValueOf(value))
return nil
}
}
switch typ {
case typeJSON:
value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
field.Set(reflect.ValueOf(value))
return nil
case typeHStore:
value := types.HStore{}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0}
field.Set(reflect.ValueOf(value))
return nil
}
}
}
@ -191,8 +313,11 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
isNull = false
}
// Retrieve the value to be returned
if kind == reflect.Struct {
// If it's a Postgres array, treat it like one
if strings.HasPrefix(fieldType, "ARRAY") {
value = getArrayRandValue(s, typ, fieldType)
// Retrieve the value to be returned
} else if kind == reflect.Struct {
if isNull {
value = getStructNullValue(typ)
} else {
@ -215,6 +340,69 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo
return nil
}
func getArrayRandValue(s *Seed, typ reflect.Type, fieldType string) interface{} {
fieldType = strings.TrimLeft(fieldType, "ARRAY")
switch typ {
case typeInt64Array:
return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())}
case typeFloat64Array:
return types.Float64Array{float64(s.nextInt()), float64(s.nextInt())}
case typeBoolArray:
return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0}
case typeStringArray:
if fieldType == "interval" {
value := strconv.Itoa((s.nextInt()%26)+2) + " days"
return types.StringArray{value, value}
}
if fieldType == "uuid" {
value := uuid.NewV4().String()
return types.StringArray{value, value}
}
if fieldType == "box" || fieldType == "line" || fieldType == "lseg" ||
fieldType == "path" || fieldType == "polygon" {
value := randBox()
return types.StringArray{value, value}
}
if fieldType == "cidr" || fieldType == "inet" {
value := randNetAddr()
return types.StringArray{value, value}
}
if fieldType == "macaddr" {
value := randMacAddr()
return types.StringArray{value, value}
}
if fieldType == "circle" {
value := randCircle()
return types.StringArray{value, value}
}
if fieldType == "pg_lsn" {
value := randLsn()
return types.StringArray{value, value}
}
if fieldType == "point" {
value := randPoint()
return types.StringArray{value, value}
}
if fieldType == "txid_snapshot" {
value := randTxID()
return types.StringArray{value, value}
}
if fieldType == "money" {
value := randMoney(s)
return types.StringArray{value, value}
}
if fieldType == "json" || fieldType == "jsonb" {
value := []byte(fmt.Sprintf(`"%s"`, randStr(s, 1)))
return types.StringArray{string(value)}
}
return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)}
case typeBytesArray:
return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)}
}
return nil
}
// getStructNullValue for the matching type.
func getStructNullValue(typ reflect.Type) interface{} {
switch typ {
@ -250,6 +438,8 @@ func getStructNullValue(typ reflect.Type) interface{} {
return null.NewUint32(0, false)
case typeNullUint64:
return null.NewUint64(0, false)
case typeNullBytes:
return null.NewBytes(nil, false)
}
return nil
@ -292,6 +482,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} {
return null.NewUint32(uint32(s.nextInt()), true)
case typeNullUint64:
return null.NewUint64(uint64(s.nextInt()), true)
case typeNullBytes:
return null.NewBytes(randByteSlice(s, 16), true)
}
return nil
@ -378,23 +570,3 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac
return nil
}
const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randStr(s *Seed, ln int) string {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(alphabet[s.nextInt()%len(alphabet)])
}
return string(str)
}
func randByteSlice(s *Seed, ln int) []byte {
str := make([]byte, ln)
for i := 0; i < ln; i++ {
str[i] = byte(s.nextInt() % 256)
}
return str
}

View file

@ -5,7 +5,7 @@ import (
"testing"
"time"
"gopkg.in/nullbio/null.v4"
"gopkg.in/nullbio/null.v5"
)
func TestRandomizeStruct(t *testing.T) {

View file

@ -15,7 +15,8 @@ import (
"github.com/pkg/errors"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/bdb/drivers"
"github.com/vattle/sqlboiler/boil"
"github.com/vattle/sqlboiler/queries"
"github.com/vattle/sqlboiler/strmangle"
)
const (
@ -32,8 +33,9 @@ const (
type State struct {
Config *Config
Driver bdb.Interface
Tables []bdb.Table
Driver bdb.Interface
Tables []bdb.Table
Dialect queries.Dialect
Templates *templateList
TestTemplates *templateList
@ -59,7 +61,7 @@ func New(config *Config) (*State, error) {
return nil, errors.Wrap(err, "unable to connect to the database")
}
err = s.initTables(config.ExcludeTables)
err = s.initTables(config.Schema, config.WhitelistTables, config.BlacklistTables)
if err != nil {
return nil, errors.Wrap(err, "unable to initialize tables")
}
@ -69,8 +71,7 @@ func New(config *Config) (*State, error) {
if err != nil {
return nil, errors.Wrap(err, "unable to json marshal tables")
}
boil.DebugWriter.Write(b)
fmt.Fprintln(boil.DebugWriter)
fmt.Printf("%s\n", b)
}
err = s.initOutFolder()
@ -96,11 +97,15 @@ func New(config *Config) (*State, error) {
func (s *State) Run(includeTests bool) error {
singletonData := &templateData{
Tables: s.Tables,
Schema: s.Config.Schema,
DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName,
NoHooks: s.Config.NoHooks,
NoAutoTimestamps: s.Config.NoAutoTimestamps,
Dialect: s.Dialect,
LQ: strmangle.QuoteCharacter(s.Dialect.LQ),
RQ: strmangle.QuoteCharacter(s.Dialect.RQ),
StringFuncs: templateStringMappers,
}
@ -127,12 +132,16 @@ func (s *State) Run(includeTests bool) error {
data := &templateData{
Tables: s.Tables,
Table: table,
Schema: s.Config.Schema,
DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName,
NoHooks: s.Config.NoHooks,
NoAutoTimestamps: s.Config.NoAutoTimestamps,
Tags: s.Config.Tags,
Dialect: s.Dialect,
LQ: strmangle.QuoteCharacter(s.Dialect.LQ),
RQ: strmangle.QuoteCharacter(s.Dialect.RQ),
StringFuncs: templateStringMappers,
}
@ -227,6 +236,15 @@ func (s *State) initDriver(driverName string) error {
s.Config.Postgres.Port,
s.Config.Postgres.SSLMode,
)
case "mysql":
s.Driver = drivers.NewMySQLDriver(
s.Config.MySQL.User,
s.Config.MySQL.Pass,
s.Config.MySQL.DBName,
s.Config.MySQL.Host,
s.Config.MySQL.Port,
s.Config.MySQL.SSLMode,
)
case "mock":
s.Driver = &drivers.MockDriver{}
}
@ -235,13 +253,17 @@ func (s *State) initDriver(driverName string) error {
return errors.New("An invalid driver name was provided")
}
s.Dialect.LQ = s.Driver.LeftQuote()
s.Dialect.RQ = s.Driver.RightQuote()
s.Dialect.IndexPlaceholders = s.Driver.IndexPlaceholders()
return nil
}
// initTables retrieves all "public" schema table names from the database.
func (s *State) initTables(exclude []string) error {
func (s *State) initTables(schema string, whitelist, blacklist []string) error {
var err error
s.Tables, err = bdb.Tables(s.Driver, exclude...)
s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, blacklist)
if err != nil {
return errors.Wrap(err, "unable to fetch table data")
}

View file

@ -37,10 +37,10 @@ func TestNew(t *testing.T) {
}()
config := &Config{
DriverName: "mock",
PkgName: "models",
OutFolder: out,
ExcludeTables: []string{"hangars"},
DriverName: "mock",
PkgName: "models",
OutFolder: out,
BlacklistTables: []string{"hangars"},
}
state, err = New(config)

View file

@ -22,6 +22,7 @@ var uppercaseWords = map[string]struct{}{
"id": {},
"uid": {},
"uuid": {},
"json": {},
}
func init() {
@ -33,9 +34,21 @@ func init() {
boilRuleset = newBoilRuleset()
}
// SchemaTable returns a table name with a schema prefixed if
// using a database that supports real schemas, for example,
// for Postgres: "schema_name"."table_name", versus
// simply "table_name" for MySQL (because it does not support real schemas)
func SchemaTable(lq, rq string, driver string, schema string, table string) string {
if driver == "postgres" && schema != "public" {
return fmt.Sprintf(`%s%s%s.%s%s%s`, lq, schema, rq, lq, table, rq)
}
return fmt.Sprintf(`%s%s%s`, lq, table, rq)
}
// IdentQuote attempts to quote simple identifiers in SQL tatements
func IdentQuote(s string) string {
if strings.ToLower(s) == "null" {
func IdentQuote(lq byte, rq byte, s string) string {
if strings.ToLower(s) == "null" || s == "?" {
return s
}
@ -52,28 +65,28 @@ func IdentQuote(s string) string {
buf.WriteByte('.')
}
if strings.HasPrefix(split, `"`) || strings.HasSuffix(split, `"`) || split == "*" {
if split[0] == lq || split[len(split)-1] == rq || split == "*" {
buf.WriteString(split)
continue
}
buf.WriteByte('"')
buf.WriteByte(lq)
buf.WriteString(split)
buf.WriteByte('"')
buf.WriteByte(rq)
}
return buf.String()
}
// IdentQuoteSlice applies IdentQuote to a slice.
func IdentQuoteSlice(s []string) []string {
func IdentQuoteSlice(lq byte, rq byte, s []string) []string {
if len(s) == 0 {
return s
}
strs := make([]string, len(s))
for i, str := range s {
strs[i] = IdentQuote(str)
strs[i] = IdentQuote(lq, rq, str)
}
return strs
@ -105,6 +118,16 @@ func Identifier(in int) string {
return cols.String()
}
// QuoteCharacter returns a string that allows the quote character
// to be embedded into a Go string that uses double quotes:
func QuoteCharacter(q byte) string {
if q == '"' {
return `\"`
}
return string(q)
}
// Plural converts singular words to plural words (eg: person to people)
func Plural(name string) string {
buf := GetBuffer()
@ -368,7 +391,8 @@ func PrefixStringSlice(str string, strs []string) []string {
// Placeholders generates the SQL statement placeholders for in queries.
// For example, ($1,$2,$3),($4,$5,$6) etc.
// It will start counting placeholders at "start".
func Placeholders(count int, start int, group int) string {
// If indexPlaceholders is false, it will convert to ? instead of $1 etc.
func Placeholders(indexPlaceholders bool, count int, start int, group int) string {
buf := GetBuffer()
defer PutBuffer(buf)
@ -387,7 +411,11 @@ func Placeholders(count int, start int, group int) string {
buf.WriteByte(',')
}
}
buf.WriteString(fmt.Sprintf("$%d", start+i))
if indexPlaceholders {
buf.WriteString(fmt.Sprintf("$%d", start+i))
} else {
buf.WriteByte('?')
}
}
if group > 1 {
buf.WriteByte(')')
@ -399,14 +427,19 @@ func Placeholders(count int, start int, group int) string {
// SetParamNames takes a slice of columns and returns a comma separated
// list of parameter names for a template statement SET clause.
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string {
func SetParamNames(lq, rq string, start int, columns []string) string {
buf := GetBuffer()
defer PutBuffer(buf)
for i, c := range columns {
buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, i+1))
if start != 0 {
buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, i+start))
} else {
buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq))
}
if i < len(columns)-1 {
buf.WriteString(", ")
buf.WriteByte(',')
}
}
@ -415,16 +448,17 @@ func SetParamNames(columns []string) string {
// WhereClause returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WhereClause(start int, cols []string) string {
if start == 0 {
panic("0 is not a valid start number for whereClause")
}
func WhereClause(lq, rq string, start int, cols []string) string {
buf := GetBuffer()
defer PutBuffer(buf)
for i, c := range cols {
buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, start+i))
if start != 0 {
buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i))
} else {
buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq))
}
if i < len(cols)-1 {
buf.WriteString(" AND ")
}

View file

@ -29,7 +29,7 @@ func TestIdentQuote(t *testing.T) {
}
for _, test := range tests {
if got := IdentQuote(test.In); got != test.Out {
if got := IdentQuote('"', '"', test.In); got != test.Out {
t.Errorf("want: %s, got: %s", test.Out, got)
}
}
@ -38,7 +38,7 @@ func TestIdentQuote(t *testing.T) {
func TestIdentQuoteSlice(t *testing.T) {
t.Parallel()
ret := IdentQuoteSlice([]string{`thing`, `null`})
ret := IdentQuoteSlice('"', '"', []string{`thing`, `null`})
if ret[0] != `"thing"` {
t.Error(ret[0])
}
@ -69,34 +69,60 @@ func TestIdentifier(t *testing.T) {
}
}
func TestQuoteCharacter(t *testing.T) {
t.Parallel()
if QuoteCharacter('[') != "[" {
t.Error("want just the normal quote character")
}
if QuoteCharacter('`') != "`" {
t.Error("want just the normal quote character")
}
if QuoteCharacter('"') != `\"` {
t.Error("want an escaped character")
}
}
func TestPlaceholders(t *testing.T) {
t.Parallel()
x := Placeholders(1, 2, 1)
x := Placeholders(true, 1, 2, 1)
want := "$2"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(5, 1, 1)
x = Placeholders(true, 5, 1, 1)
want = "$1,$2,$3,$4,$5"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(6, 1, 2)
x = Placeholders(false, 5, 1, 1)
want = "?,?,?,?,?"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(true, 6, 1, 2)
want = "($1,$2),($3,$4),($5,$6)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(9, 1, 3)
want = "($1,$2,$3),($4,$5,$6),($7,$8,$9)"
x = Placeholders(true, 6, 1, 2)
want = "($1,$2),($3,$4),($5,$6)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(7, 1, 3)
x = Placeholders(false, 9, 1, 3)
want = "(?,?,?),(?,?,?),(?,?,?)"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
x = Placeholders(true, 7, 1, 3)
want = "($1,$2,$3),($4,$5,$6),($7)"
if want != x {
t.Errorf("want %s, got %s", want, x)
@ -291,6 +317,28 @@ func TestPrefixStringSlice(t *testing.T) {
}
}
func TestSetParamNames(t *testing.T) {
t.Parallel()
tests := []struct {
Cols []string
Start int
Should string
}{
{Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=?,"col2"=?`},
{Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`},
{Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4,"col2"=$5`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4,"col2"=$5,"col3"=$6`},
}
for i, test := range tests {
r := SetParamNames(`"`, `"`, test.Start, test.Cols)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
}
}
func TestWhereClause(t *testing.T) {
t.Parallel()
@ -299,13 +347,14 @@ func TestWhereClause(t *testing.T) {
Start int
Should string
}{
{Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=? AND "col2"=?`},
{Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`},
{Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`},
}
for i, test := range tests {
r := WhereClause(test.Start, test.Cols)
r := WhereClause(`"`, `"`, test.Start, test.Cols)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}

View file

@ -8,21 +8,45 @@ import (
"text/template"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/queries"
"github.com/vattle/sqlboiler/strmangle"
)
// templateData for sqlboiler templates
type templateData struct {
Tables []bdb.Table
Table bdb.Table
DriverName string
UseLastInsertID bool
PkgName string
Tables []bdb.Table
Table bdb.Table
// Controls what names are output
PkgName string
Schema string
// Controls which code is output (mysql vs postgres ...)
DriverName string
UseLastInsertID bool
// Turn off auto timestamps or hook generation
NoHooks bool
NoAutoTimestamps bool
Tags []string
// Tags control which
Tags []string
// StringFuncs are usable in templates with stringMap
StringFuncs map[string]func(string) string
// Dialect controls quoting
Dialect queries.Dialect
LQ string
RQ string
}
func (t templateData) Quotes(s string) string {
return fmt.Sprintf("%s%s%s", t.LQ, s, t.RQ)
}
func (t templateData) SchemaTable(table string) string {
return strmangle.SchemaTable(t.LQ, t.RQ, t.DriverName, t.Schema, table)
}
type templateList struct {
@ -113,7 +137,7 @@ var templateStringMappers = map[string]func(string) string{
// add a function pointer here.
var templateFunctions = template.FuncMap{
// String ops
"quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) },
"quoteWrap": func(s string) string { return fmt.Sprintf(`"%s"`, s) },
"id": strmangle.Identifier,
// Pluralization
@ -150,6 +174,7 @@ var templateFunctions = template.FuncMap{
// dbdrivers ops
"filterColumnsByDefault": bdb.FilterColumnsByDefault,
"autoIncPrimaryKey": bdb.AutoIncPrimaryKey,
"sqlColDefinitions": bdb.SQLColDefinitions,
"columnNames": bdb.ColumnNames,
"columnDBTypes": bdb.ColumnDBTypes,

View file

@ -1,5 +1,5 @@
{{- define "relationship_to_one_struct_helper" -}}
{{.Function.Name}} *{{.ForeignTable.NameGo}}
{{.Function.Name}} *{{.ForeignTable.NameGo}}
{{- end -}}
{{- $dot := . -}}
@ -8,30 +8,30 @@
{{- $modelNameCamel := $tableNameSingular | camelCase -}}
// {{$modelName}} is an object representing the database table.
type {{$modelName}} struct {
{{range $column := .Table.Columns -}}
{{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"`
{{end -}}
{{- if .Table.IsJoinTable -}}
{{- else}}
R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
{{end -}}
{{range $column := .Table.Columns -}}
{{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"`
{{end -}}
{{- if .Table.IsJoinTable -}}
{{- else}}
R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"`
{{end -}}
}
{{- if .Table.IsJoinTable -}}
{{- else}}
// {{$modelNameCamel}}R is where relationships are stored.
type {{$modelNameCamel}}R struct {
{{range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_struct_helper" $rel}}
{{end -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}}
{{else -}}
{{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}}
{{$rel.Function.Name}} {{$rel.ForeignTable.Slice}}
{{range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_struct_helper" $rel}}
{{end -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}}
{{else -}}
{{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}}
{{$rel.Function.Name}} {{$rel.ForeignTable.Slice}}
{{end -}}{{/* if ForeignColumnUnique */}}
{{- end -}}{{/* range tomany */}}
}

View file

@ -3,31 +3,36 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
var (
{{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
{{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
{{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
{{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}}
{{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}}
)
type (
{{$tableNameSingular}}Slice []*{{$tableNameSingular}}
{{if eq .NoHooks false -}}
{{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error
{{- end}}
// {{$tableNameSingular}}Slice is an alias for a slice of pointers to {{$tableNameSingular}}.
// This should generally be used opposed to []{{$tableNameSingular}}.
{{$tableNameSingular}}Slice []*{{$tableNameSingular}}
{{if eq .NoHooks false -}}
// {{$tableNameSingular}}Hook is the signature for custom {{$tableNameSingular}} hook methods
{{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error
{{- end}}
{{$varNameSingular}}Query struct {
*boil.Query
}
{{$varNameSingular}}Query struct {
*queries.Query
}
)
// Cache for insert and update
// Cache for insert, update and upsert
var (
{{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{})
{{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type)
{{$varNameSingular}}InsertCacheMut sync.RWMutex
{{$varNameSingular}}InsertCache = make(map[string]insertCache)
{{$varNameSingular}}UpdateCacheMut sync.RWMutex
{{$varNameSingular}}UpdateCache = make(map[string]updateCache)
{{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{})
{{$varNameSingular}}Mapping = queries.MakeStructMapping({{$varNameSingular}}Type)
{{$varNameSingular}}InsertCacheMut sync.RWMutex
{{$varNameSingular}}InsertCache = make(map[string]insertCache)
{{$varNameSingular}}UpdateCacheMut sync.RWMutex
{{$varNameSingular}}UpdateCache = make(map[string]updateCache)
{{$varNameSingular}}UpsertCacheMut sync.RWMutex
{{$varNameSingular}}UpsertCache = make(map[string]insertCache)
)
// Force time package dependency for automated UpdatedAt/CreatedAt.

View file

@ -14,123 +14,124 @@ var {{$varNameSingular}}AfterUpsertHooks []{{$tableNameSingular}}Hook
// doBeforeInsertHooks executes all "before insert" hooks.
func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeInsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}BeforeInsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doBeforeUpdateHooks executes all "before Update" hooks.
func (o *{{$tableNameSingular}}) doBeforeUpdateHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeUpdateHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}BeforeUpdateHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doBeforeDeleteHooks executes all "before Delete" hooks.
func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeDeleteHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}BeforeDeleteHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doBeforeUpsertHooks executes all "before Upsert" hooks.
func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}BeforeUpsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}BeforeUpsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doAfterInsertHooks executes all "after Insert" hooks.
func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterInsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}AfterInsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doAfterSelectHooks executes all "after Select" hooks.
func (o *{{$tableNameSingular}}) doAfterSelectHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterSelectHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}AfterSelectHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doAfterUpdateHooks executes all "after Update" hooks.
func (o *{{$tableNameSingular}}) doAfterUpdateHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterUpdateHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}AfterUpdateHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doAfterDeleteHooks executes all "after Delete" hooks.
func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterDeleteHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}AfterDeleteHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// doAfterUpsertHooks executes all "after Upsert" hooks.
func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) {
for _, hook := range {{$varNameSingular}}AfterUpsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
for _, hook := range {{$varNameSingular}}AfterUpsertHooks {
if err := hook(exec, o); err != nil {
return err
}
}
return nil
return nil
}
// Add{{$tableNameSingular}}Hook registers your hook function for all future operations.
func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) {
switch hookPoint {
case boil.BeforeInsertHook:
{{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpdateHook:
{{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook)
case boil.BeforeDeleteHook:
{{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpsertHook:
{{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook)
case boil.AfterInsertHook:
{{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook)
case boil.AfterSelectHook:
{{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook)
case boil.AfterUpdateHook:
{{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook)
case boil.AfterDeleteHook:
{{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook)
case boil.AfterUpsertHook:
{{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook)
}
switch hookPoint {
case boil.BeforeInsertHook:
{{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpdateHook:
{{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook)
case boil.BeforeDeleteHook:
{{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook)
case boil.BeforeUpsertHook:
{{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook)
case boil.AfterInsertHook:
{{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook)
case boil.AfterSelectHook:
{{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook)
case boil.AfterUpdateHook:
{{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook)
case boil.AfterDeleteHook:
{{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook)
case boil.AfterUpsertHook:
{{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook)
}
}
{{- end}}

View file

@ -2,114 +2,115 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
// OneP returns a single {{$varNameSingular}} record from the query, and panics on error.
func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) {
o, err := q.One()
if err != nil {
panic(boil.WrapErr(err))
}
o, err := q.One()
if err != nil {
panic(boil.WrapErr(err))
}
return o
return o
}
// One returns a single {{$varNameSingular}} record from the query.
func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) {
o := &{{$tableNameSingular}}{}
o := &{{$tableNameSingular}}{}
boil.SetLimit(q.Query, 1)
queries.SetLimit(q.Query, 1)
err := q.Bind(o)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows
}
return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}")
}
err := q.Bind(o)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows
}
return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}")
}
{{if not .NoHooks -}}
if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil {
return o, err
}
{{- end}}
{{if not .NoHooks -}}
if err := o.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil {
return o, err
}
{{- end}}
return o, nil
return o, nil
}
// AllP returns all {{$tableNameSingular}} records from the query, and panics on error.
func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice {
o, err := q.All()
if err != nil {
panic(boil.WrapErr(err))
}
o, err := q.All()
if err != nil {
panic(boil.WrapErr(err))
}
return o
return o
}
// All returns all {{$tableNameSingular}} records from the query.
func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) {
var o {{$tableNameSingular}}Slice
var o {{$tableNameSingular}}Slice
err := q.Bind(&o)
if err != nil {
return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice")
}
err := q.Bind(&o)
if err != nil {
return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice")
}
{{if not .NoHooks -}}
if len({{$varNameSingular}}AfterSelectHooks) != 0 {
for _, obj := range o {
if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil {
return o, err
}
}
}
{{- end}}
{{if not .NoHooks -}}
if len({{$varNameSingular}}AfterSelectHooks) != 0 {
for _, obj := range o {
if err := obj.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil {
return o, err
}
}
}
{{- end}}
return o, nil
return o, nil
}
// CountP returns the count of all {{$tableNameSingular}} records in the query, and panics on error.
func (q {{$varNameSingular}}Query) CountP() int64 {
c, err := q.Count()
if err != nil {
panic(boil.WrapErr(err))
}
c, err := q.Count()
if err != nil {
panic(boil.WrapErr(err))
}
return c
return c
}
// Count returns the count of all {{$tableNameSingular}} records in the query.
func (q {{$varNameSingular}}Query) Count() (int64, error) {
var count int64
var count int64
boil.SetCount(q.Query)
queries.SetSelect(q.Query, nil)
queries.SetCount(q.Query)
err := boil.ExecQueryOne(q.Query).Scan(&count)
if err != nil {
return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows")
}
err := q.Query.QueryRow().Scan(&count)
if err != nil {
return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows")
}
return count, nil
return count, nil
}
// Exists checks if the row exists in the table, and panics on error.
func (q {{$varNameSingular}}Query) ExistsP() bool {
e, err := q.Exists()
if err != nil {
panic(boil.WrapErr(err))
}
e, err := q.Exists()
if err != nil {
panic(boil.WrapErr(err))
}
return e
return e
}
// Exists checks if the row exists in the table.
func (q {{$varNameSingular}}Query) Exists() (bool, error) {
var count int64
var count int64
boil.SetCount(q.Query)
boil.SetLimit(q.Query, 1)
queries.SetCount(q.Query)
queries.SetLimit(q.Query, 1)
err := boil.ExecQueryOne(q.Query).Scan(&count)
if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists")
}
err := q.Query.QueryRow().Scan(&count)
if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists")
}
return count > 0, nil
return count > 0, nil
}

View file

@ -1,30 +1,34 @@
{{- define "relationship_to_one_helper" -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
{{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
// {{.Function.Name}}G pointed to by the foreign key.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...)
return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...)
}
// {{.Function.Name}} pointed to by the foreign key.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) {
queryMods := []qm.QueryMod{
qm.Where("{{.ForeignTable.ColumnName}}=$1", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}),
}
queryMods := []qm.QueryMod{
qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}),
}
queryMods = append(queryMods, mods...)
queryMods = append(queryMods, mods...)
query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...)
boil.SetFrom(query.Query, "{{.ForeignTable.Name}}")
query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...)
queries.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}")
return query
return query
}
{{- end -}}{{/* end with */}}
{{end -}}{{/* end define */}}
{{end -}}
{{- /* Begin execution of template for one-to-one relationship */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_helper" $rel -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}}
{{- end -}}
{{- end -}}

View file

@ -1,46 +1,51 @@
{{- /* Begin execution of template for many-to-one or many-to-many relationship helper */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- /* Begin execution of template for many-to-one relationship. */ -}}
{{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}}
{{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}}
{{- else -}}
{{- /* Begin execution of template for many-to-many relationship. */ -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}}
// {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}}
{{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...)
return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...)
}
// {{$rel.Function.Name}} retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} with an executor
{{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query {
queryMods := []qm.QueryMod{
qm.Select(`"{{id 0}}".*`),
}
queryMods := []qm.QueryMod{
qm.Select("{{id 0 | $dot.Quotes}}.*"),
}
if len(mods) != 0 {
queryMods = append(queryMods, mods...)
}
if len(mods) != 0 {
queryMods = append(queryMods, mods...)
}
{{if .ToJoinTable -}}
queryMods = append(queryMods,
qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`),
qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
)
{{else -}}
queryMods = append(queryMods,
qm.Where(`"{{id 0}}"."{{.ForeignColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
)
{{end}}
{{if .ToJoinTable -}}
queryMods = append(queryMods,
qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"),
qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
)
{{else -}}
queryMods = append(queryMods,
qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}),
)
{{end}}
query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...)
boil.SetFrom(query.Query, `"{{.ForeignTable}}" as "{{id 0}}"`)
return query
query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...)
queries.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}")
return query
}
{{end -}}{{- /* if unique foreign key */ -}}
{{- end -}}{{- /* range relationships */ -}}
{{- end -}}{{- /* outer if join table */ -}}
{{- end -}}{{- /* if isJoinTable */ -}}

View file

@ -1,91 +1,93 @@
{{- define "relationship_to_one_eager_helper" -}}
{{- $varNameSingular := .Dot.Table.Name | singular | camelCase -}}
{{- $noHooks := .Dot.NoHooks -}}
{{- with .Rel -}}
{{- $arg := printf "maybe%s" .LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" .LocalTable.NameGo -}}
{{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- $varNameSingular := $dot.Table.Name | singular | camelCase -}}
{{- with .Rel -}}
{{- $arg := printf "maybe%s" .LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" .LocalTable.NameGo -}}
// Load{{.Function.Name}} allows an eager lookup of values, cached into the
// loaded structs of the objects.
func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error {
var slice []*{{.LocalTable.NameGo}}
var object *{{.LocalTable.NameGo}}
var slice []*{{.LocalTable.NameGo}}
var object *{{.LocalTable.NameGo}}
count := 1
if singular {
object = {{$arg}}.(*{{.LocalTable.NameGo}})
} else {
slice = *{{$arg}}.(*{{$slice}})
count = len(slice)
}
count := 1
if singular {
object = {{$arg}}.(*{{.LocalTable.NameGo}})
} else {
slice = *{{$arg}}.(*{{$slice}})
count = len(slice)
}
args := make([]interface{}, count)
if singular {
args[0] = object.{{.LocalTable.ColumnNameGo}}
} else {
for i, obj := range slice {
args[i] = obj.{{.LocalTable.ColumnNameGo}}
}
}
args := make([]interface{}, count)
if singular {
args[0] = object.{{.LocalTable.ColumnNameGo}}
} else {
for i, obj := range slice {
args[i] = obj.{{.LocalTable.ColumnNameGo}}
}
}
query := fmt.Sprintf(
`select * from "{{.ForeignKey.ForeignTable}}" where "{{.ForeignKey.ForeignColumn}}" in (%s)`,
strmangle.Placeholders(count, 1, 1),
)
query := fmt.Sprintf(
"select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
)
if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
}
if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
}
results, err := e.Query(query, args...)
if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}")
}
defer results.Close()
results, err := e.Query(query, args...)
if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}")
}
defer results.Close()
var resultSlice []*{{.ForeignTable.NameGo}}
if err = boil.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}")
}
var resultSlice []*{{.ForeignTable.NameGo}}
if err = queries.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}")
}
{{if not $noHooks -}}
if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil {
return err
}
}
}
{{- end}}
{{if not $dot.NoHooks -}}
if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil {
return err
}
}
}
{{- end}}
if singular && len(resultSlice) != 0 {
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
object.R.{{.Function.Name}} = resultSlice[0]
return nil
}
if singular && len(resultSlice) != 0 {
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
object.R.{{.Function.Name}} = resultSlice[0]
return nil
}
for _, foreign := range resultSlice {
for _, local := range slice {
if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{.Function.Name}} = foreign
break
}
}
}
for _, foreign := range resultSlice {
for _, local := range slice {
if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{.Function.Name}} = foreign
break
}
}
}
return nil
return nil
}
{{- end -}}
{{end -}}
{{- end -}}{{- /* end with */ -}}
{{end -}}{{- /* end define */ -}}
{{- /* Begin execution of template for one-to-one eager load */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $rel) -}}
{{- end -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}}
{{- end -}}
{{end}}

View file

@ -1,136 +1,141 @@
{{- /* Begin execution of template for many-to-one or many-to-many eager load */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}}
{{- else -}}
{{- $varNameSingular := $dot.Table.Name | singular | camelCase -}}
{{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}}
{{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}}
{{- $dot := . -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- /* Begin execution of template for many-to-one eager load */ -}}
{{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}}
{{- else -}}
{{- /* Begin execution of template for many-to-many eager load */ -}}
{{- $varNameSingular := $dot.Table.Name | singular | camelCase -}}
{{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}}
{{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}}
{{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}}
{{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}}
// Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the
// loaded structs of the objects.
func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error {
var slice []*{{$txt.LocalTable.NameGo}}
var object *{{$txt.LocalTable.NameGo}}
var slice []*{{$txt.LocalTable.NameGo}}
var object *{{$txt.LocalTable.NameGo}}
count := 1
if singular {
object = {{$arg}}.(*{{$txt.LocalTable.NameGo}})
} else {
slice = *{{$arg}}.(*{{$slice}})
count = len(slice)
}
count := 1
if singular {
object = {{$arg}}.(*{{$txt.LocalTable.NameGo}})
} else {
slice = *{{$arg}}.(*{{$slice}})
count = len(slice)
}
args := make([]interface{}, count)
if singular {
args[0] = object.{{.Column | titleCase}}
} else {
for i, obj := range slice {
args[i] = obj.{{.Column | titleCase}}
}
}
args := make([]interface{}, count)
if singular {
args[0] = object.{{.Column | titleCase}}
} else {
for i, obj := range slice {
args[i] = obj.{{.Column | titleCase}}
}
}
{{if .ToJoinTable -}}
query := fmt.Sprintf(
`select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from "{{.ForeignTable}}" as "{{id 0}}" inner join "{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`,
strmangle.Placeholders(count, 1, 1),
)
{{else -}}
query := fmt.Sprintf(
`select * from "{{.ForeignTable}}" where "{{.ForeignColumn}}" in (%s)`,
strmangle.Placeholders(count, 1, 1),
)
{{end -}}
{{if .ToJoinTable -}}
{{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}}
query := fmt.Sprintf(
"select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
)
{{else -}}
query := fmt.Sprintf(
"select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1),
)
{{end -}}
if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
}
if boil.DebugMode {
fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args)
}
results, err := e.Query(query, args...)
if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable}}")
}
defer results.Close()
results, err := e.Query(query, args...)
if err != nil {
return errors.Wrap(err, "failed to eager load {{.ForeignTable}}")
}
defer results.Close()
var resultSlice []*{{$txt.ForeignTable.NameGo}}
{{if .ToJoinTable -}}
{{- $foreignTable := getTable $dot.Tables .ForeignTable -}}
{{- $joinTable := getTable $dot.Tables .JoinTable -}}
{{- $localCol := $joinTable.GetColumn .JoinLocalColumn}}
var localJoinCols []{{$localCol.Type}}
for results.Next() {
one := new({{$txt.ForeignTable.NameGo}})
var localJoinCol {{$localCol.Type}}
var resultSlice []*{{$txt.ForeignTable.NameGo}}
{{if .ToJoinTable -}}
{{- $foreignTable := getTable $dot.Tables .ForeignTable -}}
{{- $joinTable := getTable $dot.Tables .JoinTable -}}
{{- $localCol := $joinTable.GetColumn .JoinLocalColumn}}
var localJoinCols []{{$localCol.Type}}
for results.Next() {
one := new({{$txt.ForeignTable.NameGo}})
var localJoinCol {{$localCol.Type}}
err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol)
if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
}
err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol)
if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
}
resultSlice = append(resultSlice, one)
localJoinCols = append(localJoinCols, localJoinCol)
}
resultSlice = append(resultSlice, one)
localJoinCols = append(localJoinCols, localJoinCol)
}
if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
}
{{else -}}
if err = boil.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}")
}
{{end}}
if err = results.Err(); err != nil {
return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}")
}
{{else -}}
if err = queries.Bind(results, &resultSlice); err != nil {
return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}")
}
{{end}}
{{if not $dot.NoHooks -}}
if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil {
return err
}
}
}
{{if not $dot.NoHooks -}}
if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 {
for _, obj := range resultSlice {
if err := obj.doAfterSelectHooks(e); err != nil {
return err
}
}
}
{{- end}}
if singular {
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
object.R.{{$txt.Function.Name}} = resultSlice
return nil
}
{{- end}}
if singular {
if object.R == nil {
object.R = &{{$varNameSingular}}R{}
}
object.R.{{$txt.Function.Name}} = resultSlice
return nil
}
{{if .ToJoinTable -}}
for i, foreign := range resultSlice {
localJoinCol := localJoinCols[i]
for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == localJoinCol {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break
}
}
}
{{else -}}
for _, foreign := range resultSlice {
for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break
}
}
}
{{end}}
{{if .ToJoinTable -}}
for i, foreign := range resultSlice {
localJoinCol := localJoinCols[i]
for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == localJoinCol {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break
}
}
}
{{else -}}
for _, foreign := range resultSlice {
for _, local := range slice {
if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} {
if local.R == nil {
local.R = &{{$varNameSingular}}R{}
}
local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign)
break
}
}
}
{{end}}
return nil
return nil
}
{{end -}}{{/* if ForeignColumnUnique */}}
{{- end -}}{{/* range tomany */}}
{{- end -}}{{/* if isjointable */}}
{{- end -}}{{/* if IsJoinTable */}}

View file

@ -1,101 +1,105 @@
{{- define "relationship_to_one_setops_helper" -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}}
{{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}}
{{- with .Rel -}}
{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}}
// Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item.
// Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related.
// Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}.
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec boil.Executor, insert bool, related *{{.ForeignTable.NameGo}}) error {
var err error
if insert {
if err = related.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table")
}
}
var err error
if insert {
if err = related.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table")
}
}
oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}
{{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}}
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal
return errors.Wrap(err, "failed to update local table")
}
oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}
{{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}}
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal
return errors.Wrap(err, "failed to update local table")
}
if {{.Function.Receiver}}.R == nil {
{{.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{.Function.Name}}: related,
}
} else {
{{.Function.Receiver}}.R.{{.Function.Name}} = related
}
if {{.Function.Receiver}}.R == nil {
{{.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{.Function.Name}}: related,
}
} else {
{{.Function.Receiver}}.R.{{.Function.Name}} = related
}
{{if (or .ForeignKey.Unique .Function.OneToOne) -}}
if related.R == nil {
related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.Function.Receiver}},
}
} else {
related.R.{{.Function.ForeignName}} = {{.Function.Receiver}}
}
{{else -}}
if related.R == nil {
related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}},
}
} else {
related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}})
}
{{end -}}
{{if (or .ForeignKey.Unique .Function.OneToOne) -}}
if related.R == nil {
related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.Function.Receiver}},
}
} else {
related.R.{{.Function.ForeignName}} = {{.Function.Receiver}}
}
{{else -}}
if related.R == nil {
related.R = &{{$varNameSingular}}R{
{{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}},
}
} else {
related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}})
}
{{end -}}
{{if .ForeignKey.Nullable}}
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
{{end -}}
return nil
{{if .ForeignKey.Nullable}}
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
{{end -}}
return nil
}
{{- if .ForeignKey.Nullable}}
{{- if .ForeignKey.Nullable}}
// Remove{{.Function.Name}} relationship.
// Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil.
// Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional).
func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(exec boil.Executor, related *{{.ForeignTable.NameGo}}) error {
var err error
var err error
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
return errors.Wrap(err, "failed to update local table")
}
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false
if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil {
{{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true
return errors.Wrap(err, "failed to update local table")
}
{{.Function.Receiver}}.R.{{.Function.Name}} = nil
if related == nil || related.R == nil {
return nil
}
{{.Function.Receiver}}.R.{{.Function.Name}} = nil
if related == nil || related.R == nil {
return nil
}
{{if .ForeignKey.Unique -}}
related.R.{{.Function.ForeignName}} = nil
{{else -}}
for i, ri := range related.R.{{.Function.ForeignName}} {
if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} {
continue
}
{{if .ForeignKey.Unique -}}
related.R.{{.Function.ForeignName}} = nil
{{else -}}
for i, ri := range related.R.{{.Function.ForeignName}} {
if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} {
continue
}
ln := len(related.R.{{.Function.ForeignName}})
if ln > 1 && i < ln-1 {
related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1]
}
related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1]
break
}
{{end -}}
ln := len(related.R.{{.Function.ForeignName}})
if ln > 1 && i < ln-1 {
related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1]
}
related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1]
break
}
{{end -}}
return nil
return nil
}
{{end -}}
{{- end -}}
{{- end -}}{{/* if foreignkey nullable */}}
{{end -}}{{/* end with */}}
{{- end -}}{{/* end define */}}
{{- /* Begin execution of template for one-to-one setops */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_setops_helper" $rel -}}
{{- end -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}}
{{- end -}}
{{- end -}}

View file

@ -1,91 +1,93 @@
{{- /* Begin execution of template for many-to-one or many-to-many setops */ -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_setops_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $localNameSingular := .Table | singular | camelCase -}}
{{- $foreignNameSingular := .ForeignTable | singular | camelCase}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- $varNameSingular := .ForeignTable | singular | camelCase -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- /* Begin execution of template for many-to-one setops */ -}}
{{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}}
{{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}}
{{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- $localNameSingular := .Table | singular | camelCase -}}
{{- $foreignNameSingular := .ForeignTable | singular | camelCase}}
// Add{{$rel.Function.Name}} adds the given related objects to the existing relationships
// of the {{$table.Name | singular}}, optionally inserting them as new records.
// Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}.
// Sets related.R.{{$rel.Function.ForeignName}} appropriately.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error {
var err error
for _, rel := range related {
{{if not .ToJoinTable -}}
rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}}
{{if .ForeignColumnNullable -}}
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true
{{end -}}
{{end -}}
if insert {
if err = rel.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table")
}
}{{if not .ToJoinTable}} else {
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return errors.Wrap(err, "failed to update foreign table")
}
}{{end -}}
}
var err error
for _, rel := range related {
{{if not .ToJoinTable -}}
rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}}
{{if .ForeignColumnNullable -}}
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true
{{end -}}
{{end -}}
if insert {
if err = rel.Insert(exec); err != nil {
return errors.Wrap(err, "failed to insert into foreign table")
}
}{{if not .ToJoinTable}} else {
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return errors.Wrap(err, "failed to update foreign table")
}
}{{end -}}
}
{{if .ToJoinTable -}}
for _, rel := range related {
query := `insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}}
{{if .ToJoinTable -}}
for _, rel := range related {
query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
_, err = exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to insert into join table")
}
}
{{end -}}
_, err = exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to insert into join table")
}
}
{{end -}}
if {{$rel.Function.Receiver}}.R == nil {
{{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{$rel.Function.Name}}: related,
}
} else {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...)
}
if {{$rel.Function.Receiver}}.R == nil {
{{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{
{{$rel.Function.Name}}: related,
}
} else {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...)
}
{{if .ToJoinTable -}}
for _, rel := range related {
if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}},
}
} else {
rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}})
}
}
{{else -}}
for _, rel := range related {
if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}},
}
} else {
rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}}
}
}
{{end -}}
{{if .ToJoinTable -}}
for _, rel := range related {
if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}},
}
} else {
rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}})
}
}
{{else -}}
for _, rel := range related {
if rel.R == nil {
rel.R = &{{$foreignNameSingular}}R{
{{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}},
}
} else {
rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}}
}
}
{{end -}}
return nil
return nil
}
{{- if (or .ForeignColumnNullable .ToJoinTable)}}
{{- if (or .ForeignColumnNullable .ToJoinTable)}}
// Set{{$rel.Function.Name}} removes all previously related items of the
// {{$table.Name | singular}} replacing them completely with the passed
// in related items, optionally inserting them as new records.
@ -93,126 +95,126 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function
// Replaces {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} with related.
// Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error {
{{if .ToJoinTable -}}
query := `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1`
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{else -}}
query := `update "{{.ForeignTable}}" set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1`
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{end -}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
{{if .ToJoinTable -}}
query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{else -}}
query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}"
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
{{end -}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
_, err := exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to remove relationships before set")
}
_, err := exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to remove relationships before set")
}
{{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
{{else -}}
if {{$rel.Function.Receiver}}.R != nil {
for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
if rel.R == nil {
continue
}
{{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
{{else -}}
if {{$rel.Function.Receiver}}.R != nil {
for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
if rel.R == nil {
continue
}
rel.R.{{$rel.Function.ForeignName}} = nil
}
rel.R.{{$rel.Function.ForeignName}} = nil
}
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
}
{{end -}}
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil
}
{{end -}}
return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...)
return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...)
}
// Remove{{$rel.Function.Name}} relationships from objects passed in.
// Removes related items from R.{{$rel.Function.Name}} (uses pointer comparison, removal does not keep order)
// Sets related.R.{{$rel.Function.ForeignName}}.
func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Function.Name}}(exec boil.Executor, related ...*{{$rel.ForeignTable.NameGo}}) error {
var err error
{{if .ToJoinTable -}}
query := fmt.Sprintf(
`delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`,
strmangle.Placeholders(len(related), 1, 1),
)
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
var err error
{{if .ToJoinTable -}}
query := fmt.Sprintf(
"delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)",
strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1),
)
values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, values)
}
_, err = exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to remove relationships before set")
}
{{else -}}
for _, rel := range related {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
{{if not .ToJoinTable -}}
if rel.R != nil {
rel.R.{{$rel.Function.ForeignName}} = nil
}
{{end -}}
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return err
}
}
{{end -}}
_, err = exec.Exec(query, values...)
if err != nil {
return errors.Wrap(err, "failed to remove relationships before set")
}
{{else -}}
for _, rel := range related {
rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false
{{if not .ToJoinTable -}}
if rel.R != nil {
rel.R.{{$rel.Function.ForeignName}} = nil
}
{{end -}}
if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil {
return err
}
}
{{end -}}
{{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{end -}}
if {{$rel.Function.Receiver}}.R == nil {
return nil
}
{{if .ToJoinTable -}}
remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related)
{{end -}}
if {{$rel.Function.Receiver}}.R == nil {
return nil
}
for _, rel := range related {
for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
if rel != ri {
continue
}
for _, rel := range related {
for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} {
if rel != ri {
continue
}
ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}})
if ln > 1 && i < ln-1 {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1]
}
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1]
break
}
}
ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}})
if ln > 1 && i < ln-1 {
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1]
}
{{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1]
break
}
}
return nil
return nil
}
{{if .ToJoinTable -}}
{{if .ToJoinTable -}}
func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) {
for _, rel := range related {
if rel.R == nil {
continue
}
for i, ri := range rel.R.{{$rel.Function.ForeignName}} {
if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} {
continue
}
for _, rel := range related {
if rel.R == nil {
continue
}
for i, ri := range rel.R.{{$rel.Function.ForeignName}} {
if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} {
continue
}
ln := len(rel.R.{{$rel.Function.ForeignName}})
if ln > 1 && i < ln-1 {
rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1]
}
rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1]
break
}
}
ln := len(rel.R.{{$rel.Function.ForeignName}})
if ln > 1 && i < ln-1 {
rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1]
}
rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1]
break
}
}
}
{{end -}}{{- /* if join table */ -}}
{{- end -}}{{- /* if nullable foreign key */ -}}
{{- end -}}{{- /* if unique foreign key */ -}}
{{- end -}}{{- /* range relationships */ -}}
{{- end -}}{{- /* outer if join table */ -}}
{{end -}}{{- /* if ToJoinTable */ -}}
{{- end -}}{{- /* if nullable foreign key */ -}}
{{- end -}}{{- /* if unique foreign key */ -}}
{{- end -}}{{- /* range relationships */ -}}
{{- end -}}{{- /* if IsJoinTable */ -}}

View file

@ -3,11 +3,11 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
// {{$tableNamePlural}}G retrieves all records.
func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query {
return {{$tableNamePlural}}(boil.GetDB(), mods...)
return {{$tableNamePlural}}(boil.GetDB(), mods...)
}
// {{$tableNamePlural}} retrieves all the records using an executor.
func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query {
mods = append(mods, qm.From("{{.Table.Name}}"))
return {{$varNameSingular}}Query{NewQuery(exec, mods...)}
mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}"))
return {{$varNameSingular}}Query{NewQuery(exec, mods...)}
}

View file

@ -4,53 +4,53 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
// {{$tableNameSingular}}FindG retrieves a single record by ID.
// Find{{$tableNameSingular}}G retrieves a single record by ID.
func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) {
return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
}
// {{$tableNameSingular}}FindGP retrieves a single record by ID, and panics on error.
// Find{{$tableNameSingular}}GP retrieves a single record by ID, and panics on error.
func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} {
retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
if err != nil {
panic(boil.WrapErr(err))
}
retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...)
if err != nil {
panic(boil.WrapErr(err))
}
return retobj
return retobj
}
// {{$tableNameSingular}}Find retrieves a single record by ID with an executor.
// Find{{$tableNameSingular}} retrieves a single record by ID with an executor.
// If selectCols is empty Find will return all columns.
func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) {
{{$varNameSingular}}Obj := &{{$tableNameSingular}}{}
{{$varNameSingular}}Obj := &{{$tableNameSingular}}{}
sel := "*"
if len(selectCols) > 0 {
sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",")
}
query := fmt.Sprintf(
`select %s from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}}`, sel,
)
sel := "*"
if len(selectCols) > 0 {
sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",")
}
query := fmt.Sprintf(
"select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel,
)
q := boil.SQL(exec, query, {{$pkNames | join ", "}})
q := queries.Raw(exec, query, {{$pkNames | join ", "}})
err := q.Bind({{$varNameSingular}}Obj)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows
}
return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}")
}
err := q.Bind({{$varNameSingular}}Obj)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, sql.ErrNoRows
}
return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}")
}
return {{$varNameSingular}}Obj, nil
return {{$varNameSingular}}Obj, nil
}
// {{$tableNameSingular}}FindP retrieves a single record by ID with an executor, and panics on error.
// Find{{$tableNameSingular}}P retrieves a single record by ID with an executor, and panics on error.
func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} {
retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...)
if err != nil {
panic(boil.WrapErr(err))
}
retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...)
if err != nil {
panic(boil.WrapErr(err))
}
return retobj
return retobj
}

View file

@ -1,24 +1,25 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// InsertG a single record. See Insert for whitelist behavior description.
func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error {
return o.Insert(boil.GetDB(), whitelist...)
return o.Insert(boil.GetDB(), whitelist...)
}
// InsertGP a single record, and panics on error. See Insert for whitelist
// behavior description.
func (o *{{$tableNameSingular}}) InsertGP(whitelist ... string) {
if err := o.Insert(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err))
}
if err := o.Insert(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err))
}
}
// InsertP a single record using an executor, and panics on error. See Insert
// for whitelist behavior description.
func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... string) {
if err := o.Insert(exec, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
if err := o.Insert(exec, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
}
// Insert a single record using an executor.
@ -27,115 +28,131 @@ func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... strin
// - All columns without a default value are included (i.e. name, age)
// - All columns with a default, but non-zero are included (i.e. health = 75)
func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
}
if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
}
var err error
{{- template "timestamp_insert_helper" . }}
var err error
{{- template "timestamp_insert_helper" . }}
{{if not .NoHooks -}}
if err := o.doBeforeInsertHooks(exec); err != nil {
return err
}
{{- end}}
{{if not .NoHooks -}}
if err := o.doBeforeInsertHooks(exec); err != nil {
return err
}
{{- end}}
nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
nzDefaults := queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
key := makeCacheKey(whitelist, nzDefaults)
{{$varNameSingular}}InsertCacheMut.RLock()
cache, cached := {{$varNameSingular}}InsertCache[key]
{{$varNameSingular}}InsertCacheMut.RUnlock()
key := makeCacheKey(whitelist, nzDefaults)
{{$varNameSingular}}InsertCacheMut.RLock()
cache, cached := {{$varNameSingular}}InsertCache[key]
{{$varNameSingular}}InsertCacheMut.RUnlock()
if !cached {
wl, returnColumns := strmangle.InsertColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
nzDefaults,
whitelist,
)
if !cached {
wl, returnColumns := strmangle.InsertColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
nzDefaults,
whitelist,
)
cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl)
if err != nil {
return err
}
cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns)
if err != nil {
return err
}
cache.query = fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1))
cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl)
if err != nil {
return err
}
cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns)
if err != nil {
return err
}
cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1))
if len(cache.retMapping) != 0 {
{{if .UseLastInsertID -}}
cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns))
{{else -}}
cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
{{end -}}
}
}
if len(cache.retMapping) != 0 {
{{if .UseLastInsertID -}}
cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns))
{{else -}}
cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"))
{{end -}}
}
}
value := reflect.Indirect(reflect.ValueOf(o))
vals := boil.ValuesFromMapping(value, cache.valueMapping)
{{if .UseLastInsertID}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, vals)
}
value := reflect.Indirect(reflect.ValueOf(o))
vals := queries.ValuesFromMapping(value, cache.valueMapping)
{{if .UseLastInsertID}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, vals)
}
result, err := exec.Exec(ins, vals...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
}
result, err := exec.Exec(cache.query, vals...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
}
var lastID int64
var identifierCols []interface{}
if len(cache.retMapping) == 0 {
{{if not .NoHooks -}}
return o.doAfterInsertHooks(exec)
{{else -}}
return nil
{{end -}}
}
if len(cache.retMapping) == 0 {
goto CacheNoHooks
}
lastID, err := result.LastInsertId()
if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 {
return ErrSyncFail
}
lastID, err = result.LastInsertId()
if err != nil {
return ErrSyncFail
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.retQuery)
fmt.Fprintln(boil.DebugWriter, lastID)
}
if lastID != 0 {
{{- $colName := index .Table.PKey.Columns 0 -}}
{{- $col := .Table.GetColumn $colName -}}
o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID)
identifierCols = []interface{}{lastID}
} else {
identifierCols = []interface{}{
{{range .Table.PKey.Columns -}}
o.{{. | singular | titleCase}},
{{end -}}
}
}
err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
}
{{else}}
if len(cache.retMapping) != 0 {
err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...)
} else {
_, err = exec.Exec(cache.query, vals...)
}
if lastID != 0 && len(cache.retMapping) == 1 {
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.retQuery)
fmt.Fprintln(boil.DebugWriter, identifierCols...)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, vals)
}
err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
}
}
{{else}}
if len(cache.retMapping) != 0 {
err = exec.QueryRow(cache.query, vals...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...)
} else {
_, err = exec.Exec(cache.query, vals...)
}
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
}
{{end}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, vals)
}
if !cached {
{{$varNameSingular}}InsertCacheMut.Lock()
{{$varNameSingular}}InsertCache[key] = cache
{{$varNameSingular}}InsertCacheMut.Unlock()
}
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}")
}
{{end}}
{{if .UseLastInsertID -}}
CacheNoHooks:
{{- end}}
if !cached {
{{$varNameSingular}}InsertCacheMut.Lock()
{{$varNameSingular}}InsertCache[key] = cache
{{$varNameSingular}}InsertCacheMut.Unlock()
}
{{if not .NoHooks -}}
return o.doAfterInsertHooks(exec)
{{- else -}}
return nil
{{- end}}
{{if not .NoHooks -}}
return o.doAfterInsertHooks(exec)
{{- else -}}
return nil
{{- end}}
}

View file

@ -3,28 +3,29 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// UpdateG a single {{$tableNameSingular}} record. See Update for
// whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error {
return o.Update(boil.GetDB(), whitelist...)
return o.Update(boil.GetDB(), whitelist...)
}
// UpdateGP a single {{$tableNameSingular}} record.
// UpdateGP takes a whitelist of column names that should be updated.
// Panics on error. See Update for whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateGP(whitelist ...string) {
if err := o.Update(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err))
}
if err := o.Update(boil.GetDB(), whitelist...); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateP uses an executor to update the {{$tableNameSingular}}, and panics on error.
// See Update for whitelist behavior description.
func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... string) {
err := o.Update(exec, whitelist...)
if err != nil {
panic(boil.WrapErr(err))
}
err := o.Update(exec, whitelist...)
if err != nil {
panic(boil.WrapErr(err))
}
}
// Update uses an executor to update the {{$tableNameSingular}}.
@ -35,146 +36,147 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin
// Update does not automatically update the record in case of default values. Use .Reload()
// to refresh the records.
func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error {
{{- template "timestamp_update_helper" . -}}
{{- template "timestamp_update_helper" . -}}
var err error
{{if not .NoHooks -}}
if err = o.doBeforeUpdateHooks(exec); err != nil {
return err
}
{{end -}}
var err error
{{if not .NoHooks -}}
if err = o.doBeforeUpdateHooks(exec); err != nil {
return err
}
{{end -}}
key := makeCacheKey(whitelist, nil)
{{$varNameSingular}}UpdateCacheMut.RLock()
cache, cached := {{$varNameSingular}}UpdateCache[key]
{{$varNameSingular}}UpdateCacheMut.RUnlock()
key := makeCacheKey(whitelist, nil)
{{$varNameSingular}}UpdateCacheMut.RLock()
cache, cached := {{$varNameSingular}}UpdateCache[key]
{{$varNameSingular}}UpdateCacheMut.RUnlock()
if !cached {
wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist)
if !cached {
wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist)
cache.query = fmt.Sprintf(`UPDATE "{{.Table.Name}}" SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns))
cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...))
if err != nil {
return err
}
}
cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s",
strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl),
strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns),
)
cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...))
if err != nil {
return err
}
}
if len(cache.valueMapping) == 0 {
return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
}
if len(cache.valueMapping) == 0 {
return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist")
}
values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping)
values := queries.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, values)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, values)
}
result, err := exec.Exec(cache.query, values...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row")
}
result, err := exec.Exec(cache.query, values...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row")
}
if r, err := result.RowsAffected(); err == nil && r != 1 {
return errors.Errorf("failed to update single row, updated %d rows", r)
}
if r, err := result.RowsAffected(); err == nil && r != 1 {
return errors.Errorf("failed to update single row, updated %d rows", r)
}
if !cached {
{{$varNameSingular}}UpdateCacheMut.Lock()
{{$varNameSingular}}UpdateCache[key] = cache
{{$varNameSingular}}UpdateCacheMut.Unlock()
}
if !cached {
{{$varNameSingular}}UpdateCacheMut.Lock()
{{$varNameSingular}}UpdateCache[key] = cache
{{$varNameSingular}}UpdateCacheMut.Unlock()
}
{{if not .NoHooks -}}
return o.doAfterUpdateHooks(exec)
{{- else -}}
return nil
{{- end}}
{{if not .NoHooks -}}
return o.doAfterUpdateHooks(exec)
{{- else -}}
return nil
{{- end}}
}
// UpdateAllP updates all rows with matching column names, and panics on error.
func (q {{$varNameSingular}}Query) UpdateAllP(cols M) {
if err := q.UpdateAll(cols); err != nil {
panic(boil.WrapErr(err))
}
if err := q.UpdateAll(cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAll updates all rows with the specified column values.
func (q {{$varNameSingular}}Query) UpdateAll(cols M) error {
boil.SetUpdate(q.Query, cols)
queries.SetUpdate(q.Query, cols)
_, err := boil.ExecQuery(q.Query)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}")
}
_, err := q.Query.Exec()
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}")
}
return nil
return nil
}
// UpdateAllG updates all rows with the specified column values.
func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error {
return o.UpdateAll(boil.GetDB(), cols)
return o.UpdateAll(boil.GetDB(), cols)
}
// UpdateAllGP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) {
if err := o.UpdateAll(boil.GetDB(), cols); err != nil {
panic(boil.WrapErr(err))
}
if err := o.UpdateAll(boil.GetDB(), cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAllP updates all rows with the specified column values, and panics on error.
func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) {
if err := o.UpdateAll(exec, cols); err != nil {
panic(boil.WrapErr(err))
}
if err := o.UpdateAll(exec, cols); err != nil {
panic(boil.WrapErr(err))
}
}
// UpdateAll updates all rows with the specified column values, using an executor.
func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error {
ln := int64(len(o))
if ln == 0 {
return nil
}
ln := int64(len(o))
if ln == 0 {
return nil
}
if len(cols) == 0 {
return errors.New("{{.PkgName}}: update all requires at least one column argument")
}
if len(cols) == 0 {
return errors.New("{{.PkgName}}: update all requires at least one column argument")
}
colNames := make([]string, len(cols))
args := make([]interface{}, len(cols))
colNames := make([]string, len(cols))
args := make([]interface{}, len(cols))
i := 0
for name, value := range cols {
colNames[i] = strmangle.IdentQuote(name)
args[i] = value
i++
}
i := 0
for name, value := range cols {
colNames[i] = name
args[i] = value
i++
}
// Append all of the primary key values for each column
args = append(args, o.inPrimaryKeyArgs()...)
// Append all of the primary key values for each column
args = append(args, o.inPrimaryKeyArgs()...)
sql := fmt.Sprintf(
`UPDATE {{.Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`,
strings.Join(colNames, ", "),
strmangle.Placeholders(len(colNames), 1, 1),
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
sql := fmt.Sprintf(
"UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)",
strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames),
strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...)
}
result, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice")
}
result, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice")
}
if r, err := result.RowsAffected(); err == nil && r != ln {
return errors.Errorf("failed to update %d rows, only affected %d", ln, r)
}
if r, err := result.RowsAffected(); err == nil && r != ln {
return errors.Errorf("failed to update %d rows, only affected %d", ln, r)
}
return nil
return nil
}

View file

@ -1,85 +1,188 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// UpsertG attempts an insert, and does an update or ignore on conflict.
func (o *{{$tableNameSingular}}) UpsertG(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error {
return o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...)
func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error {
return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...)
}
// UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error.
func (o *{{$tableNameSingular}}) UpsertGP(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) {
if err := o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) {
if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
}
// UpsertP attempts an insert using an executor, and does an update or ignore on conflict.
// UpsertP panics on error.
func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) {
if err := o.Upsert(exec, updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) {
if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil {
panic(boil.WrapErr(err))
}
}
// Upsert attempts an insert using an executor, and does an update or ignore on conflict.
func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert")
}
func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert")
}
{{- template "timestamp_upsert_helper" . }}
{{- template "timestamp_upsert_helper" . }}
{{if not .NoHooks -}}
if err := o.doBeforeUpsertHooks(exec); err != nil {
return err
}
{{- end}}
{{if not .NoHooks -}}
if err := o.doBeforeUpsertHooks(exec); err != nil {
return err
}
{{- end}}
var err error
var ret []string
whitelist, ret = strmangle.InsertColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o),
whitelist,
)
update := strmangle.UpdateColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns,
updateColumns,
)
conflict := conflictColumns
if len(conflict) == 0 {
conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns))
copy(conflict, {{$varNameSingular}}PrimaryKeyColumns)
}
// Build cache key in-line uglily - mysql vs postgres problems
buf := strmangle.GetBuffer()
{{if ne .DriverName "mysql" -}}
if updateOnConflict {
buf.WriteByte('t')
} else {
buf.WriteByte('f')
}
buf.WriteByte('.')
for _, c := range conflictColumns {
buf.WriteString(c)
}
buf.WriteByte('.')
{{end -}}
for _, c := range updateColumns {
buf.WriteString(c)
}
buf.WriteByte('.')
for _, c := range whitelist {
buf.WriteString(c)
}
key := buf.String()
strmangle.PutBuffer(buf)
query := boil.BuildUpsertQuery("{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist)
{{$varNameSingular}}UpsertCacheMut.RLock()
cache, cached := {{$varNameSingular}}UpsertCache[key]
{{$varNameSingular}}UpsertCacheMut.RUnlock()
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...))
}
var err error
{{- if .UseLastInsertID}}
return errors.New("don't know how to do this yet")
{{- else}}
if len(ret) != 0 {
err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...)
} else {
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...)
}
{{- end}}
if !cached {
var ret []string
whitelist, ret = strmangle.InsertColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}ColumnsWithDefault,
{{$varNameSingular}}ColumnsWithoutDefault,
queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o),
whitelist,
)
update := strmangle.UpdateColumnSet(
{{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns,
updateColumns,
)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}")
}
{{if ne .DriverName "mysql" -}}
var conflict []string
if len(conflictColumns) == 0 {
conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns))
copy(conflict, {{$varNameSingular}}PrimaryKeyColumns)
}
cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist)
{{- else -}}
cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist)
cache.retQuery = fmt.Sprintf(
"SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}",
strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","),
)
{{- end}}
{{if not .NoHooks -}}
if err := o.doAfterUpsertHooks(exec); err != nil {
return err
}
{{- end}}
cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, whitelist)
if err != nil {
return err
}
if len(ret) != 0 {
cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, ret)
if err != nil {
return err
}
}
}
return nil
value := reflect.Indirect(reflect.ValueOf(o))
values := queries.ValuesFromMapping(value, cache.valueMapping)
var returns []interface{}
if len(cache.retMapping) != 0 {
returns = queries.PtrsFromMapping(value, cache.retMapping)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.query)
fmt.Fprintln(boil.DebugWriter, values)
}
{{- if .UseLastInsertID}}
result, err := exec.Exec(cache.query, values...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}")
}
if len(cache.retMapping) == 0 {
{{if not .NoHooks -}}
return o.doAfterUpsertHooks(exec)
{{else -}}
return nil
{{end -}}
}
lastID, err := result.LastInsertId()
if err != nil {
return ErrSyncFail
}
var identifierCols []interface{}
if lastID != 0 {
{{- $colName := index .Table.PKey.Columns 0 -}}
{{- $col := .Table.GetColumn $colName -}}
o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID)
identifierCols = []interface{}{lastID}
} else {
identifierCols = []interface{}{
{{range .Table.PKey.Columns -}}
o.{{. | singular | titleCase}},
{{end -}}
}
}
if lastID != 0 && len(cache.retMapping) == 1 {
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, cache.retQuery)
fmt.Fprintln(boil.DebugWriter, identifierCols...)
}
err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(returns...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}")
}
}
{{- else}}
if len(cache.retMapping) != 0 {
err = exec.QueryRow(cache.query, values...).Scan(returns...)
} else {
_, err = exec.Exec(cache.query, values...)
}
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}")
}
{{- end}}
if !cached {
{{$varNameSingular}}UpsertCacheMut.Lock()
{{$varNameSingular}}UpsertCache[key] = cache
{{$varNameSingular}}UpsertCacheMut.Unlock()
}
{{if not .NoHooks -}}
return o.doAfterUpsertHooks(exec)
{{- else -}}
return nil
{{- end}}
}

View file

@ -1,161 +1,162 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// DeleteP deletes a single {{$tableNameSingular}} record with an executor.
// DeleteP will match against the primary key column to find the record to delete.
// Panics on error.
func (o *{{$tableNameSingular}}) DeleteP(exec boil.Executor) {
if err := o.Delete(exec); err != nil {
panic(boil.WrapErr(err))
}
if err := o.Delete(exec); err != nil {
panic(boil.WrapErr(err))
}
}
// DeleteG deletes a single {{$tableNameSingular}} record.
// DeleteG will match against the primary key column to find the record to delete.
func (o *{{$tableNameSingular}}) DeleteG() error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion")
}
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion")
}
return o.Delete(boil.GetDB())
return o.Delete(boil.GetDB())
}
// DeleteGP deletes a single {{$tableNameSingular}} record.
// DeleteGP will match against the primary key column to find the record to delete.
// Panics on error.
func (o *{{$tableNameSingular}}) DeleteGP() {
if err := o.DeleteG(); err != nil {
panic(boil.WrapErr(err))
}
if err := o.DeleteG(); err != nil {
panic(boil.WrapErr(err))
}
}
// Delete deletes a single {{$tableNameSingular}} record with an executor.
// Delete will match against the primary key column to find the record to delete.
func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete")
}
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete")
}
{{if not .NoHooks -}}
if err := o.doBeforeDeleteHooks(exec); err != nil {
return err
}
{{- end}}
{{if not .NoHooks -}}
if err := o.doBeforeDeleteHooks(exec); err != nil {
return err
}
{{- end}}
args := o.inPrimaryKeyArgs()
args := o.inPrimaryKeyArgs()
sql := `DELETE FROM {{.Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}`
sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}"
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args...)
}
_, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}")
}
_, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}")
}
{{if not .NoHooks -}}
if err := o.doAfterDeleteHooks(exec); err != nil {
return err
}
{{- end}}
{{if not .NoHooks -}}
if err := o.doAfterDeleteHooks(exec); err != nil {
return err
}
{{- end}}
return nil
return nil
}
// DeleteAllP deletes all rows, and panics on error.
func (q {{$varNameSingular}}Query) DeleteAllP() {
if err := q.DeleteAll(); err != nil {
panic(boil.WrapErr(err))
}
if err := q.DeleteAll(); err != nil {
panic(boil.WrapErr(err))
}
}
// DeleteAll deletes all matching rows.
func (q {{$varNameSingular}}Query) DeleteAll() error {
if q.Query == nil {
return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all")
}
if q.Query == nil {
return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all")
}
boil.SetDelete(q.Query)
queries.SetDelete(q.Query)
_, err := boil.ExecQuery(q.Query)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}")
}
_, err := q.Query.Exec()
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}")
}
return nil
return nil
}
// DeleteAll deletes all rows in the slice, and panics on error.
// DeleteAllGP deletes all rows in the slice, and panics on error.
func (o {{$tableNameSingular}}Slice) DeleteAllGP() {
if err := o.DeleteAllG(); err != nil {
panic(boil.WrapErr(err))
}
if err := o.DeleteAllG(); err != nil {
panic(boil.WrapErr(err))
}
}
// DeleteAllG deletes all rows in the slice.
func (o {{$tableNameSingular}}Slice) DeleteAllG() error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
}
return o.DeleteAll(boil.GetDB())
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
}
return o.DeleteAll(boil.GetDB())
}
// DeleteAllP deletes all rows in the slice, using an executor, and panics on error.
func (o {{$tableNameSingular}}Slice) DeleteAllP(exec boil.Executor) {
if err := o.DeleteAll(exec); err != nil {
panic(boil.WrapErr(err))
}
if err := o.DeleteAll(exec); err != nil {
panic(boil.WrapErr(err))
}
}
// DeleteAll deletes all rows in the slice, using an executor.
func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
}
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
}
if len(o) == 0 {
return nil
}
if len(o) == 0 {
return nil
}
{{if not .NoHooks -}}
if len({{$varNameSingular}}BeforeDeleteHooks) != 0 {
for _, obj := range o {
if err := obj.doBeforeDeleteHooks(exec); err != nil {
return err
}
}
}
{{- end}}
{{if not .NoHooks -}}
if len({{$varNameSingular}}BeforeDeleteHooks) != 0 {
for _, obj := range o {
if err := obj.doBeforeDeleteHooks(exec); err != nil {
return err
}
}
}
{{- end}}
args := o.inPrimaryKeyArgs()
args := o.inPrimaryKeyArgs()
sql := fmt.Sprintf(
`DELETE FROM {{.Table.Name}} WHERE (%s) IN (%s)`,
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
sql := fmt.Sprintf(
"DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)",
strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args)
}
_, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice")
}
_, err := exec.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice")
}
{{if not .NoHooks -}}
if len({{$varNameSingular}}AfterDeleteHooks) != 0 {
for _, obj := range o {
if err := obj.doAfterDeleteHooks(exec); err != nil {
return err
}
}
}
{{- end}}
{{if not .NoHooks -}}
if len({{$varNameSingular}}AfterDeleteHooks) != 0 {
for _, obj := range o {
if err := obj.doAfterDeleteHooks(exec); err != nil {
return err
}
}
}
{{- end}}
return nil
return nil
}

View file

@ -1,85 +1,94 @@
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// ReloadGP refetches the object from the database and panics on error.
func (o *{{$tableNameSingular}}) ReloadGP() {
if err := o.ReloadG(); err != nil {
panic(boil.WrapErr(err))
}
if err := o.ReloadG(); err != nil {
panic(boil.WrapErr(err))
}
}
// ReloadP refetches the object from the database with an executor. Panics on error.
func (o *{{$tableNameSingular}}) ReloadP(exec boil.Executor) {
if err := o.Reload(exec); err != nil {
panic(boil.WrapErr(err))
}
if err := o.Reload(exec); err != nil {
panic(boil.WrapErr(err))
}
}
// ReloadG refetches the object from the database using the primary keys.
func (o *{{$tableNameSingular}}) ReloadG() error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload")
}
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload")
}
return o.Reload(boil.GetDB())
return o.Reload(boil.GetDB())
}
// Reload refetches the object from the database
// using the primary keys with an executor.
func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error {
ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
if err != nil {
return err
}
ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
if err != nil {
return err
}
*o = *ret
return nil
*o = *ret
return nil
}
// ReloadAllGP refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
// Panics on error.
func (o *{{$tableNameSingular}}Slice) ReloadAllGP() {
if err := o.ReloadAllG(); err != nil {
panic(boil.WrapErr(err))
}
if err := o.ReloadAllG(); err != nil {
panic(boil.WrapErr(err))
}
}
// ReloadAllP refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
// Panics on error.
func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) {
if err := o.ReloadAll(exec); err != nil {
panic(boil.WrapErr(err))
}
if err := o.ReloadAll(exec); err != nil {
panic(boil.WrapErr(err))
}
}
// ReloadAllG refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAllG() error {
if o == nil {
return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all")
}
if o == nil {
return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all")
}
return o.ReloadAll(boil.GetDB())
return o.ReloadAll(boil.GetDB())
}
// ReloadAll refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error {
if o == nil || len(*o) == 0 {
return nil
}
if o == nil || len(*o) == 0 {
return nil
}
{{$varNamePlural}} := {{$tableNameSingular}}Slice{}
args := o.inPrimaryKeyArgs()
{{$varNamePlural}} := {{$tableNameSingular}}Slice{}
args := o.inPrimaryKeyArgs()
sql := fmt.Sprintf(
`SELECT {{.Table.Name}}.* FROM {{.Table.Name}} WHERE (%s) IN (%s)`,
strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
sql := fmt.Sprintf(
"SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)",
strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","),
strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
q := boil.SQL(exec, sql, args...)
q := queries.Raw(exec, sql, args...)
err := q.Bind(&{{$varNamePlural}})
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice")
}
err := q.Bind(&{{$varNamePlural}})
if err != nil {
return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice")
}
*o = {{$varNamePlural}}
*o = {{$varNamePlural}}
return nil
return nil
}

View file

@ -2,48 +2,49 @@
{{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}}
{{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}}
{{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}}
{{- $schemaTable := .Table.Name | .SchemaTable -}}
// {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists.
func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) {
var exists bool
var exists bool
sql := `select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}} limit 1)`
sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)"
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}})
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}})
}
row := exec.QueryRow(sql, {{$pkNames | join ", "}})
row := exec.QueryRow(sql, {{$pkNames | join ", "}})
err := row.Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists")
}
err := row.Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists")
}
return exists, nil
return exists, nil
}
// {{$tableNameSingular}}ExistsG checks if the {{$tableNameSingular}} row exists.
func {{$tableNameSingular}}ExistsG({{$pkArgs}}) (bool, error) {
return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
}
// {{$tableNameSingular}}ExistsGP checks if the {{$tableNameSingular}} row exists. Panics on error.
func {{$tableNameSingular}}ExistsGP({{$pkArgs}}) bool {
e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
if err != nil {
panic(boil.WrapErr(err))
}
e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}})
if err != nil {
panic(boil.WrapErr(err))
}
return e
return e
}
// {{$tableNameSingular}}ExistsP checks if the {{$tableNameSingular}} row exists. Panics on error.
func {{$tableNameSingular}}ExistsP(exec boil.Executor, {{$pkArgs}}) bool {
e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}})
if err != nil {
panic(boil.WrapErr(err))
}
e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}})
if err != nil {
panic(boil.WrapErr(err))
}
return e
return e
}

View file

@ -1,23 +1,23 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $tableNameSingular := .Table.Name | singular | titleCase -}}
func (o {{$tableNameSingular}}) inPrimaryKeyArgs() []interface{} {
var args []interface{}
var args []interface{}
{{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o.{{titleCase $value}})
{{ end -}}
{{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o.{{titleCase $value}})
{{ end -}}
return args
return args
}
func (o {{$tableNameSingular}}Slice) inPrimaryKeyArgs() []interface{} {
var args []interface{}
var args []interface{}
for i := 0; i < len(o); i++ {
{{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o[i].{{titleCase $value}})
{{ end -}}
}
for i := 0; i < len(o); i++ {
{{- range $key, $value := .Table.PKey.Columns }}
args = append(args, o[i].{{titleCase $value}})
{{ end -}}
}
return args
return args
}

View file

@ -1,82 +1,82 @@
{{- define "timestamp_insert_helper" -}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true
}
{{- else}}
if o.CreatedAt.IsZero() {
o.CreatedAt = currTime
}
{{- end -}}
{{- end -}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
if o.UpdatedAt.Time.IsZero() {
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
}
{{- else}}
if o.UpdatedAt.IsZero() {
o.UpdatedAt = currTime
}
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true
}
{{- else}}
if o.CreatedAt.IsZero() {
o.CreatedAt = currTime
}
{{- end -}}
{{- end -}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
if o.UpdatedAt.Time.IsZero() {
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
}
{{- else}}
if o.UpdatedAt.IsZero() {
o.UpdatedAt = currTime
}
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{- end -}}
{{- define "timestamp_update_helper" -}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
{{- else}}
o.UpdatedAt = currTime
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
{{- else}}
o.UpdatedAt = currTime
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{end -}}
{{- define "timestamp_upsert_helper" -}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true
}
{{- else}}
if o.CreatedAt.IsZero() {
o.CreatedAt = currTime
}
{{- end -}}
{{- end -}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
{{- else}}
o.UpdatedAt = currTime
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{- if not .NoAutoTimestamps -}}
{{- $colNames := .Table.Columns | columnNames -}}
{{if containsAny $colNames "created_at" "updated_at"}}
currTime := time.Now().In(boil.GetLocation())
{{range $ind, $col := .Table.Columns}}
{{- if eq $col.Name "created_at" -}}
{{- if $col.Nullable}}
if o.CreatedAt.Time.IsZero() {
o.CreatedAt.Time = currTime
o.CreatedAt.Valid = true
}
{{- else}}
if o.CreatedAt.IsZero() {
o.CreatedAt = currTime
}
{{- end -}}
{{- end -}}
{{- if eq $col.Name "updated_at" -}}
{{- if $col.Nullable}}
o.UpdatedAt.Time = currTime
o.UpdatedAt.Valid = true
{{- else}}
o.UpdatedAt = currTime
{{- end -}}
{{- end -}}
{{end}}
{{end}}
{{- end}}
{{end -}}

View file

@ -1,12 +1,19 @@
var dialect = queries.Dialect{
LQ: 0x{{printf "%x" .Dialect.LQ}},
RQ: 0x{{printf "%x" .Dialect.RQ}},
IndexPlaceholders: {{.Dialect.IndexPlaceholders}},
}
// NewQueryG initializes a new Query using the passed in QueryMods
func NewQueryG(mods ...qm.QueryMod) *boil.Query {
func NewQueryG(mods ...qm.QueryMod) *queries.Query {
return NewQuery(boil.GetDB(), mods...)
}
// NewQuery initializes a new Query using the passed in QueryMods
func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query {
q := &boil.Query{}
boil.SetExecutor(q, exec)
func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query {
q := &queries.Query{}
queries.SetExecutor(q, exec)
queries.SetDialect(q, &dialect)
qm.Apply(q, mods...)
return q

View file

@ -6,33 +6,32 @@ type M map[string]interface{}
// fails or there was a primary key configuration that was not resolvable.
var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert")
type insertCache struct{
query string
retQuery string
valueMapping []uint64
retMapping []uint64
type insertCache struct {
query string
retQuery string
valueMapping []uint64
retMapping []uint64
}
type updateCache struct{
query string
valueMapping []uint64
type updateCache struct {
query string
valueMapping []uint64
}
func makeCacheKey(wl, nzDefaults []string) string {
buf := strmangle.GetBuffer()
buf := strmangle.GetBuffer()
for _, w := range wl {
buf.WriteString(w)
}
if len(nzDefaults) != 0 {
buf.WriteByte('.')
}
for _, nz := range nzDefaults {
buf.WriteString(nz)
}
for _, w := range wl {
buf.WriteString(w)
}
if len(nzDefaults) != 0 {
buf.WriteByte('.')
}
for _, nz := range nzDefaults {
buf.WriteString(nz)
}
str := buf.String()
strmangle.PutBuffer(buf)
return str
str := buf.String()
strmangle.PutBuffer(buf)
return str
}

View file

@ -3,11 +3,11 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}(t *testing.T) {
t.Parallel()
t.Parallel()
query := {{$tableNamePlural}}(nil)
query := {{$tableNamePlural}}(nil)
if query.Query == nil {
t.Error("expected a query, got nothing")
}
if query.Query == nil {
t.Error("expected a query, got nothing")
}
}

View file

@ -3,93 +3,93 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Delete(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}.Delete(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}.Delete(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
}
func test{{$tableNamePlural}}QueryDeleteAll(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil {
t.Error(err)
}
if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
}
func test{{$tableNamePlural}}SliceDeleteAll(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.DeleteAll(tx); err != nil {
t.Error(err)
}
if err = slice.DeleteAll(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
if count != 0 {
t.Error("want zero records, got:", count)
}
}

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Exists(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
{{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}}
e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}})
if err != nil {
t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err)
}
if e != true {
t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.")
}
{{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}}
e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}})
if err != nil {
t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err)
}
if e != true {
t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.")
}
}

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Find(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
{{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}})
if err != nil {
t.Error(err)
}
{{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}})
if err != nil {
t.Error(err)
}
if {{$varNameSingular}}Found == nil {
t.Error("want a record, got nil")
}
if {{$varNameSingular}}Found == nil {
t.Error("want a record, got nil")
}
}

View file

@ -3,111 +3,111 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Bind(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil {
t.Error(err)
}
if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil {
t.Error(err)
}
}
func test{{$tableNamePlural}}One(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
if x, err := {{$tableNamePlural}}(tx).One(); err != nil {
t.Error(err)
} else if x == nil {
t.Error("expected to get a non nil record")
}
if x, err := {{$tableNamePlural}}(tx).One(); err != nil {
t.Error(err)
} else if x == nil {
t.Error("expected to get a non nil record")
}
}
func test{{$tableNamePlural}}All(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err)
}
slice, err := {{$tableNamePlural}}(tx).All()
if err != nil {
t.Error(err)
}
slice, err := {{$tableNamePlural}}(tx).All()
if err != nil {
t.Error(err)
}
if len(slice) != 2 {
t.Error("want 2 records, got:", len(slice))
}
if len(slice) != 2 {
t.Error("want 2 records, got:", len(slice))
}
}
func test{{$tableNamePlural}}Count(t *testing.T) {
t.Parallel()
t.Parallel()
var err error
seed := randomize.NewSeed()
{{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
var err error
seed := randomize.NewSeed()
{{$varNameSingular}}One := &{{$tableNameSingular}}{}
{{$varNameSingular}}Two := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}One.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}Two.Insert(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 2 {
t.Error("want 2 records, got:", count)
}
if count != 2 {
t.Error("want 2 records, got:", count)
}
}

View file

@ -5,57 +5,57 @@
var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}}
func test{{$tableNamePlural}}InPrimaryKeyArgs(t *testing.T) {
t.Parallel()
t.Parallel()
var err error
var o {{$tableNameSingular}}
o = {{$tableNameSingular}}{}
var err error
var o {{$tableNameSingular}}
o = {{$tableNameSingular}}{}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err)
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err)
}
args := o.inPrimaryKeyArgs()
args := o.inPrimaryKeyArgs()
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args))
}
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args))
}
{{range $key, $value := .Table.PKey.Columns}}
if o.{{titleCase $value}} != args[{{$key}}] {
t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}])
}
{{- end}}
{{range $key, $value := .Table.PKey.Columns}}
if o.{{titleCase $value}} != args[{{$key}}] {
t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}])
}
{{- end}}
}
func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) {
t.Parallel()
t.Parallel()
var err error
o := make({{$tableNameSingular}}Slice, 3)
var err error
o := make({{$tableNameSingular}}Slice, 3)
seed := randomize.NewSeed()
for i := range o {
o[i] = &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err)
}
}
seed := randomize.NewSeed()
for i := range o {
o[i] = &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Could not randomize struct: %s", err)
}
}
args := o.inPrimaryKeyArgs()
args := o.inPrimaryKeyArgs()
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args))
}
if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 {
t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args))
}
argC := 0
for i := 0; i < 3; i++ {
{{range $key, $value := .Table.PKey.Columns}}
if o[i].{{titleCase $value}} != args[argC] {
t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i])
}
argC++
{{- end}}
}
argC := 0
for i := 0; i < 3; i++ {
{{range $key, $value := .Table.PKey.Columns}}
if o[i].{{titleCase $value}} != args[argC] {
t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i])
}
argC++
{{- end}}
}
}

View file

@ -4,142 +4,142 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func {{$varNameSingular}}BeforeInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}AfterInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}AfterSelectHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}BeforeUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}AfterUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}BeforeDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error {
*o = {{$tableNameSingular}}{}
return nil
*o = {{$tableNameSingular}}{}
return nil
}
func test{{$tableNamePlural}}Hooks(t *testing.T) {
t.Parallel()
t.Parallel()
var err error
var err error
empty := &{{$tableNameSingular}}{}
o := &{{$tableNameSingular}}{}
empty := &{{$tableNameSingular}}{}
o := &{{$tableNameSingular}}{}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err)
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err)
}
Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook)
if err = o.doBeforeInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeInsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook)
if err = o.doBeforeInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeInsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook)
if err = o.doAfterInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterInsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook)
if err = o.doAfterInsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterInsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook)
if err = o.doAfterSelectHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterSelectHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook)
if err = o.doAfterSelectHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterSelectHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook)
if err = o.doBeforeUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook)
if err = o.doBeforeUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook)
if err = o.doAfterUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpdateHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook)
if err = o.doAfterUpdateHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpdateHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook)
if err = o.doBeforeDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook)
if err = o.doBeforeDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook)
if err = o.doAfterDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterDeleteHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook)
if err = o.doAfterDeleteHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterDeleteHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook)
if err = o.doBeforeUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook)
if err = o.doBeforeUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook)
if err = o.doAfterUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{}
Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook)
if err = o.doAfterUpsertHooks(nil); err != nil {
t.Errorf("Unable to execute doAfterUpsertHooks: %s", err)
}
if !reflect.DeepEqual(o, empty) {
t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o)
}
{{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{}
}
{{- end}}

View file

@ -4,53 +4,53 @@
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
{{- $parent := . -}}
func test{{$tableNamePlural}}Insert(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if count != 1 {
t.Error("want one record, got:", count)
}
}
func test{{$tableNamePlural}}InsertWhitelist(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if count != 1 {
t.Error("want one record, got:", count)
}
}

View file

@ -0,0 +1,166 @@
type mysqlTester struct {
dbConn *sql.DB
dbName string
host string
user string
pass string
sslmode string
port int
optionFile string
testDBName string
}
func init() {
dbMain = &mysqlTester{}
}
func (m *mysqlTester) setup() error {
var err error
m.dbName = viper.GetString("mysql.dbname")
m.host = viper.GetString("mysql.host")
m.user = viper.GetString("mysql.user")
m.pass = viper.GetString("mysql.pass")
m.port = viper.GetInt("mysql.port")
m.sslmode = viper.GetString("mysql.sslmode")
// Create a randomized db name.
m.testDBName = randomize.StableDBName(m.dbName)
if err = m.makeOptionFile(); err != nil {
return errors.Wrap(err, "couldn't make option file")
}
if err = m.dropTestDB(); err != nil {
return err
}
if err = m.createTestDB(); err != nil {
return err
}
dumpCmd := exec.Command("mysqldump", m.defaultsFile(), "--no-data", m.dbName)
createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName)
r, w := io.Pipe()
dumpCmd.Stdout = w
createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r)
if err = dumpCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start mysqldump command")
}
if err = createCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start mysql command")
}
if err = dumpCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for mysqldump command")
}
w.Close() // After dumpCmd is done, close the write end of the pipe
if err = createCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for mysql command")
}
return nil
}
func (m *mysqlTester) sslMode(mode string) string {
switch mode {
case "true":
return "REQUIRED"
case "false":
return "DISABLED"
default:
return "PREFERRED"
}
}
func (m *mysqlTester) defaultsFile() string {
return fmt.Sprintf("--defaults-file=%s", m.optionFile)
}
func (m *mysqlTester) makeOptionFile() error {
tmp, err := ioutil.TempFile("", "optionfile")
if err != nil {
return errors.Wrap(err, "failed to create option file")
}
fmt.Fprintln(tmp, "[client]")
fmt.Fprintf(tmp, "host=%s\n", m.host)
fmt.Fprintf(tmp, "port=%d\n", m.port)
fmt.Fprintf(tmp, "user=%s\n", m.user)
fmt.Fprintf(tmp, "password=%s\n", m.pass)
fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode))
fmt.Fprintln(tmp, "[mysqldump]")
fmt.Fprintf(tmp, "host=%s\n", m.host)
fmt.Fprintf(tmp, "port=%d\n", m.port)
fmt.Fprintf(tmp, "user=%s\n", m.user)
fmt.Fprintf(tmp, "password=%s\n", m.pass)
fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode))
m.optionFile = tmp.Name()
return tmp.Close()
}
func (m *mysqlTester) createTestDB() error {
sql := fmt.Sprintf("create database %s;", m.testDBName)
return m.runCmd(sql, "mysql")
}
func (m *mysqlTester) dropTestDB() error {
sql := fmt.Sprintf("drop database if exists %s;", m.testDBName)
return m.runCmd(sql, "mysql")
}
func (m *mysqlTester) teardown() error {
if m.dbConn != nil {
m.dbConn.Close()
}
if err := m.dropTestDB(); err != nil {
return err
}
return os.Remove(m.optionFile)
}
func (m *mysqlTester) runCmd(stdin, command string, args ...string) error {
args = append([]string{m.defaultsFile()}, args...)
cmd := exec.Command(command, args...)
cmd.Stdin = strings.NewReader(stdin)
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
fmt.Println("failed running:", command, args)
fmt.Println(stdout.String())
fmt.Println(stderr.String())
return err
}
return nil
}
func (m *mysqlTester) conn() (*sql.DB, error) {
if m.dbConn != nil {
return m.dbConn, nil
}
var err error
m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode))
if err != nil {
return nil, err
}
return m.dbConn, nil
}

View file

@ -1,275 +1,166 @@
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
type pgTester struct {
dbConn *sql.DB
dbName string
host string
user string
pass string
sslmode string
port int
pgPassFile string
testDBName string
}
type Config struct {
Postgres PostgresCfg `toml:"postgres"`
}
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-2)
}
var code int
if err = disableTriggers(); err != nil {
fmt.Println("Unable to disable triggers:", err)
} else {
boil.SetDB(dbConn)
code = m.Run()
}
if err = teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-3)
}
os.Exit(code)
}
// disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors.
func disableTriggers() error {
var stmts []string
{{range .Tables}}
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
{{- end}}
if len(stmts) == 0 {
return nil
}
var err error
for _, s := range stmts {
_, err = dbConn.Exec(s)
if err != nil {
return err
}
}
return nil
}
// teardown executes cleanup tasks when the tests finish running
func teardown() error {
err := dropTestDB()
return err
}
// dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations.
func dropTestDB() error {
var err error
if dbConn != nil {
if err = dbConn.Close(); err != nil {
return err
}
}
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
if err != nil {
return err
}
return dbConn.Close()
}
// DBConnect connects to a database and returns the handle.
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode)
return sql.Open("postgres", connStr)
func init() {
dbMain = &pgTester{}
}
// setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the
// generated sqlboiler ORM package.
func setup() error {
var err error
func (p *pgTester) setup() error {
var err error
// Initialize Viper and load the config file
err = InitViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
}
p.dbName = viper.GetString("postgres.dbname")
p.host = viper.GetString("postgres.host")
p.user = viper.GetString("postgres.user")
p.pass = viper.GetString("postgres.pass")
p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.sslmode")
// Create a randomized db name.
p.testDBName = randomize.StableDBName(p.dbName)
viper.SetDefault("postgres.sslmode", "require")
viper.SetDefault("postgres.port", "5432")
if err = p.makePGPassFile(); err != nil {
return err
}
// Create a randomized test configuration object.
testCfg.Postgres.Host = viper.GetString("postgres.host")
testCfg.Postgres.Port = viper.GetInt("postgres.port")
testCfg.Postgres.User = viper.GetString("postgres.user")
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
if err = p.dropTestDB(); err != nil {
return err
}
if err = p.createTestDB(); err != nil {
return err
}
// Set the default SSLMode value
if testCfg.Postgres.SSLMode == "" {
viper.Set("postgres.sslmode", "require")
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
}
dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName)
dumpCmd.Env = append(os.Environ(), p.pgEnv()...)
createCmd := exec.Command("psql", p.testDBName)
createCmd.Env = append(os.Environ(), p.pgEnv()...)
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check()
r, w := io.Pipe()
dumpCmd.Stdout = w
createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r)
if err != nil {
return errors.Wrap(err, "Unable to load testCfg")
}
if err = dumpCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start pg_dump command")
}
if err = createCmd.Start(); err != nil {
return errors.Wrap(err, "failed to start psql command")
}
err = dropTestDB()
if err != nil {
fmt.Printf("%#v\n", err)
return err
}
if err = dumpCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for pg_dump command")
}
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
if err != nil {
return errors.Wrap(err, "Unable to create sqlboiler schema tmp file")
}
defer os.Remove(fhSchema.Name())
w.Close() // After dumpCmd is done, close the write end of the pipe
passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler")
if err != nil {
return errors.Wrap(err, "Unable to create sqlboiler tmp dir for postgres pw file")
}
defer os.RemoveAll(passDir)
if err = createCmd.Wait(); err != nil {
fmt.Println(err)
return errors.Wrap(err, "failed to wait for psql command")
}
// Write the postgres user password to a tmp file for pg_dump
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.user"),
))
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
}
passFilePath := filepath.Join(passDir, "pwfile")
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
if err != nil {
return errors.Wrap(err, "Unable to create pwfile in passDir")
}
// The params for the pg_dump command to dump the database schema
params := []string{
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")),
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")),
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")),
"--schema-only",
viper.GetString("postgres.dbname"),
}
// Dump the database schema into the sqlboilerschema tmp file
errBuf := bytes.Buffer{}
cmd := exec.Command("pg_dump", params...)
cmd.Stderr = &errBuf
cmd.Stdout = fhSchema
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath))
if err := cmd.Run(); err != nil {
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
return err
}
dbConn, err = DBConnect(
viper.GetString("postgres.user"),
viper.GetString("postgres.pass"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
)
if err != nil {
return err
}
// Create the randomly generated database
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
if err != nil {
return err
}
// Close the old connection so we can reconnect to the test database
if err = dbConn.Close(); err != nil {
return err
}
// Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
// Write the test config credentials to a tmp file for pg_dump
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
testCfg.Postgres.Host,
testCfg.Postgres.Port,
testCfg.Postgres.DBName,
testCfg.Postgres.User,
))
if len(testCfg.Postgres.Pass) > 0 {
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
}
testPassFilePath := passDir + "/testpwfile"
err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)
if err != nil {
return errors.Wrapf(err, "Unable to create testpwfile in passDir")
}
// The params for the psql schema import command
params = []string{
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
}
// Import the database schema into the generated database.
// It is now ready to be used by the generated ORM package for testing.
outBuf := bytes.Buffer{}
cmd = exec.Command("psql", params...)
cmd.Stderr = &errBuf
cmd.Stdout = &outBuf
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath))
if err = cmd.Run(); err != nil {
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
}
return nil
return nil
}
func (p *pgTester) runCmd(stdin, command string, args ...string) error {
cmd := exec.Command(command, args...)
cmd.Env = append(os.Environ(), p.pgEnv()...)
if len(stdin) != 0 {
cmd.Stdin = strings.NewReader(stdin)
}
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
fmt.Println("failed running:", command, args)
fmt.Println(stdout.String())
fmt.Println(stderr.String())
return err
}
return nil
}
func (p *pgTester) pgEnv() []string {
return []string{
fmt.Sprintf("PGHOST=%s", p.host),
fmt.Sprintf("PGPORT=%d", p.port),
fmt.Sprintf("PGUSER=%s", p.user),
fmt.Sprintf("PGPASS=%s", p.pgPassFile),
}
}
func (p *pgTester) makePGPassFile() error {
tmp, err := ioutil.TempFile("", "pgpass")
if err != nil {
return errors.Wrap(err, "failed to create option file")
}
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
p.pgPassFile = tmp.Name()
return tmp.Close()
}
func (p *pgTester) createTestDB() error {
return p.runCmd("", "createdb", p.testDBName)
}
func (p *pgTester) dropTestDB() error {
return p.runCmd("", "dropdb", "--if-exists", p.testDBName)
}
// teardown executes cleanup tasks when the tests finish running
func (p *pgTester) teardown() error {
var err error
if err = p.dbConn.Close(); err != nil {
return err
}
p.dbConn = nil
if err = p.dropTestDB(); err != nil {
return err
}
return os.Remove(p.pgPassFile)
}
func (p *pgTester) conn() (*sql.DB, error) {
if p.dbConn != nil {
return p.dbConn, nil
}
var err error
p.dbConn, err = sql.Open("postgres", drivers.PostgresBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode))
if err != nil {
return nil, err
}
return p.dbConn, nil
}

View file

@ -1,98 +1,98 @@
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . }}
{{- $table := .Table }}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- $dot := . }}
{{- $table := .Table }}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
{{- else -}}
{{- $rel := textsFromRelationship $dot.Tables $table . -}}
func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) {
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}}
var b, c {{$rel.ForeignTable.NameGo}}
var a {{$rel.LocalTable.NameGo}}
var b, c {{$rel.ForeignTable.NameGo}}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
seed := randomize.NewSeed()
randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
{{if .Nullable -}}
a.{{.Column | titleCase}}.Valid = true
{{- end}}
{{- if .ForeignColumnNullable -}}
b.{{.ForeignColumn | titleCase}}.Valid = true
c.{{.ForeignColumn | titleCase}}.Valid = true
{{- end}}
{{if not .ToJoinTable -}}
b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
{{- end}}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
seed := randomize.NewSeed()
randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}")
{{if .Nullable -}}
a.{{.Column | titleCase}}.Valid = true
{{- end}}
{{- if .ForeignColumnNullable -}}
b.{{.ForeignColumn | titleCase}}.Valid = true
c.{{.ForeignColumn | titleCase}}.Valid = true
{{- end}}
{{if not .ToJoinTable -}}
b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}}
{{- end}}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
{{if .ToJoinTable -}}
_, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil {
t.Fatal(err)
}
{{end}}
{{if .ToJoinTable -}}
_, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil {
t.Fatal(err)
}
_, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}})
if err != nil {
t.Fatal(err)
}
{{end}}
{{$varname := .ForeignTable | singular | camelCase -}}
{{$varname}}, err := a.{{$rel.Function.Name}}(tx).All()
if err != nil {
t.Fatal(err)
}
{{$varname := .ForeignTable | singular | camelCase -}}
{{$varname}}, err := a.{{$rel.Function.Name}}(tx).All()
if err != nil {
t.Fatal(err)
}
bFound, cFound := false, false
for _, v := range {{$varname}} {
if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} {
bFound = true
}
if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} {
cFound = true
}
}
bFound, cFound := false, false
for _, v := range {{$varname}} {
if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} {
bFound = true
}
if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} {
cFound = true
}
}
if !bFound {
t.Error("expected to find b")
}
if !cFound {
t.Error("expected to find c")
}
if !bFound {
t.Error("expected to find b")
}
if !cFound {
t.Error("expected to find c")
}
slice := {{$rel.LocalTable.NameGo}}Slice{&a}
if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err)
}
if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got)
}
slice := {{$rel.LocalTable.NameGo}}Slice{&a}
if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err)
}
if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got)
}
a.R.{{$rel.Function.Name}} = nil
if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil {
t.Fatal(err)
}
if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got)
}
a.R.{{$rel.Function.Name}} = nil
if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil {
t.Fatal(err)
}
if got := len(a.R.{{$rel.Function.Name}}); got != 2 {
t.Error("number of eager loaded records wrong, got:", got)
}
if t.Failed() {
t.Logf("%#v", {{$varname}})
}
if t.Failed() {
t.Logf("%#v", {{$varname}})
}
}
{{end -}}{{- /* if unique */ -}}

View file

@ -1,306 +1,306 @@
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- $dot := . -}}
{{- $table := .Table -}}
{{- range .Table.ToManyRelationships -}}
{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}}
{{- template "relationship_to_one_setops_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}}
{{- else -}}
{{- $varNameSingular := .Table | singular | camelCase -}}
{{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}}
{{- $rel := textsFromRelationship $dot.Tables $table .}}
{{- else -}}
{{- $varNameSingular := .Table | singular | camelCase -}}
{{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}}
{{- $rel := textsFromRelationship $dot.Tables $table .}}
func test{{$rel.LocalTable.NameGo}}ToManyAddOp{{$rel.Function.Name}}(t *testing.T) {
var err error
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{
{&b, &c},
{&d, &e},
}
foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{
{&b, &c},
{&d, &e},
}
for i, x := range foreignersSplitByInsertion {
err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...)
if err != nil {
t.Fatal(err)
}
for i, x := range foreignersSplitByInsertion {
err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...)
if err != nil {
t.Fatal(err)
}
first := x[0]
second := x[1]
{{- if .ToJoinTable}}
first := x[0]
second := x[1]
{{- if .ToJoinTable}}
if first.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
if second.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
{{- else}}
if first.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
if second.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
{{- else}}
if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}})
}
if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}})
}
if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}})
}
if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}})
}
if first.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice")
}
if second.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice")
}
{{- end}}
if first.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice")
}
if second.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign slice")
}
{{- end}}
if a.R.{{$rel.Function.Name}}[i*2] != first {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[i*2+1] != second {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[i*2] != first {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[i*2+1] != second {
t.Error("relationship struct slice not set to correct value")
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if want := int64((i+1)*2); count != want {
t.Error("want", want, "got", count)
}
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if want := int64((i+1)*2); count != want {
t.Error("want", want, "got", count)
}
}
}
{{- if (or .ForeignColumnNullable .ToJoinTable)}}
func test{{$rel.LocalTable.NameGo}}ToManySetOp{{$rel.Function.Name}}(t *testing.T) {
var err error
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
if err = a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
if err = a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err = c.Insert(tx); err != nil {
t.Fatal(err)
}
err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c)
if err != nil {
t.Fatal(err)
}
err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c)
if err != nil {
t.Fatal(err)
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e)
if err != nil {
t.Fatal(err)
}
err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e)
if err != nil {
t.Fatal(err)
}
count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
{{- if .ToJoinTable}}
{{- if .ToJoinTable}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
{{- else}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the slice")
}
{{- else}}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil")
}
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil")
}
if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}})
}
if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}})
}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil")
}
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil")
}
if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}})
}
if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}})
}
if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct")
}
if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct")
}
{{- end}}
if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct")
}
if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship was not added properly to the foreign struct")
}
{{- end}}
if a.R.{{$rel.Function.Name}}[0] != &d {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[1] != &e {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[0] != &d {
t.Error("relationship struct slice not set to correct value")
}
if a.R.{{$rel.Function.Name}}[1] != &e {
t.Error("relationship struct slice not set to correct value")
}
}
func test{{$rel.LocalTable.NameGo}}ToManyRemoveOp{{$rel.Function.Name}}(t *testing.T) {
var err error
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
var a {{$rel.LocalTable.NameGo}}
var b, c, d, e {{$rel.ForeignTable.NameGo}}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e}
for _, x := range foreigners {
if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...)
if err != nil {
t.Fatal(err)
}
err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...)
if err != nil {
t.Fatal(err)
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 4 {
t.Error("count was wrong:", count)
}
count, err := a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 4 {
t.Error("count was wrong:", count)
}
err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...)
if err != nil {
t.Fatal(err)
}
err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...)
if err != nil {
t.Fatal(err)
}
count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
count, err = a.{{$rel.Function.Name}}(tx).Count()
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Error("count was wrong:", count)
}
{{- if .ToJoinTable}}
{{- if .ToJoinTable}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct")
}
if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct")
}
{{- else}}
if len(b.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if len(c.R.{{$rel.Function.ForeignName}}) != 0 {
t.Error("relationship was not removed properly from the slice")
}
if d.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct")
}
if e.R.{{$rel.Function.ForeignName}}[0] != &a {
t.Error("relationship was not added properly to the foreign struct")
}
{{- else}}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil")
}
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil")
}
if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want b's foreign key value to be nil")
}
if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid {
t.Error("want c's foreign key value to be nil")
}
if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved")
}
if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved")
}
{{- end}}
if b.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if c.R.{{$rel.Function.ForeignName}} != nil {
t.Error("relationship was not removed properly from the foreign struct")
}
if d.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved")
}
if e.R.{{$rel.Function.ForeignName}} != &a {
t.Error("relationship to a should have been preserved")
}
{{- end}}
if len(a.R.{{$rel.Function.Name}}) != 2 {
t.Error("should have preserved two relationships")
}
if len(a.R.{{$rel.Function.Name}}) != 2 {
t.Error("should have preserved two relationships")
}
// Removal doesn't do a stable deletion for performance so we have to flip the order
if a.R.{{$rel.Function.Name}}[1] != &d {
t.Error("relationship to d should have been preserved")
}
if a.R.{{$rel.Function.Name}}[0] != &e {
t.Error("relationship to e should have been preserved")
}
// Removal doesn't do a stable deletion for performance so we have to flip the order
if a.R.{{$rel.Function.Name}}[1] != &d {
t.Error("relationship to d should have been preserved")
}
if a.R.{{$rel.Function.Name}}[0] != &e {
t.Error("relationship to e should have been preserved")
}
}
{{end -}}
{{- end -}}{{- /* if unique foreign key */ -}}

View file

@ -1,69 +1,69 @@
{{- define "relationship_to_one_test_helper"}}
func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var foreign {{.ForeignTable.NameGo}}
var local {{.LocalTable.NameGo}}
{{if .ForeignKey.Nullable -}}
local.{{.ForeignKey.Column | titleCase}}.Valid = true
{{end}}
{{- if .ForeignKey.ForeignColumnNullable -}}
foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true
{{end}}
var foreign {{.ForeignTable.NameGo}}
var local {{.LocalTable.NameGo}}
{{if .ForeignKey.Nullable -}}
local.{{.ForeignKey.Column | titleCase}}.Valid = true
{{end}}
{{- if .ForeignKey.ForeignColumnNullable -}}
foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true
{{end}}
{{if not .Function.OneToOne -}}
if err := foreign.Insert(tx); err != nil {
t.Fatal(err)
}
{{if not .Function.OneToOne -}}
if err := foreign.Insert(tx); err != nil {
t.Fatal(err)
}
local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}}
if err := local.Insert(tx); err != nil {
t.Fatal(err)
}
{{else -}}
if err := local.Insert(tx); err != nil {
t.Fatal(err)
}
local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}}
if err := local.Insert(tx); err != nil {
t.Fatal(err)
}
{{else -}}
if err := local.Insert(tx); err != nil {
t.Fatal(err)
}
foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}}
if err := foreign.Insert(tx); err != nil {
t.Fatal(err)
}
{{end -}}
foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}}
if err := foreign.Insert(tx); err != nil {
t.Fatal(err)
}
{{end -}}
check, err := local.{{.Function.Name}}(tx).One()
if err != nil {
t.Fatal(err)
}
check, err := local.{{.Function.Name}}(tx).One()
if err != nil {
t.Fatal(err)
}
if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} {
t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}})
}
if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} {
t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}})
}
slice := {{.LocalTable.NameGo}}Slice{&local}
if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err)
}
if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded")
}
slice := {{.LocalTable.NameGo}}Slice{&local}
if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil {
t.Fatal(err)
}
if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded")
}
local.R.{{.Function.Name}} = nil
if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil {
t.Fatal(err)
}
if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded")
}
local.R.{{.Function.Name}} = nil
if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil {
t.Fatal(err)
}
if local.R.{{.Function.Name}} == nil {
t.Error("struct should have been eager loaded")
}
}
{{end -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}}
{{- template "relationship_to_one_test_helper" $rel -}}
{{end -}}
{{- end -}}

View file

@ -2,131 +2,131 @@
{{- $varNameSingular := .ForeignKey.Table | singular | camelCase -}}
{{- $foreignVarNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}}
func test{{.LocalTable.NameGo}}ToOneSetOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
var err error
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{.LocalTable.NameGo}}
var b, c {{.ForeignTable.NameGo}}
var a {{.LocalTable.NameGo}}
var b, c {{.ForeignTable.NameGo}}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
if err := a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = b.Insert(tx); err != nil {
t.Fatal(err)
}
for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} {
err = a.Set{{.Function.Name}}(tx, i != 0, x)
if err != nil {
t.Fatal(err)
}
for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} {
err = a.Set{{.Function.Name}}(tx, i != 0, x)
if err != nil {
t.Fatal(err)
}
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}})
}
if a.R.{{.Function.Name}} != x {
t.Error("relationship struct not set to correct value")
}
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}})
}
if a.R.{{.Function.Name}} != x {
t.Error("relationship struct not set to correct value")
}
zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}}))
reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero)
zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}}))
reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero)
if err = a.Reload(tx); err != nil {
t.Fatal("failed to reload", err)
}
if err = a.Reload(tx); err != nil {
t.Fatal("failed to reload", err)
}
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}})
}
if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} {
t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}})
}
{{if .ForeignKey.Unique -}}
if x.R.{{.Function.ForeignName}} != &a {
t.Error("failed to append to foreign relationship struct")
}
{{else -}}
if x.R.{{.Function.ForeignName}}[0] != &a {
t.Error("failed to append to foreign relationship struct")
}
{{end -}}
}
{{if .ForeignKey.Unique -}}
if x.R.{{.Function.ForeignName}} != &a {
t.Error("failed to append to foreign relationship struct")
}
{{else -}}
if x.R.{{.Function.ForeignName}}[0] != &a {
t.Error("failed to append to foreign relationship struct")
}
{{end -}}
}
}
{{- if .ForeignKey.Nullable}}
func test{{.LocalTable.NameGo}}ToOneRemoveOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) {
var err error
var err error
tx := MustTx(boil.Begin())
defer tx.Rollback()
tx := MustTx(boil.Begin())
defer tx.Rollback()
var a {{.LocalTable.NameGo}}
var b {{.ForeignTable.NameGo}}
var a {{.LocalTable.NameGo}}
var b {{.ForeignTable.NameGo}}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
seed := randomize.NewSeed()
if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil {
t.Fatal(err)
}
if err = a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = a.Insert(tx); err != nil {
t.Fatal(err)
}
if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil {
t.Fatal(err)
}
if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil {
t.Fatal(err)
}
if err = a.Remove{{.Function.Name}}(tx, &b); err != nil {
t.Error("failed to remove relationship")
}
if err = a.Remove{{.Function.Name}}(tx, &b); err != nil {
t.Error("failed to remove relationship")
}
count, err := a.{{.Function.Name}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 0 {
t.Error("want no relationships remaining")
}
count, err := a.{{.Function.Name}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 0 {
t.Error("want no relationships remaining")
}
if a.R.{{.Function.Name}} != nil {
t.Error("R struct entry should be nil")
}
if a.R.{{.Function.Name}} != nil {
t.Error("R struct entry should be nil")
}
if a.{{.LocalTable.ColumnNameGo}}.Valid {
t.Error("R struct entry should be nil")
}
if a.{{.LocalTable.ColumnNameGo}}.Valid {
t.Error("R struct entry should be nil")
}
{{if .ForeignKey.Unique -}}
if b.R.{{.Function.ForeignName}} != nil {
t.Error("failed to remove a from b's relationships")
}
{{else -}}
if len(b.R.{{.Function.ForeignName}}) != 0 {
t.Error("failed to remove a from b's relationships")
}
{{end -}}
{{if .ForeignKey.Unique -}}
if b.R.{{.Function.ForeignName}} != nil {
t.Error("failed to remove a from b's relationships")
}
{{else -}}
if len(b.R.{{.Function.ForeignName}}) != 0 {
t.Error("failed to remove a from b's relationships")
}
{{end -}}
}
{{end -}}
{{- end -}}
{{- if .Table.IsJoinTable -}}
{{- else -}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}}
{{- $dot := . -}}
{{- range .Table.FKeys -}}
{{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}}
{{template "relationship_to_one_setops_test_helper" $rel -}}
{{- end -}}

View file

@ -3,45 +3,45 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Reload(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}.Reload(tx); err != nil {
t.Error(err)
}
if err = {{$varNameSingular}}.Reload(tx); err != nil {
t.Error(err)
}
}
func test{{$tableNamePlural}}ReloadAll(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.ReloadAll(tx); err != nil {
t.Error(err)
}
if err = slice.ReloadAll(tx); err != nil {
t.Error(err)
}
}

View file

@ -3,27 +3,27 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Select(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
slice, err := {{$tableNamePlural}}(tx).All()
if err != nil {
t.Error(err)
}
slice, err := {{$tableNamePlural}}(tx).All()
if err != nil {
t.Error(err)
}
if len(slice) != 1 {
t.Error("want one record, got:", len(slice))
}
if len(slice) != 1 {
t.Error("want one record, got:", len(slice))
}
}

View file

@ -0,0 +1,131 @@
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
var (
dbMain tester
)
type tester interface {
setup() error
conn() (*sql.DB, error)
teardown() error
}
func TestMain(m *testing.M) {
if dbMain == nil {
fmt.Println("no dbMain tester interface was ready")
os.Exit(-1)
}
rand.Seed(time.Now().UnixNano())
var err error
// Load configuration
err = initViper()
if err != nil {
fmt.Println("unable to load config file")
os.Exit(-2)
}
setConfigDefaults()
if err := validateConfig("{{.DriverName}}"); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-3)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-4)
}
conn, err := dbMain.conn()
if err != nil {
fmt.Println("failed to get connection:", err)
}
var code int
boil.SetDB(conn)
code = m.Run()
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-5)
}
os.Exit(code)
}
func initViper() error {
var err error
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}
// setConfigDefaults is only necessary because of bugs in viper, noted in main
func setConfigDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
if viper.GetInt("postgres.port") == 0 {
viper.Set("postgres.port", 5432)
}
if viper.GetString("mysql.sslmode") == "" {
viper.Set("mysql.sslmode", "true")
}
if viper.GetInt("mysql.port") == 0 {
viper.Set("mysql.port", 3306)
}
}
func validateConfig(driverName string) error {
if driverName == "postgres" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
}
if driverName == "mysql" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
}
return errors.New("not a valid driver name")
}

View file

@ -7,53 +7,31 @@ func MustTx(transactor boil.Transactor, err error) boil.Transactor {
return transactor
}
func initDBNameRand(input string) {
sum := md5.Sum([]byte(input))
var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE ONLY .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`)
var rgxMySQLkey = regexp.MustCompile(`(?m)((,\n)?\s+CONSTRAINT.*?FOREIGN KEY.*?\n)+`)
var sumInt string
for _, v := range sum {
sumInt = sumInt + strconv.Itoa(int(v))
func newFKeyDestroyer(regex *regexp.Regexp, reader io.Reader) io.Reader {
return &fKeyDestroyer{
reader: reader,
rgx: regex,
}
}
// Cut integer to 18 digits to ensure no int64 overflow.
sumInt = sumInt[:18]
type fKeyDestroyer struct {
reader io.Reader
buf *bytes.Buffer
rgx *regexp.Regexp
}
sumTmp := sumInt
for i, v := range sumInt {
if v == '0' {
sumTmp = sumInt[i+1:]
continue
func (f *fKeyDestroyer) Read(b []byte) (int, error) {
if f.buf == nil {
all, err := ioutil.ReadAll(f.reader)
if err != nil {
return 0, err
}
break
f.buf = bytes.NewBuffer(f.rgx.ReplaceAll(all, []byte{}))
}
sumInt = sumTmp
randSeed, err := strconv.ParseInt(sumInt, 0, 64)
if err != nil {
fmt.Printf("Unable to parse sumInt: %s", err)
os.Exit(-1)
}
dbNameRand = rand.New(rand.NewSource(randSeed))
}
var alphabetChars = "abcdefghijklmnopqrstuvwxyz"
func randStr(length int) string {
c := len(alphabetChars)
output := make([]rune, length)
for i := 0; i < length; i++ {
output[i] = rune(alphabetChars[dbNameRand.Intn(c)])
}
return string(output)
}
// getDBNameHash takes a database name in, and generates
// a random string using the database name as the rand Seed.
// getDBNameHash is used to generate unique test database names.
func getDBNameHash(input string) string {
initDBNameRand(input)
return randStr(40)
return f.buf.Read(b)
}

View file

@ -1,37 +0,0 @@
var (
testCfg *Config
dbConn *sql.DB
)
func InitViper() error {
var err error
testCfg = &Config{}
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}

View file

@ -3,97 +3,97 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Update(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
// If table only contains primary key columns, we need to pass
// them into a whitelist to get a valid test result,
// otherwise the Update method will error because it will not be able to
// generate a whitelist (due to it excluding primary key columns).
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Error(err)
}
} else {
if err = {{$varNameSingular}}.Update(tx); err != nil {
t.Error(err)
}
}
// If table only contains primary key columns, we need to pass
// them into a whitelist to get a valid test result,
// otherwise the Update method will error because it will not be able to
// generate a whitelist (due to it excluding primary key columns).
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Error(err)
}
} else {
if err = {{$varNameSingular}}.Update(tx); err != nil {
t.Error(err)
}
}
}
func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) {
t.Parallel()
t.Parallel()
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
{{$varNameSingular}} := &{{$tableNameSingular}}{}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Insert(tx); err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if count != 1 {
t.Error("want one record, got:", count)
}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
// Remove Primary keys and unique columns from what we plan to update
var fields []string
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
fields = {{$varNameSingular}}Columns
} else {
fields = strmangle.SetComplement(
{{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns,
)
}
// Remove Primary keys and unique columns from what we plan to update
var fields []string
if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) {
fields = {{$varNameSingular}}Columns
} else {
fields = strmangle.SetComplement(
{{$varNameSingular}}Columns,
{{$varNameSingular}}PrimaryKeyColumns,
)
}
value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}}))
updateMap := M{}
for _, col := range fields {
updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface()
}
updateMap := M{}
for _, col := range fields {
updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface()
}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.UpdateAll(tx, updateMap); err != nil {
t.Error(err)
}
slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}}
if err = slice.UpdateAll(tx, updateMap); err != nil {
t.Error(err)
}
}

View file

@ -3,44 +3,47 @@
{{- $varNamePlural := .Table.Name | plural | camelCase -}}
{{- $varNameSingular := .Table.Name | singular | camelCase -}}
func test{{$tableNamePlural}}Upsert(t *testing.T) {
t.Parallel()
{{if not (eq .DriverName "postgres") -}}
t.Skip("not implemented for {{.DriverName}}")
{{end -}}
t.Parallel()
seed := randomize.NewSeed()
var err error
// Attempt the INSERT side of an UPSERT
{{$varNameSingular}} := {{$tableNameSingular}}{}
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
seed := randomize.NewSeed()
var err error
// Attempt the INSERT side of an UPSERT
{{$varNameSingular}} := {{$tableNameSingular}}{}
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Upsert(tx, false, nil, nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
}
tx := MustTx(boil.Begin())
defer tx.Rollback()
if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}false, nil, {{end}}nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
count, err := {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
// Attempt the UPDATE side of an UPSERT
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
// Attempt the UPDATE side of an UPSERT
if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil {
t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err)
}
if err = {{$varNameSingular}}.Upsert(tx, true, nil, nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
}
if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}true, nil, {{end}}nil); err != nil {
t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err)
}
count, err = {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
count, err = {{$tableNamePlural}}(tx).Count()
if err != nil {
t.Error(err)
}
if count != 1 {
t.Error("want one record, got:", count)
}
}

View file

@ -93,7 +93,46 @@ CREATE TABLE magic (
strange_three timestamp without time zone default (now() at time zone 'utc'),
strange_four timestamp with time zone default (now() at time zone 'utc'),
strange_five interval NOT NULL DEFAULT '21 days',
strange_six interval NULL DEFAULT '23 hours'
strange_six interval NULL DEFAULT '23 hours',
aa json NULL,
bb json NOT NULL,
cc jsonb NULL,
dd jsonb NOT NULL,
ee box NULL,
ff box NOT NULL,
gg cidr NULL,
hh cidr NOT NULL,
ii circle NULL,
jj circle NOT NULL,
kk double precision NULL,
ll double precision NOT NULL,
mm inet NULL,
nn inet NOT NULL,
oo line NULL,
pp line NOT NULL,
qq lseg NULL,
rr lseg NOT NULL,
ss macaddr NULL,
tt macaddr NOT NULL,
uu money NULL,
vv money NOT NULL,
ww path NULL,
xx path NOT NULL,
yy pg_lsn NULL,
zz pg_lsn NOT NULL,
aaa point NULL,
bbb point NOT NULL,
ccc polygon NULL,
ddd polygon NOT NULL,
eee tsquery NULL,
fff tsquery NOT NULL,
ggg tsvector NULL,
hhh tsvector NOT NULL,
iii txid_snapshot NULL,
jjj txid_snapshot NOT NULL,
kkk xml NULL,
lll xml NOT NULL
);
create table owner (
@ -136,12 +175,6 @@ create table spider_toys (
primary key (spider_id)
);
/*
Test:
* Variations of capitalization
* Single value columns
* Primary key as only value
*/
create table pals (
pal character varying,
primary key (pal)
@ -161,3 +194,22 @@ create table enemies (
enemies character varying,
primary key (enemies)
);
create table fun_arrays (
id serial,
fun_one integer[] null,
fun_two integer[] not null,
fun_three boolean[] null,
fun_four boolean[] not null,
fun_five varchar[] null,
fun_six varchar[] not null,
fun_seven decimal[] null,
fun_eight decimal[] not null,
fun_nine bytea[] null,
fun_ten bytea[] not null,
fun_eleven jsonb[] null,
fun_twelve jsonb[] not null,
fun_thirteen json[] null,
fun_fourteen json[] not null,
primary key (id)
)

View file

@ -12,7 +12,7 @@ import (
func TestTextsFromForeignKey(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil {
t.Fatal(err)
}
@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) {
func TestTextsFromOneToOneRelationship(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil {
t.Fatal(err)
}
@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) {
func TestTextsFromRelationship(t *testing.T) {
t.Parallel()
tables, err := bdb.Tables(&drivers.MockDriver{})
tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil)
if err != nil {
t.Fatal(err)
}

719
types/array.go Normal file
View file

@ -0,0 +1,719 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func encode(x interface{}) []byte {
switch v := x.(type) {
case int64:
return strconv.AppendInt(nil, v, 10)
case float64:
return strconv.AppendFloat(nil, v, 'f', -1, 64)
case []byte:
return encodeBytes(v)
case string:
return []byte(v)
case bool:
return strconv.AppendBool(nil, v)
case time.Time:
return formatTimestamp(v)
default:
panic(fmt.Errorf("encode: unknown type for %T", v))
}
}
// FormatTimestamp formats t into Postgres' text format for timestamps.
func formatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format(time.RFC3339Nano))
_, offset := t.Zone()
offset = offset % 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
func encodeBytes(v []byte) (result []byte) {
for _, b := range v {
if b == '\\' {
result = append(result, '\\', '\\')
} else if b < 0x20 || b > 0x7e {
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
} else {
result = append(result, b)
}
}
return result
}
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytes(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
//
// For example:
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
//
// var x []sql.NullInt64
// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
//
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
// bound is not one (such as `[0:0]={1}') are not supported.
func Array(a interface{}) interface {
driver.Valuer
sql.Scanner
} {
switch a := a.(type) {
case []bool:
return (*BoolArray)(&a)
case []float64:
return (*Float64Array)(&a)
case []int64:
return (*Int64Array)(&a)
case []string:
return (*StringArray)(&a)
case *[]bool:
return (*BoolArray)(a)
case *[]float64:
return (*Float64Array)(a)
case *[]int64:
return (*Int64Array)(a)
case *[]string:
return (*StringArray)(a)
default:
panic(fmt.Sprintf("boil: invalid type received %T", a))
}
}
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
// to override the array delimiter used by GenericArray.
type ArrayDelimiter interface {
// ArrayDelimiter returns the delimiter character(s) for this element's type.
ArrayDelimiter() string
}
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
type BoolArray []bool
// Scan implements the sql.Scanner interface.
func (a *BoolArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to BoolArray", src)
}
func (a *BoolArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BoolArray, len(elems))
for i, v := range elems {
if len(v) != 1 {
return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v)
}
switch v[0] {
case 't':
b[i] = true
case 'f':
b[i] = false
default:
return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a BoolArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be exactly two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1+2*n)
for i := 0; i < n; i++ {
b[2*i] = ','
if a[i] {
b[1+2*i] = 't'
} else {
b[1+2*i] = 'f'
}
}
b[0] = '{'
b[2*n] = '}'
return string(b), nil
}
return "{}", nil
}
// BytesArray represents a one-dimensional array of the PostgreSQL bytea type.
type BytesArray [][]byte
// Scan implements the sql.Scanner interface.
func (a *BytesArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to BytesArray", src)
}
func (a *BytesArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BytesArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BytesArray, len(elems))
for i, v := range elems {
b[i], err = parseBytes(v)
if err != nil {
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface. It uses the "hex" format which
// is only supported on PostgreSQL 9.0 or newer.
func (a BytesArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
size := 1 + 6*n
for _, x := range a {
size += hex.EncodedLen(len(x))
}
b := make([]byte, size)
for i, s := 0, b; i < n; i++ {
o := copy(s, `,"\\x`)
o += hex.Encode(s[o:], a[i])
s[o] = '"'
s = s[o+1:]
}
b[0] = '{'
b[size-1] = '}'
return string(b), nil
}
return "{}", nil
}
// Float64Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float64Array []float64
// Scan implements the sql.Scanner interface.
func (a *Float64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to Float64Array", src)
}
func (a *Float64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
return fmt.Errorf("boil: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
type Int64Array []int64
// Scan implements the sql.Scanner interface.
func (a *Int64Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to Int64Array", src)
}
func (a *Int64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
return fmt.Errorf("boil: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, a[0], 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, a[i], 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// StringArray represents a one-dimensional array of the PostgreSQL character types.
type StringArray []string
// Scan implements the sql.Scanner interface.
func (a *StringArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
}
return fmt.Errorf("boil: cannot convert %T to StringArray", src)
}
func (a *StringArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
if err != nil {
return err
}
if len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(StringArray, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("boil: parsing array element index %d: cannot convert nil to string", i)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, []byte(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(a[i]))
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// appendArray appends rv to the buffer, returning the extended buffer and
// the delimiter used between elements.
//
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
var del string
var err error
b = append(b, '{')
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
return b, del, err
}
for i := 1; i < n; i++ {
b = append(b, del...)
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
return b, del, err
}
}
return append(b, '}'), del, nil
}
// appendArrayElement appends rv to the buffer, returning the extended buffer
// and the delimiter to use before the next element.
//
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
// using driver.DefaultParameterConverter and the resulting []byte or string
// is double-quoted.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
if n := rv.Len(); n > 0 {
return appendArray(b, rv, n)
}
return b, "", nil
}
}
var del = ","
var err error
var iv interface{} = rv.Interface()
if ad, ok := iv.(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
return b, del, err
}
switch v := iv.(type) {
case nil:
return append(b, "NULL"...), del, nil
case []byte:
return appendArrayQuotedBytes(b, v), del, nil
case string:
return appendArrayQuotedBytes(b, []byte(v)), del, nil
}
b, err = appendValue(b, iv)
return b, del, err
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func appendValue(b []byte, v driver.Value) ([]byte, error) {
return append(b, encode(v)...), nil
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
if len(src) < 1 || src[0] != '{' {
return nil, nil, fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '{', 0)
}
Open:
for i < len(src) {
switch src[i] {
case '{':
depth++
i++
case '}':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
case '{':
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) {
dims[depth-1]++
i += len(del)
goto Element
} else if src[i] == '}' {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
if src[i] == '}' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
err = fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '}', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("boil: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("boil: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}

800
types/array_test.go Normal file
View file

@ -0,0 +1,800 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"database/sql"
"database/sql/driver"
"math/rand"
"reflect"
"strings"
"testing"
)
func TestParseArray(t *testing.T) {
for _, tt := range []struct {
input string
delim string
dims []int
elems [][]byte
}{
{`{}`, `,`, nil, [][]byte{}},
{`{NULL}`, `,`, []int{1}, [][]byte{nil}},
{`{a}`, `,`, []int{1}, [][]byte{{'a'}}},
{`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}},
{`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}},
{`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}},
{`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{
{'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
}},
{`{""}`, `,`, []int{1}, [][]byte{{}}},
{`{","}`, `,`, []int{1}, [][]byte{{','}}},
{`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}},
{`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}},
{`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}},
{`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{
{','}, {','}, {','}, {','}, {','}, {','},
}},
{`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}},
{`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}},
{`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}},
{`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}},
{`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{
{'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'},
}},
{`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}},
} {
dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim))
if err != nil {
t.Fatalf("Expected no error for %q, got %q", tt.input, err)
}
if !reflect.DeepEqual(dims, tt.dims) {
t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims)
}
if !reflect.DeepEqual(elems, tt.elems) {
t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems)
}
}
}
func TestParseArrayError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "expected '{' at offset 0"},
{`x`, "expected '{' at offset 0"},
{`}`, "expected '{' at offset 0"},
{`{`, "expected '}' at offset 1"},
{`{{}`, "expected '}' at offset 3"},
{`{}}`, "unexpected '}' at offset 2"},
{`{,}`, "unexpected ',' at offset 1"},
{`{,x}`, "unexpected ',' at offset 1"},
{`{x,}`, "unexpected '}' at offset 3"},
{`{""x}`, "unexpected 'x' at offset 3"},
{`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"},
} {
_, _, err := parseArray([]byte(tt.input), []byte{','})
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
}
}
func TestArrayScanner(t *testing.T) {
var s sql.Scanner
s = Array(&[]bool{})
if _, ok := s.(*BoolArray); !ok {
t.Errorf("Expected *BoolArray, got %T", s)
}
s = Array(&[]float64{})
if _, ok := s.(*Float64Array); !ok {
t.Errorf("Expected *Float64Array, got %T", s)
}
s = Array(&[]int64{})
if _, ok := s.(*Int64Array); !ok {
t.Errorf("Expected *Int64Array, got %T", s)
}
s = Array(&[]string{})
if _, ok := s.(*StringArray); !ok {
t.Errorf("Expected *StringArray, got %T", s)
}
}
func TestArrayValuer(t *testing.T) {
var v driver.Valuer
v = Array([]bool{})
if _, ok := v.(*BoolArray); !ok {
t.Errorf("Expected *BoolArray, got %T", v)
}
v = Array([]float64{})
if _, ok := v.(*Float64Array); !ok {
t.Errorf("Expected *Float64Array, got %T", v)
}
v = Array([]int64{})
if _, ok := v.(*Int64Array); !ok {
t.Errorf("Expected *Int64Array, got %T", v)
}
v = Array([]string{})
if _, ok := v.(*StringArray); !ok {
t.Errorf("Expected *StringArray, got %T", v)
}
}
func TestBoolArrayScanUnsupported(t *testing.T) {
var arr BoolArray
err := arr.Scan(1)
if err == nil {
t.Fatal("Expected error when scanning from int")
}
if !strings.Contains(err.Error(), "int to BoolArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var BoolArrayStringTests = []struct {
str string
arr BoolArray
}{
{`{}`, BoolArray{}},
{`{t}`, BoolArray{true}},
{`{f,t}`, BoolArray{false, true}},
}
func TestBoolArrayScanBytes(t *testing.T) {
for _, tt := range BoolArrayStringTests {
bytes := []byte(tt.str)
arr := BoolArray{true, true, true}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkBoolArrayScanBytes(b *testing.B) {
var a BoolArray
var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`)
for i := 0; i < b.N; i++ {
a = BoolArray{}
a.Scan(x)
}
}
func TestBoolArrayScanString(t *testing.T) {
for _, tt := range BoolArrayStringTests {
arr := BoolArray{true, true, true}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestBoolArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"},
{`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`},
{`{a}`, `could not parse boolean array index 0: invalid boolean "a"`},
{`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`},
{`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`},
} {
arr := BoolArray{true, true, true}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, BoolArray{true, true, true}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestBoolArrayValue(t *testing.T) {
result, err := BoolArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = BoolArray([]bool{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = BoolArray([]bool{false, true, false}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkBoolArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]bool, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.Intn(2) == 0
}
a := BoolArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestBytesArrayScanUnsupported(t *testing.T) {
var arr BytesArray
err := arr.Scan(1)
if err == nil {
t.Fatal("Expected error when scanning from int")
}
if !strings.Contains(err.Error(), "int to BytesArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var BytesArrayStringTests = []struct {
str string
arr BytesArray
}{
{`{}`, BytesArray{}},
{`{NULL}`, BytesArray{nil}},
{`{"\\xfeff"}`, BytesArray{{'\xFE', '\xFF'}}},
{`{"\\xdead","\\xbeef"}`, BytesArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}},
}
func TestBytesArrayScanBytes(t *testing.T) {
for _, tt := range BytesArrayStringTests {
bytes := []byte(tt.str)
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkBytesArrayScanBytes(b *testing.B) {
var a BytesArray
var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`)
for i := 0; i < b.N; i++ {
a = BytesArray{}
a.Scan(x)
}
}
func TestBytesArrayScanString(t *testing.T) {
for _, tt := range BytesArrayStringTests {
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestBytesArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to BytesArray"},
{`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"},
} {
arr := BytesArray{{2}, {6}, {0, 0}}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, BytesArray{{2}, {6}, {0, 0}}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestBytesArrayValue(t *testing.T) {
result, err := BytesArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = BytesArray([][]byte{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = BytesArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkBytesArrayValue(b *testing.B) {
rand.Seed(1)
x := make([][]byte, 10)
for i := 0; i < len(x); i++ {
x[i] = make([]byte, len(x))
for j := 0; j < len(x); j++ {
x[i][j] = byte(rand.Int())
}
}
a := BytesArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestFloat64ArrayScanUnsupported(t *testing.T) {
var arr Float64Array
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to Float64Array") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var Float64ArrayStringTests = []struct {
str string
arr Float64Array
}{
{`{}`, Float64Array{}},
{`{1.2}`, Float64Array{1.2}},
{`{3.456,7.89}`, Float64Array{3.456, 7.89}},
{`{3,1,2}`, Float64Array{3, 1, 2}},
}
func TestFloat64ArrayScanBytes(t *testing.T) {
for _, tt := range Float64ArrayStringTests {
bytes := []byte(tt.str)
arr := Float64Array{5, 5, 5}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkFloat64ArrayScanBytes(b *testing.B) {
var a Float64Array
var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`)
for i := 0; i < b.N; i++ {
a = Float64Array{}
a.Scan(x)
}
}
func TestFloat64ArrayScanString(t *testing.T) {
for _, tt := range Float64ArrayStringTests {
arr := Float64Array{5, 5, 5}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestFloat64ArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"},
{`{NULL}`, "parsing array element index 0:"},
{`{a}`, "parsing array element index 0:"},
{`{5.6,a}`, "parsing array element index 1:"},
{`{5.6,7.8,a}`, "parsing array element index 2:"},
} {
arr := Float64Array{5, 5, 5}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestFloat64ArrayValue(t *testing.T) {
result, err := Float64Array(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = Float64Array([]float64{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkFloat64ArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]float64, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.NormFloat64()
}
a := Float64Array(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestInt64ArrayScanUnsupported(t *testing.T) {
var arr Int64Array
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to Int64Array") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var Int64ArrayStringTests = []struct {
str string
arr Int64Array
}{
{`{}`, Int64Array{}},
{`{12}`, Int64Array{12}},
{`{345,678}`, Int64Array{345, 678}},
}
func TestInt64ArrayScanBytes(t *testing.T) {
for _, tt := range Int64ArrayStringTests {
bytes := []byte(tt.str)
arr := Int64Array{5, 5, 5}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkInt64ArrayScanBytes(b *testing.B) {
var a Int64Array
var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`)
for i := 0; i < b.N; i++ {
a = Int64Array{}
a.Scan(x)
}
}
func TestInt64ArrayScanString(t *testing.T) {
for _, tt := range Int64ArrayStringTests {
arr := Int64Array{5, 5, 5}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestInt64ArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"},
{`{NULL}`, "parsing array element index 0:"},
{`{a}`, "parsing array element index 0:"},
{`{5,a}`, "parsing array element index 1:"},
{`{5,6,a}`, "parsing array element index 2:"},
} {
arr := Int64Array{5, 5, 5}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestInt64ArrayValue(t *testing.T) {
result, err := Int64Array(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = Int64Array([]int64{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = Int64Array([]int64{1, 2, 3}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkInt64ArrayValue(b *testing.B) {
rand.Seed(1)
x := make([]int64, 10)
for i := 0; i < len(x); i++ {
x[i] = rand.Int63()
}
a := Int64Array(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}
func TestStringArrayScanUnsupported(t *testing.T) {
var arr StringArray
err := arr.Scan(true)
if err == nil {
t.Fatal("Expected error when scanning from bool")
}
if !strings.Contains(err.Error(), "bool to StringArray") {
t.Errorf("Expected type to be mentioned when scanning, got %q", err)
}
}
var StringArrayStringTests = []struct {
str string
arr StringArray
}{
{`{}`, StringArray{}},
{`{t}`, StringArray{"t"}},
{`{f,1}`, StringArray{"f", "1"}},
{`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}},
}
func TestStringArrayScanBytes(t *testing.T) {
for _, tt := range StringArrayStringTests {
bytes := []byte(tt.str)
arr := StringArray{"x", "x", "x"}
err := arr.Scan(bytes)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", bytes, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr)
}
}
}
func BenchmarkStringArrayScanBytes(b *testing.B) {
var a StringArray
var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`)
var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`)
for i := 0; i < b.N; i++ {
a = StringArray{}
a.Scan(x)
a = StringArray{}
a.Scan(y)
}
}
func TestStringArrayScanString(t *testing.T) {
for _, tt := range StringArrayStringTests {
arr := StringArray{"x", "x", "x"}
err := arr.Scan(tt.str)
if err != nil {
t.Fatalf("Expected no error for %q, got %v", tt.str, err)
}
if !reflect.DeepEqual(arr, tt.arr) {
t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr)
}
}
}
func TestStringArrayScanError(t *testing.T) {
for _, tt := range []struct {
input, err string
}{
{``, "unable to parse array"},
{`{`, "unable to parse array"},
{`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"},
{`{NULL}`, "parsing array element index 0: cannot convert nil to string"},
{`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"},
{`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"},
} {
arr := StringArray{"x", "x", "x"}
err := arr.Scan(tt.input)
if err == nil {
t.Fatalf("Expected error for %q, got none", tt.input)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err)
}
if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) {
t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr)
}
}
}
func TestStringArrayValue(t *testing.T) {
result, err := StringArray(nil).Value()
if err != nil {
t.Fatalf("Expected no error for nil, got %v", err)
}
if result != nil {
t.Errorf("Expected nil, got %q", result)
}
result, err = StringArray([]string{}).Value()
if err != nil {
t.Fatalf("Expected no error for empty, got %v", err)
}
if expected := `{}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected empty, got %q", result)
}
result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func BenchmarkStringArrayValue(b *testing.B) {
x := make([]string, 10)
for i := 0; i < len(x); i++ {
x[i] = strings.Repeat(`abc"def\ghi`, 5)
}
a := StringArray(x)
for i := 0; i < b.N; i++ {
a.Value()
}
}

135
types/hstore.go Normal file
View file

@ -0,0 +1,135 @@
// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package types
import (
"database/sql"
"database/sql/driver"
"strings"
)
// HStore is a wrapper for transferring HStore values back and forth easily.
type HStore map[string]sql.NullString
// escapes and quotes hstore keys/values
// s should be a sql.NullString or string
func hQuote(s interface{}) string {
var str string
switch v := s.(type) {
case sql.NullString:
if !v.Valid {
return "NULL"
}
str = v.String
case string:
str = v
default:
panic("not a string or sql.NullString")
}
str = strings.Replace(str, "\\", "\\\\", -1)
return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"`
}
// Scan implements the Scanner interface.
//
// Note h is reallocated before the scan to clear existing values. If the
// hstore column's database value is NULL, then h is set to nil instead.
func (h *HStore) Scan(value interface{}) error {
if value == nil {
h = nil
return nil
}
*h = make(map[string]sql.NullString)
var b byte
pair := [][]byte{{}, {}}
pi := 0
inQuote := false
didQuote := false
sawSlash := false
bindex := 0
for bindex, b = range value.([]byte) {
if sawSlash {
pair[pi] = append(pair[pi], b)
sawSlash = false
continue
}
switch b {
case '\\':
sawSlash = true
continue
case '"':
inQuote = !inQuote
if !didQuote {
didQuote = true
}
continue
default:
if !inQuote {
switch b {
case ' ', '\t', '\n', '\r':
continue
case '=':
continue
case '>':
pi = 1
didQuote = false
continue
case ',':
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
pair[0] = []byte{}
pair[1] = []byte{}
pi = 0
continue
}
}
}
pair[pi] = append(pair[pi], b)
}
if bindex > 0 {
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
(*h)[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
(*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
}
return nil
}
// Value implements the driver Valuer interface. Note if h is nil, the
// database column value will be set to NULL.
func (h HStore) Value() (driver.Value, error) {
if h == nil {
return nil, nil
}
parts := []string{}
for key, val := range h {
thispart := hQuote(key) + "=>" + hQuote(val)
parts = append(parts, thispart)
}
return []byte(strings.Join(parts, ",")), nil
}

77
types/json.go Normal file
View file

@ -0,0 +1,77 @@
package types
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// JSON is an alias for json.RawMessage, which is
// a []byte underneath.
// JSON implements Marshal and Unmarshal.
type JSON json.RawMessage
// String output your JSON.
func (j JSON) String() string {
return string(j)
}
// Unmarshal your JSON variable into dest.
func (j JSON) Unmarshal(dest interface{}) error {
return json.Unmarshal(j, dest)
}
// Marshal obj into your JSON variable.
func (j *JSON) Marshal(obj interface{}) error {
res, err := json.Marshal(obj)
if err != nil {
return err
}
*j = res
return nil
}
// UnmarshalJSON sets *j to a copy of data.
func (j *JSON) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSON: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
// MarshalJSON returns j as the JSON encoding of j.
func (j JSON) MarshalJSON() ([]byte, error) {
return j, nil
}
// Value returns j as a value.
// Unmarshal into RawMessage for validation.
func (j JSON) Value() (driver.Value, error) {
var r json.RawMessage
if err := j.Unmarshal(&r); err != nil {
return nil, err
}
return []byte(r), nil
}
// Scan stores the src in *j.
func (j *JSON) Scan(src interface{}) error {
var source []byte
switch src.(type) {
case string:
source = []byte(src.(string))
case []byte:
source = src.([]byte)
default:
return errors.New("Incompatible type for JSON")
}
*j = JSON(append((*j)[0:0], source...))
return nil
}

119
types/json_test.go Normal file
View file

@ -0,0 +1,119 @@
package types
import (
"bytes"
"testing"
)
func TestJSONString(t *testing.T) {
t.Parallel()
j := JSON("hello")
if j.String() != "hello" {
t.Errorf("Expected %q, got %s", "hello", j.String())
}
}
func TestJSONUnmarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
var jt JSONTest
j := JSON(`{"Name":"hi","Age":15}`)
err := j.Unmarshal(&jt)
if err != nil {
t.Error(err)
}
if jt.Name != "hi" {
t.Errorf("Expected %q, got %s", "hi", jt.Name)
}
if jt.Age != 15 {
t.Errorf("Expected %v, got %v", 15, jt.Age)
}
}
func TestJSONMarshal(t *testing.T) {
t.Parallel()
type JSONTest struct {
Name string
Age int
}
jt := JSONTest{
Name: "hi",
Age: 15,
}
var j JSON
err := j.Marshal(jt)
if err != nil {
t.Error(err)
}
if j.String() != `{"Name":"hi","Age":15}` {
t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String())
}
}
func TestJSONUnmarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(nil)
err := j.UnmarshalJSON(JSON(`"hi"`))
if err != nil {
t.Error(err)
}
if j.String() != `"hi"` {
t.Errorf("Expected %q, got %s", "hi", j.String())
}
}
func TestJSONMarshalJSON(t *testing.T) {
t.Parallel()
j := JSON(`"hi"`)
res, err := j.MarshalJSON()
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(`"hi"`)) {
t.Errorf("Expected %q, got %v", `"hi"`, res)
}
}
func TestJSONValue(t *testing.T) {
t.Parallel()
j := JSON(`{"Name":"hi","Age":15}`)
v, err := j.Value()
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, v.([]byte)) {
t.Errorf("byte mismatch, %v %v", j, v)
}
}
func TestJSONScan(t *testing.T) {
t.Parallel()
j := JSON{}
err := j.Scan(`"hello"`)
if err != nil {
t.Error(err)
}
if !bytes.Equal(j, []byte(`"hello"`)) {
t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`)))
}
}