import sys import asyncio import logging import unittest import multiprocessing as mp from concurrent.futures import ThreadPoolExecutor from lbry.testcase import AsyncioTestCase from lbry.event import EventController, EventQueuePublisher from lbry.tasks import TaskGroup class StreamControllerTestCase(AsyncioTestCase): async def test_non_unique_events(self): events = [] controller = EventController() controller.stream.listen(events.append) await controller.add("yo") await controller.add("yo") self.assertListEqual(events, ["yo", "yo"]) async def test_unique_events(self): events = [] controller = EventController(merge_repeated_events=True) controller.stream.listen(events.append) await controller.add("yo") await controller.add("yo") self.assertListEqual(events, ["yo"]) async def test_sync_listener_errors(self): def bad_listener(_): raise ValueError('bad') controller = EventController() controller.stream.listen(bad_listener) with self.assertRaises(ValueError), self.assertLogs(): await controller.add("yo") async def test_async_listener_errors(self): async def bad_listener(_): raise ValueError('bad') controller = EventController() controller.stream.listen(bad_listener) with self.assertRaises(ValueError), self.assertLogs(): await controller.add("yo") async def test_first_event(self): controller = EventController() first = controller.stream.first await controller.add("one") second = controller.stream.first await controller.add("two") self.assertEqual("one", await first) self.assertEqual("two", await second) async def test_last_event(self): controller = EventController() last = controller.stream.last await controller.add("one") await controller.add("two") await controller.close() self.assertEqual("two", await last) async def test_race_condition_during_subscription_iteration(self): controller = EventController() sub1 = controller.stream.listen(print) sub2 = controller.stream.listen(print) sub3 = controller.stream.listen(print) # normal iteration i = iter(controller._iterate_subscriptions) self.assertEqual(next(i, None), sub1) self.assertEqual(next(i, None), sub2) self.assertEqual(next(i, None), sub3) self.assertEqual(next(i, None), None) # subscription canceled immediately after it's iterated over i = iter(controller._iterate_subscriptions) self.assertEqual(next(i, None), sub1) self.assertEqual(next(i, None), sub2) sub2.cancel() self.assertEqual(next(i, None), sub3) self.assertEqual(next(i, None), None) # subscription canceled immediately before it's iterated over self.assertEqual(list(controller._iterate_subscriptions), [sub1, sub3]) # precondition i = iter(controller._iterate_subscriptions) self.assertEqual(next(i, None), sub1) sub3.cancel() self.assertEqual(next(i, None), None) class TestEventQueuePublisher(AsyncioTestCase): @unittest.skipIf("linux" not in sys.platform, "unreliable everywhere except linux") async def test_event_buffering_avoids_overloading_asyncio(self): threads = 3 generate_events = 3000 expected_event_count = (threads * generate_events)-1 queue = mp.Queue() executor = ThreadPoolExecutor(max_workers=threads) controller = EventController() events = [] async def event_logger(e): await asyncio.sleep(0) events.append(e) controller.stream.listen(event_logger) until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count) def event_producer(q, j): for i in range(generate_events): q.put(f'foo-{i}-{j}') with EventQueuePublisher(queue, controller), self.assertLogs() as logs: # assertLogs() requires that at least one message is logged # this is that one message: logging.getLogger().info("placeholder") await asyncio.wait([ self.loop.run_in_executor(executor, event_producer, queue, j) for j in range(threads) ]) await until_all_consumed # assert that there were no WARNINGs from asyncio about slow tasks # (should have exactly 1 log which is the placeholder above) self.assertEqual(['INFO:root:placeholder'], logs.output) class TaskGroupTestCase(AsyncioTestCase): async def test_cancel_sets_it_done(self): group = TaskGroup() group.cancel() self.assertTrue(group.done.is_set())