fixes for writers (check inside lock and dont hold verified bytes)

This commit is contained in:
Victor Shyba 2019-02-01 16:02:27 -03:00
parent df0635103e
commit 25050fdeeb
2 changed files with 12 additions and 16 deletions

View file

@ -77,7 +77,7 @@ class BlobFile:
def writer_finished(self, writer: HashBlobWriter): def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future): def callback(finished: asyncio.Future):
try: try:
error = finished.result() error = finished.exception()
except Exception as err: except Exception as err:
error = err error = err
if writer in self.writers: # remove this download attempt if writer in self.writers: # remove this download attempt
@ -86,7 +86,7 @@ class BlobFile:
while self.writers: while self.writers:
other = self.writers.pop() other = self.writers.pop()
other.finished.cancel() other.finished.cancel()
t = self.loop.create_task(self.save_verified_blob(writer)) t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
t.add_done_callback(lambda *_: self.finished_writing.set()) t.add_done_callback(lambda *_: self.finished_writing.set())
return return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)): if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
@ -96,24 +96,24 @@ class BlobFile:
raise error raise error
return callback return callback
async def save_verified_blob(self, writer): async def save_verified_blob(self, writer, verified_bytes):
def _save_verified(): def _save_verified():
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}") # log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
if not self.saved_verified_blob and not os.path.isfile(self.file_path): if not self.saved_verified_blob and not os.path.isfile(self.file_path):
if self.get_length() == len(writer.verified_bytes): if self.get_length() == len(verified_bytes):
with open(self.file_path, 'wb') as write_handle: with open(self.file_path, 'wb') as write_handle:
write_handle.write(writer.verified_bytes) write_handle.write(verified_bytes)
self.saved_verified_blob = True self.saved_verified_blob = True
else: else:
raise Exception("length mismatch") raise Exception("length mismatch")
async with self.blob_write_lock:
if self.verified.is_set(): if self.verified.is_set():
return return
async with self.blob_write_lock:
await self.loop.run_in_executor(None, _save_verified) await self.loop.run_in_executor(None, _save_verified)
self.verified.set()
if self.blob_completed_callback: if self.blob_completed_callback:
await self.blob_completed_callback(self) await self.blob_completed_callback(self)
self.verified.set()
def open_for_writing(self) -> HashBlobWriter: def open_for_writing(self) -> HashBlobWriter:
if os.path.exists(self.file_path): if os.path.exists(self.file_path):

View file

@ -18,7 +18,6 @@ class HashBlobWriter:
self.finished.add_done_callback(lambda *_: self.close_handle()) self.finished.add_done_callback(lambda *_: self.close_handle())
self._hashsum = get_lbry_hash_obj() self._hashsum = get_lbry_hash_obj()
self.len_so_far = 0 self.len_so_far = 0
self.verified_bytes = b''
def __del__(self): def __del__(self):
if self.buffer is not None: if self.buffer is not None:
@ -46,7 +45,7 @@ class HashBlobWriter:
self.len_so_far += len(data) self.len_so_far += len(data)
if self.len_so_far > expected_length: if self.len_so_far > expected_length:
self.close_handle() self.close_handle()
self.finished.set_result(InvalidDataError( self.finished.set_exception(InvalidDataError(
f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}' f'Length so far is greater than the expected length. {self.len_so_far} to {expected_length}'
)) ))
return return
@ -55,15 +54,12 @@ class HashBlobWriter:
blob_hash = self.calculate_blob_hash() blob_hash = self.calculate_blob_hash()
if blob_hash != self.expected_blob_hash: if blob_hash != self.expected_blob_hash:
self.close_handle() self.close_handle()
self.finished.set_result(InvalidBlobHashError( self.finished.set_exception(InvalidBlobHashError(
f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}" f"blob hash is {blob_hash} vs expected {self.expected_blob_hash}"
)) ))
return elif self.finished and not (self.finished.done() or self.finished.cancelled()):
self.buffer.seek(0) self.finished.set_result(self.buffer.getvalue())
self.verified_bytes = self.buffer.read()
self.close_handle() self.close_handle()
if self.finished and not (self.finished.done() or self.finished.cancelled()):
self.finished.set_result(None)
def close_handle(self): def close_handle(self):
if self.buffer is not None: if self.buffer is not None: