diff --git a/lbry/event.py b/lbry/event.py index deed7dea0..a3adc3789 100644 --- a/lbry/event.py +++ b/lbry/event.py @@ -70,9 +70,9 @@ class EventController: next_sub = next_sub._next yield subscription - async def _notify(self, notify, event): + async def _notify(self, notify, *args): try: - maybe_coroutine = notify(event) + maybe_coroutine = notify(*args) if asyncio.iscoroutine(maybe_coroutine): await maybe_coroutine except Exception as e: @@ -90,9 +90,9 @@ class EventController: for subscription in self._iterate_subscriptions: await self._notify(subscription._add_error, exception) - def close(self): + async def close(self): for subscription in self._iterate_subscriptions: - subscription._close() + await self._notify(subscription._close) def _cancel(self, subscription): previous = subscription._previous @@ -151,6 +151,23 @@ class EventStream: ) return future + @property + def last(self) -> asyncio.Future: + future = asyncio.get_event_loop().create_future() + value = None + + def update_value(v): + nonlocal value + value = v + + subscription = self.listen( + lambda v: update_value(v), + lambda exception: not future.done() and self._cancel_and_error(subscription, future, exception), + lambda: not future.done() and self._cancel_and_callback(subscription, future, value), + ) + + return future + @staticmethod def _cancel_and_callback(subscription: BroadcastSubscription, future: asyncio.Future, value): subscription.cancel() @@ -170,7 +187,14 @@ class EventQueuePublisher(threading.Thread): super().__init__() self.queue = queue self.event_controller = event_controller + self.loop = None + + def message_to_event(self, message): + return message + + def start(self): self.loop = asyncio.get_running_loop() + super().start() def run(self): while True: @@ -178,7 +202,7 @@ class EventQueuePublisher(threading.Thread): if msg == self.STOP: return asyncio.run_coroutine_threadsafe( - self.event_controller.add(msg), self.loop + self.event_controller.add(self.message_to_event(msg)), self.loop ) def stop(self): diff --git a/tests/unit/test_event_controller.py b/tests/unit/test_event_controller.py index 213899414..eb23d5a31 100644 --- a/tests/unit/test_event_controller.py +++ b/tests/unit/test_event_controller.py @@ -37,6 +37,23 @@ class StreamControllerTestCase(AsyncioTestCase): with self.assertRaises(ValueError): await controller.add("yo") + async def test_first_event(self): + controller = EventController() + first = controller.stream.first + await controller.add("one") + second = controller.stream.first + await controller.add("two") + self.assertEqual("one", await first) + self.assertEqual("two", await second) + + async def test_last_event(self): + controller = EventController() + last = controller.stream.last + await controller.add("one") + await controller.add("two") + await controller.close() + self.assertEqual("two", await last) + class TaskGroupTestCase(AsyncioTestCase):