test incomplete streams, respect real length, delete on incomplete assembly

This commit is contained in:
Victor Shyba 2019-02-06 15:40:16 -03:00
parent 3508da4993
commit 3a15ff4bcd
4 changed files with 23 additions and 2 deletions

View file

@ -67,7 +67,7 @@ class BlobFile:
self.finished_writing = asyncio.Event(loop=loop) self.finished_writing = asyncio.Event(loop=loop)
self.blob_write_lock = asyncio.Lock(loop=loop) self.blob_write_lock = asyncio.Lock(loop=loop)
if os.path.isfile(os.path.join(blob_dir, blob_hash)): if os.path.isfile(os.path.join(blob_dir, blob_hash)):
length = length or int(os.stat(os.path.join(blob_dir, blob_hash)).st_size) length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
self.length = length self.length = length
self.verified.set() self.verified.set()
self.finished_writing.set() self.finished_writing.set()

View file

@ -96,6 +96,10 @@ class StreamAssembler:
while self.stream_handle and not self.stream_handle.closed: while self.stream_handle and not self.stream_handle.closed:
try: try:
blob = await self.get_blob(blob_info.blob_hash, blob_info.length) blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
if blob and blob.length != blob_info.length:
log.warning("Found incomplete, deleting: %s", blob_info.blob_hash)
await self.blob_manager.delete_blobs([blob_info.blob_hash])
continue
if await self._decrypt_blob(blob, blob_info, self.descriptor.key): if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
await self.blob_manager.blob_completed(blob) await self.blob_manager.blob_completed(blob)
written_blobs = i written_blobs = i

View file

@ -58,6 +58,10 @@ class StreamDescriptor:
self.stream_hash = stream_hash or self.get_stream_hash() self.stream_hash = stream_hash or self.get_stream_hash()
self.sd_hash = sd_hash self.sd_hash = sd_hash
@property
def length(self):
return len(self.as_json())
def get_stream_hash(self) -> str: def get_stream_hash(self) -> str:
return self.calculate_stream_hash( return self.calculate_stream_hash(
binascii.hexlify(self.stream_name.encode()), self.key.encode(), binascii.hexlify(self.stream_name.encode()), self.key.encode(),

View file

@ -19,7 +19,7 @@ class TestStreamAssembler(AsyncioTestCase):
self.key = b'deadbeef' * 4 self.key = b'deadbeef' * 4
self.cleartext = b'test' self.cleartext = b'test'
async def test_create_and_decrypt_one_blob_stream(self): async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir)) self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), ":memory:") self.storage = SQLiteStorage(Config(), ":memory:")
@ -42,6 +42,11 @@ class TestStreamAssembler(AsyncioTestCase):
for blob_info in sd.blobs: for blob_info in sd.blobs:
if blob_info.blob_hash: if blob_info.blob_hash:
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash)) shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite")) downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
await downloader_storage.open() await downloader_storage.open()
@ -55,6 +60,8 @@ class TestStreamAssembler(AsyncioTestCase):
with open(os.path.join(download_dir, "test_file"), "rb") as f: with open(os.path.join(download_dir, "test_file"), "rb") as f:
decrypted = f.read() decrypted = f.read()
if corrupt:
return decrypted
self.assertEqual(decrypted, self.cleartext) self.assertEqual(decrypted, self.cleartext)
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified()) self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified()) self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
@ -103,3 +110,9 @@ class TestStreamAssembler(AsyncioTestCase):
self.assertEqual( self.assertEqual(
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash], [stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
await storage.get_blobs_to_announce()) await storage.get_blobs_to_announce())
async def test_create_truncate_and_handle_stream(self):
self.cleartext = b'potato' * 1337 * 5279
decrypted = await self.test_create_and_decrypt_one_blob_stream(corrupt=True)
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
self.assertFalse(decrypted)