Refactored interface & added tests

This commit is contained in:
Patrick O'brien 2016-04-03 17:15:35 +10:00
parent a7263bde40
commit 0ffccde168
3 changed files with 117 additions and 70 deletions

View file

@ -1,14 +1,16 @@
package dbdrivers
import "fmt"
// Interface for a database driver. Functionality required to support a specific
// database type (eg, MySQL, Postgres etc.)
type Interface interface {
// Tables connects to the database and retrieves the table metadata for
// the given tables, or all tables if none are provided.
Tables(names ...string) ([]Table, error)
TableNames() ([]string, error)
Columns(tableName string) ([]Column, error)
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
// TranslateColumnType takes a Database column type and returns a go column
// type.
// TranslateColumnType takes a Database column type and returns a go column type.
TranslateColumnType(Column) Column
// Open the database connection
@ -32,10 +34,9 @@ type Table struct {
// Column holds information about a database column.
// Types are Go types, converted by TranslateColumnType.
type Column struct {
Name string
Type string
IsPrimaryKey bool
IsNullable bool
Name string
Type string
IsNullable bool
}
// PrimaryKey represents a primary key constraint in a database
@ -52,3 +53,66 @@ type ForeignKey struct {
ForeignTable string
ForeignColumn string
}
// Tables returns the table metadata for the given tables, or all tables if
// no tables are provided.
func Tables(db Interface, names ...string) ([]Table, error) {
var err error
if len(names) == 0 {
if names, err = db.TableNames(); err != nil {
fmt.Println("Unable to get table names.")
return nil, err
}
}
var tables []Table
for _, name := range names {
t := Table{Name: name}
if t.Columns, err = db.Columns(name); err != nil {
return nil, err
fmt.Println("Unable to get columnss.")
}
for i, c := range t.Columns {
t.Columns[i] = db.TranslateColumnType(c)
}
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
fmt.Println("Unable to get primary key info.")
return nil, err
}
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
fmt.Println("Unable to get foreign key info.")
return nil, err
}
setIsJoinTable(&t)
tables = append(tables, t)
}
return tables, nil
}
func setIsJoinTable(t *Table) {
if t.PKey == nil {
return
}
for _, c := range t.PKey.Columns {
found := false
for _, f := range t.FKeys {
if c == f.Column {
found = true
break
}
}
if !found {
return
}
}
t.IsJoinTable = true
}

View file

@ -0,0 +1,31 @@
package dbdrivers
import "testing"
func TestTables(t *testing.T) {
}
func TestSetIsJoinTable(t *testing.T) {
tests := []struct {
Pkey []string
Fkey []string
Should bool
}{
{Pkey: []string{"one"}, Fkey: []string{"one"}, Should: true},
}
for i, test := range tests {
var table Table
table.PKey = &PrimaryKey{Columns: test.Pkey}
for _, k := range test.Fkey {
table.FKeys = append(table.FKeys, ForeignKey{Column: k})
}
setIsJoinTable(&table)
if is := table.IsJoinTable; is != test.Should {
t.Errorf("%d) want: %t, got: %t\nTest: %#v", i, test.Should, is, test)
}
}
}

View file

@ -44,47 +44,11 @@ func (p *PostgresDriver) Close() {
p.dbConn.Close()
}
// Tables returns the table metadata for the given tables, or all tables if
// no tables are provided.
func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
var err error
if len(names) == 0 {
if names, err = p.tableNames(); err != nil {
fmt.Println("Unable to get table names.")
return nil, err
}
}
var tables []Table
for _, name := range names {
t := Table{Name: name}
if t.Columns, err = p.columns(name); err != nil {
fmt.Println("Unable to get columnss.")
return nil, err
}
if t.PKey, err = p.primaryKeyInfo(name); err != nil {
fmt.Println("Unable to get primary key info.")
return nil, err
}
if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
fmt.Println("Unable to get foreign key info.")
return nil, err
}
tables = append(tables, t)
}
return tables, nil
}
// tableNames connects to the postgres database and
// retrieves all table names from the information_schema where the
// table schema is public. It excludes common migration tool tables
// such as gorp_migrations
func (p *PostgresDriver) tableNames() ([]string, error) {
func (p *PostgresDriver) TableNames() ([]string, error) {
var names []string
rows, err := p.dbConn.Query(`select table_name from
@ -111,23 +75,12 @@ func (p *PostgresDriver) tableNames() ([]string, error) {
// from the database information_schema.columns. It retrieves the column names
// and column types and returns those as a []Column after TranslateColumnType()
// converts the SQL types to Go types, for example: "varchar" to "string"
func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
func (p *PostgresDriver) Columns(tableName string) ([]Column, error) {
var columns []Column
rows, err := p.dbConn.Query(`
SELECT c.column_name, c.data_type, c.is_nullable,
CASE WHEN pk.column_name IS NOT NULL THEN 'PRIMARY KEY' ELSE '' END AS KeyType
FROM information_schema.columns c
LEFT JOIN (
SELECT ku.table_name, ku.column_name
FROM information_schema.table_constraints AS tc
INNER JOIN information_schema.key_column_usage AS ku
ON tc.constraint_type = 'PRIMARY KEY'
AND tc.constraint_name = ku.constraint_name
) pk
ON c.table_name = pk.table_name
AND c.column_name = pk.column_name
WHERE c.table_name=$1
SELECT column_name, data_type, is_nullable from
information_schema.columns WHERE table_name=$1
`, tableName)
if err != nil {
@ -136,16 +89,15 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
defer rows.Close()
for rows.Next() {
var colName, colType, isNullable, isPrimary string
if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil {
var colName, colType, isNullable string
if err := rows.Scan(&colName, &colType, &isNullable); err != nil {
return nil, err
}
column := p.TranslateColumnType(Column{
Name: colName,
Type: colType,
IsNullable: isNullable == "YES",
IsPrimaryKey: isPrimary == "PRIMARY KEY",
})
column := Column{
Name: colName,
Type: colType,
IsNullable: isNullable == "YES",
}
columns = append(columns, column)
}
@ -153,7 +105,7 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
}
// primaryKeyInfo looks up the primary key for a table.
func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
pkey := &PrimaryKey{}
var err error
@ -199,7 +151,7 @@ func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
}
// foreignKeyInfo retrieves the foreign keys for a given table name.
func (p *PostgresDriver) foreignKeyInfo(tableName string) ([]ForeignKey, error) {
func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
var fkeys []ForeignKey
query := `