Merge pull request #1841 from lbryio/general_fixes_downloader

General fixes downloader
This commit is contained in:
Jack Robison 2019-02-01 15:10:36 -05:00 committed by GitHub
commit 3eac99c509
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 36 additions and 42 deletions

View file

@ -77,7 +77,7 @@ class BlobFile:
def writer_finished(self, writer: HashBlobWriter): def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future): def callback(finished: asyncio.Future):
try: try:
error = finished.result() error = finished.exception()
except Exception as err: except Exception as err:
error = err error = err
if writer in self.writers: # remove this download attempt if writer in self.writers: # remove this download attempt
@ -86,7 +86,7 @@ class BlobFile:
while self.writers: while self.writers:
other = self.writers.pop() other = self.writers.pop()
other.finished.cancel() other.finished.cancel()
t = self.loop.create_task(self.save_verified_blob(writer)) t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
t.add_done_callback(lambda *_: self.finished_writing.set()) t.add_done_callback(lambda *_: self.finished_writing.set())
return return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)): if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
@ -96,24 +96,24 @@ class BlobFile:
raise error raise error
return callback return callback
async def save_verified_blob(self, writer): async def save_verified_blob(self, writer, verified_bytes: bytes):
def _save_verified(): def _save_verified():
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}") # log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
if not self.saved_verified_blob and not os.path.isfile(self.file_path): if not self.saved_verified_blob and not os.path.isfile(self.file_path):
if self.get_length() == len(writer.verified_bytes): if self.get_length() == len(verified_bytes):
with open(self.file_path, 'wb') as write_handle: with open(self.file_path, 'wb') as write_handle:
write_handle.write(writer.verified_bytes) write_handle.write(verified_bytes)
self.saved_verified_blob = True self.saved_verified_blob = True
else: else:
raise Exception("length mismatch") raise Exception("length mismatch")
async with self.blob_write_lock:
if self.verified.is_set(): if self.verified.is_set():
return return
async with self.blob_write_lock:
await self.loop.run_in_executor(None, _save_verified) await self.loop.run_in_executor(None, _save_verified)
self.verified.set()
if self.blob_completed_callback: if self.blob_completed_callback:
await self.blob_completed_callback(self) await self.blob_completed_callback(self)
self.verified.set()
def open_for_writing(self) -> HashBlobWriter: def open_for_writing(self) -> HashBlobWriter:
if os.path.exists(self.file_path): if os.path.exists(self.file_path):

View file

@ -18,7 +18,6 @@ class HashBlobWriter:
self.finished.add_done_callback(lambda *_: self.close_handle()) self.finished.add_done_callback(lambda *_: self.close_handle())
self._hashsum = get_lbry_hash_obj() self._hashsum = get_lbry_hash_obj()
self.len_so_far = 0 self.len_so_far = 0
self.verified_bytes = b''
def __del__(self): def __del__(self):
if self.buffer is not None: if self.buffer is not None:
@ -46,7 +45,7 @@ class HashBlobWriter:
self.len_so_far += len(data) self.len_so_far += len(data)
if self.len_so_far > expected_length: if self.len_so_far > expected_length:
self.close_handle() self.close_handle()
self.finished.set_result(InvalidDataError( self.finished.set_exception(InvalidDataError(
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}' f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
)) ))
return return
@ -55,15 +54,12 @@ class HashBlobWriter:
blob_hash = self.calculate_blob_hash() blob_hash = self.calculate_blob_hash()
if blob_hash != self.expected_blob_hash: if blob_hash != self.expected_blob_hash:
self.close_handle() self.close_handle()
self.finished.set_result(InvalidBlobHashError( self.finished.set_exception(InvalidBlobHashError(
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}" f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
)) ))
return elif self.finished and not (self.finished.done() or self.finished.cancelled()):
self.buffer.seek(0) self.finished.set_result(self.buffer.getvalue())
self.verified_bytes = self.buffer.read()
self.close_handle() self.close_handle()
if self.finished and not (self.finished.done() or self.finished.cancelled()):
self.finished.set_result(None)
def close_handle(self): def close_handle(self):
if self.buffer is not None: if self.buffer is not None:

View file

@ -57,6 +57,7 @@ class BlobDownloader:
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED') await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED')
except asyncio.CancelledError: except asyncio.CancelledError:
drain_tasks(tasks) drain_tasks(tasks)
raise
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
blob = self.blob_manager.get_blob(blob_hash, length) blob = self.blob_manager.get_blob(blob_hash, length)

View file

@ -459,7 +459,7 @@ class StreamManagerComponent(Component):
log.info('Done setting up file manager') log.info('Done setting up file manager')
async def stop(self): async def stop(self):
await self.stream_manager.stop() self.stream_manager.stop()
class PeerProtocolServerComponent(Component): class PeerProtocolServerComponent(Component):

View file

