use multi_get for fetching transactions #25

Merged
jackrobison merged 5 commits from multiget-transactions into master 2022-04-16 17:19:13 +02:00
6 changed files with 112 additions and 53 deletions

View file

@ -181,7 +181,7 @@ class BlockchainProcessorService(BlockchainService):
self.log.warning("failed to get a mempool tx, reorg underway?") self.log.warning("failed to get a mempool tx, reorg underway?")
return return
if current_mempool: if current_mempool:
if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.coin.header_hash(self.db.headers[-1]): if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.db.block_hashes[-1]:
return return
await self.run_in_thread( await self.run_in_thread(
update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool
@ -1417,6 +1417,7 @@ class BlockchainProcessorService(BlockchainService):
self.height = height self.height = height
self.db.headers.append(block.header) self.db.headers.append(block.header)
self.db.block_hashes.append(self.env.coin.header_hash(block.header))
self.tip = self.coin.header_hash(block.header) self.tip = self.coin.header_hash(block.header)
self.db.fs_height = self.height self.db.fs_height = self.height
@ -1493,8 +1494,9 @@ class BlockchainProcessorService(BlockchainService):
# Check and update self.tip # Check and update self.tip
self.db.tx_counts.pop() self.db.tx_counts.pop()
reverted_block_hash = self.coin.header_hash(self.db.headers.pop()) self.db.headers.pop()
self.tip = self.coin.header_hash(self.db.headers[-1]) reverted_block_hash = self.db.block_hashes.pop()
self.tip = self.db.block_hashes[-1]
if self.env.cache_all_tx_hashes: if self.env.cache_all_tx_hashes:
while len(self.db.total_transactions) > self.db.tx_counts[-1]: while len(self.db.total_transactions) > self.db.tx_counts[-1]:
self.db.tx_num_mapping.pop(self.db.total_transactions.pop()) self.db.tx_num_mapping.pop(self.db.total_transactions.pop())

View file

@ -78,6 +78,7 @@ class HubDB:
self.tx_counts = None self.tx_counts = None
self.headers = None self.headers = None
self.block_hashes = None
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server') self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server')
self.last_flush = time.time() self.last_flush = time.time()
@ -775,6 +776,18 @@ class HubDB:
assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}" assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}"
self.headers = headers self.headers = headers
async def _read_block_hashes(self):
def get_block_hashes():
return [
block_hash for block_hash in self.prefix_db.block_hash.iterate(
start=(0, ), stop=(self.db_height + 1, ), include_key=False, fill_cache=False, deserialize_value=False
)
]
block_hashes = await asyncio.get_event_loop().run_in_executor(self._executor, get_block_hashes)
assert len(block_hashes) == len(self.headers)
self.block_hashes = block_hashes
async def _read_tx_hashes(self): async def _read_tx_hashes(self):
def _read_tx_hashes(): def _read_tx_hashes():
return list(self.prefix_db.tx_hash.iterate(start=(0,), stop=(self.db_tx_count + 1,), include_key=False, fill_cache=False, deserialize_value=False)) return list(self.prefix_db.tx_hash.iterate(start=(0,), stop=(self.db_tx_count + 1,), include_key=False, fill_cache=False, deserialize_value=False))
@ -839,6 +852,7 @@ class HubDB:
async def initialize_caches(self): async def initialize_caches(self):
await self._read_tx_counts() await self._read_tx_counts()
await self._read_headers() await self._read_headers()
await self._read_block_hashes()
if self._cache_all_claim_txos: if self._cache_all_claim_txos:
await self._read_claim_txos() await self._read_claim_txos()
if self._cache_all_tx_hashes: if self._cache_all_tx_hashes:
@ -976,51 +990,74 @@ class HubDB:
async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]): async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]):
tx_infos = {} tx_infos = {}
for tx_hash in tx_hashes: needed = []
tx_infos[tx_hash] = await asyncio.get_event_loop().run_in_executor( needed_confirmed = []
self._executor, self._get_transaction_and_merkle, tx_hash needed_mempool = []
) run_in_executor = asyncio.get_event_loop().run_in_executor
await asyncio.sleep(0)
return tx_infos
def _get_transaction_and_merkle(self, tx_hash): for tx_hash in tx_hashes:
cached_tx = self._tx_and_merkle_cache.get(tx_hash) cached_tx = self._tx_and_merkle_cache.get(tx_hash)
if cached_tx: if cached_tx:
tx, merkle = cached_tx tx, merkle = cached_tx
tx_infos[tx_hash] = None if not tx else tx.hex(), merkle
else: else:
tx_hash_bytes = bytes.fromhex(tx_hash)[::-1] tx_hash_bytes = bytes.fromhex(tx_hash)[::-1]
tx_num = self.prefix_db.tx_num.get(tx_hash_bytes) if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping:
tx = None needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes]))
tx_height = -1 else:
tx_num = None if not tx_num else tx_num.tx_num needed.append(tx_hash_bytes)
if needed:
for tx_hash_bytes, v in zip(needed, await run_in_executor(
self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed],
True, True)):
tx_num = None if v is None else v.tx_num
if tx_num is not None: if tx_num is not None:
if self._cache_all_claim_txos: needed_confirmed.append((tx_hash_bytes, tx_num))
fill_cache = tx_num in self.txo_to_claim and len(self.txo_to_claim[tx_num]) > 0
else: else:
fill_cache = True needed_mempool.append(tx_hash_bytes)
await asyncio.sleep(0)
if needed_confirmed:
needed_heights = set()
tx_heights_and_positions = defaultdict(list)
for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor(
self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed],
True, False)):
tx_height = bisect_right(self.tx_counts, tx_num) tx_height = bisect_right(self.tx_counts, tx_num)
tx = self.prefix_db.tx.get(tx_hash_bytes, fill_cache=fill_cache, deserialize_value=False) needed_heights.add(tx_height)
if tx_height == -1:
merkle = {
'block_height': -1
}
tx = self.prefix_db.mempool_tx.get(tx_hash_bytes, deserialize_value=False)
else:
tx_pos = tx_num - self.tx_counts[tx_height - 1] tx_pos = tx_num - self.tx_counts[tx_height - 1]
branch, root = self.merkle.branch_and_root( tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
self.get_block_txs(tx_height), tx_pos
sorted_heights = list(sorted(needed_heights))
block_txs = await run_in_executor(
self._executor, self.prefix_db.block_txs.multi_get, [(height,) for height in sorted_heights]
) )
block_txs = {height: v.tx_hashes for height, v in zip(sorted_heights, block_txs)}
for tx_height, v in tx_heights_and_positions.items():
branches, root = self.merkle.branches_and_root(
block_txs[tx_height], [tx_pos for (tx_hash_bytes, tx, tx_num, tx_pos) in v]
)
for (tx_hash_bytes, tx, tx_num, tx_pos) in v:
merkle = { merkle = {
'block_height': tx_height, 'block_height': tx_height,
'merkle': [ 'merkle': [
hash_to_hex_str(_hash) hash_to_hex_str(_hash)
for _hash in branch for _hash in branches[tx_pos]
], ],
'pos': tx_pos 'pos': tx_pos
} }
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), merkle
if tx_height > 0 and tx_height + 10 < self.db_height: if tx_height > 0 and tx_height + 10 < self.db_height:
self._tx_and_merkle_cache[tx_hash] = tx, merkle self._tx_and_merkle_cache[tx_hash_bytes[::-1].hex()] = tx, merkle
return None if not tx else tx.hex(), merkle await asyncio.sleep(0)
if needed_mempool:
for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor(
self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool],
True, False)):
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
await asyncio.sleep(0)
return tx_infos
async def fs_block_hashes(self, height, count): async def fs_block_hashes(self, height, count):
if height + count > len(self.headers): if height + count > len(self.headers):

View file

@ -90,12 +90,9 @@ class PrefixRow(metaclass=PrefixRowType):
def multi_get(self, key_args: typing.List[typing.Tuple], fill_cache=True, deserialize_value=True): def multi_get(self, key_args: typing.List[typing.Tuple], fill_cache=True, deserialize_value=True):
packed_keys = {tuple(args): self.pack_key(*args) for args in key_args} packed_keys = {tuple(args): self.pack_key(*args) for args in key_args}
result = { db_result = self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args],
k[-1]: v for k, v in ( fill_cache=fill_cache)
self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args], result = {k[-1]: v for k, v in (db_result or {}).items()}
fill_cache=fill_cache) or {}
).items()
}
def handle_value(v): def handle_value(v):
return None if v is None else v if not deserialize_value else self.unpack_value(v) return None if v is None else v if not deserialize_value else self.unpack_value(v)

View file

@ -25,7 +25,7 @@
# and warranty status of this software. # and warranty status of this software.
"""Merkle trees, branches, proofs and roots.""" """Merkle trees, branches, proofs and roots."""
import typing
from asyncio import Event from asyncio import Event
from math import ceil, log from math import ceil, log
@ -87,6 +87,26 @@ class Merkle:
return branch, hashes[0] return branch, hashes[0]
@staticmethod
def branches_and_root(block_tx_hashes: typing.List[bytes], tx_positions: typing.List[int]):
block_tx_hashes = list(block_tx_hashes)
positions = list(tx_positions)
length = ceil(log(len(block_tx_hashes), 2))
branches = [[] for _ in range(len(tx_positions))]
for _ in range(length):
if len(block_tx_hashes) & 1:
h = block_tx_hashes[-1]
block_tx_hashes.append(h)
for idx, tx_position in enumerate(tx_positions):
h = block_tx_hashes[tx_position ^ 1]
branches[idx].append(h)
tx_positions[idx] >>= 1
block_tx_hashes = [
double_sha256(block_tx_hashes[n] + block_tx_hashes[n + 1]) for n in
range(0, len(block_tx_hashes), 2)
]
return {tx_position: branch for tx_position, branch in zip(positions, branches)}, block_tx_hashes[0]
@staticmethod @staticmethod
def root(hashes, length=None): def root(hashes, length=None):
"""Return the merkle root of a non-empty iterable of binary hashes.""" """Return the merkle root of a non-empty iterable of binary hashes."""

View file

@ -231,7 +231,7 @@ class ElasticSyncService(BlockchainReaderService):
self._advanced = True self._advanced = True
def unwind(self): def unwind(self):
reverted_block_hash = self.db.coin.header_hash(self.db.headers[-1]) reverted_block_hash = self.db.block_hashes[-1]
super().unwind() super().unwind()
packed = self.db.prefix_db.undo.get(len(self.db.tx_counts), reverted_block_hash) packed = self.db.prefix_db.undo.get(len(self.db.tx_counts), reverted_block_hash)
touched_or_deleted = None touched_or_deleted = None

View file

@ -157,7 +157,9 @@ class BlockchainReaderService(BlockchainService):
self.db.total_transactions.append(tx_hash) self.db.total_transactions.append(tx_hash)
self.db.tx_num_mapping[tx_hash] = tx_count self.db.tx_num_mapping[tx_hash] = tx_count
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}" assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
self.db.headers.append(self.db.prefix_db.header.get(height, deserialize_value=False)) header = self.db.prefix_db.header.get(height, deserialize_value=False)
self.db.headers.append(header)
self.db.block_hashes.append(self.env.coin.header_hash(header))
def unwind(self): def unwind(self):
""" """
@ -166,6 +168,7 @@ class BlockchainReaderService(BlockchainService):
prev_count = self.db.tx_counts.pop() prev_count = self.db.tx_counts.pop()
tx_count = self.db.tx_counts[-1] tx_count = self.db.tx_counts[-1]
self.db.headers.pop() self.db.headers.pop()
self.db.block_hashes.pop()
if self.db._cache_all_tx_hashes: if self.db._cache_all_tx_hashes:
for _ in range(prev_count - tx_count): for _ in range(prev_count - tx_count):
self.db.tx_num_mapping.pop(self.db.total_transactions.pop()) self.db.tx_num_mapping.pop(self.db.total_transactions.pop())