From ff8c08b289fde493d73d6cecdafc11f4793048a4 Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Sun, 16 Jan 2022 14:00:48 -0500
Subject: [PATCH] executors

---
 lbry/testcase.py                      |  3 ++-
 lbry/wallet/server/block_processor.py |  4 ++--
 lbry/wallet/server/chain_reader.py    |  5 ++---
 lbry/wallet/server/db/db.py           | 26 ++++++++++++++------------
 4 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/lbry/testcase.py b/lbry/testcase.py
index e761a2602..fbd3fb2e1 100644
--- a/lbry/testcase.py
+++ b/lbry/testcase.py
@@ -464,7 +464,8 @@ class CommandTestCase(IntegrationTestCase):
     async def confirm_tx(self, txid, ledger=None):
         """ Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """
         await self.on_transaction_id(txid, ledger)
-        await asyncio.wait([self.generate(1), self.on_transaction_id(txid, ledger)], timeout=5)
+        on_tx = self.on_transaction_id(txid, ledger)
+        await asyncio.wait([self.generate(1), on_tx], timeout=5)
         return txid
 
     async def on_transaction_dict(self, tx):
diff --git a/lbry/wallet/server/block_processor.py b/lbry/wallet/server/block_processor.py
index 643c71245..bd11e3d4e 100644
--- a/lbry/wallet/server/block_processor.py
+++ b/lbry/wallet/server/block_processor.py
@@ -86,11 +86,11 @@ class BlockProcessor:
         self.env = env
         self.state_lock = asyncio.Lock()
         self.daemon = env.coin.DAEMON(env.coin, env.daemon_url)
+        self._chain_executor = ThreadPoolExecutor(1, thread_name_prefix='block-processor')
         self.db = HubDB(
             env.coin, env.db_dir, env.cache_MB, env.reorg_limit, env.cache_all_claim_txos, env.cache_all_tx_hashes,
-            max_open_files=env.db_max_open_files
+            max_open_files=env.db_max_open_files, executor=self._chain_executor
         )
-        self._chain_executor = ThreadPoolExecutor(1, thread_name_prefix='block-processor')
         self.shutdown_event = asyncio.Event()
         self.coin = env.coin
         if env.coin.NET == 'mainnet':
diff --git a/lbry/wallet/server/chain_reader.py b/lbry/wallet/server/chain_reader.py
index cc7ff1b4c..27535b391 100644
--- a/lbry/wallet/server/chain_reader.py
+++ b/lbry/wallet/server/chain_reader.py
@@ -20,15 +20,15 @@ class BlockchainReader:
         self.log = logging.getLogger(__name__).getChild(self.__class__.__name__)
         self.shutdown_event = asyncio.Event()
         self.cancellable_tasks = []
+        self._executor = ThreadPoolExecutor(thread_workers, thread_name_prefix=thread_prefix)
 
         self.db = HubDB(
             env.coin, env.db_dir, env.cache_MB, env.reorg_limit, env.cache_all_claim_txos, env.cache_all_tx_hashes,
-            secondary_name=secondary_name, max_open_files=-1
+            secondary_name=secondary_name, max_open_files=-1, executor=self._executor
         )
         self.last_state: typing.Optional[DBState] = None
         self._refresh_interval = 0.1
         self._lock = asyncio.Lock()
-        self._executor = ThreadPoolExecutor(thread_workers, thread_name_prefix=thread_prefix)
 
     def _detect_changes(self):
         try:
@@ -241,7 +241,6 @@ class BlockchainReaderServer(BlockchainReader):
             pass
         finally:
             loop.run_until_complete(self.stop())
-            executor.shutdown(True)
 
     async def start_prometheus(self):
         if not self.prometheus_server and self.env.prometheus_port:
diff --git a/lbry/wallet/server/db/db.py b/lbry/wallet/server/db/db.py
index 679f9763e..188e90b01 100644
--- a/lbry/wallet/server/db/db.py
+++ b/lbry/wallet/server/db/db.py
@@ -11,6 +11,7 @@ from functools import partial
 from asyncio import sleep
 from bisect import bisect_right
 from collections import defaultdict
+from concurrent.futures.thread import ThreadPoolExecutor
 
 from lbry.error import ResolveCensoredError
 from lbry.schema.result import Censor
@@ -43,9 +44,10 @@ class HubDB:
 
     def __init__(self, coin, db_dir: str, cache_MB: int = 512, reorg_limit: int = 200,
                  cache_all_claim_txos: bool = False, cache_all_tx_hashes: bool = False,
-                 secondary_name: str = '', max_open_files: int = 256):
+                 secondary_name: str = '', max_open_files: int = 256, executor: ThreadPoolExecutor = None):
         self.logger = util.class_logger(__name__, self.__class__.__name__)
         self.coin = coin
+        self._executor = executor
         self._db_dir = db_dir
 
         self._cache_MB = cache_MB
@@ -332,7 +334,7 @@ class HubDB:
         return ExpandedResolveResult(resolved_stream, resolved_channel, repost, reposted_channel)
 
     async def resolve(self, url) -> ExpandedResolveResult:
-         return await asyncio.get_event_loop().run_in_executor(None, self._resolve, url)
+         return await asyncio.get_event_loop().run_in_executor(self._executor, self._resolve, url)
 
     def _fs_get_claim_by_hash(self, claim_hash):
         claim = self.get_cached_claim_txo(claim_hash)
@@ -417,7 +419,7 @@ class HubDB:
             self.filtered_streams, self.filtered_channels = self.get_streams_and_channels_reposted_by_channel_hashes(
                 self.filtering_channel_hashes
             )
-        await asyncio.get_event_loop().run_in_executor(None, reload)
+        await asyncio.get_event_loop().run_in_executor(self._executor, reload)
 
     def get_streams_and_channels_reposted_by_channel_hashes(self, reposter_channel_hashes: Set[bytes]):
         streams, channels = {}, {}
@@ -741,7 +743,7 @@ class HubDB:
                 v.tx_count for v in self.prefix_db.tx_count.iterate(include_key=False, fill_cache=False)
             ]
 
-        tx_counts = await asyncio.get_event_loop().run_in_executor(None, get_counts)
+        tx_counts = await asyncio.get_event_loop().run_in_executor(self._executor, get_counts)
         assert len(tx_counts) == self.db_height + 1, f"{len(tx_counts)} vs {self.db_height + 1}"
         self.tx_counts = array.array('I', tx_counts)
 
@@ -762,7 +764,7 @@ class HubDB:
         self.txo_to_claim.clear()
         start = time.perf_counter()
         self.logger.info("loading claims")
-        await asyncio.get_event_loop().run_in_executor(None, read_claim_txos)
+        await asyncio.get_event_loop().run_in_executor(self._executor, read_claim_txos)
         ts = time.perf_counter() - start
         self.logger.info("loaded %i claim txos in %ss", len(self.claim_to_txo), round(ts, 4))
 
@@ -777,7 +779,7 @@ class HubDB:
                 )
             ]
 
-        headers = await asyncio.get_event_loop().run_in_executor(None, get_headers)
+        headers = await asyncio.get_event_loop().run_in_executor(self._executor, get_headers)
         assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}"
         self.headers = headers
 
@@ -789,7 +791,7 @@ class HubDB:
         self.total_transactions.clear()
         self.tx_num_mapping.clear()
         start = time.perf_counter()
-        self.total_transactions.extend(await asyncio.get_event_loop().run_in_executor(None, _read_tx_hashes))
+        self.total_transactions.extend(await asyncio.get_event_loop().run_in_executor(self._executor, _read_tx_hashes))
         self.tx_num_mapping = {
             tx_hash: tx_num for tx_num, tx_hash in enumerate(self.total_transactions)
         }
@@ -936,7 +938,7 @@ class HubDB:
             return x
 
         if disk_count:
-            return await asyncio.get_event_loop().run_in_executor(None, read_headers), disk_count
+            return await asyncio.get_event_loop().run_in_executor(self._executor, read_headers), disk_count
         return b'', 0
 
     def fs_tx_hash(self, tx_num):
@@ -1029,7 +1031,7 @@ class HubDB:
         transactions.  By default returns at most 1000 entries.  Set
         limit to None to get them all.
         """
-        return await asyncio.get_event_loop().run_in_executor(None, self.read_history, hashX, limit)
+        return await asyncio.get_event_loop().run_in_executor(self._executor, self.read_history, hashX, limit)
 
     # -- Undo information
 
@@ -1119,7 +1121,7 @@ class HubDB:
             return utxos
 
         while True:
-            utxos = await asyncio.get_event_loop().run_in_executor(None, read_utxos)
+            utxos = await asyncio.get_event_loop().run_in_executor(self._executor, read_utxos)
             if all(utxo.tx_hash is not None for utxo in utxos):
                 return utxos
             self.logger.warning(f'all_utxos: tx hash not '
@@ -1144,11 +1146,11 @@ class HubDB:
                 if utxo_value:
                     utxo_append((hashX, utxo_value.amount))
             return utxos
-        return await asyncio.get_event_loop().run_in_executor(None, lookup_utxos)
+        return await asyncio.get_event_loop().run_in_executor(self._executor, lookup_utxos)
 
     async def get_trending_notifications(self, height: int):
         def read_trending():
             return {
                 k.claim_hash: v for k, v in self.prefix_db.trending_notification.iterate((height,))
             }
-        return await asyncio.get_event_loop().run_in_executor(None, read_trending)
+        return await asyncio.get_event_loop().run_in_executor(self._executor, read_trending)