diff --git a/db/db.go b/db/db.go index babe641..ec47f7f 100644 --- a/db/db.go +++ b/db/db.go @@ -698,34 +698,25 @@ type txFunc func(tx *sql.Tx) error // withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one. // if dbOrTx is an sql.DB, a new transaction is started -func withTx(dbOrTx interface{}, f txFunc) (err error) { - var tx *sql.Tx - - switch t := dbOrTx.(type) { - case *sql.Tx: - tx = t - case *sql.DB: - tx, err = t.Begin() - if err != nil { - return err - } - defer func() { - if p := recover(); p != nil { - if rollBackError := tx.Rollback(); rollBackError != nil { - log.Error("failed to rollback tx on panic - ", rollBackError) - } - panic(p) - } else if err != nil { - if rollBackError := tx.Rollback(); rollBackError != nil { - log.Error("failed to rollback tx on panic - ", rollBackError) - } - } else { - err = errors.Err(tx.Commit()) - } - }() - default: - return errors.Err("db or tx required") +func (s *SQL) withTx(f txFunc) (err error) { + tx, err := s.conn.Begin() + if err != nil { + return err } + defer func() { + if p := recover(); p != nil { + if rollBackError := tx.Rollback(); rollBackError != nil { + log.Error("failed to rollback tx on panic - ", rollBackError) + } + panic(p) + } else if err != nil { + if rollBackError := tx.Rollback(); rollBackError != nil { + log.Error("failed to rollback tx on panic - ", rollBackError) + } + } else { + err = errors.Err(tx.Commit()) + } + }() return f(tx) }