Refactor all the bits.

- Make TestMain be driver-based
- Move config to TestMain file
- Make config a little more sane in pgmain
This commit is contained in:
Aaron L 2016-09-11 09:17:08 -07:00
parent 9bcaf51493
commit 9d29d2b946
5 changed files with 227 additions and 159 deletions

View file

@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{
`_ "github.com/lib/pq"`,
},
},
"mysql": {
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"path/filepath"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`_ "github.com/go-mysql-driver/mysql"`,
},
},
}
// importsBasedOnType imports are only included in the template output if the

View file

@ -1,2 +1,17 @@
func TestMain(m *testing.M) {
type mysqlTester struct {
dbConn *sql.DB
}
dbMain = mysqlTester{}
func (m mysqlTester) setup() error {
return nil
}
func (m mysqlTester) teardown() error {
return nil
}
func (m mysqlTester) conn() *sql.DB {
return m.dbConn
}

View file

@ -1,50 +1,21 @@
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
type pgTester struct {
dbConn *sql.DB
dbName string
host string
user string
pass string
sslmode string
port int
testDBName string
}
type Config struct {
Postgres PostgresCfg `toml:"postgres"`
}
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-2)
}
var code int
if err = disableTriggers(); err != nil {
fmt.Println("Unable to disable triggers:", err)
} else {
boil.SetDB(dbConn)
code = m.Run()
}
if err = teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-3)
}
os.Exit(code)
}
dbMain = pgTester{}
// disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors.
func disableTriggers() error {
func (p pgTester) disableTriggers() error {
var stmts []string
{{range .Tables}}
@ -57,7 +28,7 @@ func disableTriggers() error {
var err error
for _, s := range stmts {
_, err = dbConn.Exec(s)
_, err = p.dbConn.Exec(s)
if err != nil {
return err
}
@ -67,33 +38,37 @@ func disableTriggers() error {
}
// teardown executes cleanup tasks when the tests finish running
func teardown() error {
func (p pgTester) teardown() error {
err := dropTestDB()
return err
}
func (p pgTester) conn() *sql.DB {
return p.dbConn
}
// dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations.
func dropTestDB() error {
func (p pgTester) dropTestDB() error {
var err error
if dbConn != nil {
if err = dbConn.Close(); err != nil {
if p.dbConn != nil {
if err = p.dbConn.Close(); err != nil {
return err
}
}
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
if err != nil {
return err
}
return dbConn.Close()
return p.dbConn.Close()
}
// DBConnect connects to a database and returns the handle.
@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
// setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the
// generated sqlboiler ORM package.
func setup() error {
func (p pgTester) setup() error {
var err error
// Initialize Viper and load the config file
err = InitViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
}
viper.SetDefault("postgres.sslmode", "require")
viper.SetDefault("postgres.port", "5432")
// Create a randomized test configuration object.
testCfg.Postgres.Host = viper.GetString("postgres.host")
testCfg.Postgres.Port = viper.GetInt("postgres.port")
testCfg.Postgres.User = viper.GetString("postgres.user")
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
// Set the default SSLMode value
if testCfg.Postgres.SSLMode == "" {
viper.Set("postgres.sslmode", "require")
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
}
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check()
if err != nil {
return errors.Wrap(err, "Unable to load testCfg")
}
p.dbName = viper.GetString("postgres.dbname")
p.host = viper.GetString("postgres.host")
p.user = viper.GetString("postgres.user")
p.pass = viper.GetString("postgres.pass")
p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.dbname")
// Create a randomized db name.
p.testDBName = getDBNameHash(p.dbname)
err = dropTestDB()
if err != nil {
@ -163,15 +112,10 @@ func setup() error {
defer os.RemoveAll(passDir)
// Write the postgres user password to a tmp file for pg_dump
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.user"),
))
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
if len(p.pass) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
}
passFilePath := filepath.Join(passDir, "pwfile")
@ -183,11 +127,11 @@ func setup() error {
// The params for the pg_dump command to dump the database schema
params := []string{
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")),
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")),
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")),
fmt.Sprintf(`--host=%s`, p.host),
fmt.Sprintf(`--port=%d`, p.port),
fmt.Sprintf(`--username=%s`, p.user),
"--schema-only",
viper.GetString("postgres.dbname"),
p.dbName,
}
// Dump the database schema into the sqlboilerschema tmp file
@ -202,45 +146,33 @@ func setup() error {
return err
}
dbConn, err = DBConnect(
viper.GetString("postgres.user"),
viper.GetString("postgres.pass"),
viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
)
p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode)
if err != nil {
return err
}
// Create the randomly generated database
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
_, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName))
if err != nil {
return err
}
// Close the old connection so we can reconnect to the test database
if err = dbConn.Close(); err != nil {
if err = p.dbConn.Close(); err != nil {
return err
}
// Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)
if err != nil {
return err
}
// Write the test config credentials to a tmp file for pg_dump
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
testCfg.Postgres.Host,
testCfg.Postgres.Port,
testCfg.Postgres.DBName,
testCfg.Postgres.User,
))
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user))
if len(testCfg.Postgres.Pass) > 0 {
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
if len(p.pass) > 0 {
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass))
}
testPassFilePath := passDir + "/testpwfile"
@ -252,10 +184,10 @@ func setup() error {
// The params for the psql schema import command
params = []string{
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
fmt.Sprintf(`--dbname=%s`, p.testDBName),
fmt.Sprintf(`--host=%s`, p.host),
fmt.Sprintf(`--port=%d`, p.port),
fmt.Sprintf(`--username=%s`, p.user),
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
}
@ -271,5 +203,5 @@ func setup() error {
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
}
return nil
return p.disableTriggers()
}

View file

@ -0,0 +1,135 @@
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
var (
dbMain tester
)
type tester interface {
setup() error
conn() *sql.DB
teardown() error
}
func TestMain(m *testing.M) {
if dbMain == nil {
fmt.Println("no dbMain tester interface was ready")
os.Exit(-1)
}
rand.Seed(time.Now().UnixNano())
// Load configuration
err = initViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
}
setConfigDefaults()
if err := validateConfig({{.DriverName}}); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-2)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-3)
}
var code int
boil.SetDB(dbMain.conn())
code = m.Run()
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-4)
}
os.Exit(code)
}
func initViper() error {
var err error
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}
// setDefaults is only necessary because of bugs in viper, noted in main
func setDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
if viper.GetInt("postgres.port") == 0 {
viper.Set("postgres.port", 5432)
}
if viper.GetString("mysql.sslmode") == "" {
viper.Set("mysql.sslmode", "true")
}
if viper.GetInt("mysql.port") == 0 {
viper.Set("mysql.port", 3306)
}
}
func validateConfig(driverName string) error {
if viper.IsSet("postgres.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if viper.IsSet("mysql.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
}
}

View file

@ -1,37 +0,0 @@
var (
testCfg *Config
dbConn *sql.DB
)
func InitViper() error {
var err error
testCfg = &Config{}
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}