@ -1621,7 +1621,7 @@ class Daemon(metaclass=JSONRPCServerType):
stream.downloader.download(self.dht_node) stream.downloader.download(self.dht_node)
msg = "Resumed download" msg = "Resumed download"
elif status == 'stop' and stream.running: elif status == 'stop' and stream.running:
await stream.stop_download() stream.stop_download()
msg = "Stopped download" msg = "Stopped download"
else: else:
msg = ( msg = (

View file

@ -43,13 +43,11 @@ class StreamAssembler:
self.written_bytes: int = 0 self.written_bytes: int = 0
async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str): async def _decrypt_blob(self, blob: 'BlobFile', blob_info: 'BlobInfo', key: str):
offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1) if not blob or self.stream_handle.closed:
return False
def _decrypt_and_write(): def _decrypt_and_write():
if self.stream_handle.closed: offset = blob_info.blob_num * (MAX_BLOB_SIZE - 1)
return False
if not blob:
return False
self.stream_handle.seek(offset) self.stream_handle.seek(offset)
_decrypted = blob.decrypt( _decrypted = blob.decrypt(
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode()) binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
@ -57,12 +55,10 @@ class StreamAssembler:
self.stream_handle.write(_decrypted) self.stream_handle.write(_decrypted)
self.stream_handle.flush() self.stream_handle.flush()
self.written_bytes += len(_decrypted) self.written_bytes += len(_decrypted)
return True
decrypted = await self.loop.run_in_executor(None, _decrypt_and_write)
if decrypted:
log.debug("decrypted %s", blob.blob_hash[:8]) log.debug("decrypted %s", blob.blob_hash[:8])
return self.wrote_bytes_event.set()
await self.loop.run_in_executor(None, _decrypt_and_write)
async def setup(self): async def setup(self):
pass pass
@ -106,8 +102,6 @@ class StreamAssembler:
self.descriptor.sd_hash) self.descriptor.sd_hash)
continue continue
if not self.wrote_bytes_event.is_set():
self.wrote_bytes_event.set()
self.stream_finished_event.set() self.stream_finished_event.set()
await self.after_finished() await self.after_finished()
finally: finally:

View file

@ -52,11 +52,11 @@ class StreamDownloader(StreamAssembler):
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path) log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished') await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
async def stop(self): def stop(self):
if self.accumulate_task and not self.accumulate_task.done(): if self.accumulate_task:
self.accumulate_task.cancel() self.accumulate_task.cancel()
self.accumulate_task = None self.accumulate_task = None
if self.assemble_task and not self.assemble_task.done(): if self.assemble_task:
self.assemble_task.cancel() self.assemble_task.cancel()
self.assemble_task = None self.assemble_task = None
if self.fixed_peers_handle: if self.fixed_peers_handle:
@ -80,7 +80,7 @@ class StreamDownloader(StreamAssembler):
and self.node and self.node
and len(self.node.protocol.routing_table.get_peers()) and len(self.node.protocol.routing_table.get_peers())
) else 0.0, ) else 0.0,
self.loop.create_task, _add_fixed_peers() lambda: self.loop.create_task(_add_fixed_peers())
) )
def download(self, node: typing.Optional['Node'] = None): def download(self, node: typing.Optional['Node'] = None):

View file

@ -155,9 +155,9 @@ class ManagedStream:
return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path), return cls(loop, blob_manager, descriptor, os.path.dirname(file_path), os.path.basename(file_path),
status=cls.STATUS_FINISHED) status=cls.STATUS_FINISHED)
async def stop_download(self): def stop_download(self):
if self.downloader: if self.downloader:
await self.downloader.stop() self.downloader.stop()
if not self.finished: if not self.finished:
self.update_status(self.STATUS_STOPPED) self.update_status(self.STATUS_STOPPED)

View file

@ -122,12 +122,12 @@ class StreamManager:
await self.load_streams_from_database() await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume()) self.resume_downloading_task = self.loop.create_task(self.resume())
async def stop(self): def stop(self):
if self.resume_downloading_task and not self.resume_downloading_task.done(): if self.resume_downloading_task and not self.resume_downloading_task.done():
self.resume_downloading_task.cancel() self.resume_downloading_task.cancel()
while self.streams: while self.streams:
stream = self.streams.pop() stream = self.streams.pop()
await stream.stop_download() stream.stop_download()
while self.update_stream_finished_futs: while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel() self.update_stream_finished_futs.pop().cancel()
@ -142,7 +142,7 @@ class StreamManager:
return stream return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
await stream.stop_download() stream.stop_download()
self.streams.remove(stream) self.streams.remove(stream)
await self.storage.delete_stream(stream.descriptor) await self.storage.delete_stream(stream.descriptor)
@ -182,7 +182,7 @@ class StreamManager:
log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name']) log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
except (asyncio.TimeoutError, asyncio.CancelledError): except (asyncio.TimeoutError, asyncio.CancelledError):
log.info("stream timeout") log.info("stream timeout")
await downloader.stop() downloader.stop()
log.info("stopped stream") log.info("stopped stream")
return return
if not await self.blob_manager.storage.stream_exists(downloader.sd_hash): if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
@ -204,7 +204,8 @@ class StreamManager:
self.wait_for_stream_finished(stream) self.wait_for_stream_finished(stream)
return stream return stream
except asyncio.CancelledError: except asyncio.CancelledError:
await downloader.stop() downloader.stop()
log.debug("stopped stream")
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict, async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
file_name: typing.Optional[str] = None, file_name: typing.Optional[str] = None,

View file

@ -40,10 +40,12 @@ class TestStreamDownloader(BlobExchangeTestBase):
self.downloader.download(mock_node) self.downloader.download(mock_node)
await self.downloader.stream_finished_event.wait() await self.downloader.stream_finished_event.wait()
await self.downloader.stop() self.downloader.stop()
self.assertTrue(os.path.isfile(self.downloader.output_path)) self.assertTrue(os.path.isfile(self.downloader.output_path))
with open(self.downloader.output_path, 'rb') as f: with open(self.downloader.output_path, 'rb') as f:
self.assertEqual(f.read(), self.stream_bytes) self.assertEqual(f.read(), self.stream_bytes)
await asyncio.sleep(0.01)
self.assertTrue(self.downloader.stream_handle.closed)
async def test_transfer_stream(self): async def test_transfer_stream(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)