test incomplete streams, respect real length, delete on incomplete assembly
This commit is contained in:
parent
3508da4993
commit
3a15ff4bcd
4 changed files with 23 additions and 2 deletions
|
@ -67,7 +67,7 @@ class BlobFile:
|
||||||
self.finished_writing = asyncio.Event(loop=loop)
|
self.finished_writing = asyncio.Event(loop=loop)
|
||||||
self.blob_write_lock = asyncio.Lock(loop=loop)
|
self.blob_write_lock = asyncio.Lock(loop=loop)
|
||||||
if os.path.isfile(os.path.join(blob_dir, blob_hash)):
|
if os.path.isfile(os.path.join(blob_dir, blob_hash)):
|
||||||
length = length or int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
|
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
|
||||||
self.length = length
|
self.length = length
|
||||||
self.verified.set()
|
self.verified.set()
|
||||||
self.finished_writing.set()
|
self.finished_writing.set()
|
||||||
|
|
|
@ -96,6 +96,10 @@ class StreamAssembler:
|
||||||
while self.stream_handle and not self.stream_handle.closed:
|
while self.stream_handle and not self.stream_handle.closed:
|
||||||
try:
|
try:
|
||||||
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
|
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
|
||||||
|
if blob and blob.length != blob_info.length:
|
||||||
|
log.warning("Found incomplete, deleting: %s", blob_info.blob_hash)
|
||||||
|
await self.blob_manager.delete_blobs([blob_info.blob_hash])
|
||||||
|
continue
|
||||||
if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
|
if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
|
||||||
await self.blob_manager.blob_completed(blob)
|
await self.blob_manager.blob_completed(blob)
|
||||||
written_blobs = i
|
written_blobs = i
|
||||||
|
|
|
@ -58,6 +58,10 @@ class StreamDescriptor:
|
||||||
self.stream_hash = stream_hash or self.get_stream_hash()
|
self.stream_hash = stream_hash or self.get_stream_hash()
|
||||||
self.sd_hash = sd_hash
|
self.sd_hash = sd_hash
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self):
|
||||||
|
return len(self.as_json())
|
||||||
|
|
||||||
def get_stream_hash(self) -> str:
|
def get_stream_hash(self) -> str:
|
||||||
return self.calculate_stream_hash(
|
return self.calculate_stream_hash(
|
||||||
binascii.hexlify(self.stream_name.encode()), self.key.encode(),
|
binascii.hexlify(self.stream_name.encode()), self.key.encode(),
|
||||||
|
|
|
@ -19,7 +19,7 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
self.key = b'deadbeef' * 4
|
self.key = b'deadbeef' * 4
|
||||||
self.cleartext = b'test'
|
self.cleartext = b'test'
|
||||||
|
|
||||||
async def test_create_and_decrypt_one_blob_stream(self):
|
async def test_create_and_decrypt_one_blob_stream(self, corrupt=False):
|
||||||
tmp_dir = tempfile.mkdtemp()
|
tmp_dir = tempfile.mkdtemp()
|
||||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||||
self.storage = SQLiteStorage(Config(), ":memory:")
|
self.storage = SQLiteStorage(Config(), ":memory:")
|
||||||
|
@ -42,6 +42,11 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
for blob_info in sd.blobs:
|
for blob_info in sd.blobs:
|
||||||
if blob_info.blob_hash:
|
if blob_info.blob_hash:
|
||||||
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
|
shutil.copy(os.path.join(tmp_dir, blob_info.blob_hash), os.path.join(download_dir, blob_info.blob_hash))
|
||||||
|
if corrupt and blob_info.length == MAX_BLOB_SIZE:
|
||||||
|
with open(os.path.join(download_dir, blob_info.blob_hash), "rb+") as handle:
|
||||||
|
handle.truncate()
|
||||||
|
handle.flush()
|
||||||
|
|
||||||
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
|
downloader_storage = SQLiteStorage(Config(), os.path.join(download_dir, "lbrynet.sqlite"))
|
||||||
await downloader_storage.open()
|
await downloader_storage.open()
|
||||||
|
|
||||||
|
@ -55,6 +60,8 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
|
|
||||||
with open(os.path.join(download_dir, "test_file"), "rb") as f:
|
with open(os.path.join(download_dir, "test_file"), "rb") as f:
|
||||||
decrypted = f.read()
|
decrypted = f.read()
|
||||||
|
if corrupt:
|
||||||
|
return decrypted
|
||||||
self.assertEqual(decrypted, self.cleartext)
|
self.assertEqual(decrypted, self.cleartext)
|
||||||
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
|
self.assertEqual(True, self.blob_manager.get_blob(sd_hash).get_is_verified())
|
||||||
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
|
self.assertEqual(True, self.blob_manager.get_blob(descriptor.blobs[0].blob_hash).get_is_verified())
|
||||||
|
@ -103,3 +110,9 @@ class TestStreamAssembler(AsyncioTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
|
[stream.sd_hash, stream.descriptor.blobs[0].blob_hash],
|
||||||
await storage.get_blobs_to_announce())
|
await storage.get_blobs_to_announce())
|
||||||
|
|
||||||
|
async def test_create_truncate_and_handle_stream(self):
|
||||||
|
self.cleartext = b'potato' * 1337 * 5279
|
||||||
|
decrypted = await self.test_create_and_decrypt_one_blob_stream(corrupt=True)
|
||||||
|
# The purpose of this test is just to make sure it can finish even if a blob is corrupt/truncated
|
||||||
|
self.assertFalse(decrypted)
|
||||||
|
|
Loading…
Reference in a new issue