Improve blockchain.address.get_history performance #40

Merged
jackrobison merged 6 commits from improve-history-cache into master 2022-05-22 06:00:39 +02:00
5 changed files with 163 additions and 27 deletions

View file

@ -1,4 +1,5 @@
import struct import struct
import asyncio
import hashlib import hashlib
import hmac import hmac
import ipaddress import ipaddress
@ -28,6 +29,9 @@ CLAIM_HASH_LEN = 20
HISTOGRAM_BUCKETS = ( HISTOGRAM_BUCKETS = (
.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf')
) )
SIZE_BUCKETS = (
1, 10, 100, 500, 1000, 2000, 4000, 7500, 10000, 15000, 25000, 50000, 75000, 100000, 150000, 250000, float('inf')
)
CLAIM_TYPES = { CLAIM_TYPES = {
'stream': 1, 'stream': 1,
@ -763,3 +767,11 @@ def expand_result(results):
if inner_hits: if inner_hits:
return expand_result(inner_hits) return expand_result(inner_hits)
return expanded return expanded
async def asyncify_for_loop(gen, ticks_per_sleep: int = 1000):
async_sleep = asyncio.sleep
for cnt, item in enumerate(gen):
yield item
if cnt % ticks_per_sleep == 0:
await async_sleep(0)

View file

@ -959,11 +959,25 @@ class HubDB:
return self.total_transactions[tx_num] return self.total_transactions[tx_num]
return self.prefix_db.tx_hash.get(tx_num, deserialize_value=False) return self.prefix_db.tx_hash.get(tx_num, deserialize_value=False)
def get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]: def _get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]:
if self._cache_all_tx_hashes: if self._cache_all_tx_hashes:
return [None if tx_num > self.db_tx_count else self.total_transactions[tx_num] for tx_num in tx_nums] return [None if tx_num > self.db_tx_count else self.total_transactions[tx_num] for tx_num in tx_nums]
return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False) return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False)
async def get_tx_hashes(self, tx_nums: List[int]) -> List[Optional[bytes]]:
if self._cache_all_tx_hashes:
result = []
append_result = result.append
for tx_num in tx_nums:
append_result(None if tx_num > self.db_tx_count else self.total_transactions[tx_num])
await asyncio.sleep(0)
return result
def _get_tx_hashes():
return self.prefix_db.tx_hash.multi_get([(tx_num,) for tx_num in tx_nums], deserialize_value=False)
return await asyncio.get_event_loop().run_in_executor(self._executor, _get_tx_hashes)
def get_raw_mempool_tx(self, tx_hash: bytes) -> Optional[bytes]: def get_raw_mempool_tx(self, tx_hash: bytes) -> Optional[bytes]:
return self.prefix_db.mempool_tx.get(tx_hash, deserialize_value=False) return self.prefix_db.mempool_tx.get(tx_hash, deserialize_value=False)
@ -1159,7 +1173,7 @@ class HubDB:
raise DBError(f'only got {len(self.headers) - height:,d} headers starting at {height:,d}, not {count:,d}') raise DBError(f'only got {len(self.headers) - height:,d} headers starting at {height:,d}, not {count:,d}')
return [self.coin.header_hash(header) for header in self.headers[height:height + count]] return [self.coin.header_hash(header) for header in self.headers[height:height + count]]
def read_history(self, hashX: bytes, limit: int = 1000) -> List[int]: def _read_history(self, hashX: bytes, limit: Optional[int] = 1000) -> List[int]:
txs = [] txs = []
txs_extend = txs.extend txs_extend = txs.extend
for hist in self.prefix_db.hashX_history.iterate(prefix=(hashX,), include_key=False): for hist in self.prefix_db.hashX_history.iterate(prefix=(hashX,), include_key=False):
@ -1168,6 +1182,9 @@ class HubDB:
break break
return txs return txs
async def read_history(self, hashX: bytes, limit: Optional[int] = 1000) -> List[int]:
return await asyncio.get_event_loop().run_in_executor(self._executor, self._read_history, hashX, limit)
async def limited_history(self, hashX, *, limit=1000): async def limited_history(self, hashX, *, limit=1000):
"""Return an unpruned, sorted list of (tx_hash, height) tuples of """Return an unpruned, sorted list of (tx_hash, height) tuples of
confirmed transactions that touched the address, earliest in confirmed transactions that touched the address, earliest in
@ -1176,13 +1193,12 @@ class HubDB:
limit to None to get them all. limit to None to get them all.
""" """
run_in_executor = asyncio.get_event_loop().run_in_executor run_in_executor = asyncio.get_event_loop().run_in_executor
tx_nums = await run_in_executor(self._executor, self.read_history, hashX, limit) tx_nums = await run_in_executor(self._executor, self._read_history, hashX, limit)
history = [] history = []
append_history = history.append append_history = history.append
while tx_nums: while tx_nums:
batch, tx_nums = tx_nums[:100], tx_nums[100:] batch, tx_nums = tx_nums[:100], tx_nums[100:]
batch_result = self.get_tx_hashes(batch) if self._cache_all_tx_hashes else await run_in_executor(self._executor, self.get_tx_hashes, batch) for tx_num, tx_hash in zip(batch, await self.get_tx_hashes(batch)):
for tx_num, tx_hash in zip(batch, batch_result):
append_history((tx_hash, bisect_right(self.tx_counts, tx_num))) append_history((tx_hash, bisect_right(self.tx_counts, tx_num)))
await asyncio.sleep(0) await asyncio.sleep(0)
return history return history

View file

@ -46,8 +46,29 @@ class HubServerService(BlockchainReaderService):
def advance(self, height: int): def advance(self, height: int):
super().advance(height) super().advance(height)
touched_hashXs = self.db.prefix_db.touched_hashX.get(height).touched_hashXs touched_hashXs = self.db.prefix_db.touched_hashX.get(height).touched_hashXs
self.session_manager.update_history_caches(touched_hashXs)
self.notifications_to_send.append((set(touched_hashXs), height)) self.notifications_to_send.append((set(touched_hashXs), height))
def unwind(self):
self.session_manager.hashX_raw_history_cache.clear()
self.session_manager.hashX_history_cache.clear()
prev_count = self.db.tx_counts.pop()
tx_count = self.db.tx_counts[-1]
self.db.headers.pop()
self.db.block_hashes.pop()
current_count = prev_count
for _ in range(prev_count - tx_count):
if current_count in self.session_manager.history_tx_info_cache:
self.session_manager.history_tx_info_cache.pop(current_count)
current_count -= 1
if self.db._cache_all_tx_hashes:
for _ in range(prev_count - tx_count):
tx_hash = self.db.tx_num_mapping.pop(self.db.total_transactions.pop())
if tx_hash in self.db.tx_cache:
self.db.tx_cache.pop(tx_hash)
assert len(self.db.total_transactions) == tx_count, f"{len(self.db.total_transactions)} vs {tx_count}"
self.db.merkle_cache.clear()
def _detect_changes(self): def _detect_changes(self):
super()._detect_changes() super()._detect_changes()
start = time.perf_counter() start = time.perf_counter()

View file

@ -1,5 +1,5 @@
import os import os
import ssl import sys
import math import math
import time import time
import codecs import codecs
@ -21,8 +21,9 @@ from hub import __version__, PROMETHEUS_NAMESPACE
from hub.herald import PROTOCOL_MIN, PROTOCOL_MAX, HUB_PROTOCOL_VERSION from hub.herald import PROTOCOL_MIN, PROTOCOL_MAX, HUB_PROTOCOL_VERSION
from hub.build_info import BUILD, COMMIT_HASH, DOCKER_TAG from hub.build_info import BUILD, COMMIT_HASH, DOCKER_TAG
from hub.herald.search import SearchIndex from hub.herald.search import SearchIndex
from hub.common import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, version_string, formatted_time from hub.common import sha256, hash_to_hex_str, hex_str_to_hash, HASHX_LEN, version_string, formatted_time, SIZE_BUCKETS
from hub.common import protocol_version, RPCError, DaemonError, TaskGroup, HISTOGRAM_BUCKETS from hub.common import protocol_version, RPCError, DaemonError, TaskGroup, HISTOGRAM_BUCKETS
from hub.common import LRUCacheWithMetrics
from hub.herald.jsonrpc import JSONRPCAutoDetect, JSONRPCConnection, JSONRPCv2, JSONRPC from hub.herald.jsonrpc import JSONRPCAutoDetect, JSONRPCConnection, JSONRPCv2, JSONRPC
from hub.herald.common import BatchRequest, ProtocolError, Request, Batch, Notification from hub.herald.common import BatchRequest, ProtocolError, Request, Batch, Notification
from hub.herald.framer import NewlineFramer from hub.herald.framer import NewlineFramer
@ -32,6 +33,8 @@ if typing.TYPE_CHECKING:
from hub.scribe.daemon import LBCDaemon from hub.scribe.daemon import LBCDaemon
from hub.herald.mempool import HubMemPool from hub.herald.mempool import HubMemPool
PYTHON_VERSION = sys.version_info.major, sys.version_info.minor
TypedDict = dict if PYTHON_VERSION < (3, 8) else typing.TypedDict
BAD_REQUEST = 1 BAD_REQUEST = 1
DAEMON_ERROR = 2 DAEMON_ERROR = 2
@ -42,6 +45,11 @@ SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
'required_names other_names') 'required_names other_names')
class CachedAddressHistoryItem(TypedDict):
tx_hash: str
height: int
def scripthash_to_hashX(scripthash: str) -> bytes: def scripthash_to_hashX(scripthash: str) -> bytes:
try: try:
bin_hash = hex_str_to_hash(scripthash) bin_hash = hex_str_to_hash(scripthash)
@ -146,7 +154,6 @@ class SessionManager:
pending_query_metric = Gauge( pending_query_metric = Gauge(
"pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE
) )
client_version_metric = Counter( client_version_metric = Counter(
"clients", "Number of connections received per client version", "clients", "Number of connections received per client version",
namespace=NAMESPACE, labelnames=("version",) namespace=NAMESPACE, labelnames=("version",)
@ -155,6 +162,14 @@ class SessionManager:
"address_history", "Time to fetch an address history", "address_history", "Time to fetch an address history",
namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS
) )
address_subscription_metric = Gauge(
"address_subscriptions", "Number of subscribed addresses",
namespace=NAMESPACE
)
address_history_size_metric = Histogram(
"history_size", "Sizes of histories for subscribed addresses",
namespace=NAMESPACE, buckets=SIZE_BUCKETS
)
notifications_in_flight_metric = Gauge( notifications_in_flight_metric = Gauge(
"notifications_in_flight", "Count of notifications in flight", "notifications_in_flight", "Count of notifications in flight",
namespace=NAMESPACE namespace=NAMESPACE
@ -183,7 +198,6 @@ class SessionManager:
self.cur_group = SessionGroup(0) self.cur_group = SessionGroup(0)
self.txs_sent = 0 self.txs_sent = 0
self.start_time = time.time() self.start_time = time.time()
self.history_cache = {}
self.resolve_outputs_cache = {} self.resolve_outputs_cache = {}
self.resolve_cache = {} self.resolve_cache = {}
self.notified_height: typing.Optional[int] = None self.notified_height: typing.Optional[int] = None
@ -198,12 +212,50 @@ class SessionManager:
elastic_host=env.elastic_host, elastic_port=env.elastic_port elastic_host=env.elastic_host, elastic_port=env.elastic_port
) )
self.running = False self.running = False
# hashX: List[int]
self.hashX_raw_history_cache = LRUCacheWithMetrics(2 ** 16, metric_name='raw_history', namespace=NAMESPACE)
# hashX: List[CachedAddressHistoryItem]
self.hashX_history_cache = LRUCacheWithMetrics(2 ** 14, metric_name='full_history', namespace=NAMESPACE)
# tx_num: Tuple[txid, height]
self.history_tx_info_cache = LRUCacheWithMetrics(2 ** 19, metric_name='history_tx', namespace=NAMESPACE)
def clear_caches(self): def clear_caches(self):
self.history_cache.clear()
self.resolve_outputs_cache.clear() self.resolve_outputs_cache.clear()
self.resolve_cache.clear() self.resolve_cache.clear()
def update_history_caches(self, touched_hashXs: typing.List[bytes]):
update_history_cache = {}
for hashX in set(touched_hashXs):
history_tx_nums = None
# if the history is the raw_history_cache, update it
# TODO: use a reversed iterator for this instead of rescanning it all
if hashX in self.hashX_raw_history_cache:
self.hashX_raw_history_cache[hashX] = history_tx_nums = self.db._read_history(hashX, None)
# if it's in hashX_history_cache, prepare to update it in a batch
if hashX in self.hashX_history_cache:
full_cached = self.hashX_history_cache[hashX]
if history_tx_nums is None:
history_tx_nums = self.db._read_history(hashX, None)
new_txs = history_tx_nums[len(full_cached):]
update_history_cache[hashX] = full_cached, new_txs
if update_history_cache:
# get the set of new tx nums that were touched in all of the new histories to be cached
total_tx_nums = set()
for _, new_txs in update_history_cache.values():
total_tx_nums.update(new_txs)
total_tx_nums = list(total_tx_nums)
# collect the total new tx infos
referenced_new_txs = {
tx_num: (CachedAddressHistoryItem(tx_hash=tx_hash[::-1].hex(), height=bisect_right(self.db.tx_counts, tx_num)))
for tx_num, tx_hash in zip(total_tx_nums, self.db._get_tx_hashes(total_tx_nums))
}
# update the cached history lists
get_referenced = referenced_new_txs.__getitem__
for hashX, (full, new_txs) in update_history_cache.items():
append_to_full = full.append
for tx_num in new_txs:
append_to_full(get_referenced(tx_num))
async def _start_server(self, kind, *args, **kw_args): async def _start_server(self, kind, *args, **kw_args):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -592,16 +644,48 @@ class SessionManager:
self.txs_sent += 1 self.txs_sent += 1
return hex_hash return hex_hash
async def limited_history(self, hashX): async def _cached_raw_history(self, hashX: bytes, limit: typing.Optional[int] = None):
"""A caching layer.""" tx_nums = self.hashX_raw_history_cache.get(hashX)
if hashX not in self.history_cache: if tx_nums is None:
# History DoS limit. Each element of history is about 99 self.hashX_raw_history_cache[hashX] = tx_nums = await self.db.read_history(hashX, limit)
# bytes when encoded as JSON. This limits resource usage return tx_nums
# on bloated history requests, and uses a smaller divisor
# so large requests are logged before refusing them. async def cached_confirmed_history(self, hashX: bytes,
limit = self.env.max_send // 97 limit: typing.Optional[int] = None) -> typing.List[CachedAddressHistoryItem]:
self.history_cache[hashX] = await self.db.limited_history(hashX, limit=limit) cached_full_history = self.hashX_history_cache.get(hashX)
return self.history_cache[hashX] # return the cached history
if cached_full_history is not None:
self.address_history_size_metric.observe(len(cached_full_history))
return cached_full_history
# return the history and update the caches
tx_nums = await self._cached_raw_history(hashX, limit)
needed_tx_infos = []
append_needed_tx_info = needed_tx_infos.append
tx_infos = {}
for cnt, tx_num in enumerate(tx_nums): # determine which tx_hashes are cached and which we need to look up
cached = self.history_tx_info_cache.get(tx_num)
if cached is not None:
tx_infos[tx_num] = cached
else:
append_needed_tx_info(tx_num)
if cnt % 1000 == 0:
await asyncio.sleep(0)
if needed_tx_infos: # request all the needed tx hashes in one batch, cache the txids and heights
for cnt, (tx_num, tx_hash) in enumerate(zip(needed_tx_infos, await self.db.get_tx_hashes(needed_tx_infos))):
hist = CachedAddressHistoryItem(tx_hash=tx_hash[::-1].hex(), height=bisect_right(self.db.tx_counts, tx_num))
tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = hist
if cnt % 1000 == 0:
await asyncio.sleep(0)
# ensure the ordering of the txs
history = []
history_append = history.append
for cnt, tx_num in enumerate(tx_nums):
history_append(tx_infos[tx_num])
if cnt % 1000 == 0:
await asyncio.sleep(0)
self.hashX_history_cache[hashX] = history
self.address_history_size_metric.observe(len(history))
return history
def _notify_peer(self, peer): def _notify_peer(self, peer):
notify_tasks = [ notify_tasks = [
@ -623,6 +707,7 @@ class SessionManager:
def remove_session(self, session): def remove_session(self, session):
"""Remove a session from our sessions list if there.""" """Remove a session from our sessions list if there."""
session_id = id(session) session_id = id(session)
self.address_subscription_metric.dec(len(session.hashX_subs))
for hashX in session.hashX_subs: for hashX in session.hashX_subs:
sessions = self.hashx_subscriptions_by_session[hashX] sessions = self.hashx_subscriptions_by_session[hashX]
sessions.remove(session_id) sessions.remove(session_id)
@ -1348,6 +1433,8 @@ class LBRYElectrumX(asyncio.Protocol):
sessions.remove(id(self)) sessions.remove(id(self))
except KeyError: except KeyError:
pass pass
else:
self.session_manager.address_subscription_metric.dec()
if not sessions: if not sessions:
self.hashX_subs.pop(hashX, None) self.hashX_subs.pop(hashX, None)
@ -1385,6 +1472,7 @@ class LBRYElectrumX(asyncio.Protocol):
if len(addresses) > 1000: if len(addresses) > 1000:
raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}') raise RPCError(BAD_REQUEST, f'too many addresses in subscription request: {len(addresses)}')
results = [] results = []
self.session_manager.address_subscription_metric.inc(len(addresses))
for address in addresses: for address in addresses:
results.append(await self.hashX_subscribe(self.address_to_hashX(address), address)) results.append(await self.hashX_subscribe(self.address_to_hashX(address), address))
await asyncio.sleep(0) await asyncio.sleep(0)
@ -1418,10 +1506,8 @@ class LBRYElectrumX(asyncio.Protocol):
async def confirmed_and_unconfirmed_history(self, hashX): async def confirmed_and_unconfirmed_history(self, hashX):
# Note history is ordered but unconfirmed is unordered in e-s # Note history is ordered but unconfirmed is unordered in e-s
history = await self.session_manager.limited_history(hashX) history = await self.session_manager.cached_confirmed_history(hashX)
conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height} return history + self.unconfirmed_history(hashX)
for tx_hash, height in history]
return conf + self.unconfirmed_history(hashX)
async def scripthash_get_history(self, scripthash): async def scripthash_get_history(self, scripthash):
"""Return the confirmed and unconfirmed history of a scripthash.""" """Return the confirmed and unconfirmed history of a scripthash."""
@ -1443,6 +1529,7 @@ class LBRYElectrumX(asyncio.Protocol):
scripthash: the SHA256 hash of the script to subscribe to""" scripthash: the SHA256 hash of the script to subscribe to"""
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
self.session_manager.address_subscription_metric.inc()
return await self.hashX_subscribe(hashX, scripthash) return await self.hashX_subscribe(hashX, scripthash)
async def scripthash_unsubscribe(self, scripthash: str): async def scripthash_unsubscribe(self, scripthash: str):

View file

@ -1245,7 +1245,7 @@ class BlockchainProcessorService(BlockchainService):
if hashX in self.hashX_full_cache: if hashX in self.hashX_full_cache:
return self.hashX_full_cache[hashX] return self.hashX_full_cache[hashX]
if hashX not in self.hashX_history_cache: if hashX not in self.hashX_history_cache:
self.hashX_history_cache[hashX] = tx_nums = self.db.read_history(hashX, limit=None) self.hashX_history_cache[hashX] = tx_nums = self.db._read_history(hashX, limit=None)
else: else:
tx_nums = self.hashX_history_cache[hashX] tx_nums = self.hashX_history_cache[hashX]
needed_tx_infos = [] needed_tx_infos = []
@ -1257,7 +1257,7 @@ class BlockchainProcessorService(BlockchainService):
else: else:
append_needed_tx_info(tx_num) append_needed_tx_info(tx_num)
if needed_tx_infos: if needed_tx_infos:
for tx_num, tx_hash in zip(needed_tx_infos, self.db.get_tx_hashes(needed_tx_infos)): for tx_num, tx_hash in zip(needed_tx_infos, self.db._get_tx_hashes(needed_tx_infos)):
tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:' tx_infos[tx_num] = self.history_tx_info_cache[tx_num] = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:'
history = '' history = ''
@ -1487,7 +1487,7 @@ class BlockchainProcessorService(BlockchainService):
else: else:
append_needed_tx_info(tx_num) append_needed_tx_info(tx_num)
if needed_tx_infos: if needed_tx_infos:
for tx_num, tx_hash in zip(needed_tx_infos, self.db.get_tx_hashes(needed_tx_infos)): for tx_num, tx_hash in zip(needed_tx_infos, self.db._get_tx_hashes(needed_tx_infos)):
tx_info = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:' tx_info = f'{tx_hash[::-1].hex()}:{bisect_right(self.db.tx_counts, tx_num):d}:'
tx_infos[tx_num] = tx_info tx_infos[tx_num] = tx_info
self.history_tx_info_cache[tx_num] = tx_info self.history_tx_info_cache[tx_num] = tx_info