Refactor all the bits.
- Make TestMain be driver-based - Move config to TestMain file - Make config a little more sane in pgmain
This commit is contained in:
parent
9bcaf51493
commit
9d29d2b946
5 changed files with 227 additions and 159 deletions
23
imports.go
23
imports.go
|
@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{
|
||||||
`_ "github.com/lib/pq"`,
|
`_ "github.com/lib/pq"`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"mysql": {
|
||||||
|
standard: importList{
|
||||||
|
`"testing"`,
|
||||||
|
`"os"`,
|
||||||
|
`"os/exec"`,
|
||||||
|
`"flag"`,
|
||||||
|
`"fmt"`,
|
||||||
|
`"io/ioutil"`,
|
||||||
|
`"bytes"`,
|
||||||
|
`"database/sql"`,
|
||||||
|
`"path/filepath"`,
|
||||||
|
`"time"`,
|
||||||
|
`"math/rand"`,
|
||||||
|
},
|
||||||
|
thirdParty: importList{
|
||||||
|
`"github.com/kat-co/vala"`,
|
||||||
|
`"github.com/pkg/errors"`,
|
||||||
|
`"github.com/spf13/viper"`,
|
||||||
|
`"github.com/vattle/sqlboiler/boil"`,
|
||||||
|
`"github.com/vattle/sqlboiler/bdb/drivers"`,
|
||||||
|
`_ "github.com/go-mysql-driver/mysql"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// importsBasedOnType imports are only included in the template output if the
|
// importsBasedOnType imports are only included in the template output if the
|
||||||
|
|
|
@ -1,2 +1,17 @@
|
||||||
func TestMain(m *testing.M) {
|
type mysqlTester struct {
|
||||||
|
dbConn *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
dbMain = mysqlTester{}
|
||||||
|
|
||||||
|
func (m mysqlTester) setup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mysqlTester) teardown() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mysqlTester) conn() *sql.DB {
|
||||||
|
return m.dbConn
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,50 +1,21 @@
|
||||||
type PostgresCfg struct {
|
type pgTester struct {
|
||||||
User string `toml:"user"`
|
dbConn *sql.DB
|
||||||
Pass string `toml:"pass"`
|
|
||||||
Host string `toml:"host"`
|
dbName string
|
||||||
Port int `toml:"port"`
|
host string
|
||||||
DBName string `toml:"dbname"`
|
user string
|
||||||
SSLMode string `toml:"sslmode"`
|
pass string
|
||||||
|
sslmode string
|
||||||
|
port int
|
||||||
|
|
||||||
|
testDBName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
dbMain = pgTester{}
|
||||||
Postgres PostgresCfg `toml:"postgres"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
|
|
||||||
// Set DebugMode so we can see generated sql statements
|
|
||||||
flag.Parse()
|
|
||||||
boil.DebugMode = *flagDebugMode
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if err = setup(); err != nil {
|
|
||||||
fmt.Println("Unable to execute setup:", err)
|
|
||||||
os.Exit(-2)
|
|
||||||
}
|
|
||||||
|
|
||||||
var code int
|
|
||||||
if err = disableTriggers(); err != nil {
|
|
||||||
fmt.Println("Unable to disable triggers:", err)
|
|
||||||
} else {
|
|
||||||
boil.SetDB(dbConn)
|
|
||||||
code = m.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = teardown(); err != nil {
|
|
||||||
fmt.Println("Unable to execute teardown:", err)
|
|
||||||
os.Exit(-3)
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Exit(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// disableTriggers is used to disable foreign key constraints for every table.
|
// disableTriggers is used to disable foreign key constraints for every table.
|
||||||
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
||||||
func disableTriggers() error {
|
func (p pgTester) disableTriggers() error {
|
||||||
var stmts []string
|
var stmts []string
|
||||||
|
|
||||||
{{range .Tables}}
|
{{range .Tables}}
|
||||||
|
@ -57,7 +28,7 @@ func disableTriggers() error {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
for _, s := range stmts {
|
for _, s := range stmts {
|
||||||
_, err = dbConn.Exec(s)
|
_, err = p.dbConn.Exec(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -67,33 +38,37 @@ func disableTriggers() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// teardown executes cleanup tasks when the tests finish running
|
// teardown executes cleanup tasks when the tests finish running
|
||||||
func teardown() error {
|
func (p pgTester) teardown() error {
|
||||||
err := dropTestDB()
|
err := dropTestDB()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p pgTester) conn() *sql.DB {
|
||||||
|
return p.dbConn
|
||||||
|
}
|
||||||
|
|
||||||
// dropTestDB switches its connection to the template1 database temporarily
|
// dropTestDB switches its connection to the template1 database temporarily
|
||||||
// so that it can drop the test database without causing "in use" conflicts.
|
// so that it can drop the test database without causing "in use" conflicts.
|
||||||
// The template1 database should be present on all default postgres installations.
|
// The template1 database should be present on all default postgres installations.
|
||||||
func dropTestDB() error {
|
func (p pgTester) dropTestDB() error {
|
||||||
var err error
|
var err error
|
||||||
if dbConn != nil {
|
if p.dbConn != nil {
|
||||||
if err = dbConn.Close(); err != nil {
|
if err = p.dbConn.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
_, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return dbConn.Close()
|
return p.dbConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DBConnect connects to a database and returns the handle.
|
// DBConnect connects to a database and returns the handle.
|
||||||
|
@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.
|
||||||
// setup dumps the database schema and imports it into a temporary randomly
|
// setup dumps the database schema and imports it into a temporary randomly
|
||||||
// generated test database so that tests can be run against it using the
|
// generated test database so that tests can be run against it using the
|
||||||
// generated sqlboiler ORM package.
|
// generated sqlboiler ORM package.
|
||||||
func setup() error {
|
func (p pgTester) setup() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Initialize Viper and load the config file
|
p.dbName = viper.GetString("postgres.dbname")
|
||||||
err = InitViper()
|
p.host = viper.GetString("postgres.host")
|
||||||
if err != nil {
|
p.user = viper.GetString("postgres.user")
|
||||||
return errors.Wrap(err, "Unable to load config file")
|
p.pass = viper.GetString("postgres.pass")
|
||||||
}
|
p.port = viper.GetInt("postgres.port")
|
||||||
|
p.sslmode = viper.GetString("postgres.dbname")
|
||||||
viper.SetDefault("postgres.sslmode", "require")
|
// Create a randomized db name.
|
||||||
viper.SetDefault("postgres.port", "5432")
|
p.testDBName = getDBNameHash(p.dbname)
|
||||||
|
|
||||||
// Create a randomized test configuration object.
|
|
||||||
testCfg.Postgres.Host = viper.GetString("postgres.host")
|
|
||||||
testCfg.Postgres.Port = viper.GetInt("postgres.port")
|
|
||||||
testCfg.Postgres.User = viper.GetString("postgres.user")
|
|
||||||
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
|
|
||||||
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
|
|
||||||
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
|
||||||
|
|
||||||
// Set the default SSLMode value
|
|
||||||
if testCfg.Postgres.SSLMode == "" {
|
|
||||||
viper.Set("postgres.sslmode", "require")
|
|
||||||
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = vala.BeginValidation().Validate(
|
|
||||||
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
|
||||||
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
|
||||||
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
|
||||||
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
|
||||||
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
|
|
||||||
).Check()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Unable to load testCfg")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dropTestDB()
|
err = dropTestDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -163,15 +112,10 @@ func setup() error {
|
||||||
defer os.RemoveAll(passDir)
|
defer os.RemoveAll(passDir)
|
||||||
|
|
||||||
// Write the postgres user password to a tmp file for pg_dump
|
// Write the postgres user password to a tmp file for pg_dump
|
||||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user))
|
||||||
viper.GetString("postgres.host"),
|
|
||||||
viper.GetInt("postgres.port"),
|
|
||||||
viper.GetString("postgres.dbname"),
|
|
||||||
viper.GetString("postgres.user"),
|
|
||||||
))
|
|
||||||
|
|
||||||
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
|
if len(p.pass) > 0 {
|
||||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
|
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass))
|
||||||
}
|
}
|
||||||
|
|
||||||
passFilePath := filepath.Join(passDir, "pwfile")
|
passFilePath := filepath.Join(passDir, "pwfile")
|
||||||
|
@ -183,11 +127,11 @@ func setup() error {
|
||||||
|
|
||||||
// The params for the pg_dump command to dump the database schema
|
// The params for the pg_dump command to dump the database schema
|
||||||
params := []string{
|
params := []string{
|
||||||
fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")),
|
fmt.Sprintf(`--host=%s`, p.host),
|
||||||
fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")),
|
fmt.Sprintf(`--port=%d`, p.port),
|
||||||
fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")),
|
fmt.Sprintf(`--username=%s`, p.user),
|
||||||
"--schema-only",
|
"--schema-only",
|
||||||
viper.GetString("postgres.dbname"),
|
p.dbName,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dump the database schema into the sqlboilerschema tmp file
|
// Dump the database schema into the sqlboilerschema tmp file
|
||||||
|
@ -202,45 +146,33 @@ func setup() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dbConn, err = DBConnect(
|
p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode)
|
||||||
viper.GetString("postgres.user"),
|
|
||||||
viper.GetString("postgres.pass"),
|
|
||||||
viper.GetString("postgres.dbname"),
|
|
||||||
viper.GetString("postgres.host"),
|
|
||||||
viper.GetInt("postgres.port"),
|
|
||||||
viper.GetString("postgres.sslmode"),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the randomly generated database
|
// Create the randomly generated database
|
||||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
|
_, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the old connection so we can reconnect to the test database
|
// Close the old connection so we can reconnect to the test database
|
||||||
if err = dbConn.Close(); err != nil {
|
if err = p.dbConn.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the generated test db
|
// Connect to the generated test db
|
||||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the test config credentials to a tmp file for pg_dump
|
// Write the test config credentials to a tmp file for pg_dump
|
||||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user))
|
||||||
testCfg.Postgres.Host,
|
|
||||||
testCfg.Postgres.Port,
|
|
||||||
testCfg.Postgres.DBName,
|
|
||||||
testCfg.Postgres.User,
|
|
||||||
))
|
|
||||||
|
|
||||||
if len(testCfg.Postgres.Pass) > 0 {
|
if len(p.pass) > 0 {
|
||||||
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
|
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass))
|
||||||
}
|
}
|
||||||
|
|
||||||
testPassFilePath := passDir + "/testpwfile"
|
testPassFilePath := passDir + "/testpwfile"
|
||||||
|
@ -252,10 +184,10 @@ func setup() error {
|
||||||
|
|
||||||
// The params for the psql schema import command
|
// The params for the psql schema import command
|
||||||
params = []string{
|
params = []string{
|
||||||
fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName),
|
fmt.Sprintf(`--dbname=%s`, p.testDBName),
|
||||||
fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host),
|
fmt.Sprintf(`--host=%s`, p.host),
|
||||||
fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port),
|
fmt.Sprintf(`--port=%d`, p.port),
|
||||||
fmt.Sprintf(`--username=%s`, testCfg.Postgres.User),
|
fmt.Sprintf(`--username=%s`, p.user),
|
||||||
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
|
fmt.Sprintf(`--file=%s`, fhSchema.Name()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,5 +203,5 @@ func setup() error {
|
||||||
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
|
fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return p.disableTriggers()
|
||||||
}
|
}
|
||||||
|
|
135
templates_test/singleton/boil_main_test.tpl
Normal file
135
templates_test/singleton/boil_main_test.tpl
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
|
||||||
|
|
||||||
|
var (
|
||||||
|
dbMain tester
|
||||||
|
)
|
||||||
|
|
||||||
|
type tester interface {
|
||||||
|
setup() error
|
||||||
|
conn() *sql.DB
|
||||||
|
teardown() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
if dbMain == nil {
|
||||||
|
fmt.Println("no dbMain tester interface was ready")
|
||||||
|
os.Exit(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
err = initViper()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Unable to load config file")
|
||||||
|
}
|
||||||
|
|
||||||
|
setConfigDefaults()
|
||||||
|
if err := validateConfig({{.DriverName}}); err != nil {
|
||||||
|
fmt.Println("failed to validate config", err)
|
||||||
|
os.Exit(-2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set DebugMode so we can see generated sql statements
|
||||||
|
flag.Parse()
|
||||||
|
boil.DebugMode = *flagDebugMode
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if err = dbMain.setup(); err != nil {
|
||||||
|
fmt.Println("Unable to execute setup:", err)
|
||||||
|
os.Exit(-3)
|
||||||
|
}
|
||||||
|
|
||||||
|
var code int
|
||||||
|
boil.SetDB(dbMain.conn())
|
||||||
|
code = m.Run()
|
||||||
|
|
||||||
|
if err = dbMain.teardown(); err != nil {
|
||||||
|
fmt.Println("Unable to execute teardown:", err)
|
||||||
|
os.Exit(-4)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initViper() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
viper.SetConfigName("sqlboiler")
|
||||||
|
|
||||||
|
configHome := os.Getenv("XDG_CONFIG_HOME")
|
||||||
|
homePath := os.Getenv("HOME")
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
wd = "../"
|
||||||
|
} else {
|
||||||
|
wd = wd + "/.."
|
||||||
|
}
|
||||||
|
|
||||||
|
configPaths := []string{wd}
|
||||||
|
if len(configHome) > 0 {
|
||||||
|
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
|
||||||
|
} else {
|
||||||
|
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range configPaths {
|
||||||
|
viper.AddConfigPath(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore errors here, fall back to defaults and validation to provide errs
|
||||||
|
_ = viper.ReadInConfig()
|
||||||
|
viper.AutomaticEnv()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setDefaults is only necessary because of bugs in viper, noted in main
|
||||||
|
func setDefaults() {
|
||||||
|
if viper.GetString("postgres.sslmode") == "" {
|
||||||
|
viper.Set("postgres.sslmode", "require")
|
||||||
|
}
|
||||||
|
if viper.GetInt("postgres.port") == 0 {
|
||||||
|
viper.Set("postgres.port", 5432)
|
||||||
|
}
|
||||||
|
if viper.GetString("mysql.sslmode") == "" {
|
||||||
|
viper.Set("mysql.sslmode", "true")
|
||||||
|
}
|
||||||
|
if viper.GetInt("mysql.port") == 0 {
|
||||||
|
viper.Set("mysql.port", 3306)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateConfig(driverName string) error {
|
||||||
|
if viper.IsSet("postgres.dbname") {
|
||||||
|
err = vala.BeginValidation().Validate(
|
||||||
|
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
||||||
|
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
||||||
|
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
||||||
|
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
||||||
|
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
||||||
|
).Check()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if driverName == "postgres" {
|
||||||
|
return errors.New("postgres driver requires a postgres section in your config file")
|
||||||
|
}
|
||||||
|
|
||||||
|
if viper.IsSet("mysql.dbname") {
|
||||||
|
err = vala.BeginValidation().Validate(
|
||||||
|
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
||||||
|
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
||||||
|
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
||||||
|
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
||||||
|
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
||||||
|
).Check()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if driverName == "mysql" {
|
||||||
|
return errors.New("mysql driver requires a mysql section in your config file")
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,37 +0,0 @@
|
||||||
var (
|
|
||||||
testCfg *Config
|
|
||||||
dbConn *sql.DB
|
|
||||||
)
|
|
||||||
|
|
||||||
func InitViper() error {
|
|
||||||
var err error
|
|
||||||
testCfg = &Config{}
|
|
||||||
|
|
||||||
viper.SetConfigName("sqlboiler")
|
|
||||||
|
|
||||||
configHome := os.Getenv("XDG_CONFIG_HOME")
|
|
||||||
homePath := os.Getenv("HOME")
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
wd = "../"
|
|
||||||
} else {
|
|
||||||
wd = wd + "/.."
|
|
||||||
}
|
|
||||||
|
|
||||||
configPaths := []string{wd}
|
|
||||||
if len(configHome) > 0 {
|
|
||||||
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
|
|
||||||
} else {
|
|
||||||
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range configPaths {
|
|
||||||
viper.AddConfigPath(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore errors here, fall back to defaults and validation to provide errs
|
|
||||||
_ = viper.ReadInConfig()
|
|
||||||
viper.AutomaticEnv()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in a new issue