Struct generation complete, pipes to stdout

* Database driver config & flag complete
* Table flag complete, will use all tables if non specified
This commit is contained in:
Patrick O'brien 2016-02-23 18:27:32 +10:00
parent f72d646e66
commit b43f9c63d8
16 changed files with 419 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
config.toml

View file

@ -1,2 +1,11 @@
# sqlboiler # sqlboiler
SQL Boiler generates boilerplate structs and statements SQL Boiler generates boilerplate structs and statements
Create a file named config.toml in the root sqlboiler directory. An example config:
[postgres]
host="localhost"
port=5432
user="username"
pass="password"
dbname="database"

33
cmds/config.go Normal file
View file

@ -0,0 +1,33 @@
package cmds
import (
"fmt"
"os"
"github.com/BurntSushi/toml"
)
var cfg = struct {
Postgres struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
} `toml:"postgres"`
}{}
func init() {
_, err := toml.DecodeFile("config.toml", &cfg)
if err == nil {
return
}
if os.IsNotExist(err) {
return
}
if err != nil {
fmt.Println("Failed to decode toml configuration file:", err)
}
}

1
cmds/select.go Normal file
View file

@ -0,0 +1 @@
package cmds

1
cmds/select_test.go Normal file
View file

@ -0,0 +1 @@
package cmds

31
cmds/shared.go Normal file
View file

@ -0,0 +1,31 @@
package cmds
import (
"fmt"
"os"
"strings"
)
func errorQuit(err error) {
fmt.Println(fmt.Sprintf("Error: %s\n---\n", err))
structCmd.Help()
os.Exit(-1)
}
func makeGoColName(name string) string {
s := strings.Split(name, "_")
for i := 0; i < len(s); i++ {
if s[i] == "id" {
s[i] = "ID"
continue
}
s[i] = strings.Title(s[i])
}
return strings.Join(s, "")
}
func makeDBColName(tableName, colName string) string {
return tableName + "_" + colName
}

91
cmds/sqlboiler.go Normal file
View file

@ -0,0 +1,91 @@
package cmds
import (
"errors"
"strings"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
type CmdData struct {
TablesInfo [][]dbdrivers.DBTable
TableNames []string
DBDriver dbdrivers.DBDriver
}
var cmdData *CmdData
func init() {
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml")
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
}
var SQLBoiler = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
}
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
var err error
cmdData = &CmdData{}
// Retrieve driver flag
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
errorQuit(errors.New("Must supply a driver flag."))
}
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
cfg.Postgres.User,
cfg.Postgres.Pass,
cfg.Postgres.DBName,
cfg.Postgres.Host,
cfg.Postgres.Port,
)
}
// Connect to the driver database
if err = cmdData.DBDriver.Open(); err != nil {
errorQuit(err)
}
// Retrieve the list of tables
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 {
cmdData.TableNames = strings.Split(tn, ",")
for i, name := range cmdData.TableNames {
cmdData.TableNames[i] = strings.TrimSpace(name)
}
}
// If no table names are provided attempt to process all tables in database
if len(cmdData.TableNames) == 0 {
// get all table names
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
if err != nil {
errorQuit(err)
}
if len(cmdData.TableNames) == 0 {
errorQuit(errors.New("No tables found in database, migrate some tables first"))
}
}
// loop over table Names and build TablesInfo
for i := 0; i < len(cmdData.TableNames); i++ {
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
if err != nil {
errorQuit(err)
}
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
}
}

67
cmds/structs.go Normal file
View file

@ -0,0 +1,67 @@
package cmds
import (
"bytes"
"go/format"
"os"
"text/template"
"github.com/spf13/cobra"
)
func init() {
SQLBoiler.AddCommand(structCmd)
structCmd.Run = structRun
}
var structCmd = &cobra.Command{
Use: "struct",
Short: "Generate structs from table definitions",
}
type tplData struct {
TableName string
TableData interface{}
}
func structRun(cmd *cobra.Command, args []string) {
out := generateStructs()
for _, v := range out {
os.Stdout.Write(v)
}
}
func generateStructs() [][]byte {
t, err := template.New("struct.tpl").Funcs(template.FuncMap{
"makeGoColName": makeGoColName,
"makeDBColName": makeDBColName,
}).ParseFiles("templates/struct.tpl")
if err != nil {
errorQuit(err)
}
var structOutputs [][]byte
for i := 0; i < len(cmdData.TablesInfo); i++ {
data := tplData{
TableName: cmdData.TableNames[i],
TableData: cmdData.TablesInfo[i],
}
var buf bytes.Buffer
if err = t.Execute(&buf, data); err != nil {
errorQuit(err)
}
out, err := format.Source(buf.Bytes())
if err != nil {
errorQuit(err)
}
structOutputs = append(structOutputs, out)
}
return structOutputs
}

15
cmds/structs_test.go Normal file
View file

@ -0,0 +1,15 @@
package cmds
import "testing"
func testVerifyStructArgs(t *testing.T) {
err := verifyStructArgs([]string{})
if err == nil {
t.Error("Expected an error")
}
err = verifyStructArgs([]string{"hello"})
if err != nil {
t.Errorf("Expected error nil, got: %s", err)
}
}

16
dbdrivers/db_driver.go Normal file
View file

@ -0,0 +1,16 @@
package dbdrivers
type DBDriver interface {
GetAllTableNames() ([]string, error)
GetTableInfo(tableName string) ([]DBTable, error)
ParseTableInfo(colName, colType string) DBTable
// Open the database connection
Open() error
// Close the database connection
Close()
}
type DBTable struct {
ColName string
ColType string
}

View file

@ -0,0 +1,105 @@
package dbdrivers
import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
)
type PostgresDriver struct {
connStr string
dbConn *sql.DB
}
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
driver := PostgresDriver{
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
user, pass, dbname, host, port),
}
return &driver
}
func (d *PostgresDriver) Open() error {
var err error
d.dbConn, err = sql.Open("postgres", d.connStr)
if err != nil {
return err
}
return nil
}
func (d *PostgresDriver) Close() {
d.dbConn.Close()
}
func (d *PostgresDriver) GetAllTableNames() ([]string, error) {
var tableNames []string
rows, err := d.dbConn.Query(`select table_name from
information_schema.tables where table_schema='public'
and table_name <> 'gorp_migrations'`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, err
}
tableNames = append(tableNames, tableName)
}
return tableNames, nil
}
func (d *PostgresDriver) GetTableInfo(tableName string) ([]DBTable, error) {
var tableInfo []DBTable
rows, err := d.dbConn.Query(`select column_name, data_type from
information_schema.columns where table_name=$1`, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var colName, colType string
if err := rows.Scan(&colName, &colType); err != nil {
return nil, err
}
tableInfo = append(tableInfo, d.ParseTableInfo(colName, colType))
}
return tableInfo, nil
}
func (d *PostgresDriver) ParseTableInfo(colName, colType string) DBTable {
t := DBTable{}
t.ColName = colName
switch colType {
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
t.ColType = "int64"
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
t.ColType = "string"
case "bytea":
t.ColType = "[]byte"
case "boolean":
t.ColType = "bool"
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
t.ColType = "time.Time"
case "double precision", "money", "numeric", "real":
t.ColType = "float64"
default:
t.ColType = "string"
}
return t
}

View file

@ -0,0 +1,9 @@
package dbdrivers
import "testing"
func TestFlow(t *testing.T) {
// driver := NewPostgresDriver("test", "pass", "dbname", "localhost", 3456)
// defer driver.Close()
// driver.GetTableInfo()
}

15
main.go Normal file
View file

@ -0,0 +1,15 @@
package main
import (
"fmt"
"os"
"github.com/pobri19/sqlboiler/cmds"
)
func main() {
if err := cmds.SQLBoiler.Execute(); err != nil {
fmt.Println(err)
os.Exit(-1)
}
}

BIN
sqlboiler Executable file

Binary file not shown.

19
templates/select.tpl Normal file
View file

@ -0,0 +1,19 @@
func Insert{{makeGoColName $tableName}}(o *{{makeGoColName $tableName}}, db *sqlx.DB) (int, error) {
if o == nil {
return 0, errors.New("No {{objName}} provided for insertion")
}
var rowID int
err := db.QueryRow(
`INSERT INTO {{tableName}}
({{makeGoInsertParamNames tableData}})
VALUES({{makeGoInsertParamFlags tableData}})
RETURNING id`
)
if err != nil {
return 0, fmt.Errorf("Unable to insert {{objName}}: %s", err)
}
return rowID, nil
}

6
templates/struct.tpl Normal file
View file

@ -0,0 +1,6 @@
{{- $tableName := .TableName -}}
type {{makeGoColName $tableName}} struct {
{{range $key, $value := .TableData -}}
{{makeGoColName $value.ColName}} {{$value.ColType}} `db:"{{makeDBColName $tableName $value.ColName}}",json:"{{$value.ColName}}"`
{{end -}}
}