Postgres works again after refactor
This commit is contained in:
parent
d1ea925523
commit
d183ec4bb5
5 changed files with 49 additions and 59 deletions
17
imports.go
17
imports.go
|
@ -186,14 +186,22 @@ var defaultTestTemplateImports = imports{
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultSingletonTestTemplateImports = map[string]imports{
|
var defaultSingletonTestTemplateImports = map[string]imports{
|
||||||
"boil_viper_test": {
|
"boil_main_test": {
|
||||||
standard: importList{
|
standard: importList{
|
||||||
`"database/sql"`,
|
`"database/sql"`,
|
||||||
|
`"flag"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"math/rand"`,
|
||||||
`"os"`,
|
`"os"`,
|
||||||
`"path/filepath"`,
|
`"path/filepath"`,
|
||||||
|
`"testing"`,
|
||||||
|
`"time"`,
|
||||||
},
|
},
|
||||||
thirdParty: importList{
|
thirdParty: importList{
|
||||||
|
`"github.com/kat-co/vala"`,
|
||||||
|
`"github.com/pkg/errors"`,
|
||||||
`"github.com/spf13/viper"`,
|
`"github.com/spf13/viper"`,
|
||||||
|
`"github.com/vattle/sqlboiler/boil"`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"boil_queries_test": {
|
"boil_queries_test": {
|
||||||
|
@ -218,23 +226,17 @@ var defaultSingletonTestTemplateImports = map[string]imports{
|
||||||
var defaultTestMainImports = map[string]imports{
|
var defaultTestMainImports = map[string]imports{
|
||||||
"postgres": {
|
"postgres": {
|
||||||
standard: importList{
|
standard: importList{
|
||||||
`"testing"`,
|
|
||||||
`"os"`,
|
`"os"`,
|
||||||
`"os/exec"`,
|
`"os/exec"`,
|
||||||
`"flag"`,
|
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
`"io/ioutil"`,
|
`"io/ioutil"`,
|
||||||
`"bytes"`,
|
`"bytes"`,
|
||||||
`"database/sql"`,
|
`"database/sql"`,
|
||||||
`"path/filepath"`,
|
`"path/filepath"`,
|
||||||
`"time"`,
|
|
||||||
`"math/rand"`,
|
|
||||||
},
|
},
|
||||||
thirdParty: importList{
|
thirdParty: importList{
|
||||||
`"github.com/kat-co/vala"`,
|
|
||||||
`"github.com/pkg/errors"`,
|
`"github.com/pkg/errors"`,
|
||||||
`"github.com/spf13/viper"`,
|
`"github.com/spf13/viper"`,
|
||||||
`"github.com/vattle/sqlboiler/boil"`,
|
|
||||||
`"github.com/vattle/sqlboiler/bdb/drivers"`,
|
`"github.com/vattle/sqlboiler/bdb/drivers"`,
|
||||||
`_ "github.com/lib/pq"`,
|
`_ "github.com/lib/pq"`,
|
||||||
},
|
},
|
||||||
|
@ -254,7 +256,6 @@ var defaultTestMainImports = map[string]imports{
|
||||||
`"math/rand"`,
|
`"math/rand"`,
|
||||||
},
|
},
|
||||||
thirdParty: importList{
|
thirdParty: importList{
|
||||||
`"github.com/kat-co/vala"`,
|
|
||||||
`"github.com/pkg/errors"`,
|
`"github.com/pkg/errors"`,
|
||||||
`"github.com/spf13/viper"`,
|
`"github.com/spf13/viper"`,
|
||||||
`"github.com/vattle/sqlboiler/boil"`,
|
`"github.com/vattle/sqlboiler/boil"`,
|
||||||
|
|
9
main.go
9
main.go
|
@ -2,7 +2,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -148,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if viper.IsSet("postgres.dbname") {
|
if driverName == "postgres" {
|
||||||
cmdConfig.Postgres = PostgresConfig{
|
cmdConfig.Postgres = PostgresConfig{
|
||||||
User: viper.GetString("postgres.user"),
|
User: viper.GetString("postgres.user"),
|
||||||
Pass: viper.GetString("postgres.pass"),
|
Pass: viper.GetString("postgres.pass"),
|
||||||
|
@ -182,11 +181,9 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return commandFailure(err.Error())
|
return commandFailure(err.Error())
|
||||||
}
|
}
|
||||||
} else if driverName == "postgres" {
|
|
||||||
return errors.New("postgres driver requires a postgres section in your config file")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if viper.IsSet("mysql.dbname") {
|
if driverName == "mysql" {
|
||||||
cmdConfig.MySQL = MySQLConfig{
|
cmdConfig.MySQL = MySQLConfig{
|
||||||
User: viper.GetString("mysql.user"),
|
User: viper.GetString("mysql.user"),
|
||||||
Pass: viper.GetString("mysql.pass"),
|
Pass: viper.GetString("mysql.pass"),
|
||||||
|
@ -223,8 +220,6 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return commandFailure(err.Error())
|
return commandFailure(err.Error())
|
||||||
}
|
}
|
||||||
} else if driverName == "mysql" {
|
|
||||||
return errors.New("mysql driver requires a mysql section in your config file")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdState, err = New(cmdConfig)
|
cmdState, err = New(cmdConfig)
|
||||||
|
|
|
@ -2,16 +2,18 @@ type mysqlTester struct {
|
||||||
dbConn *sql.DB
|
dbConn *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
dbMain = mysqlTester{}
|
func init() {
|
||||||
|
dbMain = &mysqlTester{}
|
||||||
|
}
|
||||||
|
|
||||||
func (m mysqlTester) setup() error {
|
func (m *mysqlTester) setup() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mysqlTester) teardown() error {
|
func (m *mysqlTester) teardown() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mysqlTester) conn() *sql.DB {
|
func (m *mysqlTester) conn() *sql.DB {
|
||||||
return m.dbConn
|
return m.dbConn
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,16 +11,18 @@ type pgTester struct {
|
||||||
testDBName string
|
testDBName string
|
||||||
}
|
}
|
||||||
|
|
||||||
dbMain = pgTester{}
|
func init() {
|
||||||
|
dbMain = &pgTester{}
|
||||||
|
}
|
||||||
|
|
||||||
// disableTriggers is used to disable foreign key constraints for every table.
|
// disableTriggers is used to disable foreign key constraints for every table.
|
||||||
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
||||||
func (p pgTester) disableTriggers() error {
|
func (p *pgTester) disableTriggers() error {
|
||||||
var stmts []string
|
var stmts []string
|
||||||
|
|
||||||
{{range .Tables}}
|
{{range .Tables -}}
|
||||||
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
|
stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`)
|
||||||
{{- end}}
|
{{end -}}
|
||||||
|
|
||||||
if len(stmts) == 0 {
|
if len(stmts) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -38,19 +40,18 @@ func (p pgTester) disableTriggers() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// teardown executes cleanup tasks when the tests finish running
|
// teardown executes cleanup tasks when the tests finish running
|
||||||
func (p pgTester) teardown() error {
|
func (p *pgTester) teardown() error {
|
||||||
err := dropTestDB()
|
return p.dropTestDB()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p pgTester) conn() *sql.DB {
|
func (p *pgTester) conn() *sql.DB {
|
||||||
return p.dbConn
|
return p.dbConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropTestDB switches its connection to the template1 database temporarily
|
// dropTestDB switches its connection to the template1 database temporarily
|
||||||
// so that it can drop the test database without causing "in use" conflicts.
|
// so that it can drop the test database without causing "in use" conflicts.
|
||||||
// The template1 database should be present on all default postgres installations.
|
// The template1 database should be present on all default postgres installations.
|
||||||
func (p pgTester) dropTestDB() error {
|
func (p *pgTester) dropTestDB() error {
|
||||||
var err error
|
var err error
|
||||||
if p.dbConn != nil {
|
if p.dbConn != nil {
|
||||||
if err = p.dbConn.Close(); err != nil {
|
if err = p.dbConn.Close(); err != nil {
|
||||||
|
@ -58,12 +59,12 @@ func (p pgTester) dropTestDB() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -81,7 +82,7 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
|
||||||
// setup dumps the database schema and imports it into a temporary randomly
|
// setup dumps the database schema and imports it into a temporary randomly
|
||||||
// generated test database so that tests can be run against it using the
|
// generated test database so that tests can be run against it using the
|
||||||
// generated sqlboiler ORM package.
|
// generated sqlboiler ORM package.
|
||||||
func (p pgTester) setup() error {
|
func (p *pgTester) setup() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
p.dbName = viper.GetString("postgres.dbname")
|
p.dbName = viper.GetString("postgres.dbname")
|
||||||
|
@ -89,11 +90,11 @@ func (p pgTester) setup() error {
|
||||||
p.user = viper.GetString("postgres.user")
|
p.user = viper.GetString("postgres.user")
|
||||||
p.pass = viper.GetString("postgres.pass")
|
p.pass = viper.GetString("postgres.pass")
|
||||||
p.port = viper.GetInt("postgres.port")
|
p.port = viper.GetInt("postgres.port")
|
||||||
p.sslmode = viper.GetString("postgres.dbname")
|
p.sslmode = viper.GetString("postgres.sslmode")
|
||||||
// Create a randomized db name.
|
// Create a randomized db name.
|
||||||
p.testDBName = getDBNameHash(p.dbname)
|
p.testDBName = getDBNameHash(p.dbName)
|
||||||
|
|
||||||
err = dropTestDB()
|
err = p.dropTestDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("%#v\n", err)
|
fmt.Printf("%#v\n", err)
|
||||||
return err
|
return err
|
||||||
|
@ -112,7 +113,7 @@ func (p pgTester) setup() error {
|
||||||
defer os.RemoveAll(passDir)
|
defer os.RemoveAll(passDir)
|
||||||
|
|
||||||
// Write the postgres user password to a tmp file for pg_dump
|
// Write the postgres user password to a tmp file for pg_dump
|
||||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
|
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user))
|
||||||
|
|
||||||
if len(p.pass) > 0 {
|
if len(p.pass) > 0 {
|
||||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
|
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
|
||||||
|
|
|
@ -17,27 +17,28 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
var err error
|
||||||
|
|
||||||
// Load configuration
|
// Load configuration
|
||||||
err = initViper()
|
err = initViper()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Unable to load config file")
|
fmt.Println("unable to load config file")
|
||||||
|
os.Exit(-2)
|
||||||
}
|
}
|
||||||
|
|
||||||
setConfigDefaults()
|
setConfigDefaults()
|
||||||
if err := validateConfig({{.DriverName}}); err != nil {
|
if err := validateConfig("{{.DriverName}}"); err != nil {
|
||||||
fmt.Println("failed to validate config", err)
|
fmt.Println("failed to validate config", err)
|
||||||
os.Exit(-2)
|
os.Exit(-3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set DebugMode so we can see generated sql statements
|
// Set DebugMode so we can see generated sql statements
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
boil.DebugMode = *flagDebugMode
|
boil.DebugMode = *flagDebugMode
|
||||||
|
|
||||||
var err error
|
|
||||||
if err = dbMain.setup(); err != nil {
|
if err = dbMain.setup(); err != nil {
|
||||||
fmt.Println("Unable to execute setup:", err)
|
fmt.Println("Unable to execute setup:", err)
|
||||||
os.Exit(-3)
|
os.Exit(-4)
|
||||||
}
|
}
|
||||||
|
|
||||||
var code int
|
var code int
|
||||||
|
@ -46,7 +47,7 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
if err = dbMain.teardown(); err != nil {
|
if err = dbMain.teardown(); err != nil {
|
||||||
fmt.Println("Unable to execute teardown:", err)
|
fmt.Println("Unable to execute teardown:", err)
|
||||||
os.Exit(-4)
|
os.Exit(-5)
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
|
@ -84,8 +85,8 @@ func initViper() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setDefaults is only necessary because of bugs in viper, noted in main
|
// setConfigDefaults is only necessary because of bugs in viper, noted in main
|
||||||
func setDefaults() {
|
func setConfigDefaults() {
|
||||||
if viper.GetString("postgres.sslmode") == "" {
|
if viper.GetString("postgres.sslmode") == "" {
|
||||||
viper.Set("postgres.sslmode", "require")
|
viper.Set("postgres.sslmode", "require")
|
||||||
}
|
}
|
||||||
|
@ -101,35 +102,25 @@ func setDefaults() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateConfig(driverName string) error {
|
func validateConfig(driverName string) error {
|
||||||
if viper.IsSet("postgres.dbname") {
|
if driverName == "postgres" {
|
||||||
err = vala.BeginValidation().Validate(
|
return vala.BeginValidation().Validate(
|
||||||
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
||||||
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
||||||
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
||||||
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
||||||
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
||||||
).Check()
|
).Check()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if driverName == "postgres" {
|
|
||||||
return errors.New("postgres driver requires a postgres section in your config file")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if viper.IsSet("mysql.dbname") {
|
if driverName == "mysql" {
|
||||||
err = vala.BeginValidation().Validate(
|
return vala.BeginValidation().Validate(
|
||||||
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
||||||
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
||||||
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
||||||
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
||||||
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
||||||
).Check()
|
).Check()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if driverName == "mysql" {
|
|
||||||
return errors.New("mysql driver requires a mysql section in your config file")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return errors.New("not a valid driver name")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue