diff --git a/imports.go b/imports.go index 5fc0896..bf08208 100644 --- a/imports.go +++ b/imports.go @@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{ `_ "github.com/lib/pq"`, }, }, + "mysql": { + standard: importList{ + `"testing"`, + `"os"`, + `"os/exec"`, + `"flag"`, + `"fmt"`, + `"io/ioutil"`, + `"bytes"`, + `"database/sql"`, + `"path/filepath"`, + `"time"`, + `"math/rand"`, + }, + thirdParty: importList{ + `"github.com/kat-co/vala"`, + `"github.com/pkg/errors"`, + `"github.com/spf13/viper"`, + `"github.com/vattle/sqlboiler/boil"`, + `"github.com/vattle/sqlboiler/bdb/drivers"`, + `_ "github.com/go-mysql-driver/mysql"`, + }, + }, } // importsBasedOnType imports are only included in the template output if the diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index 96dcbfc..5643e0e 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -1,2 +1,17 @@ -func TestMain(m *testing.M) { +type mysqlTester struct { + dbConn *sql.DB +} + +dbMain = mysqlTester{} + +func (m mysqlTester) setup() error { + return nil +} + +func (m mysqlTester) teardown() error { + return nil +} + +func (m mysqlTester) conn() *sql.DB { + return m.dbConn } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 2750f60..2bc819e 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -1,50 +1,21 @@ -type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` - SSLMode string `toml:"sslmode"` +type pgTester struct { + dbConn *sql.DB + + dbName string + host string + user string + pass string + sslmode string + port int + + testDBName string } -type Config struct { - Postgres PostgresCfg `toml:"postgres"` -} - -var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") - -func TestMain(m *testing.M) { - rand.Seed(time.Now().UnixNano()) - - // Set DebugMode so we can see generated sql statements - flag.Parse() - boil.DebugMode = *flagDebugMode - - var err error - if err = setup(); err != nil { - fmt.Println("Unable to execute setup:", err) - os.Exit(-2) - } - - var code int - if err = disableTriggers(); err != nil { - fmt.Println("Unable to disable triggers:", err) - } else { - boil.SetDB(dbConn) - code = m.Run() - } - - if err = teardown(); err != nil { - fmt.Println("Unable to execute teardown:", err) - os.Exit(-3) - } - - os.Exit(code) -} +dbMain = pgTester{} // disableTriggers is used to disable foreign key constraints for every table. // If this is not used we cannot test inserts due to foreign key constraint errors. -func disableTriggers() error { +func (p pgTester) disableTriggers() error { var stmts []string {{range .Tables}} @@ -57,7 +28,7 @@ func disableTriggers() error { var err error for _, s := range stmts { - _, err = dbConn.Exec(s) + _, err = p.dbConn.Exec(s) if err != nil { return err } @@ -67,33 +38,37 @@ func disableTriggers() error { } // teardown executes cleanup tasks when the tests finish running -func teardown() error { +func (p pgTester) teardown() error { err := dropTestDB() return err } +func (p pgTester) conn() *sql.DB { + return p.dbConn +} + // dropTestDB switches its connection to the template1 database temporarily // so that it can drop the test database without causing "in use" conflicts. // The template1 database should be present on all default postgres installations. -func dropTestDB() error { +func (p pgTester) dropTestDB() error { var err error - if dbConn != nil { - if err = dbConn.Close(); err != nil { + if p.dbConn != nil { + if err = p.dbConn.Close(); err != nil { return err } } - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) + p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) if err != nil { return err } - _, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) + _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) if err != nil { return err } - return dbConn.Close() + return p.dbConn.Close() } // DBConnect connects to a database and returns the handle. @@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql. // setup dumps the database schema and imports it into a temporary randomly // generated test database so that tests can be run against it using the // generated sqlboiler ORM package. -func setup() error { +func (p pgTester) setup() error { var err error - // Initialize Viper and load the config file - err = InitViper() - if err != nil { - return errors.Wrap(err, "Unable to load config file") - } - - viper.SetDefault("postgres.sslmode", "require") - viper.SetDefault("postgres.port", "5432") - - // Create a randomized test configuration object. - testCfg.Postgres.Host = viper.GetString("postgres.host") - testCfg.Postgres.Port = viper.GetInt("postgres.port") - testCfg.Postgres.User = viper.GetString("postgres.user") - testCfg.Postgres.Pass = viper.GetString("postgres.pass") - testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - - // Set the default SSLMode value - if testCfg.Postgres.SSLMode == "" { - viper.Set("postgres.sslmode", "require") - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - } - - err = vala.BeginValidation().Validate( - vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), - vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), - vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), - vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), - vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"), - ).Check() - - if err != nil { - return errors.Wrap(err, "Unable to load testCfg") - } + p.dbName = viper.GetString("postgres.dbname") + p.host = viper.GetString("postgres.host") + p.user = viper.GetString("postgres.user") + p.pass = viper.GetString("postgres.pass") + p.port = viper.GetInt("postgres.port") + p.sslmode = viper.GetString("postgres.dbname") + // Create a randomized db name. + p.testDBName = getDBNameHash(p.dbname) err = dropTestDB() if err != nil { @@ -163,15 +112,10 @@ func setup() error { defer os.RemoveAll(passDir) // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.user"), - )) + pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user)) - if pw := viper.GetString("postgres.pass"); len(pw) > 0 { - pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw)) + if len(p.pass) > 0 { + pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass)) } passFilePath := filepath.Join(passDir, "pwfile") @@ -183,11 +127,11 @@ func setup() error { // The params for the pg_dump command to dump the database schema params := []string{ - fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")), - fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")), - fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")), + fmt.Sprintf(`--host=%s`, p.host), + fmt.Sprintf(`--port=%d`, p.port), + fmt.Sprintf(`--username=%s`, p.user), "--schema-only", - viper.GetString("postgres.dbname"), + p.dbName, } // Dump the database schema into the sqlboilerschema tmp file @@ -202,45 +146,33 @@ func setup() error { return err } - dbConn, err = DBConnect( - viper.GetString("postgres.user"), - viper.GetString("postgres.pass"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.sslmode"), - ) + p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode) if err != nil { return err } // Create the randomly generated database - _, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName)) + _, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName)) if err != nil { return err } // Close the old connection so we can reconnect to the test database - if err = dbConn.Close(); err != nil { + if err = p.dbConn.Close(); err != nil { return err } // Connect to the generated test db - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) + p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode) if err != nil { return err } // Write the test config credentials to a tmp file for pg_dump - testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - testCfg.Postgres.Host, - testCfg.Postgres.Port, - testCfg.Postgres.DBName, - testCfg.Postgres.User, - )) + testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)) - if len(testCfg.Postgres.Pass) > 0 { - testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass)) + if len(p.pass) > 0 { + testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass)) } testPassFilePath := passDir + "/testpwfile" @@ -252,10 +184,10 @@ func setup() error { // The params for the psql schema import command params = []string{ - fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName), - fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host), - fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port), - fmt.Sprintf(`--username=%s`, testCfg.Postgres.User), + fmt.Sprintf(`--dbname=%s`, p.testDBName), + fmt.Sprintf(`--host=%s`, p.host), + fmt.Sprintf(`--port=%d`, p.port), + fmt.Sprintf(`--username=%s`, p.user), fmt.Sprintf(`--file=%s`, fhSchema.Name()), } @@ -271,5 +203,5 @@ func setup() error { fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String()) } - return nil + return p.disableTriggers() } diff --git a/templates_test/singleton/boil_main_test.tpl b/templates_test/singleton/boil_main_test.tpl new file mode 100644 index 0000000..77fe62d --- /dev/null +++ b/templates_test/singleton/boil_main_test.tpl @@ -0,0 +1,135 @@ +var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") + +var ( + dbMain tester +) + +type tester interface { + setup() error + conn() *sql.DB + teardown() error +} + +func TestMain(m *testing.M) { + if dbMain == nil { + fmt.Println("no dbMain tester interface was ready") + os.Exit(-1) + } + + rand.Seed(time.Now().UnixNano()) + + // Load configuration + err = initViper() + if err != nil { + return errors.Wrap(err, "Unable to load config file") + } + + setConfigDefaults() + if err := validateConfig({{.DriverName}}); err != nil { + fmt.Println("failed to validate config", err) + os.Exit(-2) + } + + // Set DebugMode so we can see generated sql statements + flag.Parse() + boil.DebugMode = *flagDebugMode + + var err error + if err = dbMain.setup(); err != nil { + fmt.Println("Unable to execute setup:", err) + os.Exit(-3) + } + + var code int + boil.SetDB(dbMain.conn()) + code = m.Run() + + if err = dbMain.teardown(); err != nil { + fmt.Println("Unable to execute teardown:", err) + os.Exit(-4) + } + + os.Exit(code) +} + +func initViper() error { + var err error + + viper.SetConfigName("sqlboiler") + + configHome := os.Getenv("XDG_CONFIG_HOME") + homePath := os.Getenv("HOME") + wd, err := os.Getwd() + if err != nil { + wd = "../" + } else { + wd = wd + "/.." + } + + configPaths := []string{wd} + if len(configHome) > 0 { + configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) + } else { + configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) + } + + for _, p := range configPaths { + viper.AddConfigPath(p) + } + + // Ignore errors here, fall back to defaults and validation to provide errs + _ = viper.ReadInConfig() + viper.AutomaticEnv() + + return nil +} + +// setDefaults is only necessary because of bugs in viper, noted in main +func setDefaults() { + if viper.GetString("postgres.sslmode") == "" { + viper.Set("postgres.sslmode", "require") + } + if viper.GetInt("postgres.port") == 0 { + viper.Set("postgres.port", 5432) + } + if viper.GetString("mysql.sslmode") == "" { + viper.Set("mysql.sslmode", "true") + } + if viper.GetInt("mysql.port") == 0 { + viper.Set("mysql.port", 3306) + } +} + +func validateConfig(driverName string) error { + if viper.IsSet("postgres.dbname") { + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"), + vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"), + vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")), + vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"), + vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"), + ).Check() + + if err != nil { + return err + } + } else if driverName == "postgres" { + return errors.New("postgres driver requires a postgres section in your config file") + } + + if viper.IsSet("mysql.dbname") { + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"), + vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"), + vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")), + vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"), + vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"), + ).Check() + + if err != nil { + return err + } + } else if driverName == "mysql" { + return errors.New("mysql driver requires a mysql section in your config file") + } +} diff --git a/templates_test/singleton/boil_viper_test.tpl b/templates_test/singleton/boil_viper_test.tpl deleted file mode 100644 index d05a20a..0000000 --- a/templates_test/singleton/boil_viper_test.tpl +++ /dev/null @@ -1,37 +0,0 @@ -var ( - testCfg *Config - dbConn *sql.DB -) - -func InitViper() error { - var err error - testCfg = &Config{} - - viper.SetConfigName("sqlboiler") - - configHome := os.Getenv("XDG_CONFIG_HOME") - homePath := os.Getenv("HOME") - wd, err := os.Getwd() - if err != nil { - wd = "../" - } else { - wd = wd + "/.." - } - - configPaths := []string{wd} - if len(configHome) > 0 { - configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) - } else { - configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) - } - - for _, p := range configPaths { - viper.AddConfigPath(p) - } - - // Ignore errors here, fall back to defaults and validation to provide errs - _ = viper.ReadInConfig() - viper.AutomaticEnv() - - return nil -}