Finished insert template

* Removed functions in helpers not being used
This commit is contained in:
Patrick O'brien 2016-05-02 16:34:25 +10:00
parent a957bc3836
commit f059bdebf4
10 changed files with 576 additions and 98 deletions

View file

@ -49,6 +49,22 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
return typ, isSlice, nil return typ, isSlice, nil
} }
// GetStructValues returns the values (as interface) of the matching columns in obj
func GetStructValues(obj interface{}, columns ...string) []interface{} {
ret := make([]interface{}, len(columns))
val := reflect.ValueOf(obj)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
for i, c := range columns {
field := val.FieldByName(strmangle.TitleCase(c))
ret[i] = field.Interface()
}
return ret
}
// GetStructPointers returns a slice of pointers to the matching columns in obj // GetStructPointers returns a slice of pointers to the matching columns in obj
func GetStructPointers(obj interface{}, columns ...string) []interface{} { func GetStructPointers(obj interface{}, columns ...string) []interface{} {
val := reflect.ValueOf(obj).Elem() val := reflect.ValueOf(obj).Elem()

View file

@ -1,6 +1,54 @@
package boil package boil
import "testing" import (
"testing"
"time"
"github.com/guregu/null"
)
func TestGetStructValues(t *testing.T) {
t.Parallel()
timeThing := time.Now()
o := struct {
TitleThing string
Name string
ID int
Stuff int
Things int
Time time.Time
NullBool null.Bool
}{
TitleThing: "patrick",
Stuff: 10,
Things: 0,
Time: timeThing,
NullBool: null.NewBool(true, false),
}
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
if vals[0].(string) != "patrick" {
t.Errorf("Want test, got %s", vals[0])
}
if vals[1].(string) != "" {
t.Errorf("Want empty string, got %s", vals[1])
}
if vals[2].(int) != 0 {
t.Errorf("Want 0, got %d", vals[2])
}
if vals[3].(int) != 10 {
t.Errorf("Want 10, got %d", vals[3])
}
if vals[4].(int) != 0 {
t.Errorf("Want 0, got %d", vals[4])
}
if !vals[5].(time.Time).Equal(timeThing) {
t.Errorf("Want %s, got %s", o.Time, vals[5])
}
if !vals[6].(null.Bool).IsZero() {
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
}
}
func TestGetStructPointers(t *testing.T) { func TestGetStructPointers(t *testing.T) {
t.Parallel() t.Parallel()

View file

@ -7,8 +7,79 @@ import (
"sort" "sort"
"strings" "strings"
"unicode" "unicode"
"github.com/pobri19/sqlboiler/strmangle"
) )
// SetComplement subtracts the elements in b from a
func SetComplement(a []string, b []string) []string {
c := make([]string, 0, len(a))
for _, aVal := range a {
found := false
for _, bVal := range b {
if aVal == bVal {
found = true
break
}
}
if !found {
c = append(c, aVal)
}
}
return c
}
// SetIntersect returns the elements that are common in a and b
func SetIntersect(a []string, b []string) []string {
c := make([]string, 0, len(a))
for _, aVal := range a {
found := false
for _, bVal := range b {
if aVal == bVal {
found = true
break
}
}
if found {
c = append(c, aVal)
}
}
return c
}
// NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
c := make([]string, 0, len(defaults))
val := reflect.ValueOf(obj)
for _, d := range defaults {
fieldName := strmangle.TitleCase(d)
field := val.FieldByName(fieldName)
if !field.IsValid() {
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
}
zero := reflect.Zero(field.Type())
if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
c = append(c, d)
}
}
return c
}
// GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string {
return strmangle.GenerateParamFlags(colCount, startAt)
}
// WherePrimaryKeyIn generates a "in" string for where queries // WherePrimaryKeyIn generates a "in" string for where queries
// For example: (col1, col2) IN (($1, $2), ($3, $4)) // For example: (col1, col2) IN (($1, $2), ($3, $4))
func WherePrimaryKeyIn(numRows int, keyNames ...string) string { func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
@ -81,15 +152,13 @@ func SelectNames(results interface{}) string {
// WhereClause returns the where clause for an sql statement // WhereClause returns the where clause for an sql statement
// eg: col1=$1 AND col2=$2 AND col3=$3 // eg: col1=$1 AND col2=$2 AND col3=$3
func WhereClause(columns map[string]interface{}) string { func WhereClause(columns []string) string {
names := make([]string, 0, len(columns)) names := make([]string, 0, len(columns))
for c := range columns { for _, c := range columns {
names = append(names, c) names = append(names, c)
} }
sort.Strings(names)
for i, c := range names { for i, c := range names {
names[i] = fmt.Sprintf("%s=$%d", c, i+1) names[i] = fmt.Sprintf("%s=$%d", c, i+1)
} }
@ -115,24 +184,6 @@ func Update(columns map[string]interface{}) string {
return strings.Join(names, ",") return strings.Join(names, ",")
} }
// WhereParams returns a list of sql parameter values for the query
func WhereParams(columns map[string]interface{}) []interface{} {
names := make([]string, 0, len(columns))
results := make([]interface{}, 0, len(columns))
for c := range columns {
names = append(names, c)
}
sort.Strings(names)
for _, c := range names {
results = append(results, columns[c])
}
return results
}
// SetParamNames takes a slice of columns and returns a comma seperated // SetParamNames takes a slice of columns and returns a comma seperated
// list of parameter names for a template statement SET clause. // list of parameter names for a template statement SET clause.
// eg: col1=$1,col2=$2,col3=$3 // eg: col1=$1,col2=$2,col3=$3

View file

@ -1,8 +1,11 @@
package boil package boil
import ( import (
"reflect"
"testing" "testing"
"time" "time"
"github.com/guregu/null"
) )
type testObj struct { type testObj struct {
@ -11,6 +14,134 @@ type testObj struct {
HeadSize int HeadSize int
} }
func TestSetComplement(t *testing.T) {
t.Parallel()
tests := []struct {
A []string
B []string
C []string
}{
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing2", "otherthing", "stuff"},
[]string{"thing1", "thing3"},
},
{
[]string{},
[]string{"thing1", "thing2"},
[]string{},
},
{
[]string{"thing1", "thing2"},
[]string{},
[]string{"thing1", "thing2"},
},
{
[]string{"thing1", "thing2"},
[]string{"thing1", "thing2"},
[]string{},
},
}
for i, test := range tests {
c := SetComplement(test.A, test.B)
if !reflect.DeepEqual(test.C, c) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
}
}
}
func TestSetIntersect(t *testing.T) {
t.Parallel()
tests := []struct {
A []string
B []string
C []string
}{
{
[]string{"thing1", "thing2", "thing3"},
[]string{"thing2", "otherthing", "stuff"},
[]string{"thing2"},
},
{
[]string{},
[]string{"thing1", "thing2"},
[]string{},
},
{
[]string{"thing1", "thing2"},
[]string{},
[]string{},
},
{
[]string{"thing1", "thing2"},
[]string{"thing1", "thing2"},
[]string{"thing1", "thing2"},
},
}
for i, test := range tests {
c := SetIntersect(test.A, test.B)
if !reflect.DeepEqual(test.C, c) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
}
}
}
func TestNonZeroDefaultSet(t *testing.T) {
t.Parallel()
type Anything struct {
ID int
Name string
CreatedAt *time.Time
UpdatedAt null.Time
}
now := time.Now()
tests := []struct {
Defaults []string
Obj interface{}
Ret []string
}{
{
[]string{"id"},
Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{},
},
{
[]string{"id"},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{"id"},
},
{
[]string{},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{},
},
{
[]string{"id", "created_at", "updated_at"},
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
[]string{"id"},
},
{
[]string{"id", "created_at", "updated_at"},
Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}},
[]string{"id", "created_at", "updated_at"},
},
}
for i, test := range tests {
z := NonZeroDefaultSet(test.Defaults, test.Obj)
if !reflect.DeepEqual(test.Ret, z) {
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
}
}
}
func TestWherePrimaryKeyIn(t *testing.T) { func TestWherePrimaryKeyIn(t *testing.T) {
t.Parallel() t.Parallel()
@ -105,34 +236,15 @@ func TestSelectNames(t *testing.T) {
func TestWhereClause(t *testing.T) { func TestWhereClause(t *testing.T) {
t.Parallel() t.Parallel()
columns := map[string]interface{}{ columns := []string{
"name": "bob", "id",
"id": 5, "name",
"date": time.Now(), "date",
} }
result := WhereClause(columns) result := WhereClause(columns)
if result != `date=$1 AND id=$2 AND name=$3` { if result != `id=$1 AND name=$2 AND date=$3` {
t.Error("Result was wrong, got:", result) t.Error("Result was wrong, got:", result)
} }
} }
func TestWhereParams(t *testing.T) {
t.Parallel()
columns := map[string]interface{}{
"name": "bob",
"id": 5,
}
result := WhereParams(columns)
if result[0].(int) != 5 {
t.Error("Result[0] was wrong, got:", result[0])
}
if result[1].(string) != "bob" {
t.Error("Result[1] was wrong, got:", result[1])
}
}

View file

@ -37,6 +37,7 @@ var sqlBoilerImports = imports{
standard: importList{ standard: importList{
`"errors"`, `"errors"`,
`"fmt"`, `"fmt"`,
`"strings"`,
}, },
thirdparty: importList{ thirdparty: importList{
`"github.com/pobri19/sqlboiler/boil"`, `"github.com/pobri19/sqlboiler/boil"`,
@ -84,30 +85,33 @@ var sqlBoilerTestMainImports = map[string]imports{
// sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates. // sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates.
// If you wish to pass a new function into your own template, add a pointer to it here. // If you wish to pass a new function into your own template, add a pointer to it here.
var sqlBoilerTemplateFuncs = template.FuncMap{ var sqlBoilerTemplateFuncs = template.FuncMap{
"singular": strmangle.Singular, "singular": strmangle.Singular,
"plural": strmangle.Plural, "plural": strmangle.Plural,
"titleCase": strmangle.TitleCase, "titleCase": strmangle.TitleCase,
"titleCaseSingular": strmangle.TitleCaseSingular, "titleCaseSingular": strmangle.TitleCaseSingular,
"titleCasePlural": strmangle.TitleCasePlural, "titleCasePlural": strmangle.TitleCasePlural,
"camelCase": strmangle.CamelCase, "camelCase": strmangle.CamelCase,
"camelCaseSingular": strmangle.CamelCaseSingular, "camelCaseSingular": strmangle.CamelCaseSingular,
"camelCasePlural": strmangle.CamelCasePlural, "camelCasePlural": strmangle.CamelCasePlural,
"camelCaseCommaList": strmangle.CamelCaseCommaList, "camelCaseCommaList": strmangle.CamelCaseCommaList,
"commaList": strmangle.CommaList, "commaList": strmangle.CommaList,
"makeDBName": strmangle.MakeDBName, "makeDBName": strmangle.MakeDBName,
"selectParamNames": strmangle.SelectParamNames, "selectParamNames": strmangle.SelectParamNames,
"insertParamNames": strmangle.InsertParamNames, "insertParamNames": strmangle.InsertParamNames,
"insertParamFlags": strmangle.InsertParamFlags, "insertParamFlags": strmangle.InsertParamFlags,
"insertParamVariables": strmangle.InsertParamVariables, "insertParamVariables": strmangle.InsertParamVariables,
"scanParamNames": strmangle.ScanParamNames, "scanParamNames": strmangle.ScanParamNames,
"hasPrimaryKey": strmangle.HasPrimaryKey, "hasPrimaryKey": strmangle.HasPrimaryKey,
"primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig, "primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig,
"wherePrimaryKey": strmangle.WherePrimaryKey, "wherePrimaryKey": strmangle.WherePrimaryKey,
"paramsPrimaryKey": strmangle.ParamsPrimaryKey, "paramsPrimaryKey": strmangle.ParamsPrimaryKey,
"primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex, "primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex,
"updateParamNames": strmangle.UpdateParamNames, "updateParamNames": strmangle.UpdateParamNames,
"updateParamVariables": strmangle.UpdateParamVariables, "updateParamVariables": strmangle.UpdateParamVariables,
"primaryKeyStrList": strmangle.PrimaryKeyStrList, "primaryKeyStrList": strmangle.PrimaryKeyStrList,
"supportsResultObject": strmangle.SupportsResultObject,
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
} }
// LoadConfigFile loads the toml config file into the cfg object // LoadConfigFile loads the toml config file into the cfg object

View file

@ -229,6 +229,26 @@ func initTables(tableName string, cmdData *CmdData) error {
return errors.New("No tables found in database, migrate some tables first") return errors.New("No tables found in database, migrate some tables first")
} }
if err := checkPKeys(cmdData.Tables); err != nil {
return err
}
return nil
}
// checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error {
var missingPkey []string
for _, t := range tables {
if t.PKey == nil {
missingPkey = append(missingPkey, t.Name)
}
}
if len(missingPkey) != 0 {
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
}
return nil return nil
} }

View file

@ -40,8 +40,23 @@ func init() {
{ {
Name: "spiderman", Name: "spiderman",
Columns: []dbdrivers.Column{ Columns: []dbdrivers.Column{
{Name: "id", Type: "int64", IsNullable: false},
},
PKey: &dbdrivers.PrimaryKey{
Name: "pkey_id",
Columns: []string{"id"},
},
},
{
Name: "spiderman_table_two",
Columns: []dbdrivers.Column{
{Name: "id", Type: "int64", IsNullable: false},
{Name: "patrick", Type: "string", IsNullable: false}, {Name: "patrick", Type: "string", IsNullable: false},
}, },
PKey: &dbdrivers.PrimaryKey{
Name: "pkey_id",
Columns: []string{"id"},
},
}, },
}, },
PkgName: "patrick", PkgName: "patrick",
@ -69,6 +84,10 @@ func TestTemplates(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
if err := checkPKeys(cmdData.Tables); err != nil {
t.Fatalf("%s", err)
}
// Initialize the templates // Initialize the templates
var err error var err error
cmdData.Templates, err = loadTemplates("templates") cmdData.Templates, err = loadTemplates("templates")

View file

@ -1,47 +1,86 @@
{{- if hasPrimaryKey .Table.PKey -}}
{{- $tableNameSingular := titleCaseSingular .Table.Name -}} {{- $tableNameSingular := titleCaseSingular .Table.Name -}}
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
// {{$tableNameSingular}}Insert inserts a single record. // {{$tableNameSingular}}Insert inserts a single record.
func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error { func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error {
if o == nil { return o.InsertX(boil.GetDB(), whitelist...)
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
}
if err := o.doBeforeCreateHooks(); err != nil {
return 0, err
}
var rowID int
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
if err != nil {
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
}
if err := o.doAfterCreateHooks(); err != nil {
return 0, err
}
return rowID, nil
} }
var {{$varNameSingular}}DefaultInsertWhitelist = []string{{"{"}}{{filterColumnsByDefault .Table.Columns false}}{{"}"}}
var {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{filterColumnsByDefault .Table.Columns true}}{{"}"}}
var {{$varNameSingular}}AutoIncPrimaryKey = "{{autoIncPrimaryKey .Table.Columns .Table.PKey}}"
func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error { func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error {
if o == nil { if o == nil {
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
} }
if len(whitelist) == 0 {
whitelist = {{$varNameSingular}}DefaultInsertWhitelist
}
nzDefaultSet := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
if len(nzDefaultSet) != 0 {
whitelist = append(nzDefaultSet, whitelist...)
}
// Only return the columns with default values that are not in the insert whitelist
returnColumns := boil.SetComplement({{$varNameSingular}}ColumnsWithDefault, whitelist)
var err error
if err := o.doBeforeCreateHooks(); err != nil { if err := o.doBeforeCreateHooks(); err != nil {
return 0, err return err
} }
var rowID int ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} (%s) VALUES (%s)`, strings.Join(whitelist, ","), boil.GenerateParamFlags(len(whitelist), 1))
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
{{if supportsResultObject .DriverName}}
if len(returnColumns) != 0 {
result, err := exec.Exec(ins, boil.GetStructValues(o, whitelist...))
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
}
lastId, err := result.lastInsertId()
if err != nil || lastId == 0 {
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, ","), boil.WhereClause(whitelist))
rows, err := exec.Query(sel, boil.GetStructValues(o, whitelist...))
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
}
defer rows.Close()
i := 0
ptrs := boil.GetStructPointers(o, returnColumns...)
for rows.Next() {
if err := rows.Scan(ptrs[i]); err != nil {
return fmt.Errorf("{{.PkgName}}: unable to get result of insert, scan failed for column %s index %d: %s\n\n%#v", returnColumns[i], i, err, ptrs)
}
i++
}
} else if {{$varNameSingular}}AutoIncPrimKey != "" {
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s=$1`, strings.Join(returnColumns, ","), {{$varNameSingular}}AutoIncPrimaryKey, lastId)
}
} else {
_, err = exec.Exec(ins, boil.GetStructValues(o, whitelist...))
}
{{else}}
if len(returnColumns) != 0 {
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
err = exec.QueryRow(ins, boil.GetStructValues(o, whitelist...)).Scan(boil.GetStructPointers(o, returnColumns...))
} else {
_, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}})
}
{{end}}
if err != nil { if err != nil {
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err) return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
} }
if err := o.doAfterCreateHooks(); err != nil { if err := o.doAfterCreateHooks(); err != nil {
return 0, err return err
} }
return rowID, nil return nil
} }
{{- end -}}

