Properly abstract LastInsertID

This commit is contained in:
Aaron L 2016-08-13 16:27:34 -07:00
parent ef03225024
commit 944303f2f5
9 changed files with 93 additions and 94 deletions

View file

@ -74,6 +74,11 @@ func (p *PostgresDriver) Close() {
p.dbConn.Close()
}
// UseLastInsertID returns false for postgres
func (p *PostgresDriver) UseLastInsertID() bool {
return false
}
// TableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables

View file

@ -14,9 +14,12 @@ type Interface interface {
// TranslateColumnType takes a Database column type and returns a go column type.
TranslateColumnType(Column) Column
// UseLastInsertID should return true if the driver is capable of using
// the sql.Exec result's LastInsertId
UseLastInsertID() bool
// Open the database connection
Open() error
// Close the database connection
Close()
}
@ -109,14 +112,3 @@ func setForeignKeyConstraints(t *Table, tables []Table) {
func setRelationships(t *Table, tables []Table) {
t.ToManyRelationships = toManyRelationships(*t, tables)
}
// DriverUsesLastInsertID returns whether the database driver
// supports the sql.Result interface.
func DriverUsesLastInsertID(driverName string) bool {
switch driverName {
case "postgres":
return false
default:
return true
}
}

View file

@ -20,6 +20,10 @@ func (t testInterface) Columns(tableName string) ([]Column, error) {
return testCols, nil
}
func (t testInterface) UseLastInsertID() bool {
return false
}
var testPkey = &PrimaryKey{Name: "pkey1", Columns: []string{"col1", "col2"}}
func (t testInterface) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
@ -223,14 +227,3 @@ func TestSetRelationships(t *testing.T) {
t.Error("should not be a join table")
}
}
func TestDriverUsesLastInsertID(t *testing.T) {
t.Parallel()
if DriverUsesLastInsertID("postgres") {
t.Error("postgres does not support LastInsertId")
}
if !DriverUsesLastInsertID("mysql") {
t.Error("postgres does support LastInsertId")
}
}

64
fakedb_test.go Normal file
View file

