diff --git a/scribe/blockchain/service.py b/scribe/blockchain/service.py index 8a4c410..10b2e7c 100644 --- a/scribe/blockchain/service.py +++ b/scribe/blockchain/service.py @@ -181,7 +181,7 @@ class BlockchainProcessorService(BlockchainService): self.log.warning("failed to get a mempool tx, reorg underway?") return if current_mempool: - if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.coin.header_hash(self.db.headers[-1]): + if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.db.block_hashes[-1]: return await self.run_in_thread( update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool @@ -1417,6 +1417,7 @@ class BlockchainProcessorService(BlockchainService): self.height = height self.db.headers.append(block.header) + self.db.block_hashes.append(self.env.coin.header_hash(block.header)) self.tip = self.coin.header_hash(block.header) self.db.fs_height = self.height @@ -1493,8 +1494,9 @@ class BlockchainProcessorService(BlockchainService): # Check and update self.tip self.db.tx_counts.pop() - reverted_block_hash = self.coin.header_hash(self.db.headers.pop()) - self.tip = self.coin.header_hash(self.db.headers[-1]) + self.db.headers.pop() + reverted_block_hash = self.db.block_hashes.pop() + self.tip = self.db.block_hashes[-1] if self.env.cache_all_tx_hashes: while len(self.db.total_transactions) > self.db.tx_counts[-1]: self.db.tx_num_mapping.pop(self.db.total_transactions.pop()) diff --git a/scribe/db/db.py b/scribe/db/db.py index 879e4ca..af5a24d 100644 --- a/scribe/db/db.py +++ b/scribe/db/db.py @@ -78,6 +78,7 @@ class HubDB: self.tx_counts = None self.headers = None + self.block_hashes = None self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server') self.last_flush = time.time() @@ -775,6 +776,18 @@ class HubDB: assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}" self.headers = headers + async def _read_block_hashes(self): + def get_block_hashes(): + return [ + block_hash for block_hash in self.prefix_db.block_hash.iterate( + start=(0, ), stop=(self.db_height + 1, ), include_key=False, fill_cache=False, deserialize_value=False + ) + ] + + block_hashes = await asyncio.get_event_loop().run_in_executor(self._executor, get_block_hashes) + assert len(block_hashes) == len(self.headers) + self.block_hashes = block_hashes + async def _read_tx_hashes(self): def _read_tx_hashes(): return list(self.prefix_db.tx_hash.iterate(start=(0,), stop=(self.db_tx_count + 1,), include_key=False, fill_cache=False, deserialize_value=False)) @@ -839,6 +852,7 @@ class HubDB: async def initialize_caches(self): await self._read_tx_counts() await self._read_headers() + await self._read_block_hashes() if self._cache_all_claim_txos: await self._read_claim_txos() if self._cache_all_tx_hashes: @@ -976,51 +990,74 @@ class HubDB: async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]): tx_infos = {} - for tx_hash in tx_hashes: - tx_infos[tx_hash] = await asyncio.get_event_loop().run_in_executor( - self._executor, self._get_transaction_and_merkle, tx_hash - ) - await asyncio.sleep(0) - return tx_infos + needed = [] + needed_confirmed = [] + needed_mempool = [] + run_in_executor = asyncio.get_event_loop().run_in_executor - def _get_transaction_and_merkle(self, tx_hash): - cached_tx = self._tx_and_merkle_cache.get(tx_hash) - if cached_tx: - tx, merkle = cached_tx - else: - tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] - tx_num = self.prefix_db.tx_num.get(tx_hash_bytes) - tx = None - tx_height = -1 - tx_num = None if not tx_num else tx_num.tx_num - if tx_num is not None: - if self._cache_all_claim_txos: - fill_cache = tx_num in self.txo_to_claim and len(self.txo_to_claim[tx_num]) > 0 - else: - fill_cache = True - tx_height = bisect_right(self.tx_counts, tx_num) - tx = self.prefix_db.tx.get(tx_hash_bytes, fill_cache=fill_cache, deserialize_value=False) - if tx_height == -1: - merkle = { - 'block_height': -1 - } - tx = self.prefix_db.mempool_tx.get(tx_hash_bytes, deserialize_value=False) + for tx_hash in tx_hashes: + cached_tx = self._tx_and_merkle_cache.get(tx_hash) + if cached_tx: + tx, merkle = cached_tx + tx_infos[tx_hash] = None if not tx else tx.hex(), merkle else: + tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] + if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping: + needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes])) + else: + needed.append(tx_hash_bytes) + + if needed: + for tx_hash_bytes, v in zip(needed, await run_in_executor( + self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed], + True, True)): + tx_num = None if v is None else v.tx_num + if tx_num is not None: + needed_confirmed.append((tx_hash_bytes, tx_num)) + else: + needed_mempool.append(tx_hash_bytes) + await asyncio.sleep(0) + + if needed_confirmed: + needed_heights = set() + tx_heights_and_positions = defaultdict(list) + for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor( + self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed], + True, False)): + tx_height = bisect_right(self.tx_counts, tx_num) + needed_heights.add(tx_height) tx_pos = tx_num - self.tx_counts[tx_height - 1] - branch, root = self.merkle.branch_and_root( - self.get_block_txs(tx_height), tx_pos + tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos)) + + sorted_heights = list(sorted(needed_heights)) + block_txs = await run_in_executor( + self._executor, self.prefix_db.block_txs.multi_get, [(height,) for height in sorted_heights] + ) + block_txs = {height: v.tx_hashes for height, v in zip(sorted_heights, block_txs)} + for tx_height, v in tx_heights_and_positions.items(): + branches, root = self.merkle.branches_and_root( + block_txs[tx_height], [tx_pos for (tx_hash_bytes, tx, tx_num, tx_pos) in v] ) - merkle = { - 'block_height': tx_height, - 'merkle': [ - hash_to_hex_str(_hash) - for _hash in branch - ], - 'pos': tx_pos - } - if tx_height > 0 and tx_height + 10 < self.db_height: - self._tx_and_merkle_cache[tx_hash] = tx, merkle - return None if not tx else tx.hex(), merkle + for (tx_hash_bytes, tx, tx_num, tx_pos) in v: + merkle = { + 'block_height': tx_height, + 'merkle': [ + hash_to_hex_str(_hash) + for _hash in branches[tx_pos] + ], + 'pos': tx_pos + } + tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), merkle + if tx_height > 0 and tx_height + 10 < self.db_height: + self._tx_and_merkle_cache[tx_hash_bytes[::-1].hex()] = tx, merkle + await asyncio.sleep(0) + if needed_mempool: + for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor( + self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool], + True, False)): + tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1} + await asyncio.sleep(0) + return tx_infos async def fs_block_hashes(self, height, count): if height + count > len(self.headers): diff --git a/scribe/db/interface.py b/scribe/db/interface.py index 57278f2..b4f430b 100644 --- a/scribe/db/interface.py +++ b/scribe/db/interface.py @@ -90,12 +90,9 @@ class PrefixRow(metaclass=PrefixRowType): def multi_get(self, key_args: typing.List[typing.Tuple], fill_cache=True, deserialize_value=True): packed_keys = {tuple(args): self.pack_key(*args) for args in key_args} - result = { - k[-1]: v for k, v in ( - self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args], - fill_cache=fill_cache) or {} - ).items() - } + db_result = self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args], + fill_cache=fill_cache) + result = {k[-1]: v for k, v in (db_result or {}).items()} def handle_value(v): return None if v is None else v if not deserialize_value else self.unpack_value(v) diff --git a/scribe/db/merkle.py b/scribe/db/merkle.py index 5e9a6a6..072f810 100644 --- a/scribe/db/merkle.py +++ b/scribe/db/merkle.py @@ -25,7 +25,7 @@ # and warranty status of this software. """Merkle trees, branches, proofs and roots.""" - +import typing from asyncio import Event from math import ceil, log @@ -87,6 +87,26 @@ class Merkle: return branch, hashes[0] + @staticmethod + def branches_and_root(block_tx_hashes: typing.List[bytes], tx_positions: typing.List[int]): + block_tx_hashes = list(block_tx_hashes) + positions = list(tx_positions) + length = ceil(log(len(block_tx_hashes), 2)) + branches = [[] for _ in range(len(tx_positions))] + for _ in range(length): + if len(block_tx_hashes) & 1: + h = block_tx_hashes[-1] + block_tx_hashes.append(h) + for idx, tx_position in enumerate(tx_positions): + h = block_tx_hashes[tx_position ^ 1] + branches[idx].append(h) + tx_positions[idx] >>= 1 + block_tx_hashes = [ + double_sha256(block_tx_hashes[n] + block_tx_hashes[n + 1]) for n in + range(0, len(block_tx_hashes), 2) + ] + return {tx_position: branch for tx_position, branch in zip(positions, branches)}, block_tx_hashes[0] + @staticmethod def root(hashes, length=None): """Return the merkle root of a non-empty iterable of binary hashes.""" diff --git a/scribe/elasticsearch/service.py b/scribe/elasticsearch/service.py index 4f0739e..f22c195 100644 --- a/scribe/elasticsearch/service.py +++ b/scribe/elasticsearch/service.py @@ -231,7 +231,7 @@ class ElasticSyncService(BlockchainReaderService): self._advanced = True def unwind(self): - reverted_block_hash = self.db.coin.header_hash(self.db.headers[-1]) + reverted_block_hash = self.db.block_hashes[-1] super().unwind() packed = self.db.prefix_db.undo.get(len(self.db.tx_counts), reverted_block_hash) touched_or_deleted = None diff --git a/scribe/service.py b/scribe/service.py index a1d7fff..b43ea4f 100644 --- a/scribe/service.py +++ b/scribe/service.py @@ -157,7 +157,9 @@ class BlockchainReaderService(BlockchainService): self.db.total_transactions.append(tx_hash) self.db.tx_num_mapping[tx_hash] = tx_count assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}" - self.db.headers.append(self.db.prefix_db.header.get(height, deserialize_value=False)) + header = self.db.prefix_db.header.get(height, deserialize_value=False) + self.db.headers.append(header) + self.db.block_hashes.append(self.env.coin.header_hash(header)) def unwind(self): """ @@ -166,6 +168,7 @@ class BlockchainReaderService(BlockchainService): prev_count = self.db.tx_counts.pop() tx_count = self.db.tx_counts[-1] self.db.headers.pop() + self.db.block_hashes.pop() if self.db._cache_all_tx_hashes: for _ in range(prev_count - tx_count): self.db.tx_num_mapping.pop(self.db.total_transactions.pop())