Rename everything.
This commit is contained in:
parent
48a9ba8d29
commit
1394489a63
11 changed files with 99 additions and 97 deletions
|
@ -9,7 +9,7 @@ SQLBoiler is a tool to generate Go boilerplate code for database interactions. S
|
||||||
# Supports?
|
# Supports?
|
||||||
* Postgres
|
* Postgres
|
||||||
|
|
||||||
If anyone wants to create a driver for their own database it's easy very to do. All you need to do is create a driver in the ````/dbdrivers```` package that implements the ````DBDriver```` interface (you can use ````postgres_driver.go```` as an example), and add your driver to the switch statement in the ````initDBDriver()```` function in ````sqlboiler.go````. That's it!
|
If anyone wants to create a driver for their own database it's easy very to do. All you need to do is create a driver in the ````/dbdrivers```` package that implements the ````Interface```` interface (you can use ````postgres_driver.go```` as an example), and add your driver to the switch statement in the ````initInterface()```` function in ````sqlboiler.go````. That's it!
|
||||||
|
|
||||||
I've included templates for struct definitions and select, delete and insert statement helpers. Editing the output of the existing templates is as easy as modifying the template file, but to add new types of statements you'll need to add a new command and a new template. This is also very easy to do, and you can use any of the existing command files as an example.
|
I've included templates for struct definitions and select, delete and insert statement helpers. Editing the output of the existing templates is as easy as modifying the template file, but to add new types of statements you'll need to add a new command and a new template. This is also very easy to do, and you can use any of the existing command files as an example.
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ var sqlBoilerDefaultTestImports = imports{
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlBoilerConditionalTypeImports imports are only included in the template output
|
// sqlBoilerConditionalTypeImports imports are only included in the template output
|
||||||
// if the database requires one of the following special types. Check ParseTableInfo
|
// if the database requires one of the following special types. Check TranslateColumn
|
||||||
// to see the type assignments.
|
// to see the type assignments.
|
||||||
var sqlBoilerConditionalTypeImports = map[string]imports{
|
var sqlBoilerConditionalTypeImports = map[string]imports{
|
||||||
"null.Int": imports{
|
"null.Int": imports{
|
||||||
|
|
|
@ -40,7 +40,7 @@ func combineImports(a, b imports) imports {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func combineConditionalTypeImports(a imports, b map[string]imports, columns []dbdrivers.DBColumn) imports {
|
func combineConditionalTypeImports(a imports, b map[string]imports, columns []dbdrivers.Column) imports {
|
||||||
tmpImp := imports{
|
tmpImp := imports{
|
||||||
standard: make(importList, len(a.standard)),
|
standard: make(importList, len(a.standard)),
|
||||||
thirdparty: make(importList, len(a.thirdparty)),
|
thirdparty: make(importList, len(a.thirdparty)),
|
||||||
|
|
|
@ -30,17 +30,17 @@ func TestCombineConditionalTypeImports(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cols := []dbdrivers.DBColumn{
|
cols := []dbdrivers.Column{
|
||||||
dbdrivers.DBColumn{
|
dbdrivers.Column{
|
||||||
Type: "null.Time",
|
Type: "null.Time",
|
||||||
},
|
},
|
||||||
dbdrivers.DBColumn{
|
dbdrivers.Column{
|
||||||
Type: "null.Time",
|
Type: "null.Time",
|
||||||
},
|
},
|
||||||
dbdrivers.DBColumn{
|
dbdrivers.Column{
|
||||||
Type: "time.Time",
|
Type: "time.Time",
|
||||||
},
|
},
|
||||||
dbdrivers.DBColumn{
|
dbdrivers.Column{
|
||||||
Type: "null.Float",
|
Type: "null.Float",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,16 +18,16 @@ type CobraRunFunc func(cmd *cobra.Command, args []string)
|
||||||
// output file, if one is specified with a flag.
|
// output file, if one is specified with a flag.
|
||||||
type CmdData struct {
|
type CmdData struct {
|
||||||
Tables []string
|
Tables []string
|
||||||
Columns [][]dbdrivers.DBColumn
|
Columns [][]dbdrivers.Column
|
||||||
PkgName string
|
PkgName string
|
||||||
OutFolder string
|
OutFolder string
|
||||||
DBDriver dbdrivers.DBDriver
|
Interface dbdrivers.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
// tplData is used to pass data to the template
|
// tplData is used to pass data to the template
|
||||||
type tplData struct {
|
type tplData struct {
|
||||||
Table string
|
Table string
|
||||||
Columns []dbdrivers.DBColumn
|
Columns []dbdrivers.Column
|
||||||
PkgName string
|
PkgName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ func init() {
|
||||||
// sqlBoilerPostRun cleans up the output file and database connection once
|
// sqlBoilerPostRun cleans up the output file and database connection once
|
||||||
// all commands are finished running.
|
// all commands are finished running.
|
||||||
func sqlBoilerPostRun(cmd *cobra.Command, args []string) {
|
func sqlBoilerPostRun(cmd *cobra.Command, args []string) {
|
||||||
cmdData.DBDriver.Close()
|
cmdData.Interface.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlBoilerPreRun executes before all commands start running. Its job is to
|
// sqlBoilerPreRun executes before all commands start running. Its job is to
|
||||||
|
@ -68,11 +68,11 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
||||||
var err error
|
var err error
|
||||||
cmdData = &CmdData{}
|
cmdData = &CmdData{}
|
||||||
|
|
||||||
// Initialize the cmdData.DBDriver
|
// Initialize the cmdData.Interface
|
||||||
initDBDriver()
|
initInterface()
|
||||||
|
|
||||||
// Connect to the driver database
|
// Connect to the driver database
|
||||||
if err = cmdData.DBDriver.Open(); err != nil {
|
if err = cmdData.Interface.Open(); err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
|
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,9 +103,9 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDBDriver attempts to set the cmdData DBDriver based off the passed in
|
// initInterface attempts to set the cmdData Interface based off the passed in
|
||||||
// driver flag value. If an invalid flag string is provided the program will exit.
|
// driver flag value. If an invalid flag string is provided the program will exit.
|
||||||
func initDBDriver() {
|
func initInterface() {
|
||||||
// Retrieve driver flag
|
// Retrieve driver flag
|
||||||
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
||||||
if driverName == "" {
|
if driverName == "" {
|
||||||
|
@ -115,7 +115,7 @@ func initDBDriver() {
|
||||||
// Create a driver based off driver flag
|
// Create a driver based off driver flag
|
||||||
switch driverName {
|
switch driverName {
|
||||||
case "postgres":
|
case "postgres":
|
||||||
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
|
cmdData.Interface = dbdrivers.NewPostgresDriver(
|
||||||
cfg.Postgres.User,
|
cfg.Postgres.User,
|
||||||
cfg.Postgres.Pass,
|
cfg.Postgres.Pass,
|
||||||
cfg.Postgres.DBName,
|
cfg.Postgres.DBName,
|
||||||
|
@ -124,7 +124,7 @@ func initDBDriver() {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmdData.DBDriver == nil {
|
if cmdData.Interface == nil {
|
||||||
errorQuit(errors.New("An invalid driver name was provided"))
|
errorQuit(errors.New("An invalid driver name was provided"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,7 @@ func initTables() {
|
||||||
if len(cmdData.Tables) == 0 {
|
if len(cmdData.Tables) == 0 {
|
||||||
// get all table names
|
// get all table names
|
||||||
var err error
|
var err error
|
||||||
cmdData.Tables, err = cmdData.DBDriver.GetAllTables()
|
cmdData.Tables, err = cmdData.Interface.AllTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
|
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
|
||||||
}
|
}
|
||||||
|
@ -159,11 +159,11 @@ func initTables() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// initColumns builds a description of each table (column name, column type)
|
// initColumns builds a description of each table (column name, column type)
|
||||||
// and assigns it to cmdData.Columns, the slice of dbdrivers.DBColumn slices.
|
// and assigns it to cmdData.Columns, the slice of dbdrivers.Column slices.
|
||||||
func initColumns() {
|
func initColumns() {
|
||||||
// loop over table Names and build Columns
|
// loop over table Names and build Columns
|
||||||
for i := 0; i < len(cmdData.Tables); i++ {
|
for i := 0; i < len(cmdData.Tables); i++ {
|
||||||
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.Tables[i])
|
tInfo, err := cmdData.Interface.Columns(cmdData.Tables[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorQuit(fmt.Errorf("Unable to get the table info: %s", err))
|
errorQuit(fmt.Errorf("Unable to get the table info: %s", err))
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,8 +14,8 @@ import (
|
||||||
func init() {
|
func init() {
|
||||||
cmdData = &CmdData{
|
cmdData = &CmdData{
|
||||||
Tables: []string{"patrick_table", "spiderman"},
|
Tables: []string{"patrick_table", "spiderman"},
|
||||||
Columns: [][]dbdrivers.DBColumn{
|
Columns: [][]dbdrivers.Column{
|
||||||
[]dbdrivers.DBColumn{
|
[]dbdrivers.Column{
|
||||||
{Name: "patrick_column", Type: "string", IsNullable: false},
|
{Name: "patrick_column", Type: "string", IsNullable: false},
|
||||||
{Name: "aaron_column", Type: "null.String", IsNullable: true},
|
{Name: "aaron_column", Type: "null.String", IsNullable: true},
|
||||||
{Name: "id", Type: "null.Int", IsNullable: true},
|
{Name: "id", Type: "null.Int", IsNullable: true},
|
||||||
|
@ -24,13 +24,13 @@ func init() {
|
||||||
{Name: "fun_time", Type: "time.Time", IsNullable: false},
|
{Name: "fun_time", Type: "time.Time", IsNullable: false},
|
||||||
{Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false},
|
{Name: "cool_stuff_forever", Type: "[]byte", IsNullable: false},
|
||||||
},
|
},
|
||||||
[]dbdrivers.DBColumn{
|
[]dbdrivers.Column{
|
||||||
{Name: "patrick", Type: "string", IsNullable: false},
|
{Name: "patrick", Type: "string", IsNullable: false},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
PkgName: "patrick",
|
PkgName: "patrick",
|
||||||
OutFolder: "",
|
OutFolder: "",
|
||||||
DBDriver: nil,
|
Interface: nil,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -183,11 +183,11 @@ func makeDBName(tableName, colName string) string {
|
||||||
return tableName + "_" + colName
|
return tableName + "_" + colName
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateParamNames takes a []DBColumn and returns a comma seperated
|
// updateParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for the update statement template SET clause.
|
// list of parameter names for the update statement template SET clause.
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: col1=$1,col2=$2,col3=$3
|
||||||
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
// Note: updateParamNames will exclude the PRIMARY KEY column.
|
||||||
func updateParamNames(columns []dbdrivers.DBColumn) string {
|
func updateParamNames(columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
counter := 0
|
counter := 0
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
|
@ -200,11 +200,11 @@ func updateParamNames(columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(names, ",")
|
return strings.Join(names, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateParamVariables takes a prefix and a []DBColumns and returns a
|
// updateParamVariables takes a prefix and a []Columns and returns a
|
||||||
// comma seperated list of parameter variable names for the update statement.
|
// comma seperated list of parameter variable names for the update statement.
|
||||||
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||||
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
// Note: updateParamVariables will exclude the PRIMARY KEY column.
|
||||||
func updateParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
func updateParamVariables(prefix string, columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
|
@ -218,9 +218,9 @@ func updateParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamNames takes a []DBColumn and returns a comma seperated
|
// insertParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for the insert statement template.
|
// list of parameter names for the insert statement template.
|
||||||
func insertParamNames(columns []dbdrivers.DBColumn) string {
|
func insertParamNames(columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
names = append(names, c.Name)
|
names = append(names, c.Name)
|
||||||
|
@ -228,9 +228,9 @@ func insertParamNames(columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamFlags takes a []DBColumn and returns a comma seperated
|
// insertParamFlags takes a []Column and returns a comma seperated
|
||||||
// list of parameter flags for the insert statement template.
|
// list of parameter flags for the insert statement template.
|
||||||
func insertParamFlags(columns []dbdrivers.DBColumn) string {
|
func insertParamFlags(columns []dbdrivers.Column) string {
|
||||||
params := make([]string, 0, len(columns))
|
params := make([]string, 0, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
params = append(params, fmt.Sprintf("$%d", i+1))
|
params = append(params, fmt.Sprintf("$%d", i+1))
|
||||||
|
@ -238,10 +238,10 @@ func insertParamFlags(columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(params, ", ")
|
return strings.Join(params, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertParamVariables takes a prefix and a []DBColumns and returns a
|
// insertParamVariables takes a prefix and a []Columns and returns a
|
||||||
// comma seperated list of parameter variable names for the insert statement.
|
// comma seperated list of parameter variable names for the insert statement.
|
||||||
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
|
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
|
||||||
func insertParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
func insertParamVariables(prefix string, columns []dbdrivers.Column) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
|
@ -252,11 +252,11 @@ func insertParamVariables(prefix string, columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(names, ", ")
|
return strings.Join(names, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectParamNames takes a []DBColumn and returns a comma seperated
|
// selectParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names with for the select statement template.
|
// list of parameter names with for the select statement template.
|
||||||
// It also uses the table name to generate the "AS" part of the statement, for
|
// It also uses the table name to generate the "AS" part of the statement, for
|
||||||
// example: var_name AS table_name_var_name, ...
|
// example: var_name AS table_name_var_name, ...
|
||||||
func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
|
func selectParamNames(tableName string, columns []dbdrivers.Column) string {
|
||||||
selects := make([]string, 0, len(columns))
|
selects := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
statement := fmt.Sprintf("%s AS %s", c.Name, makeDBName(tableName, c.Name))
|
statement := fmt.Sprintf("%s AS %s", c.Name, makeDBName(tableName, c.Name))
|
||||||
|
@ -266,9 +266,9 @@ func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
|
||||||
return strings.Join(selects, ", ")
|
return strings.Join(selects, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanParamNames takes a []DBColumn and returns a comma seperated
|
// scanParamNames takes a []Column and returns a comma seperated
|
||||||
// list of parameter names for use in a db.Scan() call.
|
// list of parameter names for use in a db.Scan() call.
|
||||||
func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
|
func scanParamNames(object string, columns []dbdrivers.Column) string {
|
||||||
scans := make([]string, 0, len(columns))
|
scans := make([]string, 0, len(columns))
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
|
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
|
||||||
|
@ -279,7 +279,7 @@ func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
// hasPrimaryKey returns true if one of the columns passed in is a primary key
|
||||||
func hasPrimaryKey(columns []dbdrivers.DBColumn) bool {
|
func hasPrimaryKey(columns []dbdrivers.Column) bool {
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if c.IsPrimaryKey {
|
if c.IsPrimaryKey {
|
||||||
return true
|
return true
|
||||||
|
@ -290,7 +290,7 @@ func hasPrimaryKey(columns []dbdrivers.DBColumn) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPrimaryKey returns the primary key column name if one is present
|
// getPrimaryKey returns the primary key column name if one is present
|
||||||
func getPrimaryKey(columns []dbdrivers.DBColumn) string {
|
func getPrimaryKey(columns []dbdrivers.Column) string {
|
||||||
for _, c := range columns {
|
for _, c := range columns {
|
||||||
if c.IsPrimaryKey {
|
if c.IsPrimaryKey {
|
||||||
return c.Name
|
return c.Name
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testColumns = []dbdrivers.DBColumn{
|
var testColumns = []dbdrivers.Column{
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func TestMakeDBName(t *testing.T) {
|
||||||
func TestUpdateParamNames(t *testing.T) {
|
func TestUpdateParamNames(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var testCols = []dbdrivers.DBColumn{
|
var testCols = []dbdrivers.Column{
|
||||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||||
|
@ -113,7 +113,7 @@ func TestUpdateParamNames(t *testing.T) {
|
||||||
func TestUpdateParamVariables(t *testing.T) {
|
func TestUpdateParamVariables(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var testCols = []dbdrivers.DBColumn{
|
var testCols = []dbdrivers.Column{
|
||||||
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
{Name: "id", Type: "int", IsNullable: false, IsPrimaryKey: true},
|
||||||
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
{Name: "friend_column", Type: "int", IsNullable: false, IsPrimaryKey: false},
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true, IsPrimaryKey: false},
|
||||||
|
|
|
@ -1,23 +1,18 @@
|
||||||
package dbdrivers
|
package dbdrivers
|
||||||
|
|
||||||
// DBDriver is an interface that handles the operation uniqueness of each
|
// Interface for a database driver. Functionality required to support a specific
|
||||||
// type of database connection. For example, queries to obtain schema data
|
// database type (eg, MySQL, Postgres etc.)
|
||||||
// will vary depending on what type of database software is in use.
|
type Interface interface {
|
||||||
// The goal of the DBDriver is to retrieve all table names in a database
|
// AllTables connects to the database and retrieves all "public" table names
|
||||||
// using GetAllTables() if no table names are provided via flags,
|
AllTables() ([]string, error)
|
||||||
// to handle the database connection using Open() and Close(), and to
|
|
||||||
// build the table information using GetTableInfo() and ParseTableInfo()
|
|
||||||
type DBDriver interface {
|
|
||||||
// GetAllTables connects to the database and retrieves all "public" table names
|
|
||||||
GetAllTables() ([]string, error)
|
|
||||||
|
|
||||||
// GetTableInfo retrieves column information about the table.
|
// Columns retrieves column information about the table.
|
||||||
GetTableInfo(tableName string) ([]DBColumn, error)
|
Columns(tableName string) ([]Column, error)
|
||||||
|
|
||||||
// ParseTableInfo builds a DBColumn out of a column name and column type.
|
// TranslateColumn builds a Column out of a column metadata.
|
||||||
// Its main responsibility is to convert database types to Go types, for example
|
// Its main responsibility is to convert database types to Go types, for
|
||||||
// "varchar" to "string".
|
// example "varchar" to "string".
|
||||||
ParseTableInfo(name, colType string, isNullable bool, isPrimary bool) DBColumn
|
TranslateColumn(Column) Column
|
||||||
|
|
||||||
// Open the database connection
|
// Open the database connection
|
||||||
Open() error
|
Open() error
|
||||||
|
@ -26,9 +21,17 @@ type DBDriver interface {
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DBColumn holds information about a database column name.
|
// Table metadata from the database schema.
|
||||||
// Column types are Go types, converted by ParseTableInfo.
|
type Table struct {
|
||||||
type DBColumn struct {
|
Name string
|
||||||
|
Columns []Column
|
||||||
|
|
||||||
|
IsJoinTable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column holds information about a database column.
|
||||||
|
// Types are Go types, converted by TranslateColumn.
|
||||||
|
type Column struct {
|
||||||
Name string
|
Name string
|
||||||
Type string
|
Type string
|
||||||
IsPrimaryKey bool
|
IsPrimaryKey bool
|
||||||
|
|
|
@ -44,11 +44,11 @@ func (d *PostgresDriver) Close() {
|
||||||
d.dbConn.Close()
|
d.dbConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllTables connects to the postgres database and
|
// AllTables connects to the postgres database and
|
||||||
// retrieves all table names from the information_schema where the
|
// retrieves all table names from the information_schema where the
|
||||||
// table schema is public. It excludes common migration tool tables
|
// table schema is public. It excludes common migration tool tables
|
||||||
// such as gorp_migrations
|
// such as gorp_migrations
|
||||||
func (d *PostgresDriver) GetAllTables() ([]string, error) {
|
func (d *PostgresDriver) AllTables() ([]string, error) {
|
||||||
var tableNames []string
|
var tableNames []string
|
||||||
|
|
||||||
rows, err := d.dbConn.Query(`select table_name from
|
rows, err := d.dbConn.Query(`select table_name from
|
||||||
|
@ -71,12 +71,12 @@ func (d *PostgresDriver) GetAllTables() ([]string, error) {
|
||||||
return tableNames, nil
|
return tableNames, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTableInfo takes a table name and attempts to retrieve the table information
|
// Columns takes a table name and attempts to retrieve the table information
|
||||||
// from the database information_schema.columns. It retrieves the column names
|
// from the database information_schema.columns. It retrieves the column names
|
||||||
// and column types and returns those as a []DBColumn after ParseTableInfo()
|
// and column types and returns those as a []Column after TranslateColumn()
|
||||||
// converts the SQL types to Go types, for example: "varchar" to "string"
|
// converts the SQL types to Go types, for example: "varchar" to "string"
|
||||||
func (d *PostgresDriver) GetTableInfo(tableName string) ([]DBColumn, error) {
|
func (d *PostgresDriver) Columns(tableName string) ([]Column, error) {
|
||||||
var table []DBColumn
|
var table []Column
|
||||||
|
|
||||||
rows, err := d.dbConn.Query(`
|
rows, err := d.dbConn.Query(`
|
||||||
SELECT c.column_name, c.data_type, c.is_nullable,
|
SELECT c.column_name, c.data_type, c.is_nullable,
|
||||||
|
@ -104,56 +104,55 @@ func (d *PostgresDriver) GetTableInfo(tableName string) ([]DBColumn, error) {
|
||||||
if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil {
|
if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t := d.ParseTableInfo(colName, colType, isNullable == "YES", isPrimary == "PRIMARY KEY")
|
t := d.TranslateColumn(Column{
|
||||||
|
Name: colName,
|
||||||
|
Type: colType,
|
||||||
|
IsNullable: isNullable == "YES",
|
||||||
|
IsPrimaryKey: isPrimary == "PRIMARY KEY",
|
||||||
|
})
|
||||||
table = append(table, t)
|
table = append(table, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return table, nil
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseTableInfo converts postgres database types to Go types, for example
|
// TranslateColumn converts postgres database types to Go types, for example
|
||||||
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
|
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
|
||||||
// as a DBColumn object.
|
// as a Column object.
|
||||||
func (d *PostgresDriver) ParseTableInfo(colName, colType string, isNullable bool, isPrimary bool) DBColumn {
|
func (d *PostgresDriver) TranslateColumn(c Column) Column {
|
||||||
var t DBColumn
|
if c.IsNullable {
|
||||||
|
switch c.Type {
|
||||||
t.Name = colName
|
|
||||||
t.IsPrimaryKey = isPrimary
|
|
||||||
t.IsNullable = isNullable
|
|
||||||
|
|
||||||
if isNullable {
|
|
||||||
switch colType {
|
|
||||||
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||||
t.Type = "null.Int"
|
c.Type = "null.Int"
|
||||||
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||||
t.Type = "null.String"
|
c.Type = "null.String"
|
||||||
case "boolean":
|
case "boolean":
|
||||||
t.Type = "null.Bool"
|
c.Type = "null.Bool"
|
||||||
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||||
t.Type = "null.Time"
|
c.Type = "null.Time"
|
||||||
case "double precision", "money", "numeric", "real":
|
case "double precision", "money", "numeric", "real":
|
||||||
t.Type = "null.Float"
|
c.Type = "null.Float"
|
||||||
default:
|
default:
|
||||||
t.Type = "null.String"
|
c.Type = "null.String"
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch colType {
|
switch c.Type {
|
||||||
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||||
t.Type = "int64"
|
c.Type = "int64"
|
||||||
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||||
t.Type = "string"
|
c.Type = "string"
|
||||||
case "bytea":
|
case "bytea":
|
||||||
t.Type = "[]byte"
|
c.Type = "[]byte"
|
||||||
case "boolean":
|
case "boolean":
|
||||||
t.Type = "bool"
|
c.Type = "bool"
|
||||||
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||||
t.Type = "time.Time"
|
c.Type = "time.Time"
|
||||||
case "double precision", "money", "numeric", "real":
|
case "double precision", "money", "numeric", "real":
|
||||||
t.Type = "float64"
|
c.Type = "float64"
|
||||||
default:
|
default:
|
||||||
t.Type = "string"
|
c.Type = "string"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return t
|
return c
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue