diff --git a/lbry/event.py b/lbry/event.py index a3adc3789..d584eed5e 100644 --- a/lbry/event.py +++ b/lbry/event.py @@ -1,7 +1,9 @@ +import time import asyncio import threading -import multiprocessing import logging +from queue import Empty +from multiprocessing import Queue log = logging.getLogger(__name__) @@ -86,6 +88,10 @@ class EventController: for subscription in self._iterate_subscriptions: await self._notify(subscription._add, event) + async def add_all(self, events): + for event in events: + await self.add(event) + async def add_error(self, exception): for subscription in self._iterate_subscriptions: await self._notify(subscription._add_error, exception) @@ -183,7 +189,7 @@ class EventQueuePublisher(threading.Thread): STOP = 'STOP' - def __init__(self, queue: multiprocessing.Queue, event_controller: EventController): + def __init__(self, queue: Queue, event_controller: EventController): super().__init__() self.queue = queue self.event_controller = event_controller @@ -197,13 +203,37 @@ class EventQueuePublisher(threading.Thread): super().start() def run(self): + queue_get_timeout = 0.2 + buffer_drain_size = 100 + buffer_drain_timeout = 0.1 + + buffer = [] + last_drained_ms_ago = time.perf_counter() while True: - msg = self.queue.get() + + try: + msg = self.queue.get(timeout=queue_get_timeout) + if msg != self.STOP: + buffer.append(msg) + except Empty: + msg = None + + drain = any(( + len(buffer) >= buffer_drain_size, + (time.perf_counter() - last_drained_ms_ago) >= buffer_drain_timeout, + msg == self.STOP + )) + if drain and buffer: + asyncio.run_coroutine_threadsafe( + self.event_controller.add_all([ + self.message_to_event(msg) for msg in buffer + ]), self.loop + ) + buffer.clear() + last_drained_ms_ago = time.perf_counter() + if msg == self.STOP: return - asyncio.run_coroutine_threadsafe( - self.event_controller.add(self.message_to_event(msg)), self.loop - ) def stop(self): self.queue.put(self.STOP) diff --git a/tests/unit/test_event_controller.py b/tests/unit/test_event_controller.py index eb23d5a31..160ad2ee8 100644 --- a/tests/unit/test_event_controller.py +++ b/tests/unit/test_event_controller.py @@ -1,5 +1,10 @@ +import asyncio +import logging +import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor + from lbry.testcase import AsyncioTestCase -from lbry.event import EventController +from lbry.event import EventController, EventQueuePublisher from lbry.tasks import TaskGroup @@ -22,7 +27,7 @@ class StreamControllerTestCase(AsyncioTestCase): self.assertListEqual(events, ["yo"]) async def test_sync_listener_errors(self): - def bad_listener(e): + def bad_listener(_): raise ValueError('bad') controller = EventController() controller.stream.listen(bad_listener) @@ -30,7 +35,7 @@ class StreamControllerTestCase(AsyncioTestCase): await controller.add("yo") async def test_async_listener_errors(self): - async def bad_listener(e): + async def bad_listener(_): raise ValueError('bad') controller = EventController() controller.stream.listen(bad_listener) @@ -55,6 +60,43 @@ class StreamControllerTestCase(AsyncioTestCase): self.assertEqual("two", await last) +class TestEventQueuePublisher(AsyncioTestCase): + + async def test_event_buffering_avoids_overloading_asyncio(self): + threads = 3 + generate_events = 3000 + expected_event_count = (threads * generate_events)-1 + + queue = mp.Queue() + executor = ThreadPoolExecutor(max_workers=threads) + controller = EventController() + events = [] + + async def event_logger(e): + await asyncio.sleep(0) + events.append(e) + + controller.stream.listen(event_logger) + until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count) + + def event_producer(q, j): + for i in range(generate_events): + q.put(f'foo-{i}-{j}') + + with EventQueuePublisher(queue, controller), self.assertLogs() as logs: + # assertLogs() requires that at least one message is logged + # this is that one message: + logging.getLogger().info("placeholder") + await asyncio.wait([ + self.loop.run_in_executor(executor, event_producer, queue, j) + for j in range(threads) + ]) + await until_all_consumed + # assert that there were no WARNINGs from asyncio about slow tasks + # (should have exactly 1 log which is the placeholder above) + self.assertEqual(['INFO:root:placeholder'], logs.output) + + class TaskGroupTestCase(AsyncioTestCase): async def test_cancel_sets_it_done(self):