diff --git a/lbrynet/stream/reflector/server.py b/lbrynet/stream/reflector/server.py index fd43f0c3a..0a01f397e 100644 --- a/lbrynet/stream/reflector/server.py +++ b/lbrynet/stream/reflector/server.py @@ -15,7 +15,7 @@ log = logging.getLogger(__name__) class ReflectorServerProtocol(asyncio.Protocol): - def __init__(self, blob_manager: 'BlobManager'): + def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager self.server_task: asyncio.Task = None @@ -28,6 +28,7 @@ class ReflectorServerProtocol(asyncio.Protocol): self.sd_blob: typing.Optional['BlobFile'] = None self.received = [] self.incoming = asyncio.Event(loop=self.loop) + self.chunk_size = response_chunk_size def connection_made(self, transport): self.transport = transport @@ -47,7 +48,15 @@ class ReflectorServerProtocol(asyncio.Protocol): self.loop.create_task(self.handle_request(request)) def send_response(self, response: typing.Dict): - self.transport.write(json.dumps(response).encode()) + def chunk_response(remaining: bytes): + f = self.loop.create_future() + f.add_done_callback(lambda _: self.transport.write(remaining[:self.chunk_size])) + if len(remaining) > self.chunk_size: + f.add_done_callback(lambda _: self.loop.call_soon(chunk_response, remaining[self.chunk_size:])) + self.loop.call_soon(f.set_result, None) + + response_bytes = json.dumps(response).encode() + chunk_response(response_bytes) async def handle_request(self, request: typing.Dict): if self.client_version is None: @@ -121,11 +130,12 @@ class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServer: - def __init__(self, blob_manager: 'BlobManager'): + def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager - self.server_task: asyncio.Task = None + self.server_task: typing.Optional[asyncio.Task] = None self.started_listening = asyncio.Event(loop=self.loop) + self.response_chunk_size = response_chunk_size def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): if self.server_task is not None: @@ -133,7 +143,7 @@ class ReflectorServer: async def _start_server(): server = await self.loop.create_server( - lambda: ReflectorServerProtocol(self.blob_manager), + lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size), interface, port ) self.started_listening.set() diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index 2034b5af9..1d820f009 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -41,8 +41,8 @@ class TestStreamAssembler(AsyncioTestCase): self.stream = await self.stream_manager.create_stream(file_path) - async def test_reflect_stream(self): - reflector = ReflectorServer(self.server_blob_manager) + async def _test_reflect_stream(self, response_chunk_size): + reflector = ReflectorServer(self.server_blob_manager, response_chunk_size=response_chunk_size) reflector.start_server(5566, '127.0.0.1') await reflector.started_listening.wait() self.addCleanup(reflector.stop_server) @@ -63,6 +63,12 @@ class TestStreamAssembler(AsyncioTestCase): sent = await self.stream.upload_to_reflector('127.0.0.1', 5566) self.assertListEqual(sent, []) + async def test_reflect_stream(self): + return await asyncio.wait_for(self._test_reflect_stream(response_chunk_size=50), 3, loop=self.loop) + + async def test_reflect_stream_small_response_chunks(self): + return await asyncio.wait_for(self._test_reflect_stream(response_chunk_size=30), 3, loop=self.loop) + async def test_announces(self): to_announce = await self.storage.get_blobs_to_announce() self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")