EventQueuePublisher uses a buffer to reduce number of tasks created
This commit is contained in:
parent
2af29b892b
commit
596ed08395
2 changed files with 81 additions and 9 deletions
|
@ -1,7 +1,9 @@
|
|||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
import multiprocessing
|
||||
import logging
|
||||
from queue import Empty
|
||||
from multiprocessing import Queue
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -86,6 +88,10 @@ class EventController:
|
|||
for subscription in self._iterate_subscriptions:
|
||||
await self._notify(subscription._add, event)
|
||||
|
||||
async def add_all(self, events):
|
||||
for event in events:
|
||||
await self.add(event)
|
||||
|
||||
async def add_error(self, exception):
|
||||
for subscription in self._iterate_subscriptions:
|
||||
await self._notify(subscription._add_error, exception)
|
||||
|
@ -183,7 +189,7 @@ class EventQueuePublisher(threading.Thread):
|
|||
|
||||
STOP = 'STOP'
|
||||
|
||||
def __init__(self, queue: multiprocessing.Queue, event_controller: EventController):
|
||||
def __init__(self, queue: Queue, event_controller: EventController):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.event_controller = event_controller
|
||||
|
@ -197,13 +203,37 @@ class EventQueuePublisher(threading.Thread):
|
|||
super().start()
|
||||
|
||||
def run(self):
|
||||
queue_get_timeout = 0.2
|
||||
buffer_drain_size = 100
|
||||
buffer_drain_timeout = 0.1
|
||||
|
||||
buffer = []
|
||||
last_drained_ms_ago = time.perf_counter()
|
||||
while True:
|
||||
msg = self.queue.get()
|
||||
|
||||
try:
|
||||
msg = self.queue.get(timeout=queue_get_timeout)
|
||||
if msg != self.STOP:
|
||||
buffer.append(msg)
|
||||
except Empty:
|
||||
msg = None
|
||||
|
||||
drain = any((
|
||||
len(buffer) >= buffer_drain_size,
|
||||
(time.perf_counter() - last_drained_ms_ago) >= buffer_drain_timeout,
|
||||
msg == self.STOP
|
||||
))
|
||||
if drain and buffer:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.event_controller.add_all([
|
||||
self.message_to_event(msg) for msg in buffer
|
||||
]), self.loop
|
||||
)
|
||||
buffer.clear()
|
||||
last_drained_ms_ago = time.perf_counter()
|
||||
|
||||
if msg == self.STOP:
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.event_controller.add(self.message_to_event(msg)), self.loop
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
self.queue.put(self.STOP)
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.event import EventController
|
||||
from lbry.event import EventController, EventQueuePublisher
|
||||
from lbry.tasks import TaskGroup
|
||||
|
||||
|
||||
|
@ -22,7 +27,7 @@ class StreamControllerTestCase(AsyncioTestCase):
|
|||
self.assertListEqual(events, ["yo"])
|
||||
|
||||
async def test_sync_listener_errors(self):
|
||||
def bad_listener(e):
|
||||
def bad_listener(_):
|
||||
raise ValueError('bad')
|
||||
controller = EventController()
|
||||
controller.stream.listen(bad_listener)
|
||||
|
@ -30,7 +35,7 @@ class StreamControllerTestCase(AsyncioTestCase):
|
|||
await controller.add("yo")
|
||||
|
||||
async def test_async_listener_errors(self):
|
||||
async def bad_listener(e):
|
||||
async def bad_listener(_):
|
||||
raise ValueError('bad')
|
||||
controller = EventController()
|
||||
controller.stream.listen(bad_listener)
|
||||
|
@ -55,6 +60,43 @@ class StreamControllerTestCase(AsyncioTestCase):
|
|||
self.assertEqual("two", await last)
|
||||
|
||||
|
||||
class TestEventQueuePublisher(AsyncioTestCase):
|
||||
|
||||
async def test_event_buffering_avoids_overloading_asyncio(self):
|
||||
threads = 3
|
||||
generate_events = 3000
|
||||
expected_event_count = (threads * generate_events)-1
|
||||
|
||||
queue = mp.Queue()
|
||||
executor = ThreadPoolExecutor(max_workers=threads)
|
||||
controller = EventController()
|
||||
events = []
|
||||
|
||||
async def event_logger(e):
|
||||
await asyncio.sleep(0)
|
||||
events.append(e)
|
||||
|
||||
controller.stream.listen(event_logger)
|
||||
until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)
|
||||
|
||||
def event_producer(q, j):
|
||||
for i in range(generate_events):
|
||||
q.put(f'foo-{i}-{j}')
|
||||
|
||||
with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
|
||||
# assertLogs() requires that at least one message is logged
|
||||
# this is that one message:
|
||||
logging.getLogger().info("placeholder")
|
||||
await asyncio.wait([
|
||||
self.loop.run_in_executor(executor, event_producer, queue, j)
|
||||
for j in range(threads)
|
||||
])
|
||||
await until_all_consumed
|
||||
# assert that there were no WARNINGs from asyncio about slow tasks
|
||||
# (should have exactly 1 log which is the placeholder above)
|
||||
self.assertEqual(['INFO:root:placeholder'], logs.output)
|
||||
|
||||
|
||||
class TaskGroupTestCase(AsyncioTestCase):
|
||||
|
||||
async def test_cancel_sets_it_done(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue