use session TaskGroup for notification tasks

This commit is contained in:
Jack Robison 2022-09-13 15:02:28 -04:00
parent 9b3618f73e
commit 6155700a68
2 changed files with 20 additions and 13 deletions

View file

@ -237,12 +237,13 @@ class HubMemPool:
session = self.session_manager.sessions.get(session_id)
if session:
if session.subscribe_headers and height_changed:
asyncio.create_task(
session.send_notification('blockchain.headers.subscribe',
(self.session_manager.hsub_results[session.subscribe_headers_raw],))
session.send_notification(
'blockchain.headers.subscribe',
(self.session_manager.hsub_results[session.subscribe_headers_raw],)
)
if hashXes:
asyncio.create_task(session.send_history_notifications(hashXes))
session.send_history_notifications(hashXes)
async def _notify_sessions(self, height, touched, new_touched):
"""Notify sessions about height changes and touched addresses."""

View file

@ -690,13 +690,13 @@ class SessionManager:
return history
def _notify_peer(self, peer):
notify_tasks = [
notify_count = 0
for session in self.sessions.values():
if session.subscribe_peers:
notify_count += 1
session.send_notification('blockchain.peers.subscribe', [peer])
for session in self.sessions.values() if session.subscribe_peers
]
if notify_tasks:
self.logger.info(f'notify {len(notify_tasks)} sessions of new peers')
asyncio.create_task(asyncio.wait(notify_tasks))
if notify_count:
self.logger.info(f'notify {notify_count} sessions of new peers')
def add_session(self, session):
self.sessions[id(session)] = session
@ -1094,7 +1094,7 @@ class LBRYElectrumX(asyncio.Protocol):
raise result
return result
async def send_notification(self, method, args=()) -> bool:
async def _send_notification(self, method, args=()) -> bool:
"""Send an RPC notification over the network."""
message = self.connection.send_notification(Notification(method, args))
self.NOTIFICATION_COUNT.labels(method=method).inc()
@ -1106,6 +1106,9 @@ class LBRYElectrumX(asyncio.Protocol):
self.abort()
return False
def send_notification(self, method, args=()):
self._task_group.add(self._send_notification(method, args))
async def send_notifications(self, notifications) -> bool:
"""Send an RPC notification over the network."""
message, _ = self.connection.send_batch(notifications)
@ -1188,7 +1191,7 @@ class LBRYElectrumX(asyncio.Protocol):
return await self.db.get_hashX_statuses(hashXes)
return [await self.get_hashX_status(hashX) for hashX in hashXes]
async def send_history_notifications(self, hashXes: typing.List[bytes]):
async def _send_history_notifications(self, hashXes: typing.List[bytes]):
notifications = []
start = time.perf_counter()
statuses = await self.get_hashX_statuses(hashXes)
@ -1217,6 +1220,9 @@ class LBRYElectrumX(asyncio.Protocol):
finally:
self.session_manager.notifications_in_flight_metric.dec(len(notifications))
def send_history_notifications(self, hashXes: typing.List[bytes]):
self._task_group.add(self._send_history_notifications(hashXes))
# def get_metrics_or_placeholder_for_api(self, query_name):
# """ Do not hold on to a reference to the metrics
# returned by this method past an `await` or