147 lines
3.8 KiB
Smarty
147 lines
3.8 KiB
Smarty
var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements")
|
|
|
|
var (
|
|
dbMain tester
|
|
)
|
|
|
|
type tester interface {
|
|
setup() error
|
|
conn() (*sql.DB, error)
|
|
teardown() error
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
if dbMain == nil {
|
|
fmt.Println("no dbMain tester interface was ready")
|
|
os.Exit(-1)
|
|
}
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
var err error
|
|
|
|
// Load configuration
|
|
err = initViper()
|
|
if err != nil {
|
|
fmt.Println("unable to load config file")
|
|
os.Exit(-2)
|
|
}
|
|
|
|
setConfigDefaults()
|
|
if err := validateConfig("{{.DriverName}}"); err != nil {
|
|
fmt.Println("failed to validate config", err)
|
|
os.Exit(-3)
|
|
}
|
|
|
|
// Set DebugMode so we can see generated sql statements
|
|
flag.Parse()
|
|
boil.DebugMode = *flagDebugMode
|
|
|
|
if err = dbMain.setup(); err != nil {
|
|
fmt.Println("Unable to execute setup:", err)
|
|
os.Exit(-4)
|
|
}
|
|
|
|
conn, err := dbMain.conn()
|
|
if err != nil {
|
|
fmt.Println("failed to get connection:", err)
|
|
}
|
|
|
|
var code int
|
|
boil.SetDB(conn)
|
|
code = m.Run()
|
|
|
|
if err = dbMain.teardown(); err != nil {
|
|
fmt.Println("Unable to execute teardown:", err)
|
|
os.Exit(-5)
|
|
}
|
|
|
|
os.Exit(code)
|
|
}
|
|
|
|
func initViper() error {
|
|
var err error
|
|
|
|
viper.SetConfigName("sqlboiler")
|
|
|
|
configHome := os.Getenv("XDG_CONFIG_HOME")
|
|
homePath := os.Getenv("HOME")
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
wd = "../"
|
|
} else {
|
|
wd = wd + "/.."
|
|
}
|
|
|
|
configPaths := []string{wd}
|
|
if len(configHome) > 0 {
|
|
configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
|
|
} else {
|
|
configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
|
|
}
|
|
|
|
for _, p := range configPaths {
|
|
viper.AddConfigPath(p)
|
|
}
|
|
|
|
// Ignore errors here, fall back to defaults and validation to provide errs
|
|
_ = viper.ReadInConfig()
|
|
viper.AutomaticEnv()
|
|
|
|
return nil
|
|
}
|
|
|
|
// setConfigDefaults is only necessary because of bugs in viper, noted in main
|
|
func setConfigDefaults() {
|
|
if viper.GetString("postgres.sslmode") == "" {
|
|
viper.Set("postgres.sslmode", "require")
|
|
}
|
|
if viper.GetInt("postgres.port") == 0 {
|
|
viper.Set("postgres.port", 5432)
|
|
}
|
|
if viper.GetString("mysql.sslmode") == "" {
|
|
viper.Set("mysql.sslmode", "true")
|
|
}
|
|
if viper.GetInt("mysql.port") == 0 {
|
|
viper.Set("mysql.port", 3306)
|
|
}
|
|
if viper.GetString("mssql.sslmode") == "" {
|
|
viper.Set("mssql.sslmode", "true")
|
|
}
|
|
if viper.GetInt("mssql.port") == 0 {
|
|
viper.Set("mssql.port", 1433)
|
|
}
|
|
}
|
|
|
|
func validateConfig(driverName string) error {
|
|
if driverName == "postgres" {
|
|
return vala.BeginValidation().Validate(
|
|
vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"),
|
|
vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"),
|
|
vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")),
|
|
vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"),
|
|
vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"),
|
|
).Check()
|
|
}
|
|
|
|
if driverName == "mysql" {
|
|
return vala.BeginValidation().Validate(
|
|
vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"),
|
|
vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"),
|
|
vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")),
|
|
vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"),
|
|
vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"),
|
|
).Check()
|
|
}
|
|
|
|
if driverName == "mssql" {
|
|
return vala.BeginValidation().Validate(
|
|
vala.StringNotEmpty(viper.GetString("mssql.user"), "mssql.user"),
|
|
vala.StringNotEmpty(viper.GetString("mssql.host"), "mssql.host"),
|
|
vala.Not(vala.Equals(viper.GetInt("mssql.port"), 0, "mssql.port")),
|
|
vala.StringNotEmpty(viper.GetString("mssql.dbname"), "mssql.dbname"),
|
|
vala.StringNotEmpty(viper.GetString("mssql.sslmode"), "mssql.sslmode"),
|
|
).Check()
|
|
}
|
|
|
|
return errors.Err("not a valid driver name")
|
|
}
|