test reflector connection breaking mid-transfer
This commit is contained in:
parent
506d3f3cd9
commit
b3b5e3d8f0
2 changed files with 129 additions and 7 deletions
|
@ -15,7 +15,10 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ReflectorServerProtocol(asyncio.Protocol):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||
stop_event: typing.Optional[asyncio.Event] = None,
|
||||
incoming_event: typing.Optional[asyncio.Event] = None,
|
||||
not_incoming_event: typing.Optional[asyncio.Event] = None):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: asyncio.Task = None
|
||||
|
@ -27,11 +30,24 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.descriptor: typing.Optional['StreamDescriptor'] = None
|
||||
self.sd_blob: typing.Optional['BlobFile'] = None
|
||||
self.received = []
|
||||
self.incoming = asyncio.Event(loop=self.loop)
|
||||
self.incoming = incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.stop_event = stop_event or asyncio.Event(loop=self.loop)
|
||||
self.chunk_size = response_chunk_size
|
||||
|
||||
async def wait_for_stop(self):
|
||||
await self.stop_event.wait()
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
if self.wait_for_stop_task:
|
||||
self.wait_for_stop_task.cancel()
|
||||
self.wait_for_stop_task = None
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
if self.incoming.is_set():
|
||||
|
@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.not_incoming.clear()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
|
@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.transport.close()
|
||||
finally:
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
else:
|
||||
|
@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
if self.writer:
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
|
@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.not_incoming.clear()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
|
@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
except asyncio.TimeoutError:
|
||||
self.send_response({"received_blob": False})
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
else:
|
||||
|
@ -130,12 +151,19 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
|
||||
|
||||
class ReflectorServer:
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||
stop_event: typing.Optional[asyncio.Event] = None,
|
||||
incoming_event: typing.Optional[asyncio.Event] = None,
|
||||
not_incoming_event: typing.Optional[asyncio.Event] = None):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: typing.Optional[asyncio.Task] = None
|
||||
self.started_listening = asyncio.Event(loop=self.loop)
|
||||
self.stopped_listening = asyncio.Event(loop=self.loop)
|
||||
self.incoming_event = incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.response_chunk_size = response_chunk_size
|
||||
self.stop_event = stop_event
|
||||
|
||||
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
|
||||
if self.server_task is not None:
|
||||
|
@ -143,13 +171,20 @@ class ReflectorServer:
|
|||
|
||||
async def _start_server():
|
||||
server = await self.loop.create_server(
|
||||
lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size),
|
||||
lambda: ReflectorServerProtocol(
|
||||
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
|
||||
self.not_incoming_event
|
||||
),
|
||||
interface, port
|
||||
)
|
||||
self.started_listening.set()
|
||||
self.stopped_listening.clear()
|
||||
log.info("Reflector server listening on TCP %s:%i", interface, port)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
try:
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
finally:
|
||||
self.stopped_listening.set()
|
||||
|
||||
self.server_task = self.loop.create_task(_start_server())
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager
|
|||
from lbry.stream.reflector.server import ReflectorServer
|
||||
|
||||
|
||||
class TestStreamAssembler(AsyncioTestCase):
|
||||
class TestReflector(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.key = b'deadbeef' * 4
|
||||
|
@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
|
||||
await self.storage.open()
|
||||
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
|
||||
self.addCleanup(self.blob_manager.stop)
|
||||
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
|
||||
|
||||
server_tmp_dir = tempfile.mkdtemp()
|
||||
|
@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
|
||||
await self.server_storage.open()
|
||||
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
|
||||
self.addCleanup(self.server_blob_manager.stop)
|
||||
|
||||
download_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
|
@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
set(map(lambda b: b.blob_hash,
|
||||
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
|
||||
)
|
||||
self.assertTrue(self.stream.is_fully_reflected)
|
||||
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
|
||||
self.assertTrue(server_sd_blob.get_is_verified())
|
||||
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
|
||||
|
@ -75,3 +78,87 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
to_announce = await self.storage.get_blobs_to_announce()
|
||||
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
|
||||
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
|
||||
|
||||
async def test_result_from_disconnect_mid_sd_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
stop.set()
|
||||
# this used to raise (and then propagate) a CancelledError
|
||||
self.assertListEqual(await reflect_task, [])
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_after_sd_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_after_data_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified())
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_mid_data_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
await incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertFalse(
|
||||
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
|
||||
)
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
|
Loading…
Reference in a new issue