Improve merkle proof and tx cache performance #29

Merged
jackrobison merged 4 commits from improve-tx-caches into master 2022-05-01 01:23:52 +02:00
3 changed files with 95 additions and 41 deletions

View file

@ -12,13 +12,14 @@ from functools import partial
from bisect import bisect_right from bisect import bisect_right
from collections import defaultdict from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from scribe import PROMETHEUS_NAMESPACE
from scribe.error import ResolveCensoredError from scribe.error import ResolveCensoredError
from scribe.schema.url import URL, normalize_name from scribe.schema.url import URL, normalize_name
from scribe.schema.claim import guess_stream_type from scribe.schema.claim import guess_stream_type
from scribe.schema.result import Censor from scribe.schema.result import Censor
from scribe.blockchain.transaction import TxInput from scribe.blockchain.transaction import TxInput
from scribe.common import hash_to_hex_str, hash160, LRUCacheWithMetrics from scribe.common import hash_to_hex_str, hash160, LRUCacheWithMetrics
from scribe.db.merkle import Merkle, MerkleCache from scribe.db.merkle import Merkle, MerkleCache, FastMerkleCacheItem
from scribe.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO from scribe.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO
from scribe.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, PrefixDB from scribe.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, PrefixDB
from scribe.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE, EffectiveAmountKey from scribe.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE, EffectiveAmountKey
@ -29,6 +30,7 @@ from scribe.db.prefixes import HashXMempoolStatusPrefixRow
TXO_STRUCT = struct.Struct(b'>LH') TXO_STRUCT = struct.Struct(b'>LH')
TXO_STRUCT_unpack = TXO_STRUCT.unpack TXO_STRUCT_unpack = TXO_STRUCT.unpack
TXO_STRUCT_pack = TXO_STRUCT.pack TXO_STRUCT_pack = TXO_STRUCT.pack
NAMESPACE = f"{PROMETHEUS_NAMESPACE}_db"
class HubDB: class HubDB:
@ -79,14 +81,17 @@ class HubDB:
self.tx_counts = None self.tx_counts = None
self.headers = None self.headers = None
self.block_hashes = None self.block_hashes = None
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server') self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace=NAMESPACE)
self.last_flush = time.time() self.last_flush = time.time()
# Header merkle cache # Header merkle cache
self.merkle = Merkle() self.merkle = Merkle()
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes) self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
self._tx_and_merkle_cache = LRUCacheWithMetrics(2 ** 16, metric_name='tx_and_merkle', namespace="wallet_server") # lru cache of tx_hash: (tx_bytes, tx_num, position, tx_height)
self.tx_cache = LRUCacheWithMetrics(2 ** 14, metric_name='tx', namespace=NAMESPACE)
# lru cache of block heights to merkle trees of the block tx hashes
self.merkle_cache = LRUCacheWithMetrics(2 ** 13, metric_name='merkle', namespace=NAMESPACE)
# these are only used if the cache_all_tx_hashes setting is on # these are only used if the cache_all_tx_hashes setting is on
self.total_transactions: List[bytes] = [] self.total_transactions: List[bytes] = []
@ -95,7 +100,6 @@ class HubDB:
# these are only used if the cache_all_claim_txos setting is on # these are only used if the cache_all_claim_txos setting is on
self.claim_to_txo: Dict[bytes, ClaimToTXOValue] = {} self.claim_to_txo: Dict[bytes, ClaimToTXOValue] = {}
self.txo_to_claim: DefaultDict[int, Dict[int, bytes]] = defaultdict(dict) self.txo_to_claim: DefaultDict[int, Dict[int, bytes]] = defaultdict(dict)
self.genesis_bytes = bytes.fromhex(self.coin.GENESIS_HASH) self.genesis_bytes = bytes.fromhex(self.coin.GENESIS_HASH)
def get_claim_from_txo(self, tx_num: int, tx_idx: int) -> Optional[TXOToClaimValue]: def get_claim_from_txo(self, tx_num: int, tx_idx: int) -> Optional[TXOToClaimValue]:
@ -988,28 +992,36 @@ class HubDB:
def get_block_txs(self, height: int) -> List[bytes]: def get_block_txs(self, height: int) -> List[bytes]:
return self.prefix_db.block_txs.get(height).tx_hashes return self.prefix_db.block_txs.get(height).tx_hashes
async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]): async def get_transactions_and_merkles(self, txids: List[str]):
tx_infos = {} tx_infos = {}
needed = [] needed_tx_nums = []
needed_confirmed = [] needed_confirmed = []
needed_mempool = [] needed_mempool = []
cached_mempool = []
needed_heights = set()
tx_heights_and_positions = defaultdict(list)
run_in_executor = asyncio.get_event_loop().run_in_executor run_in_executor = asyncio.get_event_loop().run_in_executor
for tx_hash in tx_hashes: for txid in txids:
cached_tx = self._tx_and_merkle_cache.get(tx_hash) tx_hash_bytes = bytes.fromhex(txid)[::-1]
cached_tx = self.tx_cache.get(tx_hash_bytes)
if cached_tx: if cached_tx:
tx, merkle = cached_tx tx, tx_num, tx_pos, tx_height = cached_tx
tx_infos[tx_hash] = None if not tx else tx.hex(), merkle if tx_height > 0:
needed_heights.add(tx_height)
tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
else:
cached_mempool.append((tx_hash_bytes, tx))
else: else:
tx_hash_bytes = bytes.fromhex(tx_hash)[::-1]
if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping: if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping:
needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes])) needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes]))
else: else:
needed.append(tx_hash_bytes) needed_tx_nums.append(tx_hash_bytes)
if needed: if needed_tx_nums:
for tx_hash_bytes, v in zip(needed, await run_in_executor( for tx_hash_bytes, v in zip(needed_tx_nums, await run_in_executor(
self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed], self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed_tx_nums],
True, True)): True, True)):
tx_num = None if v is None else v.tx_num tx_num = None if v is None else v.tx_num
if tx_num is not None: if tx_num is not None:
@ -1019,8 +1031,6 @@ class HubDB:
await asyncio.sleep(0) await asyncio.sleep(0)
if needed_confirmed: if needed_confirmed:
needed_heights = set()
tx_heights_and_positions = defaultdict(list)
for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor( for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor(
self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed], self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed],
True, False)): True, False)):
@ -1028,36 +1038,43 @@ class HubDB:
needed_heights.add(tx_height) needed_heights.add(tx_height)
tx_pos = tx_num - self.tx_counts[tx_height - 1] tx_pos = tx_num - self.tx_counts[tx_height - 1]
tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos)) tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
self.tx_cache[tx_hash_bytes] = tx, tx_num, tx_pos, tx_height
sorted_heights = list(sorted(needed_heights)) sorted_heights = list(sorted(needed_heights))
merkles: Dict[int, FastMerkleCacheItem] = {} # uses existing cached merkle trees when they're available
needed_for_merkle_cache = []
for height in sorted_heights:
merkle = self.merkle_cache.get(height)
if merkle:
merkles[height] = merkle
else:
needed_for_merkle_cache.append(height)
if needed_for_merkle_cache:
block_txs = await run_in_executor( block_txs = await run_in_executor(
self._executor, self.prefix_db.block_txs.multi_get, [(height,) for height in sorted_heights] self._executor, self.prefix_db.block_txs.multi_get,
[(height,) for height in needed_for_merkle_cache]
) )
block_txs = {height: v.tx_hashes for height, v in zip(sorted_heights, block_txs)} for height, v in zip(needed_for_merkle_cache, block_txs):
merkles[height] = self.merkle_cache[height] = FastMerkleCacheItem(v.tx_hashes)
await asyncio.sleep(0)
for tx_height, v in tx_heights_and_positions.items(): for tx_height, v in tx_heights_and_positions.items():
branches, root = self.merkle.branches_and_root( get_merkle_branch = merkles[tx_height].branch
block_txs[tx_height], [tx_pos for (tx_hash_bytes, tx, tx_num, tx_pos) in v]
)
for (tx_hash_bytes, tx, tx_num, tx_pos) in v: for (tx_hash_bytes, tx, tx_num, tx_pos) in v:
merkle = { tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {
'block_height': tx_height, 'block_height': tx_height,
'merkle': [ 'merkle': get_merkle_branch(tx_pos),
hash_to_hex_str(_hash)
for _hash in branches[tx_pos]
],
'pos': tx_pos 'pos': tx_pos
} }
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), merkle for tx_hash_bytes, tx in cached_mempool:
if tx_height > 0 and tx_height + 10 < self.db_height: tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
self._tx_and_merkle_cache[tx_hash_bytes[::-1].hex()] = tx, merkle
await asyncio.sleep(0)
if needed_mempool: if needed_mempool:
for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor( for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor(
self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool], self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool],
True, False)): True, False)):
self.tx_cache[tx_hash_bytes] = tx, None, None, -1
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1} tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
await asyncio.sleep(0) await asyncio.sleep(0)
return tx_infos return {txid: tx_infos.get(txid) for txid in txids} # match ordering of the txs in the request
async def fs_block_hashes(self, height, count): async def fs_block_hashes(self, height, count):
if height + count > len(self.headers): if height + count > len(self.headers):

