forked from LBRYCommunity/lbry-sdk
Make stop(), stop_tasks() consistently async routines, and have stop_tasks()
wait for file_output_task completion. This fixes a problem with test_download_stop_resume_delete.
This commit is contained in:
parent
ea4fba39a6
commit
585962d930
6 changed files with 18 additions and 17 deletions
|
@ -67,7 +67,7 @@ class ManagedDownloadSource:
|
||||||
async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None):
|
async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def stop_tasks(self):
|
async def stop_tasks(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):
|
def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):
|
||||||
|
|
|
@ -59,11 +59,11 @@ class SourceManager:
|
||||||
def add(self, source: ManagedDownloadSource):
|
def add(self, source: ManagedDownloadSource):
|
||||||
self._sources[source.identifier] = source
|
self._sources[source.identifier] = source
|
||||||
|
|
||||||
def remove(self, source: ManagedDownloadSource):
|
async def remove(self, source: ManagedDownloadSource):
|
||||||
if source.identifier not in self._sources:
|
if source.identifier not in self._sources:
|
||||||
return
|
return
|
||||||
self._sources.pop(source.identifier)
|
self._sources.pop(source.identifier)
|
||||||
source.stop_tasks()
|
await source.stop_tasks()
|
||||||
|
|
||||||
async def initialize_from_database(self):
|
async def initialize_from_database(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -72,10 +72,10 @@ class SourceManager:
|
||||||
await self.initialize_from_database()
|
await self.initialize_from_database()
|
||||||
self.started.set()
|
self.started.set()
|
||||||
|
|
||||||
def stop(self):
|
async def stop(self):
|
||||||
while self._sources:
|
while self._sources:
|
||||||
_, source = self._sources.popitem()
|
_, source = self._sources.popitem()
|
||||||
source.stop_tasks()
|
await source.stop_tasks()
|
||||||
self.started.clear()
|
self.started.clear()
|
||||||
|
|
||||||
async def create(self, file_path: str, key: Optional[bytes] = None,
|
async def create(self, file_path: str, key: Optional[bytes] = None,
|
||||||
|
@ -83,7 +83,7 @@ class SourceManager:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
||||||
self.remove(source)
|
await self.remove(source)
|
||||||
if delete_file and source.output_file_exists:
|
if delete_file and source.output_file_exists:
|
||||||
os.remove(source.full_path)
|
os.remove(source.full_path)
|
||||||
|
|
||||||
|
|
|
@ -191,7 +191,7 @@ class ManagedStream(ManagedDownloadSource):
|
||||||
Stop any running save/stream tasks as well as the downloader and update the status in the database
|
Stop any running save/stream tasks as well as the downloader and update the status in the database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.stop_tasks()
|
await self.stop_tasks()
|
||||||
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
||||||
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
||||||
|
|
||||||
|
@ -324,12 +324,13 @@ class ManagedStream(ManagedDownloadSource):
|
||||||
await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout)
|
await asyncio.wait_for(self.started_writing.wait(), self.config.download_timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id)
|
log.warning("timeout starting to write data for lbry://%s#%s", self.claim_name, self.claim_id)
|
||||||
self.stop_tasks()
|
await self.stop_tasks()
|
||||||
await self.update_status(ManagedStream.STATUS_STOPPED)
|
await self.update_status(ManagedStream.STATUS_STOPPED)
|
||||||
|
|
||||||
def stop_tasks(self):
|
async def stop_tasks(self):
|
||||||
if self.file_output_task and not self.file_output_task.done():
|
if self.file_output_task and not self.file_output_task.done():
|
||||||
self.file_output_task.cancel()
|
self.file_output_task.cancel()
|
||||||
|
await asyncio.gather(self.file_output_task, return_exceptions=True)
|
||||||
self.file_output_task = None
|
self.file_output_task = None
|
||||||
while self.streaming_responses:
|
while self.streaming_responses:
|
||||||
req, response = self.streaming_responses.pop()
|
req, response = self.streaming_responses.pop()
|
||||||
|
|
|
@ -196,8 +196,8 @@ class StreamManager(SourceManager):
|
||||||
await super().start()
|
await super().start()
|
||||||
self.re_reflect_task = self.loop.create_task(self.reflect_streams())
|
self.re_reflect_task = self.loop.create_task(self.reflect_streams())
|
||||||
|
|
||||||
def stop(self):
|
async def stop(self):
|
||||||
super().stop()
|
await super().stop()
|
||||||
if self.resume_saving_task and not self.resume_saving_task.done():
|
if self.resume_saving_task and not self.resume_saving_task.done():
|
||||||
self.resume_saving_task.cancel()
|
self.resume_saving_task.cancel()
|
||||||
if self.re_reflect_task and not self.re_reflect_task.done():
|
if self.re_reflect_task and not self.re_reflect_task.done():
|
||||||
|
@ -260,7 +260,7 @@ class StreamManager(SourceManager):
|
||||||
return
|
return
|
||||||
if source.identifier in self.running_reflector_uploads:
|
if source.identifier in self.running_reflector_uploads:
|
||||||
self.running_reflector_uploads[source.identifier].cancel()
|
self.running_reflector_uploads[source.identifier].cancel()
|
||||||
source.stop_tasks()
|
await source.stop_tasks()
|
||||||
if source.identifier in self.streams:
|
if source.identifier in self.streams:
|
||||||
del self.streams[source.identifier]
|
del self.streams[source.identifier]
|
||||||
blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]]
|
blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]]
|
||||||
|
|
|
@ -74,7 +74,7 @@ class TorrentSource(ManagedDownloadSource):
|
||||||
def bt_infohash(self):
|
def bt_infohash(self):
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
def stop_tasks(self):
|
async def stop_tasks(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -118,8 +118,8 @@ class TorrentManager(SourceManager):
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super().start()
|
await super().start()
|
||||||
|
|
||||||
def stop(self):
|
async def stop(self):
|
||||||
super().stop()
|
await super().stop()
|
||||||
log.info("finished stopping the torrent manager")
|
log.info("finished stopping the torrent manager")
|
||||||
|
|
||||||
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False):
|
||||||
|
|
|
@ -424,7 +424,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
||||||
self.assertIsNone(stream.full_path)
|
self.assertIsNone(stream.full_path)
|
||||||
self.assertEqual(0, stream.written_bytes)
|
self.assertEqual(0, stream.written_bytes)
|
||||||
|
|
||||||
self.stream_manager.stop()
|
await self.stream_manager.stop()
|
||||||
await self.stream_manager.start()
|
await self.stream_manager.start()
|
||||||
self.assertEqual(1, len(self.stream_manager.streams))
|
self.assertEqual(1, len(self.stream_manager.streams))
|
||||||
stream = list(self.stream_manager.streams.values())[0]
|
stream = list(self.stream_manager.streams.values())[0]
|
||||||
|
@ -449,7 +449,7 @@ class TestStreamManager(BlobExchangeTestBase):
|
||||||
stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager)
|
stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager)
|
||||||
await stream.finished_writing.wait()
|
await stream.finished_writing.wait()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
self.stream_manager.stop()
|
await self.stream_manager.stop()
|
||||||
self.client_blob_manager.stop()
|
self.client_blob_manager.stop()
|
||||||
# partial removal, only sd blob is missing.
|
# partial removal, only sd blob is missing.
|
||||||
# in this case, we recover the sd blob while the other blobs are kept untouched as 'finished'
|
# in this case, we recover the sd blob while the other blobs are kept untouched as 'finished'
|
||||||
|
|
Loading…
Reference in a new issue