forked from LBRYCommunity/lbry-sdk
Merge pull request #2997 from lbryio/fix-reflector-lost-connection
Fix uncaught reflector connection errors
This commit is contained in:
commit
6ed1614db0
6 changed files with 166 additions and 12 deletions
|
@ -5,9 +5,8 @@ import typing
|
|||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from contextlib import contextmanager
|
||||
|
||||
import yaml
|
||||
from appdirs import user_data_dir, user_config_dir
|
||||
import yaml
|
||||
from lbry.error import InvalidCurrencyError
|
||||
from lbry.dht import constants
|
||||
from lbry.wallet.coinselection import STRATEGIES
|
||||
|
@ -334,7 +333,7 @@ class ConfigFileAccess:
|
|||
cls = type(self.configuration)
|
||||
with open(self.path, 'r') as config_file:
|
||||
raw = config_file.read()
|
||||
serialized = yaml.load(raw) or {}
|
||||
serialized = yaml.safe_load(raw) or {}
|
||||
for key, value in serialized.items():
|
||||
attr = getattr(cls, key, None)
|
||||
if attr is None:
|
||||
|
|
|
@ -356,6 +356,9 @@ class ManagedStream(ManagedDownloadSource):
|
|||
return sent
|
||||
except ConnectionRefusedError:
|
||||
return sent
|
||||
except OSError:
|
||||
# raised if a blob is deleted while it's being sent
|
||||
return sent
|
||||
finally:
|
||||
if protocol.transport:
|
||||
protocol.transport.close()
|
||||
|
|
|
@ -60,10 +60,16 @@ class StreamReflectorClient(asyncio.Protocol):
|
|||
|
||||
async def send_request(self, request_dict: typing.Dict, timeout: int = 180):
|
||||
msg = json.dumps(request_dict)
|
||||
self.transport.write(msg.encode())
|
||||
try:
|
||||
self.transport.write(msg.encode())
|
||||
self.pending_request = self.loop.create_task(asyncio.wait_for(self.response_queue.get(), timeout))
|
||||
return await self.pending_request
|
||||
except (AttributeError, asyncio.CancelledError):
|
||||
# attribute error happens when we transport.write after disconnect
|
||||
# cancelled error happens when the pending_request task is cancelled by a disconnect
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
raise asyncio.TimeoutError()
|
||||
finally:
|
||||
self.pending_request = None
|
||||
|
||||
|
|
|
@ -15,7 +15,9 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ReflectorServerProtocol(asyncio.Protocol):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
|
||||
not_incoming_event: asyncio.Event = None):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: asyncio.Task = None
|
||||
|
@ -27,11 +29,25 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.descriptor: typing.Optional['StreamDescriptor'] = None
|
||||
self.sd_blob: typing.Optional['BlobFile'] = None
|
||||
self.received = []
|
||||
self.incoming = asyncio.Event(loop=self.loop)
|
||||
self.incoming = incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.stop_event = stop_event or asyncio.Event(loop=self.loop)
|
||||
self.chunk_size = response_chunk_size
|
||||
self.wait_for_stop_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
async def wait_for_stop(self):
|
||||
await self.stop_event.wait()
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.wait_for_stop_task:
|
||||
self.wait_for_stop_task.cancel()
|
||||
self.wait_for_stop_task = None
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
if self.incoming.is_set():
|
||||
|
@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.not_incoming.clear()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
|
@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.transport.close()
|
||||
finally:
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
else:
|
||||
|
@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
if self.writer:
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
|
@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.not_incoming.clear()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
|
@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
except asyncio.TimeoutError:
|
||||
self.send_response({"received_blob": False})
|
||||
self.incoming.clear()
|
||||
self.not_incoming.set()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
else:
|
||||
|
@ -130,12 +151,18 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
|
||||
|
||||
class ReflectorServer:
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000):
|
||||
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
|
||||
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
|
||||
not_incoming_event: asyncio.Event = None):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: typing.Optional[asyncio.Task] = None
|
||||
self.started_listening = asyncio.Event(loop=self.loop)
|
||||
self.stopped_listening = asyncio.Event(loop=self.loop)
|
||||
self.incoming_event = incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop)
|
||||
self.response_chunk_size = response_chunk_size
|
||||
self.stop_event = stop_event
|
||||
|
||||
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
|
||||
if self.server_task is not None:
|
||||
|
@ -143,13 +170,20 @@ class ReflectorServer:
|
|||
|
||||
async def _start_server():
|
||||
server = await self.loop.create_server(
|
||||
lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size),
|
||||
lambda: ReflectorServerProtocol(
|
||||
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
|
||||
self.not_incoming_event
|
||||
),
|
||||
interface, port
|
||||
)
|
||||
self.started_listening.set()
|
||||
self.stopped_listening.clear()
|
||||
log.info("Reflector server listening on TCP %s:%i", interface, port)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
try:
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
finally:
|
||||
self.stopped_listening.set()
|
||||
|
||||
self.server_task = self.loop.create_task(_start_server())
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -46,7 +46,7 @@ setup(
|
|||
'msgpack==0.6.1',
|
||||
'prometheus_client==0.7.1',
|
||||
'ecdsa==0.13.3',
|
||||
'pyyaml==4.2b1',
|
||||
'pyyaml==5.3.1',
|
||||
'docopt==0.6.2',
|
||||
'hachoir',
|
||||
'multidict==4.6.1',
|
||||
|
|
|
@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager
|
|||
from lbry.stream.reflector.server import ReflectorServer
|
||||
|
||||
|
||||
class TestStreamAssembler(AsyncioTestCase):
|
||||
class TestReflector(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.key = b'deadbeef' * 4
|
||||
|
@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
|
||||
await self.storage.open()
|
||||
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
|
||||
self.addCleanup(self.blob_manager.stop)
|
||||
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
|
||||
|
||||
server_tmp_dir = tempfile.mkdtemp()
|
||||
|
@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
|
||||
await self.server_storage.open()
|
||||
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
|
||||
self.addCleanup(self.server_blob_manager.stop)
|
||||
|
||||
download_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
|
@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
set(map(lambda b: b.blob_hash,
|
||||
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
|
||||
)
|
||||
self.assertTrue(self.stream.is_fully_reflected)
|
||||
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
|
||||
self.assertTrue(server_sd_blob.get_is_verified())
|
||||
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
|
||||
|
@ -75,3 +78,112 @@ class TestStreamAssembler(AsyncioTestCase):
|
|||
to_announce = await self.storage.get_blobs_to_announce()
|
||||
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
|
||||
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
|
||||
|
||||
async def test_result_from_disconnect_mid_sd_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
stop.set()
|
||||
# this used to raise (and then propagate) a CancelledError
|
||||
self.assertListEqual(await reflect_task, [])
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_after_sd_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_after_data_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified())
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_result_from_disconnect_mid_data_transfer(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
await incoming.wait()
|
||||
stop.set()
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertFalse(
|
||||
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
|
||||
)
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
||||
async def test_delete_file_during_reflector_upload(self):
|
||||
stop = asyncio.Event()
|
||||
incoming = asyncio.Event()
|
||||
not_incoming = asyncio.Event()
|
||||
reflector = ReflectorServer(
|
||||
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
|
||||
not_incoming_event=not_incoming
|
||||
)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
self.assertEqual(0, self.stream.reflector_progress)
|
||||
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
|
||||
await incoming.wait()
|
||||
await not_incoming.wait()
|
||||
await incoming.wait()
|
||||
await self.stream_manager.delete(self.stream, delete_file=True)
|
||||
# this used to raise OSError when it can't read the deleted blob for the upload
|
||||
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
|
||||
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
|
||||
self.assertFalse(
|
||||
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
|
||||
)
|
||||
self.assertFalse(self.stream.is_fully_reflected)
|
||||
|
|
Loading…
Reference in a new issue