Struct generation complete, pipes to stdout
* Database driver config & flag complete * Table flag complete, will use all tables if non specified
This commit is contained in:
parent
f72d646e66
commit
b43f9c63d8
16 changed files with 419 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
config.toml
|
|
@ -1,2 +1,11 @@
|
||||||
# sqlboiler
|
# sqlboiler
|
||||||
SQL Boiler generates boilerplate structs and statements
|
SQL Boiler generates boilerplate structs and statements
|
||||||
|
|
||||||
|
Create a file named config.toml in the root sqlboiler directory. An example config:
|
||||||
|
|
||||||
|
[postgres]
|
||||||
|
host="localhost"
|
||||||
|
port=5432
|
||||||
|
user="username"
|
||||||
|
pass="password"
|
||||||
|
dbname="database"
|
||||||
|
|
33
cmds/config.go
Normal file
33
cmds/config.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cfg = struct {
|
||||||
|
Postgres struct {
|
||||||
|
User string `toml:"user"`
|
||||||
|
Pass string `toml:"pass"`
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
DBName string `toml:"dbname"`
|
||||||
|
} `toml:"postgres"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
_, err := toml.DecodeFile("config.toml", &cfg)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to decode toml configuration file:", err)
|
||||||
|
}
|
||||||
|
}
|
1
cmds/select.go
Normal file
1
cmds/select.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package cmds
|
1
cmds/select_test.go
Normal file
1
cmds/select_test.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package cmds
|
31
cmds/shared.go
Normal file
31
cmds/shared.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errorQuit(err error) {
|
||||||
|
fmt.Println(fmt.Sprintf("Error: %s\n---\n", err))
|
||||||
|
structCmd.Help()
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeGoColName(name string) string {
|
||||||
|
s := strings.Split(name, "_")
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == "id" {
|
||||||
|
s[i] = "ID"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s[i] = strings.Title(s[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(s, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDBColName(tableName, colName string) string {
|
||||||
|
return tableName + "_" + colName
|
||||||
|
}
|
91
cmds/sqlboiler.go
Normal file
91
cmds/sqlboiler.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CmdData struct {
|
||||||
|
TablesInfo [][]dbdrivers.DBTable
|
||||||
|
TableNames []string
|
||||||
|
DBDriver dbdrivers.DBDriver
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmdData *CmdData
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml")
|
||||||
|
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||||||
|
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
|
||||||
|
}
|
||||||
|
|
||||||
|
var SQLBoiler = &cobra.Command{
|
||||||
|
Use: "sqlboiler",
|
||||||
|
Short: "SQL Boiler generates boilerplate structs and statements",
|
||||||
|
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
|
||||||
|
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
||||||
|
var err error
|
||||||
|
cmdData = &CmdData{}
|
||||||
|
|
||||||
|
// Retrieve driver flag
|
||||||
|
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
||||||
|
if driverName == "" {
|
||||||
|
errorQuit(errors.New("Must supply a driver flag."))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a driver based off driver flag
|
||||||
|
switch driverName {
|
||||||
|
case "postgres":
|
||||||
|
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
|
||||||
|
cfg.Postgres.User,
|
||||||
|
cfg.Postgres.Pass,
|
||||||
|
cfg.Postgres.DBName,
|
||||||
|
cfg.Postgres.Host,
|
||||||
|
cfg.Postgres.Port,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the driver database
|
||||||
|
if err = cmdData.DBDriver.Open(); err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve the list of tables
|
||||||
|
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
|
||||||
|
|
||||||
|
if len(tn) != 0 {
|
||||||
|
cmdData.TableNames = strings.Split(tn, ",")
|
||||||
|
for i, name := range cmdData.TableNames {
|
||||||
|
cmdData.TableNames[i] = strings.TrimSpace(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no table names are provided attempt to process all tables in database
|
||||||
|
if len(cmdData.TableNames) == 0 {
|
||||||
|
// get all table names
|
||||||
|
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
|
||||||
|
if err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmdData.TableNames) == 0 {
|
||||||
|
errorQuit(errors.New("No tables found in database, migrate some tables first"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over table Names and build TablesInfo
|
||||||
|
for i := 0; i < len(cmdData.TableNames); i++ {
|
||||||
|
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
|
||||||
|
if err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
|
||||||
|
}
|
||||||
|
}
|
67
cmds/structs.go
Normal file
67
cmds/structs.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"go/format"
|
||||||
|
"os"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SQLBoiler.AddCommand(structCmd)
|
||||||
|
structCmd.Run = structRun
|
||||||
|
}
|
||||||
|
|
||||||
|
var structCmd = &cobra.Command{
|
||||||
|
Use: "struct",
|
||||||
|
Short: "Generate structs from table definitions",
|
||||||
|
}
|
||||||
|
|
||||||
|
type tplData struct {
|
||||||
|
TableName string
|
||||||
|
TableData interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func structRun(cmd *cobra.Command, args []string) {
|
||||||
|
out := generateStructs()
|
||||||
|
|
||||||
|
for _, v := range out {
|
||||||
|
os.Stdout.Write(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateStructs() [][]byte {
|
||||||
|
t, err := template.New("struct.tpl").Funcs(template.FuncMap{
|
||||||
|
"makeGoColName": makeGoColName,
|
||||||
|
"makeDBColName": makeDBColName,
|
||||||
|
}).ParseFiles("templates/struct.tpl")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var structOutputs [][]byte
|
||||||
|
|
||||||
|
for i := 0; i < len(cmdData.TablesInfo); i++ {
|
||||||
|
data := tplData{
|
||||||
|
TableName: cmdData.TableNames[i],
|
||||||
|
TableData: cmdData.TablesInfo[i],
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err = t.Execute(&buf, data); err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
errorQuit(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
structOutputs = append(structOutputs, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return structOutputs
|
||||||
|
}
|
15
cmds/structs_test.go
Normal file
15
cmds/structs_test.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package cmds
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func testVerifyStructArgs(t *testing.T) {
|
||||||
|
err := verifyStructArgs([]string{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected an error")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyStructArgs([]string{"hello"})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected error nil, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
16
dbdrivers/db_driver.go
Normal file
16
dbdrivers/db_driver.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package dbdrivers
|
||||||
|
|
||||||
|
type DBDriver interface {
|
||||||
|
GetAllTableNames() ([]string, error)
|
||||||
|
GetTableInfo(tableName string) ([]DBTable, error)
|
||||||
|
ParseTableInfo(colName, colType string) DBTable
|
||||||
|
// Open the database connection
|
||||||
|
Open() error
|
||||||
|
// Close the database connection
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type DBTable struct {
|
||||||
|
ColName string
|
||||||
|
ColType string
|
||||||
|
}
|
105
dbdrivers/postgres_driver.go
Normal file
105
dbdrivers/postgres_driver.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package dbdrivers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PostgresDriver struct {
|
||||||
|
connStr string
|
||||||
|
dbConn *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
|
||||||
|
driver := PostgresDriver{
|
||||||
|
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||||
|
user, pass, dbname, host, port),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &driver
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDriver) Open() error {
|
||||||
|
var err error
|
||||||
|
d.dbConn, err = sql.Open("postgres", d.connStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDriver) Close() {
|
||||||
|
d.dbConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDriver) GetAllTableNames() ([]string, error) {
|
||||||
|
var tableNames []string
|
||||||
|
|
||||||
|
rows, err := d.dbConn.Query(`select table_name from
|
||||||
|
information_schema.tables where table_schema='public'
|
||||||
|
and table_name <> 'gorp_migrations'`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var tableName string
|
||||||
|
if err := rows.Scan(&tableName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tableNames = append(tableNames, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableNames, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDriver) GetTableInfo(tableName string) ([]DBTable, error) {
|
||||||
|
var tableInfo []DBTable
|
||||||
|
|
||||||
|
rows, err := d.dbConn.Query(`select column_name, data_type from
|
||||||
|
information_schema.columns where table_name=$1`, tableName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var colName, colType string
|
||||||
|
if err := rows.Scan(&colName, &colType); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tableInfo = append(tableInfo, d.ParseTableInfo(colName, colType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDriver) ParseTableInfo(colName, colType string) DBTable {
|
||||||
|
t := DBTable{}
|
||||||
|
|
||||||
|
t.ColName = colName
|
||||||
|
switch colType {
|
||||||
|
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||||
|
t.ColType = "int64"
|
||||||
|
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||||
|
t.ColType = "string"
|
||||||
|
case "bytea":
|
||||||
|
t.ColType = "[]byte"
|
||||||
|
case "boolean":
|
||||||
|
t.ColType = "bool"
|
||||||
|
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||||
|
t.ColType = "time.Time"
|
||||||
|
case "double precision", "money", "numeric", "real":
|
||||||
|
t.ColType = "float64"
|
||||||
|
default:
|
||||||
|
t.ColType = "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
9
dbdrivers/postgres_driver_test.go
Normal file
9
dbdrivers/postgres_driver_test.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package dbdrivers
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestFlow(t *testing.T) {
|
||||||
|
// driver := NewPostgresDriver("test", "pass", "dbname", "localhost", 3456)
|
||||||
|
// defer driver.Close()
|
||||||
|
// driver.GetTableInfo()
|
||||||
|
}
|
15
main.go
Normal file
15
main.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pobri19/sqlboiler/cmds"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := cmds.SQLBoiler.Execute(); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
}
|
BIN
sqlboiler
Executable file
BIN
sqlboiler
Executable file
Binary file not shown.
19
templates/select.tpl
Normal file
19
templates/select.tpl
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
func Insert{{makeGoColName $tableName}}(o *{{makeGoColName $tableName}}, db *sqlx.DB) (int, error) {
|
||||||
|
if o == nil {
|
||||||
|
return 0, errors.New("No {{objName}} provided for insertion")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rowID int
|
||||||
|
err := db.QueryRow(
|
||||||
|
`INSERT INTO {{tableName}}
|
||||||
|
({{makeGoInsertParamNames tableData}})
|
||||||
|
VALUES({{makeGoInsertParamFlags tableData}})
|
||||||
|
RETURNING id`
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("Unable to insert {{objName}}: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowID, nil
|
||||||
|
}
|
6
templates/struct.tpl
Normal file
6
templates/struct.tpl
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{{- $tableName := .TableName -}}
|
||||||
|
type {{makeGoColName $tableName}} struct {
|
||||||
|
{{range $key, $value := .TableData -}}
|
||||||
|
{{makeGoColName $value.ColName}} {{$value.ColType}} `db:"{{makeDBColName $tableName $value.ColName}}",json:"{{$value.ColName}}"`
|
||||||
|
{{end -}}
|
||||||
|
}
|
Loading…
Reference in a new issue