Began implementing the ORM prototype

* Hooks, query mods, and query
* Update and UpdateX
This commit is contained in:
Patrick O'brien 2016-04-13 23:51:58 +10:00
parent d949f68ed0
commit d89d23e673
27 changed files with 448 additions and 69 deletions

5
boil/bind.go Normal file
View file

@ -0,0 +1,5 @@
package boil
func (q *Query) Bind() {
}

1
boil/bind_test.go Normal file
View file

@ -0,0 +1 @@
package boil

41
boil/db.go Normal file
View file

@ -0,0 +1,41 @@
package boil
import "database/sql"
type Executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Transactor interface {
Commit() error
Rollback() error
Executor
}
type Creator interface {
Begin() (*sql.Tx, error)
}
var currentDB Executor
func Begin() (Transactor, error) {
creator, ok := currentDB.(Creator)
if !ok {
panic("Your database handle does not support transactions.")
}
return creator.Begin()
}
// SetDB initializes the database handle for all template db interactions
func SetDB(db Executor) {
currentDB = db
}
// GetDB retrieves the global state database handle
func GetDB() Executor {
return currentDB
}

16
boil/db_test.go Normal file
View file

@ -0,0 +1,16 @@
package boil
import (
"database/sql"
"testing"
)
func TestGetSetDB(t *testing.T) {
t.Parallel()
SetDB(&sql.DB{})
if GetDB() == nil {
t.Errorf("Expected GetDB to return a database handle, got nil")
}
}

View file

