EventQueuePublisher uses a buffer to reduce number of tasks created

This commit is contained in:
Lex Berezhny 2020-05-22 18:40:21 -04:00
parent 2af29b892b
commit 596ed08395
2 changed files with 81 additions and 9 deletions

View file

@ -1,7 +1,9 @@
import time
import asyncio import asyncio
import threading import threading
import multiprocessing
import logging import logging
from queue import Empty
from multiprocessing import Queue
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -86,6 +88,10 @@ class EventController:
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
await self._notify(subscription._add, event) await self._notify(subscription._add, event)
async def add_all(self, events):
for event in events:
await self.add(event)
async def add_error(self, exception): async def add_error(self, exception):
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
await self._notify(subscription._add_error, exception) await self._notify(subscription._add_error, exception)
@ -183,7 +189,7 @@ class EventQueuePublisher(threading.Thread):
STOP = 'STOP' STOP = 'STOP'
def __init__(self, queue: multiprocessing.Queue, event_controller: EventController): def __init__(self, queue: Queue, event_controller: EventController):
super().__init__() super().__init__()
self.queue = queue self.queue = queue
self.event_controller = event_controller self.event_controller = event_controller
@ -197,13 +203,37 @@ class EventQueuePublisher(threading.Thread):
super().start() super().start()
def run(self): def run(self):
queue_get_timeout = 0.2
buffer_drain_size = 100
buffer_drain_timeout = 0.1
buffer = []
last_drained_ms_ago = time.perf_counter()
while True: while True:
msg = self.queue.get()
try:
msg = self.queue.get(timeout=queue_get_timeout)
if msg != self.STOP:
buffer.append(msg)
except Empty:
msg = None
drain = any((
len(buffer) >= buffer_drain_size,
(time.perf_counter() - last_drained_ms_ago) >= buffer_drain_timeout,
msg == self.STOP
))
if drain and buffer:
asyncio.run_coroutine_threadsafe(
self.event_controller.add_all([
self.message_to_event(msg) for msg in buffer
]), self.loop
)
buffer.clear()
last_drained_ms_ago = time.perf_counter()
if msg == self.STOP: if msg == self.STOP:
return return
asyncio.run_coroutine_threadsafe(
self.event_controller.add(self.message_to_event(msg)), self.loop
)
def stop(self): def stop(self):
self.queue.put(self.STOP) self.queue.put(self.STOP)

View file

@ -1,5 +1,10 @@
import asyncio
import logging
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.event import EventController from lbry.event import EventController, EventQueuePublisher
from lbry.tasks import TaskGroup from lbry.tasks import TaskGroup
@ -22,7 +27,7 @@ class StreamControllerTestCase(AsyncioTestCase):
self.assertListEqual(events, ["yo"]) self.assertListEqual(events, ["yo"])
async def test_sync_listener_errors(self): async def test_sync_listener_errors(self):
def bad_listener(e): def bad_listener(_):
raise ValueError('bad') raise ValueError('bad')
controller = EventController() controller = EventController()
controller.stream.listen(bad_listener) controller.stream.listen(bad_listener)
@ -30,7 +35,7 @@ class StreamControllerTestCase(AsyncioTestCase):
await controller.add("yo") await controller.add("yo")
async def test_async_listener_errors(self): async def test_async_listener_errors(self):
async def bad_listener(e): async def bad_listener(_):
raise ValueError('bad') raise ValueError('bad')
controller = EventController() controller = EventController()
controller.stream.listen(bad_listener) controller.stream.listen(bad_listener)
@ -55,6 +60,43 @@ class StreamControllerTestCase(AsyncioTestCase):
self.assertEqual("two", await last) self.assertEqual("two", await last)
class TestEventQueuePublisher(AsyncioTestCase):
async def test_event_buffering_avoids_overloading_asyncio(self):
threads = 3
generate_events = 3000
expected_event_count = (threads * generate_events)-1
queue = mp.Queue()
executor = ThreadPoolExecutor(max_workers=threads)
controller = EventController()
events = []
async def event_logger(e):
await asyncio.sleep(0)
events.append(e)
controller.stream.listen(event_logger)
until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)
def event_producer(q, j):
for i in range(generate_events):
q.put(f'foo-{i}-{j}')
with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
# assertLogs() requires that at least one message is logged
# this is that one message:
logging.getLogger().info("placeholder")
await asyncio.wait([
self.loop.run_in_executor(executor, event_producer, queue, j)
for j in range(threads)
])
await until_all_consumed
# assert that there were no WARNINGs from asyncio about slow tasks
# (should have exactly 1 log which is the placeholder above)
self.assertEqual(['INFO:root:placeholder'], logs.output)
class TaskGroupTestCase(AsyncioTestCase): class TaskGroupTestCase(AsyncioTestCase):
async def test_cancel_sets_it_done(self): async def test_cancel_sets_it_done(self):