Finished insert template
* Removed functions in helpers not being used
This commit is contained in:
parent
a957bc3836
commit
f059bdebf4
10 changed files with 576 additions and 98 deletions
16
boil/bind.go
16
boil/bind.go
|
@ -49,6 +49,22 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
|
|||
return typ, isSlice, nil
|
||||
}
|
||||
|
||||
// GetStructValues returns the values (as interface) of the matching columns in obj
|
||||
func GetStructValues(obj interface{}, columns ...string) []interface{} {
|
||||
ret := make([]interface{}, len(columns))
|
||||
val := reflect.ValueOf(obj)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
for i, c := range columns {
|
||||
field := val.FieldByName(strmangle.TitleCase(c))
|
||||
ret[i] = field.Interface()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetStructPointers returns a slice of pointers to the matching columns in obj
|
||||
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
|
||||
val := reflect.ValueOf(obj).Elem()
|
||||
|
|
|
@ -1,6 +1,54 @@
|
|||
package boil
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null"
|
||||
)
|
||||
|
||||
func TestGetStructValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
timeThing := time.Now()
|
||||
o := struct {
|
||||
TitleThing string
|
||||
Name string
|
||||
ID int
|
||||
Stuff int
|
||||
Things int
|
||||
Time time.Time
|
||||
NullBool null.Bool
|
||||
}{
|
||||
TitleThing: "patrick",
|
||||
Stuff: 10,
|
||||
Things: 0,
|
||||
Time: timeThing,
|
||||
NullBool: null.NewBool(true, false),
|
||||
}
|
||||
|
||||
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
|
||||
if vals[0].(string) != "patrick" {
|
||||
t.Errorf("Want test, got %s", vals[0])
|
||||
}
|
||||
if vals[1].(string) != "" {
|
||||
t.Errorf("Want empty string, got %s", vals[1])
|
||||
}
|
||||
if vals[2].(int) != 0 {
|
||||
t.Errorf("Want 0, got %d", vals[2])
|
||||
}
|
||||
if vals[3].(int) != 10 {
|
||||
t.Errorf("Want 10, got %d", vals[3])
|
||||
}
|
||||
if vals[4].(int) != 0 {
|
||||
t.Errorf("Want 0, got %d", vals[4])
|
||||
}
|
||||
if !vals[5].(time.Time).Equal(timeThing) {
|
||||
t.Errorf("Want %s, got %s", o.Time, vals[5])
|
||||
}
|
||||
if !vals[6].(null.Bool).IsZero() {
|
||||
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStructPointers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -7,8 +7,79 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/pobri19/sqlboiler/strmangle"
|
||||
)
|
||||
|
||||
// SetComplement subtracts the elements in b from a
|
||||
func SetComplement(a []string, b []string) []string {
|
||||
c := make([]string, 0, len(a))
|
||||
|
||||
for _, aVal := range a {
|
||||
found := false
|
||||
for _, bVal := range b {
|
||||
if aVal == bVal {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c = append(c, aVal)
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// SetIntersect returns the elements that are common in a and b
|
||||
func SetIntersect(a []string, b []string) []string {
|
||||
c := make([]string, 0, len(a))
|
||||
|
||||
for _, aVal := range a {
|
||||
found := false
|
||||
for _, bVal := range b {
|
||||
if aVal == bVal {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
c = append(c, aVal)
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NonZeroDefaultSet returns the fields included in the
|
||||
// defaults slice that are non zero values
|
||||
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
||||
c := make([]string, 0, len(defaults))
|
||||
|
||||
val := reflect.ValueOf(obj)
|
||||
|
||||
for _, d := range defaults {
|
||||
fieldName := strmangle.TitleCase(d)
|
||||
field := val.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
|
||||
}
|
||||
|
||||
zero := reflect.Zero(field.Type())
|
||||
if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
|
||||
c = append(c, d)
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// GenerateParamFlags generates the SQL statement parameter flags
|
||||
// For example, $1,$2,$3 etc. It will start counting at startAt.
|
||||
func GenerateParamFlags(colCount int, startAt int) string {
|
||||
return strmangle.GenerateParamFlags(colCount, startAt)
|
||||
}
|
||||
|
||||
// WherePrimaryKeyIn generates a "in" string for where queries
|
||||
// For example: (col1, col2) IN (($1, $2), ($3, $4))
|
||||
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
||||
|
@ -81,15 +152,13 @@ func SelectNames(results interface{}) string {
|
|||
|
||||
// WhereClause returns the where clause for an sql statement
|
||||
// eg: col1=$1 AND col2=$2 AND col3=$3
|
||||
func WhereClause(columns map[string]interface{}) string {
|
||||
func WhereClause(columns []string) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
for c := range columns {
|
||||
for _, c := range columns {
|
||||
names = append(names, c)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
|
||||
for i, c := range names {
|
||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
||||
}
|
||||
|
@ -115,24 +184,6 @@ func Update(columns map[string]interface{}) string {
|
|||
return strings.Join(names, ",")
|
||||
}
|
||||
|
||||
// WhereParams returns a list of sql parameter values for the query
|
||||
func WhereParams(columns map[string]interface{}) []interface{} {
|
||||
names := make([]string, 0, len(columns))
|
||||
results := make([]interface{}, 0, len(columns))
|
||||
|
||||
for c := range columns {
|
||||
names = append(names, c)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
|
||||
for _, c := range names {
|
||||
results = append(results, columns[c])
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// SetParamNames takes a slice of columns and returns a comma seperated
|
||||
// list of parameter names for a template statement SET clause.
|
||||
// eg: col1=$1,col2=$2,col3=$3
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package boil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null"
|
||||
)
|
||||
|
||||
type testObj struct {
|
||||
|
@ -11,6 +14,134 @@ type testObj struct {
|
|||
HeadSize int
|
||||
}
|
||||
|
||||
func TestSetComplement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
A []string
|
||||
B []string
|
||||
C []string
|
||||
}{
|
||||
{
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
[]string{"thing2", "otherthing", "stuff"},
|
||||
[]string{"thing1", "thing3"},
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
[]string{"thing1", "thing2"},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := SetComplement(test.A, test.B)
|
||||
if !reflect.DeepEqual(test.C, c) {
|
||||
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIntersect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
A []string
|
||||
B []string
|
||||
C []string
|
||||
}{
|
||||
{
|
||||
[]string{"thing1", "thing2", "thing3"},
|
||||
[]string{"thing2", "otherthing", "stuff"},
|
||||
[]string{"thing2"},
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{"thing1", "thing2"},
|
||||
[]string{"thing1", "thing2"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := SetIntersect(test.A, test.B)
|
||||
if !reflect.DeepEqual(test.C, c) {
|
||||
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonZeroDefaultSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type Anything struct {
|
||||
ID int
|
||||
Name string
|
||||
CreatedAt *time.Time
|
||||
UpdatedAt null.Time
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
Defaults []string
|
||||
Obj interface{}
|
||||
Ret []string
|
||||
}{
|
||||
{
|
||||
[]string{"id"},
|
||||
Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"id"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{"id"},
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||
[]string{"id"},
|
||||
},
|
||||
{
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}},
|
||||
[]string{"id", "created_at", "updated_at"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
z := NonZeroDefaultSet(test.Defaults, test.Obj)
|
||||
if !reflect.DeepEqual(test.Ret, z) {
|
||||
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWherePrimaryKeyIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -105,34 +236,15 @@ func TestSelectNames(t *testing.T) {
|
|||
func TestWhereClause(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
columns := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"id": 5,
|
||||
"date": time.Now(),
|
||||
columns := []string{
|
||||
"id",
|
||||
"name",
|
||||
"date",
|
||||
}
|
||||
|
||||
result := WhereClause(columns)
|
||||
|
||||
if result != `date=$1 AND id=$2 AND name=$3` {
|
||||
if result != `id=$1 AND name=$2 AND date=$3` {
|
||||
t.Error("Result was wrong, got:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhereParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
columns := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"id": 5,
|
||||
}
|
||||
|
||||
result := WhereParams(columns)
|
||||
|
||||
if result[0].(int) != 5 {
|
||||
t.Error("Result[0] was wrong, got:", result[0])
|
||||
}
|
||||
|
||||
if result[1].(string) != "bob" {
|
||||
t.Error("Result[1] was wrong, got:", result[1])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ var sqlBoilerImports = imports{
|
|||
standard: importList{
|
||||
`"errors"`,
|
||||
`"fmt"`,
|
||||
`"strings"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/pobri19/sqlboiler/boil"`,
|
||||
|
@ -108,6 +109,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
|
|||
"updateParamNames": strmangle.UpdateParamNames,
|
||||
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||
"primaryKeyStrList": strmangle.PrimaryKeyStrList,
|
||||
"supportsResultObject": strmangle.SupportsResultObject,
|
||||
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
||||
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
||||
}
|
||||
|
||||
// LoadConfigFile loads the toml config file into the cfg object
|
||||
|
|
|
@ -229,6 +229,26 @@ func initTables(tableName string, cmdData *CmdData) error {
|
|||
return errors.New("No tables found in database, migrate some tables first")
|
||||
}
|
||||
|
||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPKeys ensures every table has a primary key column
|
||||
func checkPKeys(tables []dbdrivers.Table) error {
|
||||
var missingPkey []string
|
||||
for _, t := range tables {
|
||||
if t.PKey == nil {
|
||||
missingPkey = append(missingPkey, t.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPkey) != 0 {
|
||||
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -40,8 +40,23 @@ func init() {
|
|||
{
|
||||
Name: "spiderman",
|
||||
Columns: []dbdrivers.Column{
|
||||
{Name: "id", Type: "int64", IsNullable: false},
|
||||
},
|
||||
PKey: &dbdrivers.PrimaryKey{
|
||||
Name: "pkey_id",
|
||||
Columns: []string{"id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "spiderman_table_two",
|
||||
Columns: []dbdrivers.Column{
|
||||
{Name: "id", Type: "int64", IsNullable: false},
|
||||
{Name: "patrick", Type: "string", IsNullable: false},
|
||||
},
|
||||
PKey: &dbdrivers.PrimaryKey{
|
||||
Name: "pkey_id",
|
||||
Columns: []string{"id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
PkgName: "patrick",
|
||||
|
@ -69,6 +84,10 @@ func TestTemplates(t *testing.T) {
|
|||
t.SkipNow()
|
||||
}
|
||||
|
||||
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||
t.Fatalf("%s", err)
|
||||
}
|
||||
|
||||
// Initialize the templates
|
||||
var err error
|
||||
cmdData.Templates, err = loadTemplates("templates")
|
||||
|
|
|
@ -1,47 +1,86 @@
|
|||
{{- if hasPrimaryKey .Table.PKey -}}
|
||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
||||
// {{$tableNameSingular}}Insert inserts a single record.
|
||||
func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error {
|
||||
if o == nil {
|
||||
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
||||
return o.InsertX(boil.GetDB(), whitelist...)
|
||||
}
|
||||
|
||||
if err := o.doBeforeCreateHooks(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var rowID int
|
||||
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
|
||||
}
|
||||
|
||||
if err := o.doAfterCreateHooks(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return rowID, nil
|
||||
}
|
||||
var {{$varNameSingular}}DefaultInsertWhitelist = []string{{"{"}}{{filterColumnsByDefault .Table.Columns false}}{{"}"}}
|
||||
var {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{filterColumnsByDefault .Table.Columns true}}{{"}"}}
|
||||
var {{$varNameSingular}}AutoIncPrimaryKey = "{{autoIncPrimaryKey .Table.Columns .Table.PKey}}"
|
||||
|
||||
func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error {
|
||||
if o == nil {
|
||||
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
||||
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
||||
}
|
||||
|
||||
if len(whitelist) == 0 {
|
||||
whitelist = {{$varNameSingular}}DefaultInsertWhitelist
|
||||
}
|
||||
|
||||
nzDefaultSet := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
|
||||
if len(nzDefaultSet) != 0 {
|
||||
whitelist = append(nzDefaultSet, whitelist...)
|
||||
}
|
||||
|
||||
// Only return the columns with default values that are not in the insert whitelist
|
||||
returnColumns := boil.SetComplement({{$varNameSingular}}ColumnsWithDefault, whitelist)
|
||||
|
||||
var err error
|
||||
if err := o.doBeforeCreateHooks(); err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
var rowID int
|
||||
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
|
||||
ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} (%s) VALUES (%s)`, strings.Join(whitelist, ","), boil.GenerateParamFlags(len(whitelist), 1))
|
||||
|
||||
{{if supportsResultObject .DriverName}}
|
||||
if len(returnColumns) != 0 {
|
||||
result, err := exec.Exec(ins, boil.GetStructValues(o, whitelist...))
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||
}
|
||||
|
||||
lastId, err := result.lastInsertId()
|
||||
if err != nil || lastId == 0 {
|
||||
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, ","), boil.WhereClause(whitelist))
|
||||
rows, err := exec.Query(sel, boil.GetStructValues(o, whitelist...))
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
i := 0
|
||||
ptrs := boil.GetStructPointers(o, returnColumns...)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(ptrs[i]); err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to get result of insert, scan failed for column %s index %d: %s\n\n%#v", returnColumns[i], i, err, ptrs)
|
||||
}
|
||||
i++
|
||||
}
|
||||
} else if {{$varNameSingular}}AutoIncPrimKey != "" {
|
||||
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s=$1`, strings.Join(returnColumns, ","), {{$varNameSingular}}AutoIncPrimaryKey, lastId)
|
||||
}
|
||||
} else {
|
||||
_, err = exec.Exec(ins, boil.GetStructValues(o, whitelist...))
|
||||
}
|
||||
{{else}}
|
||||
if len(returnColumns) != 0 {
|
||||
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
|
||||
err = exec.QueryRow(ins, boil.GetStructValues(o, whitelist...)).Scan(boil.GetStructPointers(o, returnColumns...))
|
||||
} else {
|
||||
_, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}})
|
||||
}
|
||||
{{end}}
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
|
||||
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||
}
|
||||
|
||||
if err := o.doAfterCreateHooks(); err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
return rowID, nil
|
||||
return nil
|
||||
}
|
||||
{{- end -}}
|
||||
|
|
|
@ -2,12 +2,15 @@ package strmangle
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
)
|
||||
|
||||
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
|
||||
|
||||
// Plural converts singular words to plural words (eg: person to people)
|
||||
func Plural(name string) string {
|
||||
splits := strings.Split(name, "_")
|
||||
|
@ -247,6 +250,18 @@ func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
|
|||
return strings.Join(output, ", ")
|
||||
}
|
||||
|
||||
// GenerateParamFlags generates the SQL statement parameter flags
|
||||
// For example, $1,$2,$3 etc. It will start counting at startAt.
|
||||
func GenerateParamFlags(colCount int, startAt int) string {
|
||||
cols := make([]string, 0, colCount)
|
||||
|
||||
for i := startAt; i < colCount+startAt; i++ {
|
||||
cols = append(cols, fmt.Sprintf("$%d", i))
|
||||
}
|
||||
|
||||
return strings.Join(cols, ",")
|
||||
}
|
||||
|
||||
// WherePrimaryKey returns the where clause using start as the $ flag index
|
||||
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
||||
func WherePrimaryKey(pkeyCols []string, start int) string {
|
||||
|
@ -280,6 +295,26 @@ func PrimaryKeyStrList(pkeyCols []string) string {
|
|||
return strings.Join(cols, ", ")
|
||||
}
|
||||
|
||||
// AutoIncPrimKey returns the auto-increment primary key column name or an empty string
|
||||
func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string {
|
||||
if pkey == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, c := range cols {
|
||||
if rgxAutoIncColumn.MatchString(c.Default) &&
|
||||
c.IsNullable == false && c.Type == "int64" {
|
||||
for _, p := range pkey.Columns {
|
||||
if c.Name == p {
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// CommaList returns a comma seperated list: "col1, col2, col3"
|
||||
func CommaList(cols []string) string {
|
||||
return strings.Join(cols, ", ")
|
||||
|
@ -307,3 +342,32 @@ func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) str
|
|||
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
||||
return len(regularCols) - len(pkeyCols) + 1
|
||||
}
|
||||
|
||||
// SupportsResult returns whether the database driver supports the sql.Results
|
||||
// interface, i.e. LastReturnId and RowsAffected
|
||||
func SupportsResultObject(driverName string) bool {
|
||||
switch driverName {
|
||||
case "postgres":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// FilterColumnsByDefault generates the list of columns that have default values
|
||||
func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string {
|
||||
var cols []string
|
||||
|
||||
for _, c := range columns {
|
||||
if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) {
|
||||
cols = append(cols, fmt.Sprintf(`"%s"`, c.Name))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(cols, `,`)
|
||||
}
|
||||
|
||||
// DEFAULT WHITELIST: The things that are not default values. The things we want to insert all the time.
|
||||
// WHITELIST: The things that we will NEVER return. The things that we will ALWAYS insert.
|
||||
// DEFAULTS: The things that we will return (if not in WHITELIST)
|
||||
// NON-ZEROS: The things that we will return (if not in WHITELIST)
|
||||
|
|
|
@ -11,6 +11,73 @@ var testColumns = []dbdrivers.Column{
|
|||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||
}
|
||||
|
||||
func TestAutoIncPrimaryKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var pkey *dbdrivers.PrimaryKey
|
||||
var cols []dbdrivers.Column
|
||||
|
||||
r := AutoIncPrimaryKey(cols, pkey)
|
||||
if r != "" {
|
||||
t.Errorf("Expected empty string, got %s", r)
|
||||
}
|
||||
|
||||
pkey = &dbdrivers.PrimaryKey{
|
||||
Columns: []string{
|
||||
"col1", "auto",
|
||||
},
|
||||
Name: "",
|
||||
}
|
||||
|
||||
cols = []dbdrivers.Column{
|
||||
{
|
||||
Name: "thing",
|
||||
IsNullable: true,
|
||||
Type: "int64",
|
||||
Default: "nextval('abc'::regclass)",
|
||||
},
|
||||
{
|
||||
Name: "stuff",
|
||||
IsNullable: false,
|
||||
Type: "string",
|
||||
Default: "nextval('abc'::regclass)",
|
||||
},
|
||||
{
|
||||
Name: "other",
|
||||
IsNullable: false,
|
||||
Type: "int64",
|
||||
Default: "nextval",
|
||||
},
|
||||
}
|
||||
|
||||
r = AutoIncPrimaryKey(cols, pkey)
|
||||
if r != "" {
|
||||
t.Errorf("Expected empty string, got %s", r)
|
||||
}
|
||||
|
||||
cols = append(cols, dbdrivers.Column{
|
||||
Name: "auto",
|
||||
IsNullable: false,
|
||||
Type: "int64",
|
||||
Default: "nextval('abc'::regclass)",
|
||||
})
|
||||
|
||||
r = AutoIncPrimaryKey(cols, pkey)
|
||||
if r != "auto" {
|
||||
t.Errorf("Expected empty string, got %s", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateParamFlags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
x := GenerateParamFlags(5, 1)
|
||||
want := "$1,$2,$3,$4,$5"
|
||||
if want != x {
|
||||
t.Errorf("want %s, got %s", want, x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingular(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -263,3 +330,41 @@ func TestWherePrimaryKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterColumnsByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cols := []dbdrivers.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Default: "",
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Default: "things",
|
||||
},
|
||||
{
|
||||
Name: "col3",
|
||||
Default: "",
|
||||
},
|
||||
{
|
||||
Name: "col4",
|
||||
Default: "things2",
|
||||
},
|
||||
}
|
||||
|
||||
res := FilterColumnsByDefault(cols, false)
|
||||
if res != `"col1","col3"` {
|
||||
t.Errorf("Invalid result: %s", res)
|
||||
}
|
||||
|
||||
res = FilterColumnsByDefault(cols, true)
|
||||
if res != `"col2","col4"` {
|
||||
t.Errorf("Invalid result: %s", res)
|
||||
}
|
||||
|
||||
res = FilterColumnsByDefault([]dbdrivers.Column{}, false)
|
||||
if res != `` {
|
||||
t.Errorf("Invalid result: %s", res)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue