forked from LBRYCommunity/lbry-sdk
Merge pull request #1841 from lbryio/general_fixes_downloader
General fixes downloader
This commit is contained in:
commit
3eac99c509
10 changed files with 36 additions and 42 deletions
|
@ -77,7 +77,7 @@ class BlobFile:
|
|||
def writer_finished(self, writer: HashBlobWriter):
|
||||
def callback(finished: asyncio.Future):
|
||||
try:
|
||||
error = finished.result()
|
||||
error = finished.exception()
|
||||
except Exception as err:
|
||||
error = err
|
||||
if writer in self.writers: # remove this download attempt
|
||||
|
@ -86,7 +86,7 @@ class BlobFile:
|
|||
while self.writers:
|
||||
other = self.writers.pop()
|
||||
other.finished.cancel()
|
||||
t = self.loop.create_task(self.save_verified_blob(writer))
|
||||
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
|
||||
t.add_done_callback(lambda *_: self.finished_writing.set())
|
||||
return
|
||||
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
||||
|
@ -96,24 +96,24 @@ class BlobFile:
|
|||
raise error
|
||||
return callback
|
||||
|
||||
async def save_verified_blob(self, writer):
|
||||
async def save_verified_blob(self, writer, verified_bytes: bytes):
|
||||
def _save_verified():
|
||||
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
|
||||
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
|
||||
if self.get_length() == len(writer.verified_bytes):
|
||||
if self.get_length() == len(verified_bytes):
|
||||
with open(self.file_path, 'wb') as write_handle:
|
||||
write_handle.write(writer.verified_bytes)
|
||||
write_handle.write(verified_bytes)
|
||||
self.saved_verified_blob = True
|
||||
else:
|
||||
raise Exception("length mismatch")
|
||||
|
||||
async with self.blob_write_lock:
|
||||
if self.verified.is_set():
|
||||
return
|
||||
async with self.blob_write_lock:
|
||||
await self.loop.run_in_executor(None, _save_verified)
|
||||
self.verified.set()
|
||||
if self.blob_completed_callback:
|
||||
await self.blob_completed_callback(self)
|
||||
self.verified.set()
|
||||
|
||||
def open_for_writing(self) -> HashBlobWriter:
|
||||
if os.path.exists(self.file_path):
|
||||
|
|
|
@ -18,7 +18,6 @@ class HashBlobWriter:
|
|||
self.finished.add_done_callback(lambda *_: self.close_handle())
|
||||
self._hashsum = get_lbry_hash_obj()
|
||||
self.len_so_far = 0
|
||||
self.verified_bytes = b''
|
||||
|
||||
def __del__(self):
|
||||
if self.buffer is not None:
|
||||
|
@ -46,7 +45,7 @@ class HashBlobWriter:
|
|||
self.len_so_far += len(data)
|
||||
if self.len_so_far > expected_length:
|
||||
self.close_handle()
|
||||
self.finished.set_result(InvalidDataError(
|
||||
self.finished.set_exception(InvalidDataError(
|
||||
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
|
||||
))
|
||||
return
|
||||
|
@ -55,15 +54,12 @@ class HashBlobWriter:
|
|||
blob_hash = self.calculate_blob_hash()
|
||||
if blob_hash != self.expected_blob_hash:
|
||||
self.close_handle()
|
||||
self.finished.set_result(InvalidBlobHashError(
|
||||
self.finished.set_exception(InvalidBlobHashError(
|
||||
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
|
||||
))
|
||||
return
|
||||
self.buffer.seek(0)
|
||||
self.verified_bytes = self.buffer.read()
|
||||
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
|
||||
self.finished.set_result(self.buffer.getvalue())
|
||||
self.close_handle()
|
||||
if self.finished and not (self.finished.done() or self.finished.cancelled()):
|
||||
self.finished.set_result(None)
|
||||
|
||||
def close_handle(self):
|
||||
if self.buffer is not None:
|
||||
|
|
|
@ -57,6 +57,7 @@ class BlobDownloader:
|
|||
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
||||
except asyncio.CancelledError:
|
||||
drain_tasks(tasks)
|
||||
raise
|
||||
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||
|
|
|
@ -459,7 +459,7 @@ class StreamManagerComponent(Component):
|
|||
log.info('Done setting up file manager')
|
||||
|
||||
async def stop(self):
|
||||
await self.stream_manager.stop()
|
||||
self.stream_manager.stop()
|
||||
|
||||
|
||||
class PeerProtocolServerComponent(Component):
|
||||
|
|
|
@ -1621,7 +1621,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
stream.downloader.download(self.dht_node)
|
||||
msg = "Resumed download"
|
||||
elif status == 'stop' and stream.running:
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
msg = "Stopped download"
|
||||
else:
|
||||
msg = (
|
||||
|
|
|
@ -43,13 +43,11 @@ class StreamAssembler:
|
|||
self.written_bytes: int = 0
|
||||
|
||||
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
||||
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
||||
if not blob or self.stream_handle.closed:
|
||||
return False
|
||||
|
||||
def _decrypt_and_write():
|
||||
if self.stream_handle.closed:
|
||||
return False
|
||||
if not blob:
|
||||
return False
|
||||
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
||||
self.stream_handle.seek(offset)
|
||||
_decrypted = blob.decrypt(
|
||||
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
|
||||
|
@ -57,12 +55,10 @@ class StreamAssembler:
|
|||
self.stream_handle.write(_decrypted)
|
||||
self.stream_handle.flush()
|
||||
self.written_bytes += len(_decrypted)
|
||||
return True
|
||||
|
||||
decrypted = await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||
if decrypted:
|
||||
log.debug("decrypted %s", blob.blob_hash[:8])
|
||||
return
|
||||
self.wrote_bytes_event.set()
|
||||
|
||||
await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
@ -106,8 +102,6 @@ class StreamAssembler:
|
|||
self.descriptor.sd_hash)
|
||||
continue
|
||||
|
||||
if not self.wrote_bytes_event.is_set():
|
||||
self.wrote_bytes_event.set()
|
||||
self.stream_finished_event.set()
|
||||
await self.after_finished()
|
||||
finally:
|
||||
|
|
|
@ -52,11 +52,11 @@ class StreamDownloader(StreamAssembler):
|
|||
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
||||
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
||||
|
||||
async def stop(self):
|
||||
if self.accumulate_task and not self.accumulate_task.done():
|
||||
def stop(self):
|
||||
if self.accumulate_task:
|
||||
self.accumulate_task.cancel()
|
||||
self.accumulate_task = None
|
||||
if self.assemble_task and not self.assemble_task.done():
|
||||
if self.assemble_task:
|
||||
self.assemble_task.cancel()
|
||||
self.assemble_task = None
|
||||
if self.fixed_peers_handle:
|
||||
|
@ -80,7 +80,7 @@ class StreamDownloader(StreamAssembler):
|
|||
and self.node
|
||||
and len(self.node.protocol.routing_table.get_peers())
|
||||
) else 0.0,
|
||||
self.loop.create_task, _add_fixed_peers()
|
||||
lambda: self.loop.create_task(_add_fixed_peers())
|
||||
)
|
||||
|
||||
def download(self, node: typing.Optional['Node'] = None):
|
||||
|
|
|
@ -155,9 +155,9 @@ class ManagedStream:
|
|||
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
|
||||
status=cls.STATUS_FINISHED)
|
||||
|
||||
async def stop_download(self):
|
||||
def stop_download(self):
|
||||
if self.downloader:
|
||||
await self.downloader.stop()
|
||||
self.downloader.stop()
|
||||
if not self.finished:
|
||||
self.update_status(self.STATUS_STOPPED)
|
||||
|
||||
|
|
|
@ -122,12 +122,12 @@ class StreamManager:
|
|||
await self.load_streams_from_database()
|
||||
self.resume_downloading_task = self.loop.create_task(self.resume())
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
if self.resume_downloading_task and not self.resume_downloading_task.done():
|
||||
self.resume_downloading_task.cancel()
|
||||
while self.streams:
|
||||
stream = self.streams.pop()
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
while self.update_stream_finished_futs:
|
||||
self.update_stream_finished_futs.pop().cancel()
|
||||
|
||||
|
@ -142,7 +142,7 @@ class StreamManager:
|
|||
return stream
|
||||
|
||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||
await stream.stop_download()
|
||||
stream.stop_download()
|
||||
self.streams.remove(stream)
|
||||
await self.storage.delete_stream(stream.descriptor)
|
||||
|
||||
|
@ -182,7 +182,7 @@ class StreamManager:
|
|||
log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
log.info("stream timeout")
|
||||
await downloader.stop()
|
||||
downloader.stop()
|
||||
log.info("stopped stream")
|
||||
return
|
||||
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
|
||||
|
@ -204,7 +204,8 @@ class StreamManager:
|
|||
self.wait_for_stream_finished(stream)
|
||||
return stream
|
||||
except asyncio.CancelledError:
|
||||
await downloader.stop()
|
||||
downloader.stop()
|
||||
log.debug("stopped stream")
|
||||
|
||||
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
|
||||
file_name: typing.Optional[str] = None,
|
||||
|
|
|
@ -40,10 +40,12 @@ class TestStreamDownloader(BlobExchangeTestBase):
|
|||
|
||||
self.downloader.download(mock_node)
|
||||
await self.downloader.stream_finished_event.wait()
|
||||
await self.downloader.stop()
|
||||
self.downloader.stop()
|
||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||
with open(self.downloader.output_path, 'rb') as f:
|
||||
self.assertEqual(f.read(), self.stream_bytes)
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertTrue(self.downloader.stream_handle.closed)
|
||||
|
||||
async def test_transfer_stream(self):
|
||||
await self._test_transfer_stream(10)
|
||||
|
|
Loading…
Reference in a new issue