test reflector connection breaking mid-transfer

This commit is contained in:
Jack Robison 2020-07-15 16:39:59 -04:00
parent 506d3f3cd9
commit b3b5e3d8f0
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 129 additions and 7 deletions

View file

@ -15,7 +15,10 @@ log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
stop_event: typing.Optional[asyncio.Event] = None,
incoming_event: typing.Optional[asyncio.Event] = None,
not_incoming_event: typing.Optional[asyncio.Event] = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: asyncio.Task = None self.server_task: asyncio.Task = None
@ -27,11 +30,24 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.descriptor: typing.Optional['StreamDescriptor'] = None self.descriptor: typing.Optional['StreamDescriptor'] = None
self.sd_blob: typing.Optional['BlobFile'] = None self.sd_blob: typing.Optional['BlobFile'] = None
self.received = [] self.received = []
self.incoming = asyncio.Event(loop=self.loop) self.incoming = incoming_event or asyncio.Event(loop=self.loop)
self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop)
self.stop_event = stop_event or asyncio.Event(loop=self.loop)
self.chunk_size = response_chunk_size self.chunk_size = response_chunk_size
async def wait_for_stop(self):
await self.stop_event.wait()
if self.transport:
self.transport.close()
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
if self.wait_for_stop_task:
self.wait_for_stop_task.cancel()
self.wait_for_stop_task = None
def data_received(self, data: bytes): def data_received(self, data: bytes):
if self.incoming.is_set(): if self.incoming.is_set():
@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.transport.close() self.transport.close()
finally: finally:
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
if self.writer: if self.writer:
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:
@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.send_response({"received_blob": False}) self.send_response({"received_blob": False})
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -130,12 +151,19 @@ class ReflectorServerProtocol(asyncio.Protocol):
class ReflectorServer: class ReflectorServer:
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
stop_event: typing.Optional[asyncio.Event] = None,
incoming_event: typing.Optional[asyncio.Event] = None,
not_incoming_event: typing.Optional[asyncio.Event] = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: typing.Optional[asyncio.Task] = None self.server_task: typing.Optional[asyncio.Task] = None
self.started_listening = asyncio.Event(loop=self.loop) self.started_listening = asyncio.Event(loop=self.loop)
self.stopped_listening = asyncio.Event(loop=self.loop)
self.incoming_event = incoming_event or asyncio.Event(loop=self.loop)
self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop)
self.response_chunk_size = response_chunk_size self.response_chunk_size = response_chunk_size
self.stop_event = stop_event
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
if self.server_task is not None: if self.server_task is not None:
@ -143,13 +171,20 @@ class ReflectorServer:
async def _start_server(): async def _start_server():
server = await self.loop.create_server( server = await self.loop.create_server(
lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size), lambda: ReflectorServerProtocol(
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
self.not_incoming_event
),
interface, port interface, port
) )
self.started_listening.set() self.started_listening.set()
self.stopped_listening.clear()
log.info("Reflector server listening on TCP %s:%i", interface, port) log.info("Reflector server listening on TCP %s:%i", interface, port)
try:
async with server: async with server:
await server.serve_forever() await server.serve_forever()
finally:
self.stopped_listening.set()
self.server_task = self.loop.create_task(_start_server()) self.server_task = self.loop.create_task(_start_server())

View file

@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager
from lbry.stream.reflector.server import ReflectorServer from lbry.stream.reflector.server import ReflectorServer
class TestStreamAssembler(AsyncioTestCase): class TestReflector(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4 self.key = b'deadbeef' * 4
@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase):
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
await self.storage.open() await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
self.addCleanup(self.blob_manager.stop)
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None) self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
server_tmp_dir = tempfile.mkdtemp() server_tmp_dir = tempfile.mkdtemp()
@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase):
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
await self.server_storage.open() await self.server_storage.open()
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
self.addCleanup(self.server_blob_manager.stop)
download_dir = tempfile.mkdtemp() download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir)) self.addCleanup(lambda: shutil.rmtree(download_dir))
@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase):
set(map(lambda b: b.blob_hash, set(map(lambda b: b.blob_hash,
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)])) self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
) )
self.assertTrue(self.stream.is_fully_reflected)
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash) server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
self.assertTrue(server_sd_blob.get_is_verified()) self.assertTrue(server_sd_blob.get_is_verified())
self.assertEqual(server_sd_blob.length, server_sd_blob.length) self.assertEqual(server_sd_blob.length, server_sd_blob.length)
@ -75,3 +78,87 @@ class TestStreamAssembler(AsyncioTestCase):
to_announce = await self.storage.get_blobs_to_announce() to_announce = await self.storage.get_blobs_to_announce()
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce") self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce") self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
async def test_result_from_disconnect_mid_sd_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
stop.set()
# this used to raise (and then propagate) a CancelledError
self.assertListEqual(await reflect_task, [])
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_after_sd_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_after_data_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
await incoming.wait()
await not_incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified())
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_mid_data_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
await incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertFalse(
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
self.assertFalse(self.stream.is_fully_reflected)