forked from LBRYCommunity/lbry-sdk
EventQueuePublisher uses a buffer to reduce number of tasks created
This commit is contained in:
parent
2af29b892b
commit
596ed08395
2 changed files with 81 additions and 9 deletions
|
@ -1,7 +1,9 @@
|
||||||
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import multiprocessing
|
|
||||||
import logging
|
import logging
|
||||||
|
from queue import Empty
|
||||||
|
from multiprocessing import Queue
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -86,6 +88,10 @@ class EventController:
|
||||||
for subscription in self._iterate_subscriptions:
|
for subscription in self._iterate_subscriptions:
|
||||||
await self._notify(subscription._add, event)
|
await self._notify(subscription._add, event)
|
||||||
|
|
||||||
|
async def add_all(self, events):
|
||||||
|
for event in events:
|
||||||
|
await self.add(event)
|
||||||
|
|
||||||
async def add_error(self, exception):
|
async def add_error(self, exception):
|
||||||
for subscription in self._iterate_subscriptions:
|
for subscription in self._iterate_subscriptions:
|
||||||
await self._notify(subscription._add_error, exception)
|
await self._notify(subscription._add_error, exception)
|
||||||
|
@ -183,7 +189,7 @@ class EventQueuePublisher(threading.Thread):
|
||||||
|
|
||||||
STOP = 'STOP'
|
STOP = 'STOP'
|
||||||
|
|
||||||
def __init__(self, queue: multiprocessing.Queue, event_controller: EventController):
|
def __init__(self, queue: Queue, event_controller: EventController):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.event_controller = event_controller
|
self.event_controller = event_controller
|
||||||
|
@ -197,13 +203,37 @@ class EventQueuePublisher(threading.Thread):
|
||||||
super().start()
|
super().start()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
queue_get_timeout = 0.2
|
||||||
|
buffer_drain_size = 100
|
||||||
|
buffer_drain_timeout = 0.1
|
||||||
|
|
||||||
|
buffer = []
|
||||||
|
last_drained_ms_ago = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
msg = self.queue.get()
|
|
||||||
|
try:
|
||||||
|
msg = self.queue.get(timeout=queue_get_timeout)
|
||||||
|
if msg != self.STOP:
|
||||||
|
buffer.append(msg)
|
||||||
|
except Empty:
|
||||||
|
msg = None
|
||||||
|
|
||||||
|
drain = any((
|
||||||
|
len(buffer) >= buffer_drain_size,
|
||||||
|
(time.perf_counter() - last_drained_ms_ago) >= buffer_drain_timeout,
|
||||||
|
msg == self.STOP
|
||||||
|
))
|
||||||
|
if drain and buffer:
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.event_controller.add_all([
|
||||||
|
self.message_to_event(msg) for msg in buffer
|
||||||
|
]), self.loop
|
||||||
|
)
|
||||||
|
buffer.clear()
|
||||||
|
last_drained_ms_ago = time.perf_counter()
|
||||||
|
|
||||||
if msg == self.STOP:
|
if msg == self.STOP:
|
||||||
return
|
return
|
||||||
asyncio.run_coroutine_threadsafe(
|
|
||||||
self.event_controller.add(self.message_to_event(msg)), self.loop
|
|
||||||
)
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.queue.put(self.STOP)
|
self.queue.put(self.STOP)
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from lbry.testcase import AsyncioTestCase
|
from lbry.testcase import AsyncioTestCase
|
||||||
from lbry.event import EventController
|
from lbry.event import EventController, EventQueuePublisher
|
||||||
from lbry.tasks import TaskGroup
|
from lbry.tasks import TaskGroup
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +27,7 @@ class StreamControllerTestCase(AsyncioTestCase):
|
||||||
self.assertListEqual(events, ["yo"])
|
self.assertListEqual(events, ["yo"])
|
||||||
|
|
||||||
async def test_sync_listener_errors(self):
|
async def test_sync_listener_errors(self):
|
||||||
def bad_listener(e):
|
def bad_listener(_):
|
||||||
raise ValueError('bad')
|
raise ValueError('bad')
|
||||||
controller = EventController()
|
controller = EventController()
|
||||||
controller.stream.listen(bad_listener)
|
controller.stream.listen(bad_listener)
|
||||||
|
@ -30,7 +35,7 @@ class StreamControllerTestCase(AsyncioTestCase):
|
||||||
await controller.add("yo")
|
await controller.add("yo")
|
||||||
|
|
||||||
async def test_async_listener_errors(self):
|
async def test_async_listener_errors(self):
|
||||||
async def bad_listener(e):
|
async def bad_listener(_):
|
||||||
raise ValueError('bad')
|
raise ValueError('bad')
|
||||||
controller = EventController()
|
controller = EventController()
|
||||||
controller.stream.listen(bad_listener)
|
controller.stream.listen(bad_listener)
|
||||||
|
@ -55,6 +60,43 @@ class StreamControllerTestCase(AsyncioTestCase):
|
||||||
self.assertEqual("two", await last)
|
self.assertEqual("two", await last)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventQueuePublisher(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def test_event_buffering_avoids_overloading_asyncio(self):
|
||||||
|
threads = 3
|
||||||
|
generate_events = 3000
|
||||||
|
expected_event_count = (threads * generate_events)-1
|
||||||
|
|
||||||
|
queue = mp.Queue()
|
||||||
|
executor = ThreadPoolExecutor(max_workers=threads)
|
||||||
|
controller = EventController()
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def event_logger(e):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
events.append(e)
|
||||||
|
|
||||||
|
controller.stream.listen(event_logger)
|
||||||
|
until_all_consumed = controller.stream.where(lambda _: len(events) == expected_event_count)
|
||||||
|
|
||||||
|
def event_producer(q, j):
|
||||||
|
for i in range(generate_events):
|
||||||
|
q.put(f'foo-{i}-{j}')
|
||||||
|
|
||||||
|
with EventQueuePublisher(queue, controller), self.assertLogs() as logs:
|
||||||
|
# assertLogs() requires that at least one message is logged
|
||||||
|
# this is that one message:
|
||||||
|
logging.getLogger().info("placeholder")
|
||||||
|
await asyncio.wait([
|
||||||
|
self.loop.run_in_executor(executor, event_producer, queue, j)
|
||||||
|
for j in range(threads)
|
||||||
|
])
|
||||||
|
await until_all_consumed
|
||||||
|
# assert that there were no WARNINGs from asyncio about slow tasks
|
||||||
|
# (should have exactly 1 log which is the placeholder above)
|
||||||
|
self.assertEqual(['INFO:root:placeholder'], logs.output)
|
||||||
|
|
||||||
|
|
||||||
class TaskGroupTestCase(AsyncioTestCase):
|
class TaskGroupTestCase(AsyncioTestCase):
|
||||||
|
|
||||||
async def test_cancel_sets_it_done(self):
|
async def test_cancel_sets_it_done(self):
|
||||||
|
|
Loading…
Reference in a new issue