Refactor all the bits.

- Make TestMain be driver-based
- Move config to TestMain file
- Make config a little more sane in pgmain
This commit is contained in:
Aaron L 2016-09-11 09:17:08 -07:00
parent 9bcaf51493
commit 9d29d2b946
5 changed files with 227 additions and 159 deletions

View file

@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{
`_ "github.com/lib/pq"`, `_ "github.com/lib/pq"`,
}, },
}, },
"mysql": {
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"path/filepath"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`_ "github.com/go-mysql-driver/mysql"`,
},
},
} }
// importsBasedOnType imports are only included in the template output if the // importsBasedOnType imports are only included in the template output if the

View file

@ -1,2 +1,17 @@
func TestMain(m *testing.M) { type mysqlTester struct {
dbConn *sql.DB
}
dbMain = mysqlTester{}
func (m mysqlTester) setup() error {
return nil
}
func (m mysqlTester) teardown() error {
return nil
}
func (m mysqlTester) conn() *sql.DB {
return m.dbConn
} }

View file

@ -1,50 +1,21 @@
type PostgresCfg struct { type pgTester struct {
User string `toml:"user"` dbConn *sql.DB
Pass string `toml:"pass"`
Host string `toml:"host"` dbName string
Port int `toml:"port"` host string
DBName string `toml:"dbname"` user string
SSLMode string `toml:"sslmode"` pass string
sslmode string
port int
testDBName string
} }
type Config struct { dbMain = pgTester{}
Postgres PostgresCfg `toml:"postgres"`
}
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-2)
}
var code int
if err = disableTriggers(); err != nil {
fmt.Println("Unable to disable triggers:", err)
} else {
boil.SetDB(dbConn)
code = m.Run()
}
if err = teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-3)
}
os.Exit(code)
}
// disableTriggers is used to disable foreign key constraints for every table. // disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors. // If this is not used we cannot test inserts due to foreign key constraint errors.
func disableTriggers() error { func (p pgTester) disableTriggers() error {
var stmts []string var stmts []string
{{range .Tables}} {{range .Tables}}
@ -57,7 +28,7 @@ func disableTriggers() error {
var err error var err error
for _, s := range stmts { for _, s := range stmts {
_, err = dbConn.Exec(s) _, err = p.dbConn.Exec(s)
if err != nil { if err != nil {
return err return err
} }
@ -67,33 +38,37 @@ func disableTriggers() error {
} }
// teardown executes cleanup tasks when the tests finish running // teardown executes cleanup tasks when the tests finish running
func teardown() error { func (p pgTester) teardown() error {
err := dropTestDB() err := dropTestDB()
return err return err
} }
func (p pgTester) conn() *sql.DB {
return p.dbConn
}
// dropTestDB switches its connection to the template1 database temporarily // dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts. // so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations. // The template1 database should be present on all default postgres installations.
func dropTestDB() error { func (p pgTester) dropTestDB() error {
var err error var err error
if dbConn != nil { if p.dbConn != nil {
if err = dbConn.Close(); err != nil { if err = p.dbConn.Close(); err != nil {
return err return err
} }
} }
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil { if err != nil {
return err return err
} }
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
if err != nil { if err != nil {
return err return err
} }
return dbConn.Close() return p.dbConn.Close()
} }
// DBConnect connects to a database and returns the handle. // DBConnect connects to a database and returns the handle.
@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
// setup dumps the database schema and imports it into a temporary randomly // setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the // generated test database so that tests can be run against it using the
// generated sqlboiler ORM package. // generated sqlboiler ORM package.
func setup() error { func (p pgTester) setup() error {
var err error var err error
// Initialize Viper and load the config file p.dbName = viper.GetString("postgres.dbname")
err = InitViper() p.host = viper.GetString("postgres.host")
if err != nil { p.user = viper.GetString("postgres.user")
return errors.Wrap(err, "Unable to load config file") p.pass = viper.GetString("postgres.pass")
} p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.dbname")
viper.SetDefault("postgres.sslmode", "require") // Create a randomized db name.
viper.SetDefault("postgres.port", "5432") p.testDBName = getDBNameHash(p.dbname)
// Create a randomized test configuration object.
testCfg.Postgres.Host = viper.GetString("postgres.host")
testCfg.Postgres.Port = viper.GetInt("postgres.port")
testCfg.Postgres.User = viper.GetString("postgres.user")
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
// Set the default SSLMode value
if testCfg.Postgres.SSLMode == "" {
viper.Set("postgres.sslmode", "require")
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
}
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check()
if err != nil {
return errors.Wrap(err, "Unable to load testCfg")
}
err = dropTestDB() err = dropTestDB()
if err != nil { if err != nil {
@ -163,15 +112,10 @@ func setup() error {
defer os.RemoveAll(passDir) defer os.RemoveAll(passDir)
// Write the postgres user password to a tmp file for pg_dump // Write the postgres user password to a tmp file for pg_dump
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.user"),
))
if pw := viper.GetString("postgres.pass"); len(pw) > 0 { if len(p.pass) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw)) pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
} }
passFilePath := filepath.Join(passDir, "pwfile") passFilePath := filepath.Join(passDir, "pwfile")
@ -183,11 +127,11 @@ func setup() error {
// The params for the pg_dump command to dump the database schema // The params for the pg_dump command to dump the database schema
params := []string{ params := []string{
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")), fmt.Sprintf(`--host=%s`, p.host),
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")), fmt.Sprintf(`--port=%d`, p.port),
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")), fmt.Sprintf(`--username=%s`, p.user),
"--schema-only", "--schema-only",
viper.GetString("postgres.dbname"), p.dbName,
} }
// Dump the database schema into the sqlboilerschema tmp file // Dump the database schema into the sqlboilerschema tmp file
@ -202,45 +146,33 @@ func setup() error {
return err return err
} }
dbConn, err = DBConnect( p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode)
viper.GetString("postgres.user"),
viper.GetString("postgres.pass"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
)
if err != nil { if err != nil {
return err return err
} }
// Create the randomly generated database // Create the randomly generated database
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName)) _, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName))
if err != nil { if err != nil {
return err return err
} }
// Close the old connection so we can reconnect to the test database // Close the old connection so we can reconnect to the test database
if err = dbConn.Close(); err != nil { if err = p.dbConn.Close(); err != nil {
return err return err
} }
// Connect to the generated test db // Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)
if err != nil { if err != nil {
return err return err
} }
// Write the test config credentials to a tmp file for pg_dump // Write the test config credentials to a tmp file for pg_dump
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user))
testCfg.Postgres.Host,
testCfg.Postgres.Port,
testCfg.Postgres.DBName,
testCfg.Postgres.User,
))
if len(testCfg.Postgres.Pass) > 0 { if len(p.pass) > 0 {
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass)) testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass))
} }
testPassFilePath := passDir + "/testpwfile" testPassFilePath := passDir + "/testpwfile"
@ -252,10 +184,10 @@ func setup() error {
// The params for the psql schema import command // The params for the psql schema import command
params = []string{ params = []string{
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName), fmt.Sprintf(`--dbname=%s`, p.testDBName),
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host), fmt.Sprintf(`--host=%s`, p.host),
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port), fmt.Sprintf(`--port=%d`, p.port),
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User), fmt.Sprintf(`--username=%s`, p.user),
fmt.Sprintf(`--file=%s`, fhSchema.Name()), fmt.Sprintf(`--file=%s`, fhSchema.Name()),
} }
@ -271,5 +203,5 @@ func setup() error {
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String()) fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
} }
return nil return p.disableTriggers()
} }

