diff --git a/imports.go b/imports.go index c30a2ec..ee27d00 100644 --- a/imports.go +++ b/imports.go @@ -214,11 +214,13 @@ var defaultTestMainImports = map[string]imports{ `"io/ioutil"`, `"bytes"`, `"database/sql"`, + `"path/filepath"`, `"time"`, `"math/rand"`, }, thirdParty: importList{ `"github.com/nullbio/sqlboiler/boil"`, + `"github.com/nullbio/sqlboiler/bdb/drivers"`, `_ "github.com/lib/pq"`, `"github.com/spf13/viper"`, `"github.com/kat-co/vala"`, diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 811103b..e41c0be 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -97,8 +97,7 @@ func dropTestDB() error { // DBConnect connects to a database and returns the handle. func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) { - connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", - user, pass, dbname, host, port, sslmode) + connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode) return sql.Open("postgres", connStr) } @@ -133,7 +132,6 @@ func setup() error { err = vala.BeginValidation().Validate( vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), - vala.StringNotEmpty(testCfg.Postgres.Pass, "postgres.pass"), vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), @@ -163,15 +161,18 @@ func setup() error { defer os.RemoveAll(passDir) // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s", + pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", viper.GetString("postgres.host"), viper.GetInt("postgres.port"), viper.GetString("postgres.dbname"), viper.GetString("postgres.user"), - viper.GetString("postgres.pass"), )) - passFilePath := passDir + "/pwfile" + if pw := viper.GetString("postgres.pass"); len(pw) > 0 { + pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw)) + } + + passFilePath := filepath.Join(passDir, "pwfile") err = ioutil.WriteFile(passFilePath, pwBytes, 0600) if err != nil { @@ -229,14 +230,17 @@ func setup() error { } // Write the test config credentials to a tmp file for pg_dump - testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s", + testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.DBName, testCfg.Postgres.User, - testCfg.Postgres.Pass, )) + if len(testCfg.Postgres.Pass) > 0 { + testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass)) + } + testPassFilePath := passDir + "/testpwfile" err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)