diff --git a/lbrynet/blob/blob_file.py b/lbrynet/blob/blob_file.py index a566652e2..fef7396fd 100644 --- a/lbrynet/blob/blob_file.py +++ b/lbrynet/blob/blob_file.py @@ -44,14 +44,13 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl return encrypted, digest.hexdigest() -def decrypt_blob_bytes(read_handle: typing.BinaryIO, length: int, key: bytes, iv: bytes) -> bytes: - buff = read_handle.read() - if len(buff) != length: +def decrypt_blob_bytes(data: bytes, length: int, key: bytes, iv: bytes) -> bytes: + if len(data) != length: raise ValueError("unexpected length") cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) unpadder = PKCS7(AES.block_size).unpadder() decryptor = cipher.decryptor() - return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize() + return unpadder.update(decryptor.update(data) + decryptor.finalize()) + unpadder.finalize() class AbstractBlob: @@ -73,7 +72,7 @@ class AbstractBlob: ] def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, - blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_directory: typing.Optional[str] = None): self.loop = loop self.blob_hash = blob_hash @@ -99,12 +98,15 @@ class AbstractBlob: @contextlib.contextmanager def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: - try: - with self._reader_context() as reader: + if not self.is_readable(): + raise OSError("not readable") + with self._reader_context() as reader: + try: self.readers.append(reader) yield reader - finally: - self.readers = [reader for reader in self.readers if reader is not None] + finally: + if reader in self.readers: + self.readers.remove(reader) def _write_blob(self, blob_bytes: bytes): raise NotImplementedError() @@ -167,7 +169,7 @@ class AbstractBlob: """ with self.reader_context() as reader: - return decrypt_blob_bytes(reader, self.length, key, iv) + return decrypt_blob_bytes(reader.read(), self.length, key, iv) @classmethod async def create_from_unencrypted( @@ -191,9 +193,10 @@ class AbstractBlob: return if self.is_writeable(): self._write_blob(verified_bytes) - self.verified.set() if self.blob_completed_callback: - self.blob_completed_callback(self) + self.blob_completed_callback(self).add_done_callback(lambda _: self.verified.set()) + else: + self.verified.set() def get_blob_writer(self) -> HashBlobWriter: fut = asyncio.Future(loop=self.loop) @@ -228,7 +231,7 @@ class BlobBuffer(AbstractBlob): An in-memory only blob """ def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, - blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_directory: typing.Optional[str] = None): self._verified_bytes: typing.Optional[BytesIO] = None super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) @@ -240,7 +243,8 @@ class BlobBuffer(AbstractBlob): try: yield self._verified_bytes finally: - self._verified_bytes.close() + if self._verified_bytes: + self._verified_bytes.close() self._verified_bytes = None self.verified.clear() @@ -266,7 +270,7 @@ class BlobFile(AbstractBlob): A blob existing on the local file system """ def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None, - blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None, + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], asyncio.Task]] = None, blob_directory: typing.Optional[str] = None): super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory) if not blob_directory or not os.path.isdir(blob_directory): @@ -314,7 +318,8 @@ class BlobFile(AbstractBlob): async def create_from_unencrypted( cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes, iv: bytes, unencrypted: bytes, blob_num: int, - blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], None]] = None) -> BlobInfo: + blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], + asyncio.Task]] = None) -> BlobInfo: if not blob_dir or not os.path.isdir(blob_dir): raise OSError(f"cannot create blob in directory: '{blob_dir}'") return await super().create_from_unencrypted( diff --git a/lbrynet/blob/blob_manager.py b/lbrynet/blob/blob_manager.py index c86b552e1..24900a4d9 100644 --- a/lbrynet/blob/blob_manager.py +++ b/lbrynet/blob/blob_manager.py @@ -37,7 +37,7 @@ class BlobManager: self.loop, blob_hash, length, self.blob_completed, self.blob_dir ) else: - if length and is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash)): + if is_valid_blobhash(blob_hash) and os.path.isfile(os.path.join(self.blob_dir, blob_hash)): return BlobFile( self.loop, blob_hash, length, self.blob_completed, self.blob_dir ) @@ -47,6 +47,14 @@ class BlobManager: def get_blob(self, blob_hash, length: typing.Optional[int] = None): if blob_hash in self.blobs: + if self.config.save_blobs and isinstance(self.blobs[blob_hash], BlobBuffer): + buffer = self.blobs.pop(blob_hash) + if blob_hash in self.completed_blob_hashes: + self.completed_blob_hashes.remove(blob_hash) + self.blobs[blob_hash] = self._get_blob(blob_hash, length) + if buffer.is_readable(): + with buffer.reader_context() as reader: + self.blobs[blob_hash].write_blob(reader.read()) if length and self.blobs[blob_hash].length is None: self.blobs[blob_hash].set_length(length) else: @@ -75,19 +83,17 @@ class BlobManager: def get_stream_descriptor(self, sd_hash): return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash)) - def blob_completed(self, blob: AbstractBlob): + def blob_completed(self, blob: AbstractBlob) -> asyncio.Task: if blob.blob_hash is None: raise Exception("Blob hash is None") if not blob.length: raise Exception("Blob has a length of 0") - if not blob.get_is_verified(): - raise Exception("Blob is not verified") if isinstance(blob, BlobFile): if blob.blob_hash not in self.completed_blob_hashes: self.completed_blob_hashes.add(blob.blob_hash) - self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=True)) + return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=True)) else: - self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=False)) + return self.loop.create_task(self.storage.add_blobs((blob.blob_hash, blob.length), finished=False)) def check_completed_blobs(self, blob_hashes: typing.List[str]) -> typing.List[str]: """Returns of the blobhashes_to_check, which are valid""" diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index 68b37faad..209fb9327 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -48,6 +48,7 @@ if typing.TYPE_CHECKING: from lbrynet.wallet.manager import LbryWalletManager from lbrynet.wallet.ledger import MainNetLedger from lbrynet.stream.stream_manager import StreamManager + from lbrynet.stream.managed_stream import ManagedStream log = logging.getLogger(__name__) @@ -473,24 +474,19 @@ class Daemon(metaclass=JSONRPCServerType): name, claim_id = name_and_claim_id.split("/") uri = f"lbry://{name}#{claim_id}" stream = await self.jsonrpc_get(uri) + if isinstance(stream, dict): + raise web.HTTPServerError(text=stream['error']) raise web.HTTPFound(f"/stream/{stream.sd_hash}") - async def handle_stream_range_request(self, request: web.Request): - sd_hash = request.path.split("/stream/")[1] - if sd_hash not in self.stream_manager.streams: - return web.HTTPNotFound() - stream = self.stream_manager.streams[sd_hash] - - get_range = request.headers.get('range', 'bytes=0-') + @staticmethod + def prepare_range_response_headers(get_range: str, stream: 'ManagedStream') -> typing.Tuple[typing.Dict[str, str], + int, int]: if '=' in get_range: get_range = get_range.split('=')[1] start, end = get_range.split('-') size = 0 - await self.stream_manager.start_stream(stream) for blob in stream.descriptor.blobs[:-1]: - size += 2097152 - 1 if blob.length == 2097152 else blob.length - size -= 15 # last padding is unguessable - + size += blob.length - 1 start = int(start) end = int(end) if end else size - 1 skip_blobs = start // 2097150 @@ -504,18 +500,36 @@ class Daemon(metaclass=JSONRPCServerType): 'Content-Length': str(final_size), 'Content-Type': stream.mime_type } + return headers, size, skip_blobs + async def handle_stream_range_request(self, request: web.Request): + sd_hash = request.path.split("/stream/")[1] + if sd_hash not in self.stream_manager.streams: + return web.HTTPNotFound() + stream = self.stream_manager.streams[sd_hash] + if stream.status == 'stopped': + await self.stream_manager.start_stream(stream) if stream.delayed_stop: stream.delayed_stop.cancel() + headers, size, skip_blobs = self.prepare_range_response_headers( + request.headers.get('range', 'bytes=0-'), stream + ) response = web.StreamResponse( status=206, headers=headers ) await response.prepare(request) + wrote = 0 async for blob_info, decrypted in stream.aiter_read_stream(skip_blobs): - await response.write(decrypted) - log.info("sent browser blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1) - await response.write_eof() + if (blob_info.blob_num == len(stream.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size): + decrypted += b'\x00' * (size - len(decrypted) - wrote) + await response.write_eof(decrypted) + break + else: + await response.write(decrypted) + wrote += len(decrypted) + log.info("streamed blob %i/%i", blob_info.blob_num + 1, len(stream.descriptor.blobs) - 1) + response.force_close() return response async def _process_rpc_call(self, data): diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index 529e49c80..8c0653d4d 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -107,20 +107,17 @@ class StreamDownloader: blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length) return blob - def _decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob'): + def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: return blob.decrypt( binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) ) - async def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: - return await self.loop.run_in_executor(None, self._decrypt_blob, blob_info, blob) - async def read_blob(self, blob_info: 'BlobInfo') -> bytes: start = None if self.time_to_first_bytes is None: start = self.loop.time() blob = await self.download_stream_blob(blob_info) - decrypted = await self.decrypt_blob(blob_info, blob) + decrypted = self.decrypt_blob(blob_info, blob) if start: self.time_to_first_bytes = self.loop.time() - start return decrypted diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index 19f4f1857..96f938eb7 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -226,10 +226,12 @@ class ManagedStream: self.rowid = self.blob_manager.storage.save_downloaded_file( self.stream_hash, None, None, 0.0 ) + self.download_directory = None + self._file_name = None self.update_status(ManagedStream.STATUS_RUNNING) await self.blob_manager.storage.change_file_status(self.stream_hash, ManagedStream.STATUS_RUNNING) self.update_delayed_stop() - else: + elif not os.path.isfile(self.full_path): await self.save_file(file_name, download_directory) await self.started_writing.wait() diff --git a/tests/integration/test_range_requests.py b/tests/integration/test_range_requests.py index 9e0b7ecc1..2462b8271 100644 --- a/tests/integration/test_range_requests.py +++ b/tests/integration/test_range_requests.py @@ -1,68 +1,313 @@ -import asyncio import aiohttp import aiohttp.web import os import hashlib -import logging from lbrynet.utils import aiohttp_request +from lbrynet.blob.blob_file import MAX_BLOB_SIZE from lbrynet.testcase import CommandTestCase -log = logging.getLogger(__name__) + +def get_random_bytes(n: int) -> bytes: + result = b''.join(hashlib.sha256(os.urandom(4)).digest() for _ in range(n // 16)) + if len(result) < n: + result += os.urandom(n - len(result)) + elif len(result) > n: + result = result[:-(len(result) - n)] + assert len(result) == n, (n, len(result)) + return result class RangeRequests(CommandTestCase): + async def _restart_stream_manager(self): + self.daemon.stream_manager.stop() + await self.daemon.stream_manager.start() + return - VERBOSITY = logging.WARN - - async def _test_range_requests(self, data: bytes, save_blobs: bool = True, streaming_only: bool = True): + async def _setup_stream(self, data: bytes, save_blobs: bool = True, streaming_only: bool = True): self.daemon.conf.save_blobs = save_blobs self.daemon.conf.streaming_only = streaming_only self.data = data await self.stream_create('foo', '0.01', data=self.data) + if save_blobs: + self.assertTrue(len(os.listdir(self.daemon.blob_manager.blob_dir)) > 1) + await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait() await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, claim_name='foo') - - self.daemon.stream_manager.stop() - await self.daemon.stream_manager.start() - + self.assertEqual(0, len(os.listdir(self.daemon.blob_manager.blob_dir))) + # await self._restart_stream_manager() await self.daemon.runner.setup() site = aiohttp.web.TCPSite(self.daemon.runner, self.daemon.conf.api_host, self.daemon.conf.api_port) await site.start() self.assertListEqual(self.daemon.jsonrpc_file_list(), []) + + async def _test_range_requests(self): name = 'foo' url = f'http://{self.daemon.conf.api_host}:{self.daemon.conf.api_port}/get/{name}' - streamed_bytes = b'' async with aiohttp_request('get', url) as req: self.assertEqual(req.headers.get('Content-Type'), 'application/octet-stream') content_range = req.headers.get('Content-Range') - while True: - try: - data, eof = await asyncio.wait_for(req.content.readchunk(), 3, loop=self.loop) - except asyncio.TimeoutError: - data = b'' - eof = True - if data: - streamed_bytes += data - if not data or eof: - break - self.assertTrue((len(streamed_bytes) + 16 >= len(self.data)) - and (len(streamed_bytes) <= len(self.data))) - return streamed_bytes, content_range + content_length = int(req.headers.get('Content-Length')) + streamed_bytes = await req.content.read() + self.assertEqual(content_length, len(streamed_bytes)) + return streamed_bytes, content_range, content_length - async def test_range_requests_0_padded_bytes(self): - self.data = b''.join(hashlib.sha256(os.urandom(16)).digest() for _ in range(250000)) + b'0000000000000' - streamed, content_range = await self._test_range_requests(self.data) - self.assertEqual(streamed, self.data) - self.assertEqual(content_range, 'bytes 0-8000013/8000014') + async def test_range_requests_2_byte(self): + self.data = b'hi' + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(15, content_length) + self.assertEqual(b'hi\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', streamed) + self.assertEqual('bytes 0-14/15', content_range) + + async def test_range_requests_15_byte(self): + self.data = b'123456789abcdef' + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(15, content_length) + self.assertEqual(15, len(streamed)) + self.assertEqual(self.data, streamed) + self.assertEqual('bytes 0-14/15', content_range) + + async def test_range_requests_0_padded_bytes(self, size: int = (MAX_BLOB_SIZE - 1) * 4, + expected_range: str = 'bytes 0-8388603/8388604', padding=b''): + self.data = get_random_bytes(size) + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(len(self.data + padding), content_length) + self.assertEqual(streamed, self.data + padding) + self.assertEqual(expected_range, content_range) async def test_range_requests_1_padded_bytes(self): - self.data = b''.join(hashlib.sha256(os.urandom(16)).digest() for _ in range(250000)) + b'00000000000001x' - streamed, content_range = await self._test_range_requests(self.data) - self.assertEqual(streamed, self.data[:-1]) - self.assertEqual(content_range, 'bytes 0-8000013/8000014') + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 1, padding=b'\x00' + ) async def test_range_requests_2_padded_bytes(self): - self.data = b''.join(hashlib.sha256(os.urandom(16)).digest() for _ in range(250000)) - streamed, content_range = await self._test_range_requests(self.data) - self.assertEqual(streamed, self.data[:-2]) - self.assertEqual(content_range, 'bytes 0-7999997/7999998') + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 2, padding=b'\x00' * 2 + ) + + async def test_range_requests_14_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 14, padding=b'\x00' * 14 + ) + + async def test_range_requests_15_padded_bytes(self): + await self.test_range_requests_0_padded_bytes( + ((MAX_BLOB_SIZE - 1) * 4) - 15, padding=b'\x00' * 15 + ) + + async def test_range_requests_last_block_of_last_blob_padding(self): + self.data = get_random_bytes(((MAX_BLOB_SIZE - 1) * 4) - 16) + await self._setup_stream(self.data) + streamed, content_range, content_length = await self._test_range_requests() + self.assertEqual(len(self.data), content_length) + self.assertEqual(streamed, self.data) + self.assertEqual('bytes 0-8388587/8388588', content_range) + + async def test_streaming_only_with_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + + # test that repeated range requests do not create duplicate files + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + # test that a range request after restart does not create a duplicate file + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + async def test_streaming_only_without_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, save_blobs=False) + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + + # test that repeated range requests do not create duplicate files + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + # test that a range request after restart does not create a duplicate file + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertIsNone(stream.download_directory) + self.assertIsNone(stream.full_path) + current_files_in_download_dir = list(os.scandir(os.path.dirname(self.daemon.conf.data_dir))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + async def test_stream_and_save_with_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, streaming_only=False) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + full_path = stream.full_path + files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + + for _ in range(3): + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + await self._restart_stream_manager() + + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isfile(self.daemon.blob_manager.get_blob(stream.sd_hash).file_path)) + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) + + async def test_stream_and_save_without_blobs(self): + self.data = get_random_bytes((MAX_BLOB_SIZE - 1) * 4) + await self._setup_stream(self.data, streaming_only=False) + self.daemon.conf.save_blobs = False + + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + full_path = stream.full_path + files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + + for _ in range(3): + await self._test_range_requests() + stream = self.daemon.jsonrpc_file_list()[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + await self._restart_stream_manager() + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + + await self._test_range_requests() + streams = self.daemon.jsonrpc_file_list() + self.assertEqual(1, len(streams)) + stream = streams[0] + self.assertTrue(os.path.isdir(stream.download_directory)) + self.assertTrue(os.path.isfile(stream.full_path)) + current_files_in_download_dir = list(os.scandir(os.path.dirname(full_path))) + self.assertEqual( + len(files_in_download_dir), len(current_files_in_download_dir) + ) + + with open(stream.full_path, 'rb') as f: + self.assertEqual(self.data, f.read()) + + async def test_switch_save_blobs_while_running(self): + await self.test_streaming_only_without_blobs() + self.daemon.conf.save_blobs = True + blobs_in_stream = self.daemon.jsonrpc_file_list()[0].blobs_in_stream + sd_hash = self.daemon.jsonrpc_file_list()[0].sd_hash + start_file_count = len(os.listdir(self.daemon.blob_manager.blob_dir)) + await self._test_range_requests() + self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining) + + # switch back + self.daemon.conf.save_blobs = False + await self._test_range_requests() + self.assertEqual(start_file_count + blobs_in_stream, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(0, self.daemon.jsonrpc_file_list()[0].blobs_remaining) + await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, sd_hash=sd_hash) + self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir))) + await self._test_range_requests() + self.assertEqual(start_file_count, len(os.listdir(self.daemon.blob_manager.blob_dir))) + self.assertEqual(blobs_in_stream, self.daemon.jsonrpc_file_list()[0].blobs_remaining) diff --git a/tests/unit/blob/test_blob_manager.py b/tests/unit/blob/test_blob_manager.py index 1f16b8480..6dd1885dd 100644 --- a/tests/unit/blob/test_blob_manager.py +++ b/tests/unit/blob/test_blob_manager.py @@ -32,8 +32,7 @@ class TestBlobManager(AsyncioTestCase): self.assertSetEqual(self.blob_manager.completed_blob_hashes, set()) # make sure we can add the blob - self.blob_manager.blob_completed(self.blob_manager.get_blob(blob_hash, len(blob_bytes))) - await self.blob_manager.storage.add_blobs((blob_hash, len(blob_bytes)), finished=True) + await self.blob_manager.blob_completed(self.blob_manager.get_blob(blob_hash, len(blob_bytes))) self.assertSetEqual(self.blob_manager.completed_blob_hashes, {blob_hash}) # stop the blob manager and restart it, make sure the blob is there