diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index b152f6e57..a6ffc569b 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -104,7 +104,10 @@ class StreamDownloader: async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob': if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]): raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}") - blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id) + blob = await asyncio.wait_for( + self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id), + self.config.blob_download_timeout * 10, loop=self.loop + ) return blob def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index 42b8f29ce..d99d4ddeb 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -66,7 +66,7 @@ class ManagedStream: 'saving', 'finished_writing', 'started_writing', - + 'finished_write_attempt' ] def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', @@ -100,6 +100,7 @@ class ManagedStream: self.saving = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop) self.started_writing = asyncio.Event(loop=self.loop) + self.finished_write_attempt = asyncio.Event(loop=self.loop) @property def descriptor(self) -> StreamDescriptor: @@ -347,6 +348,7 @@ class ManagedStream: log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6], output_path) self.saving.set() + self.finished_write_attempt.clear() self.finished_writing.clear() self.started_writing.clear() try: @@ -370,11 +372,22 @@ class ManagedStream: if os.path.isfile(output_path): log.warning("removing incomplete download %s for %s", output_path, self.sd_hash) os.remove(output_path) - if not isinstance(err, asyncio.CancelledError): + self.written_bytes = 0 + if isinstance(err, asyncio.TimeoutError): + self.downloader.stop() + await self.blob_manager.storage.change_file_download_dir_and_file_name( + self.stream_hash, None, None + ) + self._file_name, self.download_directory = None, None + await self.blob_manager.storage.clear_saved_file(self.stream_hash) + await self.update_status(self.STATUS_STOPPED) + return + elif not isinstance(err, asyncio.CancelledError): log.exception("unexpected error encountered writing file for stream %s", self.sd_hash) raise err finally: self.saving.clear() + self.finished_write_attempt.set() async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None, node: typing.Optional['Node'] = None): diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index b6df1d237..6f839b931 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -308,6 +308,26 @@ class TestStreamManager(BlobExchangeTestBase): self.server_blob_manager.delete_blob(head_blob_hash) await self._test_download_error_analytics_on_start(DownloadDataTimeout, timeout=1) + async def test_non_head_data_timeout(self): + await self.setup_stream_manager() + with open(os.path.join(self.server_dir, self.sd_hash), 'r') as sdf: + head_blob_hash = json.loads(sdf.read())['blobs'][-2]['blob_hash'] + self.server_blob_manager.delete_blob(head_blob_hash) + self.client_config.blob_download_timeout = 0.1 + stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + await stream.finished_write_attempt.wait() + self.assertEqual('stopped', stream.status) + self.assertIsNone(stream.full_path) + self.assertEqual(0, stream.written_bytes) + + self.stream_manager.stop() + await self.stream_manager.start() + self.assertEqual(1, len(self.stream_manager.streams)) + stream = list(self.stream_manager.streams.values())[0] + self.assertEqual('stopped', stream.status) + self.assertIsNone(stream.full_path) + self.assertEqual(0, stream.written_bytes) + async def test_download_then_recover_stream_on_startup(self, old_sort=False): expected_analytics_events = [ 'Time To First Bytes',