@ -0,0 +1,64 @@
package main
import (
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/strmangle"
)
type fakeDB int
func (fakeDB) TableNames() ([]string, error) {
return []string{"users", "videos", "contests", "notifications", "users_videos_tags"}, nil
}
func (fakeDB) Columns(tableName string) ([]bdb.Column, error) {
return map[string][]bdb.Column{
"users": {{Name: "id", Type: "int32"}},
"contests": {{Name: "id", Type: "int32", Nullable: true}},
"videos": {
{Name: "id", Type: "int32"},
{Name: "user_id", Type: "int32", Nullable: true, Unique: true},
{Name: "contest_id", Type: "int32"},
},
"notifications": {
{Name: "user_id", Type: "int32"},
{Name: "source_id", Type: "int32", Nullable: true},
},
"users_videos_tags": {
{Name: "user_id", Type: "int32"},
{Name: "video_id", Type: "int32"},
},
}[tableName], nil
}
func (fakeDB) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) {
return map[string][]bdb.ForeignKey{
"videos": {
{Name: "videos_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "videos_contest_id_fk", Column: "contest_id", ForeignTable: "contests", ForeignColumn: "id"},
},
"notifications": {
{Name: "notifications_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "notifications_source_id_fk", Column: "source_id", ForeignTable: "users", ForeignColumn: "id"},
},
"users_videos_tags": {
{Name: "user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "video_id_fk", Column: "video_id", ForeignTable: "videos", ForeignColumn: "id"},
},
}[tableName], nil
}
func (fakeDB) TranslateColumnType(c bdb.Column) bdb.Column {
if c.Nullable {
c.Type = "null." + strmangle.TitleCase(c.Type)
}
return c
}
func (fakeDB) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) {
return map[string]*bdb.PrimaryKey{
"users_videos_tags": {
Name: "user_video_id_pkey",
Columns: []string{"user_id", "video_id"},
},
}[tableName], nil
}
func (fakeDB) UseLastInsertID() bool { return false }
func (fakeDB) Open() error { return nil }
func (fakeDB) Close() {}

View file

@ -75,9 +75,10 @@ func New(config *Config) (*State, error) {
// state given.
func (s *State) Run(includeTests bool) error {
singletonData := &templateData{
Tables: s.Tables,
DriverName: s.Config.DriverName,
PkgName: s.Config.PkgName,
Tables: s.Tables,
DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName,
StringFuncs: templateStringMappers,
}
@ -102,10 +103,11 @@ func (s *State) Run(includeTests bool) error {
}
data := &templateData{
Tables: s.Tables,
Table: table,
DriverName: s.Config.DriverName,
PkgName: s.Config.PkgName,
Tables: s.Tables,
Table: table,
DriverName: s.Config.DriverName,
UseLastInsertID: s.Driver.UseLastInsertID(),
PkgName: s.Config.PkgName,
StringFuncs: templateStringMappers,
}

View file

@ -20,6 +20,7 @@ var rgxHasSpaces = regexp.MustCompile(`^\s+`)
func init() {
state = &State{
Driver: fakeDB(0),
Tables: []bdb.Table{
{
Name: "patrick_table",

View file

@ -14,10 +14,11 @@ import (
// templateData for sqlboiler templates
type templateData struct {
Tables []bdb.Table
Table bdb.Table
DriverName string
PkgName string
Tables []bdb.Table
Table bdb.Table
DriverName string
UseLastInsertID bool
PkgName string
StringFuncs map[string]func(string) string
}
@ -168,7 +169,6 @@ var templateFunctions = template.FuncMap{
"textsFromRelationship": textsFromRelationship,
// dbdrivers ops
"driverUsesLastInsertID": bdb.DriverUsesLastInsertID,
"filterColumnsByDefault": bdb.FilterColumnsByDefault,
"filterColumnsBySimpleDefault": bdb.FilterColumnsBySimpleDefault,
"filterColumnsByAutoIncrement": bdb.FilterColumnsByAutoIncrement,

View file

@ -40,7 +40,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1))
{{if driverUsesLastInsertID .DriverName}}
{{if .UseLastInsertID}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, ins)
fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, wl...))

View file

@ -6,66 +6,8 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/vattle/sqlboiler/bdb"
"github.com/vattle/sqlboiler/strmangle"
)
type fakeDB int
func (fakeDB) TableNames() ([]string, error) {
return []string{"users", "videos", "contests", "notifications", "users_videos_tags"}, nil
}
func (fakeDB) Columns(tableName string) ([]bdb.Column, error) {
return map[string][]bdb.Column{
"users": {{Name: "id", Type: "int32"}},
"contests": {{Name: "id", Type: "int32", Nullable: true}},
"videos": {
{Name: "id", Type: "int32"},
{Name: "user_id", Type: "int32", Nullable: true, Unique: true},
{Name: "contest_id", Type: "int32"},
},
"notifications": {
{Name: "user_id", Type: "int32"},
{Name: "source_id", Type: "int32", Nullable: true},
},
"users_videos_tags": {
{Name: "user_id", Type: "int32"},
{Name: "video_id", Type: "int32"},
},
}[tableName], nil
}
func (fakeDB) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) {
return map[string][]bdb.ForeignKey{
"videos": {
{Name: "videos_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "videos_contest_id_fk", Column: "contest_id", ForeignTable: "contests", ForeignColumn: "id"},
},
"notifications": {
{Name: "notifications_user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "notifications_source_id_fk", Column: "source_id", ForeignTable: "users", ForeignColumn: "id"},
},
"users_videos_tags": {
{Name: "user_id_fk", Column: "user_id", ForeignTable: "users", ForeignColumn: "id"},
{Name: "video_id_fk", Column: "video_id", ForeignTable: "videos", ForeignColumn: "id"},
},
}[tableName], nil
}
func (fakeDB) TranslateColumnType(c bdb.Column) bdb.Column {
if c.Nullable {
c.Type = "null." + strmangle.TitleCase(c.Type)
}
return c
}
func (fakeDB) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) {
return map[string]*bdb.PrimaryKey{
"users_videos_tags": {
Name: "user_video_id_pkey",
Columns: []string{"user_id", "video_id"},
},
}[tableName], nil
}
func (fakeDB) Open() error { return nil }
func (fakeDB) Close() {}
func TestTextsFromForeignKey(t *testing.T) {
t.Parallel()