sqlboiler/templates_test/singleton/boil_main_test.tpl

136 lines
3.3 KiB
Smarty
Raw Normal View History

var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
var (
dbMain tester
)
type tester interface {
setup() error
conn() *sql.DB
teardown() error
}
func TestMain(m *testing.M) {
if dbMain == nil {
fmt.Println("no dbMain tester interface was ready")
os.Exit(-1)
}
rand.Seed(time.Now().UnixNano())
// Load configuration
err = initViper()
if err != nil {
return errors.Wrap(err, "Unable to load config file")
}
setConfigDefaults()
if err := validateConfig({{.DriverName}}); err != nil {
fmt.Println("failed to validate config", err)
os.Exit(-2)
}
// Set DebugMode so we can see generated sql statements
flag.Parse()
boil.DebugMode = *flagDebugMode
var err error
if err = dbMain.setup(); err != nil {
fmt.Println("Unable to execute setup:", err)
os.Exit(-3)
}
var code int
boil.SetDB(dbMain.conn())
code = m.Run()
if err = dbMain.teardown(); err != nil {
fmt.Println("Unable to execute teardown:", err)
os.Exit(-4)
}
os.Exit(code)
}
func initViper() error {
var err error
viper.SetConfigName("sqlboiler")
configHome := os.Getenv("XDG_CONFIG_HOME")
homePath := os.Getenv("HOME")
wd, err := os.Getwd()
if err != nil {
wd = "../"
} else {
wd = wd + "/.."
}
configPaths := []string{wd}
if len(configHome) > 0 {
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
} else {
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
}
for _, p := range configPaths {
viper.AddConfigPath(p)
}
// Ignore errors here, fall back to defaults and validation to provide errs
_ = viper.ReadInConfig()
viper.AutomaticEnv()
return nil
}
// setDefaults is only necessary because of bugs in viper, noted in main
func setDefaults() {
if viper.GetString("postgres.sslmode") == "" {
viper.Set("postgres.sslmode", "require")
}
if viper.GetInt("postgres.port") == 0 {
viper.Set("postgres.port", 5432)
}
if viper.GetString("mysql.sslmode") == "" {
viper.Set("mysql.sslmode", "true")
}
if viper.GetInt("mysql.port") == 0 {
viper.Set("mysql.port", 3306)
}
}
func validateConfig(driverName string) error {
if viper.IsSet("postgres.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "postgres" {
return errors.New("postgres driver requires a postgres section in your config file")
}
if viper.IsSet("mysql.dbname") {
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
).Check()
if err != nil {
return err
}
} else if driverName == "mysql" {
return errors.New("mysql driver requires a mysql section in your config file")
}
}