Postgres works again after refactor

This commit is contained in:
Aaron L 2016-09-11 12:07:39 -07:00
parent d1ea925523
commit d183ec4bb5
5 changed files with 49 additions and 59 deletions

View file

@ -186,14 +186,22 @@ var defaultTestTemplateImports = imports{
}
var defaultSingletonTestTemplateImports = map[string]imports{
"boil_viper_test": {
"boil_main_test": {
standard: importList{
`"database/sql"`,
`"flag"`,
`"fmt"`,
`"math/rand"`,
`"os"`,
`"path/filepath"`,
`"testing"`,
`"time"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_queries_test": {
@ -218,23 +226,17 @@ var defaultSingletonTestTemplateImports = map[string]imports{
var defaultTestMainImports = map[string]imports{
"postgres": {
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"path/filepath"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`_ "github.com/lib/pq"`,
},
@ -254,7 +256,6 @@ var defaultTestMainImports = map[string]imports{
`"math/rand"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,

View file

@ -2,7 +2,6 @@
package main
import (
"errors"
"fmt"
"os"
"path/filepath"
@ -148,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
}
}
if viper.IsSet("postgres.dbname") {
if driverName == "postgres" {
cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"),
@ -182,11 +181,9 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil {
return commandFailure(err.Error())
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if viper.IsSet("mysql.dbname") {
if driverName == "mysql" {
cmdConfig.MySQL = MySQLConfig{
User: viper.GetString("mysql.user"),
Pass: viper.GetString("mysql.pass"),
@ -223,8 +220,6 @@ func preRun(cmd *cobra.Command, args []string) error {
if err != nil {
return commandFailure(err.Error())
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
}
cmdState, err = New(cmdConfig)

View file

@ -2,16 +2,18 @@ type mysqlTester struct {
dbConn *sql.DB
}
dbMain = mysqlTester{}
func init() {
dbMain = &mysqlTester{}
}
func (m mysqlTester) setup() error {
func (m *mysqlTester) setup() error {
return nil
}
func (m mysqlTester) teardown() error {
func (m *mysqlTester) teardown() error {
return nil
}
func (m mysqlTester) conn() *sql.DB {
func (m *mysqlTester) conn() *sql.DB {
return m.dbConn
}

View file

@ -11,16 +11,18 @@ type pgTester struct {
testDBName string
}
dbMain = pgTester{}
func init() {
dbMain = &pgTester{}
}
// disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors.
func (p pgTester) disableTriggers() error {
func (p *pgTester) disableTriggers() error {
var stmts []string
{{range .Tables}}
{{range .Tables -}}
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
{{- end}}
{{end -}}
if len(stmts) == 0 {
return nil
@ -38,19 +40,18 @@ func (p pgTester) disableTriggers() error {
}
// teardown executes cleanup tasks when the tests finish running
func (p pgTester) teardown() error {
err := dropTestDB()
return err
func (p *pgTester) teardown() error {
return p.dropTestDB()
}
func (p pgTester) conn() *sql.DB {
func (p *pgTester) conn() *sql.DB {
return p.dbConn
}
// dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations.
func (p pgTester) dropTestDB() error {
func (p *pgTester) dropTestDB() error {
var err error
if p.dbConn != nil {
if err = p.dbConn.Close(); err != nil {
@ -58,12 +59,12 @@ func (p pgTester) dropTestDB() error {
}
}
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode)
if err != nil {
return err
}
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName))
if err != nil {
return err
}
@ -81,7 +82,7 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
// setup dumps the database schema and imports it into a temporary randomly
// generated test database so that tests can be run against it using the
// generated sqlboiler ORM package.
func (p pgTester) setup() error {
func (p *pgTester) setup() error {
var err error
p.dbName = viper.GetString("postgres.dbname")
@ -89,11 +90,11 @@ func (p pgTester) setup() error {
p.user = viper.GetString("postgres.user")
p.pass = viper.GetString("postgres.pass")
p.port = viper.GetInt("postgres.port")
p.sslmode = viper.GetString("postgres.dbname")
p.sslmode = viper.GetString("postgres.sslmode")
// Create a randomized db name.
p.testDBName = getDBNameHash(p.dbname)
p.testDBName = getDBNameHash(p.dbName)
err = dropTestDB()
err = p.dropTestDB()
if err != nil {
fmt.Printf("%#v\n", err)
return err
@ -112,7 +113,7 @@ func (p pgTester) setup() error {
defer os.RemoveAll(passDir)
// Write the postgres user password to a tmp file for pg_dump
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user))
if len(p.pass) > 0 {
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))

View file

@ -17,27 +17,28 @@ func TestMain(m *testing.M) {
}
rand.Seed(time.Now().UnixNano())
var err error
// Load configuration
err = initViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
fmt.Println("unable to load config file")
os.Exit(-2)
}
setConfigDefaults()
if err := validateConfig({{.DriverName}}); err != nil {
if err := validateConfig("{{.DriverName}}"); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-2)
os.Exit(-3)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-3)
os.Exit(-4)
}
var code int
@ -46,7 +47,7 @@ func TestMain(m *testing.M) {
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-4)
os.Exit(-5)
}
os.Exit(code)
@ -84,8 +85,8 @@ func initViper() error {
return nil
}
// setDefaults is only necessary because of bugs in viper, noted in main
func setDefaults() {
// setConfigDefaults is only necessary because of bugs in viper, noted in main
func setConfigDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
@ -101,35 +102,25 @@ func setDefaults() {
}
func validateConfig(driverName string) error {
if viper.IsSet("postgres.dbname") {
err = vala.BeginValidation().Validate(
if driverName == "postgres" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if viper.IsSet("mysql.dbname") {
err = vala.BeginValidation().Validate(
if driverName == "mysql" {
return vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
}
return errors.New("not a valid driver name")
}