View file

@ -2,12 +2,15 @@ package strmangle
import ( import (
"fmt" "fmt"
"regexp"
"strings" "strings"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"github.com/pobri19/sqlboiler/dbdrivers" "github.com/pobri19/sqlboiler/dbdrivers"
) )
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
// Plural converts singular words to plural words (eg: person to people) // Plural converts singular words to plural words (eg: person to people)
func Plural(name string) string { func Plural(name string) string {
splits := strings.Split(name, "_") splits := strings.Split(name, "_")
@ -247,6 +250,18 @@ func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
return strings.Join(output, ", ") return strings.Join(output, ", ")
} }
// GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string {
cols := make([]string, 0, colCount)
for i := startAt; i < colCount+startAt; i++ {
cols = append(cols, fmt.Sprintf("$%d", i))
}
return strings.Join(cols, ",")
}
// WherePrimaryKey returns the where clause using start as the $ flag index // WherePrimaryKey returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(pkeyCols []string, start int) string { func WherePrimaryKey(pkeyCols []string, start int) string {
@ -280,6 +295,26 @@ func PrimaryKeyStrList(pkeyCols []string) string {
return strings.Join(cols, ", ") return strings.Join(cols, ", ")
} }
// AutoIncPrimKey returns the auto-increment primary key column name or an empty string
func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string {
if pkey == nil {
return ""
}
for _, c := range cols {
if rgxAutoIncColumn.MatchString(c.Default) &&
c.IsNullable == false && c.Type == "int64" {
for _, p := range pkey.Columns {
if c.Name == p {
return p
}
}
}
}
return ""
}
// CommaList returns a comma seperated list: "col1, col2, col3" // CommaList returns a comma seperated list: "col1, col2, col3"
func CommaList(cols []string) string { func CommaList(cols []string) string {
return strings.Join(cols, ", ") return strings.Join(cols, ", ")
@ -307,3 +342,32 @@ func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) str
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int { func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
return len(regularCols) - len(pkeyCols) + 1 return len(regularCols) - len(pkeyCols) + 1
} }
// SupportsResult returns whether the database driver supports the sql.Results
// interface, i.e. LastReturnId and RowsAffected
func SupportsResultObject(driverName string) bool {
switch driverName {
case "postgres":
return false
default:
return true
}
}
// FilterColumnsByDefault generates the list of columns that have default values
func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string {
var cols []string
for _, c := range columns {
if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) {
cols = append(cols, fmt.Sprintf(`"%s"`, c.Name))
}
}
return strings.Join(cols, `,`)
}
// DEFAULT WHITELIST: The things that are not default values. The things we want to insert all the time.
// WHITELIST: The things that we will NEVER return. The things that we will ALWAYS insert.
// DEFAULTS: The things that we will return (if not in WHITELIST)
// NON-ZEROS: The things that we will return (if not in WHITELIST)

