More swaths of renaming awkward structs.
This commit is contained in:
parent
4f619d4efe
commit
8a591ecdd6
19 changed files with 106 additions and 68 deletions
|
@ -44,10 +44,10 @@ func boilRun(cmd *cobra.Command, args []string) {
|
|||
// Prepend "struct" command to templateNames slice so it sits at top of sort
|
||||
templateNames = append([]string{"struct"}, templateNames...)
|
||||
|
||||
for i := 0; i < len(cmdData.TablesInfo); i++ {
|
||||
for i := 0; i < len(cmdData.Columns); i++ {
|
||||
data := tplData{
|
||||
TableName: cmdData.TableNames[i],
|
||||
TableData: cmdData.TablesInfo[i],
|
||||
Table: cmdData.Tables[i],
|
||||
Columns: cmdData.Columns[i],
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
|
|
|
@ -69,6 +69,7 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
|
|||
"selectParamNames": selectParamNames,
|
||||
"insertParamNames": insertParamNames,
|
||||
"insertParamFlags": insertParamFlags,
|
||||
"scanParamNames": scanParamNames,
|
||||
}
|
||||
|
||||
var allCmd = &cobra.Command{
|
||||
|
|
|
@ -16,17 +16,17 @@ type CobraRunFunc func(cmd *cobra.Command, args []string)
|
|||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||
// output file, if one is specified with a flag.
|
||||
type CmdData struct {
|
||||
TablesInfo [][]dbdrivers.DBColumn
|
||||
TableNames []string
|
||||
PkgName string
|
||||
OutFolder string
|
||||
DBDriver dbdrivers.DBDriver
|
||||
Tables []string
|
||||
Columns [][]dbdrivers.DBColumn
|
||||
PkgName string
|
||||
OutFolder string
|
||||
DBDriver dbdrivers.DBDriver
|
||||
}
|
||||
|
||||
// tplData is used to pass data to the template
|
||||
type tplData struct {
|
||||
TableName string
|
||||
TableData []dbdrivers.DBColumn
|
||||
Table string
|
||||
Columns []dbdrivers.DBColumn
|
||||
}
|
||||
|
||||
// errorQuit displays an error message and then exits the application.
|
||||
|
@ -39,10 +39,10 @@ func errorQuit(err error) {
|
|||
// It will generate the specific commands template and send it to outHandler for output.
|
||||
func defaultRun(cmd *cobra.Command, args []string) {
|
||||
// Generate the template for every table
|
||||
for i := 0; i < len(cmdData.TablesInfo); i++ {
|
||||
for i := 0; i < len(cmdData.Columns); i++ {
|
||||
data := tplData{
|
||||
TableName: cmdData.TableNames[i],
|
||||
TableData: cmdData.TablesInfo[i],
|
||||
Table: cmdData.Tables[i],
|
||||
Columns: cmdData.Columns[i],
|
||||
}
|
||||
|
||||
// outHandler takes a slice of byte slices, so append the Template
|
||||
|
@ -72,7 +72,7 @@ func outHandler(output [][]byte, data *tplData) error {
|
|||
}
|
||||
}
|
||||
} else { // If not using stdout, attempt to create the model file.
|
||||
path := cmdData.OutFolder + "/" + data.TableName + ".go"
|
||||
path := cmdData.OutFolder + "/" + data.Table + ".go"
|
||||
out, err := os.Create(path)
|
||||
if err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))
|
||||
|
|
|
@ -13,6 +13,10 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
templatesDirectory = "/cmds/templates"
|
||||
)
|
||||
|
||||
// cmdData is used globally by all commands to access the table schema data,
|
||||
// the database driver and the output file. cmdData is initialized by
|
||||
// the root SQLBoiler cobra command at run time, before other commands execute.
|
||||
|
@ -68,11 +72,11 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
|||
errorQuit(fmt.Errorf("Unable to connect to the database: %s", err))
|
||||
}
|
||||
|
||||
// Initialize the cmdData.TableNames
|
||||
initTableNames()
|
||||
// Initialize the cmdData.Tables
|
||||
initTables()
|
||||
|
||||
// Initialize the cmdData.TablesInfo
|
||||
initTablesInfo()
|
||||
// Initialize the cmdData.Columns
|
||||
initColumns()
|
||||
|
||||
// Initialize the package name
|
||||
initPkgName()
|
||||
|
@ -81,7 +85,7 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
|||
initOutFolder()
|
||||
|
||||
// Initialize the templates
|
||||
templates, err = initTemplates()
|
||||
templates, err = initTemplates(templatesDirectory)
|
||||
if err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to initialize templates: %s", err))
|
||||
}
|
||||
|
@ -113,46 +117,46 @@ func initDBDriver() {
|
|||
}
|
||||
}
|
||||
|
||||
// initTableNames will create a string slice out of the passed in table flag value
|
||||
// initTables will create a string slice out of the passed in table flag value
|
||||
// if one is provided. If no flag is provided, it will attempt to connect to the
|
||||
// database to retrieve all "public" table names, and build a slice out of that result.
|
||||
func initTableNames() {
|
||||
func initTables() {
|
||||
// Retrieve the list of tables
|
||||
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
|
||||
|
||||
if len(tn) != 0 {
|
||||
cmdData.TableNames = strings.Split(tn, ",")
|
||||
for i, name := range cmdData.TableNames {
|
||||
cmdData.TableNames[i] = strings.TrimSpace(name)
|
||||
cmdData.Tables = strings.Split(tn, ",")
|
||||
for i, name := range cmdData.Tables {
|
||||
cmdData.Tables[i] = strings.TrimSpace(name)
|
||||
}
|
||||
}
|
||||
|
||||
// If no table names are provided attempt to process all tables in database
|
||||
if len(cmdData.TableNames) == 0 {
|
||||
if len(cmdData.Tables) == 0 {
|
||||
// get all table names
|
||||
var err error
|
||||
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
|
||||
cmdData.Tables, err = cmdData.DBDriver.GetAllTables()
|
||||
if err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to get all table names: %s", err))
|
||||
}
|
||||
|
||||
if len(cmdData.TableNames) == 0 {
|
||||
if len(cmdData.Tables) == 0 {
|
||||
errorQuit(errors.New("No tables found in database, migrate some tables first"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initTablesInfo builds a description of each table (column name, column type)
|
||||
// and assigns it to cmdData.TablesInfo, the slice of dbdrivers.DBColumn slices.
|
||||
func initTablesInfo() {
|
||||
// loop over table Names and build TablesInfo
|
||||
for i := 0; i < len(cmdData.TableNames); i++ {
|
||||
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
|
||||
// initColumns builds a description of each table (column name, column type)
|
||||
// and assigns it to cmdData.Columns, the slice of dbdrivers.DBColumn slices.
|
||||
func initColumns() {
|
||||
// loop over table Names and build Columns
|
||||
for i := 0; i < len(cmdData.Tables); i++ {
|
||||
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.Tables[i])
|
||||
if err != nil {
|
||||
errorQuit(fmt.Errorf("Unable to get the table info: %s", err))
|
||||
}
|
||||
|
||||
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
|
||||
cmdData.Columns = append(cmdData.Columns, tInfo)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,13 +182,13 @@ func initOutFolder() {
|
|||
|
||||
// initTemplates loads all of the template files in the /cmds/templates directory
|
||||
// and returns a slice of pointers to these templates.
|
||||
func initTemplates() ([]*template.Template, error) {
|
||||
func initTemplates(dir string) ([]*template.Template, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := filepath.Join(wd, "/cmds/templates", "*.tpl")
|
||||
pattern := filepath.Join(wd, dir, "*.tpl")
|
||||
tpl, err := template.New("").Funcs(sqlBoilerTemplateFuncs).ParseGlob(pattern)
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -138,3 +138,15 @@ func selectParamNames(tableName string, columns []dbdrivers.DBColumn) string {
|
|||
|
||||
return strings.Join(selects, ", ")
|
||||
}
|
||||
|
||||
// scanParamNames takes a []DBColumn and returns a comma seperated
|
||||
// list of parameter names for use in a db.Scan() call.
|
||||
func scanParamNames(object string, columns []dbdrivers.DBColumn) string {
|
||||
scans := make([]string, 0, len(columns))
|
||||
for _, c := range columns {
|
||||
statement := fmt.Sprintf("&%s.%s", object, titleCase(c.Name))
|
||||
scans = append(scans, statement)
|
||||
}
|
||||
|
||||
return strings.Join(scans, ", ")
|
||||
}
|
||||
|
|
|
@ -71,3 +71,10 @@ func TestSelectParamFlags(t *testing.T) {
|
|||
t.Error("Wrong output:", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanParams(t *testing.T) {
|
||||
out := scanParamNames("object", testColumns)
|
||||
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
|
||||
t.Error("Wrong output:", out)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,27 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
// {{titleCase $tableName}}All retrieves all records.
|
||||
func {{titleCase $tableName}}All(db boil.DB) ([]*{{titleCase $tableName}}, error) {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
|
||||
{{- $tableName := titleCase .Table -}}
|
||||
{{- $varName := camelCase $tableName -}}
|
||||
// {{$tableName}}All retrieves all records.
|
||||
func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) {
|
||||
var {{$varName}} []*{{$tableName}}
|
||||
|
||||
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
return nil, fmt.Errorf("models: failed to query: %v", err)
|
||||
}
|
||||
|
||||
return {{$varName}}, nil
|
||||
for rows.Next() {
|
||||
{{$varName}}Tmp := {{$tableName}}{}
|
||||
|
||||
if err := rows.Scan({{scanParamNames $varName .Columns}}); err != nil {
|
||||
return nil, fmt.Errorf("models: failed to scan row: %v", err)
|
||||
}
|
||||
|
||||
{{$varName}} = append({{$varName}}, {{$varName}}Tmp)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("models: failed to read rows: %v", err)
|
||||
}
|
||||
|
||||
return {{$varName}}, nil
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}AllBy retrieves all records with the specified column values.
|
||||
func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}Delete deletes a single record.
|
||||
func {{titleCase $tableName}}Delete(db boil.DB, id int) error {
|
||||
if id == nil {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records.
|
||||
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
|
||||
// For example: friendName string `db:"friend_name"`
|
||||
func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FieldsAllBy retrieves the specified columns
|
||||
// for all records with the specified column values.
|
||||
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
|
||||
|
@ -6,7 +6,7 @@
|
|||
func {{titleCase $tableName}}FieldsAllBy(db boil.DB, columns map[string]interface{}, results interface{}) error {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}}`)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FieldsFind retrieves the specified columns for a single record by ID.
|
||||
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
|
||||
// For example: friendName string `db:"friend_name"`
|
||||
|
@ -8,7 +8,7 @@ func {{titleCase $tableName}}FieldsFind(db boil.DB, id int, results interface{})
|
|||
}
|
||||
{{$varName := camelCase $tableName}}
|
||||
var {{$varName}} *{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FieldsFindBy retrieves the specified columns
|
||||
// for a single record with the specified column values.
|
||||
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
|
||||
|
@ -9,7 +9,7 @@ func {{titleCase $tableName}}FieldsFindBy(db boil.DB, columns map[string]interfa
|
|||
}
|
||||
{{$varName := camelCase $tableName}}
|
||||
var {{$varName}} *{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}Find retrieves a single record by ID.
|
||||
func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}, error) {
|
||||
if id == 0 {
|
||||
|
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Find(db boil.DB, id int) (*{{titleCase $tableName}}
|
|||
}
|
||||
{{$varName := camelCase $tableName}}
|
||||
var {{$varName}} *{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .TableData}} WHERE id=$1`, id)
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}} WHERE id=$1`, id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FindBy retrieves a single record with the specified column values.
|
||||
func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{}) (*{{titleCase $tableName}}, error) {
|
||||
if id == 0 {
|
||||
|
@ -6,7 +6,7 @@ func {{titleCase $tableName}}FindBy(db boil.DB, columns map[string]interface{})
|
|||
}
|
||||
{{$varName := camelCase $tableName}}
|
||||
var {{$varName}} *{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .TableData}} WHERE %s=$1`, column), value)
|
||||
err := db.Select(&{{$varName}}, fmt.Sprintf(`SELECT {{selectParamNames $tableName .Columns}} WHERE %s=$1`, column), value)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}Insert inserts a single record.
|
||||
func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (int, error) {
|
||||
if o == nil {
|
||||
|
@ -6,7 +6,7 @@ func {{titleCase $tableName}}Insert(db boil.DB, o *{{titleCase $tableName}}) (in
|
|||
}
|
||||
|
||||
var rowID int
|
||||
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .TableData}}) VALUES({{insertParamFlags .TableData}}) RETURNING id`)
|
||||
err := db.QueryRow(`INSERT INTO {{$tableName}} ({{insertParamNames .Columns}}) VALUES({{insertParamFlags .Columns}}) RETURNING id`)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("model: unable to insert {{$tableName}}: %s", err)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{{- $tableName := .TableName -}}
|
||||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}} is an object representing the database table.
|
||||
type {{titleCase $tableName}} struct {
|
||||
{{range $key, $value := .TableData -}}
|
||||
{{range $key, $value := .Columns -}}
|
||||
{{titleCase $value.Name}} {{$value.Type}} `db:"{{makeDBName $tableName $value.Name}}" json:"{{$value.Name}}"`
|
||||
{{end -}}
|
||||
}
|
||||
|
|
|
@ -4,12 +4,12 @@ package dbdrivers
|
|||
// type of database connection. For example, queries to obtain schema data
|
||||
// will vary depending on what type of database software is in use.
|
||||
// The goal of the DBDriver is to retrieve all table names in a database
|
||||
// using GetAllTableNames() if no table names are provided via flags,
|
||||
// using GetAllTables() if no table names are provided via flags,
|
||||
// to handle the database connection using Open() and Close(), and to
|
||||
// build the table information using GetTableInfo() and ParseTableInfo()
|
||||
type DBDriver interface {
|
||||
// GetAllTableNames connects to the database and retrieves all "public" table names
|
||||
GetAllTableNames() ([]string, error)
|
||||
// GetAllTables connects to the database and retrieves all "public" table names
|
||||
GetAllTables() ([]string, error)
|
||||
|
||||
// GetTableInfo retrieves column information about the table.
|
||||
GetTableInfo(tableName string) ([]DBColumn, error)
|
||||
|
|
|
@ -44,11 +44,11 @@ func (d *PostgresDriver) Close() {
|
|||
d.dbConn.Close()
|
||||
}
|
||||
|
||||
// GetAllTableNames connects to the postgres database and
|
||||
// GetAllTables connects to the postgres database and
|
||||
// retrieves all table names from the information_schema where the
|
||||
// table schema is public. It excludes common migration tool tables
|
||||
// such as gorp_migrations
|
||||
func (d *PostgresDriver) GetAllTableNames() ([]string, error) {
|
||||
func (d *PostgresDriver) GetAllTables() ([]string, error) {
|
||||
var tableNames []string
|
||||
|
||||
rows, err := d.dbConn.Query(`select table_name from
|
||||
|
|
Loading…
Reference in a new issue