2016-03-26 07:54:55 +01:00
|
|
|
type PostgresCfg struct {
|
|
|
|
User string `toml:"user"`
|
|
|
|
Pass string `toml:"pass"`
|
|
|
|
Host string `toml:"host"`
|
|
|
|
Port int `toml:"port"`
|
|
|
|
DBName string `toml:"dbname"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Config struct {
|
|
|
|
Postgres PostgresCfg `toml:"postgres"`
|
|
|
|
}
|
|
|
|
|
|
|
|
var cfg *Config
|
2016-04-09 01:45:44 +02:00
|
|
|
var testCfg *Config
|
2016-03-26 07:54:55 +01:00
|
|
|
|
|
|
|
var dbConn *sql.DB
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
func TestMain(m *testing.M) {
|
2016-06-02 23:07:51 +02:00
|
|
|
// Set the DebugMode to true so we can see generated sql statements
|
|
|
|
boil.DebugMode = true
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
rand.Seed(time.Now().UnixNano())
|
2016-06-02 23:07:51 +02:00
|
|
|
var err error
|
2016-04-09 01:45:44 +02:00
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
err = setup()
|
2016-04-09 01:45:44 +02:00
|
|
|
if err != nil {
|
2016-06-02 23:07:51 +02:00
|
|
|
fmt.Printf("Unable to execute setup: %s", err)
|
2016-04-09 01:45:44 +02:00
|
|
|
os.Exit(-1)
|
|
|
|
}
|
2016-03-26 07:54:55 +01:00
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
err = disableTriggers()
|
|
|
|
if err != nil {
|
|
|
|
fmt.Printf("Unable to disable triggers: %s", err)
|
|
|
|
}
|
2016-05-17 12:00:56 +02:00
|
|
|
boil.SetDB(dbConn)
|
2016-04-09 01:45:44 +02:00
|
|
|
code := m.Run()
|
2016-03-26 07:54:55 +01:00
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
err = teardown()
|
|
|
|
if err != nil {
|
2016-06-02 23:07:51 +02:00
|
|
|
fmt.Printf("Unable to execute teardown: %s", err)
|
2016-04-09 01:45:44 +02:00
|
|
|
os.Exit(-1)
|
|
|
|
}
|
|
|
|
|
|
|
|
os.Exit(code)
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
// disableTriggers is used to disable foreign key constraints for every table.
|
|
|
|
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
|
|
|
func disableTriggers() error {
|
|
|
|
var stmts []string
|
|
|
|
|
|
|
|
{{range .Tables}}
|
2016-06-14 14:58:46 +02:00
|
|
|
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
|
2016-06-02 23:07:51 +02:00
|
|
|
{{- end}}
|
|
|
|
|
|
|
|
if len(stmts) == 0 {
|
|
|
|
return nil
|
2016-04-09 01:45:44 +02:00
|
|
|
}
|
2016-03-26 07:54:55 +01:00
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
var err error
|
|
|
|
for _, s := range stmts {
|
|
|
|
_, err = dbConn.Exec(s)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// teardown executes cleanup tasks when the tests finish running
|
|
|
|
func teardown() error {
|
|
|
|
err := dropTestDB()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// dropTestDB switches its connection to the template1 database temporarily
|
|
|
|
// so that it can drop the test database without causing "in use" conflicts.
|
|
|
|
// The template1 database should be present on all default postgres installations.
|
|
|
|
func dropTestDB() error {
|
|
|
|
var err error
|
|
|
|
if dbConn != nil {
|
|
|
|
if err = dbConn.Close(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
|
2016-03-26 07:54:55 +01:00
|
|
|
if err != nil {
|
2016-04-09 01:45:44 +02:00
|
|
|
return err
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
2016-04-09 01:45:44 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
return dbConn.Close()
|
|
|
|
}
|
2016-03-26 07:54:55 +01:00
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
// DBConnect connects to a database and returns the handle.
|
|
|
|
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
|
|
|
|
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
|
|
|
user, pass, dbname, host, port)
|
|
|
|
|
|
|
|
return sql.Open("postgres", connStr)
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
func LoadConfigFile(filename string) error {
|
|
|
|
_, err := toml.DecodeFile(filename, &cfg)
|
|
|
|
|
|
|
|
if os.IsNotExist(err) {
|
|
|
|
return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err)
|
|
|
|
}
|
|
|
|
|
2016-03-26 07:54:55 +01:00
|
|
|
if err != nil {
|
2016-04-09 01:45:44 +02:00
|
|
|
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
2016-04-09 01:45:44 +02:00
|
|
|
|
|
|
|
return nil
|
2016-03-26 07:54:55 +01:00
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
// setup dumps the database schema and imports it into a temporary randomly
|
|
|
|
// generated test database so that tests can be run against it using the
|
|
|
|
// generated sqlboiler ORM package.
|
2016-03-26 07:54:55 +01:00
|
|
|
func setup() error {
|
2016-04-09 01:45:44 +02:00
|
|
|
// Load the config file in the parent directory.
|
2016-06-14 14:58:46 +02:00
|
|
|
err := LoadConfigFile("../sqlboiler.toml")
|
2016-03-26 07:54:55 +01:00
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to load config file: %s", err)
|
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
testDBName := getDBNameHash(cfg.Postgres.DBName)
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
// Create a randomized test configuration object.
|
|
|
|
testCfg = &Config{}
|
|
|
|
testCfg.Postgres.Host = cfg.Postgres.Host
|
|
|
|
testCfg.Postgres.Port = cfg.Postgres.Port
|
2016-06-02 23:07:51 +02:00
|
|
|
testCfg.Postgres.User = cfg.Postgres.User
|
|
|
|
testCfg.Postgres.Pass = cfg.Postgres.Pass
|
|
|
|
testCfg.Postgres.DBName = testDBName
|
|
|
|
|
|
|
|
err = dropTestDB()
|
|
|
|
if err != nil {
|
|
|
|
fmt.Printf("%#v\n", err)
|
|
|
|
return err
|
|
|
|
}
|
2016-04-09 01:45:44 +02:00
|
|
|
|
2016-03-26 07:54:55 +01:00
|
|
|
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to create sqlboiler schema tmp file: %s", err)
|
|
|
|
}
|
|
|
|
defer os.Remove(fhSchema.Name())
|
|
|
|
|
|
|
|
passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to create sqlboiler tmp dir for postgres pw file: %s", err)
|
|
|
|
}
|
|
|
|
defer os.RemoveAll(passDir)
|
|
|
|
|
|
|
|
// Write the postgres user password to a tmp file for pg_dump
|
2016-04-09 01:45:44 +02:00
|
|
|
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
|
|
|
cfg.Postgres.Host,
|
|
|
|
cfg.Postgres.Port,
|
|
|
|
cfg.Postgres.DBName,
|
|
|
|
cfg.Postgres.User,
|
|
|
|
cfg.Postgres.Pass,
|
|
|
|
))
|
|
|
|
|
2016-03-26 07:54:55 +01:00
|
|
|
passFilePath := passDir + "/pwfile"
|
|
|
|
|
|
|
|
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to create pwfile in passDir: %s", err)
|
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
// The params for the pg_dump command to dump the database schema
|
2016-03-26 07:54:55 +01:00
|
|
|
params := []string{
|
|
|
|
fmt.Sprintf(`--host=%s`, cfg.Postgres.Host),
|
|
|
|
fmt.Sprintf(`--port=%d`, cfg.Postgres.Port),
|
|
|
|
fmt.Sprintf(`--username=%s`, cfg.Postgres.User),
|
|
|
|
"--schema-only",
|
|
|
|
cfg.Postgres.DBName,
|
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
// Dump the database schema into the sqlboilerschema tmp file
|
2016-03-26 07:54:55 +01:00
|
|
|
errBuf := bytes.Buffer{}
|
|
|
|
cmd := exec.Command("pg_dump", params...)
|
|
|
|
cmd.Stderr = &errBuf
|
|
|
|
cmd.Stdout = fhSchema
|
|
|
|
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath))
|
|
|
|
|
|
|
|
if err := cmd.Run(); err != nil {
|
|
|
|
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
|
|
|
|
}
|
|
|
|
|
2016-04-09 01:45:44 +02:00
|
|
|
dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create the randomly generated database
|
|
|
|
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
// Close the old connection so we can reconnect to the test database
|
2016-04-09 01:45:44 +02:00
|
|
|
if err = dbConn.Close(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
// Connect to the generated test db
|
2016-04-09 01:45:44 +02:00
|
|
|
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2016-06-02 23:07:51 +02:00
|
|
|
// Write the test config credentials to a tmp file for pg_dump
|
2016-04-09 01:45:44 +02:00
|
|
|
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
|
|
|
testCfg.Postgres.Host,
|
|
|
|
testCfg.Postgres.Port,
|
|
|
|
testCfg.Postgres.DBName,
|
|
|
|
testCfg.Postgres.User,
|
|
|
|
testCfg.Postgres.Pass,
|
|
|
|
))
|
|
|
|
|
|
|
|
testPassFilePath := passDir + "/testpwfile"
|
|
|
|
|
|
|
|
err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to create testpwfile in passDir: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// The params for the psql schema import command
|
|
|
|
params = []string{
|
|
|
|
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
|
|
|
|
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
|
|
|
|
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
|
|
|
|
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
|
|
|
|
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
|
|
|
|
}
|
|
|
|
|
|
|
|
// Import the database schema into the generated database.
|
|
|
|
// It is now ready to be used by the generated ORM package for testing.
|
|
|
|
outBuf := bytes.Buffer{}
|
|
|
|
cmd = exec.Command("psql", params...)
|
|
|
|
cmd.Stderr = &errBuf
|
|
|
|
cmd.Stdout = &outBuf
|
|
|
|
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath))
|
|
|
|
|
|
|
|
if err = cmd.Run(); err != nil {
|
|
|
|
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
|
|
|
|
}
|
|
|
|
|
2016-03-26 07:54:55 +01:00
|
|
|
return nil
|
|
|
|
}
|