Refactor all the bits.
- Make TestMain be driver-based - Move config to TestMain file - Make config a little more sane in pgmain
This commit is contained in:
parent
9bcaf51493
commit
9d29d2b946
5 changed files with 227 additions and 159 deletions
23
imports.go
23
imports.go
|
@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{
|
|||
`_ "github.com/lib/pq"`,
|
||||
},
|
||||
},
|
||||
"mysql": {
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"flag"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"database/sql"`,
|
||||
`"path/filepath"`,
|
||||
`"time"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/kat-co/vala"`,
|
||||
`"github.com/pkg/errors"`,
|
||||
`"github.com/spf13/viper"`,
|
||||
`"github.com/vattle/sqlboiler/boil"`,
|
||||
`"github.com/vattle/sqlboiler/bdb/drivers"`,
|
||||
`_ "github.com/go-mysql-driver/mysql"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// importsBasedOnType imports are only included in the template output if the
|
||||
|
|
|
@ -1,2 +1,17 @@
|
|||
func TestMain(m *testing.M) {
|
||||
type mysqlTester struct {
|
||||
dbConn *sql.DB
|
||||
}
|
||||
|
||||
dbMain = mysqlTester{}
|
||||
|
||||
func (m mysqlTester) setup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mysqlTester) teardown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mysqlTester) conn() *sql.DB {
|
||||
return m.dbConn
|
||||
}
|
||||
|
|
|
@ -1,50 +1,21 @@
|
|||
type PostgresCfg struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
SSLMode string `toml:"sslmode"`
|
||||
type pgTester struct {
|
||||
dbConn *sql.DB
|
||||
|
||||
dbName string
|
||||
host string
|
||||
user string
|
||||
pass string
|
||||
sslmode string
|
||||
port int
|
||||
|
||||
testDBName string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Postgres PostgresCfg `toml:"postgres"`
|
||||
}
|
||||
|
||||
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Set DebugMode so we can see generated sql statements
|
||||
flag.Parse()
|
||||
boil.DebugMode = *flagDebugMode
|
||||
|
||||
var err error
|
||||
if err = setup(); err != nil {
|
||||
fmt.Println("Unable to execute setup:", err)
|
||||
os.Exit(-2)
|
||||
}
|
||||
|
||||
var code int
|
||||
if err = disableTriggers(); err != nil {
|
||||
fmt.Println("Unable to disable triggers:", err)
|
||||
} else {
|
||||
boil.SetDB(dbConn)
|
||||
code = m.Run()
|
||||
}
|
||||
|
||||
if err = teardown(); err != nil {
|
||||
fmt.Println("Unable to execute teardown:", err)
|
||||
os.Exit(-3)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
dbMain = pgTester{}
|
||||
|
||||
// disableTriggers is used to disable foreign key constraints for every table.
|
||||
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
||||
func disableTriggers() error {
|
||||
func (p pgTester) disableTriggers() error {
|
||||
var stmts []string
|
||||
|
||||
{{range .Tables}}
|
||||
|
@ -57,7 +28,7 @@ func disableTriggers() error {
|
|||
|
||||
var err error
|
||||
for _, s := range stmts {
|
||||
_, err = dbConn.Exec(s)
|
||||
_, err = p.dbConn.Exec(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -67,33 +38,37 @@ func disableTriggers() error {
|
|||
}
|
||||
|
||||
// teardown executes cleanup tasks when the tests finish running
|
||||
func teardown() error {
|
||||
func (p pgTester) teardown() error {
|
||||
err := dropTestDB()
|
||||
return err
|
||||
}
|
||||
|
||||
func (p pgTester) conn() *sql.DB {
|
||||
return p.dbConn
|
||||
}
|
||||
|
||||
// dropTestDB switches its connection to the template1 database temporarily
|
||||
// so that it can drop the test database without causing "in use" conflicts.
|
||||
// The template1 database should be present on all default postgres installations.
|
||||
func dropTestDB() error {
|
||||
func (p pgTester) dropTestDB() error {
|
||||
var err error
|
||||
if dbConn != nil {
|
||||
if err = dbConn.Close(); err != nil {
|
||||
if p.dbConn != nil {
|
||||
if err = p.dbConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
||||
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dbConn.Close()
|
||||
return p.dbConn.Close()
|
||||
}
|
||||
|
||||
// DBConnect connects to a database and returns the handle.
|
||||
|
@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
|
|||
// setup dumps the database schema and imports it into a temporary randomly
|
||||
// generated test database so that tests can be run against it using the
|
||||
// generated sqlboiler ORM package.
|
||||
func setup() error {
|
||||
func (p pgTester) setup() error {
|
||||
var err error
|
||||
|
||||
// Initialize Viper and load the config file
|
||||
err = InitViper()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to load config file")
|
||||
}
|
||||
|
||||
viper.SetDefault("postgres.sslmode", "require")
|
||||
viper.SetDefault("postgres.port", "5432")
|
||||
|
||||
// Create a randomized test configuration object.
|
||||
testCfg.Postgres.Host = viper.GetString("postgres.host")
|
||||
testCfg.Postgres.Port = viper.GetInt("postgres.port")
|
||||
testCfg.Postgres.User = viper.GetString("postgres.user")
|
||||
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
|
||||
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
|
||||
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
||||
|
||||
// Set the default SSLMode value
|
||||
if testCfg.Postgres.SSLMode == "" {
|
||||
viper.Set("postgres.sslmode", "require")
|
||||
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
||||
}
|
||||
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
||||
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
||||
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
||||
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
||||
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to load testCfg")
|
||||
}
|
||||
p.dbName = viper.GetString("postgres.dbname")
|
||||
p.host = viper.GetString("postgres.host")
|
||||
p.user = viper.GetString("postgres.user")
|
||||
p.pass = viper.GetString("postgres.pass")
|
||||
p.port = viper.GetInt("postgres.port")
|
||||
p.sslmode = viper.GetString("postgres.dbname")
|
||||
// Create a randomized db name.
|
||||
p.testDBName = getDBNameHash(p.dbname)
|
||||
|
||||
err = dropTestDB()
|
||||
if err != nil {
|
||||
|
@ -163,15 +112,10 @@ func setup() error {
|
|||
defer os.RemoveAll(passDir)
|
||||
|
||||
// Write the postgres user password to a tmp file for pg_dump
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
||||
viper.GetString("postgres.host"),
|
||||
viper.GetInt("postgres.port"),
|
||||
viper.GetString("postgres.dbname"),
|
||||
viper.GetString("postgres.user"),
|
||||
))
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
|
||||
|
||||
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
|
||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
|
||||
if len(p.pass) > 0 {
|
||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
|
||||
}
|
||||
|
||||
passFilePath := filepath.Join(passDir, "pwfile")
|
||||
|
@ -183,11 +127,11 @@ func setup() error {
|
|||
|
||||
// The params for the pg_dump command to dump the database schema
|
||||
params := []string{
|
||||
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")),
|
||||
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")),
|
||||
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")),
|
||||
fmt.Sprintf(`--host=%s`, p.host),
|
||||
fmt.Sprintf(`--port=%d`, p.port),
|
||||
fmt.Sprintf(`--username=%s`, p.user),
|
||||
"--schema-only",
|
||||
viper.GetString("postgres.dbname"),
|
||||
p.dbName,
|
||||
}
|
||||
|
||||
// Dump the database schema into the sqlboilerschema tmp file
|
||||
|
@ -202,45 +146,33 @@ func setup() error {
|
|||
return err
|
||||
}
|
||||
|
||||
dbConn, err = DBConnect(
|
||||
viper.GetString("postgres.user"),
|
||||
viper.GetString("postgres.pass"),
|
||||
viper.GetString("postgres.dbname"),
|
||||
viper.GetString("postgres.host"),
|
||||
viper.GetInt("postgres.port"),
|
||||
viper.GetString("postgres.sslmode"),
|
||||
)
|
||||
p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the randomly generated database
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
|
||||
_, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the old connection so we can reconnect to the test database
|
||||
if err = dbConn.Close(); err != nil {
|
||||
if err = p.dbConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect to the generated test db
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the test config credentials to a tmp file for pg_dump
|
||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
||||
testCfg.Postgres.Host,
|
||||
testCfg.Postgres.Port,
|
||||
testCfg.Postgres.DBName,
|
||||
testCfg.Postgres.User,
|
||||
))
|
||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user))
|
||||
|
||||
if len(testCfg.Postgres.Pass) > 0 {
|
||||
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
|
||||
if len(p.pass) > 0 {
|
||||
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass))
|
||||
}
|
||||
|
||||
testPassFilePath := passDir + "/testpwfile"
|
||||
|
@ -252,10 +184,10 @@ func setup() error {
|
|||
|
||||
// The params for the psql schema import command
|
||||
params = []string{
|
||||
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
|
||||
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
|
||||
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
|
||||
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
|
||||
fmt.Sprintf(`--dbname=%s`, p.testDBName),
|
||||
fmt.Sprintf(`--host=%s`, p.host),
|
||||
fmt.Sprintf(`--port=%d`, p.port),
|
||||
fmt.Sprintf(`--username=%s`, p.user),
|
||||
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
|
||||
}
|
||||
|
||||
|
@ -271,5 +203,5 @@ func setup() error {
|
|||
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
return p.disableTriggers()
|
||||
}
|
||||
|
|
135
templates_test/singleton/boil_main_test.tpl
Normal file
135
templates_test/singleton/boil_main_test.tpl
Normal file
|
@ -0,0 +1,135 @@
|
|||
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
|
||||
|
||||
var (
|
||||
dbMain tester
|
||||
)
|
||||
|
||||
type tester interface {
|
||||
setup() error
|
||||
conn() *sql.DB
|
||||
teardown() error
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if dbMain == nil {
|
||||
fmt.Println("no dbMain tester interface was ready")
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Load configuration
|
||||
err = initViper()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to load config file")
|
||||
}
|
||||
|
||||
setConfigDefaults()
|
||||
if err := validateConfig({{.DriverName}}); err != nil {
|
||||
fmt.Println("failed to validate config", err)
|
||||
os.Exit(-2)
|
||||
}
|
||||
|
||||
// Set DebugMode so we can see generated sql statements
|
||||
flag.Parse()
|
||||
boil.DebugMode = *flagDebugMode
|
||||
|
||||
var err error
|
||||
if err = dbMain.setup(); err != nil {
|
||||
fmt.Println("Unable to execute setup:", err)
|
||||
os.Exit(-3)
|
||||
}
|
||||
|
||||
var code int
|
||||
boil.SetDB(dbMain.conn())
|
||||
code = m.Run()
|
||||
|
||||
if err = dbMain.teardown(); err != nil {
|
||||
fmt.Println("Unable to execute teardown:", err)
|
||||
os.Exit(-4)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func initViper() error {
|
||||
var err error
|
||||
|
||||
viper.SetConfigName("sqlboiler")
|
||||
|
||||
configHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
homePath := os.Getenv("HOME")
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
wd = "../"
|
||||
} else {
|
||||
wd = wd + "/.."
|
||||
}
|
||||
|
||||
configPaths := []string{wd}
|
||||
if len(configHome) > 0 {
|
||||
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
|
||||
} else {
|
||||
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
|
||||
}
|
||||
|
||||
for _, p := range configPaths {
|
||||
viper.AddConfigPath(p)
|
||||
}
|
||||
|
||||
// Ignore errors here, fall back to defaults and validation to provide errs
|
||||
_ = viper.ReadInConfig()
|
||||
viper.AutomaticEnv()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults is only necessary because of bugs in viper, noted in main
|
||||
func setDefaults() {
|
||||
if viper.GetString("postgres.sslmode") == "" {
|
||||
viper.Set("postgres.sslmode", "require")
|
||||
}
|
||||
if viper.GetInt("postgres.port") == 0 {
|
||||
viper.Set("postgres.port", 5432)
|
||||
}
|
||||
if viper.GetString("mysql.sslmode") == "" {
|
||||
viper.Set("mysql.sslmode", "true")
|
||||
}
|
||||
if viper.GetInt("mysql.port") == 0 {
|
||||
viper.Set("mysql.port", 3306)
|
||||
}
|
||||
}
|
||||
|
||||
func validateConfig(driverName string) error {
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
||||
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if driverName == "postgres" {
|
||||
return errors.New("postgres driver requires a postgres section in your config file")
|
||||
}
|
||||
|
||||
if viper.IsSet("mysql.dbname") {
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
||||
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if driverName == "mysql" {
|
||||
return errors.New("mysql driver requires a mysql section in your config file")
|
||||
}
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
var (
|
||||
testCfg *Config
|
||||
dbConn *sql.DB
|
||||
)
|
||||
|
||||
func InitViper() error {
|
||||
var err error
|
||||
testCfg = &Config{}
|
||||
|
||||
viper.SetConfigName("sqlboiler")
|
||||
|
||||
configHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
homePath := os.Getenv("HOME")
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
wd = "../"
|
||||
} else {
|
||||
wd = wd + "/.."
|
||||
}
|
||||
|
||||
configPaths := []string{wd}
|
||||
if len(configHome) > 0 {
|
||||
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
|
||||
} else {
|
||||
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
|
||||
}
|
||||
|
||||
for _, p := range configPaths {
|
||||
viper.AddConfigPath(p)
|
||||
}
|
||||
|
||||
// Ignore errors here, fall back to defaults and validation to provide errs
|
||||
_ = viper.ReadInConfig()
|
||||
viper.AutomaticEnv()
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in a new issue