Moved template functions to strmangle package
* Finished Find and FindX
This commit is contained in:
parent
ba8793ec1a
commit
c17e48c14a
7 changed files with 186 additions and 116 deletions
boil
cmds
strmangle
25
boil/bind.go
25
boil/bind.go
|
@ -3,17 +3,14 @@ package boil
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/pobri19/sqlboiler/strmangle"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (q *Query) Bind(obj interface{}) error {
|
func (q *Query) Bind(obj interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStructPointers(obj interface{}, columns ...string) []interface{} {
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkType(obj interface{}) (reflect.Type, bool, error) {
|
func checkType(obj interface{}) (reflect.Type, bool, error) {
|
||||||
val := reflect.ValueOf(obj)
|
val := reflect.ValueOf(obj)
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
@ -51,3 +48,21 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
|
||||||
|
|
||||||
return typ, isSlice, nil
|
return typ, isSlice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStructPointers returns a slice of pointers to the matching columns in obj
|
||||||
|
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
|
||||||
|
val := reflect.ValueOf(obj).Elem()
|
||||||
|
ret := make([]interface{}, len(columns))
|
||||||
|
|
||||||
|
for i, c := range columns {
|
||||||
|
field := val.FieldByName(strmangle.TitleCase(c))
|
||||||
|
if !field.IsValid() {
|
||||||
|
panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
field = field.Addr()
|
||||||
|
ret[i] = field.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,28 @@ package boil
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetStructPointers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
o := struct {
|
||||||
|
Title string
|
||||||
|
ID *int
|
||||||
|
}{
|
||||||
|
Title: "patrick",
|
||||||
|
}
|
||||||
|
|
||||||
|
ptrs := GetStructPointers(&o, "title", "id")
|
||||||
|
*ptrs[0].(*string) = "test"
|
||||||
|
if o.Title != "test" {
|
||||||
|
t.Errorf("Expected test, got %s", o.Title)
|
||||||
|
}
|
||||||
|
x := 5
|
||||||
|
*ptrs[1].(**int) = &x
|
||||||
|
if *o.ID != 5 {
|
||||||
|
t.Errorf("Expected 5, got %d", *o.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCheckType(t *testing.T) {
|
func TestCheckType(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
|
"github.com/pobri19/sqlboiler/strmangle"
|
||||||
)
|
)
|
||||||
|
|
||||||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||||
|
@ -83,28 +84,30 @@ var sqlBoilerTestMainImports = map[string]imports{
|
||||||
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
|
||||||
// If you wish to pass a new function into your own template, add a pointer to it here.
|
// If you wish to pass a new function into your own template, add a pointer to it here.
|
||||||
var sqlBoilerTemplateFuncs = template.FuncMap{
|
var sqlBoilerTemplateFuncs = template.FuncMap{
|
||||||
"singular": singular,
|
"singular": strmangle.Singular,
|
||||||
"plural": plural,
|
"plural": strmangle.Plural,
|
||||||
"titleCase": titleCase,
|
"titleCase": strmangle.TitleCase,
|
||||||
"titleCaseSingular": titleCaseSingular,
|
"titleCaseSingular": strmangle.TitleCaseSingular,
|
||||||
"titleCasePlural": titleCasePlural,
|
"titleCasePlural": strmangle.TitleCasePlural,
|
||||||
"camelCase": camelCase,
|
"camelCase": strmangle.CamelCase,
|
||||||
"camelCaseSingular": camelCaseSingular,
|
"camelCaseSingular": strmangle.CamelCaseSingular,
|
||||||
"camelCasePlural": camelCasePlural,
|
"camelCasePlural": strmangle.CamelCasePlural,
|
||||||
"commaList": commaList,
|
"camelCaseCommaList": strmangle.CamelCaseCommaList,
|
||||||
"makeDBName": makeDBName,
|
"commaList": strmangle.CommaList,
|
||||||
"selectParamNames": selectParamNames,
|
"makeDBName": strmangle.MakeDBName,
|
||||||
"insertParamNames": insertParamNames,
|
"selectParamNames": strmangle.SelectParamNames,
|
||||||
"insertParamFlags": insertParamFlags,
|
"insertParamNames": strmangle.InsertParamNames,
|
||||||
"insertParamVariables": insertParamVariables,
|
"insertParamFlags": strmangle.InsertParamFlags,
|
||||||
"scanParamNames": scanParamNames,
|
"insertParamVariables": strmangle.InsertParamVariables,
|
||||||
"hasPrimaryKey": hasPrimaryKey,
|
"scanParamNames": strmangle.ScanParamNames,
|
||||||
"wherePrimaryKey": wherePrimaryKey,
|
"hasPrimaryKey": strmangle.HasPrimaryKey,
|
||||||
"paramsPrimaryKey": paramsPrimaryKey,
|
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
|
||||||
"primaryKeyFlagIndex": primaryKeyFlagIndex,
|
"wherePrimaryKey": strmangle.WherePrimaryKey,
|
||||||
"updateParamNames": updateParamNames,
|
"paramsPrimaryKey": strmangle.ParamsPrimaryKey,
|
||||||
"updateParamVariables": updateParamVariables,
|
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
|
||||||
"primaryKeyStrList": primaryKeyStrList,
|
"updateParamNames": strmangle.UpdateParamNames,
|
||||||
|
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||||
|
"primaryKeyStrList": strmangle.PrimaryKeyStrList,
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfigFile loads the toml config file into the cfg object
|
// LoadConfigFile loads the toml config file into the cfg object
|
||||||
|
|
|
@ -32,6 +32,10 @@ func init() {
|
||||||
{Name: "fun_time", Type: "time.Time", IsNullable: false},
|
{Name: "fun_time", Type: "time.Time", IsNullable: false},
|
||||||
{Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false},
|
{Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false},
|
||||||
},
|
},
|
||||||
|
PKey: &dbdrivers.PrimaryKey{
|
||||||
|
Name: "pkey_thing",
|
||||||
|
Columns: []string{"id", "fun_id"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "spiderman",
|
Name: "spiderman",
|
||||||
|
|
|
@ -1,32 +1,28 @@
|
||||||
|
{{- if hasPrimaryKey .Table.PKey -}}
|
||||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||||
{{- $dbName := singular .Table.Name -}}
|
{{- $dbName := singular .Table.Name -}}
|
||||||
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
||||||
// {{$tableNameSingular}}Find retrieves a single record by ID.
|
// {{$tableNameSingular}}Find retrieves a single record by ID.
|
||||||
func {{$tableNameSingular}}Find(id int64, columns ...string) (*{{$tableNameSingular}}, error) {
|
func {{$tableNameSingular}}Find({{primaryKeyFuncSig .Table.Columns .Table.PKey.Columns}}, columns ...string) (*{{$tableNameSingular}}, error) {
|
||||||
return {{$tableNameSingular}}FindX(boil.GetDB(), id, columns...)
|
return {{$tableNameSingular}}FindX(boil.GetDB(), {{camelCaseCommaList .Table.PKey.Columns}}, columns...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func {{$tableNameSingular}}FindX(exec boil.Executor, id int64, columns ...string) (*{{$tableNameSingular}}, error) {
|
func {{$tableNameSingular}}FindX(exec boil.Executor, {{primaryKeyFuncSig .Table.Columns .Table.PKey.Columns}}, columns ...string) (*{{$tableNameSingular}}, error) {
|
||||||
if id == 0 {
|
|
||||||
return nil, errors.New("{{.PkgName}}: no id provided for {{.Table.Name}} select")
|
|
||||||
}
|
|
||||||
|
|
||||||
var {{$varNameSingular}} *{{$tableNameSingular}}
|
var {{$varNameSingular}} *{{$tableNameSingular}}
|
||||||
mods := []qs.QueryMod{
|
mods := []qs.QueryMod{
|
||||||
qs.Select(columns...),
|
qs.Select(columns...),
|
||||||
qs.From("{{.Table.Name}}"),
|
qs.From("{{.Table.Name}}"),
|
||||||
qs.Where("id=$1", id),
|
qs.Where("{{wherePrimaryKey .Table.PKey.Columns 1}}", {{camelCaseCommaList .Table.PKey.Columns}}),
|
||||||
}
|
}
|
||||||
|
|
||||||
q := NewQueryX(exec, mods...)
|
q := NewQueryX(exec, mods...)
|
||||||
|
|
||||||
err := boil.ExecQueryOne(q).Scan(
|
err := boil.ExecQueryOne(q).Scan(boil.GetStructPointers(&{{$varNameSingular}}, columns...)...)
|
||||||
)
|
|
||||||
|
|
||||||
//GetStructPointers({{$varNameSingular}}, columnsthings)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)
|
return nil, fmt.Errorf("{{.PkgName}}: unable to select from {{.Table.Name}}: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return {{$varNameSingular}}, nil
|
return {{$varNameSingular}}, nil
|
||||||
}
|
}
|
||||||
|
{{- end -}}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package cmds
|
package strmangle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -8,25 +8,25 @@ import (
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// plural converts singular words to plural words (eg: person to people)
|
// Plural converts singular words to plural words (eg: person to people)
|
||||||
func plural(name string) string {
|
func Plural(name string) string {
|
||||||
splits := strings.Split(name, "_")
|
splits := strings.Split(name, "_")
|
||||||
splits[len(splits)-1] = inflection.Plural(splits[len(splits)-1])
|
splits[len(splits)-1] = inflection.Plural(splits[len(splits)-1])
|
||||||
return strings.Join(splits, "_")
|
return strings.Join(splits, "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
// singular converts plural words to singular words (eg: people to person)
|
// Singular converts plural words to singular words (eg: people to person)
|
||||||
func singular(name string) string {
|
func Singular(name string) string {
|
||||||
splits := strings.Split(name, "_")
|
splits := strings.Split(name, "_")
|
||||||
splits[len(splits)-1] = inflection.Singular(splits[len(splits)-1])
|
splits[len(splits)-1] = inflection.Singular(splits[len(splits)-1])
|
||||||
return strings.Join(splits, "_")
|
return strings.Join(splits, "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
// titleCase changes a snake-case variable name
|
// TitleCase changes a snake-case variable name
|
||||||
// into a go styled object variable name of "ColumnName".
|
// into a go styled object variable name of "ColumnName".
|
||||||
// titleCase also fully uppercases "ID" components of names, for example
|
// titleCase also fully uppercases "ID" components of names, for example
|
||||||
// "column_name_id" to "ColumnNameID".
|
// "column_name_id" to "ColumnNameID".
|
||||||
func titleCase(name string) string {
|
func TitleCase(name string) string {
|
||||||
splits := strings.Split(name, "_")
|
splits := strings.Split(name, "_")
|
||||||
|
|
||||||
for i, split := range splits {
|
for i, split := range splits {
|
||||||
|
@ -41,27 +41,27 @@ func titleCase(name string) string {
|
||||||
return strings.Join(splits, "")
|
return strings.Join(splits, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// titleCaseSingular changes a snake-case variable name
|
// TitleCaseSingular changes a snake-case variable name
|
||||||
// to a go styled object variable name of "ColumnName".
|
// to a go styled object variable name of "ColumnName".
|
||||||
// titleCaseSingular also converts the last word in the
|
// titleCaseSingular also converts the last word in the
|
||||||
// variable name to a singularized version of itself.
|
// variable name to a singularized version of itself.
|
||||||
func titleCaseSingular(name string) string {
|
func TitleCaseSingular(name string) string {
|
||||||
return titleCase(singular(name))
|
return TitleCase(Singular(name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// titleCasePlural changes a snake-case variable name
|
// TitleCasePlural changes a snake-case variable name
|
||||||
// to a go styled object variable name of "ColumnName".
|
// to a go styled object variable name of "ColumnName".
|
||||||
// titleCasePlural also converts the last word in the
|
// titleCasePlural also converts the last word in the
|
||||||
// variable name to a pluralized version of itself.
|
// variable name to a pluralized version of itself.
|
||||||
func titleCasePlural(name string) string {
|
func TitleCasePlural(name string) string {
|
||||||
return titleCase(plural(name))
|
return TitleCase(Plural(name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// camelCase takes a variable name in the format of "var_name" and converts
|
// CamelCase takes a variable name in the format of "var_name" and converts
|
||||||
// it into a go styled variable name of "varName".
|
// it into a go styled variable name of "varName".
|
||||||
// camelCase also fully uppercases "ID" components of names, for example
|
// camelCase also fully uppercases "ID" components of names, for example
|
||||||
// "var_name_id" to "varNameID".
|
// "var_name_id" to "varNameID".
|
||||||
func camelCase(name string) string {
|
func CamelCase(name string) string {
|
||||||
splits := strings.Split(name, "_")
|
splits := strings.Split(name, "_")
|
||||||
|
|
||||||
for i, split := range splits {
|
for i, split := range splits {
|
||||||
|
@ -80,38 +80,50 @@ func camelCase(name string) string {
|
||||||
return strings.Join(splits, "")
|
return strings.Join(splits, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// camelCaseSingular takes a variable name in the format of "var_name" and converts
|
// CamelCaseSingular takes a variable name in the format of "var_name" and converts
|
||||||
// it into a go styled variable name of "varName".
|
// it into a go styled variable name of "varName".
|
||||||
// camelCaseSingular also converts the last word in the
|
// CamelCaseSingular also converts the last word in the
|
||||||
// variable name to a singularized version of itself.
|
// variable name to a singularized version of itself.
|
||||||
func camelCaseSingular(name string) string {
|
func CamelCaseSingular(name string) string {
|
||||||
return camelCase(singular(name))
|
return CamelCase(Singular(name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// camelCasePlural takes a variable name in the format of "var_name" and converts
|
// CamelCasePlural takes a variable name in the format of "var_name" and converts
|
||||||
// it into a go styled variable name of "varName".
|
// it into a go styled variable name of "varName".
|
||||||
// camelCasePlural also converts the last word in the
|
// CamelCasePlural also converts the last word in the
|
||||||
// variable name to a pluralized version of itself.
|
// variable name to a pluralized version of itself.
|
||||||
func camelCasePlural(name string) string {
|
func CamelCasePlural(name string) string {
|
||||||
return camelCase(plural(name))
|
return CamelCase(Plural(name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeDBName takes a table name in the format of "table_name" and a
|
// CamelCaseCommaList generates a list of comma seperated camel cased column names
|
||||||
|
// example: thingName, stuffName, etc
|
||||||
|
func CamelCaseCommaList(pkeyColumns []string) string {
|
||||||
|
var output []string
|
||||||
|
|
||||||
|
for _, c := range pkeyColumns {
|
||||||
|
output = append(output, CamelCase(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(output, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeDBName takes a table name in the format of "table_name" and a
|
||||||
// column name in the format of "column_name" and returns a name used in the
|
// column name in the format of "column_name" and returns a name used in the
|
||||||
// `db:""` component of an object in the format of "table_name_column_name"
|
// `db:""` component of an object in the format of "table_name_column_name"
|
||||||
func makeDBName(tableName, colName string) string {
|
func MakeDBName(tableName, colName string) string {
|
||||||
return tableName + "_" + colName
|
return tableName + "_" + colName
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateParamNames takes a []Column and returns a comma seperated
|
// UpdateParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for the update statement template SET clause.
|
// list of parameter names for the update statement template SET clause.
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: col1=$1,col2=$2,col3=$3
|
||||||
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
||||||
func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
|
func UpdateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
counter := 0
|
counter := 0
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if isPrimaryKey(c.Name, pkeyColumns) {
|
if IsPrimaryKey(c.Name, pkeyColumns) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
counter++
|
counter++
|
||||||
|
@ -120,26 +132,26 @@ func updateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||||
return strings.Join(names, ",")
|
return strings.Join(names, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateParamVariables takes a prefix and a []Columns and returns a
|
// UpdateParamVariables takes a prefix and a []Columns and returns a
|
||||||
// comma seperated list of parameter variable names for the update statement.
|
// comma seperated list of parameter variable names for the update statement.
|
||||||
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||||
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
||||||
func updateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
|
func UpdateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if isPrimaryKey(c.Name, pkeyColumns) {
|
if IsPrimaryKey(c.Name, pkeyColumns) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n := prefix + titleCase(c.Name)
|
n := prefix + TitleCase(c.Name)
|
||||||
names = append(names, n)
|
names = append(names, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPrimaryKey checks if the column is found in the primary key columns
|
// IsPrimaryKey checks if the column is found in the primary key columns
|
||||||
func isPrimaryKey(col string, pkeyCols []string) bool {
|
func IsPrimaryKey(col string, pkeyCols []string) bool {
|
||||||
for _, pkey := range pkeyCols {
|
for _, pkey := range pkeyCols {
|
||||||
if pkey == col {
|
if pkey == col {
|
||||||
return true
|
return true
|
||||||
|
@ -149,9 +161,9 @@ func isPrimaryKey(col string, pkeyCols []string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamNames takes a []Column and returns a comma seperated
|
// InsertParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for the insert statement template.
|
// list of parameter names for the insert statement template.
|
||||||
func insertParamNames(columns []dbdrivers.Column) string {
|
func InsertParamNames(columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
names = append(names, c.Name)
|
names = append(names, c.Name)
|
||||||
|
@ -159,9 +171,9 @@ func insertParamNames(columns []dbdrivers.Column) string {
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamFlags takes a []Column and returns a comma seperated
|
// InsertParamFlags takes a []Column and returns a comma seperated
|
||||||
// list of parameter flags for the insert statement template.
|
// list of parameter flags for the insert statement template.
|
||||||
func insertParamFlags(columns []dbdrivers.Column) string {
|
func InsertParamFlags(columns []dbdrivers.Column) string {
|
||||||
params := make([]string, 0, len(columns))
|
params := make([]string, 0, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
params = append(params, fmt.Sprintf("$%d", i+1))
|
params = append(params, fmt.Sprintf("$%d", i+1))
|
||||||
|
@ -169,48 +181,48 @@ func insertParamFlags(columns []dbdrivers.Column) string {
|
||||||
return strings.Join(params, ", ")
|
return strings.Join(params, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamVariables takes a prefix and a []Columns and returns a
|
// InsertParamVariables takes a prefix and a []Columns and returns a
|
||||||
// comma seperated list of parameter variable names for the insert statement.
|
// comma seperated list of parameter variable names for the insert statement.
|
||||||
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
|
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||||
func insertParamVariables(prefix string, columns []dbdrivers.Column) string {
|
func InsertParamVariables(prefix string, columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
n := prefix + titleCase(c.Name)
|
n := prefix + TitleCase(c.Name)
|
||||||
names = append(names, n)
|
names = append(names, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectParamNames takes a []Column and returns a comma seperated
|
// SelectParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names with for the select statement template.
|
// list of parameter names with for the select statement template.
|
||||||
// It also uses the table name to generate the "AS" part of the statement, for
|
// It also uses the table name to generate the "AS" part of the statement, for
|
||||||
// example: var_name AS table_name_var_name, ...
|
// example: var_name AS table_name_var_name, ...
|
||||||
func selectParamNames(tableName string, columns []dbdrivers.Column) string {
|
func SelectParamNames(tableName string, columns []dbdrivers.Column) string {
|
||||||
selects := make([]string, 0, len(columns))
|
selects := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
statement := fmt.Sprintf("%s AS %s", c.Name, makeDBName(tableName, c.Name))
|
statement := fmt.Sprintf("%s AS %s", c.Name, MakeDBName(tableName, c.Name))
|
||||||
selects = append(selects, statement)
|
selects = append(selects, statement)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(selects, ", ")
|
return strings.Join(selects, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanParamNames takes a []Column and returns a comma seperated
|
// ScanParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for use in a db.Scan() call.
|
// list of parameter names for use in a db.Scan() call.
|
||||||
func scanParamNames(object string, columns []dbdrivers.Column) string {
|
func ScanParamNames(object string, columns []dbdrivers.Column) string {
|
||||||
scans := make([]string, 0, len(columns))
|
scans := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
|
statement := fmt.Sprintf("&%s.%s", object, TitleCase(c.Name))
|
||||||
scans = append(scans, statement)
|
scans = append(scans, statement)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(scans, ", ")
|
return strings.Join(scans, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
// HasPrimaryKey returns true if one of the columns passed in is a primary key
|
||||||
func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
|
func HasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
|
||||||
if pKey == nil || len(pKey.Columns) == 0 {
|
if pKey == nil || len(pKey.Columns) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -218,9 +230,26 @@ func hasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// wherePrimaryKey returns the where clause using start as the $ flag index
|
// PrimaryKeyFuncSig generates the function signature parameters.
|
||||||
|
// example: id int64, thingName string
|
||||||
|
func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
|
||||||
|
var output []string
|
||||||
|
|
||||||
|
for _, pk := range pkeyCols {
|
||||||
|
for _, c := range cols {
|
||||||
|
if pk == c.Name {
|
||||||
|
output = append(output, fmt.Sprintf("%s %s", CamelCase(pk), c.Type))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(output, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WherePrimaryKey returns the where clause using start as the $ flag index
|
||||||
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
||||||
func wherePrimaryKey(pkeyCols []string, start int) string {
|
func WherePrimaryKey(pkeyCols []string, start int) string {
|
||||||
var output string
|
var output string
|
||||||
|
|
||||||
cols := make([]string, len(pkeyCols))
|
cols := make([]string, len(pkeyCols))
|
||||||
|
@ -238,9 +267,9 @@ func wherePrimaryKey(pkeyCols []string, start int) string {
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
// primaryKeyStrList returns a list of primary key column names in strings
|
// PrimaryKeyStrList returns a list of primary key column names in strings
|
||||||
// For example: "col1", "col2", "col3"
|
// For example: "col1", "col2", "col3"
|
||||||
func primaryKeyStrList(pkeyCols []string) string {
|
func PrimaryKeyStrList(pkeyCols []string) string {
|
||||||
cols := make([]string, len(pkeyCols))
|
cols := make([]string, len(pkeyCols))
|
||||||
copy(cols, pkeyCols)
|
copy(cols, pkeyCols)
|
||||||
|
|
||||||
|
@ -251,20 +280,20 @@ func primaryKeyStrList(pkeyCols []string) string {
|
||||||
return strings.Join(cols, ", ")
|
return strings.Join(cols, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// commaList returns a comma seperated list: "col1, col2, col3"
|
// CommaList returns a comma seperated list: "col1, col2, col3"
|
||||||
func commaList(cols []string) string {
|
func CommaList(cols []string) string {
|
||||||
return strings.Join(cols, ", ")
|
return strings.Join(cols, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// paramsPrimaryKey returns the parameters for the sql statement $ flags
|
// ParamsPrimaryKey returns the parameters for the sql statement $ flags
|
||||||
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
|
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
|
||||||
func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
|
func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
var n string
|
var n string
|
||||||
if shouldTitleCase {
|
if shouldTitleCase {
|
||||||
n = prefix + titleCase(c)
|
n = prefix + TitleCase(c)
|
||||||
} else {
|
} else {
|
||||||
n = prefix + c
|
n = prefix + c
|
||||||
}
|
}
|
||||||
|
@ -274,6 +303,7 @@ func paramsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) str
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func primaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
// PrimaryKeyFlagIndex generates the primary key column flag number for the query params
|
||||||
|
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
||||||
return len(regularCols) - len(pkeyCols) + 1
|
return len(regularCols) - len(pkeyCols) + 1
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package cmds
|
package strmangle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -24,7 +24,7 @@ func TestSingular(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
if out := singular(test.In); out != test.Out {
|
if out := Singular(test.In); out != test.Out {
|
||||||
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func TestPlural(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
if out := plural(test.In); out != test.Out {
|
if out := Plural(test.In); out != test.Out {
|
||||||
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func TestTitleCase(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
if out := titleCase(test.In); out != test.Out {
|
if out := TitleCase(test.In); out != test.Out {
|
||||||
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ func TestCamelCase(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
if out := camelCase(test.In); out != test.Out {
|
if out := CamelCase(test.In); out != test.Out {
|
||||||
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
t.Errorf("[%d] (%s) Out was wrong: %q, want: %q", i, test.In, out, test.Out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,7 @@ func TestCamelCase(t *testing.T) {
|
||||||
func TestMakeDBName(t *testing.T) {
|
func TestMakeDBName(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
if out := makeDBName("a", "b"); out != "a_b" {
|
if out := MakeDBName("a", "b"); out != "a_b" {
|
||||||
t.Error("Out was wrong:", out)
|
t.Error("Out was wrong:", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ func TestUpdateParamNames(t *testing.T) {
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
out := updateParamNames(testCols, []string{"id"})
|
out := UpdateParamNames(testCols, []string{"id"})
|
||||||
if out != "friend_column=$1,enemy_column_thing=$2" {
|
if out != "friend_column=$1,enemy_column_thing=$2" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ func TestUpdateParamVariables(t *testing.T) {
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
out := updateParamVariables("o.", testCols, []string{"id"})
|
out := UpdateParamVariables("o.", testCols, []string{"id"})
|
||||||
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -128,7 +128,7 @@ func TestUpdateParamVariables(t *testing.T) {
|
||||||
func TestInsertParamNames(t *testing.T) {
|
func TestInsertParamNames(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
out := insertParamNames(testColumns)
|
out := InsertParamNames(testColumns)
|
||||||
if out != "friend_column, enemy_column_thing" {
|
if out != "friend_column, enemy_column_thing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -137,14 +137,14 @@ func TestInsertParamNames(t *testing.T) {
|
||||||
func TestInsertParamFlags(t *testing.T) {
|
func TestInsertParamFlags(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
out := insertParamFlags(testColumns)
|
out := InsertParamFlags(testColumns)
|
||||||
if out != "$1, $2" {
|
if out != "$1, $2" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertParamVariables(t *testing.T) {
|
func TestInsertParamVariables(t *testing.T) {
|
||||||
out := insertParamVariables("o.", testColumns)
|
out := InsertParamVariables("o.", testColumns)
|
||||||
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
if out != "o.FriendColumn, o.EnemyColumnThing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -153,7 +153,7 @@ func TestInsertParamVariables(t *testing.T) {
|
||||||
func TestSelectParamFlags(t *testing.T) {
|
func TestSelectParamFlags(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
out := selectParamNames("table", testColumns)
|
out := SelectParamNames("table", testColumns)
|
||||||
if out != "friend_column AS table_friend_column, enemy_column_thing AS table_enemy_column_thing" {
|
if out != "friend_column AS table_friend_column, enemy_column_thing AS table_enemy_column_thing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ func TestSelectParamFlags(t *testing.T) {
|
||||||
func TestScanParams(t *testing.T) {
|
func TestScanParams(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
out := scanParamNames("object", testColumns)
|
out := ScanParamNames("object", testColumns)
|
||||||
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
|
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
|
||||||
t.Error("Wrong output:", out)
|
t.Error("Wrong output:", out)
|
||||||
}
|
}
|
||||||
|
@ -172,17 +172,17 @@ func TestHasPrimaryKey(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var pkey *dbdrivers.PrimaryKey
|
var pkey *dbdrivers.PrimaryKey
|
||||||
if hasPrimaryKey(pkey) {
|
if HasPrimaryKey(pkey) {
|
||||||
t.Errorf("1) Expected false, got true")
|
t.Errorf("1) Expected false, got true")
|
||||||
}
|
}
|
||||||
|
|
||||||
pkey = &dbdrivers.PrimaryKey{}
|
pkey = &dbdrivers.PrimaryKey{}
|
||||||
if hasPrimaryKey(pkey) {
|
if HasPrimaryKey(pkey) {
|
||||||
t.Errorf("2) Expected false, got true")
|
t.Errorf("2) Expected false, got true")
|
||||||
}
|
}
|
||||||
|
|
||||||
pkey.Columns = append(pkey.Columns, "test")
|
pkey.Columns = append(pkey.Columns, "test")
|
||||||
if !hasPrimaryKey(pkey) {
|
if !HasPrimaryKey(pkey) {
|
||||||
t.Errorf("3) Expected true, got false")
|
t.Errorf("3) Expected true, got false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -210,7 +210,7 @@ func TestParamsPrimaryKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
|
r := ParamsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
|
||||||
if r != test.Should {
|
if r != test.Should {
|
||||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
}
|
}
|
||||||
|
@ -236,7 +236,7 @@ func TestParamsPrimaryKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests2 {
|
for i, test := range tests2 {
|
||||||
r := paramsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
|
r := ParamsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
|
||||||
if r != test.Should {
|
if r != test.Should {
|
||||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ func TestWherePrimaryKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
r := wherePrimaryKey(test.Pkey.Columns, test.Start)
|
r := WherePrimaryKey(test.Pkey.Columns, test.Start)
|
||||||
if r != test.Should {
|
if r != test.Should {
|
||||||
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue