Refactor dbdrivers into two packages

- Break dbdrivers into bdb and drivers
- Break each type in dbdrivers into it's own file set.
This commit is contained in:
Aaron L 2016-06-22 23:09:56 -07:00
parent 84a160f3a4
commit 60f6080e73
18 changed files with 351 additions and 404 deletions

46
bdb/column.go Normal file
View file

@ -0,0 +1,46 @@
package bdb
// Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType.
type Column struct {
Name string
Type string
Default string
IsNullable bool
}
// ColumnNames of the columns.
func ColumnNames(cols []Column) []string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.Name
}
return names
}
// FilterColumnsByDefault generates the list of columns that have default values
func FilterColumnsByDefault(columns []Column, defaults bool) []Column {
var cols []Column
for _, c := range columns {
if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) {
cols = append(cols, c)
}
}
return cols
}
// FilterColumnsByAutoIncrement generates the list of auto increment columns
func FilterColumnsByAutoIncrement(columns []Column) []Column {
var cols []Column
for _, c := range columns {
if rgxAutoIncColumn.MatchString(c.Default) {
cols = append(cols, c)
}
}
return cols
}

77
bdb/column_test.go Normal file
View file

@ -0,0 +1,77 @@
package bdb
import (
"strings"
"testing"
)
func TestColumnNames(t *testing.T) {
t.Parallel()
cols := []Column{
Column{Name: "one"},
Column{Name: "two"},
Column{Name: "three"},
}
out := strings.Join(ColumnNames(cols), " ")
if out != "one two three" {
t.Error("output was wrong:", out)
}
}
func TestFilterColumnsByDefault(t *testing.T) {
t.Parallel()
cols := []Column{
{Name: "col1", Default: ""},
{Name: "col2", Default: "things"},
{Name: "col3", Default: ""},
{Name: "col4", Default: "things2"},
}
res := FilterColumnsByDefault(cols, false)
if res[0].Name != `col1` {
t.Errorf("Invalid result: %#v", res)
}
if res[1].Name != `col3` {
t.Errorf("Invalid result: %#v", res)
}
res = FilterColumnsByDefault(cols, true)
if res[0].Name != `col2` {
t.Errorf("Invalid result: %#v", res)
}
if res[1].Name != `col4` {
t.Errorf("Invalid result: %#v", res)
}
res = FilterColumnsByDefault([]Column{}, false)
if res != nil {
t.Errorf("Invalid result: %#v", res)
}
}
func TestFilterColumnsByAutoIncrement(t *testing.T) {
t.Parallel()
cols := []Column{
{Name: "col1", Default: `nextval("thing"::thing)`},
{Name: "col2", Default: "things"},
{Name: "col3", Default: ""},
{Name: "col4", Default: `nextval("thing"::thing)`},
}
res := FilterColumnsByAutoIncrement(cols)
if res[0].Name != `col1` {
t.Errorf("Invalid result: %#v", res)
}
if res[1].Name != `col4` {
t.Errorf("Invalid result: %#v", res)
}
res = FilterColumnsByAutoIncrement([]Column{})
if res != nil {
t.Errorf("Invalid result: %#v", res)
}
}

View file

