use session TaskGroup for notification tasks

This commit is contained in:
Jack Robison 2022-09-13 15:02:28 -04:00
parent 9b3618f73e
commit 6155700a68
2 changed files with 20 additions and 13 deletions

View file

@ -237,12 +237,13 @@ class HubMemPool:
session = self.session_manager.sessions.get(session_id) session = self.session_manager.sessions.get(session_id)
if session: if session:
if session.subscribe_headers and height_changed: if session.subscribe_headers and height_changed:
asyncio.create_task( session.send_notification(
session.send_notification('blockchain.headers.subscribe', 'blockchain.headers.subscribe',
(self.session_manager.hsub_results[session.subscribe_headers_raw],)) (self.session_manager.hsub_results[session.subscribe_headers_raw],)
) )
if hashXes: if hashXes:
asyncio.create_task(session.send_history_notifications(hashXes)) session.send_history_notifications(hashXes)
async def _notify_sessions(self, height, touched, new_touched): async def _notify_sessions(self, height, touched, new_touched):
"""Notify sessions about height changes and touched addresses.""" """Notify sessions about height changes and touched addresses."""

View file

@ -690,13 +690,13 @@ class SessionManager:
return history return history
def _notify_peer(self, peer): def _notify_peer(self, peer):
notify_tasks = [ notify_count = 0
session.send_notification('blockchain.peers.subscribe', [peer]) for session in self.sessions.values():
for session in self.sessions.values() if session.subscribe_peers if session.subscribe_peers:
] notify_count += 1
if notify_tasks: session.send_notification('blockchain.peers.subscribe', [peer])
self.logger.info(f'notify {len(notify_tasks)} sessions of new peers') if notify_count:
asyncio.create_task(asyncio.wait(notify_tasks)) self.logger.info(f'notify {notify_count} sessions of new peers')
def add_session(self, session): def add_session(self, session):
self.sessions[id(session)] = session self.sessions[id(session)] = session
@ -1094,7 +1094,7 @@ class LBRYElectrumX(asyncio.Protocol):
raise result raise result
return result return result
async def send_notification(self, method, args=()) -> bool: async def _send_notification(self, method, args=()) -> bool:
"""Send an RPC notification over the network.""" """Send an RPC notification over the network."""
message = self.connection.send_notification(Notification(method, args)) message = self.connection.send_notification(Notification(method, args))
self.NOTIFICATION_COUNT.labels(method=method).inc() self.NOTIFICATION_COUNT.labels(method=method).inc()
@ -1106,6 +1106,9 @@ class LBRYElectrumX(asyncio.Protocol):
self.abort() self.abort()
return False return False
def send_notification(self, method, args=()):
self._task_group.add(self._send_notification(method, args))
async def send_notifications(self, notifications) -> bool: async def send_notifications(self, notifications) -> bool:
"""Send an RPC notification over the network.""" """Send an RPC notification over the network."""
message, _ = self.connection.send_batch(notifications) message, _ = self.connection.send_batch(notifications)
@ -1188,7 +1191,7 @@ class LBRYElectrumX(asyncio.Protocol):
return await self.db.get_hashX_statuses(hashXes) return await self.db.get_hashX_statuses(hashXes)
return [await self.get_hashX_status(hashX) for hashX in hashXes] return [await self.get_hashX_status(hashX) for hashX in hashXes]
async def send_history_notifications(self, hashXes: typing.List[bytes]): async def _send_history_notifications(self, hashXes: typing.List[bytes]):
notifications = [] notifications = []
start = time.perf_counter() start = time.perf_counter()
statuses = await self.get_hashX_statuses(hashXes) statuses = await self.get_hashX_statuses(hashXes)
@ -1217,6 +1220,9 @@ class LBRYElectrumX(asyncio.Protocol):
finally: finally:
self.session_manager.notifications_in_flight_metric.dec(len(notifications)) self.session_manager.notifications_in_flight_metric.dec(len(notifications))
def send_history_notifications(self, hashXes: typing.List[bytes]):
self._task_group.add(self._send_history_notifications(hashXes))
# def get_metrics_or_placeholder_for_api(self, query_name): # def get_metrics_or_placeholder_for_api(self, query_name):
# """ Do not hold on to a reference to the metrics # """ Do not hold on to a reference to the metrics
# returned by this method past an `await` or # returned by this method past an `await` or