diff --git a/lbry/event.py b/lbry/event.py index 82928f1d9..233925d3c 100644 --- a/lbry/event.py +++ b/lbry/event.py @@ -69,13 +69,13 @@ class EventController: next_sub = self._first_subscription while next_sub is not None: subscription = next_sub - next_sub = next_sub._next yield subscription + next_sub = next_sub._next async def _notify(self, notify, *args): try: maybe_coroutine = notify(*args) - if asyncio.iscoroutine(maybe_coroutine): + if maybe_coroutine is not None and asyncio.iscoroutine(maybe_coroutine): await maybe_coroutine except Exception as e: log.exception(e) @@ -111,7 +111,6 @@ class EventController: self._last_subscription = previous else: next_sub._previous = previous - subscription._next = subscription._previous = subscription def _listen(self, on_data, on_error, on_done): subscription = BroadcastSubscription(self, on_data, on_error, on_done) @@ -128,7 +127,7 @@ class EventController: class EventStream: - def __init__(self, controller): + def __init__(self, controller: EventController): self._controller = controller def listen(self, on_data, on_error=None, on_done=None) -> BroadcastSubscription: diff --git a/tests/unit/test_event_controller.py b/tests/unit/test_event_controller.py index 0db07d3bd..00b8ec3dd 100644 --- a/tests/unit/test_event_controller.py +++ b/tests/unit/test_event_controller.py @@ -60,12 +60,39 @@ class StreamControllerTestCase(AsyncioTestCase): await controller.close() self.assertEqual("two", await last) + async def test_race_condition_during_subscription_iteration(self): + controller = EventController() + sub1 = controller.stream.listen(print) + sub2 = controller.stream.listen(print) + sub3 = controller.stream.listen(print) + + # normal iteration + i = iter(controller._iterate_subscriptions) + self.assertEqual(next(i, None), sub1) + self.assertEqual(next(i, None), sub2) + self.assertEqual(next(i, None), sub3) + self.assertEqual(next(i, None), None) + + # subscription canceled immediately after it's iterated over + i = iter(controller._iterate_subscriptions) + self.assertEqual(next(i, None), sub1) + self.assertEqual(next(i, None), sub2) + sub2.cancel() + self.assertEqual(next(i, None), sub3) + self.assertEqual(next(i, None), None) + + # subscription canceled immediately before it's iterated over + self.assertEqual(list(controller._iterate_subscriptions), [sub1, sub3]) # precondition + i = iter(controller._iterate_subscriptions) + self.assertEqual(next(i, None), sub1) + sub3.cancel() + self.assertEqual(next(i, None), None) + -@skip('need to make this test more reliable') class TestEventQueuePublisher(AsyncioTestCase): async def test_event_buffering_avoids_overloading_asyncio(self): - threads = 3 + threads = 4 generate_events = 3000 expected_event_count = (threads * generate_events)-1