2016-02-29 10:39:49 +01:00
|
|
|
package cmds
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"fmt"
|
|
|
|
"go/format"
|
|
|
|
"strings"
|
|
|
|
"text/template"
|
|
|
|
|
2016-03-18 12:26:48 +01:00
|
|
|
"github.com/jinzhu/inflection"
|
2016-02-29 10:39:49 +01:00
|
|
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
|
|
|
)
|
|
|
|
|
|
|
|
// generateTemplate generates the template associated to the passed in command name.
|
2016-03-01 15:20:13 +01:00
|
|
|
func generateTemplate(commandName string, data *tplData) []byte {
|
|
|
|
template := getTemplate(commandName)
|
2016-02-29 10:39:49 +01:00
|
|
|
|
|
|
|
if template == nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to find the template: %s", commandName+".tpl"))
|
|
|
|
}
|
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
output, err := processTemplate(template, data)
|
2016-02-29 10:39:49 +01:00
|
|
|
if err != nil {
|
2016-03-03 20:30:48 +01:00
|
|
|
errorQuit(fmt.Errorf("Unable to process the template %s for table %s: %s", template.Name(), data.Table, err))
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
return output
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
// getTemplate returns a pointer to the template matching the passed in name
|
|
|
|
func getTemplate(name string) *template.Template {
|
|
|
|
var tpl *template.Template
|
2016-02-29 10:39:49 +01:00
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
// Find the template that matches the passed in template name
|
|
|
|
for _, t := range templates {
|
|
|
|
if t.Name() == name+".tpl" {
|
|
|
|
tpl = t
|
|
|
|
break
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
2016-03-01 15:20:13 +01:00
|
|
|
}
|
2016-02-29 10:39:49 +01:00
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
return tpl
|
|
|
|
}
|
2016-02-29 10:39:49 +01:00
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
// processTemplate takes a template and returns the output of the template execution.
|
|
|
|
func processTemplate(t *template.Template, data *tplData) ([]byte, error) {
|
|
|
|
var buf bytes.Buffer
|
|
|
|
if err := t.Execute(&buf, data); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
output, err := format.Source(buf.Bytes())
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 15:20:13 +01:00
|
|
|
return output, nil
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-18 12:26:48 +01:00
|
|
|
// plural converts singular words to plural words (eg: person to people)
|
|
|
|
func plural(name string) string {
|
|
|
|
splits := strings.Split(name, "_")
|
|
|
|
splits[len(splits)-1] = inflection.Plural(splits[len(splits)-1])
|
|
|
|
return strings.Join(splits, "_")
|
|
|
|
}
|
|
|
|
|
|
|
|
// singular converts plural words to singular words (eg: people to person)
|
|
|
|
func singular(name string) string {
|
|
|
|
splits := strings.Split(name, "_")
|
|
|
|
splits[len(splits)-1] = inflection.Singular(splits[len(splits)-1])
|
|
|
|
return strings.Join(splits, "_")
|
|
|
|
}
|
|
|
|
|
2016-03-02 04:11:47 +01:00
|
|
|
// titleCase changes a snake-case variable name
|
|
|
|
// into a go styled object variable name of "ColumnName".
|
2016-03-01 17:34:57 +01:00
|
|
|
// titleCase also fully uppercases "ID" components of names, for example
|
2016-02-29 10:39:49 +01:00
|
|
|
// "column_name_id" to "ColumnNameID".
|
2016-03-01 17:34:57 +01:00
|
|
|
func titleCase(name string) string {
|
|
|
|
splits := strings.Split(name, "_")
|
2016-02-29 10:39:49 +01:00
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
for i, split := range splits {
|
|
|
|
if split == "id" {
|
|
|
|
splits[i] = "ID"
|
2016-02-29 10:39:49 +01:00
|
|
|
continue
|
|
|
|
}
|
2016-03-01 17:34:57 +01:00
|
|
|
|
|
|
|
splits[i] = strings.Title(split)
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
return strings.Join(splits, "")
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-18 12:26:48 +01:00
|
|
|
// titleCaseSingular changes a snake-case variable name
|
|
|
|
// to a go styled object variable name of "ColumnName".
|
|
|
|
// titleCaseSingular also converts the last word in the
|
|
|
|
// variable name to a singularized version of itself.
|
|
|
|
func titleCaseSingular(name string) string {
|
|
|
|
return titleCase(singular(name))
|
|
|
|
}
|
|
|
|
|
|
|
|
// titleCasePlural changes a snake-case variable name
|
|
|
|
// to a go styled object variable name of "ColumnName".
|
|
|
|
// titleCasePlural also converts the last word in the
|
|
|
|
// variable name to a pluralized version of itself.
|
|
|
|
func titleCasePlural(name string) string {
|
|
|
|
return titleCase(plural(name))
|
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
// camelCase takes a variable name in the format of "var_name" and converts
|
2016-02-29 10:39:49 +01:00
|
|
|
// it into a go styled variable name of "varName".
|
2016-03-01 17:34:57 +01:00
|
|
|
// camelCase also fully uppercases "ID" components of names, for example
|
2016-02-29 10:39:49 +01:00
|
|
|
// "var_name_id" to "varNameID".
|
2016-03-01 17:34:57 +01:00
|
|
|
func camelCase(name string) string {
|
|
|
|
splits := strings.Split(name, "_")
|
2016-02-29 10:39:49 +01:00
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
for i, split := range splits {
|
|
|
|
if split == "id" && i > 0 {
|
2016-03-02 04:11:47 +01:00
|
|
|
splits[i] = "ID"
|
2016-02-29 10:39:49 +01:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if i == 0 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
splits[i] = strings.Title(split)
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
return strings.Join(splits, "")
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-18 12:26:48 +01:00
|
|
|
// camelCaseSingular takes a variable name in the format of "var_name" and converts
|
|
|
|
// it into a go styled variable name of "varName".
|
|
|
|
// camelCaseSingular also converts the last word in the
|
|
|
|
// variable name to a singularized version of itself.
|
|
|
|
func camelCaseSingular(name string) string {
|
|
|
|
return camelCase(singular(name))
|
|
|
|
}
|
|
|
|
|
|
|
|
// camelCasePlural takes a variable name in the format of "var_name" and converts
|
|
|
|
// it into a go styled variable name of "varName".
|
|
|
|
// camelCasePlural also converts the last word in the
|
|
|
|
// variable name to a pluralized version of itself.
|
|
|
|
func camelCasePlural(name string) string {
|
|
|
|
return camelCase(plural(name))
|
|
|
|
}
|
|
|
|
|
2016-03-01 16:53:56 +01:00
|
|
|
// makeDBName takes a table name in the format of "table_name" and a
|
2016-02-29 10:39:49 +01:00
|
|
|
// column name in the format of "column_name" and returns a name used in the
|
|
|
|
// `db:""` component of an object in the format of "table_name_column_name"
|
2016-03-01 16:53:56 +01:00
|
|
|
func makeDBName(tableName, colName string) string {
|
2016-02-29 10:39:49 +01:00
|
|
|
return tableName + "_" + colName
|
|
|
|
}
|
|
|
|
|
2016-03-19 07:22:10 +01:00
|
|
|
// updateParamNames takes a []DBColumn and returns a comma seperated
|
|
|
|
// list of parameter names for the update statement template SET clause.
|
|
|
|
// eg: col1=$1,col2=$2,col3=$3
|
|
|
|
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
|
|
|
func updateParamNames(columns []dbdrivers.DBColumn) string {
|
|
|
|
names := make([]string, 0, len(columns))
|
|
|
|
counter := 0
|
|
|
|
for _, c := range columns {
|
|
|
|
if c.IsPrimaryKey {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
counter++
|
|
|
|
names = append(names, fmt.Sprintf("%s=$%d", c.Name, counter))
|
|
|
|
}
|
|
|
|
return strings.Join(names, ",")
|
|
|
|
}
|
|
|
|
|
|
|
|
// updateParamVariables takes a prefix and a []DBColumns and returns a
|
|
|
|
// comma seperated list of parameter variable names for the update statement.
|
|
|
|
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
|
|
|
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
|
|
|
func updateParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
|
|
|
names := make([]string, 0, len(columns))
|
|
|
|
|
|
|
|
for _, c := range columns {
|
|
|
|
if c.IsPrimaryKey {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
n := prefix + titleCase(c.Name)
|
|
|
|
names = append(names, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
return strings.Join(names, ", ")
|
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
// insertParamNames takes a []DBColumn and returns a comma seperated
|
2016-02-29 10:39:49 +01:00
|
|
|
// list of parameter names for the insert statement template.
|
2016-03-01 17:34:57 +01:00
|
|
|
func insertParamNames(columns []dbdrivers.DBColumn) string {
|
|
|
|
names := make([]string, 0, len(columns))
|
|
|
|
for _, c := range columns {
|
|
|
|
names = append(names, c.Name)
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
2016-03-01 17:34:57 +01:00
|
|
|
return strings.Join(names, ", ")
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
// insertParamFlags takes a []DBColumn and returns a comma seperated
|
2016-02-29 10:39:49 +01:00
|
|
|
// list of parameter flags for the insert statement template.
|
2016-03-01 17:34:57 +01:00
|
|
|
func insertParamFlags(columns []dbdrivers.DBColumn) string {
|
|
|
|
params := make([]string, 0, len(columns))
|
|
|
|
for i := range columns {
|
|
|
|
params = append(params, fmt.Sprintf("$%d", i+1))
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
2016-03-01 17:34:57 +01:00
|
|
|
return strings.Join(params, ", ")
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-03-18 12:26:48 +01:00
|
|
|
// insertParamVariables takes a prefix and a []DBColumns and returns a
|
|
|
|
// comma seperated list of parameter variable names for the insert statement.
|
|
|
|
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
|
|
|
|
func insertParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
|
|
|
names := make([]string, 0, len(columns))
|
|
|
|
|
|
|
|
for _, c := range columns {
|
|
|
|
n := prefix + titleCase(c.Name)
|
|
|
|
names = append(names, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
return strings.Join(names, ", ")
|
|
|
|
}
|
|
|
|
|
2016-03-01 17:34:57 +01:00
|
|
|
// selectParamNames takes a []DBColumn and returns a comma seperated
|
2016-02-29 10:39:49 +01:00
|
|
|
// list of parameter names with for the select statement template.
|
|
|
|
// It also uses the table name to generate the "AS" part of the statement, for
|
|
|
|
// example: var_name AS table_name_var_name, ...
|
2016-03-02 04:11:47 +01:00
|
|
|
func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
|
2016-03-01 17:34:57 +01:00
|
|
|
selects := make([]string, 0, len(columns))
|
|
|
|
for _, c := range columns {
|
2016-03-02 04:11:47 +01:00
|
|
|
statement := fmt.Sprintf("%s AS %s", c.Name, makeDBName(tableName, c.Name))
|
2016-03-01 17:34:57 +01:00
|
|
|
selects = append(selects, statement)
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
2016-03-01 17:34:57 +01:00
|
|
|
|
|
|
|
return strings.Join(selects, ", ")
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
2016-03-02 05:05:25 +01:00
|
|
|
|
|
|
|
// scanParamNames takes a []DBColumn and returns a comma seperated
|
|
|
|
// list of parameter names for use in a db.Scan() call.
|
|
|
|
func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
|
|
|
|
scans := make([]string, 0, len(columns))
|
|
|
|
for _, c := range columns {
|
|
|
|
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
|
|
|
|
scans = append(scans, statement)
|
|
|
|
}
|
|
|
|
|
|
|
|
return strings.Join(scans, ", ")
|
|
|
|
}
|
2016-03-18 16:27:55 +01:00
|
|
|
|
|
|
|
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
|
|
|
func hasPrimaryKey(columns []dbdrivers.DBColumn) bool {
|
|
|
|
for _, c := range columns {
|
|
|
|
if c.IsPrimaryKey {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
// getPrimaryKey returns the primary key column name if one is present
|
|
|
|
func getPrimaryKey(columns []dbdrivers.DBColumn) string {
|
|
|
|
for _, c := range columns {
|
|
|
|
if c.IsPrimaryKey {
|
|
|
|
return c.Name
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return ""
|
|
|
|
}
|