Added inflection library, fixed all templates

* Added inflection to all templates
* Fixed broken template code (for the most part)
This commit is contained in:
Patrick O'brien 2016-03-18 21:26:48 +10:00
parent b7ab642732
commit 2f9c936c6a
12 changed files with 186 additions and 67 deletions

View file

@ -61,13 +61,20 @@ var sqlBoilerCommandRuns = map[string]CobraRunFunc{
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{
"titleCase": titleCase,
"camelCase": camelCase,
"makeDBName": makeDBName,
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags,
"scanParamNames": scanParamNames,
"singular": singular,
"plural": plural,
"titleCase": titleCase,
"titleCaseSingular": titleCaseSingular,
"titleCasePlural": titleCasePlural,
"camelCase": camelCase,
"camelCaseSingular": camelCaseSingular,
"camelCasePlural": camelCasePlural,
"makeDBName": makeDBName,
"selectParamNames": selectParamNames,
"insertParamNames": insertParamNames,
"insertParamFlags": insertParamFlags,
"insertParamVariables": insertParamVariables,
"scanParamNames": scanParamNames,
}
/* Struct commands */

View file

@ -7,6 +7,7 @@ import (
"strings"
"text/template"
"github.com/jinzhu/inflection"
"github.com/pobri19/sqlboiler/dbdrivers"
)
@ -56,6 +57,20 @@ func processTemplate(t *template.Template, data *tplData) ([]byte, error) {
return output, nil
}
// plural converts singular words to plural words (eg: person to people)
func plural(name string) string {
splits := strings.Split(name, "_")
splits[len(splits)-1] = inflection.Plural(splits[len(splits)-1])
return strings.Join(splits, "_")
}
// singular converts plural words to singular words (eg: people to person)
func singular(name string) string {
splits := strings.Split(name, "_")
splits[len(splits)-1] = inflection.Singular(splits[len(splits)-1])
return strings.Join(splits, "_")
}
// titleCase changes a snake-case variable name
// into a go styled object variable name of "ColumnName".
// titleCase also fully uppercases "ID" components of names, for example
@ -75,6 +90,22 @@ func titleCase(name string) string {
return strings.Join(splits, "")
}
// titleCaseSingular changes a snake-case variable name
// to a go styled object variable name of "ColumnName".
// titleCaseSingular also converts the last word in the
// variable name to a singularized version of itself.
func titleCaseSingular(name string) string {
return titleCase(singular(name))
}
// titleCasePlural changes a snake-case variable name
// to a go styled object variable name of "ColumnName".
// titleCasePlural also converts the last word in the
// variable name to a pluralized version of itself.
func titleCasePlural(name string) string {
return titleCase(plural(name))
}
// camelCase takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// camelCase also fully uppercases "ID" components of names, for example
@ -98,6 +129,22 @@ func camelCase(name string) string {
return strings.Join(splits, "")
}
// camelCaseSingular takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// camelCaseSingular also converts the last word in the
// variable name to a singularized version of itself.
func camelCaseSingular(name string) string {
return camelCase(singular(name))
}
// camelCasePlural takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// camelCasePlural also converts the last word in the
// variable name to a pluralized version of itself.
func camelCasePlural(name string) string {
return camelCase(plural(name))
}
// makeDBName takes a table name in the format of "table_name" and a
// column name in the format of "column_name" and returns a name used in the
// `db:""` component of an object in the format of "table_name_column_name"
@ -125,6 +172,20 @@ func insertParamFlags(columns []dbdrivers.DBColumn) string {
return strings.Join(params, ", ")
}
// insertParamVariables takes a prefix and a []DBColumns and returns a
// comma seperated list of parameter variable names for the insert statement.
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
func insertParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
names := make([]string, 0, len(columns))
for _, c := range columns {
n := prefix + titleCase(c.Name)
names = append(names, n)
}
return strings.Join(names, ", ")
}
// selectParamNames takes a []DBColumn and returns a comma seperated
// list of parameter names with for the select statement template.
// It also uses the table name to generate the "AS" part of the statement, for

View file

@ -11,6 +11,44 @@ var testColumns = []dbdrivers.DBColumn{
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
}
func TestSingular(t *testing.T) {
t.Parallel()
tests := []struct {
In string
Out string
}{
{"hello_people", "hello_person"},
{"hello_person", "hello_person"},
{"friends", "friend"},
}
for i, test := range tests {
if out := singular(test.In); out != test.Out {
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
}
}
}
func TestPlural(t *testing.T) {
t.Parallel()
tests := []struct {
In string
Out string
}{
{"hello_person", "hello_people"},
{"friend", "friends"},
{"friends", "friends"},
}
for i, test := range tests {
if out := plural(test.In); out != test.Out {
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
}
}
}
func TestTitleCase(t *testing.T) {
t.Parallel()
@ -75,6 +113,13 @@ func TestInsertParamFlags(t *testing.T) {
}
}
func TestInsertParamVariables(t *testing.T) {
out := insertParamVariables("o.", testColumns)
if out != "o.FriendColumn, o.EnemyColumnThing" {
t.Error("Wrong output:", out)
}
}
func TestSelectParamFlags(t *testing.T) {
t.Parallel()

View file

@ -1,28 +1,30 @@
{{- $tableName := titleCase .Table -}}
{{- $varName := camelCase .Table -}}
// {{$tableName}}All retrieves all records.
func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) {
var {{$varName}} []*{{$tableName}}
{{- $tableNameSingular := titleCaseSingular .Table -}}
{{- $dbName := singular .Table -}}
{{- $tableNamePlural := titleCasePlural .Table -}}
{{- $varNamePlural := camelCasePlural .Table -}}
// {{$tableNamePlural}}All retrieves all records.
func {{$tableNamePlural}}All(db boil.DB) ([]*{{$tableNameSingular}}, error) {
var {{$varNamePlural}} []*{{$tableNameSingular}}
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}} FROM {{.Table}}`)
rows, err := db.Query(`SELECT {{selectParamNames $dbName .Columns}} FROM {{.Table}}`)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: failed to query: %v", err)
}
for rows.Next() {
{{- $tmpVarName := (print $varName "Tmp") -}}
{{$varName}}Tmp := {{$tableName}}{}
{{- $tmpVarName := (print $varNamePlural "Tmp") -}}
{{$varNamePlural}}Tmp := {{$tableNameSingular}}{}
if err := rows.Scan({{scanParamNames $tmpVarName .Columns}}); err != nil {
return nil, fmt.Errorf("{{.PkgName}}: failed to scan row: %v", err)
}
{{$varName}} = append({{$varName}}, {{$varName}}Tmp)
{{$varNamePlural}} = append({{$varNamePlural}}, {{$varNamePlural}}Tmp)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("{{.PkgName}}: failed to read rows: %v", err)
}
return {{$varName}}, nil
return {{$varNamePlural}}, nil
}

View file

@ -1,13 +1,13 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Delete deletes a single record.
func {{titleCase $tableName}}Delete(db boil.DB, id int) error {
{{- $tableNameSingular := titleCaseSingular .Table -}}
// {{$tableNameSingular}}Delete deletes a single record.
func {{$tableNameSingular}}Delete(db boil.DB, id int) error {
if id == nil {
return nil, errors.New("{{.PkgName}}: no id provided for {{$tableName}} delete")
return nil, errors.New("{{.PkgName}}: no id provided for {{.Table}} delete")
}
err := db.Exec("DELETE FROM {{$tableName}} WHERE id=$1", id)
err := db.Exec("DELETE FROM {{.Table}} WHERE id=$1", id)
if err != nil {
return errors.New("{{.PkgName}}: unable to delete from {{$tableName}}: %s", err)
return errors.New("{{.PkgName}}: unable to delete from {{.Table}}: %s", err)
}
return nil

View file

@ -1,16 +1,17 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Find retrieves a single record by ID.
func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) {
{{- $tableNameSingular := titleCaseSingular .Table -}}
{{- $dbName := singular .Table -}}
{{- $varNameSingular := camelCaseSingular .Table -}}
// {{$tableNameSingular}}Find retrieves a single record by ID.
func {{$tableNameSingular}}Find(db boil.DB, id int) (*{{$tableNameSingular}}, error) {
if id == 0 {
return nil, errors.New("{{.PkgName}}: no id provided for {{$tableName}} select")
return nil, errors.New("{{.PkgName}}: no id provided for {{.Table}} select")
}
{{$varName := camelCase $tableName}}
var {{$varName}} *{{titleCase $tableName}}
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
var {{$varNameSingular}} *{{$tableNameSingular}}
err := db.Select(&{{$varNameSingular}}, `SELECT {{selectParamNames $dbName .Columns}} WHERE id=$1`, id)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{$tableName}}: %s", err)
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table}}: %s", err)
}
return {{$varName}}, nil
return {{$varNameSingular}}, nil
}

View file

@ -1,17 +1,17 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}FindSelect retrieves the specified columns for a single record by ID.
{{- $tableNameSingular := titleCaseSingular .Table -}}
// {{$tableNameSingular}}FindSelect retrieves the specified columns for a single record by ID.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
func {{titleCase $tableName}}FindSelect(db boil.DB, id int, results interface{}) error {
func {{$tableNameSingular}}FindSelect(db boil.DB, id int, results interface{}) error {
if id == 0 {
return nil, errors.New("{{.PkgName}}: no id provided for {{$tableName}} select")
return nil, errors.New("{{.PkgName}}: no id provided for {{.Table}} select")
}
query := fmt.Sprintf(`SELECT %s FROM {{$tableName}} WHERE id=$1`, boil.SelectNames(results))
query := fmt.Sprintf(`SELECT %s FROM {{.Table}} WHERE id=$1`, boil.SelectNames(results))
err := db.Select(results, query, id)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{$tableName}}: %s", err)
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table}}: %s", err)
}
return nil

View file

@ -1,15 +1,15 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Insert inserts a single record.
func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) {
{{- $tableNameSingular := titleCaseSingular .Table -}}
// {{$tableNameSingular}}Insert inserts a single record.
func {{$tableNameSingular}}Insert(db boil.DB, o *{{$tableNameSingular}}) (int, error) {
if o == nil {
return 0, errors.New("{{.PkgName}}: no {{$tableName}} provided for insertion")
return 0, errors.New("{{.PkgName}}: no {{.Table}} provided for insertion")
}
var rowID int
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`)
err := db.QueryRow(`INSERT INTO {{.Table}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`, {{insertParamVariables "o." .Columns}}).Scan(&rowID)
if err != nil {
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{$tableName}}: %s", err)
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table}}: %s", err)
}
return rowID, nil

