event subscription and publishing bug fixes

This commit is contained in:
Lex Berezhny 2020-06-29 18:10:26 -04:00
parent 434c1bc6b3
commit b45a222f98
2 changed files with 32 additions and 6 deletions

View file

@ -69,13 +69,13 @@ class EventController:
next_sub = self._first_subscription next_sub = self._first_subscription
while next_sub is not None: while next_sub is not None:
subscription = next_sub subscription = next_sub
next_sub = next_sub._next
yield subscription yield subscription
next_sub = next_sub._next
async def _notify(self, notify, *args): async def _notify(self, notify, *args):
try: try:
maybe_coroutine = notify(*args) maybe_coroutine = notify(*args)
if asyncio.iscoroutine(maybe_coroutine): if maybe_coroutine is not None and asyncio.iscoroutine(maybe_coroutine):
await maybe_coroutine await maybe_coroutine
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -111,7 +111,6 @@ class EventController:
self._last_subscription = previous self._last_subscription = previous
else: else:
next_sub._previous = previous next_sub._previous = previous
subscription._next = subscription._previous = subscription
def _listen(self, on_data, on_error, on_done): def _listen(self, on_data, on_error, on_done):
subscription = BroadcastSubscription(self, on_data, on_error, on_done) subscription = BroadcastSubscription(self, on_data, on_error, on_done)
@ -128,7 +127,7 @@ class EventController:
class EventStream: class EventStream:
def __init__(self, controller): def __init__(self, controller: EventController):
self._controller = controller self._controller = controller
def listen(self, on_data, on_error=None, on_done=None) -> BroadcastSubscription: def listen(self, on_data, on_error=None, on_done=None) -> BroadcastSubscription:

View file

@ -60,12 +60,39 @@ class StreamControllerTestCase(AsyncioTestCase):
await controller.close() await controller.close()
self.assertEqual("two", await last) self.assertEqual("two", await last)
async def test_race_condition_during_subscription_iteration(self):
controller = EventController()
sub1 = controller.stream.listen(print)
sub2 = controller.stream.listen(print)
sub3 = controller.stream.listen(print)
# normal iteration
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately after it's iterated over
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
sub2.cancel()
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately before it's iterated over
self.assertEqual(list(controller._iterate_subscriptions), [sub1, sub3]) # precondition
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
sub3.cancel()
self.assertEqual(next(i, None), None)
@skip('need to make this test more reliable')
class TestEventQueuePublisher(AsyncioTestCase): class TestEventQueuePublisher(AsyncioTestCase):
async def test_event_buffering_avoids_overloading_asyncio(self): async def test_event_buffering_avoids_overloading_asyncio(self):
threads = 3 threads = 4
generate_events = 3000 generate_events = 3000
expected_event_count = (threads * generate_events)-1 expected_event_count = (threads * generate_events)-1