improve caching for blockchain.address.get_history

This commit is contained in:
Jack Robison 2022-05-19 12:53:49 -04:00
parent 9a6f2a6d96
commit e5713dc63c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 79 additions and 22 deletions

View file

@ -959,11 +959,25 @@ class HubDB:
return self.total_transactions[tx_num] return self.total_transactions[tx_num]
return self.prefix_db.tx_hash.get(tx_num, deserialize_value=False) return self.prefix_db.tx_hash.get(tx_num, deserialize_value=False)
def get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]: def _get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]:
if self._cache_all_tx_hashes: if self._cache_all_tx_hashes:
return [None if tx_num > self.db_tx_count else self.total_transactions[tx_num] for tx_num in tx_nums] return [None if tx_num > self.db_tx_count else self.total_transactions[tx_num] for tx_num in tx_nums]
return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False) return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False)
async def get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]:
if self._cache_all_tx_hashes:
result = []
append_result = result.append
for tx_num in tx_nums:
append_result(None if tx_num > self.db_tx_count else self.total_transactions[tx_num])
await asyncio.sleep(0)
return result
def _get_tx_hashes():
return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False)
return await asyncio.get_event_loop().run_in_executor(self._executor, _get_tx_hashes)
def get_raw_mempool_tx(self, tx_hash: bytes) -> Optional[bytes]: def get_raw_mempool_tx(self, tx_hash: bytes) -> Optional[bytes]:
return self.prefix_db.mempool_tx.get(tx_hash, deserialize_value=False) return self.prefix_db.mempool_tx.get(tx_hash, deserialize_value=False)
@ -1159,7 +1173,7 @@ class HubDB:
raise DBError(f'only got {len(self.headers) - height:,d} headers starting at {height:,d}, not {count:,d}') raise DBError(f'only got {len(self.headers) - height:,d} headers starting at {height:,d}, not {count:,d}')
return [self.coin.header_hash(header) for header in self.headers[height:height + count]] return [self.coin.header_hash(header) for header in self.headers[height:height + count]]
def read_history(self, hashX: bytes, limit: int = 1000) -> List[int]: def _read_history(self, hashX: bytes, limit: int = 1000) -> List[int]:
txs = [] txs = []
txs_extend = txs.extend txs_extend = txs.extend
for hist in self.prefix_db.hashX_history.iterate(prefix=(hashX,), include_key=False): for hist in self.prefix_db.hashX_history.iterate(prefix=(hashX,), include_key=False):
@ -1168,6 +1182,9 @@ class HubDB:
break break
return txs return txs
async def read_history(self, hashX: bytes, limit: int = 1000) -> List[int]:
return await asyncio.get_event_loop().run_in_executor(self._executor, self._read_history, hashX, limit)
async def limited_history(self, hashX, *, limit=1000): async def limited_history(self, hashX, *, limit=1000):
"""Return an unpruned, sorted list of (tx_hash, height) tuples of """Return an unpruned, sorted list of (tx_hash, height) tuples of
confirmed transactions that touched the address, earliest in confirmed transactions that touched the address, earliest in
@ -1176,13 +1193,12 @@ class HubDB:
limit to None to get them all. limit to None to get them all.
""" """
run_in_executor = asyncio.get_event_loop().run_in_executor run_in_executor = asyncio.get_event_loop().run_in_executor
tx_nums = await run_in_executor(self._executor, self.read_history, hashX, limit) tx_nums = await run_in_executor(self._executor, self._read_history, hashX, limit)
history = [] history = []
append_history = history.append append_history = history.append
while tx_nums: while tx_nums:
batch, tx_nums = tx_nums[:100], tx_nums[100:] batch, tx_nums = tx_nums[:100], tx_nums[100:]
batch_result = self.get_tx_hashes(batch) if self._cache_all_tx_hashes else await run_in_executor(self._executor, self.get_tx_hashes, batch) for tx_num, tx_hash in zip(batch, await self.get_tx_hashes(batch)):
for tx_num, tx_hash in zip(batch, batch_result):
append_history((tx_hash, bisect_right(self.tx_counts, tx_num))) append_history((tx_hash, bisect_right(self.tx_counts, tx_num)))
await asyncio.sleep(0) await asyncio.sleep(0)
return history return history

