From 37d3d83ed3058a971e33e34309bc4deeff9d00c7 Mon Sep 17 00:00:00 2001
From: Dave Collins <davec@conformal.com>
Date: Fri, 15 Nov 2013 16:12:23 -0600
Subject: [PATCH] Improve mempool handling.

- Lock the mempool when removing transactions during a notification as
  intended
- When generating the inventory vectors to serve on a mempool request,
  recheck the memory pool for each hash since it's possible another thread
  could have removed an entry after the initial query for available
  hashes
- When a block is connected, remove any transactions which are now double
  spends as a result of the newly connected transactions
---
 blockmanager.go | 11 ++++++++---
 mempool.go      | 36 +++++++++++++++++++++++++++++++++++-
 peer.go         |  6 ++++++
 3 files changed, 49 insertions(+), 4 deletions(-)

diff --git a/blockmanager.go b/blockmanager.go
index a6c7a330..e4dd8177 100644
--- a/blockmanager.go
+++ b/blockmanager.go
@@ -645,9 +645,14 @@ func (b *blockManager) handleNotifyMsg(notification *btcchain.Notification) {
 		}
 
 		// Remove all of the transactions (except the coinbase) in the
-		// connected block from the transaction pool.
+		// connected block from the transaction pool.  Also, remove any
+		// transactions which are now double spends as a result of these
+		// new transactions.  Note that removing a transaction from
+		// pool also removes any transactions which depend on it,
+		// recursively.
 		for _, tx := range block.Transactions()[1:] {
-			b.server.txMemPool.removeTransaction(tx)
+			b.server.txMemPool.RemoveTransaction(tx)
+			b.server.txMemPool.RemoveDoubleSpends(tx)
 		}
 
 		// Notify frontends
@@ -674,7 +679,7 @@ func (b *blockManager) handleNotifyMsg(notification *btcchain.Notification) {
 				// Remove the transaction and all transactions
 				// that depend on it if it wasn't accepted into
 				// the transaction pool.
-				b.server.txMemPool.removeTransaction(tx)
+				b.server.txMemPool.RemoveTransaction(tx)
 			}
 		}
 
diff --git a/mempool.go b/mempool.go
index fac768ad..c6e5c03a 100644
--- a/mempool.go
+++ b/mempool.go
@@ -565,7 +565,8 @@ func (mp *txMemPool) HaveTransaction(hash *btcwire.ShaHash) bool {
 	return mp.haveTransaction(hash)
 }
 
-// removeTransaction removes the passed transaction from the memory pool.
+// removeTransaction is the internal function which implements the public
+// RemoveTransaction.  See the comment for RemoveTransaction for more details.
 //
 // This function MUST be called with the mempool lock held (for writes).
 func (mp *txMemPool) removeTransaction(tx *btcutil.Tx) {
@@ -588,6 +589,39 @@ func (mp *txMemPool) removeTransaction(tx *btcutil.Tx) {
 	}
 }
 
+// RemoveTransaction removes the passed transaction and any transactions which
+// depend on it from the memory pool.
+//
+// This function is safe for concurrent access.
+func (mp *txMemPool) RemoveTransaction(tx *btcutil.Tx) {
+	// Protect concurrent access.
+	mp.Lock()
+	defer mp.Unlock()
+
+	mp.removeTransaction(tx)
+}
+
+// RemoveDoubleSpends removes all transactions which spend outputs spent by the
+// passed transaction from the memory pool.  Removing those transactions then
+// leads to removing all transactions which rely on them, recursively.  This is
+// necessary when a block is connected to the main chain because the block may
+// contain transactions which were previously unknown to the memory pool
+//
+// This function is safe for concurrent access.
+func (mp *txMemPool) RemoveDoubleSpends(tx *btcutil.Tx) {
+	// Protect concurrent access.
+	mp.Lock()
+	defer mp.Unlock()
+
+	for _, txIn := range tx.MsgTx().TxIn {
+		if txRedeemer, ok := mp.outpoints[txIn.PreviousOutpoint]; ok {
+			if !txRedeemer.Sha().IsEqual(tx.Sha()) {
+				mp.removeTransaction(txRedeemer)
+			}
+		}
+	}
+}
+
 // addTransaction adds the passed transaction to the memory pool.  It should
 // not be called directly as it doesn't perform any validation.  This is a
 // helper for maybeAcceptTransaction.
diff --git a/peer.go b/peer.go
index 167cc24c..bdecc4b2 100644
--- a/peer.go
+++ b/peer.go
@@ -440,6 +440,12 @@ func (p *peer) handleMemPoolMsg(msg *btcwire.MsgMemPool) {
 	invMsg := btcwire.NewMsgInv()
 	hashes := p.server.txMemPool.TxShas()
 	for i, hash := range hashes {
+		// Another thread might have removed the transaction from the
+		// pool since the initial query.
+		if !p.server.txMemPool.IsTransactionInPool(hash) {
+			continue
+		}
+
 		iv := btcwire.NewInvVect(btcwire.InvTypeTx, hash)
 		invMsg.AddInvVect(iv)
 		if i+1 >= btcwire.MaxInvPerMsg {