Postgres works again after refactor

This commit is contained in:
Aaron L 2016-09-11 12:07:39 -07:00
parent d1ea925523
commit d183ec4bb5
5 changed files with 49 additions and 59 deletions

View file

@ -186,14 +186,22 @@ var defaultTestTemplateImports = imports{
} }
var defaultSingletonTestTemplateImports = map[string]imports{ var defaultSingletonTestTemplateImports = map[string]imports{
"boil_viper_test": { "boil_main_test": {
standard: importList{ standard: importList{
`"database/sql"`, `"database/sql"`,
`"flag"`,
`"fmt"`,
`"math/rand"`,
`"os"`, `"os"`,
`"path/filepath"`, `"path/filepath"`,
`"testing"`,
`"time"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`, `"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
}, },
}, },
"boil_queries_test": { "boil_queries_test": {
@ -218,23 +226,17 @@ var defaultSingletonTestTemplateImports = map[string]imports{
var defaultTestMainImports = map[string]imports{ var defaultTestMainImports = map[string]imports{
"postgres": { "postgres": {
standard: importList{ standard: importList{
`"testing"`,
`"os"`, `"os"`,
`"os/exec"`, `"os/exec"`,
`"flag"`,
`"fmt"`, `"fmt"`,
`"io/ioutil"`, `"io/ioutil"`,
`"bytes"`, `"bytes"`,
`"database/sql"`, `"database/sql"`,
`"path/filepath"`, `"path/filepath"`,
`"time"`,
`"math/rand"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`, `"github.com/pkg/errors"`,
`"github.com/spf13/viper"`, `"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`, `"github.com/vattle/sqlboiler/bdb/drivers"`,
`_ "github.com/lib/pq"`, `_ "github.com/lib/pq"`,
}, },
@ -254,7 +256,6 @@ var defaultTestMainImports = map[string]imports{
`"math/rand"`, `"math/rand"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`, `"github.com/pkg/errors"`,
`"github.com/spf13/viper"`, `"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,

View file

@ -2,7 +2,6 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -148,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
} }
} }
if viper.IsSet("postgres.dbname") { if driverName == "postgres" {
cmdConfig.Postgres = PostgresConfig{ cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"), User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"), Pass: viper.GetString("postgres.pass"),
@ -182,11 +181,9 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return commandFailure(err.Error()) return commandFailure(err.Error())
} }
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
} }
if viper.IsSet("mysql.dbname") { if driverName == "mysql" {
cmdConfig.MySQL = MySQLConfig{ cmdConfig.MySQL = MySQLConfig{
User: viper.GetString("mysql.user"), User: viper.GetString("mysql.user"),
Pass: viper.GetString("mysql.pass"), Pass: viper.GetString("mysql.pass"),
@ -223,8 +220,6 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return commandFailure(err.Error()) return commandFailure(err.Error())
} }
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
} }
cmdState, err = New(cmdConfig) cmdState, err = New(cmdConfig)

View file

@ -2,16 +2,18 @@ type mysqlTester struct {
dbConn *sql.DB dbConn *sql.DB
} }
dbMain = mysqlTester{} func init() {
dbMain = &mysqlTester{}
}
func (m mysqlTester) setup() error { func (m *mysqlTester) setup() error {
return nil return nil
} }
func (m mysqlTester) teardown() error { func (m *mysqlTester) teardown() error {
return nil return nil
} }
func (m mysqlTester) conn() *sql.DB { func (m *mysqlTester) conn() *sql.DB {
return m.dbConn return m.dbConn
} }

View file

@ -11,16 +11,18 @@ type pgTester struct {
testDBName string testDBName string
} }
dbMain = pgTester{} func init() {
dbMain = &pgTester{}
}
// disableTriggers is used to disable foreign key constraints for every table. // disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors. // If this is not used we cannot test inserts due to foreign key constraint errors.
func (p pgTester) disableTriggers() error { func (p *pgTester) disableTriggers() error {
var stmts []string var stmts []string
{{range .Tables}} {{range .Tables -}}
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`) stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
{{- end}} {{end -}}
if len(stmts) == 0 { if len(stmts) == 0 {
return nil return nil
@ -38,19 +40,18 @@ func (p pgTester) disableTriggers() error {
} }
// teardown executes cleanup tasks when the tests finish running // teardown executes cleanup tasks when the tests finish running
func (p pgTester) teardown() error { func (p *pgTester) teardown() error {
err := dropTestDB() return p.dropTestDB()
return err
} }
func (p pgTester) conn() *sql.DB { func (p *pgTester) conn() *sql.DB {
return p.dbConn return p.dbConn
} }
// dropTestDB switches its connection to the template1 database temporarily // dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts. // so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations. // The template1 database should be present on all default postgres installations.
func (p pgTester) dropTestDB() error { func (p *pgTester) dropTestDB() error {
var err error var err error
if p.dbConn != nil { if p.dbConn != nil {
if err = p.dbConn.Close(); err != nil { if err = p.dbConn.Close(); err != nil {
@ -58,12 +59,12 @@ func (p pgTester) dropTestDB() error {
} }
} }
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode)
if err != nil { if err != nil {
return err return err
} }
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName))
if err != nil { if err != nil {
return err return err
} }
@ -81,7 +82,7 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
// setup dumps the database schema and imports it into a temporary randomly // setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the // generated test database so that tests can be run against it using the
// generated sqlboiler ORM package. // generated sqlboiler ORM package.
func (p pgTester) setup() error { func (p *pgTester) setup() error {
var err error var err error
p.dbName = viper.GetString("postgres.dbname") p.dbName = viper.GetString("postgres.dbname")
@ -89,11 +90,11 @@ func (p pgTester) setup() error {
p.user = viper.GetString("postgres.user") p.user = viper.GetString("postgres.user")
p.pass = viper.GetString("postgres.pass") p.pass = viper.GetString("postgres.pass")
p.port = viper.GetInt("postgres.port") p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.dbname") p.sslmode = viper.GetString("postgres.sslmode")
// Create a randomized db name. // Create a randomized db name.
p.testDBName = getDBNameHash(p.dbname) p.testDBName = getDBNameHash(p.dbName)
err = dropTestDB() err = p.dropTestDB()
if err != nil { if err != nil {
fmt.Printf("%#v\n", err) fmt.Printf("%#v\n", err)
return err return err
@ -112,7 +113,7 @@ func (p pgTester) setup() error {
defer os.RemoveAll(passDir) defer os.RemoveAll(passDir)
// Write the postgres user password to a tmp file for pg_dump // Write the postgres user password to a tmp file for pg_dump
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user)) pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user))
if len(p.pass) > 0 { if len(p.pass) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass)) pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))

View file

@ -17,27 +17,28 @@ func TestMain(m *testing.M) {
} }
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
var err error
// Load configuration // Load configuration
err = initViper() err = initViper()
if err != nil { if err != nil {
return errors.Wrap(err, "Unable to load config file") fmt.Println("unable to load config file")
os.Exit(-2)
} }
setConfigDefaults() setConfigDefaults()
if err := validateConfig({{.DriverName}}); err != nil { if err := validateConfig("{{.DriverName}}"); err != nil {
fmt.Println("failed to validate config", err) fmt.Println("failed to validate config", err)
os.Exit(-2) os.Exit(-3)
} }
// Set DebugMode so we can see generated sql statements // Set DebugMode so we can see generated sql statements
flag.Parse() flag.Parse()
boil.DebugMode = *flagDebugMode boil.DebugMode = *flagDebugMode
var err error
if err = dbMain.setup(); err != nil { if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err) fmt.Println("Unable to execute setup:", err)
os.Exit(-3) os.Exit(-4)
} }
var code int var code int
@ -46,7 +47,7 @@ func TestMain(m *testing.M) {
if err = dbMain.teardown(); err != nil { if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err) fmt.Println("Unable to execute teardown:", err)
os.Exit(-4) os.Exit(-5)
} }
os.Exit(code) os.Exit(code)
@ -84,8 +85,8 @@ func initViper() error {
return nil return nil
} }
// setDefaults is only necessary because of bugs in viper, noted in main // setConfigDefaults is only necessary because of bugs in viper, noted in main
func setDefaults() { func setConfigDefaults() {
if viper.GetString("postgres.sslmode") == "" { if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require") viper.Set("postgres.sslmode", "require")
} }
@ -101,35 +102,25 @@ func setDefaults() {
} }
func validateConfig(driverName string) error { func validateConfig(driverName string) error {
if viper.IsSet("postgres.dbname") { if driverName == "postgres" {
err = vala.BeginValidation().Validate( return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"), vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"), vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")), vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"), vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"), vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check() ).Check()
if err != nil {
return err
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
} }
if viper.IsSet("mysql.dbname") { if driverName == "mysql" {
err = vala.BeginValidation().Validate( return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"), vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"), vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")), vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"), vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"), vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check() ).Check()
if err != nil {
return err
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
} }
return errors.New("not a valid driver name")
} }