forked from LBRYCommunity/lbry-sdk
test reflector connection breaking mid-transfer
This commit is contained in:
parent
506d3f3cd9
commit
b3b5e3d8f0
2 changed files with 129 additions and 7 deletions
|
@ -15,7 +15,10 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReflectorServerProtocol(asyncio.Protocol):
|
class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||||
|
stop_event: typing.Optional[asyncio.Event] = None,
|
||||||
|
incoming_event: typing.Optional[asyncio.Event] = None,
|
||||||
|
not_incoming_event: typing.Optional[asyncio.Event] = None):
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.blob_manager = blob_manager
|
self.blob_manager = blob_manager
|
||||||
self.server_task: asyncio.Task = None
|
self.server_task: asyncio.Task = None
|
||||||
|
@ -27,11 +30,24 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
self.descriptor: typing.Optional['StreamDescriptor'] = None
|
self.descriptor: typing.Optional['StreamDescriptor'] = None
|
||||||
self.sd_blob: typing.Optional['BlobFile'] = None
|
self.sd_blob: typing.Optional['BlobFile'] = None
|
||||||
self.received = []
|
self.received = []
|
||||||
self.incoming = asyncio.Event(loop=self.loop)
|
self.incoming = incoming_event or asyncio.Event(loop=self.loop)
|
||||||
|
self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||||
|
self.stop_event = stop_event or asyncio.Event(loop=self.loop)
|
||||||
self.chunk_size = response_chunk_size
|
self.chunk_size = response_chunk_size
|
||||||
|
|
||||||
|
async def wait_for_stop(self):
|
||||||
|
await self.stop_event.wait()
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
|
||||||
|
|
||||||
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
|
if self.wait_for_stop_task:
|
||||||
|
self.wait_for_stop_task.cancel()
|
||||||
|
self.wait_for_stop_task = None
|
||||||
|
|
||||||
def data_received(self, data: bytes):
|
def data_received(self, data: bytes):
|
||||||
if self.incoming.is_set():
|
if self.incoming.is_set():
|
||||||
|
@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||||
if not self.sd_blob.get_is_verified():
|
if not self.sd_blob.get_is_verified():
|
||||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||||
|
self.not_incoming.clear()
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_sd_blob": True})
|
self.send_response({"send_sd_blob": True})
|
||||||
try:
|
try:
|
||||||
|
@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
finally:
|
finally:
|
||||||
self.incoming.clear()
|
self.incoming.clear()
|
||||||
|
self.not_incoming.set()
|
||||||
self.writer.close_handle()
|
self.writer.close_handle()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
else:
|
else:
|
||||||
|
@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||||
)
|
)
|
||||||
self.incoming.clear()
|
self.incoming.clear()
|
||||||
|
self.not_incoming.set()
|
||||||
if self.writer:
|
if self.writer:
|
||||||
self.writer.close_handle()
|
self.writer.close_handle()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||||
if not blob.get_is_verified():
|
if not blob.get_is_verified():
|
||||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||||
|
self.not_incoming.clear()
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_blob": True})
|
self.send_response({"send_blob": True})
|
||||||
try:
|
try:
|
||||||
|
@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.send_response({"received_blob": False})
|
self.send_response({"received_blob": False})
|
||||||
self.incoming.clear()
|
self.incoming.clear()
|
||||||
|
self.not_incoming.set()
|
||||||
self.writer.close_handle()
|
self.writer.close_handle()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
else:
|
else:
|
||||||
|
@ -130,12 +151,19 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ReflectorServer:
|
class ReflectorServer:
|
||||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||||
|
stop_event: typing.Optional[asyncio.Event] = None,
|
||||||
|
incoming_event: typing.Optional[asyncio.Event] = None,
|
||||||
|
not_incoming_event: typing.Optional[asyncio.Event] = None):
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.blob_manager = blob_manager
|
self.blob_manager = blob_manager
|
||||||
self.server_task: typing.Optional[asyncio.Task] = None
|
self.server_task: typing.Optional[asyncio.Task] = None
|
||||||
self.started_listening = asyncio.Event(loop=self.loop)
|
self.started_listening = asyncio.Event(loop=self.loop)
|
||||||
|
self.stopped_listening = asyncio.Event(loop=self.loop)
|
||||||
|
self.incoming_event = incoming_event or asyncio.Event(loop=self.loop)
|
||||||
|
self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||||
self.response_chunk_size = response_chunk_size
|
self.response_chunk_size = response_chunk_size
|
||||||
|
self.stop_event = stop_event
|
||||||
|
|
||||||
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
|
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
|
||||||
if self.server_task is not None:
|
if self.server_task is not None:
|
||||||
|
@ -143,13 +171,20 @@ class ReflectorServer:
|
||||||
|
|
||||||
async def _start_server():
|
async def _start_server():
|
||||||
server = await self.loop.create_server(
|
server = await self.loop.create_server(
|
||||||
lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size),
|
lambda: ReflectorServerProtocol(
|
||||||
|
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
|
||||||
|
self.not_incoming_event
|
||||||
|
),
|
||||||
interface, port
|
interface, port
|
||||||
)
|
)
|
||||||
self.started_listening.set()
|
self.started_listening.set()
|
||||||
|
self.stopped_listening.clear()
|
||||||
log.info("Reflector server listening on TCP %s:%i", interface, port)
|
log.info("Reflector server listening on TCP %s:%i", interface, port)
|
||||||
|
try:
|
||||||
async with server:
|
async with server:
|
||||||
await server.serve_forever()
|
await server.serve_forever()
|
||||||
|
finally:
|
||||||
|
self.stopped_listening.set()
|
||||||
|
|
||||||
self.server_task = self.loop.create_task(_start_server())
|
self.server_task = self.loop.create_task(_start_server())
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager
|
||||||
from lbry.stream.reflector.server import ReflectorServer
|
from lbry.stream.reflector.server import ReflectorServer
|
||||||
|
|
||||||
|
|
||||||
class TestStreamAssembler(AsyncioTestCase):
|
class TestReflector(AsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.key = b'deadbeef' * 4
|
self.key = b'deadbeef' * 4
|
||||||
|
@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
|
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
|
||||||
await self.storage.open()
|
await self.storage.open()
|
||||||
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
|
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
|
||||||
|
self.addCleanup(self.blob_manager.stop)
|
||||||
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
|
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
|
||||||
|
|
||||||
server_tmp_dir = tempfile.mkdtemp()
|
server_tmp_dir = tempfile.mkdtemp()
|
||||||
|
@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
|
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
|
||||||
await self.server_storage.open()
|
await self.server_storage.open()
|
||||||
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
|
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
|
||||||
|
self.addCleanup(self.server_blob_manager.stop)
|
||||||
|
|
||||||
download_dir = tempfile.mkdtemp()
|
download_dir = tempfile.mkdtemp()
|
||||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||||
|
@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
set(map(lambda b: b.blob_hash,
|
set(map(lambda b: b.blob_hash,
|
||||||
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
|
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
|
||||||
)
|
)
|
||||||
|
self.assertTrue(self.stream.is_fully_reflected)
|
||||||
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
|
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
|
||||||
self.assertTrue(server_sd_blob.get_is_verified())
|
self.assertTrue(server_sd_blob.get_is_verified())
|
||||||
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
|
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
|
||||||
|
@ -75,3 +78,87 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
to_announce = await self.storage.get_blobs_to_announce()
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
|
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
|
||||||
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
|
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
|
||||||
|
|
||||||
|
async def test_result_from_disconnect_mid_sd_transfer(self):
|
||||||
|
stop = asyncio.Event()
|
||||||
|
incoming = asyncio.Event()
|
||||||
|
reflector = ReflectorServer(
|
||||||
|
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming
|
||||||
|
)
|
||||||
|
reflector.start_server(5566, '127.0.0.1')
|
||||||
|
await reflector.started_listening.wait()
|
||||||
|
self.addCleanup(reflector.stop_server)
|
||||||
|
self.assertEqual(0, self.stream.reflector_progress)
|
||||||
|
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||||
|
await incoming.wait()
|
||||||
|
stop.set()
|
||||||
|
# this used to raise (and then propagate) a CancelledError
|
||||||
|
self.assertListEqual(await reflect_task, [])
|
||||||
|
self.assertFalse(self.stream.is_fully_reflected)
|
||||||
|
|
||||||
|
async def test_result_from_disconnect_after_sd_transfer(self):
|
||||||
|
stop = asyncio.Event()
|
||||||
|
incoming = asyncio.Event()
|
||||||
|
not_incoming = asyncio.Event()
|
||||||
|
reflector = ReflectorServer(
|
||||||
|
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||||
|
not_incoming_event=not_incoming
|
||||||
|
)
|
||||||
|
reflector.start_server(5566, '127.0.0.1')
|
||||||
|
await reflector.started_listening.wait()
|
||||||
|
self.addCleanup(reflector.stop_server)
|
||||||
|
self.assertEqual(0, self.stream.reflector_progress)
|
||||||
|
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||||
|
await incoming.wait()
|
||||||
|
await not_incoming.wait()
|
||||||
|
stop.set()
|
||||||
|
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||||
|
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||||
|
self.assertFalse(self.stream.is_fully_reflected)
|
||||||
|
|
||||||
|
async def test_result_from_disconnect_after_data_transfer(self):
|
||||||
|
stop = asyncio.Event()
|
||||||
|
incoming = asyncio.Event()
|
||||||
|
not_incoming = asyncio.Event()
|
||||||
|
reflector = ReflectorServer(
|
||||||
|
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||||
|
not_incoming_event=not_incoming
|
||||||
|
)
|
||||||
|
reflector.start_server(5566, '127.0.0.1')
|
||||||
|
await reflector.started_listening.wait()
|
||||||
|
self.addCleanup(reflector.stop_server)
|
||||||
|
self.assertEqual(0, self.stream.reflector_progress)
|
||||||
|
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||||
|
await incoming.wait()
|
||||||
|
await not_incoming.wait()
|
||||||
|
await incoming.wait()
|
||||||
|
await not_incoming.wait()
|
||||||
|
stop.set()
|
||||||
|
self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash])
|
||||||
|
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||||
|
self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified())
|
||||||
|
self.assertFalse(self.stream.is_fully_reflected)
|
||||||
|
|
||||||
|
async def test_result_from_disconnect_mid_data_transfer(self):
|
||||||
|
stop = asyncio.Event()
|
||||||
|
incoming = asyncio.Event()
|
||||||
|
not_incoming = asyncio.Event()
|
||||||
|
reflector = ReflectorServer(
|
||||||
|
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||||
|
not_incoming_event=not_incoming
|
||||||
|
)
|
||||||
|
reflector.start_server(5566, '127.0.0.1')
|
||||||
|
await reflector.started_listening.wait()
|
||||||
|
self.addCleanup(reflector.stop_server)
|
||||||
|
self.assertEqual(0, self.stream.reflector_progress)
|
||||||
|
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||||
|
await incoming.wait()
|
||||||
|
await not_incoming.wait()
|
||||||
|
await incoming.wait()
|
||||||
|
stop.set()
|
||||||
|
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||||
|
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||||
|
self.assertFalse(
|
||||||
|
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
|
||||||
|
)
|
||||||
|
self.assertFalse(self.stream.is_fully_reflected)
|
||||||
|
|
Loading…
Reference in a new issue