View file

@ -0,0 +1,135 @@
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
var (
dbMain tester
)
type tester interface {
setup() error
conn() *sql.DB
teardown() error
}
func TestMain(m *testing.M) {
if dbMain == nil {
fmt.Println("no dbMain tester interface was ready")
os.Exit(-1)
}
rand.Seed(time.Now().UnixNano())
// Load configuration
err = initViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
}
setConfigDefaults()
if err := validateConfig({{.DriverName}}); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-2)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-3)
}
var code int
boil.SetDB(dbMain.conn())
code = m.Run()
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-4)
}
os.Exit(code)
}
func initViper() error {
var err error
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}
// setDefaults is only necessary because of bugs in viper, noted in main
func setDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
if viper.GetInt("postgres.port") == 0 {
viper.Set("postgres.port", 5432)
}
if viper.GetString("mysql.sslmode") == "" {
viper.Set("mysql.sslmode", "true")
}
if viper.GetInt("mysql.port") == 0 {
viper.Set("mysql.port", 3306)
}
}
func validateConfig(driverName string) error {
if viper.IsSet("postgres.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if viper.IsSet("mysql.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
}
}

View file

@ -1,37 +0,0 @@
var (
testCfg *Config
dbConn *sql.DB
)
func InitViper() error {
var err error
testCfg = &Config{}
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}