diff --git a/boil/bind.go b/boil/bind.go index ba6c8ee..cc090b9 100644 --- a/boil/bind.go +++ b/boil/bind.go @@ -49,6 +49,22 @@ func checkType(obj interface{}) (reflect.Type, bool, error) { return typ, isSlice, nil } +// GetStructValues returns the values (as interface) of the matching columns in obj +func GetStructValues(obj interface{}, columns ...string) []interface{} { + ret := make([]interface{}, len(columns)) + val := reflect.ValueOf(obj) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + for i, c := range columns { + field := val.FieldByName(strmangle.TitleCase(c)) + ret[i] = field.Interface() + } + + return ret +} + // GetStructPointers returns a slice of pointers to the matching columns in obj func GetStructPointers(obj interface{}, columns ...string) []interface{} { val := reflect.ValueOf(obj).Elem() diff --git a/boil/bind_test.go b/boil/bind_test.go index d3f7463..e4e530e 100644 --- a/boil/bind_test.go +++ b/boil/bind_test.go @@ -1,6 +1,54 @@ package boil -import "testing" +import ( + "testing" + "time" + + "github.com/guregu/null" +) + +func TestGetStructValues(t *testing.T) { + t.Parallel() + timeThing := time.Now() + o := struct { + TitleThing string + Name string + ID int + Stuff int + Things int + Time time.Time + NullBool null.Bool + }{ + TitleThing: "patrick", + Stuff: 10, + Things: 0, + Time: timeThing, + NullBool: null.NewBool(true, false), + } + + vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") + if vals[0].(string) != "patrick" { + t.Errorf("Want test, got %s", vals[0]) + } + if vals[1].(string) != "" { + t.Errorf("Want empty string, got %s", vals[1]) + } + if vals[2].(int) != 0 { + t.Errorf("Want 0, got %d", vals[2]) + } + if vals[3].(int) != 10 { + t.Errorf("Want 10, got %d", vals[3]) + } + if vals[4].(int) != 0 { + t.Errorf("Want 0, got %d", vals[4]) + } + if !vals[5].(time.Time).Equal(timeThing) { + t.Errorf("Want %s, got %s", o.Time, vals[5]) + } + if !vals[6].(null.Bool).IsZero() { + t.Errorf("Want %v, got %v", o.NullBool, vals[6]) + } +} func TestGetStructPointers(t *testing.T) { t.Parallel() diff --git a/boil/helpers.go b/boil/helpers.go index dd19b98..c1103dd 100644 --- a/boil/helpers.go +++ b/boil/helpers.go @@ -7,8 +7,79 @@ import ( "sort" "strings" "unicode" + + "github.com/pobri19/sqlboiler/strmangle" ) +// SetComplement subtracts the elements in b from a +func SetComplement(a []string, b []string) []string { + c := make([]string, 0, len(a)) + + for _, aVal := range a { + found := false + for _, bVal := range b { + if aVal == bVal { + found = true + break + } + } + if !found { + c = append(c, aVal) + } + } + + return c +} + +// SetIntersect returns the elements that are common in a and b +func SetIntersect(a []string, b []string) []string { + c := make([]string, 0, len(a)) + + for _, aVal := range a { + found := false + for _, bVal := range b { + if aVal == bVal { + found = true + break + } + } + if found { + c = append(c, aVal) + } + } + + return c +} + +// NonZeroDefaultSet returns the fields included in the +// defaults slice that are non zero values +func NonZeroDefaultSet(defaults []string, obj interface{}) []string { + c := make([]string, 0, len(defaults)) + + val := reflect.ValueOf(obj) + + for _, d := range defaults { + fieldName := strmangle.TitleCase(d) + field := val.FieldByName(fieldName) + if !field.IsValid() { + panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj)) + } + + zero := reflect.Zero(field.Type()) + if !reflect.DeepEqual(zero.Interface(), field.Interface()) { + c = append(c, d) + } + } + + return c +} + +// GenerateParamFlags generates the SQL statement parameter flags +// For example, $1,$2,$3 etc. It will start counting at startAt. +func GenerateParamFlags(colCount int, startAt int) string { + return strmangle.GenerateParamFlags(colCount, startAt) +} + // WherePrimaryKeyIn generates a "in" string for where queries // For example: (col1, col2) IN (($1, $2), ($3, $4)) func WherePrimaryKeyIn(numRows int, keyNames ...string) string { @@ -81,15 +152,13 @@ func SelectNames(results interface{}) string { // WhereClause returns the where clause for an sql statement // eg: col1=$1 AND col2=$2 AND col3=$3 -func WhereClause(columns map[string]interface{}) string { +func WhereClause(columns []string) string { names := make([]string, 0, len(columns)) - for c := range columns { + for _, c := range columns { names = append(names, c) } - sort.Strings(names) - for i, c := range names { names[i] = fmt.Sprintf("%s=$%d", c, i+1) } @@ -115,24 +184,6 @@ func Update(columns map[string]interface{}) string { return strings.Join(names, ",") } -// WhereParams returns a list of sql parameter values for the query -func WhereParams(columns map[string]interface{}) []interface{} { - names := make([]string, 0, len(columns)) - results := make([]interface{}, 0, len(columns)) - - for c := range columns { - names = append(names, c) - } - - sort.Strings(names) - - for _, c := range names { - results = append(results, columns[c]) - } - - return results -} - // SetParamNames takes a slice of columns and returns a comma seperated // list of parameter names for a template statement SET clause. // eg: col1=$1,col2=$2,col3=$3 diff --git a/boil/helpers_test.go b/boil/helpers_test.go index 988f56b..23acb01 100644 --- a/boil/helpers_test.go +++ b/boil/helpers_test.go @@ -1,8 +1,11 @@ package boil import ( + "reflect" "testing" "time" + + "github.com/guregu/null" ) type testObj struct { @@ -11,6 +14,134 @@ type testObj struct { HeadSize int } +func TestSetComplement(t *testing.T) { + t.Parallel() + + tests := []struct { + A []string + B []string + C []string + }{ + { + []string{"thing1", "thing2", "thing3"}, + []string{"thing2", "otherthing", "stuff"}, + []string{"thing1", "thing3"}, + }, + { + []string{}, + []string{"thing1", "thing2"}, + []string{}, + }, + { + []string{"thing1", "thing2"}, + []string{}, + []string{"thing1", "thing2"}, + }, + { + []string{"thing1", "thing2"}, + []string{"thing1", "thing2"}, + []string{}, + }, + } + + for i, test := range tests { + c := SetComplement(test.A, test.B) + if !reflect.DeepEqual(test.C, c) { + t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c) + } + } +} + +func TestSetIntersect(t *testing.T) { + t.Parallel() + + tests := []struct { + A []string + B []string + C []string + }{ + { + []string{"thing1", "thing2", "thing3"}, + []string{"thing2", "otherthing", "stuff"}, + []string{"thing2"}, + }, + { + []string{}, + []string{"thing1", "thing2"}, + []string{}, + }, + { + []string{"thing1", "thing2"}, + []string{}, + []string{}, + }, + { + []string{"thing1", "thing2"}, + []string{"thing1", "thing2"}, + []string{"thing1", "thing2"}, + }, + } + + for i, test := range tests { + c := SetIntersect(test.A, test.B) + if !reflect.DeepEqual(test.C, c) { + t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c) + } + } +} + +func TestNonZeroDefaultSet(t *testing.T) { + t.Parallel() + + type Anything struct { + ID int + Name string + CreatedAt *time.Time + UpdatedAt null.Time + } + + now := time.Now() + + tests := []struct { + Defaults []string + Obj interface{} + Ret []string + }{ + { + []string{"id"}, + Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, + []string{}, + }, + { + []string{"id"}, + Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, + []string{"id"}, + }, + { + []string{}, + Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, + []string{}, + }, + { + []string{"id", "created_at", "updated_at"}, + Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, + []string{"id"}, + }, + { + []string{"id", "created_at", "updated_at"}, + Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}}, + []string{"id", "created_at", "updated_at"}, + }, + } + + for i, test := range tests { + z := NonZeroDefaultSet(test.Defaults, test.Obj) + if !reflect.DeepEqual(test.Ret, z) { + t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z) + } + } +} + func TestWherePrimaryKeyIn(t *testing.T) { t.Parallel() @@ -105,34 +236,15 @@ func TestSelectNames(t *testing.T) { func TestWhereClause(t *testing.T) { t.Parallel() - columns := map[string]interface{}{ - "name": "bob", - "id": 5, - "date": time.Now(), + columns := []string{ + "id", + "name", + "date", } result := WhereClause(columns) - if result != `date=$1 AND id=$2 AND name=$3` { + if result != `id=$1 AND name=$2 AND date=$3` { t.Error("Result was wrong, got:", result) } } - -func TestWhereParams(t *testing.T) { - t.Parallel() - - columns := map[string]interface{}{ - "name": "bob", - "id": 5, - } - - result := WhereParams(columns) - - if result[0].(int) != 5 { - t.Error("Result[0] was wrong, got:", result[0]) - } - - if result[1].(string) != "bob" { - t.Error("Result[1] was wrong, got:", result[1]) - } -} diff --git a/cmds/config.go b/cmds/config.go index 73eb54e..d4360ec 100644 --- a/cmds/config.go +++ b/cmds/config.go @@ -37,6 +37,7 @@ var sqlBoilerImports = imports{ standard: importList{ `"errors"`, `"fmt"`, + `"strings"`, }, thirdparty: importList{ `"github.com/pobri19/sqlboiler/boil"`, @@ -84,30 +85,33 @@ var sqlBoilerTestMainImports = map[string]imports{ // sqlBoilerTemplateFuncs is a map of all the functions that get passed into the templates. // If you wish to pass a new function into your own template, add a pointer to it here. var sqlBoilerTemplateFuncs = template.FuncMap{ - "singular": strmangle.Singular, - "plural": strmangle.Plural, - "titleCase": strmangle.TitleCase, - "titleCaseSingular": strmangle.TitleCaseSingular, - "titleCasePlural": strmangle.TitleCasePlural, - "camelCase": strmangle.CamelCase, - "camelCaseSingular": strmangle.CamelCaseSingular, - "camelCasePlural": strmangle.CamelCasePlural, - "camelCaseCommaList": strmangle.CamelCaseCommaList, - "commaList": strmangle.CommaList, - "makeDBName": strmangle.MakeDBName, - "selectParamNames": strmangle.SelectParamNames, - "insertParamNames": strmangle.InsertParamNames, - "insertParamFlags": strmangle.InsertParamFlags, - "insertParamVariables": strmangle.InsertParamVariables, - "scanParamNames": strmangle.ScanParamNames, - "hasPrimaryKey": strmangle.HasPrimaryKey, - "primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig, - "wherePrimaryKey": strmangle.WherePrimaryKey, - "paramsPrimaryKey": strmangle.ParamsPrimaryKey, - "primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex, - "updateParamNames": strmangle.UpdateParamNames, - "updateParamVariables": strmangle.UpdateParamVariables, - "primaryKeyStrList": strmangle.PrimaryKeyStrList, + "singular": strmangle.Singular, + "plural": strmangle.Plural, + "titleCase": strmangle.TitleCase, + "titleCaseSingular": strmangle.TitleCaseSingular, + "titleCasePlural": strmangle.TitleCasePlural, + "camelCase": strmangle.CamelCase, + "camelCaseSingular": strmangle.CamelCaseSingular, + "camelCasePlural": strmangle.CamelCasePlural, + "camelCaseCommaList": strmangle.CamelCaseCommaList, + "commaList": strmangle.CommaList, + "makeDBName": strmangle.MakeDBName, + "selectParamNames": strmangle.SelectParamNames, + "insertParamNames": strmangle.InsertParamNames, + "insertParamFlags": strmangle.InsertParamFlags, + "insertParamVariables": strmangle.InsertParamVariables, + "scanParamNames": strmangle.ScanParamNames, + "hasPrimaryKey": strmangle.HasPrimaryKey, + "primaryKeyFuncSig": strmangle.PrimaryKeyFuncSig, + "wherePrimaryKey": strmangle.WherePrimaryKey, + "paramsPrimaryKey": strmangle.ParamsPrimaryKey, + "primaryKeyFlagIndex": strmangle.PrimaryKeyFlagIndex, + "updateParamNames": strmangle.UpdateParamNames, + "updateParamVariables": strmangle.UpdateParamVariables, + "primaryKeyStrList": strmangle.PrimaryKeyStrList, + "supportsResultObject": strmangle.SupportsResultObject, + "filterColumnsByDefault": strmangle.FilterColumnsByDefault, + "autoIncPrimaryKey": strmangle.AutoIncPrimaryKey, } // LoadConfigFile loads the toml config file into the cfg object diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index 38345e7..a95e2ff 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -229,6 +229,26 @@ func initTables(tableName string, cmdData *CmdData) error { return errors.New("No tables found in database, migrate some tables first") } + if err := checkPKeys(cmdData.Tables); err != nil { + return err + } + + return nil +} + +// checkPKeys ensures every table has a primary key column +func checkPKeys(tables []dbdrivers.Table) error { + var missingPkey []string + for _, t := range tables { + if t.PKey == nil { + missingPkey = append(missingPkey, t.Name) + } + } + + if len(missingPkey) != 0 { + return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", ")) + } + return nil } diff --git a/cmds/sqlboiler_test.go b/cmds/sqlboiler_test.go index 5672466..df17394 100644 --- a/cmds/sqlboiler_test.go +++ b/cmds/sqlboiler_test.go @@ -40,8 +40,23 @@ func init() { { Name: "spiderman", Columns: []dbdrivers.Column{ + {Name: "id", Type: "int64", IsNullable: false}, + }, + PKey: &dbdrivers.PrimaryKey{ + Name: "pkey_id", + Columns: []string{"id"}, + }, + }, + { + Name: "spiderman_table_two", + Columns: []dbdrivers.Column{ + {Name: "id", Type: "int64", IsNullable: false}, {Name: "patrick", Type: "string", IsNullable: false}, }, + PKey: &dbdrivers.PrimaryKey{ + Name: "pkey_id", + Columns: []string{"id"}, + }, }, }, PkgName: "patrick", @@ -69,6 +84,10 @@ func TestTemplates(t *testing.T) { t.SkipNow() } + if err := checkPKeys(cmdData.Tables); err != nil { + t.Fatalf("%s", err) + } + // Initialize the templates var err error cmdData.Templates, err = loadTemplates("templates") diff --git a/cmds/templates/insert.tpl b/cmds/templates/insert.tpl index 030848e..fcfb165 100644 --- a/cmds/templates/insert.tpl +++ b/cmds/templates/insert.tpl @@ -1,47 +1,86 @@ +{{- if hasPrimaryKey .Table.PKey -}} {{- $tableNameSingular := titleCaseSingular .Table.Name -}} +{{- $varNameSingular := camelCaseSingular .Table.Name -}} // {{$tableNameSingular}}Insert inserts a single record. func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error { - if o == nil { - return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") - } - - if err := o.doBeforeCreateHooks(); err != nil { - return 0, err - } - - var rowID int - err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID) - - if err != nil { - return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err) - } - - if err := o.doAfterCreateHooks(); err != nil { - return 0, err - } - - return rowID, nil + return o.InsertX(boil.GetDB(), whitelist...) } +var {{$varNameSingular}}DefaultInsertWhitelist = []string{{"{"}}{{filterColumnsByDefault .Table.Columns false}}{{"}"}} +var {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{filterColumnsByDefault .Table.Columns true}}{{"}"}} +var {{$varNameSingular}}AutoIncPrimaryKey = "{{autoIncPrimaryKey .Table.Columns .Table.PKey}}" + func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error { if o == nil { - return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") } + if len(whitelist) == 0 { + whitelist = {{$varNameSingular}}DefaultInsertWhitelist + } + + nzDefaultSet := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) + if len(nzDefaultSet) != 0 { + whitelist = append(nzDefaultSet, whitelist...) + } + + // Only return the columns with default values that are not in the insert whitelist + returnColumns := boil.SetComplement({{$varNameSingular}}ColumnsWithDefault, whitelist) + + var err error if err := o.doBeforeCreateHooks(); err != nil { - return 0, err + return err } - var rowID int - err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID) + ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} (%s) VALUES (%s)`, strings.Join(whitelist, ","), boil.GenerateParamFlags(len(whitelist), 1)) + + {{if supportsResultObject .DriverName}} + if len(returnColumns) != 0 { + result, err := exec.Exec(ins, boil.GetStructValues(o, whitelist...)) + if err != nil { + return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) + } + + lastId, err := result.lastInsertId() + if err != nil || lastId == 0 { + sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, ","), boil.WhereClause(whitelist)) + rows, err := exec.Query(sel, boil.GetStructValues(o, whitelist...)) + if err != nil { + return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) + } + defer rows.Close() + + i := 0 + ptrs := boil.GetStructPointers(o, returnColumns...) + for rows.Next() { + if err := rows.Scan(ptrs[i]); err != nil { + return fmt.Errorf("{{.PkgName}}: unable to get result of insert, scan failed for column %s index %d: %s\n\n%#v", returnColumns[i], i, err, ptrs) + } + i++ + } + } else if {{$varNameSingular}}AutoIncPrimKey != "" { + sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s=$1`, strings.Join(returnColumns, ","), {{$varNameSingular}}AutoIncPrimaryKey, lastId) + } + } else { + _, err = exec.Exec(ins, boil.GetStructValues(o, whitelist...)) + } + {{else}} + if len(returnColumns) != 0 { + ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) + err = exec.QueryRow(ins, boil.GetStructValues(o, whitelist...)).Scan(boil.GetStructPointers(o, returnColumns...)) + } else { + _, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}}) + } + {{end}} if err != nil { - return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err) + return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) } if err := o.doAfterCreateHooks(); err != nil { - return 0, err + return err } - return rowID, nil + return nil } +{{- end -}} diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 73a2d81..74731a9 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -2,12 +2,15 @@ package strmangle import ( "fmt" + "regexp" "strings" "github.com/jinzhu/inflection" "github.com/pobri19/sqlboiler/dbdrivers" ) +var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`) + // Plural converts singular words to plural words (eg: person to people) func Plural(name string) string { splits := strings.Split(name, "_") @@ -247,6 +250,18 @@ func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string { return strings.Join(output, ", ") } +// GenerateParamFlags generates the SQL statement parameter flags +// For example, $1,$2,$3 etc. It will start counting at startAt. +func GenerateParamFlags(colCount int, startAt int) string { + cols := make([]string, 0, colCount) + + for i := startAt; i < colCount+startAt; i++ { + cols = append(cols, fmt.Sprintf("$%d", i)) + } + + return strings.Join(cols, ",") +} + // WherePrimaryKey returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" func WherePrimaryKey(pkeyCols []string, start int) string { @@ -280,6 +295,26 @@ func PrimaryKeyStrList(pkeyCols []string) string { return strings.Join(cols, ", ") } +// AutoIncPrimKey returns the auto-increment primary key column name or an empty string +func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string { + if pkey == nil { + return "" + } + + for _, c := range cols { + if rgxAutoIncColumn.MatchString(c.Default) && + c.IsNullable == false && c.Type == "int64" { + for _, p := range pkey.Columns { + if c.Name == p { + return p + } + } + } + } + + return "" +} + // CommaList returns a comma seperated list: "col1, col2, col3" func CommaList(cols []string) string { return strings.Join(cols, ", ") @@ -307,3 +342,32 @@ func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) str func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int { return len(regularCols) - len(pkeyCols) + 1 } + +// SupportsResult returns whether the database driver supports the sql.Results +// interface, i.e. LastReturnId and RowsAffected +func SupportsResultObject(driverName string) bool { + switch driverName { + case "postgres": + return false + default: + return true + } +} + +// FilterColumnsByDefault generates the list of columns that have default values +func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string { + var cols []string + + for _, c := range columns { + if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) { + cols = append(cols, fmt.Sprintf(`"%s"`, c.Name)) + } + } + + return strings.Join(cols, `,`) +} + +// DEFAULT WHITELIST: The things that are not default values. The things we want to insert all the time. +// WHITELIST: The things that we will NEVER return. The things that we will ALWAYS insert. +// DEFAULTS: The things that we will return (if not in WHITELIST) +// NON-ZEROS: The things that we will return (if not in WHITELIST) diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index 1c53ed2..c122a4b 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -11,6 +11,73 @@ var testColumns = []dbdrivers.Column{ {Name: "enemy_column_thing", Type: "string", IsNullable: true}, } +func TestAutoIncPrimaryKey(t *testing.T) { + t.Parallel() + + var pkey *dbdrivers.PrimaryKey + var cols []dbdrivers.Column + + r := AutoIncPrimaryKey(cols, pkey) + if r != "" { + t.Errorf("Expected empty string, got %s", r) + } + + pkey = &dbdrivers.PrimaryKey{ + Columns: []string{ + "col1", "auto", + }, + Name: "", + } + + cols = []dbdrivers.Column{ + { + Name: "thing", + IsNullable: true, + Type: "int64", + Default: "nextval('abc'::regclass)", + }, + { + Name: "stuff", + IsNullable: false, + Type: "string", + Default: "nextval('abc'::regclass)", + }, + { + Name: "other", + IsNullable: false, + Type: "int64", + Default: "nextval", + }, + } + + r = AutoIncPrimaryKey(cols, pkey) + if r != "" { + t.Errorf("Expected empty string, got %s", r) + } + + cols = append(cols, dbdrivers.Column{ + Name: "auto", + IsNullable: false, + Type: "int64", + Default: "nextval('abc'::regclass)", + }) + + r = AutoIncPrimaryKey(cols, pkey) + if r != "auto" { + t.Errorf("Expected empty string, got %s", r) + } +} + +func TestGenerateParamFlags(t *testing.T) { + t.Parallel() + + x := GenerateParamFlags(5, 1) + want := "$1,$2,$3,$4,$5" + if want != x { + t.Errorf("want %s, got %s", want, x) + } +} + func TestSingular(t *testing.T) { t.Parallel() @@ -263,3 +330,41 @@ func TestWherePrimaryKey(t *testing.T) { } } } + +func TestFilterColumnsByDefault(t *testing.T) { + t.Parallel() + + cols := []dbdrivers.Column{ + { + Name: "col1", + Default: "", + }, + { + Name: "col2", + Default: "things", + }, + { + Name: "col3", + Default: "", + }, + { + Name: "col4", + Default: "things2", + }, + } + + res := FilterColumnsByDefault(cols, false) + if res != `"col1","col3"` { + t.Errorf("Invalid result: %s", res) + } + + res = FilterColumnsByDefault(cols, true) + if res != `"col2","col4"` { + t.Errorf("Invalid result: %s", res) + } + + res = FilterColumnsByDefault([]dbdrivers.Column{}, false) + if res != `` { + t.Errorf("Invalid result: %s", res) + } +}