Improve merkle proof and tx cache performance #29
3 changed files with 95 additions and 41 deletions
|
@ -12,13 +12,14 @@ from functools import partial
|
|||
from bisect import bisect_right
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from scribe import PROMETHEUS_NAMESPACE
|
||||
from scribe.error import ResolveCensoredError
|
||||
from scribe.schema.url import URL, normalize_name
|
||||
from scribe.schema.claim import guess_stream_type
|
||||
from scribe.schema.result import Censor
|
||||
from scribe.blockchain.transaction import TxInput
|
||||
from scribe.common import hash_to_hex_str, hash160, LRUCacheWithMetrics
|
||||
from scribe.db.merkle import Merkle, MerkleCache
|
||||
from scribe.db.merkle import Merkle, MerkleCache, FastMerkleCacheItem
|
||||
from scribe.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO
|
||||
from scribe.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, PrefixDB
|
||||
from scribe.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE, EffectiveAmountKey
|
||||
|
@ -29,6 +30,7 @@ from scribe.db.prefixes import HashXMempoolStatusPrefixRow
|
|||
TXO_STRUCT = struct.Struct(b'>LH')
|
||||
TXO_STRUCT_unpack = TXO_STRUCT.unpack
|
||||
TXO_STRUCT_pack = TXO_STRUCT.pack
|
||||
NAMESPACE = f"{PROMETHEUS_NAMESPACE}_db"
|
||||
|
||||
|
||||
class HubDB:
|
||||
|
@ -79,14 +81,17 @@ class HubDB:
|
|||
self.tx_counts = None
|
||||
self.headers = None
|
||||
self.block_hashes = None
|
||||
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server')
|
||||
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace=NAMESPACE)
|
||||
self.last_flush = time.time()
|
||||
|
||||
# Header merkle cache
|
||||
self.merkle = Merkle()
|
||||
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
|
||||
|
||||
self._tx_and_merkle_cache = LRUCacheWithMetrics(2 ** 16, metric_name='tx_and_merkle', namespace="wallet_server")
|
||||
# lru cache of tx_hash: (tx_bytes, tx_num, position, tx_height)
|
||||
self.tx_cache = LRUCacheWithMetrics(2 ** 14, metric_name='tx', namespace=NAMESPACE)
|
||||
# lru cache of block heights to merkle trees of the block tx hashes
|
||||
self.merkle_cache = LRUCacheWithMetrics(2 ** 13, metric_name='merkle', namespace=NAMESPACE)
|
||||
|
||||
# these are only used if the cache_all_tx_hashes setting is on
|
||||
self.total_transactions: List[bytes] = []
|
||||
|
@ -95,7 +100,6 @@ class HubDB:
|
|||
# these are only used if the cache_all_claim_txos setting is on
|
||||
self.claim_to_txo: Dict[bytes, ClaimToTXOValue] = {}
|
||||
self.txo_to_claim: DefaultDict[int, Dict[int, bytes]] = defaultdict(dict)
|
||||
|
||||
self.genesis_bytes = bytes.fromhex(self.coin.GENESIS_HASH)
|
||||
|
||||
def get_claim_from_txo(self, tx_num: int, tx_idx: int) -> Optional[TXOToClaimValue]:
|
||||
|
@ -988,28 +992,36 @@ class HubDB:
|
|||
def get_block_txs(self, height: int) -> List[bytes]:
|
||||
return self.prefix_db.block_txs.get(height).tx_hashes
|
||||
|
||||
async def get_transactions_and_merkles(self, tx_hashes: Iterable[str]):
|
||||
async def get_transactions_and_merkles(self, txids: List[str]):
|
||||
tx_infos = {}
|
||||
needed = []
|
||||
needed_tx_nums = []
|
||||
needed_confirmed = []
|
||||
needed_mempool = []
|
||||
cached_mempool = []
|
||||
needed_heights = set()
|
||||
tx_heights_and_positions = defaultdict(list)
|
||||
|
||||
run_in_executor = asyncio.get_event_loop().run_in_executor
|
||||
|
||||
for tx_hash in tx_hashes:
|
||||
cached_tx = self._tx_and_merkle_cache.get(tx_hash)
|
||||
for txid in txids:
|
||||
tx_hash_bytes = bytes.fromhex(txid)[::-1]
|
||||
cached_tx = self.tx_cache.get(tx_hash_bytes)
|
||||
if cached_tx:
|
||||
tx, merkle = cached_tx
|
||||
tx_infos[tx_hash] = None if not tx else tx.hex(), merkle
|
||||
tx, tx_num, tx_pos, tx_height = cached_tx
|
||||
if tx_height > 0:
|
||||
needed_heights.add(tx_height)
|
||||
tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
|
||||
else:
|
||||
cached_mempool.append((tx_hash_bytes, tx))
|
||||
else:
|
||||
tx_hash_bytes = bytes.fromhex(tx_hash)[::-1]
|
||||
if self._cache_all_tx_hashes and tx_hash_bytes in self.tx_num_mapping:
|
||||
needed_confirmed.append((tx_hash_bytes, self.tx_num_mapping[tx_hash_bytes]))
|
||||
else:
|
||||
needed.append(tx_hash_bytes)
|
||||
needed_tx_nums.append(tx_hash_bytes)
|
||||
|
||||
if needed:
|
||||
for tx_hash_bytes, v in zip(needed, await run_in_executor(
|
||||
self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed],
|
||||
if needed_tx_nums:
|
||||
for tx_hash_bytes, v in zip(needed_tx_nums, await run_in_executor(
|
||||
self._executor, self.prefix_db.tx_num.multi_get, [(tx_hash,) for tx_hash in needed_tx_nums],
|
||||
True, True)):
|
||||
tx_num = None if v is None else v.tx_num
|
||||
if tx_num is not None:
|
||||
|
@ -1019,8 +1031,6 @@ class HubDB:
|
|||
await asyncio.sleep(0)
|
||||
|
||||
if needed_confirmed:
|
||||
needed_heights = set()
|
||||
tx_heights_and_positions = defaultdict(list)
|
||||
for (tx_hash_bytes, tx_num), tx in zip(needed_confirmed, await run_in_executor(
|
||||
self._executor, self.prefix_db.tx.multi_get, [(tx_hash,) for tx_hash, _ in needed_confirmed],
|
||||
True, False)):
|
||||
|
@ -1028,36 +1038,43 @@ class HubDB:
|
|||
needed_heights.add(tx_height)
|
||||
tx_pos = tx_num - self.tx_counts[tx_height - 1]
|
||||
tx_heights_and_positions[tx_height].append((tx_hash_bytes, tx, tx_num, tx_pos))
|
||||
self.tx_cache[tx_hash_bytes] = tx, tx_num, tx_pos, tx_height
|
||||
|
||||
sorted_heights = list(sorted(needed_heights))
|
||||
merkles: Dict[int, FastMerkleCacheItem] = {} # uses existing cached merkle trees when they're available
|
||||
needed_for_merkle_cache = []
|
||||
for height in sorted_heights:
|
||||
merkle = self.merkle_cache.get(height)
|
||||
if merkle:
|
||||
merkles[height] = merkle
|
||||
else:
|
||||
needed_for_merkle_cache.append(height)
|
||||
if needed_for_merkle_cache:
|
||||
block_txs = await run_in_executor(
|
||||
self._executor, self.prefix_db.block_txs.multi_get, [(height,) for height in sorted_heights]
|
||||
self._executor, self.prefix_db.block_txs.multi_get,
|
||||
[(height,) for height in needed_for_merkle_cache]
|
||||
)
|
||||
block_txs = {height: v.tx_hashes for height, v in zip(sorted_heights, block_txs)}
|
||||
for height, v in zip(needed_for_merkle_cache, block_txs):
|
||||
merkles[height] = self.merkle_cache[height] = FastMerkleCacheItem(v.tx_hashes)
|
||||
await asyncio.sleep(0)
|
||||
for tx_height, v in tx_heights_and_positions.items():
|
||||
branches, root = self.merkle.branches_and_root(
|
||||
block_txs[tx_height], [tx_pos for (tx_hash_bytes, tx, tx_num, tx_pos) in v]
|
||||
)
|
||||
get_merkle_branch = merkles[tx_height].branch
|
||||
for (tx_hash_bytes, tx, tx_num, tx_pos) in v:
|
||||
merkle = {
|
||||
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {
|
||||
'block_height': tx_height,
|
||||
'merkle': [
|
||||
hash_to_hex_str(_hash)
|
||||
for _hash in branches[tx_pos]
|
||||
],
|
||||
'merkle': get_merkle_branch(tx_pos),
|
||||
'pos': tx_pos
|
||||
}
|
||||
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), merkle
|
||||
if tx_height > 0 and tx_height + 10 < self.db_height:
|
||||
self._tx_and_merkle_cache[tx_hash_bytes[::-1].hex()] = tx, merkle
|
||||
await asyncio.sleep(0)
|
||||
for tx_hash_bytes, tx in cached_mempool:
|
||||
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
|
||||
if needed_mempool:
|
||||
for tx_hash_bytes, tx in zip(needed_mempool, await run_in_executor(
|
||||
self._executor, self.prefix_db.mempool_tx.multi_get, [(tx_hash,) for tx_hash in needed_mempool],
|
||||
True, False)):
|
||||
self.tx_cache[tx_hash_bytes] = tx, None, None, -1
|
||||
tx_infos[tx_hash_bytes[::-1].hex()] = None if not tx else tx.hex(), {'block_height': -1}
|
||||
await asyncio.sleep(0)
|
||||
return tx_infos
|
||||
return {txid: tx_infos.get(txid) for txid in txids} # match ordering of the txs in the request
|
||||
|
||||
async def fs_block_hashes(self, height, count):
|
||||
if height + count > len(self.headers):
|
||||
|
|
|
@ -276,3 +276,29 @@ class MerkleCache:
|
|||
level = await self._level_for(length)
|
||||
return self.merkle.branch_and_root_from_level(
|
||||
level, leaf_hashes, index, self.depth_higher)
|
||||
|
||||
|
||||
class FastMerkleCacheItem:
|
||||
__slots__ = ['tree', 'root_hash']
|
||||
|
||||
def __init__(self, tx_hashes: typing.List[bytes]):
|
||||
self.tree: typing.List[typing.List[bytes]] = []
|
||||
self.root_hash = self._walk_merkle(tx_hashes, self.tree.append)
|
||||
|
||||
@staticmethod
|
||||
def _walk_merkle(items: typing.List[bytes], append_layer) -> bytes:
|
||||
if len(items) == 1:
|
||||
return items[0]
|
||||
append_layer(items)
|
||||
layer = [
|
||||
double_sha256(items[index] + items[index])
|
||||
if index + 1 == len(items) else double_sha256(items[index] + items[index + 1])
|
||||
for index in range(0, len(items), 2)
|
||||
]
|
||||
return FastMerkleCacheItem._walk_merkle(layer, append_layer)
|
||||
|
||||
def branch(self, tx_position: int) -> typing.List[str]:
|
||||
return [
|
||||
(layer[-1] if (tx_position >> shift) ^ 1 == len(layer) else layer[(tx_position >> shift) ^ 1])[::-1].hex()
|
||||
for shift, layer in enumerate(self.tree)
|
||||
]
|
||||
|
|
|
@ -151,12 +151,20 @@ class BlockchainReaderService(BlockchainService):
|
|||
assert len(self.db.tx_counts) == height, f"{len(self.db.tx_counts)} != {height}"
|
||||
prev_count = self.db.tx_counts[-1]
|
||||
self.db.tx_counts.append(tx_count)
|
||||
|
||||
# precache all of the txs from this block
|
||||
block_tx_hashes = self.db.prefix_db.block_txs.get(height).tx_hashes
|
||||
block_txs = self.db.prefix_db.tx.multi_get(
|
||||
[(tx_hash,) for tx_hash in block_tx_hashes], deserialize_value=False
|
||||
)
|
||||
for tx_pos, (tx_num, tx_hash, tx) in enumerate(zip(range(prev_count, tx_count), block_tx_hashes, block_txs)):
|
||||
self.db.tx_cache[tx_hash] = tx, tx_num, tx_pos, height
|
||||
if self.db._cache_all_tx_hashes:
|
||||
for tx_num in range(prev_count, tx_count):
|
||||
tx_hash = self.db.prefix_db.tx_hash.get(tx_num).tx_hash
|
||||
self.db.total_transactions.append(tx_hash)
|
||||
self.db.tx_num_mapping[tx_hash] = tx_count
|
||||
if self.db._cache_all_tx_hashes:
|
||||
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
|
||||
|
||||
header = self.db.prefix_db.header.get(height, deserialize_value=False)
|
||||
self.db.headers.append(header)
|
||||
self.db.block_hashes.append(self.env.coin.header_hash(header))
|
||||
|
@ -171,8 +179,11 @@ class BlockchainReaderService(BlockchainService):
|
|||
self.db.block_hashes.pop()
|
||||
if self.db._cache_all_tx_hashes:
|
||||
for _ in range(prev_count - tx_count):
|
||||
self.db.tx_num_mapping.pop(self.db.total_transactions.pop())
|
||||
tx_hash = self.db.tx_num_mapping.pop(self.db.total_transactions.pop())
|
||||
if tx_hash in self.db.tx_cache:
|
||||
self.db.tx_cache.pop(tx_hash)
|
||||
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
|
||||
self.db.merkle_cache.clear()
|
||||
|
||||
def _detect_changes(self):
|
||||
try:
|
||||
|
|
Loading…
Reference in a new issue