Refactored interface & added tests
This commit is contained in:
parent
a7263bde40
commit
0ffccde168
3 changed files with 117 additions and 70 deletions
|
@ -1,14 +1,16 @@
|
||||||
package dbdrivers
|
package dbdrivers
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
// Interface for a database driver. Functionality required to support a specific
|
// Interface for a database driver. Functionality required to support a specific
|
||||||
// database type (eg, MySQL, Postgres etc.)
|
// database type (eg, MySQL, Postgres etc.)
|
||||||
type Interface interface {
|
type Interface interface {
|
||||||
// Tables connects to the database and retrieves the table metadata for
|
TableNames() ([]string, error)
|
||||||
// the given tables, or all tables if none are provided.
|
Columns(tableName string) ([]Column, error)
|
||||||
Tables(names ...string) ([]Table, error)
|
PrimaryKeyInfo(tableName string) (*PrimaryKey, error)
|
||||||
|
ForeignKeyInfo(tableName string) ([]ForeignKey, error)
|
||||||
|
|
||||||
// TranslateColumnType takes a Database column type and returns a go column
|
// TranslateColumnType takes a Database column type and returns a go column type.
|
||||||
// type.
|
|
||||||
TranslateColumnType(Column) Column
|
TranslateColumnType(Column) Column
|
||||||
|
|
||||||
// Open the database connection
|
// Open the database connection
|
||||||
|
@ -34,7 +36,6 @@ type Table struct {
|
||||||
type Column struct {
|
type Column struct {
|
||||||
Name string
|
Name string
|
||||||
Type string
|
Type string
|
||||||
IsPrimaryKey bool
|
|
||||||
IsNullable bool
|
IsNullable bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,3 +53,66 @@ type ForeignKey struct {
|
||||||
ForeignTable string
|
ForeignTable string
|
||||||
ForeignColumn string
|
ForeignColumn string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tables returns the table metadata for the given tables, or all tables if
|
||||||
|
// no tables are provided.
|
||||||
|
func Tables(db Interface, names ...string) ([]Table, error) {
|
||||||
|
var err error
|
||||||
|
if len(names) == 0 {
|
||||||
|
if names, err = db.TableNames(); err != nil {
|
||||||
|
fmt.Println("Unable to get table names.")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tables []Table
|
||||||
|
for _, name := range names {
|
||||||
|
t := Table{Name: name}
|
||||||
|
|
||||||
|
if t.Columns, err = db.Columns(name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
fmt.Println("Unable to get columnss.")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range t.Columns {
|
||||||
|
t.Columns[i] = db.TranslateColumnType(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.PKey, err = db.PrimaryKeyInfo(name); err != nil {
|
||||||
|
fmt.Println("Unable to get primary key info.")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.FKeys, err = db.ForeignKeyInfo(name); err != nil {
|
||||||
|
fmt.Println("Unable to get foreign key info.")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsJoinTable(&t)
|
||||||
|
|
||||||
|
tables = append(tables, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tables, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIsJoinTable(t *Table) {
|
||||||
|
if t.PKey == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range t.PKey.Columns {
|
||||||
|
found := false
|
||||||
|
for _, f := range t.FKeys {
|
||||||
|
if c == f.Column {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.IsJoinTable = true
|
||||||
|
}
|
||||||
|
|
31
dbdrivers/interface_test.go
Normal file
31
dbdrivers/interface_test.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package dbdrivers
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestTables(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetIsJoinTable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
Pkey []string
|
||||||
|
Fkey []string
|
||||||
|
Should bool
|
||||||
|
}{
|
||||||
|
{Pkey: []string{"one"}, Fkey: []string{"one"}, Should: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
var table Table
|
||||||
|
|
||||||
|
table.PKey = &PrimaryKey{Columns: test.Pkey}
|
||||||
|
for _, k := range test.Fkey {
|
||||||
|
table.FKeys = append(table.FKeys, ForeignKey{Column: k})
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsJoinTable(&table)
|
||||||
|
if is := table.IsJoinTable; is != test.Should {
|
||||||
|
t.Errorf("%d) want: %t, got: %t\nTest: %#v", i, test.Should, is, test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,47 +44,11 @@ func (p *PostgresDriver) Close() {
|
||||||
p.dbConn.Close()
|
p.dbConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tables returns the table metadata for the given tables, or all tables if
|
|
||||||
// no tables are provided.
|
|
||||||
func (p *PostgresDriver) Tables(names ...string) ([]Table, error) {
|
|
||||||
var err error
|
|
||||||
if len(names) == 0 {
|
|
||||||
if names, err = p.tableNames(); err != nil {
|
|
||||||
fmt.Println("Unable to get table names.")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tables []Table
|
|
||||||
for _, name := range names {
|
|
||||||
t := Table{Name: name}
|
|
||||||
|
|
||||||
if t.Columns, err = p.columns(name); err != nil {
|
|
||||||
fmt.Println("Unable to get columnss.")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.PKey, err = p.primaryKeyInfo(name); err != nil {
|
|
||||||
fmt.Println("Unable to get primary key info.")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.FKeys, err = p.foreignKeyInfo(name); err != nil {
|
|
||||||
fmt.Println("Unable to get foreign key info.")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tables = append(tables, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tables, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tableNames connects to the postgres database and
|
// tableNames connects to the postgres database and
|
||||||
// retrieves all table names from the information_schema where the
|
// retrieves all table names from the information_schema where the
|
||||||
// table schema is public. It excludes common migration tool tables
|
// table schema is public. It excludes common migration tool tables
|
||||||
// such as gorp_migrations
|
// such as gorp_migrations
|
||||||
func (p *PostgresDriver) tableNames() ([]string, error) {
|
func (p *PostgresDriver) TableNames() ([]string, error) {
|
||||||
var names []string
|
var names []string
|
||||||
|
|
||||||
rows, err := p.dbConn.Query(`select table_name from
|
rows, err := p.dbConn.Query(`select table_name from
|
||||||
|
@ -111,23 +75,12 @@ func (p *PostgresDriver) tableNames() ([]string, error) {
|
||||||
// from the database information_schema.columns. It retrieves the column names
|
// from the database information_schema.columns. It retrieves the column names
|
||||||
// and column types and returns those as a []Column after TranslateColumnType()
|
// and column types and returns those as a []Column after TranslateColumnType()
|
||||||
// converts the SQL types to Go types, for example: "varchar" to "string"
|
// converts the SQL types to Go types, for example: "varchar" to "string"
|
||||||
func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
func (p *PostgresDriver) Columns(tableName string) ([]Column, error) {
|
||||||
var columns []Column
|
var columns []Column
|
||||||
|
|
||||||
rows, err := p.dbConn.Query(`
|
rows, err := p.dbConn.Query(`
|
||||||
SELECT c.column_name, c.data_type, c.is_nullable,
|
SELECT column_name, data_type, is_nullable from
|
||||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRIMARY KEY' ELSE '' END AS KeyType
|
information_schema.columns WHERE table_name=$1
|
||||||
FROM information_schema.columns c
|
|
||||||
LEFT JOIN (
|
|
||||||
SELECT ku.table_name, ku.column_name
|
|
||||||
FROM information_schema.table_constraints AS tc
|
|
||||||
INNER JOIN information_schema.key_column_usage AS ku
|
|
||||||
ON tc.constraint_type = 'PRIMARY KEY'
|
|
||||||
AND tc.constraint_name = ku.constraint_name
|
|
||||||
) pk
|
|
||||||
ON c.table_name = pk.table_name
|
|
||||||
AND c.column_name = pk.column_name
|
|
||||||
WHERE c.table_name=$1
|
|
||||||
`, tableName)
|
`, tableName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -136,16 +89,15 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var colName, colType, isNullable, isPrimary string
|
var colName, colType, isNullable string
|
||||||
if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil {
|
if err := rows.Scan(&colName, &colType, &isNullable); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
column := p.TranslateColumnType(Column{
|
column := Column{
|
||||||
Name: colName,
|
Name: colName,
|
||||||
Type: colType,
|
Type: colType,
|
||||||
IsNullable: isNullable == "YES",
|
IsNullable: isNullable == "YES",
|
||||||
IsPrimaryKey: isPrimary == "PRIMARY KEY",
|
}
|
||||||
})
|
|
||||||
columns = append(columns, column)
|
columns = append(columns, column)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +105,7 @@ func (p *PostgresDriver) columns(tableName string) ([]Column, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// primaryKeyInfo looks up the primary key for a table.
|
// primaryKeyInfo looks up the primary key for a table.
|
||||||
func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||||
pkey := &PrimaryKey{}
|
pkey := &PrimaryKey{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -199,7 +151,7 @@ func (p *PostgresDriver) primaryKeyInfo(tableName string) (*PrimaryKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// foreignKeyInfo retrieves the foreign keys for a given table name.
|
// foreignKeyInfo retrieves the foreign keys for a given table name.
|
||||||
func (p *PostgresDriver) foreignKeyInfo(tableName string) ([]ForeignKey, error) {
|
func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) {
|
||||||
var fkeys []ForeignKey
|
var fkeys []ForeignKey
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
|
|
Loading…
Add table
Reference in a new issue