forked from LBRYCommunity/lbry-sdk
calculate notifications for all subscriptions instead of per session
This commit is contained in:
parent
20b435732a
commit
eab3b65629
5 changed files with 77 additions and 72 deletions
|
@ -366,7 +366,7 @@ class MemPool:
|
|||
result.update(tx.prevouts)
|
||||
return result
|
||||
|
||||
async def transaction_summaries(self, hashX):
|
||||
def transaction_summaries(self, hashX):
|
||||
"""Return a list of MemPoolTxSummary objects for the hashX."""
|
||||
result = []
|
||||
for tx_hash in self.hashXs.get(hashX, ()):
|
||||
|
|
|
@ -190,7 +190,9 @@ class SessionManager:
|
|||
self.shutdown_event = shutdown_event
|
||||
self.logger = util.class_logger(__name__, self.__class__.__name__)
|
||||
self.servers: typing.Dict[str, asyncio.AbstractServer] = {}
|
||||
self.sessions: typing.Set['SessionBase'] = set()
|
||||
self.sessions: typing.Dict[int, 'SessionBase'] = {}
|
||||
self.hashx_subscriptions_by_session: typing.DefaultDict[str, typing.Set[int]] = defaultdict(set)
|
||||
self.mempool_statuses = {}
|
||||
self.cur_group = SessionGroup(0)
|
||||
self.txs_sent = 0
|
||||
self.start_time = time.time()
|
||||
|
@ -276,12 +278,12 @@ class SessionManager:
|
|||
|
||||
def _group_map(self):
|
||||
group_map = defaultdict(list)
|
||||
for session in self.sessions:
|
||||
for session in self.sessions.values():
|
||||
group_map[session.group].append(session)
|
||||
return group_map
|
||||
|
||||
def _sub_count(self) -> int:
|
||||
return sum(s.sub_count() for s in self.sessions)
|
||||
return sum(s.sub_count() for s in self.sessions.values())
|
||||
|
||||
def _lookup_session(self, session_id):
|
||||
try:
|
||||
|
@ -289,7 +291,7 @@ class SessionManager:
|
|||
except Exception:
|
||||
pass
|
||||
else:
|
||||
for session in self.sessions:
|
||||
for session in self.sessions.values():
|
||||
if session.session_id == session_id:
|
||||
return session
|
||||
return None
|
||||
|
@ -313,7 +315,7 @@ class SessionManager:
|
|||
while True:
|
||||
await sleep(session_timeout // 10)
|
||||
stale_cutoff = time.perf_counter() - session_timeout
|
||||
stale_sessions = [session for session in self.sessions
|
||||
stale_sessions = [session for session in self.sessions.values()
|
||||
if session.last_recv < stale_cutoff]
|
||||
if stale_sessions:
|
||||
text = ', '.join(str(session.session_id)
|
||||
|
@ -345,7 +347,7 @@ class SessionManager:
|
|||
pending_requests = 0
|
||||
closing = 0
|
||||
|
||||
for s in self.sessions:
|
||||
for s in self.sessions.values():
|
||||
error_count += s.errors
|
||||
if s.log_me:
|
||||
logged += 1
|
||||
|
@ -379,7 +381,7 @@ class SessionManager:
|
|||
def _session_data(self, for_log):
|
||||
"""Returned to the RPC 'sessions' call."""
|
||||
now = time.time()
|
||||
sessions = sorted(self.sessions, key=lambda s: s.start_time)
|
||||
sessions = sorted(self.sessions.values(), key=lambda s: s.start_time)
|
||||
return [(session.session_id,
|
||||
session.flags(),
|
||||
session.peer_address_str(for_log=for_log),
|
||||
|
@ -583,7 +585,7 @@ class SessionManager:
|
|||
await self._close_servers(list(self.servers.keys()))
|
||||
if self.sessions:
|
||||
await asyncio.wait([
|
||||
session.close(force_after=1) for session in self.sessions
|
||||
session.close(force_after=1) for session in self.sessions.values()
|
||||
])
|
||||
await self.stop_other()
|
||||
|
||||
|
@ -638,13 +640,37 @@ class SessionManager:
|
|||
height_changed = height != self.notified_height
|
||||
if height_changed:
|
||||
await self._refresh_hsub_results(height)
|
||||
if self.sessions:
|
||||
await asyncio.wait([
|
||||
session.notify(touched, height_changed) for session in self.sessions
|
||||
])
|
||||
if not self.sessions:
|
||||
return
|
||||
|
||||
if height_changed:
|
||||
header_tasks = [
|
||||
session.send_notification('blockchain.headers.subscribe', (self.hsub_results[session.subscribe_headers_raw], ))
|
||||
for session in self.sessions.values() if session.subscribe_headers
|
||||
]
|
||||
if header_tasks:
|
||||
await asyncio.wait(header_tasks)
|
||||
|
||||
touched = touched.intersection(self.hashx_subscriptions_by_session.keys())
|
||||
|
||||
if touched or (height_changed and self.mempool_statuses):
|
||||
mempool_hashxs = set(self.mempool_statuses.keys())
|
||||
notified = set()
|
||||
for hashX in touched:
|
||||
for session_id in self.hashx_subscriptions_by_session[hashX]:
|
||||
asyncio.create_task(self.sessions[session_id].send_history_notification(hashX))
|
||||
notified.add(hashX)
|
||||
for hashX in mempool_hashxs.difference(touched):
|
||||
for session_id in self.hashx_subscriptions_by_session[hashX]:
|
||||
asyncio.create_task(self.sessions[session_id].send_history_notification(hashX))
|
||||
notified.add(hashX)
|
||||
|
||||
if touched:
|
||||
es = '' if len(touched) == 1 else 'es'
|
||||
self.logger.info(f'notified {len(notified)} mempool/{len(touched):,d} touched address{es}')
|
||||
|
||||
def add_session(self, session):
|
||||
self.sessions.add(session)
|
||||
self.sessions[id(session)] = session
|
||||
self.session_event.set()
|
||||
gid = int(session.start_time - self.start_time) // 900
|
||||
if self.cur_group.gid != gid:
|
||||
|
@ -653,7 +679,13 @@ class SessionManager:
|
|||
|
||||
def remove_session(self, session):
|
||||
"""Remove a session from our sessions list if there."""
|
||||
self.sessions.remove(session)
|
||||
session_id = id(session)
|
||||
for hashX in session.hashX_subs:
|
||||
sessions = self.hashx_subscriptions_by_session[hashX]
|
||||
sessions.remove(session_id)
|
||||
if not sessions:
|
||||
self.hashx_subscriptions_by_session.pop(hashX)
|
||||
self.sessions.pop(session_id)
|
||||
self.session_event.set()
|
||||
|
||||
|
||||
|
@ -688,8 +720,6 @@ class SessionBase(RPCSession):
|
|||
self._receive_message_orig = self.connection.receive_message
|
||||
self.connection.receive_message = self.receive_message
|
||||
|
||||
async def notify(self, touched, height_changed):
|
||||
pass
|
||||
|
||||
def default_framer(self):
|
||||
return NewlineFramer(self.env.max_receive)
|
||||
|
@ -886,7 +916,6 @@ class LBRYElectrumX(SessionBase):
|
|||
self.connection.max_response_size = self.env.max_send
|
||||
self.hashX_subs = {}
|
||||
self.sv_seen = False
|
||||
self.mempool_statuses = {}
|
||||
self.protocol_tuple = self.PROTOCOL_MIN
|
||||
|
||||
self.daemon = self.session_mgr.daemon
|
||||
|
@ -931,48 +960,22 @@ class LBRYElectrumX(SessionBase):
|
|||
def sub_count(self):
|
||||
return len(self.hashX_subs)
|
||||
|
||||
async def notify(self, touched, height_changed):
|
||||
"""Notify the client about changes to touched addresses (from mempool
|
||||
updates or new blocks) and height.
|
||||
"""
|
||||
if height_changed and self.subscribe_headers:
|
||||
args = (await self.subscribe_headers_result(), )
|
||||
if not (await self.send_notification('blockchain.headers.subscribe', args)):
|
||||
return
|
||||
|
||||
async def send_history_notification(alias, hashX):
|
||||
async def send_history_notification(self, hashX):
|
||||
start = time.perf_counter()
|
||||
alias = self.hashX_subs[hashX]
|
||||
if len(alias) == 64:
|
||||
method = 'blockchain.scripthash.subscribe'
|
||||
else:
|
||||
method = 'blockchain.address.subscribe'
|
||||
try:
|
||||
self.session_mgr.notifications_in_flight_metric.inc()
|
||||
status = await self.address_status(hashX)
|
||||
self.session_mgr.address_history_metric.observe(time.perf_counter() - start)
|
||||
start = time.perf_counter()
|
||||
if len(alias) == 64:
|
||||
method = 'blockchain.scripthash.subscribe'
|
||||
else:
|
||||
method = 'blockchain.address.subscribe'
|
||||
try:
|
||||
self.session_mgr.notifications_in_flight_metric.inc()
|
||||
status = await self.address_status(hashX)
|
||||
self.session_mgr.address_history_metric.observe(time.perf_counter() - start)
|
||||
start = time.perf_counter()
|
||||
await self.send_notification(method, (alias, status))
|
||||
self.session_mgr.notifications_sent_metric.observe(time.perf_counter() - start)
|
||||
finally:
|
||||
self.session_mgr.notifications_in_flight_metric.dec()
|
||||
|
||||
touched = touched.intersection(self.hashX_subs)
|
||||
if touched or (height_changed and self.mempool_statuses):
|
||||
notified = set()
|
||||
mempool_addrs = tuple(self.mempool_statuses.keys())
|
||||
for hashX in touched:
|
||||
alias = self.hashX_subs[hashX]
|
||||
asyncio.create_task(send_history_notification(alias, hashX))
|
||||
notified.add(hashX)
|
||||
for hashX in mempool_addrs:
|
||||
if hashX not in notified:
|
||||
alias = self.hashX_subs[hashX]
|
||||
asyncio.create_task(send_history_notification(alias, hashX))
|
||||
notified.add(hashX)
|
||||
|
||||
if touched:
|
||||
es = '' if len(touched) == 1 else 'es'
|
||||
self.logger.info(f'notified {len(notified)} mempool/{len(touched):,d} touched address{es}')
|
||||
await self.send_notification(method, (alias, status))
|
||||
self.session_mgr.notifications_sent_metric.observe(time.perf_counter() - start)
|
||||
finally:
|
||||
self.session_mgr.notifications_in_flight_metric.dec()
|
||||
|
||||
def get_metrics_or_placeholder_for_api(self, query_name):
|
||||
""" Do not hold on to a reference to the metrics
|
||||
|
@ -1189,7 +1192,7 @@ class LBRYElectrumX(SessionBase):
|
|||
# For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
|
||||
|
||||
db_history = await self.session_mgr.limited_history(hashX)
|
||||
mempool = await self.mempool.transaction_summaries(hashX)
|
||||
mempool = self.mempool.transaction_summaries(hashX)
|
||||
|
||||
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
|
||||
f'{height:d}:'
|
||||
|
@ -1203,9 +1206,9 @@ class LBRYElectrumX(SessionBase):
|
|||
status = None
|
||||
|
||||
if mempool:
|
||||
self.mempool_statuses[hashX] = status
|
||||
self.session_mgr.mempool_statuses[hashX] = status
|
||||
else:
|
||||
self.mempool_statuses.pop(hashX, None)
|
||||
self.session_mgr.mempool_statuses.pop(hashX, None)
|
||||
return status
|
||||
|
||||
async def hashX_listunspent(self, hashX):
|
||||
|
@ -1224,9 +1227,11 @@ class LBRYElectrumX(SessionBase):
|
|||
|
||||
async def hashX_subscribe(self, hashX, alias):
|
||||
self.hashX_subs[hashX] = alias
|
||||
self.session_mgr.hashx_subscriptions_by_session[hashX].add(id(self))
|
||||
return await self.address_status(hashX)
|
||||
|
||||
async def hashX_unsubscribe(self, hashX, alias):
|
||||
self.session_mgr.hashx_subscriptions_by_session[hashX].remove(id(self))
|
||||
self.hashX_subs.pop(hashX, None)
|
||||
|
||||
def address_to_hashX(self, address):
|
||||
|
@ -1249,7 +1254,7 @@ class LBRYElectrumX(SessionBase):
|
|||
async def address_get_mempool(self, address):
|
||||
"""Return the mempool transactions touching an address."""
|
||||
hashX = self.address_to_hashX(address)
|
||||
return await self.unconfirmed_history(hashX)
|
||||
return self.unconfirmed_history(hashX)
|
||||
|
||||
async def address_listunspent(self, address):
|
||||
"""Return the list of UTXOs of an address."""
|
||||
|
@ -1285,20 +1290,20 @@ class LBRYElectrumX(SessionBase):
|
|||
hashX = scripthash_to_hashX(scripthash)
|
||||
return await self.get_balance(hashX)
|
||||
|
||||
async def unconfirmed_history(self, hashX):
|
||||
def unconfirmed_history(self, hashX):
|
||||
# Note unconfirmed history is unordered in electrum-server
|
||||
# height is -1 if it has unconfirmed inputs, otherwise 0
|
||||
return [{'tx_hash': hash_to_hex_str(tx.hash),
|
||||
'height': -tx.has_unconfirmed_inputs,
|
||||
'fee': tx.fee}
|
||||
for tx in await self.mempool.transaction_summaries(hashX)]
|
||||
for tx in self.mempool.transaction_summaries(hashX)]
|
||||
|
||||
async def confirmed_and_unconfirmed_history(self, hashX):
|
||||
# Note history is ordered but unconfirmed is unordered in e-s
|
||||
history = await self.session_mgr.limited_history(hashX)
|
||||
conf = [{'tx_hash': hash_to_hex_str(tx_hash), 'height': height}
|
||||
for tx_hash, height in history]
|
||||
return conf + await self.unconfirmed_history(hashX)
|
||||
return conf + self.unconfirmed_history(hashX)
|
||||
|
||||
async def scripthash_get_history(self, scripthash):
|
||||
"""Return the confirmed and unconfirmed history of a scripthash."""
|
||||
|
@ -1308,7 +1313,7 @@ class LBRYElectrumX(SessionBase):
|
|||
async def scripthash_get_mempool(self, scripthash):
|
||||
"""Return the mempool transactions touching a scripthash."""
|
||||
hashX = scripthash_to_hashX(scripthash)
|
||||
return await self.unconfirmed_history(hashX)
|
||||
return self.unconfirmed_history(hashX)
|
||||
|
||||
async def scripthash_listunspent(self, scripthash):
|
||||
"""Return the list of UTXOs of a scripthash."""
|
||||
|
|
|
@ -71,7 +71,7 @@ class ReconnectTests(IntegrationTestCase):
|
|||
self.ledger.network.session_pool.new_connection_event.clear()
|
||||
await node2.start(self.blockchain)
|
||||
# this is only to speed up the test as retrying would take 4+ seconds
|
||||
for session in self.ledger.network.session_pool.sessions:
|
||||
for session in self.ledger.network.session_pool.sessions.values():
|
||||
session.trigger_urgent_reconnect.set()
|
||||
await asyncio.wait_for(self.ledger.network.session_pool.new_connection_event.wait(), timeout=1)
|
||||
self.assertEqual(2, len(list(self.ledger.network.session_pool.available_sessions)))
|
||||
|
@ -139,7 +139,7 @@ class ReconnectTests(IntegrationTestCase):
|
|||
async def test_online_but_still_unavailable(self):
|
||||
# Edge case. See issue #2445 for context
|
||||
self.assertIsNotNone(self.ledger.network.session_pool.fastest_session)
|
||||
for session in self.ledger.network.session_pool.sessions:
|
||||
for session in self.ledger.network.session_pool.sessions.values():
|
||||
session.response_time = None
|
||||
self.assertIsNone(self.ledger.network.session_pool.fastest_session)
|
||||
|
||||
|
|
|
@ -138,8 +138,8 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
# evil trick: mempool is unsorted on real life, but same order between python instances. reproduce it
|
||||
original_summary = self.conductor.spv_node.server.mempool.transaction_summaries
|
||||
|
||||
async def random_summary(*args, **kwargs):
|
||||
summary = await original_summary(*args, **kwargs)
|
||||
def random_summary(*args, **kwargs):
|
||||
summary = original_summary(*args, **kwargs)
|
||||
if summary and len(summary) > 2:
|
||||
ordered = summary.copy()
|
||||
while summary == ordered:
|
||||
|
|
|
@ -10,7 +10,7 @@ from lbry.wallet.dewies import dict_values_to_lbc
|
|||
class WalletCommands(CommandTestCase):
|
||||
|
||||
async def test_wallet_create_and_add_subscribe(self):
|
||||
session = next(iter(self.conductor.spv_node.server.session_mgr.sessions))
|
||||
session = next(iter(self.conductor.spv_node.server.session_mgr.sessions.values()))
|
||||
self.assertEqual(len(session.hashX_subs), 27)
|
||||
wallet = await self.daemon.jsonrpc_wallet_create('foo', create_account=True, single_key=True)
|
||||
self.assertEqual(len(session.hashX_subs), 28)
|
||||
|
|
Loading…
Add table
Reference in a new issue