Merge pull request #2997 from lbryio/fix-reflector-lost-connection

Fix uncaught reflector connection errors
This commit is contained in:
Lex Berezhny 2020-07-20 13:48:26 -04:00 committed by GitHub
commit 6ed1614db0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 12 deletions

View file

@ -5,9 +5,8 @@ import typing
import logging import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from contextlib import contextmanager from contextlib import contextmanager
import yaml
from appdirs import user_data_dir, user_config_dir from appdirs import user_data_dir, user_config_dir
import yaml
from lbry.error import InvalidCurrencyError from lbry.error import InvalidCurrencyError
from lbry.dht import constants from lbry.dht import constants
from lbry.wallet.coinselection import STRATEGIES from lbry.wallet.coinselection import STRATEGIES
@ -334,7 +333,7 @@ class ConfigFileAccess:
cls = type(self.configuration) cls = type(self.configuration)
with open(self.path, 'r') as config_file: with open(self.path, 'r') as config_file:
raw = config_file.read() raw = config_file.read()
serialized = yaml.load(raw) or {} serialized = yaml.safe_load(raw) or {}
for key, value in serialized.items(): for key, value in serialized.items():
attr = getattr(cls, key, None) attr = getattr(cls, key, None)
if attr is None: if attr is None:

View file

@ -356,6 +356,9 @@ class ManagedStream(ManagedDownloadSource):
return sent return sent
except ConnectionRefusedError: except ConnectionRefusedError:
return sent return sent
except OSError:
# raised if a blob is deleted while it's being sent
return sent
finally: finally:
if protocol.transport: if protocol.transport:
protocol.transport.close() protocol.transport.close()

View file

@ -60,10 +60,16 @@ class StreamReflectorClient(asyncio.Protocol):
async def send_request(self, request_dict: typing.Dict, timeout: int = 180): async def send_request(self, request_dict: typing.Dict, timeout: int = 180):
msg = json.dumps(request_dict) msg = json.dumps(request_dict)
self.transport.write(msg.encode())
try: try:
self.transport.write(msg.encode())
self.pending_request = self.loop.create_task(asyncio.wait_for(self.response_queue.get(), timeout)) self.pending_request = self.loop.create_task(asyncio.wait_for(self.response_queue.get(), timeout))
return await self.pending_request return await self.pending_request
except (AttributeError, asyncio.CancelledError):
# attribute error happens when we transport.write after disconnect
# cancelled error happens when the pending_request task is cancelled by a disconnect
if self.transport:
self.transport.close()
raise asyncio.TimeoutError()
finally: finally:
self.pending_request = None self.pending_request = None

View file

@ -15,7 +15,9 @@ log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
not_incoming_event: asyncio.Event = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: asyncio.Task = None self.server_task: asyncio.Task = None
@ -27,11 +29,25 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.descriptor: typing.Optional['StreamDescriptor'] = None self.descriptor: typing.Optional['StreamDescriptor'] = None
self.sd_blob: typing.Optional['BlobFile'] = None self.sd_blob: typing.Optional['BlobFile'] = None
self.received = [] self.received = []
self.incoming = asyncio.Event(loop=self.loop) self.incoming = incoming_event or asyncio.Event(loop=self.loop)
self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop)
self.stop_event = stop_event or asyncio.Event(loop=self.loop)
self.chunk_size = response_chunk_size self.chunk_size = response_chunk_size
self.wait_for_stop_task: typing.Optional[asyncio.Task] = None
async def wait_for_stop(self):
await self.stop_event.wait()
if self.transport:
self.transport.close()
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop())
def connection_lost(self, exc):
if self.wait_for_stop_task:
self.wait_for_stop_task.cancel()
self.wait_for_stop_task = None
def data_received(self, data: bytes): def data_received(self, data: bytes):
if self.incoming.is_set(): if self.incoming.is_set():
@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.transport.close() self.transport.close()
finally: finally:
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
if self.writer: if self.writer:
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.not_incoming.clear()
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:
@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.send_response({"received_blob": False}) self.send_response({"received_blob": False})
self.incoming.clear() self.incoming.clear()
self.not_incoming.set()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
else: else:
@ -130,12 +151,18 @@ class ReflectorServerProtocol(asyncio.Protocol):
class ReflectorServer: class ReflectorServer:
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
not_incoming_event: asyncio.Event = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: typing.Optional[asyncio.Task] = None self.server_task: typing.Optional[asyncio.Task] = None
self.started_listening = asyncio.Event(loop=self.loop) self.started_listening = asyncio.Event(loop=self.loop)
self.stopped_listening = asyncio.Event(loop=self.loop)
self.incoming_event = incoming_event or asyncio.Event(loop=self.loop)
self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop)
self.response_chunk_size = response_chunk_size self.response_chunk_size = response_chunk_size
self.stop_event = stop_event
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
if self.server_task is not None: if self.server_task is not None:
@ -143,13 +170,20 @@ class ReflectorServer:
async def _start_server(): async def _start_server():
server = await self.loop.create_server( server = await self.loop.create_server(
lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size), lambda: ReflectorServerProtocol(
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
self.not_incoming_event
),
interface, port interface, port
) )
self.started_listening.set() self.started_listening.set()
self.stopped_listening.clear()
log.info("Reflector server listening on TCP %s:%i", interface, port) log.info("Reflector server listening on TCP %s:%i", interface, port)
try:
async with server: async with server:
await server.serve_forever() await server.serve_forever()
finally:
self.stopped_listening.set()
self.server_task = self.loop.create_task(_start_server()) self.server_task = self.loop.create_task(_start_server())

View file

@ -46,7 +46,7 @@ setup(
'msgpack==0.6.1', 'msgpack==0.6.1',
'prometheus_client==0.7.1', 'prometheus_client==0.7.1',
'ecdsa==0.13.3', 'ecdsa==0.13.3',
'pyyaml==4.2b1', 'pyyaml==5.3.1',
'docopt==0.6.2', 'docopt==0.6.2',
'hachoir', 'hachoir',
'multidict==4.6.1', 'multidict==4.6.1',

View file

@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager
from lbry.stream.reflector.server import ReflectorServer from lbry.stream.reflector.server import ReflectorServer
class TestStreamAssembler(AsyncioTestCase): class TestReflector(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4 self.key = b'deadbeef' * 4
@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase):
self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite"))
await self.storage.open() await self.storage.open()
self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf)
self.addCleanup(self.blob_manager.stop)
self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None) self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None)
server_tmp_dir = tempfile.mkdtemp() server_tmp_dir = tempfile.mkdtemp()
@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase):
self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite"))
await self.server_storage.open() await self.server_storage.open()
self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf)
self.addCleanup(self.server_blob_manager.stop)
download_dir = tempfile.mkdtemp() download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir)) self.addCleanup(lambda: shutil.rmtree(download_dir))
@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase):
set(map(lambda b: b.blob_hash, set(map(lambda b: b.blob_hash,
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)])) self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
) )
self.assertTrue(self.stream.is_fully_reflected)
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash) server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
self.assertTrue(server_sd_blob.get_is_verified()) self.assertTrue(server_sd_blob.get_is_verified())
self.assertEqual(server_sd_blob.length, server_sd_blob.length) self.assertEqual(server_sd_blob.length, server_sd_blob.length)
@ -75,3 +78,112 @@ class TestStreamAssembler(AsyncioTestCase):
to_announce = await self.storage.get_blobs_to_announce() to_announce = await self.storage.get_blobs_to_announce()
self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce") self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce")
self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce") self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce")
async def test_result_from_disconnect_mid_sd_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
stop.set()
# this used to raise (and then propagate) a CancelledError
self.assertListEqual(await reflect_task, [])
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_after_sd_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_after_data_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
await incoming.wait()
await not_incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified())
self.assertFalse(self.stream.is_fully_reflected)
async def test_result_from_disconnect_mid_data_transfer(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
await incoming.wait()
stop.set()
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertFalse(
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
self.assertFalse(self.stream.is_fully_reflected)
async def test_delete_file_during_reflector_upload(self):
stop = asyncio.Event()
incoming = asyncio.Event()
not_incoming = asyncio.Event()
reflector = ReflectorServer(
self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming,
not_incoming_event=not_incoming
)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
self.assertEqual(0, self.stream.reflector_progress)
reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566))
await incoming.wait()
await not_incoming.wait()
await incoming.wait()
await self.stream_manager.delete(self.stream, delete_file=True)
# this used to raise OSError when it can't read the deleted blob for the upload
self.assertListEqual(await reflect_task, [self.stream.sd_hash])
self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified())
self.assertFalse(
self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()
)
self.assertFalse(self.stream.is_fully_reflected)