diff --git a/scribe/db/db.py b/scribe/db/db.py index 879e4ca..479325c 100644 --- a/scribe/db/db.py +++ b/scribe/db/db.py @@ -976,41 +976,44 @@ class HubDB: async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]): tx_infos = {} - for tx_hash in tx_hashes: - tx_infos[tx_hash] = await asyncio.get_event_loop().run_in_executor( - self._executor, self._get_transaction_and_merkle, tx_hash - ) - await asyncio.sleep(0) - return tx_infos + needed = [] + needed_confirmed = [] + needed_mempool = [] + run_in_executor = asyncio.get_event_loop().run_in_executor - def _get_transaction_and_merkle(self, tx_hash): - cached_tx = self._tx_and_merkle_cache.get(tx_hash) - if cached_tx: - tx, merkle = cached_tx - else: - tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] - tx_num = self.prefix_db.tx_num.get(tx_hash_bytes) - tx = None - tx_height = -1 - tx_num = None if not tx_num else tx_num.tx_num - if tx_num is not None: - if self._cache_all_claim_txos: - fill_cache = tx_num in self.txo_to_claim and len(self.txo_to_claim[tx_num]) > 0 - else: - fill_cache = True - tx_height = bisect_right(self.tx_counts, tx_num) - tx = self.prefix_db.tx.get(tx_hash_bytes, fill_cache=fill_cache, deserialize_value=False) - if tx_height == -1: - merkle = { - 'block_height': -1 - } - tx = self.prefix_db.mempool_tx.get(tx_hash_bytes, deserialize_value=False) + for tx_hash in tx_hashes: + cached_tx = self._tx_and_merkle_cache.get(tx_hash) + if cached_tx: + tx, merkle = cached_tx + tx_infos[tx_hash] = None if not tx else tx.hex(), merkle else: + tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] + if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping: + needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes])) + else: + needed.append(tx_hash_bytes) + + if needed: + for tx_hash_bytes, v in zip(needed, await run_in_executor( + self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed], + True, True)): + tx_num = None if v is None else v.tx_num + if tx_num is not None: + needed_confirmed.append((tx_hash_bytes, tx_num)) + else: + needed_mempool.append(tx_hash_bytes) + await asyncio.sleep(0) + + if needed_confirmed: + for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor( + self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed], + True, False)): + tx_height = bisect_right(self.tx_counts, tx_num) tx_pos = tx_num - self.tx_counts[tx_height - 1] branch, root = self.merkle.branch_and_root( self.get_block_txs(tx_height), tx_pos ) - merkle = { + tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), { 'block_height': tx_height, 'merkle': [ hash_to_hex_str(_hash) @@ -1018,9 +1021,14 @@ class HubDB: ], 'pos': tx_pos } - if tx_height > 0 and tx_height + 10 < self.db_height: - self._tx_and_merkle_cache[tx_hash] = tx, merkle - return None if not tx else tx.hex(), merkle + await asyncio.sleep(0) + if needed_mempool: + for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor( + self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool], + True, False)): + tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1} + await asyncio.sleep(0) + return tx_infos async def fs_block_hashes(self, height, count): if height + count > len(self.headers): diff --git a/scribe/db/interface.py b/scribe/db/interface.py index 57278f2..b4f430b 100644 --- a/scribe/db/interface.py +++ b/scribe/db/interface.py @@ -90,12 +90,9 @@ class PrefixRow(metaclass=PrefixRowType): def multi_get(self, key_args: typing.List[typing.Tuple], fill_cache=True, deserialize_value=True): packed_keys = {tuple(args): self.pack_key(*args) for args in key_args} - result = { - k[-1]: v for k, v in ( - self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args], - fill_cache=fill_cache) or {} - ).items() - } + db_result = self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args], + fill_cache=fill_cache) + result = {k[-1]: v for k, v in (db_result or {}).items()} def handle_value(v): return None if v is None else v if not deserialize_value else self.unpack_value(v)