359 lines
9.6 KiB
Go
359 lines
9.6 KiB
Go
|
package sqlmock
|
||
|
|
||
|
import (
|
||
|
"database/sql/driver"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
// an expectation interface
|
||
|
type expectation interface {
|
||
|
fulfilled() bool
|
||
|
Lock()
|
||
|
Unlock()
|
||
|
String() string
|
||
|
}
|
||
|
|
||
|
// common expectation struct
|
||
|
// satisfies the expectation interface
|
||
|
type commonExpectation struct {
|
||
|
sync.Mutex
|
||
|
triggered bool
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
func (e *commonExpectation) fulfilled() bool {
|
||
|
return e.triggered
|
||
|
}
|
||
|
|
||
|
// ExpectedClose is used to manage *sql.DB.Close expectation
|
||
|
// returned by *Sqlmock.ExpectClose.
|
||
|
type ExpectedClose struct {
|
||
|
commonExpectation
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for *sql.DB.Close action
|
||
|
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedClose) String() string {
|
||
|
msg := "ExpectedClose => expecting database Close"
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||
|
}
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// ExpectedBegin is used to manage *sql.DB.Begin expectation
|
||
|
// returned by *Sqlmock.ExpectBegin.
|
||
|
type ExpectedBegin struct {
|
||
|
commonExpectation
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for *sql.DB.Begin action
|
||
|
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedBegin) String() string {
|
||
|
msg := "ExpectedBegin => expecting database transaction Begin"
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||
|
}
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
|
||
|
// returned by *Sqlmock.ExpectCommit.
|
||
|
type ExpectedCommit struct {
|
||
|
commonExpectation
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for *sql.Tx.Close action
|
||
|
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedCommit) String() string {
|
||
|
msg := "ExpectedCommit => expecting transaction Commit"
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||
|
}
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
|
||
|
// returned by *Sqlmock.ExpectRollback.
|
||
|
type ExpectedRollback struct {
|
||
|
commonExpectation
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for *sql.Tx.Rollback action
|
||
|
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedRollback) String() string {
|
||
|
msg := "ExpectedRollback => expecting transaction Rollback"
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||
|
}
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query,
|
||
|
// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations.
|
||
|
// Returned by *Sqlmock.ExpectQuery.
|
||
|
type ExpectedQuery struct {
|
||
|
queryBasedExpectation
|
||
|
rows driver.Rows
|
||
|
}
|
||
|
|
||
|
// WithArgs will match given expected args to actual database query arguments.
|
||
|
// if at least one argument does not match, it will return an error. For specific
|
||
|
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||
|
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
|
||
|
e.args = args
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for expected database query
|
||
|
func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// WillReturnRows specifies the set of resulting rows that will be returned
|
||
|
// by the triggered query
|
||
|
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
|
||
|
e.rows = rows
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedQuery) String() string {
|
||
|
msg := "ExpectedQuery => expecting Query or QueryRow which:"
|
||
|
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||
|
|
||
|
if len(e.args) == 0 {
|
||
|
msg += "\n - is without arguments"
|
||
|
} else {
|
||
|
msg += "\n - is with arguments:\n"
|
||
|
for i, arg := range e.args {
|
||
|
msg += fmt.Sprintf(" %d - %+v\n", i, arg)
|
||
|
}
|
||
|
msg = strings.TrimSpace(msg)
|
||
|
}
|
||
|
|
||
|
if e.rows != nil {
|
||
|
msg += "\n - should return rows:\n"
|
||
|
rs, _ := e.rows.(*rows)
|
||
|
for i, row := range rs.rows {
|
||
|
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
||
|
}
|
||
|
msg = strings.TrimSpace(msg)
|
||
|
}
|
||
|
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||
|
}
|
||
|
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
|
||
|
// Returned by *Sqlmock.ExpectExec.
|
||
|
type ExpectedExec struct {
|
||
|
queryBasedExpectation
|
||
|
result driver.Result
|
||
|
}
|
||
|
|
||
|
// WithArgs will match given expected args to actual database exec operation arguments.
|
||
|
// if at least one argument does not match, it will return an error. For specific
|
||
|
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||
|
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
|
||
|
e.args = args
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for expected database exec action
|
||
|
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedExec) String() string {
|
||
|
msg := "ExpectedExec => expecting Exec which:"
|
||
|
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||
|
|
||
|
if len(e.args) == 0 {
|
||
|
msg += "\n - is without arguments"
|
||
|
} else {
|
||
|
msg += "\n - is with arguments:\n"
|
||
|
var margs []string
|
||
|
for i, arg := range e.args {
|
||
|
margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
|
||
|
}
|
||
|
msg += strings.Join(margs, "\n")
|
||
|
}
|
||
|
|
||
|
if e.result != nil {
|
||
|
res, _ := e.result.(*result)
|
||
|
msg += "\n - should return Result having:"
|
||
|
msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
|
||
|
msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
|
||
|
if res.err != nil {
|
||
|
msg += fmt.Sprintf("\n Error: %s", res.err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||
|
}
|
||
|
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// WillReturnResult arranges for an expected Exec() to return a particular
|
||
|
// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method
|
||
|
// to build a corresponding result. Or if actions needs to be tested against errors
|
||
|
// sqlmock.NewErrorResult(err error) to return a given error.
|
||
|
func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
|
||
|
e.result = result
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
|
||
|
// Returned by *Sqlmock.ExpectPrepare.
|
||
|
type ExpectedPrepare struct {
|
||
|
commonExpectation
|
||
|
mock *sqlmock
|
||
|
sqlRegex *regexp.Regexp
|
||
|
statement driver.Stmt
|
||
|
closeErr error
|
||
|
}
|
||
|
|
||
|
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
|
||
|
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
|
||
|
e.err = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// WillReturnCloseError allows to set an error for this prapared statement Close action
|
||
|
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
|
||
|
e.closeErr = err
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
|
||
|
// this method is convenient in order to prevent duplicating sql query string matching.
|
||
|
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
|
||
|
eq := &ExpectedQuery{}
|
||
|
eq.sqlRegex = e.sqlRegex
|
||
|
e.mock.expected = append(e.mock.expected, eq)
|
||
|
return eq
|
||
|
}
|
||
|
|
||
|
// ExpectExec allows to expect Exec() on this prepared statement.
|
||
|
// this method is convenient in order to prevent duplicating sql query string matching.
|
||
|
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
|
||
|
eq := &ExpectedExec{}
|
||
|
eq.sqlRegex = e.sqlRegex
|
||
|
e.mock.expected = append(e.mock.expected, eq)
|
||
|
return eq
|
||
|
}
|
||
|
|
||
|
// String returns string representation
|
||
|
func (e *ExpectedPrepare) String() string {
|
||
|
msg := "ExpectedPrepare => expecting Prepare statement which:"
|
||
|
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||
|
|
||
|
if e.err != nil {
|
||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||
|
}
|
||
|
|
||
|
if e.closeErr != nil {
|
||
|
msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
|
||
|
}
|
||
|
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
// query based expectation
|
||
|
// adds a query matching logic
|
||
|
type queryBasedExpectation struct {
|
||
|
commonExpectation
|
||
|
sqlRegex *regexp.Regexp
|
||
|
args []driver.Value
|
||
|
}
|
||
|
|
||
|
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (err error) {
|
||
|
if !e.queryMatches(sql) {
|
||
|
return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
|
||
|
}
|
||
|
|
||
|
// catch panic
|
||
|
defer func() {
|
||
|
if e := recover(); e != nil {
|
||
|
_, ok := e.(error)
|
||
|
if !ok {
|
||
|
err = fmt.Errorf(e.(string))
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
err = e.argsMatches(args)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (e *queryBasedExpectation) queryMatches(sql string) bool {
|
||
|
return e.sqlRegex.MatchString(sql)
|
||
|
}
|
||
|
|
||
|
func (e *queryBasedExpectation) argsMatches(args []driver.Value) error {
|
||
|
if nil == e.args {
|
||
|
return nil
|
||
|
}
|
||
|
if len(args) != len(e.args) {
|
||
|
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
||
|
}
|
||
|
for k, v := range args {
|
||
|
// custom argument matcher
|
||
|
matcher, ok := e.args[k].(Argument)
|
||
|
if ok {
|
||
|
if !matcher.Match(v) {
|
||
|
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// convert to driver converter
|
||
|
darg, err := driver.DefaultParameterConverter.ConvertValue(e.args[k])
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
||
|
}
|
||
|
|
||
|
if !driver.IsValue(darg) {
|
||
|
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
|
||
|
}
|
||
|
|
||
|
if !reflect.DeepEqual(darg, args[k]) {
|
||
|
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, args[k], args[k])
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|