2017-03-13 10:55:26 +01:00
|
|
|
type mssqlTester struct {
|
2017-03-13 17:11:34 +01:00
|
|
|
dbConn *sql.DB
|
|
|
|
dbName string
|
|
|
|
host string
|
|
|
|
user string
|
|
|
|
pass string
|
|
|
|
sslmode string
|
|
|
|
port int
|
2017-03-13 10:55:26 +01:00
|
|
|
testDBName string
|
|
|
|
}
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
dbMain = &mssqlTester{}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) setup() error {
|
|
|
|
var err error
|
|
|
|
m.dbName = viper.GetString("mssql.dbname")
|
|
|
|
m.host = viper.GetString("mssql.host")
|
|
|
|
m.user = viper.GetString("mssql.user")
|
|
|
|
m.pass = viper.GetString("mssql.pass")
|
|
|
|
m.port = viper.GetInt("mssql.port")
|
|
|
|
m.sslmode = viper.GetString("mssql.sslmode")
|
|
|
|
// Create a randomized db name.
|
|
|
|
m.testDBName = randomize.StableDBName(m.dbName)
|
|
|
|
|
|
|
|
if err = m.dropTestDB(); err != nil {
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Err(err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
if err = m.createTestDB(); err != nil {
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Err(err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
2017-03-13 17:11:34 +01:00
|
|
|
createCmd := exec.Command("sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass, "-d", m.testDBName)
|
2017-03-13 10:55:26 +01:00
|
|
|
|
2017-03-14 14:51:56 +01:00
|
|
|
f, err := os.Open("tables_schema.sql")
|
|
|
|
if err != nil {
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Prefix("failed to open tables_schema.sql file", err)
|
2017-03-14 14:51:56 +01:00
|
|
|
}
|
|
|
|
|
2017-03-13 17:11:34 +01:00
|
|
|
defer f.Close()
|
2017-03-13 10:55:26 +01:00
|
|
|
|
2017-03-13 17:11:34 +01:00
|
|
|
createCmd.Stdin = newFKeyDestroyer(rgxMSSQLkey, f)
|
2017-03-13 10:55:26 +01:00
|
|
|
|
2017-03-13 17:11:34 +01:00
|
|
|
if err = createCmd.Start(); err != nil {
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Prefix("failed to start sqlcmd command", err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if err = createCmd.Wait(); err != nil {
|
|
|
|
fmt.Println(err)
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Prefix("failed to wait for sqlcmd command", err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) sslMode(mode string) string {
|
|
|
|
switch mode {
|
|
|
|
case "true":
|
|
|
|
return "true"
|
|
|
|
case "false":
|
|
|
|
return "false"
|
|
|
|
default:
|
|
|
|
return "disable"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) createTestDB() error {
|
2017-03-18 22:02:14 +01:00
|
|
|
sql := fmt.Sprintf(`
|
|
|
|
CREATE DATABASE %s;
|
2017-04-28 22:58:17 +02:00
|
|
|
GO
|
|
|
|
ALTER DATABASE %[1]s
|
|
|
|
SET READ_COMMITTED_SNAPSHOT ON;
|
2017-03-18 22:02:14 +01:00
|
|
|
GO`, m.testDBName)
|
2017-03-13 17:11:34 +01:00
|
|
|
return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) dropTestDB() error {
|
2017-03-13 17:11:34 +01:00
|
|
|
// Since MS SQL 2016 it can be done with
|
|
|
|
// DROP DATABASE [ IF EXISTS ] { database_name | database_snapshot_name } [ ,...n ] [;]
|
|
|
|
sql := fmt.Sprintf(`
|
|
|
|
IF EXISTS(SELECT name FROM sys.databases
|
|
|
|
WHERE name = '%s')
|
|
|
|
DROP DATABASE %s
|
|
|
|
GO`, m.testDBName, m.testDBName)
|
|
|
|
return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) teardown() error {
|
|
|
|
if m.dbConn != nil {
|
|
|
|
m.dbConn.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := m.dropTestDB(); err != nil {
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Err(err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
2017-03-13 17:11:34 +01:00
|
|
|
return nil
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) runCmd(stdin, command string, args ...string) error {
|
|
|
|
cmd := exec.Command(command, args...)
|
|
|
|
cmd.Stdin = strings.NewReader(stdin)
|
|
|
|
|
|
|
|
stdout := &bytes.Buffer{}
|
|
|
|
stderr := &bytes.Buffer{}
|
|
|
|
cmd.Stdout = stdout
|
|
|
|
cmd.Stderr = stderr
|
|
|
|
if err := cmd.Run(); err != nil {
|
2017-03-13 17:11:34 +01:00
|
|
|
fmt.Println("failed running:", command, args)
|
|
|
|
fmt.Println(stdout.String())
|
|
|
|
fmt.Println(stderr.String())
|
2018-02-07 15:35:46 +01:00
|
|
|
return errors.Err(err)
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *mssqlTester) conn() (*sql.DB, error) {
|
|
|
|
if m.dbConn != nil {
|
2017-03-13 17:11:34 +01:00
|
|
|
return m.dbConn, nil
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
var err error
|
|
|
|
m.dbConn, err = sql.Open("mssql", drivers.MSSQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode))
|
|
|
|
if err != nil {
|
2017-03-13 17:11:34 +01:00
|
|
|
return nil, err
|
2017-03-13 10:55:26 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return m.dbConn, nil
|
2017-03-13 17:11:34 +01:00
|
|
|
}
|