From f11872cbf46128e22e33c3138755c4aa32e80de7 Mon Sep 17 00:00:00 2001
From: Fabian Jahr <fjahr@protonmail.com>
Date: Tue, 31 Dec 2019 19:55:18 +0100
Subject: [PATCH] wallet: Reset reused transactions cache

If a destination is reused we mark the cache of the other transactions going to that destination dirty so they are not accidentally reported as trusted when the cache is hit.

Github-Pull: #17843
Rebased-From: 6fc554f591d8ea1681b8bb25aa12da8d4f023f66
---
 src/wallet/wallet.cpp                | 27 ++++++++++++++++++++---
 src/wallet/wallet.h                  |  8 ++++++-
 test/functional/wallet_avoidreuse.py | 33 +++++++++++++++++++++++++++-
 3 files changed, 63 insertions(+), 5 deletions(-)

diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp
index 4c3b633fe..fdaafd518 100644
--- a/src/wallet/wallet.cpp
+++ b/src/wallet/wallet.cpp
@@ -1055,7 +1055,7 @@ bool CWallet::MarkReplaced(const uint256& originalHash, const uint256& newHash)
     return success;
 }
 
-void CWallet::SetUsedDestinationState(const uint256& hash, unsigned int n, bool used)
+void CWallet::SetUsedDestinationState(const uint256& hash, unsigned int n, bool used, std::set<CTxDestination>& tx_destinations)
 {
     const CWalletTx* srctx = GetWalletTx(hash);
     if (!srctx) return;
@@ -1065,7 +1065,9 @@ void CWallet::SetUsedDestinationState(const uint256& hash, unsigned int n, bool
         if (::IsMine(*this, dst)) {
             LOCK(cs_wallet);
             if (used && !GetDestData(dst, "used", nullptr)) {
-                AddDestData(dst, "used", "p"); // p for "present", opposite of absent (null)
+                if (AddDestData(dst, "used", "p")) { // p for "present", opposite of absent (null)
+                    tx_destinations.insert(dst);
+                }
             } else if (!used && GetDestData(dst, "used", nullptr)) {
                 EraseDestData(dst, "used");
             }
@@ -1110,10 +1112,14 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFlushOnClose)
 
     if (IsWalletFlagSet(WALLET_FLAG_AVOID_REUSE)) {
         // Mark used destinations
+        std::set<CTxDestination> tx_destinations;
+
         for (const CTxIn& txin : wtxIn.tx->vin) {
             const COutPoint& op = txin.prevout;
-            SetUsedDestinationState(op.hash, op.n, true);
+            SetUsedDestinationState(op.hash, op.n, true, tx_destinations);
         }
+
+        MarkDestinationsDirty(tx_destinations);
     }
 
     // Inserts only if not already there, returns tx inserted or tx found
@@ -3793,6 +3799,21 @@ int64_t CWallet::GetOldestKeyPoolTime()
     return oldestKey;
 }
 
+void CWallet::MarkDestinationsDirty(const std::set<CTxDestination>& destinations) {
+    for (auto& entry : mapWallet) {
+        CWalletTx& wtx = entry.second;
+
+        for (unsigned int i = 0; i < wtx.tx->vout.size(); i++) {
+            CTxDestination dst;
+
+            if (ExtractDestination(wtx.tx->vout[i].scriptPubKey, dst) && destinations.count(dst)) {
+                wtx.MarkDirty();
+                break;
+            }
+        }
+    }
+}
+
 std::map<CTxDestination, CAmount> CWallet::GetAddressBalances(interfaces::Chain::Lock& locked_chain)
 {
     std::map<CTxDestination, CAmount> balances;
diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h
index ee4ac9158..f4f0156b4 100644
--- a/src/wallet/wallet.h
+++ b/src/wallet/wallet.h
@@ -1006,7 +1006,7 @@ public:
 
     // Whether this or any known UTXO with the same single key has been spent.
     bool IsUsedDestination(const uint256& hash, unsigned int n) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet);
-    void SetUsedDestinationState(const uint256& hash, unsigned int n, bool used);
+    void SetUsedDestinationState(const uint256& hash, unsigned int n, bool used, std::set<CTxDestination>& tx_destinations);
 
     std::vector<OutputGroup> GroupOutputs(const std::vector<COutput>& outputs, bool single_coin) const;
 
@@ -1216,6 +1216,12 @@ public:
 
     std::set<CTxDestination> GetLabelAddresses(const std::string& label) const;
 
+    /**
+     * Marks all outputs in each one of the destinations dirty, so their cache is
+     * reset and does not return outdated information.
+     */
+    void MarkDestinationsDirty(const std::set<CTxDestination>& destinations) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet);
+
     bool GetNewDestination(const OutputType type, const std::string label, CTxDestination& dest, std::string& error);
     bool GetNewChangeDestination(const OutputType type, CTxDestination& dest, std::string& error);
 
diff --git a/test/functional/wallet_avoidreuse.py b/test/functional/wallet_avoidreuse.py
index 55b30afde..d261fe634 100755
--- a/test/functional/wallet_avoidreuse.py
+++ b/test/functional/wallet_avoidreuse.py
@@ -88,7 +88,8 @@ class AvoidReuseTest(BitcoinTestFramework):
         self.test_fund_send_fund_send("p2sh-segwit")
         reset_balance(self.nodes[1], self.nodes[0].getnewaddress())
         self.test_fund_send_fund_send("bech32")
-
+        reset_balance(self.nodes[1], self.nodes[0].getnewaddress())
+        self.test_getbalances_used()
 
     def test_persistence(self):
         '''Test that wallet files persist the avoid_reuse flag.'''
@@ -248,5 +249,35 @@ class AvoidReuseTest(BitcoinTestFramework):
         assert_approx(self.nodes[1].getbalance(), 1, 0.001)
         assert_approx(self.nodes[1].getbalance(avoid_reuse=False), 11, 0.001)
 
+    def test_getbalances_used(self):
+        '''
+        getbalances and listunspent should pick up on reused addresses
+        immediately, even for address reusing outputs created before the first
+        transaction was spending from that address
+        '''
+        self.log.info("Test getbalances used category")
+
+        # node under test should be completely empty
+        assert_equal(self.nodes[1].getbalance(avoid_reuse=False), 0)
+
+        new_addr = self.nodes[1].getnewaddress()
+        ret_addr = self.nodes[0].getnewaddress()
+
+        # send multiple transactions, reusing one address
+        for _ in range(11):
+            self.nodes[0].sendtoaddress(new_addr, 1)
+
+        self.nodes[0].generate(1)
+        self.sync_all()
+
+        # send transaction that should not use all the available outputs
+        # per the current coin selection algorithm
+        self.nodes[1].sendtoaddress(ret_addr, 5)
+
+        # getbalances and listunspent should show the remaining outputs
+        # in the reused address as used/reused
+        assert_unspent(self.nodes[1], total_count=2, total_sum=6, reused_count=1, reused_sum=1)
+        assert_balances(self.nodes[1], mine={"used": 1, "trusted": 5})
+
 if __name__ == '__main__':
     AvoidReuseTest().main()