sqlboiler/templates_test/main_test/postgres_main.tpl

173 lines
4.1 KiB
Smarty
Raw Permalink Normal View History

type pgTester struct {
2016-09-14 06:57:34 +02:00
dbConn *sql.DB
2016-07-15 07:44:52 +02:00
2016-09-14 06:57:34 +02:00
dbName string
host string
user string
pass string
sslmode string
port int
2016-09-14 06:46:58 +02:00
pgPassFile string
2016-06-27 02:44:01 +02:00
2016-09-14 06:57:34 +02:00
testDBName string
}
2016-09-11 21:07:39 +02:00
func init() {
2016-09-14 06:57:34 +02:00
dbMain = &pgTester{}
2016-09-11 21:07:39 +02:00
}
2016-09-14 06:46:58 +02:00
// setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the
// generated sqlboiler ORM package.
func (p *pgTester) setup() error {
2016-09-14 06:57:34 +02:00
var err error
p.dbName = viper.GetString("postgres.dbname")
p.host = viper.GetString("postgres.host")
p.user = viper.GetString("postgres.user")
p.pass = viper.GetString("postgres.pass")
p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.sslmode")
// Create a randomized db name.
p.testDBName = randomize.StableDBName(p.dbName)
2016-09-14 06:46:58 +02:00
if err = p.makePGPassFile(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
if err = p.dropTestDB(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
if err = p.createTestDB(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName)
dumpCmd.Env = append(os.Environ(), p.pgEnv()...)
createCmd := exec.Command("psql", p.testDBName)
createCmd.Env = append(os.Environ(), p.pgEnv()...)
r, w := io.Pipe()
dumpCmd.Stdout = w
2016-09-14 17:45:28 +02:00
createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r)
2016-09-14 06:46:58 +02:00
if err = dumpCmd.Start(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Prefix("failed to start pg_dump command", err)
2016-09-14 06:46:58 +02:00
}
if err = createCmd.Start(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Prefix("failed to start psql command", err)
2016-09-14 06:46:58 +02:00
}
if err = dumpCmd.Wait(); err != nil {
fmt.Println(err)
2018-02-07 15:35:46 +01:00
return errors.Prefix("failed to wait for pg_dump command", err)
2016-09-14 06:46:58 +02:00
}
w.Close() // After dumpCmd is done, close the write end of the pipe
if err = createCmd.Wait(); err != nil {
fmt.Println(err)
2018-02-07 15:35:46 +01:00
return errors.Prefix("failed to wait for psql command", err)
2016-09-14 06:46:58 +02:00
}
return nil
}
2016-09-14 06:46:58 +02:00
func (p *pgTester) runCmd(stdin, command string, args ...string) error {
cmd := exec.Command(command, args...)
cmd.Env = append(os.Environ(), p.pgEnv()...)
if len(stdin) != 0 {
cmd.Stdin = strings.NewReader(stdin)
}
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
fmt.Println("failed running:", command, args)
fmt.Println(stdout.String())
fmt.Println(stderr.String())
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
return nil
}
2016-09-14 06:46:58 +02:00
func (p *pgTester) pgEnv() []string {
return []string{
fmt.Sprintf("PGHOST=%s", p.host),
fmt.Sprintf("PGPORT=%d", p.port),
fmt.Sprintf("PGUSER=%s", p.user),
2016-09-24 08:16:44 +02:00
fmt.Sprintf("PGPASSFILE=%s", p.pgPassFile),
2016-09-14 06:46:58 +02:00
}
}
2016-09-14 06:46:58 +02:00
func (p *pgTester) makePGPassFile() error {
tmp, err := ioutil.TempFile("", "pgpass")
if err != nil {
2018-02-07 15:35:46 +01:00
return errors.Prefix("failed to create option file", err)
2016-09-14 06:46:58 +02:00
}
2016-09-24 08:16:44 +02:00
fmt.Fprintf(tmp, "%s:%d:postgres:%s", p.host, p.port, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
2016-09-14 06:46:58 +02:00
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)
if len(p.pass) != 0 {
fmt.Fprintf(tmp, ":%s", p.pass)
}
fmt.Fprintln(tmp)
p.pgPassFile = tmp.Name()
return tmp.Close()
}
func (p *pgTester) createTestDB() error {
return p.runCmd("", "createdb", p.testDBName)
}
2016-09-11 21:07:39 +02:00
func (p *pgTester) dropTestDB() error {
2016-09-14 06:46:58 +02:00
return p.runCmd("", "dropdb", "--if-exists", p.testDBName)
2016-04-09 01:45:44 +02:00
}
2016-09-14 06:46:58 +02:00
// teardown executes cleanup tasks when the tests finish running
func (p *pgTester) teardown() error {
var err error
if err = p.dbConn.Close(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
p.dbConn = nil
if err = p.dropTestDB(); err != nil {
2018-02-07 15:35:46 +01:00
return errors.Err(err)
2016-09-14 06:46:58 +02:00
}
2016-04-09 01:45:44 +02:00
2016-09-14 06:46:58 +02:00
return os.Remove(p.pgPassFile)
}
2016-09-14 06:46:58 +02:00
func (p *pgTester) conn() (*sql.DB, error) {
if p.dbConn != nil {
return p.dbConn, nil
}
2016-09-14 06:46:58 +02:00
var err error
p.dbConn, err = sql.Open("postgres", drivers.PostgresBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode))
if err != nil {
return nil, err
}
2016-09-14 06:46:58 +02:00
return p.dbConn, nil
}
2016-09-14 06:46:58 +02:00