View file

@ -276,3 +276,29 @@ class MerkleCache:
level = await self._level_for(length) level = await self._level_for(length)
return self.merkle.branch_and_root_from_level( return self.merkle.branch_and_root_from_level(
level, leaf_hashes, index, self.depth_higher) level, leaf_hashes, index, self.depth_higher)
class FastMerkleCacheItem:
__slots__ = ['tree', 'root_hash']
def __init__(self, tx_hashes: typing.List[bytes]):
self.tree: typing.List[typing.List[bytes]] = []
self.root_hash = self._walk_merkle(tx_hashes, self.tree.append)
@staticmethod
def _walk_merkle(items: typing.List[bytes], append_layer) -> bytes:
if len(items) == 1:
return items[0]
append_layer(items)
layer = [
double_sha256(items[index] + items[index])
if index + 1 == len(items) else double_sha256(items[index] + items[index + 1])
for index in range(0, len(items), 2)
]
return FastMerkleCacheItem._walk_merkle(layer, append_layer)
def branch(self, tx_position: int) -> typing.List[str]:
return [
(layer[-1] if (tx_position >> shift) ^ 1 == len(layer) else layer[(tx_position >> shift) ^ 1])[::-1].hex()
for shift, layer in enumerate(self.tree)
]

View file

@ -151,12 +151,20 @@ class BlockchainReaderService(BlockchainService):
assert len(self.db.tx_counts) == height, f"{len(self.db.tx_counts)} != {height}" assert len(self.db.tx_counts) == height, f"{len(self.db.tx_counts)} != {height}"
prev_count = self.db.tx_counts[-1] prev_count = self.db.tx_counts[-1]
self.db.tx_counts.append(tx_count) self.db.tx_counts.append(tx_count)
# precache all of the txs from this block
block_tx_hashes = self.db.prefix_db.block_txs.get(height).tx_hashes
block_txs = self.db.prefix_db.tx.multi_get(
[(tx_hash,) for tx_hash in block_tx_hashes], deserialize_value=False
)
for tx_pos, (tx_num, tx_hash, tx) in enumerate(zip(range(prev_count, tx_count), block_tx_hashes, block_txs)):
self.db.tx_cache[tx_hash] = tx, tx_num, tx_pos, height
if self.db._cache_all_tx_hashes: if self.db._cache_all_tx_hashes:
for tx_num in range(prev_count, tx_count):
tx_hash = self.db.prefix_db.tx_hash.get(tx_num).tx_hash
self.db.total_transactions.append(tx_hash) self.db.total_transactions.append(tx_hash)
self.db.tx_num_mapping[tx_hash] = tx_count self.db.tx_num_mapping[tx_hash] = tx_count
if self.db._cache_all_tx_hashes:
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}" assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
header = self.db.prefix_db.header.get(height, deserialize_value=False) header = self.db.prefix_db.header.get(height, deserialize_value=False)
self.db.headers.append(header) self.db.headers.append(header)
self.db.block_hashes.append(self.env.coin.header_hash(header)) self.db.block_hashes.append(self.env.coin.header_hash(header))
@ -171,8 +179,11 @@ class BlockchainReaderService(BlockchainService):
self.db.block_hashes.pop() self.db.block_hashes.pop()
if self.db._cache_all_tx_hashes: if self.db._cache_all_tx_hashes:
for _ in range(prev_count - tx_count): for _ in range(prev_count - tx_count):
self.db.tx_num_mapping.pop(self.db.total_transactions.pop()) tx_hash = self.db.tx_num_mapping.pop(self.db.total_transactions.pop())
if tx_hash in self.db.tx_cache:
self.db.tx_cache.pop(tx_hash)
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}" assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
self.db.merkle_cache.clear()
def _detect_changes(self): def _detect_changes(self):
try: try: