Refactored interface & added tests
This commit is contained in:
parent
a7263bde40
commit
0ffccde168
3 changed files with 117 additions and 70 deletions
|
@ -1,14 +1,16 @@
|
|||
package dbdrivers
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Interface for a database driver. Functionality required to support a specific
|
||||
// database type (eg, MySQL, Postgres etc.)
|
||||
type Interface interface {
|
||||
// Tables connects to the database and retrieves the table metadata for
|
||||
// the given tables, or all tables if none are provided.
|
||||
Tables(names ...string) ([]Table, error)
|
||||
TableNames() ([]string, error)
|
||||
Columns(tableName string) ([]Column, error)
|
||||
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
|
||||
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
|
||||
|
||||
// TranslateColumnType takes a Database column type and returns a go column
|
||||
// type.
|
||||
// TranslateColumnType takes a Database column type and returns a go column type.
|
||||
TranslateColumnType(Column) Column
|
||||
|
||||
// Open the database connection
|
||||
|
@ -32,10 +34,9 @@ type Table struct {
|
|||
// Column holds information about a database column.
|
||||
// Types are Go types, converted by TranslateColumnType.
|
||||
type Column struct {
|
||||
Name string
|
||||
Type string
|
||||
IsPrimaryKey bool
|
||||
IsNullable bool
|
||||
Name string
|
||||
Type string
|
||||
IsNullable bool
|
||||
}
|
||||
|
||||
// PrimaryKey represents a primary key constraint in a database
|
||||
|
@ -52,3 +53,66 @@ type ForeignKey struct {
|
|||
ForeignTable string
|
||||
ForeignColumn string
|
||||
}
|
||||
|
||||
// Tables returns the table metadata for the given tables, or all tables if
|
||||
// no tables are provided.
|
||||
func Tables(db Interface, names ...string) ([]Table, error) {
|
||||
var err error
|
||||
if len(names) == 0 {
|
||||
if names, err = db.TableNames(); err != nil {
|
||||
fmt.Println("Unable to get table names.")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var tables []Table
|
||||
for _, name := range names {
|
||||
t := Table{Name: name}
|
||||
|
||||
if t.Columns, err = db.Columns(name); err != nil {
|
||||
return nil, err
|
||||
fmt.Println("Unable to get columnss.")
|
||||
}
|
||||
|
||||
for i, c := range t.Columns {
|
||||
t.Columns[i] = db.TranslateColumnType(c)
|
||||
}
|
||||
|
||||
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get primary key info.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get foreign key info.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setIsJoinTable(&t)
|
||||
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func setIsJoinTable(t *Table) {
|
||||
if t.PKey == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range t.PKey.Columns {
|
||||
found := false
|
||||
for _, f := range t.FKeys {
|
||||
if c == f.Column {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.IsJoinTable = true
|
||||
}
|
||||
|
|
31
dbdrivers/interface_test.go
Normal file
31
dbdrivers/interface_test.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package dbdrivers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTables(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestSetIsJoinTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
Pkey []string
|
||||
Fkey []string
|
||||
Should bool
|
||||
}{
|
||||
{Pkey: []string{"one"}, Fkey: []string{"one"}, Should: true},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
var table Table
|
||||
|
||||
table.PKey = &PrimaryKey{Columns: test.Pkey}
|
||||
for _, k := range test.Fkey {
|
||||
table.FKeys = append(table.FKeys, ForeignKey{Column: k})
|
||||
}
|
||||
|
||||
setIsJoinTable(&table)
|
||||
if is := table.IsJoinTable; is != test.Should {
|
||||
t.Errorf("%d) want: %t, got: %t\nTest: %#v", i, test.Should, is, test)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -44,47 +44,11 @@ func (p *PostgresDriver) Close() {
|
|||
p.dbConn.Close()
|
||||
}
|
||||
|
||||
// Tables returns the table metadata for the given tables, or all tables if
|
||||
// no tables are provided.
|
||||
func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
|
||||
var err error
|
||||
if len(names) == 0 {
|
||||
if names, err = p.tableNames(); err != nil {
|
||||
fmt.Println("Unable to get table names.")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var tables []Table
|
||||
for _, name := range names {
|
||||
t := Table{Name: name}
|
||||
|
||||
if t.Columns, err = p.columns(name); err != nil {
|
||||
fmt.Println("Unable to get columnss.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.PKey, err = p.primaryKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get primary key info.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
|
||||
fmt.Println("Unable to get foreign key info.")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// tableNames connects to the postgres database and
|
||||
// retrieves all table names from the information_schema where the
|
||||
// table schema is public. It excludes common migration tool tables
|
||||
// such as gorp_migrations
|
||||
func (p *PostgresDriver) tableNames() ([]string, error) {
|
||||
func (p *PostgresDriver) TableNames() ([]string, error) {
|
||||
var names []string
|
||||
|
||||
rows, err := p.dbConn.Query(`select table_name from
|
||||
|
@ -111,23 +75,12 @@ func (p *PostgresDriver) tableNames() ([]string, error) {
|
|||
// from the database information_schema.columns. It retrieves the column names
|
||||
// and column types and returns those as a []Column after TranslateColumnType()
|
||||
// converts the SQL types to Go types, for example: "varchar" to "string"
|
||||
func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
||||
func (p *PostgresDriver) Columns(tableName string) ([]Column, error) {
|
||||
var columns []Column
|
||||
|
||||
rows, err := p.dbConn.Query(`
|
||||
SELECT c.column_name, c.data_type, c.is_nullable,
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRIMARY KEY' ELSE '' END AS KeyType
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN (
|
||||
SELECT ku.table_name, ku.column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
INNER JOIN information_schema.key_column_usage AS ku
|
||||
ON tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.constraint_name = ku.constraint_name
|
||||
) pk
|
||||
ON c.table_name = pk.table_name
|
||||
AND c.column_name = pk.column_name
|
||||
WHERE c.table_name=$1
|
||||
SELECT column_name, data_type, is_nullable from
|
||||
information_schema.columns WHERE table_name=$1
|
||||
`, tableName)
|
||||
|
||||
if err != nil {
|
||||
|
@ -136,16 +89,15 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
|||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var colName, colType, isNullable, isPrimary string
|
||||
if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil {
|
||||
var colName, colType, isNullable string
|
||||
if err := rows.Scan(&colName, &colType, &isNullable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
column := p.TranslateColumnType(Column{
|
||||
Name: colName,
|
||||
Type: colType,
|
||||
IsNullable: isNullable == "YES",
|
||||
IsPrimaryKey: isPrimary == "PRIMARY KEY",
|
||||
})
|
||||
column := Column{
|
||||
Name: colName,
|
||||
Type: colType,
|
||||
IsNullable: isNullable == "YES",
|
||||
}
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
|
@ -153,7 +105,7 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
|||
}
|
||||
|
||||
// primaryKeyInfo looks up the primary key for a table.
|
||||
func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||
func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||
pkey := &PrimaryKey{}
|
||||
var err error
|
||||
|
||||
|
@ -199,7 +151,7 @@ func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
|||
}
|
||||
|
||||
// foreignKeyInfo retrieves the foreign keys for a given table name.
|
||||
func (p *PostgresDriver) foreignKeyInfo(tableName string) ([]ForeignKey, error) {
|
||||
func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
|
||||
var fkeys []ForeignKey
|
||||
|
||||
query := `
|
||||
|
|
Loading…
Reference in a new issue