@ -9,6 +9,9 @@ import (
"unicode"
)
// M type is for providing where filters to Where helpers.
type M map[string]interface{}
// SelectNames returns the column names for a select statement
// Eg: col1, col2, col3
func SelectNames(results interface{}) string {
@ -36,9 +39,9 @@ func SelectNames(results interface{}) string {
return strings.Join(names, ", ")
}
// Where returns the where clause for an sql statement
// WhereClause returns the where clause for an sql statement
// eg: col1=$1 AND col2=$2 AND col3=$3
func Where(columns map[string]interface{}) string {
func WhereClause(columns map[string]interface{}) string {
names := make([]string, 0, len(columns))
for c := range columns {
@ -90,6 +93,35 @@ func WhereParams(columns map[string]interface{}) []interface{} {
return results
}
// SetParamNames takes a slice of columns and returns a comma seperated
// list of parameter names for a template statement SET clause.
// eg: col1=$1,col2=$2,col3=$3
func SetParamNames(columns []string) string {
names := make([]string, 0, len(columns))
counter := 0
for _, c := range columns {
counter++
names = append(names, fmt.Sprintf("%s=$%d", c, counter))
}
return strings.Join(names, ",")
}
// WherePrimaryKey returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(start int, pkeys ...string) string {
var output string
for i, c := range pkeys {
output = fmt.Sprintf("%s%s=$%d", output, c, start)
start++
if i < len(pkeys)-1 {
output = fmt.Sprintf("%s AND ", output)
}
}
return output
}
// goVarToSQLName converts a go variable name to a column name
// example: HelloFriendID to hello_friend_id
func goVarToSQLName(name string) string {

View file

@ -50,7 +50,7 @@ func TestSelectNames(t *testing.T) {
}
}
func TestWhere(t *testing.T) {
func TestWhereClause(t *testing.T) {
t.Parallel()
columns := map[string]interface{}{
@ -59,7 +59,7 @@ func TestWhere(t *testing.T) {
"date": time.Now(),
}
result := Where(columns)
result := WhereClause(columns)
if result != `date=$1 AND id=$2 AND name=$3` {
t.Error("Result was wrong, got:", result)

11
boil/hooks.go Normal file
View file

@ -0,0 +1,11 @@
package boil
// HookPoint is the point in time at which we hook
type HookPoint int
const (
HookAfterCreate HookPoint = iota + 1
HookAfterUpdate
HookBeforeCreate
HookBeforeUpdate
)

1
boil/hooks_test.go Normal file
View file

@ -0,0 +1 @@
package boil

25
boil/query.go Normal file
View file

@ -0,0 +1,25 @@
package boil
type where struct {
clause string
args []interface{}
}
type Query struct {
limit int
where []where
executor Executor
groupBy []string
orderBy []string
having []string
from string
}
func (q *Query) buildQuery() string {
return ""
}
// makes a new empty query ?????
func New() {
}

56
boil/query_mods.go Normal file
View file

@ -0,0 +1,56 @@
package boil
type QueryMod func(q *Query)
func (q *Query) Apply(mods ...QueryMod) {
for _, mod := range mods {
mod(q)
}
}
func DB(e Executor) QueryMod {
return func(q *Query) {
q.executor = e
}
}
func Limit(limit int) QueryMod {
return func(q *Query) {
q.limit = limit
}
}
func Where(clause string, args ...interface{}) QueryMod {
return func(q *Query) {
w := where{
clause: clause,
args: args,
}
q.where = append(q.where, w)
}
}
func GroupBy(clause string) QueryMod {
return func(q *Query) {
q.groupBy = append(q.groupBy, clause)
}
}
func OrderBy(clause string) QueryMod {
return func(q *Query) {
q.orderBy = append(q.orderBy, clause)
}
}
func Having(clause string) QueryMod {
return func(q *Query) {
q.having = append(q.having, clause)
}
}
func From(clause string) QueryMod {
return func(q *Query) {
q.from = clause
}
}

137
boil/query_mods_test.go Normal file
View file

@ -0,0 +1,137 @@
package boil
import "testing"
func TestApply(t *testing.T) {
t.Parallel()
q := &Query{}
qfn1 := Limit(10)
qfn2 := Where("x > $1 AND y > $2", 5, 3)
q.Apply(qfn1, qfn2)
expect1 := 10
if q.limit != expect1 {
t.Errorf("Expected %d, got %d", expect1, q.limit)
}
expect2 := "x > $1 AND y > $2"
if len(q.where) != 1 {
t.Errorf("Expected %d where slices, got %d", len(q.where))
}
expect := "x > $1 AND y > $2"
if q.where[0].clause != expect2 {
t.Errorf("Expected %s, got %s", expect, q.where)
}
if len(q.where[0].args) != 2 {
t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args))
}
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args)
}
}
func TestDB(t *testing.T) {
t.Parallel()
}
func TestLimit(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := Limit(10)
qfn(q)
expect := 10
if q.limit != expect {
t.Errorf("Expected %d, got %d", expect, q.limit)
}
}
func TestWhere(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := Where("x > $1 AND y > $2", 5, 3)
qfn(q)
if len(q.where) != 1 {
t.Errorf("Expected %d where slices, got %d", len(q.where))
}
expect := "x > $1 AND y > $2"
if q.where[0].clause != expect {
t.Errorf("Expected %s, got %s", expect, q.where)
}
if len(q.where[0].args) != 2 {
t.Errorf("Expected %d args, got %d", 2, len(q.where[0].args))
}
if q.where[0].args[0].(int) != 5 || q.where[0].args[1].(int) != 3 {
t.Errorf("Args not set correctly, expected 5 & 3, got: %#v", q.where[0].args)
}
}
func TestGroupBy(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := GroupBy("col1, col2")
qfn(q)
expect := "col1, col2"
if len(q.groupBy) != 1 && q.groupBy[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.groupBy[0])
}
}
func TestOrderBy(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := OrderBy("col1 desc, col2 asc")
qfn(q)
expect := "col1 desc, col2 asc"
if len(q.orderBy) != 1 && q.orderBy[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.orderBy[0])
}
}
func TestHaving(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := Having("count(orders.order_id) > 10")
qfn(q)
expect := "count(orders.order_id) > 10"
if len(q.having) != 1 && q.having[0] != expect {
t.Errorf("Expected %s, got %s", expect, q.having[0])
}
}
func TestFrom(t *testing.T) {
t.Parallel()
q := &Query{}
qfn := From("videos a, orders b")
qfn(q)
expect := "videos a, orders b"
if q.from != expect {
t.Errorf("Expected %s, got %s", expect, q.from)
}
}

1
boil/query_test.go Normal file
View file

@ -0,0 +1 @@
package boil

View file

@ -1,14 +0,0 @@
package boil
import "database/sql"
// DB implements the functions necessary for the templates to function.
type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Select(dest interface{}, query string, args ...interface{}) error
}
// M type is for providing where filters to Where helpers.
type M map[string]interface{}

View file

@ -80,6 +80,7 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
"camelCase": camelCase,
"camelCaseSingular": camelCaseSingular,
"camelCasePlural": camelCasePlural,
"commaList": commaList,
"makeDBName": makeDBName,
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,

View file

@ -33,7 +33,7 @@ func generateOutput(cmdData *CmdData, data *tplData) error {
imps = combineTypeImports(imps, sqlBoilerTypeImports, data.Table.Columns)
resp, err := generateTemplate(template, data)
if err != nil {
return err
return fmt.Errorf("Error generating template %s: %s", template.Name(), err)
}
out = append(out, resp)
}
@ -62,7 +62,7 @@ func generateTestOutput(cmdData *CmdData, data *tplData) error {
for _, template := range cmdData.TestTemplates {
resp, err := generateTemplate(template, data)
if err != nil {
return err
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
}
out = append(out, resp)
}

View file

@ -126,13 +126,13 @@ func (c *CmdData) run(includeTests bool) error {
// Generate the regular templates
if err := generateOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate test output: %s", err)
return fmt.Errorf("Unable to generate output: %s", err)
}
// Generate the test templates
if includeTests {
if err := generateTestOutput(c, data); err != nil {
return fmt.Errorf("Unable to generate output: %s", err)
return fmt.Errorf("Unable to generate test output: %s", err)
}
}
}

