forked from LBRYCommunity/lbry-sdk
Make stop(), stop_tasks() consistently async routines, and have stop_tasks()
wait for file_output_task completion. This fixes a problem with test_download_stop_resume_delete.
This commit is contained in:
parent
ea4fba39a6
commit
585962d930
6 changed files with 18 additions and 17 deletions
|
@ -67,7 +67,7 @@ class ManagedDownloadSource:
|
|||
async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def stop_tasks(self):
|
||||
async def stop_tasks(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):
|
||||
|
|
|
@ -59,11 +59,11 @@ class SourceManager:
|
|||
def add(self, source: ManagedDownloadSource):
|
||||
self._sources[source.identifier] = source
|
||||
|
||||
def remove(self, source: ManagedDownloadSource):
|
||||
async def remove(self, source: ManagedDownloadSource):
|
||||
if source.identifier not in self._sources:
|
||||
return
|
||||
self._sources.pop(source.identifier)
|
||||
source.stop_tasks()
|
||||
await source.stop_tasks()
|
||||
|
||||
async def initialize_from_database(self):
|
||||
raise NotImplementedError()
|
||||
|
@ -72,10 +72,10 @@ class SourceManager:
|
|||
await self.initialize_from_database()
|
||||
self.started.set()
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
while self._sources:
|
||||
_, source = self._sources.popitem()
|
||||
source.stop_tasks()
|
||||
await source.stop_tasks()
|
||||
self.started.clear()
|
||||
|
||||
async def create(self, file_path: str, key: Optional[bytes] = None,
|
||||
|
@ -83,7 +83,7 @@ class SourceManager:
|
|||
raise NotImplementedError()
|
||||
|
||||
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
||||
self.remove(source)
|
||||
await self.remove(source)
|
||||
if delete_file and source.output_file_exists:
|
||||
os.remove(source.full_path)
|
||||
|
||||
|
|
|
@ -191,7 +191,7 @@ class ManagedStream(ManagedDownloadSource):
|
|||
Stop any running save/stream tasks as well as the downloader and update the status in the database
|
||||
"""
|
||||
|
||||
self.stop_tasks()
|
||||
await self.stop_tasks()
|
||||
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
||||
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
||||
|
||||
|
@ -324,12 +324,13 @@ class ManagedStream(ManagedDownloadSource):
|
|||
await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id)
|
||||
self.stop_tasks()
|
||||
await self.stop_tasks()
|
||||
await self.update_status(ManagedStream.STATUS_STOPPED)
|
||||
|
||||
def stop_tasks(self):
|
||||
async def stop_tasks(self):
|
||||
if self.file_output_task and not self.file_output_task.done():
|
||||
self.file_output_task.cancel()
|
||||
await asyncio.gather(self.file_output_task, return_exceptions=True)
|
||||
self.file_output_task = None
|
||||
while self.streaming_responses:
|
||||
req, response = self.streaming_responses.pop()
|
||||
|
|
|
@ -196,8 +196,8 @@ class StreamManager(SourceManager):
|
|||
await super().start()
|
||||
self.re_reflect_task = self.loop.create_task(self.reflect_streams())
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
if self.resume_saving_task and not self.resume_saving_task.done():
|
||||
self.resume_saving_task.cancel()
|
||||
if self.re_reflect_task and not self.re_reflect_task.done():
|
||||
|
@ -260,7 +260,7 @@ class StreamManager(SourceManager):
|
|||
return
|
||||
if source.identifier in self.running_reflector_uploads:
|
||||
self.running_reflector_uploads[source.identifier].cancel()
|
||||
source.stop_tasks()
|
||||
await source.stop_tasks()
|
||||
if source.identifier in self.streams:
|
||||
del self.streams[source.identifier]
|
||||
blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]]
|
||||
|
|
|
@ -74,7 +74,7 @@ class TorrentSource(ManagedDownloadSource):
|
|||
def bt_infohash(self):
|
||||
return self.identifier
|
||||
|
||||
def stop_tasks(self):
|
||||
async def stop_tasks(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -118,8 +118,8 @@ class TorrentManager(SourceManager):
|
|||
async def start(self):
|
||||
await super().start()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
log.info("finished stopping the torrent manager")
|
||||
|
||||
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
||||
|
|
|
@ -424,7 +424,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
self.assertIsNone(stream.full_path)
|
||||
self.assertEqual(0, stream.written_bytes)
|
||||
|
||||
self.stream_manager.stop()
|
||||
await self.stream_manager.stop()
|
||||
await self.stream_manager.start()
|
||||
self.assertEqual(1, len(self.stream_manager.streams))
|
||||
stream = list(self.stream_manager.streams.values())[0]
|
||||
|
@ -449,7 +449,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
|||
stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager)
|
||||
await stream.finished_writing.wait()
|
||||
await asyncio.sleep(0)
|
||||
self.stream_manager.stop()
|
||||
await self.stream_manager.stop()
|
||||
self.client_blob_manager.stop()
|
||||
# partial removal, only sd blob is missing.
|
||||
# in this case, we recover the sd blob while the other blobs are kept untouched as 'finished'
|
||||
|
|
Loading…
Reference in a new issue