forked from LBRYCommunity/lbry-sdk
Merge pull request #1841 from lbryio/general_fixes_downloader
General fixes downloader
This commit is contained in:
commit
3eac99c509
10 changed files with 36 additions and 42 deletions
|
@ -77,7 +77,7 @@ class BlobFile:
|
||||||
def writer_finished(self, writer: HashBlobWriter):
|
def writer_finished(self, writer: HashBlobWriter):
|
||||||
def callback(finished: asyncio.Future):
|
def callback(finished: asyncio.Future):
|
||||||
try:
|
try:
|
||||||
error = finished.result()
|
error = finished.exception()
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
error = err
|
error = err
|
||||||
if writer in self.writers: # remove this download attempt
|
if writer in self.writers: # remove this download attempt
|
||||||
|
@ -86,7 +86,7 @@ class BlobFile:
|
||||||
while self.writers:
|
while self.writers:
|
||||||
other = self.writers.pop()
|
other = self.writers.pop()
|
||||||
other.finished.cancel()
|
other.finished.cancel()
|
||||||
t = self.loop.create_task(self.save_verified_blob(writer))
|
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
|
||||||
t.add_done_callback(lambda *_: self.finished_writing.set())
|
t.add_done_callback(lambda *_: self.finished_writing.set())
|
||||||
return
|
return
|
||||||
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
||||||
|
@ -96,24 +96,24 @@ class BlobFile:
|
||||||
raise error
|
raise error
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
async def save_verified_blob(self, writer):
|
async def save_verified_blob(self, writer, verified_bytes: bytes):
|
||||||
def _save_verified():
|
def _save_verified():
|
||||||
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
|
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
|
||||||
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
|
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
|
||||||
if self.get_length() == len(writer.verified_bytes):
|
if self.get_length() == len(verified_bytes):
|
||||||
with open(self.file_path, 'wb') as write_handle:
|
with open(self.file_path, 'wb') as write_handle:
|
||||||
write_handle.write(writer.verified_bytes)
|
write_handle.write(verified_bytes)
|
||||||
self.saved_verified_blob = True
|
self.saved_verified_blob = True
|
||||||
else:
|
else:
|
||||||
raise Exception("length mismatch")
|
raise Exception("length mismatch")
|
||||||
|
|
||||||
|
async with self.blob_write_lock:
|
||||||
if self.verified.is_set():
|
if self.verified.is_set():
|
||||||
return
|
return
|
||||||
async with self.blob_write_lock:
|
|
||||||
await self.loop.run_in_executor(None, _save_verified)
|
await self.loop.run_in_executor(None, _save_verified)
|
||||||
|
self.verified.set()
|
||||||
if self.blob_completed_callback:
|
if self.blob_completed_callback:
|
||||||
await self.blob_completed_callback(self)
|
await self.blob_completed_callback(self)
|
||||||
self.verified.set()
|
|
||||||
|
|
||||||
def open_for_writing(self) -> HashBlobWriter:
|
def open_for_writing(self) -> HashBlobWriter:
|
||||||
if os.path.exists(self.file_path):
|
if os.path.exists(self.file_path):
|
||||||
|
|
|
@ -18,7 +18,6 @@ class HashBlobWriter:
|
||||||
self.finished.add_done_callback(lambda *_: self.close_handle())
|
self.finished.add_done_callback(lambda *_: self.close_handle())
|
||||||
self._hashsum = get_lbry_hash_obj()
|
self._hashsum = get_lbry_hash_obj()
|
||||||
self.len_so_far = 0
|
self.len_so_far = 0
|
||||||
self.verified_bytes = b''
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
|
@ -46,7 +45,7 @@ class HashBlobWriter:
|
||||||
self.len_so_far += len(data)
|
self.len_so_far += len(data)
|
||||||
if self.len_so_far > expected_length:
|
if self.len_so_far > expected_length:
|
||||||
self.close_handle()
|
self.close_handle()
|
||||||
self.finished.set_result(InvalidDataError(
|
self.finished.set_exception(InvalidDataError(
|
||||||
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
|
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
|
||||||
))
|
))
|
||||||
return
|
return
|
||||||
|
@ -55,15 +54,12 @@ class HashBlobWriter:
|
||||||
blob_hash = self.calculate_blob_hash()
|
blob_hash = self.calculate_blob_hash()
|
||||||
if blob_hash != self.expected_blob_hash:
|
if blob_hash != self.expected_blob_hash:
|
||||||
self.close_handle()
|
self.close_handle()
|
||||||
self.finished.set_result(InvalidBlobHashError(
|
self.finished.set_exception(InvalidBlobHashError(
|
||||||
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
|
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
|
||||||
))
|
))
|
||||||
return
|
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
|
||||||
self.buffer.seek(0)
|
self.finished.set_result(self.buffer.getvalue())
|
||||||
self.verified_bytes = self.buffer.read()
|
|
||||||
self.close_handle()
|
self.close_handle()
|
||||||
if self.finished and not (self.finished.done() or self.finished.cancelled()):
|
|
||||||
self.finished.set_result(None)
|
|
||||||
|
|
||||||
def close_handle(self):
|
def close_handle(self):
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
|
|
|
@ -57,6 +57,7 @@ class BlobDownloader:
|
||||||
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
drain_tasks(tasks)
|
drain_tasks(tasks)
|
||||||
|
raise
|
||||||
|
|
||||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||||
|
|
|
@ -459,7 +459,7 @@ class StreamManagerComponent(Component):
|
||||||
log.info('Done setting up file manager')
|
log.info('Done setting up file manager')
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
await self.stream_manager.stop()
|
self.stream_manager.stop()
|
||||||
|
|
||||||
|
|
||||||
class PeerProtocolServerComponent(Component):
|
class PeerProtocolServerComponent(Component):
|
||||||
|
|
|
@ -1621,7 +1621,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
stream.downloader.download(self.dht_node)
|
stream.downloader.download(self.dht_node)
|
||||||
msg = "Resumed download"
|
msg = "Resumed download"
|
||||||
elif status == 'stop' and stream.running:
|
elif status == 'stop' and stream.running:
|
||||||
await stream.stop_download()
|
stream.stop_download()
|
||||||
msg = "Stopped download"
|
msg = "Stopped download"
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
|
|
|
@ -43,13 +43,11 @@ class StreamAssembler:
|
||||||
self.written_bytes: int = 0
|
self.written_bytes: int = 0
|
||||||
|
|
||||||
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
|
||||||
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
if not blob or self.stream_handle.closed:
|
||||||
|
return False
|
||||||
|
|
||||||
def _decrypt_and_write():
|
def _decrypt_and_write():
|
||||||
if self.stream_handle.closed:
|
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
|
||||||
return False
|
|
||||||
if not blob:
|
|
||||||
return False
|
|
||||||
self.stream_handle.seek(offset)
|
self.stream_handle.seek(offset)
|
||||||
_decrypted = blob.decrypt(
|
_decrypted = blob.decrypt(
|
||||||
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
|
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
|
||||||
|
@ -57,12 +55,10 @@ class StreamAssembler:
|
||||||
self.stream_handle.write(_decrypted)
|
self.stream_handle.write(_decrypted)
|
||||||
self.stream_handle.flush()
|
self.stream_handle.flush()
|
||||||
self.written_bytes += len(_decrypted)
|
self.written_bytes += len(_decrypted)
|
||||||
return True
|
|
||||||
|
|
||||||
decrypted = await self.loop.run_in_executor(None, _decrypt_and_write)
|
|
||||||
if decrypted:
|
|
||||||
log.debug("decrypted %s", blob.blob_hash[:8])
|
log.debug("decrypted %s", blob.blob_hash[:8])
|
||||||
return
|
self.wrote_bytes_event.set()
|
||||||
|
|
||||||
|
await self.loop.run_in_executor(None, _decrypt_and_write)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
pass
|
pass
|
||||||
|
@ -106,8 +102,6 @@ class StreamAssembler:
|
||||||
self.descriptor.sd_hash)
|
self.descriptor.sd_hash)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.wrote_bytes_event.is_set():
|
|
||||||
self.wrote_bytes_event.set()
|
|
||||||
self.stream_finished_event.set()
|
self.stream_finished_event.set()
|
||||||
await self.after_finished()
|
await self.after_finished()
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -52,11 +52,11 @@ class StreamDownloader(StreamAssembler):
|
||||||
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
||||||
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
||||||
|
|
||||||
async def stop(self):
|
def stop(self):
|
||||||
if self.accumulate_task and not self.accumulate_task.done():
|
if self.accumulate_task:
|
||||||
self.accumulate_task.cancel()
|
self.accumulate_task.cancel()
|
||||||
self.accumulate_task = None
|
self.accumulate_task = None
|
||||||
if self.assemble_task and not self.assemble_task.done():
|
if self.assemble_task:
|
||||||
self.assemble_task.cancel()
|
self.assemble_task.cancel()
|
||||||
self.assemble_task = None
|
self.assemble_task = None
|
||||||
if self.fixed_peers_handle:
|
if self.fixed_peers_handle:
|
||||||
|
@ -80,7 +80,7 @@ class StreamDownloader(StreamAssembler):
|
||||||
and self.node
|
and self.node
|
||||||
and len(self.node.protocol.routing_table.get_peers())
|
and len(self.node.protocol.routing_table.get_peers())
|
||||||
) else 0.0,
|
) else 0.0,
|
||||||
self.loop.create_task, _add_fixed_peers()
|
lambda: self.loop.create_task(_add_fixed_peers())
|
||||||
)
|
)
|
||||||
|
|
||||||
def download(self, node: typing.Optional['Node'] = None):
|
def download(self, node: typing.Optional['Node'] = None):
|
||||||
|
|
|
@ -155,9 +155,9 @@ class ManagedStream:
|
||||||
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
|
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
|
||||||
status=cls.STATUS_FINISHED)
|
status=cls.STATUS_FINISHED)
|
||||||
|
|
||||||
async def stop_download(self):
|
def stop_download(self):
|
||||||
if self.downloader:
|
if self.downloader:
|
||||||
await self.downloader.stop()
|
self.downloader.stop()
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
self.update_status(self.STATUS_STOPPED)
|
self.update_status(self.STATUS_STOPPED)
|
||||||
|
|
||||||
|
|
|
@ -122,12 +122,12 @@ class StreamManager:
|
||||||
await self.load_streams_from_database()
|
await self.load_streams_from_database()
|
||||||
self.resume_downloading_task = self.loop.create_task(self.resume())
|
self.resume_downloading_task = self.loop.create_task(self.resume())
|
||||||
|
|
||||||
async def stop(self):
|
def stop(self):
|
||||||
if self.resume_downloading_task and not self.resume_downloading_task.done():
|
if self.resume_downloading_task and not self.resume_downloading_task.done():
|
||||||
self.resume_downloading_task.cancel()
|
self.resume_downloading_task.cancel()
|
||||||
while self.streams:
|
while self.streams:
|
||||||
stream = self.streams.pop()
|
stream = self.streams.pop()
|
||||||
await stream.stop_download()
|
stream.stop_download()
|
||||||
while self.update_stream_finished_futs:
|
while self.update_stream_finished_futs:
|
||||||
self.update_stream_finished_futs.pop().cancel()
|
self.update_stream_finished_futs.pop().cancel()
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class StreamManager:
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||||
await stream.stop_download()
|
stream.stop_download()
|
||||||
self.streams.remove(stream)
|
self.streams.remove(stream)
|
||||||
await self.storage.delete_stream(stream.descriptor)
|
await self.storage.delete_stream(stream.descriptor)
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ class StreamManager:
|
||||||
log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
|
log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
log.info("stream timeout")
|
log.info("stream timeout")
|
||||||
await downloader.stop()
|
downloader.stop()
|
||||||
log.info("stopped stream")
|
log.info("stopped stream")
|
||||||
return
|
return
|
||||||
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
|
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
|
||||||
|
@ -204,7 +204,8 @@ class StreamManager:
|
||||||
self.wait_for_stream_finished(stream)
|
self.wait_for_stream_finished(stream)
|
||||||
return stream
|
return stream
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await downloader.stop()
|
downloader.stop()
|
||||||
|
log.debug("stopped stream")
|
||||||
|
|
||||||
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
|
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
|
||||||
file_name: typing.Optional[str] = None,
|
file_name: typing.Optional[str] = None,
|
||||||
|
|
|
@ -40,10 +40,12 @@ class TestStreamDownloader(BlobExchangeTestBase):
|
||||||
|
|
||||||
self.downloader.download(mock_node)
|
self.downloader.download(mock_node)
|
||||||
await self.downloader.stream_finished_event.wait()
|
await self.downloader.stream_finished_event.wait()
|
||||||
await self.downloader.stop()
|
self.downloader.stop()
|
||||||
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
self.assertTrue(os.path.isfile(self.downloader.output_path))
|
||||||
with open(self.downloader.output_path, 'rb') as f:
|
with open(self.downloader.output_path, 'rb') as f:
|
||||||
self.assertEqual(f.read(), self.stream_bytes)
|
self.assertEqual(f.read(), self.stream_bytes)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
self.assertTrue(self.downloader.stream_handle.closed)
|
||||||
|
|
||||||
async def test_transfer_stream(self):
|
async def test_transfer_stream(self):
|
||||||
await self._test_transfer_stream(10)
|
await self._test_transfer_stream(10)
|
||||||
|
|
Loading…
Reference in a new issue