View file

@ -1,13 +1,13 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}Select retrieves the specified columns for all records.
{{- $tableNamePlural := titleCasePlural .Table -}}
// {{$tableNamePlural}}Select retrieves the specified columns for all records.
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
// For example: friendName string `db:"friend_name"`
func {{titleCase $tableName}}Select(db boil.DB, results interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{$tableName}}`, boil.SelectNames(results))
func {{$tableNamePlural}}Select(db boil.DB, results interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{.Table}}`, boil.SelectNames(results))
err := db.Select(results, query)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{$tableName}}: %s", err)
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table}}: %s", err)
}
return nil

View file

@ -1,11 +1,11 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}}SelectWhere retrieves all records with the specified column values.
func {{titleCase $tableName}}SelectWhere(db boil.DB, results interface{}, columns map[string]interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{$tableName}} WHERE %s`, boil.SelectNames(results), boil.Where(columns))
{{- $tableNamePlural := titleCasePlural .Table -}}
// {{$tableNamePlural}}SelectWhere retrieves all records with the specified column values.
func {{$tableNamePlural}}SelectWhere(db boil.DB, results interface{}, columns map[string]interface{}) error {
query := fmt.Sprintf(`SELECT %s FROM {{.Table}} WHERE %s`, boil.SelectNames(results), boil.Where(columns))
err := db.Select(results, query, boil.WhereParams(columns)...)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to select from {{$tableName}}: %s", err)
return fmt.Errorf("{{.PkgName}}: unable to select from {{.Table}}: %s", err)
}
return nil

View file

@ -1,7 +1,8 @@
{{- $tableName := .Table -}}
// {{titleCase $tableName}} is an object representing the database table.
type {{titleCase $tableName}} struct {
{{- $tableNameSingular := titleCaseSingular .Table -}}
{{- $dbName := singular .Table -}}
// {{$tableNameSingular}} is an object representing the database table.
type {{$tableNameSingular}} struct {
{{range $key, $value := .Columns -}}
{{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"`
{{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $dbName $value.Name}}" json:"{{$value.Name}}"`
{{end -}}
}

View file

@ -1,14 +1,16 @@
{{- $tableName := .Table -}}
{{- $varName := camelCase $tableName -}}
// {{titleCase $tableName}}Where retrieves all records with the specified column values.
func {{titleCase $tableName}}Where(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) {
var {{$varName}} []*{{titleCase $tableName}}
query := fmt.Sprintf(`SELECT {{selectParamNames $tableName .Columns}} FROM {{$tableName}} WHERE %s`, boil.Where(columns))
err := db.Select(&{{$varName}}, query, boil.WhereParams(columns)...)
{{- $tableNameSingular := titleCaseSingular .Table -}}
{{- $dbName := singular .Table -}}
{{- $tableNamePlural := titleCasePlural .Table -}}
{{- $varNamePlural := camelCasePlural .Table -}}
// {{$tableNamePlural}}Where retrieves all records with the specified column values.
func {{$tableNamePlural}}Where(db boil.DB, columns map[string]interface{}) ([]*{{$tableNameSingular}}, error) {
var {{$varNamePlural}} []*{{$tableNameSingular}}
query := fmt.Sprintf(`SELECT {{selectParamNames $dbName .Columns}} FROM {{.Table}} WHERE %s`, boil.Where(columns))
err := db.Select(&{{$varNamePlural}}, query, boil.WhereParams(columns)...)
if err != nil {
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{$tableName}}: %s", err)
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table}}: %s", err)
}
return {{$varName}}, nil
return {{$varNamePlural}}, nil
}