Fix unit tests.

This commit is contained in:
Jonathan Moody 2023-01-08 11:25:08 -06:00
parent eb2bbca100
commit 62db078080
4 changed files with 30 additions and 17 deletions

View file

@ -84,6 +84,7 @@ class AbstractBlob:
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_manager: typing.Optional['BlobManager'] = None,
added_on: typing.Optional[int] = None, is_mine: bool = False,
error_fmt: str = "invalid blob directory '%s'",
):
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
@ -104,7 +105,7 @@ class AbstractBlob:
self.is_mine = is_mine
if not self.blob_directory or not os.path.isdir(self.blob_directory):
raise OSError(f"cannot create blob in directory: '{self.blob_directory}'")
raise OSError(error_fmt%(self.blob_directory))
def __del__(self):
if self.writers or self.readers:
@ -207,7 +208,10 @@ class AbstractBlob:
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes)
blob = cls(loop, blob_hash, length, blob_completed_callback, blob_manager, added_on, is_mine)
blob = cls(
loop, blob_hash, length, blob_completed_callback, blob_manager, added_on, is_mine,
error_fmt="cannot create blob in directory: '%s'",
)
writer = blob.get_blob_writer()
writer.write(blob_bytes)
await blob.verified.wait()
@ -314,9 +318,13 @@ class BlobFile(AbstractBlob):
self, loop: asyncio.AbstractEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None,
blob_manager: typing.Optional['BlobManager'] = None,
added_on: typing.Optional[int] = None, is_mine: bool = False
added_on: typing.Optional[int] = None, is_mine: bool = False,
error_fmt: str = "invalid blob directory '%s'",
):
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_manager, added_on, is_mine)
super().__init__(
loop, blob_hash, length, blob_completed_callback, blob_manager, added_on, is_mine,
error_fmt,
)
self.file_path = os.path.join(self.blob_directory, self.blob_hash)
if self.file_exists:
file_size = int(os.stat(self.file_path).st_size)

View file

@ -24,7 +24,7 @@ class TestManagedStream(BlobExchangeTestBase):
file_path = os.path.join(self.server_dir, file_name)
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager, file_path)
descriptor.suggested_file_name = file_name
descriptor.stream_hash = descriptor.get_stream_hash()
self.sd_hash = descriptor.sd_hash = descriptor.calculate_sd_hash()
@ -166,16 +166,16 @@ class TestManagedStream(BlobExchangeTestBase):
descriptor = await self.create_stream(blobs)
# copy blob files
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, self.sd_hash),
os.path.join(self.client_blob_manager.blob_dir, self.sd_hash))
shutil.copy(os.path.join(self.server_blob_manager._blob_dir(self.sd_hash)[0], self.sd_hash),
os.path.join(self.client_blob_manager._blob_dir(self.sd_hash)[0], self.sd_hash))
self.stream = ManagedStream(self.loop, self.client_config, self.client_blob_manager, self.sd_hash,
self.client_dir)
for blob_info in descriptor.blobs[:-1]:
shutil.copy(os.path.join(self.server_blob_manager.blob_dir, blob_info.blob_hash),
os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash))
shutil.copy(os.path.join(self.server_blob_manager._blob_dir(blob_info.blob_hash)[0], blob_info.blob_hash),
os.path.join(self.client_blob_manager._blob_dir(blob_info.blob_hash)[0], blob_info.blob_hash))
if corrupt and blob_info.length == MAX_BLOB_SIZE:
with open(os.path.join(self.client_blob_manager.blob_dir, blob_info.blob_hash), "rb+") as handle:
with open(os.path.join(self.client_blob_manager._blob_dir(blob_info.blob_hash)[0], blob_info.blob_hash), "rb+") as handle:
handle.truncate()
handle.flush()
await self.stream.save_file()

View file

@ -29,7 +29,7 @@ class TestStreamDescriptor(AsyncioTestCase):
with open(self.file_path, 'wb') as f:
f.write(self.cleartext)
self.descriptor = await StreamDescriptor.create_stream(self.loop, self.tmp_dir, self.file_path, key=self.key)
self.descriptor = await StreamDescriptor.create_stream(self.loop, self.blob_manager, self.file_path, key=self.key)
self.sd_hash = self.descriptor.calculate_sd_hash()
self.sd_dict = json.loads(self.descriptor.as_json())
@ -114,7 +114,7 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
writer.write(sd_bytes)
await blob.verified.wait()
descriptor = await StreamDescriptor.from_stream_descriptor_blob(
loop, blob_manager.blob_dir, blob
loop, blob_manager, blob
)
self.assertEqual(stream_hash, descriptor.get_stream_hash())
self.assertEqual(sd_hash, descriptor.calculate_old_sort_sd_hash())
@ -124,10 +124,15 @@ class TestRecoverOldStreamDescriptors(AsyncioTestCase):
loop = asyncio.get_event_loop()
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.conf = Config()
storage = SQLiteStorage(self.conf, ":memory:")
await storage.open()
blob_manager = BlobManager(loop, tmp_dir, storage, self.conf)
sd_hash = '9313d1807551186126acc3662e74d9de29cede78d4f133349ace846273ef116b9bb86be86c54509eb84840e4b032f6b2'
with open(os.path.join(tmp_dir, sd_hash), 'wb') as handle:
handle.write(b'doesnt work')
blob = BlobFile(loop, sd_hash, blob_directory=tmp_dir)
blob = BlobFile(loop, sd_hash, blob_manager=blob_manager)
self.assertTrue(blob.file_exists)
self.assertIsNotNone(blob.length)
with self.assertRaises(InvalidStreamDescriptorError):

View file

@ -136,7 +136,7 @@ class TestStreamManager(BlobExchangeTestBase):
with open(file_path, 'wb') as f:
f.write(os.urandom(20000000))
descriptor = await StreamDescriptor.create_stream(
self.loop, self.server_blob_manager.blob_dir, file_path, old_sort=old_sort
self.loop, self.server_blob_manager, file_path, old_sort=old_sort
)
self.sd_hash = descriptor.sd_hash
self.mock_wallet, self.uri = await get_mock_wallet(self.sd_hash, self.client_storage, self.client_wallet_dir,
@ -453,7 +453,7 @@ class TestStreamManager(BlobExchangeTestBase):
self.client_blob_manager.stop()
# partial removal, only sd blob is missing.
# in this case, we recover the sd blob while the other blobs are kept untouched as 'finished'
os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash))
os.remove(os.path.join(self.client_blob_manager._blob_dir(stream.sd_hash)[0], stream.sd_hash))
await self.client_blob_manager.setup()
await self.stream_manager.start()
self.assertEqual(1, len(self.stream_manager.streams))
@ -470,9 +470,9 @@ class TestStreamManager(BlobExchangeTestBase):
# full removal, check that status is preserved (except sd blob, which was written)
self.client_blob_manager.stop()
os.remove(os.path.join(self.client_blob_manager.blob_dir, stream.sd_hash))
os.remove(os.path.join(self.client_blob_manager._blob_dir(stream.sd_hash)[0], stream.sd_hash))
for blob in stream.descriptor.blobs[:-1]:
os.remove(os.path.join(self.client_blob_manager.blob_dir, blob.blob_hash))
os.remove(os.path.join(self.client_blob_manager._blob_dir(blob.blob_hash)[0], blob.blob_hash))
await self.client_blob_manager.setup()
await self.stream_manager.start()
for blob_hash in [b.blob_hash for b in stream.descriptor.blobs[:-1]]: