Struct generation complete, pipes to stdout
* Database driver config & flag complete * Table flag complete, will use all tables if non specified
This commit is contained in:
parent
f72d646e66
commit
b43f9c63d8
16 changed files with 419 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
config.toml
|
|
@ -1,2 +1,11 @@
|
|||
# sqlboiler
|
||||
SQL Boiler generates boilerplate structs and statements
|
||||
|
||||
Create a file named config.toml in the root sqlboiler directory. An example config:
|
||||
|
||||
[postgres]
|
||||
host="localhost"
|
||||
port=5432
|
||||
user="username"
|
||||
pass="password"
|
||||
dbname="database"
|
||||
|
|
33
cmds/config.go
Normal file
33
cmds/config.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
var cfg = struct {
|
||||
Postgres struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
} `toml:"postgres"`
|
||||
}{}
|
||||
|
||||
func init() {
|
||||
_, err := toml.DecodeFile("config.toml", &cfg)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("Failed to decode toml configuration file:", err)
|
||||
}
|
||||
}
|
1
cmds/select.go
Normal file
1
cmds/select.go
Normal file
|
@ -0,0 +1 @@
|
|||
package cmds
|
1
cmds/select_test.go
Normal file
1
cmds/select_test.go
Normal file
|
@ -0,0 +1 @@
|
|||
package cmds
|
31
cmds/shared.go
Normal file
31
cmds/shared.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func errorQuit(err error) {
|
||||
fmt.Println(fmt.Sprintf("Error: %s\n---\n", err))
|
||||
structCmd.Help()
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
func makeGoColName(name string) string {
|
||||
s := strings.Split(name, "_")
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == "id" {
|
||||
s[i] = "ID"
|
||||
continue
|
||||
}
|
||||
s[i] = strings.Title(s[i])
|
||||
}
|
||||
|
||||
return strings.Join(s, "")
|
||||
}
|
||||
|
||||
func makeDBColName(tableName, colName string) string {
|
||||
return tableName + "_" + colName
|
||||
}
|
91
cmds/sqlboiler.go
Normal file
91
cmds/sqlboiler.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CmdData struct {
|
||||
TablesInfo [][]dbdrivers.DBTable
|
||||
TableNames []string
|
||||
DBDriver dbdrivers.DBDriver
|
||||
}
|
||||
|
||||
var cmdData *CmdData
|
||||
|
||||
func init() {
|
||||
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml")
|
||||
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
|
||||
}
|
||||
|
||||
var SQLBoiler = &cobra.Command{
|
||||
Use: "sqlboiler",
|
||||
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
|
||||
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
|
||||
}
|
||||
|
||||
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
cmdData = &CmdData{}
|
||||
|
||||
// Retrieve driver flag
|
||||
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
||||
if driverName == "" {
|
||||
errorQuit(errors.New("Must supply a driver flag."))
|
||||
}
|
||||
|
||||
// Create a driver based off driver flag
|
||||
switch driverName {
|
||||
case "postgres":
|
||||
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
|
||||
cfg.Postgres.User,
|
||||
cfg.Postgres.Pass,
|
||||
cfg.Postgres.DBName,
|
||||
cfg.Postgres.Host,
|
||||
cfg.Postgres.Port,
|
||||
)
|
||||
}
|
||||
|
||||
// Connect to the driver database
|
||||
if err = cmdData.DBDriver.Open(); err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
// Retrieve the list of tables
|
||||
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
|
||||
|
||||
if len(tn) != 0 {
|
||||
cmdData.TableNames = strings.Split(tn, ",")
|
||||
for i, name := range cmdData.TableNames {
|
||||
cmdData.TableNames[i] = strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
// If no table names are provided attempt to process all tables in database
|
||||
if len(cmdData.TableNames) == 0 {
|
||||
// get all table names
|
||||
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
|
||||
if err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
if len(cmdData.TableNames) == 0 {
|
||||
errorQuit(errors.New("No tables found in database, migrate some tables first"))
|
||||
}
|
||||
}
|
||||
|
||||
// loop over table Names and build TablesInfo
|
||||
for i := 0; i < len(cmdData.TableNames); i++ {
|
||||
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
|
||||
if err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
|
||||
}
|
||||
}
|
67
cmds/structs.go
Normal file
67
cmds/structs.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package cmds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"go/format"
|
||||
"os"
|
||||
"text/template"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func init() {
|
||||
SQLBoiler.AddCommand(structCmd)
|
||||
structCmd.Run = structRun
|
||||
}
|
||||
|
||||
var structCmd = &cobra.Command{
|
||||
Use: "struct",
|
||||
Short: "Generate structs from table definitions",
|
||||
}
|
||||
|
||||
type tplData struct {
|
||||
TableName string
|
||||
TableData interface{}
|
||||
}
|
||||
|
||||
func structRun(cmd *cobra.Command, args []string) {
|
||||
out := generateStructs()
|
||||
|
||||
for _, v := range out {
|
||||
os.Stdout.Write(v)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStructs() [][]byte {
|
||||
t, err := template.New("struct.tpl").Funcs(template.FuncMap{
|
||||
"makeGoColName": makeGoColName,
|
||||
"makeDBColName": makeDBColName,
|
||||
}).ParseFiles("templates/struct.tpl")
|
||||
|
||||
if err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
var structOutputs [][]byte
|
||||
|
||||
for i := 0; i < len(cmdData.TablesInfo); i++ {
|
||||
data := tplData{
|
||||
TableName: cmdData.TableNames[i],
|
||||
TableData: cmdData.TablesInfo[i],
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err = t.Execute(&buf, data); err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
out, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
errorQuit(err)
|
||||
}
|
||||
|
||||
structOutputs = append(structOutputs, out)
|
||||
}
|
||||
|
||||
return structOutputs
|
||||
}
|
15
cmds/structs_test.go
Normal file
15
cmds/structs_test.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package cmds
|
||||
|
||||
import "testing"
|
||||
|
||||
func testVerifyStructArgs(t *testing.T) {
|
||||
err := verifyStructArgs([]string{})
|
||||
if err == nil {
|
||||
t.Error("Expected an error")
|
||||
}
|
||||
|
||||
err = verifyStructArgs([]string{"hello"})
|
||||
if err != nil {
|
||||
t.Errorf("Expected error nil, got: %s", err)
|
||||
}
|
||||
}
|
16
dbdrivers/db_driver.go
Normal file
16
dbdrivers/db_driver.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package dbdrivers
|
||||
|
||||
type DBDriver interface {
|
||||
GetAllTableNames() ([]string, error)
|
||||
GetTableInfo(tableName string) ([]DBTable, error)
|
||||
ParseTableInfo(colName, colType string) DBTable
|
||||
// Open the database connection
|
||||
Open() error
|
||||
// Close the database connection
|
||||
Close()
|
||||
}
|
||||
|
||||
type DBTable struct {
|
||||
ColName string
|
||||
ColType string
|
||||
}
|
105
dbdrivers/postgres_driver.go
Normal file
105
dbdrivers/postgres_driver.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package dbdrivers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type PostgresDriver struct {
|
||||
connStr string
|
||||
dbConn *sql.DB
|
||||
}
|
||||
|
||||
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
|
||||
driver := PostgresDriver{
|
||||
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||
user, pass, dbname, host, port),
|
||||
}
|
||||
|
||||
return &driver
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) Open() error {
|
||||
var err error
|
||||
d.dbConn, err = sql.Open("postgres", d.connStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) Close() {
|
||||
d.dbConn.Close()
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) GetAllTableNames() ([]string, error) {
|
||||
var tableNames []string
|
||||
|
||||
rows, err := d.dbConn.Query(`select table_name from
|
||||
information_schema.tables where table_schema='public'
|
||||
and table_name <> 'gorp_migrations'`)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var tableName string
|
||||
if err := rows.Scan(&tableName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
|
||||
return tableNames, nil
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) GetTableInfo(tableName string) ([]DBTable, error) {
|
||||
var tableInfo []DBTable
|
||||
|
||||
rows, err := d.dbConn.Query(`select column_name, data_type from
|
||||
information_schema.columns where table_name=$1`, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var colName, colType string
|
||||
if err := rows.Scan(&colName, &colType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableInfo = append(tableInfo, d.ParseTableInfo(colName, colType))
|
||||
}
|
||||
|
||||
return tableInfo, nil
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) ParseTableInfo(colName, colType string) DBTable {
|
||||
t := DBTable{}
|
||||
|
||||
t.ColName = colName
|
||||
switch colType {
|
||||
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||
t.ColType = "int64"
|
||||
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||
t.ColType = "string"
|
||||
case "bytea":
|
||||
t.ColType = "[]byte"
|
||||
case "boolean":
|
||||
t.ColType = "bool"
|
||||
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||
t.ColType = "time.Time"
|
||||
case "double precision", "money", "numeric", "real":
|
||||
t.ColType = "float64"
|
||||
default:
|
||||
t.ColType = "string"
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
9
dbdrivers/postgres_driver_test.go
Normal file
9
dbdrivers/postgres_driver_test.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package dbdrivers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFlow(t *testing.T) {
|
||||
// driver := NewPostgresDriver("test", "pass", "dbname", "localhost", 3456)
|
||||
// defer driver.Close()
|
||||
// driver.GetTableInfo()
|
||||
}
|
15
main.go
Normal file
15
main.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/pobri19/sqlboiler/cmds"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmds.SQLBoiler.Execute(); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
}
|
BIN
sqlboiler
Executable file
BIN
sqlboiler
Executable file
Binary file not shown.
19
templates/select.tpl
Normal file
19
templates/select.tpl
Normal file
|
@ -0,0 +1,19 @@
|
|||
func Insert{{makeGoColName $tableName}}(o *{{makeGoColName $tableName}}, db *sqlx.DB) (int, error) {
|
||||
if o == nil {
|
||||
return 0, errors.New("No {{objName}} provided for insertion")
|
||||
}
|
||||
|
||||
var rowID int
|
||||
err := db.QueryRow(
|
||||
`INSERT INTO {{tableName}}
|
||||
({{makeGoInsertParamNames tableData}})
|
||||
VALUES({{makeGoInsertParamFlags tableData}})
|
||||
RETURNING id`
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Unable to insert {{objName}}: %s", err)
|
||||
}
|
||||
|
||||
return rowID, nil
|
||||
}
|
6
templates/struct.tpl
Normal file
6
templates/struct.tpl
Normal file
|
@ -0,0 +1,6 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
type {{makeGoColName $tableName}} struct {
|
||||
{{range $key, $value := .TableData -}}
|
||||
{{makeGoColName $value.ColName}} {{$value.ColType}} `db:"{{makeDBColName $tableName $value.ColName}}",json:"{{$value.ColName}}"`
|
||||
{{end -}}
|
||||
}
|
Loading…
Reference in a new issue