diff --git a/lbry/testcase.py b/lbry/testcase.py index 6dc1e2eb9..1efefff17 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -340,6 +340,7 @@ class CommandTestCase(IntegrationTestCase): server_tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, server_tmp_dir) self.server_config = Config() + self.server_config.transaction_cache_size = 10000 self.server_storage = SQLiteStorage(self.server_config, ':memory:') await self.server_storage.open() @@ -389,6 +390,7 @@ class CommandTestCase(IntegrationTestCase): conf.fixed_peers = [('127.0.0.1', 5567)] conf.known_dht_nodes = [] conf.blob_lru_cache_size = self.blob_lru_cache_size + conf.transaction_cache_size = 10000 conf.components_to_skip = [ DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 42fc1a70f..aef2c6811 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -159,7 +159,7 @@ class AIOSQLite: await self.read_ready.wait() still_waiting = True if self._closing: - raise asyncio.CancelledError + raise asyncio.CancelledError() return await asyncio.get_event_loop().run_in_executor( self.reader_executor, read_only_fn, sql, parameters ) @@ -203,7 +203,7 @@ class AIOSQLite: try: async with self.write_lock: if self._closing: - raise asyncio.CancelledError + raise asyncio.CancelledError() return await asyncio.get_event_loop().run_in_executor( self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) ) @@ -240,7 +240,7 @@ class AIOSQLite: try: async with self.write_lock: if self._closing: - raise asyncio.CancelledError + raise asyncio.CancelledError() return await asyncio.get_event_loop().run_in_executor( self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs ) diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index d03862dd4..b9487aeba 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -156,7 +156,7 @@ class Ledger(metaclass=LedgerRegistry): self._on_ready_controller = StreamController() self.on_ready = self._on_ready_controller.stream - self._tx_cache = pylru.lrucache(self.config.get("tx_cache_size", 10_000_000)) + self._tx_cache = pylru.lrucache(self.config.get("tx_cache_size", 10_000)) self._update_tasks = TaskGroup() self._other_tasks = TaskGroup() # that we dont need to start self._utxo_reservation_lock = asyncio.Lock() @@ -578,7 +578,7 @@ class Ledger(metaclass=LedgerRegistry): log.warning("history mismatch: %s vs %s", remote_history[remote_i], pending_synced_history[i]) synced_history += pending_synced_history[i] - cache_size = self.config.get("tx_cache_size", 10_000_000) + cache_size = self.config.get("tx_cache_size", 10_000) for txid, cache_item in updated_cached_items.items(): cache_item.pending_verifications -= 1 if cache_item.pending_verifications < 0: @@ -654,6 +654,34 @@ class Ledger(metaclass=LedgerRegistry): tx.position = merkle['pos'] tx.is_verified = merkle_root == header['merkle_root'] + async def _single_batch(self, batch, remote_heights, header_cache, transactions): + batch_result = await self.network.retriable_call( + self.network.get_transaction_batch, batch + ) + for txid, (raw, merkle) in batch_result.items(): + remote_height = remote_heights[txid] + merkle_height = merkle['block_height'] + cache_item = self._tx_cache.get(txid) + if cache_item is None: + cache_item = TransactionCacheItem() + self._tx_cache[txid] = cache_item + tx = cache_item.tx or Transaction(bytes.fromhex(raw.decode() if isinstance(raw, bytes) else raw), + height=remote_height) + tx.height = remote_height + cache_item.tx = tx + if 'merkle' in merkle and remote_heights[txid] > 0: + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + try: + header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) + except IndexError: + log.warning("failed to verify %s at height %i", tx.id, merkle_height) + else: + header_cache[remote_heights[txid]] = header + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] + transactions.append(tx) + return transactions + async def request_transactions_for_inflate(self, to_request: Tuple[Tuple[str, int], ...]): header_cache = {} batches = [[]] @@ -674,34 +702,8 @@ class Ledger(metaclass=LedgerRegistry): if not batches[-1]: batches.pop() - async def _single_batch(batch): - batch_result = await self.network.retriable_call( - self.network.get_transaction_batch, batch, restricted=False - ) - for txid, (raw, merkle) in batch_result.items(): - remote_height = remote_heights[txid] - merkle_height = merkle['block_height'] - cache_item = self._tx_cache.get(txid) - if cache_item is None: - cache_item = TransactionCacheItem() - self._tx_cache[txid] = cache_item - tx = cache_item.tx or Transaction(bytes.fromhex(raw), height=remote_height) - tx.height = remote_height - cache_item.tx = tx - if 'merkle' in merkle and remote_heights[txid] > 0: - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - try: - header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) - except IndexError: - log.warning("failed to verify %s at height %i", tx.id, merkle_height) - else: - header_cache[remote_heights[txid]] = header - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - transactions.append(tx) - for batch in batches: - await _single_batch(batch) + await self._single_batch(batch, remote_heights, header_cache, transactions) return transactions async def _request_transaction_batch(self, to_request, remote_history_size, address): @@ -728,28 +730,10 @@ class Ledger(metaclass=LedgerRegistry): last_showed_synced_count = 0 async def _single_batch(batch): + transactions = await self._single_batch(batch, remote_heights, header_cache, []) this_batch_synced = [] - batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch) - for txid, (raw, merkle) in batch_result.items(): - remote_height = remote_heights[txid] - merkle_height = merkle['block_height'] - cache_item = self._tx_cache.get(txid) - if cache_item is None: - cache_item = TransactionCacheItem() - self._tx_cache[txid] = cache_item - tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height) - tx.height = remote_height - cache_item.tx = tx - if 'merkle' in merkle and remote_heights[txid] > 0: - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - try: - header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height)) - except IndexError: - log.warning("failed to verify %s at height %i", tx.id, merkle_height) - else: - header_cache[remote_heights[txid]] = header - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] + + for tx in transactions: check_db_for_txos = [] for txi in tx.inputs: @@ -783,6 +767,7 @@ class Ledger(metaclass=LedgerRegistry): synced_txs.append(tx) this_batch_synced.append(tx) + await self.db.save_transaction_io_batch( this_batch_synced, address, self.address_to_hash160(address), "" ) @@ -794,8 +779,10 @@ class Ledger(metaclass=LedgerRegistry): if last_showed_synced_count + 100 < len(synced_txs): log.info("synced %i/%i transactions for %s", len(synced_txs), remote_history_size, address) last_showed_synced_count = len(synced_txs) + for batch in batches: await _single_batch(batch) + return synced_txs async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: diff --git a/lbry/wallet/server/mempool.py b/lbry/wallet/server/mempool.py index e253dc7c5..4a0a80f30 100644 --- a/lbry/wallet/server/mempool.py +++ b/lbry/wallet/server/mempool.py @@ -366,7 +366,7 @@ class MemPool: result.update(tx.prevouts) return result - async def transaction_summaries(self, hashX): + def transaction_summaries(self, hashX): """Return a list of MemPoolTxSummary objects for the hashX.""" result = [] for tx_hash in self.hashXs.get(hashX, ()): diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index 3696f8bc8..cbbb2b30d 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -190,7 +190,9 @@ class SessionManager: self.shutdown_event = shutdown_event self.logger = util.class_logger(__name__, self.__class__.__name__) self.servers: typing.Dict[str, asyncio.AbstractServer] = {} - self.sessions: typing.Set['SessionBase'] = set() + self.sessions: typing.Dict[int, 'SessionBase'] = {} + self.hashx_subscriptions_by_session: typing.DefaultDict[str, typing.Set[int]] = defaultdict(set) + self.mempool_statuses = {} self.cur_group = SessionGroup(0) self.txs_sent = 0 self.start_time = time.time() @@ -276,12 +278,12 @@ class SessionManager: def _group_map(self): group_map = defaultdict(list) - for session in self.sessions: + for session in self.sessions.values(): group_map[session.group].append(session) return group_map def _sub_count(self) -> int: - return sum(s.sub_count() for s in self.sessions) + return sum(s.sub_count() for s in self.sessions.values()) def _lookup_session(self, session_id): try: @@ -289,7 +291,7 @@ class SessionManager: except Exception: pass else: - for session in self.sessions: + for session in self.sessions.values(): if session.session_id == session_id: return session return None @@ -313,7 +315,7 @@ class SessionManager: while True: await sleep(session_timeout // 10) stale_cutoff = time.perf_counter() - session_timeout - stale_sessions = [session for session in self.sessions + stale_sessions = [session for session in self.sessions.values() if session.last_recv < stale_cutoff] if stale_sessions: text = ', '.join(str(session.session_id) @@ -345,7 +347,7 @@ class SessionManager: pending_requests = 0 closing = 0 - for s in self.sessions: + for s in self.sessions.values(): error_count += s.errors if s.log_me: logged += 1 @@ -379,7 +381,7 @@ class SessionManager: def _session_data(self, for_log): """Returned to the RPC 'sessions' call.""" now = time.time() - sessions = sorted(self.sessions, key=lambda s: s.start_time) + sessions = sorted(self.sessions.values(), key=lambda s: s.start_time) return [(session.session_id, session.flags(), session.peer_address_str(for_log=for_log), @@ -583,7 +585,7 @@ class SessionManager: await self._close_servers(list(self.servers.keys())) if self.sessions: await asyncio.wait([ - session.close(force_after=1) for session in self.sessions + session.close(force_after=1) for session in self.sessions.values() ]) await self.stop_other() @@ -638,13 +640,37 @@ class SessionManager: height_changed = height != self.notified_height if height_changed: await self._refresh_hsub_results(height) - if self.sessions: - await asyncio.wait([ - session.notify(touched, height_changed) for session in self.sessions - ]) + if not self.sessions: + return + + if height_changed: + header_tasks = [ + session.send_notification('blockchain.headers.subscribe', (self.hsub_results[session.subscribe_headers_raw], )) + for session in self.sessions.values() if session.subscribe_headers + ] + if header_tasks: + await asyncio.wait(header_tasks) + + touched = touched.intersection(self.hashx_subscriptions_by_session.keys()) + + if touched or (height_changed and self.mempool_statuses): + mempool_hashxs = set(self.mempool_statuses.keys()) + notified = set() + for hashX in touched: + for session_id in self.hashx_subscriptions_by_session[hashX]: + asyncio.create_task(self.sessions[session_id].send_history_notification(hashX)) + notified.add(hashX) + for hashX in mempool_hashxs.difference(touched): + for session_id in self.hashx_subscriptions_by_session[hashX]: + asyncio.create_task(self.sessions[session_id].send_history_notification(hashX)) + notified.add(hashX) + + if touched: + es = '' if len(touched) == 1 else 'es' + self.logger.info(f'notified {len(notified)} mempool/{len(touched):,d} touched address{es}') def add_session(self, session): - self.sessions.add(session) + self.sessions[id(session)] = session self.session_event.set() gid = int(session.start_time - self.start_time) // 900 if self.cur_group.gid != gid: @@ -653,7 +679,13 @@ class SessionManager: def remove_session(self, session): """Remove a session from our sessions list if there.""" - self.sessions.remove(session) + session_id = id(session) + for hashX in session.hashX_subs: + sessions = self.hashx_subscriptions_by_session[hashX] + sessions.remove(session_id) + if not sessions: + self.hashx_subscriptions_by_session.pop(hashX) + self.sessions.pop(session_id) self.session_event.set() @@ -688,8 +720,6 @@ class SessionBase(RPCSession): self._receive_message_orig = self.connection.receive_message self.connection.receive_message = self.receive_message - async def notify(self, touched, height_changed): - pass def default_framer(self): return NewlineFramer(self.env.max_receive) @@ -886,7 +916,6 @@ class LBRYElectrumX(SessionBase): self.connection.max_response_size = self.env.max_send self.hashX_subs = {} self.sv_seen = False - self.mempool_statuses = {} self.protocol_tuple = self.PROTOCOL_MIN self.daemon = self.session_mgr.daemon @@ -931,48 +960,22 @@ class LBRYElectrumX(SessionBase): def sub_count(self): return len(self.hashX_subs) - async def notify(self, touched, height_changed): - """Notify the client about changes to touched addresses (from mempool - updates or new blocks) and height. - """ - if height_changed and self.subscribe_headers: - args = (await self.subscribe_headers_result(), ) - if not (await self.send_notification('blockchain.headers.subscribe', args)): - return - - async def send_history_notification(alias, hashX): + async def send_history_notification(self, hashX): + start = time.perf_counter() + alias = self.hashX_subs[hashX] + if len(alias) == 64: + method = 'blockchain.scripthash.subscribe' + else: + method = 'blockchain.address.subscribe' + try: + self.session_mgr.notifications_in_flight_metric.inc() + status = await self.address_status(hashX) + self.session_mgr.address_history_metric.observe(time.perf_counter() - start) start = time.perf_counter() - if len(alias) == 64: - method = 'blockchain.scripthash.subscribe' - else: - method = 'blockchain.address.subscribe' - try: - self.session_mgr.notifications_in_flight_metric.inc() - status = await self.address_status(hashX) - self.session_mgr.address_history_metric.observe(time.perf_counter() - start) - start = time.perf_counter() - await self.send_notification(method, (alias, status)) - self.session_mgr.notifications_sent_metric.observe(time.perf_counter() - start) - finally: - self.session_mgr.notifications_in_flight_metric.dec() - - touched = touched.intersection(self.hashX_subs) - if touched or (height_changed and self.mempool_statuses): - notified = set() - mempool_addrs = tuple(self.mempool_statuses.keys()) - for hashX in touched: - alias = self.hashX_subs[hashX] - asyncio.create_task(send_history_notification(alias, hashX)) - notified.add(hashX) - for hashX in mempool_addrs: - if hashX not in notified: - alias = self.hashX_subs[hashX] - asyncio.create_task(send_history_notification(alias, hashX)) - notified.add(hashX) - - if touched: - es = '' if len(touched) == 1 else 'es' - self.logger.info(f'notified {len(notified)} mempool/{len(touched):,d} touched address{es}') + await self.send_notification(method, (alias, status)) + self.session_mgr.notifications_sent_metric.observe(time.perf_counter() - start) + finally: + self.session_mgr.notifications_in_flight_metric.dec() def get_metrics_or_placeholder_for_api(self, query_name): """ Do not hold on to a reference to the metrics @@ -1189,7 +1192,7 @@ class LBRYElectrumX(SessionBase): # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0 db_history = await self.session_mgr.limited_history(hashX) - mempool = await self.mempool.transaction_summaries(hashX) + mempool = self.mempool.transaction_summaries(hashX) status = ''.join(f'{hash_to_hex_str(tx_hash)}:' f'{height:d}:' @@ -1203,9 +1206,9 @@ class LBRYElectrumX(SessionBase): status = None if mempool: - self.mempool_statuses[hashX] = status + self.session_mgr.mempool_statuses[hashX] = status else: - self.mempool_statuses.pop(hashX, None) + self.session_mgr.mempool_statuses.pop(hashX, None) return status async def hashX_listunspent(self, hashX): @@ -1224,9 +1227,11 @@ class LBRYElectrumX(SessionBase): async def hashX_subscribe(self, hashX, alias): self.hashX_subs[hashX] = alias + self.session_mgr.hashx_subscriptions_by_session[hashX].add(id(self)) return await self.address_status(hashX) async def hashX_unsubscribe(self, hashX, alias): + self.session_mgr.hashx_subscriptions_by_session[hashX].remove(id(self)) self.hashX_subs.pop(hashX, None) def address_to_hashX(self, address): @@ -1249,7 +1254,7 @@ class LBRYElectrumX(SessionBase): async def address_get_mempool(self, address): """Return the mempool transactions touching an address.""" hashX = self.address_to_hashX(address) - return await self.unconfirmed_history(hashX) + return self.unconfirmed_history(hashX) async def address_listunspent(self, address): """Return the list of UTXOs of an address.""" @@ -1285,20 +1290,20 @@ class LBRYElectrumX(SessionBase): hashX = scripthash_to_hashX(scripthash) return await self.get_balance(hashX) - async def unconfirmed_history(self, hashX): + def unconfirmed_history(self, hashX): # Note unconfirmed history is unordered in electrum-server # height is -1 if it has unconfirmed inputs, otherwise 0 return [{'tx_hash': hash_to_hex_str(tx.hash), 'height': -tx.has_unconfirmed_inputs, 'fee': tx.fee} - for tx in await self.mempool.transaction_summaries(hashX)] + for tx in self.mempool.transaction_summaries(hashX)] async def confirmed_and_unconfirmed_history(self, hashX): # Note history is ordered but unconfirmed is unordered in e-s history = await self.session_mgr.limited_history(hashX) conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height} for tx_hash, height in history] - return conf + await self.unconfirmed_history(hashX) + return conf + self.unconfirmed_history(hashX) async def scripthash_get_history(self, scripthash): """Return the confirmed and unconfirmed history of a scripthash.""" @@ -1308,7 +1313,7 @@ class LBRYElectrumX(SessionBase): async def scripthash_get_mempool(self, scripthash): """Return the mempool transactions touching a scripthash.""" hashX = scripthash_to_hashX(scripthash) - return await self.unconfirmed_history(hashX) + return self.unconfirmed_history(hashX) async def scripthash_listunspent(self, scripthash): """Return the list of UTXOs of a scripthash.""" diff --git a/tests/integration/blockchain/test_transactions.py b/tests/integration/blockchain/test_transactions.py index b09926034..8690698a6 100644 --- a/tests/integration/blockchain/test_transactions.py +++ b/tests/integration/blockchain/test_transactions.py @@ -138,8 +138,8 @@ class BasicTransactionTests(IntegrationTestCase): # evil trick: mempool is unsorted on real life, but same order between python instances. reproduce it original_summary = self.conductor.spv_node.server.mempool.transaction_summaries - async def random_summary(*args, **kwargs): - summary = await original_summary(*args, **kwargs) + def random_summary(*args, **kwargs): + summary = original_summary(*args, **kwargs) if summary and len(summary) > 2: ordered = summary.copy() while summary == ordered: diff --git a/tests/integration/blockchain/test_wallet_commands.py b/tests/integration/blockchain/test_wallet_commands.py index 4f8f0d6ed..f556d86bb 100644 --- a/tests/integration/blockchain/test_wallet_commands.py +++ b/tests/integration/blockchain/test_wallet_commands.py @@ -10,7 +10,7 @@ from lbry.wallet.dewies import dict_values_to_lbc class WalletCommands(CommandTestCase): async def test_wallet_create_and_add_subscribe(self): - session = next(iter(self.conductor.spv_node.server.session_mgr.sessions)) + session = next(iter(self.conductor.spv_node.server.session_mgr.sessions.values())) self.assertEqual(len(session.hashX_subs), 27) wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True) self.assertEqual(len(session.hashX_subs), 28) diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index 8968116af..ed46dbeff 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -36,12 +36,14 @@ class BlobExchangeTestBase(AsyncioTestCase): self.addCleanup(shutil.rmtree, self.server_dir) self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir, fixed_peers=[]) + self.server_config.transaction_cache_size = 10000 self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite")) self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config) self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_wallet_dir, fixed_peers=[]) + self.client_config.transaction_cache_size = 10000 self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config) self.client_peer_manager = PeerManager(self.loop) diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index fa307d65f..904cc1574 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -98,12 +98,14 @@ async def get_mock_wallet(sd_hash, storage, wallet_dir, balance=10.0, fee=None): wallet = Wallet() ledger = Ledger({ 'db': Database(os.path.join(wallet_dir, 'blockchain.db')), - 'headers': FakeHeaders(514082) + 'headers': FakeHeaders(514082), + 'tx_cache_size': 10000 }) await ledger.db.open() wallet.generate_account(ledger) manager = WalletManager() manager.config = Config() + manager.config.transaction_cache_size = 10000 manager.wallets.append(wallet) manager.ledgers[Ledger] = ledger manager.ledger.network.client = ClientSession(