sqlboiler/cmds/sqlboiler.go

111 lines
2.7 KiB
Go
Raw Normal View History

package cmds
import (
"errors"
"os"
"strings"
"github.com/pobri19/sqlboiler/dbdrivers"
"github.com/spf13/cobra"
)
type CmdData struct {
TablesInfo [][]dbdrivers.DBTable
TableNames []string
DBDriver dbdrivers.DBDriver
OutFile *os.File
}
var cmdData *CmdData
func init() {
SQLBoiler.PersistentFlags().StringP("driver", "d", "", "The name of the driver in your config.toml")
SQLBoiler.PersistentFlags().StringP("table", "t", "", "A comma seperated list of table names")
SQLBoiler.PersistentFlags().StringP("out", "o", "", "The name of the output file")
SQLBoiler.PersistentPreRun = sqlBoilerPreRun
SQLBoiler.PersistentPostRun = sqlBoilerPostRun
}
var SQLBoiler = &cobra.Command{
Use: "sqlboiler",
Short: "SQL Boiler generates boilerplate structs and statements",
Long: "SQL Boiler generates boilerplate structs and statements.\n" +
`Complete documentation is available at http://github.com/pobri19/sqlboiler`,
}
func sqlBoilerPostRun(cmd *cobra.Command, args []string) {
cmdData.OutFile.Close()
cmdData.DBDriver.Close()
}
func sqlBoilerPreRun(cmd *cobra.Command, args []string) {
var err error
cmdData = &CmdData{}
// Retrieve driver flag
driverName := SQLBoiler.PersistentFlags().Lookup("driver").Value.String()
if driverName == "" {
errorQuit(errors.New("Must supply a driver flag."))
}
// Create a driver based off driver flag
switch driverName {
case "postgres":
cmdData.DBDriver = dbdrivers.NewPostgresDriver(
cfg.Postgres.User,
cfg.Postgres.Pass,
cfg.Postgres.DBName,
cfg.Postgres.Host,
cfg.Postgres.Port,
)
}
// Connect to the driver database
if err = cmdData.DBDriver.Open(); err != nil {
errorQuit(err)
}
// Retrieve the list of tables
tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String()
if len(tn) != 0 {
cmdData.TableNames = strings.Split(tn, ",")
for i, name := range cmdData.TableNames {
cmdData.TableNames[i] = strings.TrimSpace(name)
}
}
// If no table names are provided attempt to process all tables in database
if len(cmdData.TableNames) == 0 {
// get all table names
cmdData.TableNames, err = cmdData.DBDriver.GetAllTableNames()
if err != nil {
errorQuit(err)
}
if len(cmdData.TableNames) == 0 {
errorQuit(errors.New("No tables found in database, migrate some tables first"))
}
}
// loop over table Names and build TablesInfo
for i := 0; i < len(cmdData.TableNames); i++ {
tInfo, err := cmdData.DBDriver.GetTableInfo(cmdData.TableNames[i])
if err != nil {
errorQuit(err)
}
cmdData.TablesInfo = append(cmdData.TablesInfo, tInfo)
}
// open the out file filehandle
outf := SQLBoiler.PersistentFlags().Lookup("out").Value.String()
if outf != "" {
var err error
cmdData.OutFile, err = os.Create(outf)
if err != nil {
errorQuit(err)
}
}
}