Added drivername, custom driver imports
* Split main test based off driver name
This commit is contained in:
parent
109ff789ef
commit
c74ed4e75f
7 changed files with 145 additions and 62 deletions
|
@ -18,9 +18,9 @@ func boilRun(cmd *cobra.Command, args []string) {
|
|||
// Prepend "struct" command to templateNames slice so it sits at top of sort
|
||||
commandNames = append([]string{"struct"}, commandNames...)
|
||||
|
||||
// Create a testCommandNames with "main" prepended to the front for the test templates
|
||||
// Create a testCommandNames with "driverName_main" prepended to the front for the test templates
|
||||
// the main template initializes all of the testing assets
|
||||
testCommandNames := append([]string{"main"}, commandNames...)
|
||||
testCommandNames := append([]string{cmdData.DriverName + "_main"}, commandNames...)
|
||||
|
||||
for _, table := range cmdData.Tables {
|
||||
data := &tplData{
|
||||
|
@ -54,6 +54,8 @@ func boilRun(cmd *cobra.Command, args []string) {
|
|||
testImps.standard = sqlBoilerDefaultTestImports.standard
|
||||
testImps.thirdparty = sqlBoilerDefaultTestImports.thirdparty
|
||||
|
||||
testImps = combineImports(testImps, sqlBoilerConditionalDriverTestImports[cmdData.DriverName])
|
||||
|
||||
// Loop through and generate every command test template (excluding skipTemplates)
|
||||
for _, command := range testCommandNames {
|
||||
testImps = combineImports(testImps, sqlBoilerCustomTestImports[command])
|
||||
|
|
|
@ -30,6 +30,15 @@ var sqlBoilerDefaultImports = imports{
|
|||
var sqlBoilerDefaultTestImports = imports{
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"errors"`,
|
||||
},
|
||||
thirdparty: importList{
|
||||
`"github.com/BurntSushi/toml"`,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -57,6 +66,15 @@ var sqlBoilerConditionalTypeImports = map[string]imports{
|
|||
},
|
||||
}
|
||||
|
||||
// sqlBoilerConditionalDriverTestImports defines the test template imports
|
||||
// for the particular database interfaces
|
||||
var sqlBoilerConditionalDriverTestImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
standard: importList{`"database/sql"`},
|
||||
thirdparty: importList{`_ "github.com/lib/pq"`},
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerCustomImports map[string]imports
|
||||
var sqlBoilerCustomTestImports map[string]imports
|
||||
|
||||
|
|
|
@ -16,13 +16,7 @@ type PostgresCfg struct {
|
|||
}
|
||||
|
||||
type Config struct {
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
// The TestPostgres object holds the configuration pointing to a test database.
|
||||
// This test database is used to test all of the commands that have an accompanying
|
||||
// command.testtpl file. These template files generate the go test functions.
|
||||
//
|
||||
// Note: These test templates will only be generated for the boil command,
|
||||
// if an OutFolder is provided, and if all test config variables are present.
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
TestPostgres *PostgresCfg `toml:"postgres_test"`
|
||||
}
|
||||
|
||||
|
|
|
@ -17,10 +17,11 @@ type CobraRunFunc func(cmd *cobra.Command, args []string)
|
|||
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
||||
// output file, if one is specified with a flag.
|
||||
type CmdData struct {
|
||||
Tables []dbdrivers.Table
|
||||
PkgName string
|
||||
OutFolder string
|
||||
Interface dbdrivers.Interface
|
||||
Tables []dbdrivers.Table
|
||||
PkgName string
|
||||
OutFolder string
|
||||
Interface dbdrivers.Interface
|
||||
DriverName string
|
||||
}
|
||||
|
||||
// tplData is used to pass data to the template
|
||||
|
|
|
@ -124,6 +124,8 @@ func initInterface() {
|
|||
if cmdData.Interface == nil {
|
||||
errorQuit(errors.New("An invalid driver name was provided"))
|
||||
}
|
||||
|
||||
cmdData.DriverName = driverName
|
||||
}
|
||||
|
||||
// initTables will create a string slice out of the passed in table flag value
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
type PostgresCfg struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
TestPostgres *PostgresCfg `toml:"postgres_test"`
|
||||
}
|
||||
|
||||
var cfg *Config
|
||||
|
||||
func LoadConfigFile(filename string) {
|
||||
_, err := toml.DecodeFile(filename, &cfg)
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Fatalf("Failed to find the toml configuration file %s: %s", filename, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fatalf("Failed to decode toml configuration file:", err)
|
||||
}
|
||||
|
||||
if cfg.TestPostgres != nil {
|
||||
if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" ||
|
||||
cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 ||
|
||||
cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName {
|
||||
cfg.TestPostgres = nil
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.TestPostgres == nil {
|
||||
fmt.Fatalf("Failed to load config.toml postgres_test config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setup()
|
||||
code := m.Run()
|
||||
// shutdown
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func setup() {
|
||||
LoadConfigFile("../config.toml")
|
||||
}
|
115
cmds/templates_test/postgres_main.tpl
Normal file
115
cmds/templates_test/postgres_main.tpl
Normal file
|
@ -0,0 +1,115 @@
|
|||
type PostgresCfg struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
TestPostgres *PostgresCfg `toml:"postgres_test"`
|
||||
}
|
||||
|
||||
var cfg *Config
|
||||
|
||||
var dbConn *sql.DB
|
||||
|
||||
func DBConnect(user, pass, dbname, host string, port int) error {
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||
user, pass, dbname, host, port)
|
||||
|
||||
var err error
|
||||
dbConn, err = sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadConfigFile(filename string) error {
|
||||
_, err := toml.DecodeFile(filename, &cfg)
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to decode toml configuration file:", err)
|
||||
}
|
||||
|
||||
if cfg.TestPostgres != nil {
|
||||
if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" ||
|
||||
cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 ||
|
||||
cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName {
|
||||
cfg.TestPostgres = nil
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.TestPostgres == nil {
|
||||
return errors.New("Failed to load config.toml postgres_test config")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
err := setup()
|
||||
if err != nil {
|
||||
os.Exit(-1)
|
||||
}
|
||||
code := m.Run()
|
||||
// shutdown
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func setup() error {
|
||||
err := LoadConfigFile("../config.toml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to load config file: %s", err)
|
||||
}
|
||||
|
||||
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create sqlboiler schema tmp file: %s", err)
|
||||
}
|
||||
defer os.Remove(fhSchema.Name())
|
||||
|
||||
passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create sqlboiler tmp dir for postgres pw file: %s", err)
|
||||
}
|
||||
defer os.RemoveAll(passDir)
|
||||
|
||||
// Write the postgres user password to a tmp file for pg_dump
|
||||
pwBytes := []byte(fmt.Sprintf("*:*:*:*:%s", cfg.Postgres.Pass))
|
||||
passFilePath := passDir + "/pwfile"
|
||||
|
||||
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to create pwfile in passDir: %s", err)
|
||||
}
|
||||
|
||||
params := []string{
|
||||
fmt.Sprintf(`--host=%s`, cfg.Postgres.Host),
|
||||
fmt.Sprintf(`--port=%d`, cfg.Postgres.Port),
|
||||
fmt.Sprintf(`--username=%s`, cfg.Postgres.User),
|
||||
"--schema-only",
|
||||
cfg.Postgres.DBName,
|
||||
}
|
||||
|
||||
errBuf := bytes.Buffer{}
|
||||
cmd := exec.Command("pg_dump", params...)
|
||||
cmd.Stderr = &errBuf
|
||||
cmd.Stdout = fhSchema
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath))
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
|
||||
}
|
||||
|
||||
err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName))
|
||||
return nil
|
||||
}
|
Loading…
Reference in a new issue