@ -1,4 +1,4 @@
package dbdrivers package drivers
import ( import (
"database/sql" "database/sql"

View file

@ -1,4 +1,5 @@
package dbdrivers // Package bdb supplies the sql(b)oiler (d)ata(b)ase abstractions.
package bdb
import "github.com/pkg/errors" import "github.com/pkg/errors"
@ -20,41 +21,6 @@ type Interface interface {
Close() Close()
} }
// Table metadata from the database schema.
type Table struct {
Name string
Columns []Column
PKey *PrimaryKey
FKeys []ForeignKey
IsJoinTable bool
}
// Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType.
type Column struct {
Name string
Type string
Default string
IsNullable bool
}
// PrimaryKey represents a primary key constraint in a database
type PrimaryKey struct {
Name string
Columns []string
}
// ForeignKey represents a foreign key constraint in a database
type ForeignKey struct {
Name string
Column string
ForeignTable string
ForeignColumn string
}
// Tables returns the table metadata for the given tables, or all tables if // Tables returns the table metadata for the given tables, or all tables if
// no tables are provided. // no tables are provided.
func Tables(db Interface, names ...string) ([]Table, error) { func Tables(db Interface, names ...string) ([]Table, error) {
@ -94,7 +60,7 @@ func Tables(db Interface, names ...string) ([]Table, error) {
} }
// setIsJoinTable iff there are: // setIsJoinTable iff there are:
// There is a composite primary key involving two columns // A composite primary key involving two columns
// Both primary key columns are also foreign keys // Both primary key columns are also foreign keys
func setIsJoinTable(t *Table) { func setIsJoinTable(t *Table) {
if t.PKey == nil || len(t.PKey.Columns) != 2 || len(t.FKeys) < 2 { if t.PKey == nil || len(t.PKey.Columns) != 2 || len(t.FKeys) < 2 {

View file

@ -1,4 +1,4 @@
package dbdrivers package bdb
import ( import (
"reflect" "reflect"

77
bdb/keys.go Normal file
View file

@ -0,0 +1,77 @@
package bdb
import (
"fmt"
"regexp"
"strings"
)
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
// PrimaryKey represents a primary key constraint in a database
type PrimaryKey struct {
Name string
Columns []string
}
// ForeignKey represents a foreign key constraint in a database
type ForeignKey struct {
Name string
Column string
ForeignTable string
ForeignColumn string
}
// SQLColumnDef formats a column name and type like an SQL column definition.
type SQLColumnDef struct {
Name string
Type string
}
func (s SQLColumnDef) String() string {
return fmt.Sprintf("%s %s", s.Name, s.Type)
}
// SQLColDefinitions creates a definition in sql format for a column
// example: id int64, thingName string
func SQLColDefinitions(cols []Column, names []string) []SQLColumnDef {
ret := make([]SQLColumnDef, len(names))
for i, n := range names {
for _, c := range cols {
if n != c.Name {
continue
}
ret[i] = SQLColumnDef{Name: n, Type: c.Type}
}
}
return ret
}
// AutoIncPrimaryKey returns the auto-increment primary key column name or an
// empty string.
func AutoIncPrimaryKey(cols []Column, pkey *PrimaryKey) (col Column, ok bool) {
if pkey == nil {
return col, false
}
for _, pkeyColumn := range pkey.Columns {
for _, c := range cols {
if c.Name != pkeyColumn {
continue
}
if !rgxAutoIncColumn.MatchString(c.Default) || c.IsNullable ||
!(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) {
continue
}
return c, true
}
}
return col, false
}

85
bdb/keys_test.go Normal file
View file

@ -0,0 +1,85 @@
package bdb
import "testing"
func TestSQLColDefinitions(t *testing.T) {
t.Parallel()
cols := []Column{
{Name: "one", Type: "int64"},
{Name: "two", Type: "string"},
{Name: "three", Type: "string"},
}
defs := SQLColDefinitions(cols, []string{"one"})
if len(defs) != 1 {
t.Error("wrong number of defs:", len(defs))
}
if got := defs[0].String(); got != "one int64" {
t.Error("wrong def:", got)
}
defs = SQLColDefinitions(cols, []string{"one", "three"})
if len(defs) != 2 {
t.Error("wrong number of defs:", len(defs))
}
if got := defs[0].String(); got != "one int64" {
t.Error("wrong def:", got)
}
if got := defs[1].String(); got != "three string" {
t.Error("wrong def:", got)
}
}
func TestAutoIncPrimaryKey(t *testing.T) {
t.Parallel()
tests := map[string]struct {
Ok bool
Expect Column
Pkey *PrimaryKey
Columns []Column
}{
"nillcase": {
Ok: false,
Pkey: nil,
Columns: nil,
},
"easycase": {
Ok: true,
Expect: Column{Name: "one", Type: "int32", IsNullable: false, Default: `nextval('abc'::regclass)`},
Pkey: &PrimaryKey{Name: "pkey", Columns: []string{"one"}},
Columns: []Column{Column{Name: "one", Type: "int32", IsNullable: false, Default: `nextval('abc'::regclass)`}},
},
"missingcase": {
Ok: false,
Pkey: &PrimaryKey{Name: "pkey", Columns: []string{"two"}},
Columns: []Column{Column{Name: "one", Type: "int32", IsNullable: false, Default: `nextval('abc'::regclass)`}},
},
"wrongtype": {
Ok: false,
Pkey: &PrimaryKey{Name: "pkey", Columns: []string{"one"}},
Columns: []Column{Column{Name: "one", Type: "string", IsNullable: false, Default: `nextval('abc'::regclass)`}},
},
"nodefault": {
Ok: false,
Pkey: &PrimaryKey{Name: "pkey", Columns: []string{"one"}},
Columns: []Column{Column{Name: "one", Type: "string", IsNullable: false, Default: ``}},
},
"nullable": {
Ok: false,
Pkey: &PrimaryKey{Name: "pkey", Columns: []string{"one"}},
Columns: []Column{Column{Name: "one", Type: "string", IsNullable: true, Default: `nextval('abc'::regclass)`}},
},
}
for testName, test := range tests {
pkey, ok := AutoIncPrimaryKey(test.Columns, test.Pkey)
if ok != test.Ok {
t.Errorf("%s) found state was wrong, want: %t, got: %t", testName, test.Ok, ok)
}
if pkey != test.Expect {
t.Errorf("%s) wrong primary key, want: %#v, got %#v", testName, test.Expect, pkey)
}
}
}

View file

@ -1,10 +1,4 @@
package dbdrivers package bdb
import (
"fmt"
"github.com/nullbio/sqlboiler/strmangle"
)
// ToManyRelationship describes a relationship between two tables where the // ToManyRelationship describes a relationship between two tables where the
// local table has no id, and the foreign table has an id that matches a column // local table has no id, and the foreign table has an id that matches a column
@ -30,19 +24,19 @@ func ToManyRelationships(table string, tables []Table) []ToManyRelationship {
continue continue
} }
singularName := strmangle.Singular(table) // singularName := strmangle.Singular(table)
standardColName := fmt.Sprintf("%s_id", singularName) // standardColName := fmt.Sprintf("%s_id", singularName)
relationship := ToManyRelationship{ relationship := ToManyRelationship{
ForeignTable: t.Name, ForeignTable: t.Name,
ForeignColumn: f.Column, ForeignColumn: f.Column,
} }
if standardColName == f.ForeignColumn { // if standardColName == f.ForeignColumn {
relationship.Name = strmangle.TitleCase(strmangle.Plural(name)) // relationship.Name = table
} else { // } else {
relationship.Name = strmangle.TitleCase(strmangle.Plural(name)) // relationship.Name = table
} // }
relationships = append(relationships, relationship) relationships = append(relationships, relationship)
} }

View file

@ -1,4 +1,4 @@
package dbdrivers package bdb
import "testing" import "testing"
@ -9,15 +9,15 @@ func TestToManyRelationships(t *testing.T) {
Table{ Table{
Name: "videos", Name: "videos",
FKeys: []ForeignKey{ FKeys: []ForeignKey{
{Name: "videos_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignKey: "id"}, {Name: "videos_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "videos_contest_id_fk", Column: "contest_id", ForeignTable: "contests", ForeignKey: "id"}, {Name: "videos_contest_id_fk", Column: "contest_id", ForeignTable: "contests", ForeignColumn: "id"},
}, },
}, },
Table{ Table{
Name: "notifications", Name: "notifications",
FKeys: []ForeignKey{ FKeys: []ForeignKey{
{Name: "notifications_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignKey: "id"}, {Name: "notifications_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "notifications_source_id_fk", Column: "source_id", ForeignTable: "users", ForeignKey: "id"}, {Name: "notifications_source_id_fk", Column: "source_id", ForeignTable: "users", ForeignColumn: "id"},
}, },
}, },
} }

12
bdb/table.go Normal file
View file

@ -0,0 +1,12 @@
package bdb
// Table metadata from the database schema.
type Table struct {
Name string
Columns []Column
PKey *PrimaryKey
FKeys []ForeignKey
IsJoinTable bool
}

View file

@ -1,7 +0,0 @@
package dbdrivers
// isJoinTable is true if table has at least 2 foreign keys and
// the two foreign keys are involved in a primary composite key
func isJoinTable(t Table) bool {
return false
}

View file

@ -6,7 +6,7 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/bdb"
) )
// imports defines the optional standard imports and // imports defines the optional standard imports and
@ -48,7 +48,7 @@ func combineImports(a, b imports) imports {
return c return c
} }
func combineTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports { func combineTypeImports(a imports, b map[string]imports, columns []bdb.Column) imports {
tmpImp := imports{ tmpImp := imports{
standard: make(importList, len(a.standard)), standard: make(importList, len(a.standard)),
thirdParty: make(importList, len(a.thirdParty)), thirdParty: make(importList, len(a.thirdParty)),

View file

@ -6,7 +6,7 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/bdb"
) )
func TestImportsSort(t *testing.T) { func TestImportsSort(t *testing.T) {
@ -77,17 +77,17 @@ func TestCombineTypeImports(t *testing.T) {
}, },
} }
cols := []dbdrivers.Column{ cols := []bdb.Column{
dbdrivers.Column{ bdb.Column{
Type: "null.Time", Type: "null.Time",
}, },
dbdrivers.Column{ bdb.Column{
Type: "null.Time", Type: "null.Time",
}, },
dbdrivers.Column{ bdb.Column{
Type: "time.Time", Type: "time.Time",
}, },
dbdrivers.Column{ bdb.Column{
Type: "null.Float", Type: "null.Float",
}, },
} }

View file

@ -8,7 +8,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/bdb"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -26,8 +26,8 @@ const (
type State struct { type State struct {
Config *Config Config *Config
Driver dbdrivers.Interface Driver bdb.Interface
Tables []dbdrivers.Table Tables []bdb.Table
Templates templateList Templates templateList
TestTemplates templateList TestTemplates templateList
@ -170,7 +170,7 @@ func (s *State) initDriver(driverName string) error {
// Create a driver based off driver flag // Create a driver based off driver flag
switch driverName { switch driverName {
case "postgres": case "postgres":
s.Driver = dbdrivers.NewPostgresDriver( s.Driver = bdb.NewPostgresDriver(
s.Config.Postgres.User, s.Config.Postgres.User,
s.Config.Postgres.Pass, s.Config.Postgres.Pass,
s.Config.Postgres.DBName, s.Config.Postgres.DBName,
@ -192,7 +192,7 @@ func (s *State) initDriver(driverName string) error {
// result. // result.
func (s *State) initTables(tableNames []string) error { func (s *State) initTables(tableNames []string) error {
var err error var err error
s.Tables, err = dbdrivers.Tables(s.Driver, tableNames...) s.Tables, err = bdb.Tables(s.Driver, tableNames...)
if err != nil { if err != nil {
return errors.Wrap(err, "unable to fetch table data") return errors.Wrap(err, "unable to fetch table data")
} }
@ -214,7 +214,7 @@ func (s *State) initOutFolder() error {
} }
// checkPKeys ensures every table has a primary key column // checkPKeys ensures every table has a primary key column
func checkPKeys(tables []dbdrivers.Table) error { func checkPKeys(tables []bdb.Table) error {
var missingPkey []string var missingPkey []string
for _, t := range tables { for _, t := range tables {
if t.PKey == nil { if t.PKey == nil {

View file

@ -12,7 +12,7 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/bdb"
) )
var state *State var state *State
@ -20,10 +20,10 @@ var rgxHasSpaces = regexp.MustCompile(`^\s+`)
func init() { func init() {
state = &State{ state = &State{
Tables: []dbdrivers.Table{ Tables: []bdb.Table{
{ {
Name: "patrick_table", Name: "patrick_table",
Columns: []dbdrivers.Column{ Columns: []bdb.Column{
{Name: "patrick_column", Type: "string", IsNullable: false}, {Name: "patrick_column", Type: "string", IsNullable: false},
{Name: "aaron_column", Type: "null.String", IsNullable: true}, {Name: "aaron_column", Type: "null.String", IsNullable: true},
{Name: "id", Type: "null.Int", IsNullable: true}, {Name: "id", Type: "null.Int", IsNullable: true},
@ -32,28 +32,28 @@ func init() {
{Name: "fun_time", Type: "time.Time", IsNullable: false}, {Name: "fun_time", Type: "time.Time", IsNullable: false},
{Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false}, {Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false},
}, },
PKey: &dbdrivers.PrimaryKey{ PKey: &bdb.PrimaryKey{
Name: "pkey_thing", Name: "pkey_thing",
Columns: []string{"id", "fun_id"}, Columns: []string{"id", "fun_id"},
}, },
}, },
{ {
Name: "spiderman", Name: "spiderman",
Columns: []dbdrivers.Column{ Columns: []bdb.Column{
{Name: "id", Type: "int64", IsNullable: false}, {Name: "id", Type: "int64", IsNullable: false},
}, },
PKey: &dbdrivers.PrimaryKey{ PKey: &bdb.PrimaryKey{
Name: "pkey_id", Name: "pkey_id",
Columns: []string{"id"}, Columns: []string{"id"},
}, },
}, },
{ {
Name: "spiderman_table_two", Name: "spiderman_table_two",
Columns: []dbdrivers.Column{ Columns: []bdb.Column{
{Name: "id", Type: "int64", IsNullable: false}, {Name: "id", Type: "int64", IsNullable: false},
{Name: "patrick", Type: "string", IsNullable: false}, {Name: "patrick", Type: "string", IsNullable: false},
}, },
PKey: &dbdrivers.PrimaryKey{ PKey: &bdb.PrimaryKey{
Name: "pkey_id", Name: "pkey_id",
Columns: []string{"id"}, Columns: []string{"id"},
}, },

View file

@ -1,20 +1,16 @@
// Package strmangle is used exclusively by the templates in sqlboiler. // Package strmangle is used exclusively by the templates in sqlboiler.
// There are many helper functions to deal with dbdrivers.* values as well // There are many helper functions to deal with bdb.* values as well
// as string manipulation. Because it is focused on pipelining inside templates // as string manipulation. Because it is focused on pipelining inside templates
// you will see some odd parameter ordering. // you will see some odd parameter ordering.
package strmangle package strmangle
import ( import (
"fmt" "fmt"
"regexp"
"strings" "strings"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"github.com/nullbio/sqlboiler/dbdrivers"
) )
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
// Plural converts singular words to plural words (eg: person to people) // Plural converts singular words to plural words (eg: person to people)
func Plural(name string) string { func Plural(name string) string {
splits := strings.Split(name, "_") splits := strings.Split(name, "_")
@ -111,24 +107,6 @@ func PrefixStringSlice(str string, strs []string) []string {
return ret return ret
} }
// PrimaryKeyFuncSig generates the function signature parameters.
// example: id int64, thingName string
func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
ret := make([]string, len(pkeyCols))
for i, pk := range pkeyCols {
for _, c := range cols {
if pk != c.Name {
continue
}
ret[i] = fmt.Sprintf("%s %s", CamelCase(pk), c.Type)
}
}
return strings.Join(ret, ", ")
}
// GenerateParamFlags generates the SQL statement parameter flags // GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt. // For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string { func GenerateParamFlags(colCount int, startAt int) string {
@ -141,11 +119,11 @@ func GenerateParamFlags(colCount int, startAt int) string {
return strings.Join(cols, ",") return strings.Join(cols, ",")
} }
// WherePrimaryKey returns the where clause using start as the $ flag index // WhereClause returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(pkeyCols []string, start int) string { func WhereClause(pkeyCols []string, start int) string {
if start == 0 { if start == 0 {
panic("0 is not a valid start number for wherePrimaryKey") panic("0 is not a valid start number for whereClause")
} }
cols := make([]string, len(pkeyCols)) cols := make([]string, len(pkeyCols))
@ -156,41 +134,6 @@ func WherePrimaryKey(pkeyCols []string, start int) string {
return strings.Join(cols, " AND ") return strings.Join(cols, " AND ")
} }
// AutoIncPrimaryKey returns the auto-increment primary key column name or an
// empty string.
func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string {
if pkey == nil {
return ""
}
for _, pkeyColumn := range pkey.Columns {
for _, c := range cols {
if c.Name != pkeyColumn {
continue
}
if !rgxAutoIncColumn.MatchString(c.Default) || c.IsNullable ||
!(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) {
continue
}
return pkeyColumn
}
}
return ""
}
// ColumnNames of the columns.
func ColumnNames(cols []dbdrivers.Column) []string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.Name
}
return names
}
// DriverUsesLastInsertID returns whether the database driver supports the // DriverUsesLastInsertID returns whether the database driver supports the
// sql.Result interface. // sql.Result interface.
func DriverUsesLastInsertID(driverName string) bool { func DriverUsesLastInsertID(driverName string) bool {
@ -202,32 +145,6 @@ func DriverUsesLastInsertID(driverName string) bool {
} }
} }
// FilterColumnsByDefault generates the list of columns that have default values
func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string {
var cols []string
for _, c := range columns {
if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) {
cols = append(cols, fmt.Sprintf(`"%s"`, c.Name))
}
}
return strings.Join(cols, `,`)
}
// FilterColumnsByAutoIncrement generates the list of auto increment columns
func FilterColumnsByAutoIncrement(columns []dbdrivers.Column) string {
var cols []string
for _, c := range columns {
if rgxAutoIncColumn.MatchString(c.Default) {
cols = append(cols, fmt.Sprintf(`"%s"`, c.Name))
}
}
return strings.Join(cols, `,`)
}
// Substring returns a substring of str starting at index start and going // Substring returns a substring of str starting at index start and going
// to end-1. // to end-1.
func Substring(start, end int, str string) string { func Substring(start, end int, str string) string {

View file

@ -3,128 +3,8 @@ package strmangle
import ( import (
"strings" "strings"
"testing" "testing"
"github.com/nullbio/sqlboiler/dbdrivers"
) )
var testColumns = []dbdrivers.Column{
{Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
}
func TestAutoIncPrimaryKey(t *testing.T) {
t.Parallel()
tests := map[string]struct {
Expect string
Pkey *dbdrivers.PrimaryKey
Columns []dbdrivers.Column
}{
"nillcase": {
Expect: "",
Pkey: nil,
Columns: nil,
},
"easycase": {
Expect: "one",
Pkey: &dbdrivers.PrimaryKey{
Name: "pkey",
Columns: []string{"one"},
},
Columns: []dbdrivers.Column{
dbdrivers.Column{
Name: "one",
Type: "int32",
IsNullable: false,
Default: `nextval('abc'::regclass)`,
},
},
},
"missingcase": {
Expect: "",
Pkey: &dbdrivers.PrimaryKey{
Name: "pkey",
Columns: []string{"two"},
},
Columns: []dbdrivers.Column{
dbdrivers.Column{
Name: "one",
Type: "int32",
IsNullable: false,
Default: `nextval('abc'::regclass)`,
},
},
},
"wrongtype": {
Expect: "",
Pkey: &dbdrivers.PrimaryKey{
Name: "pkey",
Columns: []string{"one"},
},
Columns: []dbdrivers.Column{
dbdrivers.Column{
Name: "one",
Type: "string",
IsNullable: false,
Default: `nextval('abc'::regclass)`,
},
},
},
"nodefault": {
Expect: "",
Pkey: &dbdrivers.PrimaryKey{
Name: "pkey",
Columns: []string{"one"},
},
Columns: []dbdrivers.Column{
dbdrivers.Column{
Name: "one",
Type: "string",
IsNullable: false,
Default: ``,
},
},
},
"nullable": {
Expect: "",
Pkey: &dbdrivers.PrimaryKey{
Name: "pkey",
Columns: []string{"one"},
},
Columns: []dbdrivers.Column{
dbdrivers.Column{
Name: "one",
Type: "string",
IsNullable: true,
Default: `nextval('abc'::regclass)`,
},
},
},
}
for testName, test := range tests {
primaryKey := AutoIncPrimaryKey(test.Columns, test.Pkey)
if primaryKey != test.Expect {
t.Errorf("%s) wrong primary key, want: %q, got %q", testName, test.Expect, primaryKey)
}
}
}
func TestColumnNames(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
dbdrivers.Column{Name: "one"},
dbdrivers.Column{Name: "two"},
dbdrivers.Column{Name: "three"},
}
out := strings.Join(ColumnNames(cols), " ")
if out != "one two three" {
t.Error("output was wrong:", out)
}
}
func TestDriverUsesLastInsertID(t *testing.T) { func TestDriverUsesLastInsertID(t *testing.T) {
t.Parallel() t.Parallel()
@ -260,50 +140,21 @@ func TestPrefixStringSlice(t *testing.T) {
} }
} }
func TestPrimaryKeyFuncSig(t *testing.T) { func TestWhereClause(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
{
Name: "one",
Type: "int64",
},
{
Name: "two",
Type: "string",
},
{
Name: "three",
Type: "string",
},
}
sig := PrimaryKeyFuncSig(cols, []string{"one"})
if sig != "one int64" {
t.Error("wrong signature:", sig)
}
sig = PrimaryKeyFuncSig(cols, []string{"one", "three"})
if sig != "one int64, three string" {
t.Error("wrong signature:", sig)
}
}
func TestWherePrimaryKey(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
Pkey dbdrivers.PrimaryKey Cols []string
Start int Start int
Should string Should string
}{ }{
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1"}}, Start: 2, Should: "col1=$2"}, {Cols: []string{"col1"}, Start: 2, Should: "col1=$2"},
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2"}}, Start: 4, Should: "col1=$4 AND col2=$5"}, {Cols: []string{"col1", "col2"}, Start: 4, Should: "col1=$4 AND col2=$5"},
{Pkey: dbdrivers.PrimaryKey{Columns: []string{"col1", "col2", "col3"}}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"}, {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: "col1=$4 AND col2=$5 AND col3=$6"},
} }
for i, test := range tests { for i, test := range tests {
r := WherePrimaryKey(test.Pkey.Columns, test.Start) r := WhereClause(test.Cols, test.Start)
if r != test.Should { if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
} }
@ -319,78 +170,7 @@ func TestWherePrimaryKeyPanic(t *testing.T) {
} }
}() }()
WherePrimaryKey(nil, 0) WhereClause(nil, 0)
}
func TestFilterColumnsByDefault(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
{
Name: "col1",
Default: "",
},
{
Name: "col2",
Default: "things",
},
{
Name: "col3",
Default: "",
},
{
Name: "col4",
Default: "things2",
},
}
res := FilterColumnsByDefault(cols, false)
if res != `"col1","col3"` {
t.Errorf("Invalid result: %s", res)
}
res = FilterColumnsByDefault(cols, true)
if res != `"col2","col4"` {
t.Errorf("Invalid result: %s", res)
}
res = FilterColumnsByDefault([]dbdrivers.Column{}, false)
if res != `` {
t.Errorf("Invalid result: %s", res)
}
}
func TestFilterColumnsByAutoIncrement(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
{
Name: "col1",
Default: `nextval("thing"::thing)`,
},
{
Name: "col2",
Default: "things",
},
{
Name: "col3",
Default: "",
},
{
Name: "col4",
Default: `nextval("thing"::thing)`,
},
}
res := FilterColumnsByAutoIncrement(cols)
if res != `"col1","col4"` {
t.Errorf("Invalid result: %s", res)
}
res = FilterColumnsByAutoIncrement([]dbdrivers.Column{})
if res != `` {
t.Errorf("Invalid result: %s", res)
}
} }
func TestSubstring(t *testing.T) { func TestSubstring(t *testing.T) {

View file

@ -8,14 +8,14 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/nullbio/sqlboiler/dbdrivers" "github.com/nullbio/sqlboiler/bdb"
"github.com/nullbio/sqlboiler/strmangle" "github.com/nullbio/sqlboiler/strmangle"
) )
// templateData for sqlboiler templates // templateData for sqlboiler templates
type templateData struct { type templateData struct {
Tables []dbdrivers.Table Tables []bdb.Table
Table dbdrivers.Table Table bdb.Table
DriverName string DriverName string
PkgName string PkgName string