diff --git a/lbry/file/source.py b/lbry/file/source.py index ba5bb311f..fa1a67cec 100644 --- a/lbry/file/source.py +++ b/lbry/file/source.py @@ -67,7 +67,7 @@ class ManagedDownloadSource: async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None): raise NotImplementedError() - def stop_tasks(self): + async def stop_tasks(self): raise NotImplementedError() def set_claim(self, claim_info: typing.Dict, claim: 'Claim'): diff --git a/lbry/file/source_manager.py b/lbry/file/source_manager.py index 72c1709dd..f3e82532e 100644 --- a/lbry/file/source_manager.py +++ b/lbry/file/source_manager.py @@ -59,11 +59,11 @@ class SourceManager: def add(self, source: ManagedDownloadSource): self._sources[source.identifier] = source - def remove(self, source: ManagedDownloadSource): + async def remove(self, source: ManagedDownloadSource): if source.identifier not in self._sources: return self._sources.pop(source.identifier) - source.stop_tasks() + await source.stop_tasks() async def initialize_from_database(self): raise NotImplementedError() @@ -72,10 +72,10 @@ class SourceManager: await self.initialize_from_database() self.started.set() - def stop(self): + async def stop(self): while self._sources: _, source = self._sources.popitem() - source.stop_tasks() + await source.stop_tasks() self.started.clear() async def create(self, file_path: str, key: Optional[bytes] = None, @@ -83,7 +83,7 @@ class SourceManager: raise NotImplementedError() async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): - self.remove(source) + await self.remove(source) if delete_file and source.output_file_exists: os.remove(source.full_path) diff --git a/lbry/stream/managed_stream.py b/lbry/stream/managed_stream.py index 6f12bb63f..7a4b69093 100644 --- a/lbry/stream/managed_stream.py +++ b/lbry/stream/managed_stream.py @@ -191,7 +191,7 @@ class ManagedStream(ManagedDownloadSource): Stop any running save/stream tasks as well as the downloader and update the status in the database """ - self.stop_tasks() + await self.stop_tasks() if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING: await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED) @@ -324,12 +324,13 @@ class ManagedStream(ManagedDownloadSource): await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout) except asyncio.TimeoutError: log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id) - self.stop_tasks() + await self.stop_tasks() await self.update_status(ManagedStream.STATUS_STOPPED) - def stop_tasks(self): + async def stop_tasks(self): if self.file_output_task and not self.file_output_task.done(): self.file_output_task.cancel() + await asyncio.gather(self.file_output_task, return_exceptions=True) self.file_output_task = None while self.streaming_responses: req, response = self.streaming_responses.pop() diff --git a/lbry/stream/stream_manager.py b/lbry/stream/stream_manager.py index bc37a1cc9..2379c1185 100644 --- a/lbry/stream/stream_manager.py +++ b/lbry/stream/stream_manager.py @@ -196,8 +196,8 @@ class StreamManager(SourceManager): await super().start() self.re_reflect_task = self.loop.create_task(self.reflect_streams()) - def stop(self): - super().stop() + async def stop(self): + await super().stop() if self.resume_saving_task and not self.resume_saving_task.done(): self.resume_saving_task.cancel() if self.re_reflect_task and not self.re_reflect_task.done(): @@ -260,7 +260,7 @@ class StreamManager(SourceManager): return if source.identifier in self.running_reflector_uploads: self.running_reflector_uploads[source.identifier].cancel() - source.stop_tasks() + await source.stop_tasks() if source.identifier in self.streams: del self.streams[source.identifier] blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]] diff --git a/lbry/torrent/torrent_manager.py b/lbry/torrent/torrent_manager.py index cf9106731..0e0fc8882 100644 --- a/lbry/torrent/torrent_manager.py +++ b/lbry/torrent/torrent_manager.py @@ -74,7 +74,7 @@ class TorrentSource(ManagedDownloadSource): def bt_infohash(self): return self.identifier - def stop_tasks(self): + async def stop_tasks(self): pass @property @@ -118,8 +118,8 @@ class TorrentManager(SourceManager): async def start(self): await super().start() - def stop(self): - super().stop() + async def stop(self): + await super().stop() log.info("finished stopping the torrent manager") async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index ba6d8dbc8..98e3fdd26 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -424,7 +424,7 @@ class TestStreamManager(BlobExchangeTestBase): self.assertIsNone(stream.full_path) self.assertEqual(0, stream.written_bytes) - self.stream_manager.stop() + await self.stream_manager.stop() await self.stream_manager.start() self.assertEqual(1, len(self.stream_manager.streams)) stream = list(self.stream_manager.streams.values())[0] @@ -449,7 +449,7 @@ class TestStreamManager(BlobExchangeTestBase): stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) await stream.finished_writing.wait() await asyncio.sleep(0) - self.stream_manager.stop() + await self.stream_manager.stop() self.client_blob_manager.stop() # partial removal, only sd blob is missing. # in this case, we recover the sd blob while the other blobs are kept untouched as 'finished'