View file

@ -48,6 +48,24 @@ class HubServerService(BlockchainReaderService):
touched_hashXs = self.db.prefix_db.touched_hashX.get(height).touched_hashXs touched_hashXs = self.db.prefix_db.touched_hashX.get(height).touched_hashXs
self.notifications_to_send.append((set(touched_hashXs), height)) self.notifications_to_send.append((set(touched_hashXs), height))
def unwind(self):
prev_count = self.db.tx_counts.pop()
tx_count = self.db.tx_counts[-1]
self.db.headers.pop()
self.db.block_hashes.pop()
current_count = prev_count
for _ in range(prev_count - tx_count):
if current_count in self.session_manager.history_tx_info_cache:
self.session_manager.history_tx_info_cache.pop(current_count)
current_count -= 1
if self.db._cache_all_tx_hashes:
for _ in range(prev_count - tx_count):
tx_hash = self.db.tx_num_mapping.pop(self.db.total_transactions.pop())
if tx_hash in self.db.tx_cache:
self.db.tx_cache.pop(tx_hash)
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
self.db.merkle_cache.clear()
def _detect_changes(self): def _detect_changes(self):
super()._detect_changes() super()._detect_changes()
start = time.perf_counter() start = time.perf_counter()

View file

@ -22,7 +22,7 @@ from hub.herald import PROTOCOL_MIN, PROTOCOL_MAX, HUB_PROTOCOL_VERSION
from hub.build_info import BUILD, COMMIT_HASH, DOCKER_TAG from hub.build_info import BUILD, COMMIT_HASH, DOCKER_TAG
from hub.herald.search import SearchIndex from hub.herald.search import SearchIndex
from hub.common import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, version_string, formatted_time from hub.common import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, version_string, formatted_time
from hub.common import protocol_version, RPCError, DaemonError, TaskGroup, HISTOGRAM_BUCKETS from hub.common import protocol_version, RPCError, DaemonError, TaskGroup, HISTOGRAM_BUCKETS, LRUCache
from hub.herald.jsonrpc import JSONRPCAutoDetect, JSONRPCConnection, JSONRPCv2, JSONRPC from hub.herald.jsonrpc import JSONRPCAutoDetect, JSONRPCConnection, JSONRPCv2, JSONRPC
from hub.herald.common import BatchRequest, ProtocolError, Request, Batch, Notification from hub.herald.common import BatchRequest, ProtocolError, Request, Batch, Notification
from hub.herald.framer import NewlineFramer from hub.herald.framer import NewlineFramer
@ -183,7 +183,6 @@ class SessionManager:
self.cur_group = SessionGroup(0) self.cur_group = SessionGroup(0)
self.txs_sent = 0 self.txs_sent = 0
self.start_time = time.time() self.start_time = time.time()
self.history_cache = {}
self.resolve_outputs_cache = {} self.resolve_outputs_cache = {}
self.resolve_cache = {} self.resolve_cache = {}
self.notified_height: typing.Optional[int] = None self.notified_height: typing.Optional[int] = None
@ -198,9 +197,13 @@ class SessionManager:
elastic_host=env.elastic_host, elastic_port=env.elastic_port elastic_host=env.elastic_host, elastic_port=env.elastic_port
) )
self.running = False self.running = False
self.hashX_history_cache = LRUCache(2 ** 14)
self.hashX_full_cache = LRUCache(2 ** 12)
self.history_tx_info_cache = LRUCache(2 ** 17)
def clear_caches(self): def clear_caches(self):
self.history_cache.clear() self.hashX_history_cache.clear()
self.hashX_full_cache.clear()
self.resolve_outputs_cache.clear() self.resolve_outputs_cache.clear()
self.resolve_cache.clear() self.resolve_cache.clear()
@ -592,16 +595,36 @@ class SessionManager:
self.txs_sent += 1 self.txs_sent += 1
return hex_hash return hex_hash
async def limited_history(self, hashX): async def limited_history(self, hashX: bytes) -> typing.List[typing.Tuple[str, int]]:
"""A caching layer.""" if hashX in self.hashX_full_cache:
if hashX not in self.history_cache: return self.hashX_full_cache[hashX]
# History DoS limit. Each element of history is about 99 if hashX not in self.hashX_history_cache:
# bytes when encoded as JSON. This limits resource usage
# on bloated history requests, and uses a smaller divisor
# so large requests are logged before refusing them.
limit = self.env.max_send // 97 limit = self.env.max_send // 97
self.history_cache[hashX] = await self.db.limited_history(hashX, limit=limit) self.hashX_history_cache[hashX] = tx_nums = await self.db.read_history(hashX, limit)
return self.history_cache[hashX] else:
tx_nums = self.hashX_history_cache[hashX]
needed_tx_infos = []
append_needed_tx_info = needed_tx_infos.append
tx_infos = {}
for tx_num in tx_nums:
if tx_num in self.history_tx_info_cache:
tx_infos[tx_num] = self.history_tx_info_cache[tx_num]
else:
append_needed_tx_info(tx_num)
await asyncio.sleep(0)
if needed_tx_infos:
for tx_num, tx_hash in zip(needed_tx_infos, await self.db.get_tx_hashes(needed_tx_infos)):
hist = tx_hash[::-1].hex(), bisect_right(self.db.tx_counts, tx_num)
tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = hist
await asyncio.sleep(0)
history = []
history_append = history.append
for tx_num in tx_nums:
history_append(tx_infos[tx_num])
await asyncio.sleep(0)
self.hashX_full_cache[hashX] = history
return history
def _notify_peer(self, peer): def _notify_peer(self, peer):
notify_tasks = [ notify_tasks = [
@ -1419,8 +1442,8 @@ class LBRYElectrumX(asyncio.Protocol):
async def confirmed_and_unconfirmed_history(self, hashX): async def confirmed_and_unconfirmed_history(self, hashX):
# Note history is ordered but unconfirmed is unordered in e-s # Note history is ordered but unconfirmed is unordered in e-s
history = await self.session_manager.limited_history(hashX) history = await self.session_manager.limited_history(hashX)
conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height} conf = [{'tx_hash': txid, 'height': height}
for tx_hash, height in history] for txid, height in history]
return conf + self.unconfirmed_history(hashX) return conf + self.unconfirmed_history(hashX)
async def scripthash_get_history(self, scripthash): async def scripthash_get_history(self, scripthash):

View file

@ -1245,7 +1245,7 @@ class BlockchainProcessorService(BlockchainService):
if hashX in self.hashX_full_cache: if hashX in self.hashX_full_cache:
return self.hashX_full_cache[hashX] return self.hashX_full_cache[hashX]
if hashX not in self.hashX_history_cache: if hashX not in self.hashX_history_cache:
self.hashX_history_cache[hashX] = tx_nums = self.db.read_history(hashX, limit=None) self.hashX_history_cache[hashX] = tx_nums = self.db._read_history(hashX, limit=None)
else: else:
tx_nums = self.hashX_history_cache[hashX] tx_nums = self.hashX_history_cache[hashX]
needed_tx_infos = [] needed_tx_infos = []
@ -1257,7 +1257,7 @@ class BlockchainProcessorService(BlockchainService):
else: else:
append_needed_tx_info(tx_num) append_needed_tx_info(tx_num)
if needed_tx_infos: if needed_tx_infos:
for tx_num, tx_hash in zip(needed_tx_infos, self.db.get_tx_hashes(needed_tx_infos)): for tx_num, tx_hash in zip(needed_tx_infos, self.db._get_tx_hashes(needed_tx_infos)):
tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:' tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:'
history = '' history = ''
@ -1487,7 +1487,7 @@ class BlockchainProcessorService(BlockchainService):
else: else:
append_needed_tx_info(tx_num) append_needed_tx_info(tx_num)
if needed_tx_infos: if needed_tx_infos:
for tx_num, tx_hash in zip(needed_tx_infos, self.db.get_tx_hashes(needed_tx_infos)): for tx_num, tx_hash in zip(needed_tx_infos, self.db._get_tx_hashes(needed_tx_infos)):
tx_info = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:' tx_info = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:'
tx_infos[tx_num] = tx_info tx_infos[tx_num] = tx_info
self.history_tx_info_cache[tx_num] = tx_info self.history_tx_info_cache[tx_num] = tx_info