From 6155700a68990e7bbce914374cbb058c4fa97641 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 13 Sep 2022 15:02:28 -0400 Subject: [PATCH] use session TaskGroup for notification tasks --- hub/herald/mempool.py | 9 +++++---- hub/herald/session.py | 24 +++++++++++++++--------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/hub/herald/mempool.py b/hub/herald/mempool.py index 67400d1..5f3c1ba 100644 --- a/hub/herald/mempool.py +++ b/hub/herald/mempool.py @@ -237,12 +237,13 @@ class HubMemPool: session = self.session_manager.sessions.get(session_id) if session: if session.subscribe_headers and height_changed: - asyncio.create_task( - session.send_notification('blockchain.headers.subscribe', - (self.session_manager.hsub_results[session.subscribe_headers_raw],)) + session.send_notification( + 'blockchain.headers.subscribe', + (self.session_manager.hsub_results[session.subscribe_headers_raw],) ) + if hashXes: - asyncio.create_task(session.send_history_notifications(hashXes)) + session.send_history_notifications(hashXes) async def _notify_sessions(self, height, touched, new_touched): """Notify sessions about height changes and touched addresses.""" diff --git a/hub/herald/session.py b/hub/herald/session.py index 2868fc8..501a68b 100644 --- a/hub/herald/session.py +++ b/hub/herald/session.py @@ -690,13 +690,13 @@ class SessionManager: return history def _notify_peer(self, peer): - notify_tasks = [ - session.send_notification('blockchain.peers.subscribe', [peer]) - for session in self.sessions.values() if session.subscribe_peers - ] - if notify_tasks: - self.logger.info(f'notify {len(notify_tasks)} sessions of new peers') - asyncio.create_task(asyncio.wait(notify_tasks)) + notify_count = 0 + for session in self.sessions.values(): + if session.subscribe_peers: + notify_count += 1 + session.send_notification('blockchain.peers.subscribe', [peer]) + if notify_count: + self.logger.info(f'notify {notify_count} sessions of new peers') def add_session(self, session): self.sessions[id(session)] = session @@ -1094,7 +1094,7 @@ class LBRYElectrumX(asyncio.Protocol): raise result return result - async def send_notification(self, method, args=()) -> bool: + async def _send_notification(self, method, args=()) -> bool: """Send an RPC notification over the network.""" message = self.connection.send_notification(Notification(method, args)) self.NOTIFICATION_COUNT.labels(method=method).inc() @@ -1106,6 +1106,9 @@ class LBRYElectrumX(asyncio.Protocol): self.abort() return False + def send_notification(self, method, args=()): + self._task_group.add(self._send_notification(method, args)) + async def send_notifications(self, notifications) -> bool: """Send an RPC notification over the network.""" message, _ = self.connection.send_batch(notifications) @@ -1188,7 +1191,7 @@ class LBRYElectrumX(asyncio.Protocol): return await self.db.get_hashX_statuses(hashXes) return [await self.get_hashX_status(hashX) for hashX in hashXes] - async def send_history_notifications(self, hashXes: typing.List[bytes]): + async def _send_history_notifications(self, hashXes: typing.List[bytes]): notifications = [] start = time.perf_counter() statuses = await self.get_hashX_statuses(hashXes) @@ -1217,6 +1220,9 @@ class LBRYElectrumX(asyncio.Protocol): finally: self.session_manager.notifications_in_flight_metric.dec(len(notifications)) + def send_history_notifications(self, hashXes: typing.List[bytes]): + self._task_group.add(self._send_history_notifications(hashXes)) + # def get_metrics_or_placeholder_for_api(self, query_name): # """ Do not hold on to a reference to the metrics # returned by this method past an `await` or