lbry-sdk/tests/unit/test_event_controller.py
2020-07-13 18:21:41 -04:00

136 lines
4.8 KiB
Python

import sys
import asyncio
import logging
import unittest
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from lbry.testcase import AsyncioTestCase
from lbry.event import EventController, EventQueuePublisher
from lbry.tasks import TaskGroup
class StreamControllerTestCase(AsyncioTestCase):
async def test_non_unique_events(self):
events = []
controller = EventController()
controller.stream.listen(events.append)
await controller.add("yo")
await controller.add("yo")
self.assertListEqual(events, ["yo", "yo"])
async def test_unique_events(self):
events = []
controller = EventController(merge_repeated_events=True)
controller.stream.listen(events.append)
await controller.add("yo")
await controller.add("yo")
self.assertListEqual(events, ["yo"])
async def test_sync_listener_errors(self):
def bad_listener(_):
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
with self.assertRaises(ValueError), self.assertLogs():
await controller.add("yo")
async def test_async_listener_errors(self):
async def bad_listener(_):
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
with self.assertRaises(ValueError), self.assertLogs():
await controller.add("yo")
async def test_first_event(self):
controller = EventController()
first = controller.stream.first
await controller.add("one")
second = controller.stream.first
await controller.add("two")
self.assertEqual("one", await first)
self.assertEqual("two", await second)
async def test_last_event(self):
controller = EventController()
last = controller.stream.last
await controller.add("one")
await controller.add("two")
await controller.close()
self.assertEqual("two", await last)
async def test_race_condition_during_subscription_iteration(self):
controller = EventController()
sub1 = controller.stream.listen(print)
sub2 = controller.stream.listen(print)
sub3 = controller.stream.listen(print)
# normal iteration
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately after it's iterated over
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
sub2.cancel()
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately before it's iterated over
self.assertEqual(list(controller._iterate_subscriptions), [sub1, sub3]) # precondition
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
sub3.cancel()
self.assertEqual(next(i, None), None)
class TestEventQueuePublisher(AsyncioTestCase):
@unittest.skipIf("darwin" in sys.platform, "test is very unreliable on mac")
async def test_event_buffering_avoids_overloading_asyncio(self):
threads = 3
generate_events = 3000
expected_event_count = (threads * generate_events)-1
queue = mp.Queue()
executor = ThreadPoolExecutor(max_workers=threads)
controller = EventController()
events = []
async def event_logger(e):
await asyncio.sleep(0)
events.append(e)
controller.stream.listen(event_logger)
until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)
def event_producer(q, j):
for i in range(generate_events):
q.put(f'foo-{i}-{j}')
with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
# assertLogs() requires that at least one message is logged
# this is that one message:
logging.getLogger().info("placeholder")
await asyncio.wait([
self.loop.run_in_executor(executor, event_producer, queue, j)
for j in range(threads)
])
await until_all_consumed
# assert that there were no WARNINGs from asyncio about slow tasks
# (should have exactly 1 log which is the placeholder above)
self.assertEqual(['INFO:root:placeholder'], logs.output)
class TaskGroupTestCase(AsyncioTestCase):
async def test_cancel_sets_it_done(self):
group = TaskGroup()
group.cancel()
self.assertTrue(group.done.is_set())