Postgres works again after refactor
This commit is contained in:
parent
d1ea925523
commit
d183ec4bb5
5 changed files with 49 additions and 59 deletions
17
imports.go
17
imports.go
|
@ -186,14 +186,22 @@ var defaultTestTemplateImports = imports{
|
|||
}
|
||||
|
||||
var defaultSingletonTestTemplateImports = map[string]imports{
|
||||
"boil_viper_test": {
|
||||
"boil_main_test": {
|
||||
standard: importList{
|
||||
`"database/sql"`,
|
||||
`"flag"`,
|
||||
`"fmt"`,
|
||||
`"math/rand"`,
|
||||
`"os"`,
|
||||
`"path/filepath"`,
|
||||
`"testing"`,
|
||||
`"time"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/kat-co/vala"`,
|
||||
`"github.com/pkg/errors"`,
|
||||
`"github.com/spf13/viper"`,
|
||||
`"github.com/vattle/sqlboiler/boil"`,
|
||||
},
|
||||
},
|
||||
"boil_queries_test": {
|
||||
|
@ -218,23 +226,17 @@ var defaultSingletonTestTemplateImports = map[string]imports{
|
|||
var defaultTestMainImports = map[string]imports{
|
||||
"postgres": {
|
||||
standard: importList{
|
||||
`"testing"`,
|
||||
`"os"`,
|
||||
`"os/exec"`,
|
||||
`"flag"`,
|
||||
`"fmt"`,
|
||||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"database/sql"`,
|
||||
`"path/filepath"`,
|
||||
`"time"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/kat-co/vala"`,
|
||||
`"github.com/pkg/errors"`,
|
||||
`"github.com/spf13/viper"`,
|
||||
`"github.com/vattle/sqlboiler/boil"`,
|
||||
`"github.com/vattle/sqlboiler/bdb/drivers"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
},
|
||||
|
@ -254,7 +256,6 @@ var defaultTestMainImports = map[string]imports{
|
|||
`"math/rand"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/kat-co/vala"`,
|
||||
`"github.com/pkg/errors"`,
|
||||
`"github.com/spf13/viper"`,
|
||||
`"github.com/vattle/sqlboiler/boil"`,
|
||||
|
|
9
main.go
9
main.go
|
@ -2,7 +2,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -148,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
if driverName == "postgres" {
|
||||
cmdConfig.Postgres = PostgresConfig{
|
||||
User: viper.GetString("postgres.user"),
|
||||
Pass: viper.GetString("postgres.pass"),
|
||||
|
@ -182,11 +181,9 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
if err != nil {
|
||||
return commandFailure(err.Error())
|
||||
}
|
||||
} else if driverName == "postgres" {
|
||||
return errors.New("postgres driver requires a postgres section in your config file")
|
||||
}
|
||||
|
||||
if viper.IsSet("mysql.dbname") {
|
||||
if driverName == "mysql" {
|
||||
cmdConfig.MySQL = MySQLConfig{
|
||||
User: viper.GetString("mysql.user"),
|
||||
Pass: viper.GetString("mysql.pass"),
|
||||
|
@ -223,8 +220,6 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
if err != nil {
|
||||
return commandFailure(err.Error())
|
||||
}
|
||||
} else if driverName == "mysql" {
|
||||
return errors.New("mysql driver requires a mysql section in your config file")
|
||||
}
|
||||
|
||||
cmdState, err = New(cmdConfig)
|
||||
|
|
|
@ -2,16 +2,18 @@ type mysqlTester struct {
|
|||
dbConn *sql.DB
|
||||
}
|
||||
|
||||
dbMain = mysqlTester{}
|
||||
func init() {
|
||||
dbMain = &mysqlTester{}
|
||||
}
|
||||
|
||||
func (m mysqlTester) setup() error {
|
||||
func (m *mysqlTester) setup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mysqlTester) teardown() error {
|
||||
func (m *mysqlTester) teardown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mysqlTester) conn() *sql.DB {
|
||||
func (m *mysqlTester) conn() *sql.DB {
|
||||
return m.dbConn
|
||||
}
|
||||
|
|
|
@ -11,16 +11,18 @@ type pgTester struct {
|
|||
testDBName string
|
||||
}
|
||||
|
||||
dbMain = pgTester{}
|
||||
func init() {
|
||||
dbMain = &pgTester{}
|
||||
}
|
||||
|
||||
// disableTriggers is used to disable foreign key constraints for every table.
|
||||
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
||||
func (p pgTester) disableTriggers() error {
|
||||
func (p *pgTester) disableTriggers() error {
|
||||
var stmts []string
|
||||
|
||||
{{range .Tables}}
|
||||
{{range .Tables -}}
|
||||
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
|
||||
{{- end}}
|
||||
{{end -}}
|
||||
|
||||
if len(stmts) == 0 {
|
||||
return nil
|
||||
|
@ -38,19 +40,18 @@ func (p pgTester) disableTriggers() error {
|
|||
}
|
||||
|
||||
// teardown executes cleanup tasks when the tests finish running
|
||||
func (p pgTester) teardown() error {
|
||||
err := dropTestDB()
|
||||
return err
|
||||
func (p *pgTester) teardown() error {
|
||||
return p.dropTestDB()
|
||||
}
|
||||
|
||||
func (p pgTester) conn() *sql.DB {
|
||||
func (p *pgTester) conn() *sql.DB {
|
||||
return p.dbConn
|
||||
}
|
||||
|
||||
// dropTestDB switches its connection to the template1 database temporarily
|
||||
// so that it can drop the test database without causing "in use" conflicts.
|
||||
// The template1 database should be present on all default postgres installations.
|
||||
func (p pgTester) dropTestDB() error {
|
||||
func (p *pgTester) dropTestDB() error {
|
||||
var err error
|
||||
if p.dbConn != nil {
|
||||
if err = p.dbConn.Close(); err != nil {
|
||||
|
@ -58,12 +59,12 @@ func (p pgTester) dropTestDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
||||
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -81,7 +82,7 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
|
|||
// setup dumps the database schema and imports it into a temporary randomly
|
||||
// generated test database so that tests can be run against it using the
|
||||
// generated sqlboiler ORM package.
|
||||
func (p pgTester) setup() error {
|
||||
func (p *pgTester) setup() error {
|
||||
var err error
|
||||
|
||||
p.dbName = viper.GetString("postgres.dbname")
|
||||
|
@ -89,11 +90,11 @@ func (p pgTester) setup() error {
|
|||
p.user = viper.GetString("postgres.user")
|
||||
p.pass = viper.GetString("postgres.pass")
|
||||
p.port = viper.GetInt("postgres.port")
|
||||
p.sslmode = viper.GetString("postgres.dbname")
|
||||
p.sslmode = viper.GetString("postgres.sslmode")
|
||||
// Create a randomized db name.
|
||||
p.testDBName = getDBNameHash(p.dbname)
|
||||
p.testDBName = getDBNameHash(p.dbName)
|
||||
|
||||
err = dropTestDB()
|
||||
err = p.dropTestDB()
|
||||
if err != nil {
|
||||
fmt.Printf("%#v\n", err)
|
||||
return err
|
||||
|
@ -112,7 +113,7 @@ func (p pgTester) setup() error {
|
|||
defer os.RemoveAll(passDir)
|
||||
|
||||
// Write the postgres user password to a tmp file for pg_dump
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user))
|
||||
|
||||
if len(p.pass) > 0 {
|
||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
|
||||
|
|
|
@ -17,27 +17,28 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
var err error
|
||||
|
||||
// Load configuration
|
||||
err = initViper()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to load config file")
|
||||
fmt.Println("unable to load config file")
|
||||
os.Exit(-2)
|
||||
}
|
||||
|
||||
setConfigDefaults()
|
||||
if err := validateConfig({{.DriverName}}); err != nil {
|
||||
if err := validateConfig("{{.DriverName}}"); err != nil {
|
||||
fmt.Println("failed to validate config", err)
|
||||
os.Exit(-2)
|
||||
os.Exit(-3)
|
||||
}
|
||||
|
||||
// Set DebugMode so we can see generated sql statements
|
||||
flag.Parse()
|
||||
boil.DebugMode = *flagDebugMode
|
||||
|
||||
var err error
|
||||
if err = dbMain.setup(); err != nil {
|
||||
fmt.Println("Unable to execute setup:", err)
|
||||
os.Exit(-3)
|
||||
os.Exit(-4)
|
||||
}
|
||||
|
||||
var code int
|
||||
|
@ -46,7 +47,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
if err = dbMain.teardown(); err != nil {
|
||||
fmt.Println("Unable to execute teardown:", err)
|
||||
os.Exit(-4)
|
||||
os.Exit(-5)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
|
@ -84,8 +85,8 @@ func initViper() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// setDefaults is only necessary because of bugs in viper, noted in main
|
||||
func setDefaults() {
|
||||
// setConfigDefaults is only necessary because of bugs in viper, noted in main
|
||||
func setConfigDefaults() {
|
||||
if viper.GetString("postgres.sslmode") == "" {
|
||||
viper.Set("postgres.sslmode", "require")
|
||||
}
|
||||
|
@ -101,35 +102,25 @@ func setDefaults() {
|
|||
}
|
||||
|
||||
func validateConfig(driverName string) error {
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
err = vala.BeginValidation().Validate(
|
||||
if driverName == "postgres" {
|
||||
return vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
||||
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
||||
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if driverName == "postgres" {
|
||||
return errors.New("postgres driver requires a postgres section in your config file")
|
||||
}
|
||||
|
||||
if viper.IsSet("mysql.dbname") {
|
||||
err = vala.BeginValidation().Validate(
|
||||
if driverName == "mysql" {
|
||||
return vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
||||
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
||||
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if driverName == "mysql" {
|
||||
return errors.New("mysql driver requires a mysql section in your config file")
|
||||
}
|
||||
|
||||
return errors.New("not a valid driver name")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue