lbry-sdk/tests/unit/test_event_controller.py

135 lines
4.7 KiB
Python
Raw Normal View History

import asyncio
import logging
import multiprocessing as mp
from unittest import skip
from concurrent.futures import ThreadPoolExecutor
2019-12-31 21:30:13 +01:00
from lbry.testcase import AsyncioTestCase
from lbry.event import EventController, EventQueuePublisher
2020-05-01 15:34:34 +02:00
from lbry.tasks import TaskGroup
2019-08-07 07:48:40 +02:00
2019-08-07 16:27:25 +02:00
class StreamControllerTestCase(AsyncioTestCase):
2020-05-01 15:34:34 +02:00
async def test_non_unique_events(self):
2019-08-07 07:48:40 +02:00
events = []
2020-05-01 15:34:34 +02:00
controller = EventController()
controller.stream.listen(events.append)
await controller.add("yo")
await controller.add("yo")
self.assertListEqual(events, ["yo", "yo"])
2019-08-07 07:48:40 +02:00
2020-05-01 15:34:34 +02:00
async def test_unique_events(self):
2019-08-07 07:48:40 +02:00
events = []
2020-05-01 15:34:34 +02:00
controller = EventController(merge_repeated_events=True)
controller.stream.listen(events.append)
await controller.add("yo")
await controller.add("yo")
self.assertListEqual(events, ["yo"])
2019-12-11 23:31:45 +01:00
2020-05-01 15:34:34 +02:00
async def test_sync_listener_errors(self):
def bad_listener(_):
2020-05-01 15:34:34 +02:00
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
2020-06-10 05:27:45 +02:00
with self.assertRaises(ValueError), self.assertLogs():
2020-05-01 15:34:34 +02:00
await controller.add("yo")
async def test_async_listener_errors(self):
async def bad_listener(_):
2020-05-01 15:34:34 +02:00
raise ValueError('bad')
controller = EventController()
controller.stream.listen(bad_listener)
2020-06-10 05:27:45 +02:00
with self.assertRaises(ValueError), self.assertLogs():
2020-05-01 15:34:34 +02:00
await controller.add("yo")
async def test_first_event(self):
controller = EventController()
first = controller.stream.first
await controller.add("one")
second = controller.stream.first
await controller.add("two")
self.assertEqual("one", await first)
self.assertEqual("two", await second)
async def test_last_event(self):
controller = EventController()
last = controller.stream.last
await controller.add("one")
await controller.add("two")
await controller.close()
self.assertEqual("two", await last)
async def test_race_condition_during_subscription_iteration(self):
controller = EventController()
sub1 = controller.stream.listen(print)
sub2 = controller.stream.listen(print)
sub3 = controller.stream.listen(print)
# normal iteration
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately after it's iterated over
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
self.assertEqual(next(i, None), sub2)
sub2.cancel()
self.assertEqual(next(i, None), sub3)
self.assertEqual(next(i, None), None)
# subscription canceled immediately before it's iterated over
self.assertEqual(list(controller._iterate_subscriptions), [sub1, sub3]) # precondition
i = iter(controller._iterate_subscriptions)
self.assertEqual(next(i, None), sub1)
sub3.cancel()
self.assertEqual(next(i, None), None)
2019-12-11 23:31:45 +01:00
class TestEventQueuePublisher(AsyncioTestCase):
async def test_event_buffering_avoids_overloading_asyncio(self):
threads = 3
generate_events = 2000
expected_event_count = (threads * generate_events)-1
queue = mp.Queue()
executor = ThreadPoolExecutor(max_workers=threads)
controller = EventController()
events = []
async def event_logger(e):
await asyncio.sleep(0)
events.append(e)
controller.stream.listen(event_logger)
until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)
def event_producer(q, j):
for i in range(generate_events):
q.put(f'foo-{i}-{j}')
with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
# assertLogs() requires that at least one message is logged
# this is that one message:
logging.getLogger().info("placeholder")
await asyncio.wait([
self.loop.run_in_executor(executor, event_producer, queue, j)
for j in range(threads)
])
await until_all_consumed
# assert that there were no WARNINGs from asyncio about slow tasks
# (should have exactly 1 log which is the placeholder above)
self.assertEqual(['INFO:root:placeholder'], logs.output)
class TaskGroupTestCase(AsyncioTestCase):
async def test_cancel_sets_it_done(self):
2019-12-11 23:31:45 +01:00
group = TaskGroup()
group.cancel()
2019-12-12 03:33:34 +01:00
self.assertTrue(group.done.is_set())