Rework the mempool locking code.

It was previously possible for the unprotected iteration of the mempool
orphans to lead to undefined results.  This commit remedies that by
reworking the locking code a bit.  It also embeds the mutex directly into
the mempool struct rather than having a separate field for it so the
syntax is a slightly cleaner.
This commit is contained in:
Dave Collins 2013-10-21 18:20:31 -05:00
parent 27abb0eb3e
commit 102fc5f513

View file

@ -75,12 +75,12 @@ const (
// blocks and relayed to other peers. It is safe for concurrent access from // blocks and relayed to other peers. It is safe for concurrent access from
// multiple peers. // multiple peers.
type txMemPool struct { type txMemPool struct {
sync.RWMutex
server *server server *server
pool map[btcwire.ShaHash]*btcwire.MsgTx pool map[btcwire.ShaHash]*btcwire.MsgTx
orphans map[btcwire.ShaHash]*btcwire.MsgTx orphans map[btcwire.ShaHash]*btcwire.MsgTx
orphansByPrev map[btcwire.ShaHash]*list.List orphansByPrev map[btcwire.ShaHash]*list.List
outpoints map[btcwire.OutPoint]*btcwire.MsgTx outpoints map[btcwire.OutPoint]*btcwire.MsgTx
lock sync.RWMutex
} }
// isDust returns whether or not the passed transaction output amount is // isDust returns whether or not the passed transaction output amount is
@ -277,11 +277,9 @@ func checkInputsStandard(tx *btcwire.MsgTx) error {
// removeOrphan removes the passed orphan transaction from the orphan pool and // removeOrphan removes the passed orphan transaction from the orphan pool and
// previous orphan index. // previous orphan index.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) removeOrphan(txHash *btcwire.ShaHash) { func (mp *txMemPool) removeOrphan(txHash *btcwire.ShaHash) {
// Protect concurrent access.
mp.lock.Lock()
defer mp.lock.Unlock()
// Nothing to do if passed tx is not an orphan. // Nothing to do if passed tx is not an orphan.
tx, exists := mp.orphans[*txHash] tx, exists := mp.orphans[*txHash]
if !exists { if !exists {
@ -313,11 +311,9 @@ func (mp *txMemPool) removeOrphan(txHash *btcwire.ShaHash) {
// limitNumOrphans limits the number of orphan transactions by evicting a random // limitNumOrphans limits the number of orphan transactions by evicting a random
// orphan if adding a new one would cause it to overflow the max allowed. // orphan if adding a new one would cause it to overflow the max allowed.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) limitNumOrphans() error { func (mp *txMemPool) limitNumOrphans() error {
// Protect concurrent access.
mp.lock.Lock()
defer mp.lock.Unlock()
if len(mp.orphans)+1 > maxOrphanTransactions { if len(mp.orphans)+1 > maxOrphanTransactions {
// Generate a cryptographically random hash. // Generate a cryptographically random hash.
randHashBytes := make([]byte, btcwire.HashSize) randHashBytes := make([]byte, btcwire.HashSize)
@ -344,25 +340,20 @@ func (mp *txMemPool) limitNumOrphans() error {
} }
} }
// Need to unlock and relock since removeOrphan has its own
// locking.
mp.lock.Unlock()
mp.removeOrphan(foundHash) mp.removeOrphan(foundHash)
mp.lock.Lock()
} }
return nil return nil
} }
// addOrphan adds an orphan transaction to the orphan pool. // addOrphan adds an orphan transaction to the orphan pool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) addOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) { func (mp *txMemPool) addOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) {
// Limit the number orphan transactions to prevent memory exhaustion. A // Limit the number orphan transactions to prevent memory exhaustion. A
// random orphan is evicted to make room if needed. // random orphan is evicted to make room if needed.
mp.limitNumOrphans() mp.limitNumOrphans()
mp.lock.Lock()
defer mp.lock.Unlock()
mp.orphans[*txHash] = tx mp.orphans[*txHash] = tx
for _, txIn := range tx.TxIn { for _, txIn := range tx.TxIn {
originTxHash := txIn.PreviousOutpoint.Hash originTxHash := txIn.PreviousOutpoint.Hash
@ -377,6 +368,8 @@ func (mp *txMemPool) addOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) {
} }
// maybeAddOrphan potentially adds an orphan to the orphan pool. // maybeAddOrphan potentially adds an orphan to the orphan pool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) maybeAddOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) error { func (mp *txMemPool) maybeAddOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) error {
// Ignore orphan transactions that are too large. This helps avoid // Ignore orphan transactions that are too large. This helps avoid
// a memory exhaustion attack based on sending a lot of really large // a memory exhaustion attack based on sending a lot of really large
@ -407,13 +400,36 @@ func (mp *txMemPool) maybeAddOrphan(tx *btcwire.MsgTx, txHash *btcwire.ShaHash)
return nil return nil
} }
// isTransactionInPool returns whether or not the passed transaction already
// exists in the main pool.
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *txMemPool) isTransactionInPool(hash *btcwire.ShaHash) bool {
if _, exists := mp.pool[*hash]; exists {
return true
}
return false
}
// IsTransactionInPool returns whether or not the passed transaction already // IsTransactionInPool returns whether or not the passed transaction already
// exists in the main pool. // exists in the main pool.
//
// This function is safe for concurrent access.
func (mp *txMemPool) IsTransactionInPool(hash *btcwire.ShaHash) bool { func (mp *txMemPool) IsTransactionInPool(hash *btcwire.ShaHash) bool {
mp.lock.RLock() // Protect concurrent access.
defer mp.lock.RUnlock() mp.RLock()
defer mp.RUnlock()
if _, exists := mp.pool[*hash]; exists { return mp.isTransactionInPool(hash)
}
// isOrphanInPool returns whether or not the passed transaction already exists
// in the orphan pool.
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *txMemPool) isOrphanInPool(hash *btcwire.ShaHash) bool {
if _, exists := mp.orphans[*hash]; exists {
return true return true
} }
@ -422,36 +438,46 @@ func (mp *txMemPool) IsTransactionInPool(hash *btcwire.ShaHash) bool {
// IsOrphanInPool returns whether or not the passed transaction already exists // IsOrphanInPool returns whether or not the passed transaction already exists
// in the orphan pool. // in the orphan pool.
//
// This function is safe for concurrent access.
func (mp *txMemPool) IsOrphanInPool(hash *btcwire.ShaHash) bool { func (mp *txMemPool) IsOrphanInPool(hash *btcwire.ShaHash) bool {
mp.lock.RLock() // Protect concurrent access.
defer mp.lock.RUnlock() mp.RLock()
defer mp.RUnlock()
if _, exists := mp.orphans[*hash]; exists { return mp.isOrphanInPool(hash)
return true }
}
return false // haveTransaction returns whether or not the passed transaction already exists
// in the main pool or in the orphan pool.
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *txMemPool) haveTransaction(hash *btcwire.ShaHash) bool {
return mp.isTransactionInPool(hash) || mp.isOrphanInPool(hash)
} }
// HaveTransaction returns whether or not the passed transaction already exists // HaveTransaction returns whether or not the passed transaction already exists
// in the main pool or in the orphan pool. // in the main pool or in the orphan pool.
//
// This function is safe for concurrent access.
func (mp *txMemPool) HaveTransaction(hash *btcwire.ShaHash) bool { func (mp *txMemPool) HaveTransaction(hash *btcwire.ShaHash) bool {
return mp.IsTransactionInPool(hash) || mp.IsOrphanInPool(hash) // Protect concurrent access.
mp.RLock()
defer mp.RUnlock()
return mp.haveTransaction(hash)
} }
// removeTransaction removes the passed transaction from the memory pool. // removeTransaction removes the passed transaction from the memory pool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) removeTransaction(tx *btcwire.MsgTx) { func (mp *txMemPool) removeTransaction(tx *btcwire.MsgTx) {
mp.lock.Lock()
defer mp.lock.Unlock()
// Remove any transactions which rely on this one. // Remove any transactions which rely on this one.
txHash, _ := tx.TxSha() txHash, _ := tx.TxSha()
for i := uint32(0); i < uint32(len(tx.TxOut)); i++ { for i := uint32(0); i < uint32(len(tx.TxOut)); i++ {
outpoint := btcwire.NewOutPoint(&txHash, i) outpoint := btcwire.NewOutPoint(&txHash, i)
if txRedeemer, exists := mp.outpoints[*outpoint]; exists { if txRedeemer, exists := mp.outpoints[*outpoint]; exists {
mp.lock.Unlock()
mp.removeTransaction(txRedeemer) mp.removeTransaction(txRedeemer)
mp.lock.Lock()
} }
} }
@ -463,16 +489,14 @@ func (mp *txMemPool) removeTransaction(tx *btcwire.MsgTx) {
} }
delete(mp.pool, txHash) delete(mp.pool, txHash)
} }
} }
// addTransaction adds the passed transaction to the memory pool. It should // addTransaction adds the passed transaction to the memory pool. It should
// not be called directly as it doesn't perform any validation. This is a // not be called directly as it doesn't perform any validation. This is a
// helper for maybeAcceptTransaction. // helper for maybeAcceptTransaction.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) addTransaction(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) { func (mp *txMemPool) addTransaction(tx *btcwire.MsgTx, txHash *btcwire.ShaHash) {
mp.lock.Lock()
defer mp.lock.Unlock()
// Add the transaction to the pool and mark the referenced outpoints // Add the transaction to the pool and mark the referenced outpoints
// as spent by the pool. // as spent by the pool.
mp.pool[*txHash] = tx mp.pool[*txHash] = tx
@ -485,10 +509,9 @@ func (mp *txMemPool) addTransaction(tx *btcwire.MsgTx, txHash *btcwire.ShaHash)
// attempting to spend coins already spent by other transactions in the pool. // attempting to spend coins already spent by other transactions in the pool.
// Note it does not check for double spends against transactions already in the // Note it does not check for double spends against transactions already in the
// main chain. // main chain.
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *txMemPool) checkPoolDoubleSpend(tx *btcwire.MsgTx) error { func (mp *txMemPool) checkPoolDoubleSpend(tx *btcwire.MsgTx) error {
mp.lock.RLock()
defer mp.lock.RUnlock()
for _, txIn := range tx.TxIn { for _, txIn := range tx.TxIn {
if txR, exists := mp.outpoints[txIn.PreviousOutpoint]; exists { if txR, exists := mp.outpoints[txIn.PreviousOutpoint]; exists {
hash, _ := txR.TxSha() hash, _ := txR.TxSha()
@ -504,10 +527,9 @@ func (mp *txMemPool) checkPoolDoubleSpend(tx *btcwire.MsgTx) error {
// fetchInputTransactions fetches the input transactions referenced by the // fetchInputTransactions fetches the input transactions referenced by the
// passed transaction. First, it fetches from the main chain, then it tries to // passed transaction. First, it fetches from the main chain, then it tries to
// fetch any missing inputs from the transaction pool. // fetch any missing inputs from the transaction pool.
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *txMemPool) fetchInputTransactions(tx *btcwire.MsgTx) (btcchain.TxStore, error) { func (mp *txMemPool) fetchInputTransactions(tx *btcwire.MsgTx) (btcchain.TxStore, error) {
mp.lock.RLock()
defer mp.lock.RUnlock()
txStore, err := mp.server.blockManager.blockChain.FetchTransactionStore(tx) txStore, err := mp.server.blockManager.blockChain.FetchTransactionStore(tx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -531,9 +553,12 @@ func (mp *txMemPool) fetchInputTransactions(tx *btcwire.MsgTx) (btcchain.TxStore
// FetchTransaction returns the requested transaction from the transaction pool. // FetchTransaction returns the requested transaction from the transaction pool.
// This only fetches from the main transaction pool and does not include // This only fetches from the main transaction pool and does not include
// orphans. // orphans.
//
// This function is safe for concurrent access.
func (mp *txMemPool) FetchTransaction(txHash *btcwire.ShaHash) (*btcwire.MsgTx, error) { func (mp *txMemPool) FetchTransaction(txHash *btcwire.ShaHash) (*btcwire.MsgTx, error) {
mp.lock.RLock() // Protect concurrent access.
defer mp.lock.RUnlock() mp.RLock()
defer mp.RUnlock()
if tx, exists := mp.pool[*txHash]; exists { if tx, exists := mp.pool[*txHash]; exists {
return tx, nil return tx, nil
@ -546,6 +571,8 @@ func (mp *txMemPool) FetchTransaction(txHash *btcwire.ShaHash) (*btcwire.MsgTx,
// free-standing transactions into a memory pool. It includes functionality // free-standing transactions into a memory pool. It includes functionality
// such as rejecting duplicate transactions, ensuring transactions follow all // such as rejecting duplicate transactions, ensuring transactions follow all
// rules, orphan transaction handling, and insertion into the memory pool. // rules, orphan transaction handling, and insertion into the memory pool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) maybeAcceptTransaction(tx *btcwire.MsgTx, isOrphan *bool) error { func (mp *txMemPool) maybeAcceptTransaction(tx *btcwire.MsgTx, isOrphan *bool) error {
*isOrphan = false *isOrphan = false
txHash, err := tx.TxSha() txHash, err := tx.TxSha()
@ -556,7 +583,7 @@ func (mp *txMemPool) maybeAcceptTransaction(tx *btcwire.MsgTx, isOrphan *bool) e
// Don't accept the transaction if it already exists in the pool. This // Don't accept the transaction if it already exists in the pool. This
// applies to orphan transactions as well. This check is intended to // applies to orphan transactions as well. This check is intended to
// be a quick check to weed out duplicates. // be a quick check to weed out duplicates.
if mp.HaveTransaction(&txHash) { if mp.haveTransaction(&txHash) {
str := fmt.Sprintf("already have transaction %v", txHash) str := fmt.Sprintf("already have transaction %v", txHash)
return TxRuleError(str) return TxRuleError(str)
} }
@ -689,10 +716,8 @@ func (mp *txMemPool) maybeAcceptTransaction(tx *btcwire.MsgTx, isOrphan *bool) e
// Add to transaction pool. // Add to transaction pool.
mp.addTransaction(tx, &txHash) mp.addTransaction(tx, &txHash)
mp.lock.RLock()
log.Debugf("TXMP: Accepted transaction %v (pool size: %v)", txHash, log.Debugf("TXMP: Accepted transaction %v (pool size: %v)", txHash,
len(mp.pool)) len(mp.pool))
mp.lock.RUnlock()
// TODO(davec): Notifications // TODO(davec): Notifications
@ -707,6 +732,8 @@ func (mp *txMemPool) maybeAcceptTransaction(tx *btcwire.MsgTx, isOrphan *bool) e
// transaction hash (they are no longer orphans if true) and potentially accepts // transaction hash (they are no longer orphans if true) and potentially accepts
// them. It repeats the process for the newly accepted transactions (to detect // them. It repeats the process for the newly accepted transactions (to detect
// further orphans which may no longer be orphans) until there are no more. // further orphans which may no longer be orphans) until there are no more.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *txMemPool) processOrphans(hash *btcwire.ShaHash) error { func (mp *txMemPool) processOrphans(hash *btcwire.ShaHash) error {
// Start with processing at least the passed hash. // Start with processing at least the passed hash.
processHashes := list.New() processHashes := list.New()
@ -764,7 +791,13 @@ func (mp *txMemPool) processOrphans(hash *btcwire.ShaHash) error {
// free-standing transactions into a memory pool. It includes functionality // free-standing transactions into a memory pool. It includes functionality
// such as rejecting duplicate transactions, ensuring transactions follow all // such as rejecting duplicate transactions, ensuring transactions follow all
// rules, orphan transaction handling, and insertion into the memory pool. // rules, orphan transaction handling, and insertion into the memory pool.
//
// This function is safe for concurrent access.
func (mp *txMemPool) ProcessTransaction(tx *btcwire.MsgTx) error { func (mp *txMemPool) ProcessTransaction(tx *btcwire.MsgTx) error {
// Protect concurrent access.
mp.Lock()
defer mp.Unlock()
txHash, err := tx.TxSha() txHash, err := tx.TxSha()
if err != nil { if err != nil {
return err return err
@ -800,9 +833,11 @@ func (mp *txMemPool) ProcessTransaction(tx *btcwire.MsgTx) error {
// TxShas returns a slice of hashes for all of the transactions in the memory // TxShas returns a slice of hashes for all of the transactions in the memory
// pool. // pool.
//
// This function is safe for concurrent access.
func (mp *txMemPool) TxShas() []*btcwire.ShaHash { func (mp *txMemPool) TxShas() []*btcwire.ShaHash {
mp.lock.Lock() mp.RLock()
defer mp.lock.Unlock() defer mp.RUnlock()
hashes := make([]*btcwire.ShaHash, len(mp.pool)) hashes := make([]*btcwire.ShaHash, len(mp.pool))
i := 0 i := 0