92 lines
2.3 KiB
Go
92 lines
2.3 KiB
Go
|
package cmds
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||
|
"github.com/spf13/cobra"
|
||
|
)
|
||
|
|
||
|
type CmdData struct {
|
||
|
TablesInfo [][]dbdrivers.DBTable
|
||
|
TableNames []string
|
||
|
DBDriver dbdrivers.DBDriver
|
||
|
}
|
||
|
|
||
|
var cmdData *CmdData
|
||
|
|
||
|
func init() {
|
||
|
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml")
|
||
|
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
|
||
|
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
|
||
|
}
|
||
|
|
||
|
var SQLBoiler = &cobra.Command{
|
||
|
Use: "sqlboiler",
|
||
|
Short: "SQL Boiler generates boilerplate structs and statements",
|
||
|
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
|
||
|
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
|
||
|
}
|
||
|
|
||
|
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
|
||
|
var err error
|
||
|
cmdData = &CmdData{}
|
||
|
|
||
|
// Retrieve driver flag
|
||
|
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
|
||
|
if driverName == "" {
|
||
|
errorQuit(errors.New("Must supply a driver flag."))
|
||
|
}
|
||
|
|
||
|
// Create a driver based off driver flag
|
||
|
switch driverName {
|
||
|
case "postgres":
|
||
|
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
|
||
|
cfg.Postgres.User,
|
||
|
cfg.Postgres.Pass,
|
||
|
cfg.Postgres.DBName,
|
||
|
cfg.Postgres.Host,
|
||
|
cfg.Postgres.Port,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// Connect to the driver database
|
||
|
if err = cmdData.DBDriver.Open(); err != nil {
|
||
|
errorQuit(err)
|
||
|
}
|
||
|
|
||
|
// Retrieve the list of tables
|
||
|
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
|
||
|
|
||
|
if len(tn) != 0 {
|
||
|
cmdData.TableNames = strings.Split(tn, ",")
|
||
|
for i, name := range cmdData.TableNames {
|
||
|
cmdData.TableNames[i] = strings.TrimSpace(name)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If no table names are provided attempt to process all tables in database
|
||
|
if len(cmdData.TableNames) == 0 {
|
||
|
// get all table names
|
||
|
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
|
||
|
if err != nil {
|
||
|
errorQuit(err)
|
||
|
}
|
||
|
|
||
|
if len(cmdData.TableNames) == 0 {
|
||
|
errorQuit(errors.New("No tables found in database, migrate some tables first"))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// loop over table Names and build TablesInfo
|
||
|
for i := 0; i < len(cmdData.TableNames); i++ {
|
||
|
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
|
||
|
if err != nil {
|
||
|
errorQuit(err)
|
||
|
}
|
||
|
|
||
|
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
|
||
|
}
|
||
|
}
|