Finished postgres TestMain template
This commit is contained in:
parent
4955218373
commit
d949f68ed0
2 changed files with 165 additions and 36 deletions
|
@ -58,8 +58,9 @@ var sqlBoilerTestMainImports = map[string]imports{
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
`"io/ioutil"`,
|
`"io/ioutil"`,
|
||||||
`"bytes"`,
|
`"bytes"`,
|
||||||
`"errors"`,
|
|
||||||
`"database/sql"`,
|
`"database/sql"`,
|
||||||
|
`"time"`,
|
||||||
|
`"math/rand"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdparty: importList{
|
||||||
`"github.com/BurntSushi/toml"`,
|
`"github.com/BurntSushi/toml"`,
|
||||||
|
|
|
@ -8,24 +8,75 @@ type PostgresCfg struct {
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Postgres PostgresCfg `toml:"postgres"`
|
Postgres PostgresCfg `toml:"postgres"`
|
||||||
TestPostgres *PostgresCfg `toml:"postgres_test"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg *Config
|
var cfg *Config
|
||||||
|
var testCfg *Config
|
||||||
|
|
||||||
var dbConn *sql.DB
|
var dbConn *sql.DB
|
||||||
|
|
||||||
func DBConnect(user, pass, dbname, host string, port int) error {
|
func TestMain(m *testing.M) {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
|
err := setup()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
|
||||||
|
err = teardown()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// teardown switches its connection to the template1 database temporarily
|
||||||
|
// so that it can drop the test database and the test user.
|
||||||
|
func teardown() error {
|
||||||
|
err := dbConn.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, "template1", cfg.Postgres.Host, cfg.Postgres.Port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE %s;`, testCfg.Postgres.DBName))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dbConn.Exec(fmt.Sprintf(`DROP USER %s;`, testCfg.Postgres.User))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dbConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBConnect connects to a database and returns the handle.
|
||||||
|
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
|
||||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||||
user, pass, dbname, host, port)
|
user, pass, dbname, host, port)
|
||||||
|
|
||||||
var err error
|
return sql.Open("postgres", connStr)
|
||||||
dbConn, err = sql.Open("postgres", connStr)
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
func randSeq(n int) string {
|
||||||
|
var letters = []rune("abcdefghijklmnopqrstuvwxyz")
|
||||||
|
|
||||||
|
randStr := make([]rune, n)
|
||||||
|
for i := range randStr {
|
||||||
|
randStr[i] = letters[rand.Intn(len(letters))]
|
||||||
|
}
|
||||||
|
return string(randStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfigFile(filename string) error {
|
func LoadConfigFile(filename string) error {
|
||||||
|
@ -36,41 +87,30 @@ func LoadConfigFile(filename string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to decode toml configuration file:", err)
|
return fmt.Errorf("Failed to decode toml configuration file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.TestPostgres != nil {
|
|
||||||
if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" ||
|
|
||||||
cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 ||
|
|
||||||
cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName {
|
|
||||||
cfg.TestPostgres = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.TestPostgres == nil {
|
|
||||||
return errors.New("Failed to load config.toml postgres_test config")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
// setup dumps the database schema and imports it into a temporary randomly
|
||||||
err := setup()
|
// generated test database so that tests can be run against it using the
|
||||||
if err != nil {
|
// generated sqlboiler ORM package.
|
||||||
fmt.Println(err)
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
code := m.Run()
|
|
||||||
// shutdown
|
|
||||||
os.Exit(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setup() error {
|
func setup() error {
|
||||||
|
// Load the config file in the parent directory.
|
||||||
err := LoadConfigFile("../config.toml")
|
err := LoadConfigFile("../config.toml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Unable to load config file: %s", err)
|
return fmt.Errorf("Unable to load config file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a randomized test configuration object.
|
||||||
|
testCfg = &Config{}
|
||||||
|
testCfg.Postgres.Host = cfg.Postgres.Host
|
||||||
|
testCfg.Postgres.Port = cfg.Postgres.Port
|
||||||
|
testCfg.Postgres.User = randSeq(20)
|
||||||
|
testCfg.Postgres.Pass = randSeq(20)
|
||||||
|
testCfg.Postgres.DBName = cfg.Postgres.DBName + "_" + randSeq(10)
|
||||||
|
|
||||||
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
|
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Unable to create sqlboiler schema tmp file: %s", err)
|
return fmt.Errorf("Unable to create sqlboiler schema tmp file: %s", err)
|
||||||
|
@ -84,7 +124,14 @@ func setup() error {
|
||||||
defer os.RemoveAll(passDir)
|
defer os.RemoveAll(passDir)
|
||||||
|
|
||||||
// Write the postgres user password to a tmp file for pg_dump
|
// Write the postgres user password to a tmp file for pg_dump
|
||||||
pwBytes := []byte(fmt.Sprintf("*:*:*:*:%s", cfg.Postgres.Pass))
|
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||||
|
cfg.Postgres.Host,
|
||||||
|
cfg.Postgres.Port,
|
||||||
|
cfg.Postgres.DBName,
|
||||||
|
cfg.Postgres.User,
|
||||||
|
cfg.Postgres.Pass,
|
||||||
|
))
|
||||||
|
|
||||||
passFilePath := passDir + "/pwfile"
|
passFilePath := passDir + "/pwfile"
|
||||||
|
|
||||||
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
|
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
|
||||||
|
@ -92,6 +139,7 @@ func setup() error {
|
||||||
return fmt.Errorf("Unable to create pwfile in passDir: %s", err)
|
return fmt.Errorf("Unable to create pwfile in passDir: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The params for the pg_dump command to dump the database schema
|
||||||
params := []string{
|
params := []string{
|
||||||
fmt.Sprintf(`--host=%s`, cfg.Postgres.Host),
|
fmt.Sprintf(`--host=%s`, cfg.Postgres.Host),
|
||||||
fmt.Sprintf(`--port=%d`, cfg.Postgres.Port),
|
fmt.Sprintf(`--port=%d`, cfg.Postgres.Port),
|
||||||
|
@ -100,6 +148,7 @@ func setup() error {
|
||||||
cfg.Postgres.DBName,
|
cfg.Postgres.DBName,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dump the database schema into the sqlboilerschema tmp file
|
||||||
errBuf := bytes.Buffer{}
|
errBuf := bytes.Buffer{}
|
||||||
cmd := exec.Command("pg_dump", params...)
|
cmd := exec.Command("pg_dump", params...)
|
||||||
cmd.Stderr = &errBuf
|
cmd.Stderr = &errBuf
|
||||||
|
@ -110,11 +159,90 @@ func setup() error {
|
||||||
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
|
fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port)
|
||||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the randomly generated database test user
|
||||||
|
if err = createTestUser(dbConn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the randomly generated database
|
||||||
|
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign the randomly generated db test user to the generated test db
|
||||||
|
_, err = dbConn.Exec(fmt.Sprintf(`ALTER DATABASE %s OWNER TO %s;`, testCfg.Postgres.DBName, testCfg.Postgres.User))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the old connection so we can reconnect with the restricted access generated user
|
||||||
|
if err = dbConn.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the generated test db with the restricted privilege generated user
|
||||||
|
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the generated user password to a tmp file for pg_dump
|
||||||
|
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||||
|
testCfg.Postgres.Host,
|
||||||
|
testCfg.Postgres.Port,
|
||||||
|
testCfg.Postgres.DBName,
|
||||||
|
testCfg.Postgres.User,
|
||||||
|
testCfg.Postgres.Pass,
|
||||||
|
))
|
||||||
|
|
||||||
|
testPassFilePath := passDir + "/testpwfile"
|
||||||
|
|
||||||
|
err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to create testpwfile in passDir: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The params for the psql schema import command
|
||||||
|
params = []string{
|
||||||
|
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
|
||||||
|
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
|
||||||
|
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
|
||||||
|
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
|
||||||
|
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import the database schema into the generated database.
|
||||||
|
// It is now ready to be used by the generated ORM package for testing.
|
||||||
|
outBuf := bytes.Buffer{}
|
||||||
|
cmd = exec.Command("psql", params...)
|
||||||
|
cmd.Stderr = &errBuf
|
||||||
|
cmd.Stdout = &outBuf
|
||||||
|
cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath))
|
||||||
|
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createTestUser creates a temporary database user with restricted privileges
|
||||||
|
func createTestUser(db *sql.DB) error {
|
||||||
|
now := time.Now().Add(time.Hour * 24 * 2)
|
||||||
|
valid := now.Format("2006-1-2")
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`CREATE USER %s WITH PASSWORD '%s' VALID UNTIL '%s';`,
|
||||||
|
testCfg.Postgres.User,
|
||||||
|
testCfg.Postgres.Pass,
|
||||||
|
valid,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := dbConn.Exec(query)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue