use multi_get for fetching transactions #25

Merged
jackrobison merged 5 commits from multiget-transactions into master 2022-04-16 17:19:13 +02:00
6 changed files with 112 additions and 53 deletions

View file

@ -181,7 +181,7 @@ class BlockchainProcessorService(BlockchainService):
self.log.warning("failed to get a mempool tx, reorg underway?")
return
if current_mempool:
if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.coin.header_hash(self.db.headers[-1]):
if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.db.block_hashes[-1]:
return
await self.run_in_thread(
update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool
@ -1417,6 +1417,7 @@ class BlockchainProcessorService(BlockchainService):
self.height = height
self.db.headers.append(block.header)
self.db.block_hashes.append(self.env.coin.header_hash(block.header))
self.tip = self.coin.header_hash(block.header)
self.db.fs_height = self.height
@ -1493,8 +1494,9 @@ class BlockchainProcessorService(BlockchainService):
# Check and update self.tip
self.db.tx_counts.pop()
reverted_block_hash = self.coin.header_hash(self.db.headers.pop())
self.tip = self.coin.header_hash(self.db.headers[-1])
self.db.headers.pop()
reverted_block_hash = self.db.block_hashes.pop()
self.tip = self.db.block_hashes[-1]
if self.env.cache_all_tx_hashes:
while len(self.db.total_transactions) > self.db.tx_counts[-1]:
self.db.tx_num_mapping.pop(self.db.total_transactions.pop())

View file

@ -78,6 +78,7 @@ class HubDB:
self.tx_counts = None
self.headers = None
self.block_hashes = None
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server')
self.last_flush = time.time()
@ -775,6 +776,18 @@ class HubDB:
assert len(headers) - 1 == self.db_height, f"{len(headers)} vs {self.db_height}"
self.headers = headers
async def _read_block_hashes(self):
def get_block_hashes():
return [
block_hash for block_hash in self.prefix_db.block_hash.iterate(
start=(0, ), stop=(self.db_height + 1, ), include_key=False, fill_cache=False, deserialize_value=False
)
]
block_hashes = await asyncio.get_event_loop().run_in_executor(self._executor, get_block_hashes)
assert len(block_hashes) == len(self.headers)
self.block_hashes = block_hashes
async def _read_tx_hashes(self):
def _read_tx_hashes():
return list(self.prefix_db.tx_hash.iterate(start=(0,), stop=(self.db_tx_count + 1,), include_key=False, fill_cache=False, deserialize_value=False))
@ -839,6 +852,7 @@ class HubDB:
async def initialize_caches(self):
await self._read_tx_counts()
await self._read_headers()
await self._read_block_hashes()
if self._cache_all_claim_txos:
await self._read_claim_txos()
if self._cache_all_tx_hashes:
@ -976,51 +990,74 @@ class HubDB:
async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]):
tx_infos = {}
for tx_hash in tx_hashes:
tx_infos[tx_hash] = await asyncio.get_event_loop().run_in_executor(
self._executor, self._get_transaction_and_merkle, tx_hash
)
await asyncio.sleep(0)
return tx_infos
needed = []
needed_confirmed = []
needed_mempool = []
run_in_executor = asyncio.get_event_loop().run_in_executor
def _get_transaction_and_merkle(self, tx_hash):
for tx_hash in tx_hashes:
cached_tx = self._tx_and_merkle_cache.get(tx_hash)
if cached_tx:
tx, merkle = cached_tx
tx_infos[tx_hash] = None if not tx else tx.hex(), merkle
else:
tx_hash_bytes = bytes.fromhex(tx_hash)[::-1]
tx_num = self.prefix_db.tx_num.get(tx_hash_bytes)
tx = None
tx_height = -1
tx_num = None if not tx_num else tx_num.tx_num
if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping:
needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes]))
else:
needed.append(tx_hash_bytes)
if needed:
for tx_hash_bytes, v in zip(needed, await run_in_executor(
self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed],
True, True)):
tx_num = None if v is None else v.tx_num
if tx_num is not None:
if self._cache_all_claim_txos:
fill_cache = tx_num in self.txo_to_claim and len(self.txo_to_claim[tx_num]) > 0
needed_confirmed.append((tx_hash_bytes, tx_num))
else:
fill_cache = True
needed_mempool.append(tx_hash_bytes)
await asyncio.sleep(0)
if needed_confirmed:
needed_heights = set()
tx_heights_and_positions = defaultdict(list)
for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor(
self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed],
True, False)):
tx_height = bisect_right(self.tx_counts, tx_num)
tx = self.prefix_db.tx.get(tx_hash_bytes, fill_cache=fill_cache, deserialize_value=False)
if tx_height == -1:
merkle = {
'block_height': -1
}
tx = self.prefix_db.mempool_tx.get(tx_hash_bytes, deserialize_value=False)
else:
needed_heights.add(tx_height)
tx_pos = tx_num - self.tx_counts[tx_height - 1]
branch, root = self.merkle.branch_and_root(
self.get_block_txs(tx_height), tx_pos
tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
sorted_heights = list(sorted(needed_heights))
block_txs = await run_in_executor(
self._executor, self.prefix_db.block_txs.multi_get, [(height,) for height in sorted_heights]
)
block_txs = {height: v.tx_hashes for height, v in zip(sorted_heights, block_txs)}
for tx_height, v in tx_heights_and_positions.items():
branches, root = self.merkle.branches_and_root(
block_txs[tx_height], [tx_pos for (tx_hash_bytes, tx, tx_num, tx_pos) in v]
)
for (tx_hash_bytes, tx, tx_num, tx_pos) in v:
merkle = {
'block_height': tx_height,
'merkle': [
hash_to_hex_str(_hash)
for _hash in branch
for _hash in branches[tx_pos]
],
'pos': tx_pos
}
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), merkle
if tx_height > 0 and tx_height + 10 < self.db_height:
self._tx_and_merkle_cache[tx_hash] = tx, merkle
return None if not tx else tx.hex(), merkle
self._tx_and_merkle_cache[tx_hash_bytes[::-1].hex()] = tx, merkle
await asyncio.sleep(0)
if needed_mempool:
for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor(
self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool],
True, False)):
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
await asyncio.sleep(0)
return tx_infos
async def fs_block_hashes(self, height, count):
if height + count > len(self.headers):

View file

@ -90,12 +90,9 @@ class PrefixRow(metaclass=PrefixRowType):
def multi_get(self, key_args: typing.List[typing.Tuple], fill_cache=True, deserialize_value=True):
packed_keys = {tuple(args): self.pack_key(*args) for args in key_args}
result = {
k[-1]: v for k, v in (
self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args],
fill_cache=fill_cache) or {}
).items()
}
db_result = self._db.multi_get([(self._column_family, packed_keys[tuple(args)]) for args in key_args],
fill_cache=fill_cache)
result = {k[-1]: v for k, v in (db_result or {}).items()}
def handle_value(v):
return None if v is None else v if not deserialize_value else self.unpack_value(v)

View file

@ -25,7 +25,7 @@
# and warranty status of this software.
"""Merkle trees, branches, proofs and roots."""
import typing
from asyncio import Event
from math import ceil, log
@ -87,6 +87,26 @@ class Merkle:
return branch, hashes[0]
@staticmethod
def branches_and_root(block_tx_hashes: typing.List[bytes], tx_positions: typing.List[int]):
block_tx_hashes = list(block_tx_hashes)
positions = list(tx_positions)
length = ceil(log(len(block_tx_hashes), 2))
branches = [[] for _ in range(len(tx_positions))]
for _ in range(length):
if len(block_tx_hashes) & 1:
h = block_tx_hashes[-1]
block_tx_hashes.append(h)
for idx, tx_position in enumerate(tx_positions):
h = block_tx_hashes[tx_position ^ 1]
branches[idx].append(h)
tx_positions[idx] >>= 1
block_tx_hashes = [
double_sha256(block_tx_hashes[n] + block_tx_hashes[n + 1]) for n in
range(0, len(block_tx_hashes), 2)
]
return {tx_position: branch for tx_position, branch in zip(positions, branches)}, block_tx_hashes[0]
@staticmethod
def root(hashes, length=None):
"""Return the merkle root of a non-empty iterable of binary hashes."""

View file

@ -231,7 +231,7 @@ class ElasticSyncService(BlockchainReaderService):
self._advanced = True
def unwind(self):
reverted_block_hash = self.db.coin.header_hash(self.db.headers[-1])
reverted_block_hash = self.db.block_hashes[-1]
super().unwind()
packed = self.db.prefix_db.undo.get(len(self.db.tx_counts), reverted_block_hash)
touched_or_deleted = None

View file

@ -157,7 +157,9 @@ class BlockchainReaderService(BlockchainService):
self.db.total_transactions.append(tx_hash)
self.db.tx_num_mapping[tx_hash] = tx_count
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
self.db.headers.append(self.db.prefix_db.header.get(height, deserialize_value=False))
header = self.db.prefix_db.header.get(height, deserialize_value=False)
self.db.headers.append(header)
self.db.block_hashes.append(self.env.coin.header_hash(header))
def unwind(self):
"""
@ -166,6 +168,7 @@ class BlockchainReaderService(BlockchainService):
prev_count = self.db.tx_counts.pop()
tx_count = self.db.tx_counts[-1]
self.db.headers.pop()
self.db.block_hashes.pop()
if self.db._cache_all_tx_hashes:
for _ in range(prev_count - tx_count):
self.db.tx_num_mapping.pop(self.db.total_transactions.pop())