View file

@ -234,6 +234,11 @@ func wherePrimaryKey(pkeyCols []string, start int) string {
return output
}
// commaList returns a comma seperated list: "col1, col2, col3"
func commaList(cols []string) string {
return strings.Join(cols, ", ")
}
// paramsPrimaryKey returns the parameters for the sql statement $ flags
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {

View file

@ -3,10 +3,10 @@
{{- $tableNamePlural := titleCasePlural .Table.Name -}}
{{- $varNamePlural := camelCasePlural .Table.Name -}}
// {{$tableNamePlural}}All retrieves all records.
func {{$tableNamePlural}}All(db boil.DB) ([]*{{$tableNameSingular}}, error) {
func {{$tableNamePlural}}All() ([]*{{$tableNameSingular}}, error) {
var {{$varNamePlural}} []*{{$tableNameSingular}}
rows, err := db.Query(`SELECT {{selectParamNames $dbName .Table.Columns}} FROM {{.Table.Name}}`)
rows, err := boil.GetDB().Query(`SELECT {{selectParamNames $dbName .Table.Columns}} FROM {{.Table.Name}}`)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: failed to query: %v", err)
}

View file

@ -2,12 +2,12 @@
{{- $dbName := singular .Table.Name -}}
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
// {{$tableNameSingular}}Find retrieves a single record by ID.
func {{$tableNameSingular}}Find(db boil.DB, id int) (*{{$tableNameSingular}}, error) {
func {{$tableNameSingular}}Find(id int) (*{{$tableNameSingular}}, error) {
if id == 0 {
return nil, errors.New("{{.PkgName}}: no id provided for {{.Table.Name}} select")
}
var {{$varNameSingular}} *{{$tableNameSingular}}
err := db.Select(&{{$varNameSingular}}, `SELECT {{selectParamNames $dbName .Table.Columns}} WHERE id=$1`, id)
err := boil.GetDB().Select(&{{$varNameSingular}}, `SELECT {{selectParamNames $dbName .Table.Columns}} WHERE id=$1`, id)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)

View file

@ -2,13 +2,13 @@
// {{$tableNameSingular}}FindSelect retrieves the specified columns for a single record by ID.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
func {{$tableNameSingular}}FindSelect(db boil.DB, id int, results interface{}) error {
func {{$tableNameSingular}}FindSelect(id int, results interface{}) error {
if id == 0 {
return errors.New("{{.PkgName}}: no id provided for {{.Table.Name}} select")
}
query := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE id=$1`, boil.SelectNames(results))
err := db.Select(results, query, id)
err := boil.GetDB().Select(results, query, id)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)

65
cmds/templates/hooks.tpl Normal file
View file

@ -0,0 +1,65 @@
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
type {{$tableNameSingular}}Hook func(*{{$tableNameSingular}}) error
var {{$varNameSingular}}BeforeCreateHooks []{{$tableNameSingular}}Hook
var {{$varNameSingular}}BeforeUpdateHooks []{{$tableNameSingular}}Hook
var {{$varNameSingular}}AfterCreateHooks []{{$tableNameSingular}}Hook
var {{$varNameSingular}}AfterUpdateHooks []{{$tableNameSingular}}Hook
// doBeforeCreateHooks executes all "before create" hooks.
func (o *{{$tableNameSingular}}) doBeforeCreateHooks() (err error) {
for _, hook := range {{$varNameSingular}}BeforeCreateHooks {
if err := hook(o); err == nil {
return err
}
}
return nil
}
// doBeforeUpdateHooks executes all "before Update" hooks.
func (o *{{$tableNameSingular}}) doBeforeUpdateHooks() (err error) {
for _, hook := range {{$varNameSingular}}BeforeUpdateHooks {
if err := hook(o); err == nil {
return err
}
}
return nil
}
// doAfterCreateHooks executes all "after create" hooks.
func (o *{{$tableNameSingular}}) doAfterCreateHooks() (err error) {
for _, hook := range {{$varNameSingular}}AfterCreateHooks {
if err := hook(o); err == nil {
return err
}
}
return nil
}
// doAfterUpdateHooks executes all "after Update" hooks.
func (o *{{$tableNameSingular}}) doAfterUpdateHooks() (err error) {
for _, hook := range {{$varNameSingular}}AfterUpdateHooks {
if err := hook(o); err == nil {
return err
}
}
return nil
}
func (o *{{$tableNameSingular}}) {{$tableNameSingular}}AddHook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) {
switch hookPoint {
case boil.HookBeforeCreate:
{{$varNameSingular}}BeforeCreateHooks = append({{$varNameSingular}}BeforeCreateHooks, {{$varNameSingular}}Hook)
case boil.HookBeforeUpdate:
{{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook)
case boil.HookAfterCreate:
{{$varNameSingular}}AfterCreateHooks = append({{$varNameSingular}}AfterCreateHooks, {{$varNameSingular}}Hook)
case boil.HookAfterUpdate:
{{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook)
}
}

View file

@ -1,16 +1,24 @@
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
// {{$tableNameSingular}}Insert inserts a single record.
func {{$tableNameSingular}}Insert(db boil.DB, o *{{$tableNameSingular}}) (int, error) {
func (o *{{$tableNameSingular}}) Insert() (int, error) {
if o == nil {
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
}
if err := o.doBeforeCreateHooks(); err != nil {
return 0, err
}
var rowID int
err := db.QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
if err != nil {
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
}
if err := o.doAfterCreateHooks(); err != nil {
return 0, err
}
return rowID, nil
}

View file

@ -2,9 +2,9 @@
// {{$tableNamePlural}}Select retrieves the specified columns for all records.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
func {{$tableNamePlural}}Select(db boil.DB, results interface{}) error {
func {{$tableNamePlural}}Select(results interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}}`, boil.SelectNames(results))
err := db.Select(results, query)
err := boil.GetDB().Select(results, query)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)

View file

@ -1,8 +1,8 @@
{{- $tableNamePlural := titleCasePlural .Table.Name -}}
// {{$tableNamePlural}}SelectWhere retrieves all records with the specified column values.
func {{$tableNamePlural}}SelectWhere(db boil.DB, results interface{}, columns map[string]interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, boil.SelectNames(results), boil.Where(columns))
err := db.Select(results, query, boil.WhereParams(columns)...)
func {{$tableNamePlural}}SelectWhere(results interface{}, columns map[string]interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, boil.SelectNames(results), boil.WhereClause(columns))
err := boil.GetDB().Select(results, query, boil.WhereParams(columns)...)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)

View file

@ -1,32 +1,36 @@
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
// {{$tableNameSingular}}Update updates a single record.
func {{$tableNameSingular}}Update(db boil.DB, id int, columns map[string]interface{}) error {
if id == 0 {
return errors.New("{{.PkgName}}: no id provided for {{.Table.Name}} update")
}
query := fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE id=$%d`, boil.Update(columns), len(columns))
_, err := db.Exec(query, id, boil.WhereParams(columns))
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to update row with ID %d in {{.Table.Name}}: %s", id, err)
}
return nil
}
{{if hasPrimaryKey .Table.PKey -}}
// Update updates a single {{$tableNameSingular}} record.
// whitelist is a list of column_name's that should be updated.
// Update will match against the primary key column to find the record to update.
// WARNING: This Update method will NOT ignore nil members.
// If you pass in nil members, those columnns will be set to null.
func (o *{{$tableNameSingular}}) Update(db boil.DB) error {
func (o *{{$tableNameSingular}}) UpdateX(db Executor, whitelist ... string) error {
if err := o.doBeforeUpdateHooks(); err != nil {
return err
}
{{$flagIndex := primaryKeyFlagIndex .Table.Columns .Table.PKey.Columns}}
_, err := db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
var err error
if len(whitelist) == 0 {
_, err = db.Exec("UPDATE {{.Table.Name}} SET {{updateParamNames .Table.Columns .Table.PKey.Columns}} WHERE {{wherePrimaryKey .Table.PKey.Columns $flagIndex}}", {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
} else {
query := fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}}))
_, err = db.Exec(query, {{updateParamVariables "o." .Table.Columns .Table.PKey.Columns}}, {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
}
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
}
if err := o.doAfterUpdateHooks(); err != nil {
return err
}
return nil
}
func (o *{{$tableNameSingular}}) Update(whitelist ... string) error {
return o.UpdateX(boil.GetDB(), whitelist...)
}
{{- end}}

View file

@ -1,16 +0,0 @@
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
{{- $dbName := singular .Table.Name -}}
{{- $tableNamePlural := titleCasePlural .Table.Name -}}
{{- $varNamePlural := camelCasePlural .Table.Name -}}
// {{$tableNamePlural}}Where retrieves all records with the specified column values.
func {{$tableNamePlural}}Where(db boil.DB, columns map[string]interface{}) ([]*{{$tableNameSingular}}, error) {
var {{$varNamePlural}} []*{{$tableNameSingular}}
query := fmt.Sprintf(`SELECT {{selectParamNames $dbName .Table.Columns}} FROM {{.Table.Name}} WHERE %s`, boil.Where(columns))
err := db.Select(&{{$varNamePlural}}, query, boil.WhereParams(columns)...)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)
}
return {{$varNamePlural}}, nil
}