View file

@ -11,6 +11,73 @@ var testColumns = []dbdrivers.Column{
{Name: "enemy_column_thing", Type: "string", IsNullable: true}, {Name: "enemy_column_thing", Type: "string", IsNullable: true},
} }
func TestAutoIncPrimaryKey(t *testing.T) {
t.Parallel()
var pkey *dbdrivers.PrimaryKey
var cols []dbdrivers.Column
r := AutoIncPrimaryKey(cols, pkey)
if r != "" {
t.Errorf("Expected empty string, got %s", r)
}
pkey = &dbdrivers.PrimaryKey{
Columns: []string{
"col1", "auto",
},
Name: "",
}
cols = []dbdrivers.Column{
{
Name: "thing",
IsNullable: true,
Type: "int64",
Default: "nextval('abc'::regclass)",
},
{
Name: "stuff",
IsNullable: false,
Type: "string",
Default: "nextval('abc'::regclass)",
},
{
Name: "other",
IsNullable: false,
Type: "int64",
Default: "nextval",
},
}
r = AutoIncPrimaryKey(cols, pkey)
if r != "" {
t.Errorf("Expected empty string, got %s", r)
}
cols = append(cols, dbdrivers.Column{
Name: "auto",
IsNullable: false,
Type: "int64",
Default: "nextval('abc'::regclass)",
})
r = AutoIncPrimaryKey(cols, pkey)
if r != "auto" {
t.Errorf("Expected empty string, got %s", r)
}
}
func TestGenerateParamFlags(t *testing.T) {
t.Parallel()
x := GenerateParamFlags(5, 1)
want := "$1,$2,$3,$4,$5"
if want != x {
t.Errorf("want %s, got %s", want, x)
}
}
func TestSingular(t *testing.T) { func TestSingular(t *testing.T) {
t.Parallel() t.Parallel()
@ -263,3 +330,41 @@ func TestWherePrimaryKey(t *testing.T) {
} }
} }
} }
func TestFilterColumnsByDefault(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
{
Name: "col1",
Default: "",
},
{
Name: "col2",
Default: "things",
},
{
Name: "col3",
Default: "",
},
{
Name: "col4",
Default: "things2",
},
}
res := FilterColumnsByDefault(cols, false)
if res != `"col1","col3"` {
t.Errorf("Invalid result: %s", res)
}
res = FilterColumnsByDefault(cols, true)
if res != `"col2","col4"` {
t.Errorf("Invalid result: %s", res)
}
res = FilterColumnsByDefault([]dbdrivers.Column{}, false)
if res != `` {
t.Errorf("Invalid result: %s", res)
}
}