set verified event earlier and remove stop awaits
This commit is contained in:
parent
832be0758b
commit
ca4a55ef28
7 changed files with 23 additions and 26 deletions
|
@ -459,7 +459,7 @@ class StreamManagerComponent(Component):
|
|||
log.info('Done setting up file manager')
|
||||
|
||||
async def stop(self):
|
||||
await self.stream_manager.stop()
|
||||
self.stream_manager.stop()
|
||||
|
||||
|
||||
class PeerProtocolServerComponent(Component):
|
||||
|
|
|
@ -1621,7 +1621,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
stream.downloader.download(self.dht_node)
|
||||
msg = "Resumed download"
|
||||
elif status == 'stop' and stream.running:
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
msg = "Stopped download"
|
||||
else:
|
||||
msg = (
|
||||
|
|
|
@ -43,13 +43,11 @@ class StreamAssembler:
|
|||
self.written_bytes: int = 0
|
||||
|
||||
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
||||
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
||||
if not blob or self.stream_handle.closed:
|
||||
return False
|
||||
|
||||
def _decrypt_and_write():
|
||||
if self.stream_handle.closed:
|
||||
return False
|
||||
if not blob:
|
||||
return False
|
||||
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
||||
self.stream_handle.seek(offset)
|
||||
_decrypted = blob.decrypt(
|
||||
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
|
||||
|
@ -57,12 +55,10 @@ class StreamAssembler:
|
|||
self.stream_handle.write(_decrypted)
|
||||
self.stream_handle.flush()
|
||||
self.written_bytes += len(_decrypted)
|
||||
return True
|
||||
|
||||
decrypted = await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||
if decrypted:
|
||||
log.debug("decrypted %s", blob.blob_hash[:8])
|
||||
return
|
||||
self.wrote_bytes_event.set()
|
||||
|
||||
await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
@ -106,8 +102,6 @@ class StreamAssembler:
|
|||
self.descriptor.sd_hash)
|
||||
continue
|
||||
|
||||
if not self.wrote_bytes_event.is_set():
|
||||
self.wrote_bytes_event.set()
|
||||
self.stream_finished_event.set()
|
||||
await self.after_finished()
|
||||
finally:
|
||||
|
|
|
@ -52,11 +52,11 @@ class StreamDownloader(StreamAssembler):
|
|||
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
||||
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
||||
|
||||
async def stop(self):
|
||||
if self.accumulate_task and not self.accumulate_task.done():
|
||||
def stop(self):
|
||||
if self.accumulate_task:
|
||||
self.accumulate_task.cancel()
|
||||
self.accumulate_task = None
|
||||
if self.assemble_task and not self.assemble_task.done():
|
||||
if self.assemble_task:
|
||||
self.assemble_task.cancel()
|
||||
self.assemble_task = None
|
||||
if self.fixed_peers_handle:
|
||||
|
@ -80,7 +80,7 @@ class StreamDownloader(StreamAssembler):
|
|||
and self.node
|
||||
and len(self.node.protocol.routing_table.get_peers())
|
||||
) else 0.0,
|
||||
self.loop.create_task, _add_fixed_peers()
|
||||
lambda: self.loop.create_task(_add_fixed_peers())
|
||||
)
|
||||
|
||||
def download(self, node: typing.Optional['Node'] = None):
|
||||
|
|
|
@ -155,9 +155,9 @@ class ManagedStream:
|
|||
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
|
||||
status=cls.STATUS_FINISHED)
|
||||
|
||||
async def stop_download(self):
|
||||
def stop_download(self):
|
||||
if self.downloader:
|
||||
await self.downloader.stop()
|
||||
self.downloader.stop()
|
||||
if not self.finished:
|
||||
self.update_status(self.STATUS_STOPPED)
|
||||
|
||||
|
|
|
@ -122,12 +122,12 @@ class StreamManager:
|
|||
await self.load_streams_from_database()
|
||||
self.resume_downloading_task = self.loop.create_task(self.resume())
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
if self.resume_downloading_task and not self.resume_downloading_task.done():
|
||||
self.resume_downloading_task.cancel()
|
||||
while self.streams:
|
||||
stream = self.streams.pop()
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
while self.update_stream_finished_futs:
|
||||
self.update_stream_finished_futs.pop().cancel()
|
||||
|
||||
|
@ -142,7 +142,7 @@ class StreamManager:
|
|||
return stream
|
||||
|
||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
self.streams.remove(stream)
|
||||
await self.storage.delete_stream(stream.descriptor)
|
||||
|
||||
|
@ -182,7 +182,7 @@ class StreamManager:
|
|||
log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
log.info("stream timeout")
|
||||
await downloader.stop()
|
||||
downloader.stop()
|
||||
log.info("stopped stream")
|
||||
return
|
||||
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
|
||||
|
@ -204,7 +204,8 @@ class StreamManager:
|
|||
self.wait_for_stream_finished(stream)
|
||||
return stream
|
||||
except asyncio.CancelledError:
|
||||
await downloader.stop()
|
||||
downloader.stop()
|
||||
log.info("stopped stream")
|
||||
|
||||
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
|
||||
file_name: typing.Optional[str] = None,
|
||||
|
|
|
@ -40,10 +40,12 @@ class TestStreamDownloader(BlobExchangeTestBase):
|
|||
|
||||
self.downloader.download(mock_node)
|
||||
await self.downloader.stream_finished_event.wait()
|
||||
await self.downloader.stop()
|
||||
self.downloader.stop()
|
||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||
with open(self.downloader.output_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertTrue(self.downloader.stream_handle.closed)
|
||||
|
||||
async def test_transfer_stream(self):
|
||||
await self._test_transfer_stream(10)
|
||||
|
|
Loading…
Reference in a new issue