diff --git a/lbrynet/blob/blob_manager.py b/lbrynet/blob/blob_manager.py index 9ae0a0e0e..cc68d01b0 100644 --- a/lbrynet/blob/blob_manager.py +++ b/lbrynet/blob/blob_manager.py @@ -2,6 +2,7 @@ import os import typing import asyncio import logging +from lbrynet.utils import LRUCache from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob from lbrynet.stream.descriptor import StreamDescriptor @@ -30,6 +31,8 @@ class BlobManager: else self._node_data_store.completed_blobs self.blobs: typing.Dict[str, AbstractBlob] = {} self.config = config + self.decrypted_blob_lru_cache = None if not self.config.blob_lru_cache_size else LRUCache( + self.config.blob_lru_cache_size) def _get_blob(self, blob_hash: str, length: typing.Optional[int] = None): if self.config.save_blobs: diff --git a/lbrynet/conf.py b/lbrynet/conf.py index 8b7434b50..0200f3f27 100644 --- a/lbrynet/conf.py +++ b/lbrynet/conf.py @@ -485,7 +485,10 @@ class Config(CLIConfig): # blob announcement and download save_blobs = Toggle("Save encrypted blob files for hosting, otherwise download blobs to memory only.", True) - + blob_lru_cache_size = Integer( + "LRU cache size for decrypted downloaded blobs used to minimize re-downloading the same blobs when " + "replying to a range request. Set to 0 to disable.", 32 + ) announce_head_and_sd_only = Toggle( "Announce only the descriptor and first (rather than all) data blob for a stream to the DHT", True, previous_names=['announce_head_blobs_only'] diff --git a/lbrynet/stream/downloader.py b/lbrynet/stream/downloader.py index 5b053082e..df10fe539 100644 --- a/lbrynet/stream/downloader.py +++ b/lbrynet/stream/downloader.py @@ -3,7 +3,7 @@ import typing import logging import binascii from lbrynet.error import DownloadSDTimeout -from lbrynet.utils import resolve_host +from lbrynet.utils import resolve_host, lru_cache_concurrent from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.blob_exchange.downloader import BlobDownloader from lbrynet.dht.peer import KademliaPeer @@ -36,6 +36,16 @@ class StreamDownloader: self.time_to_descriptor: typing.Optional[float] = None self.time_to_first_bytes: typing.Optional[float] = None + async def cached_read_blob(blob_info: 'BlobInfo') -> bytes: + return await self.read_blob(blob_info, 2) + + if self.blob_manager.decrypted_blob_lru_cache: + cached_read_blob = lru_cache_concurrent(override_lru_cache=self.blob_manager.decrypted_blob_lru_cache)( + cached_read_blob + ) + + self.cached_read_blob = cached_read_blob + async def add_fixed_peers(self): def _delayed_add_fixed_peers(): self.added_fixed_peers = True diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index 660c0619e..84d85dae3 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -42,6 +42,9 @@ class ManagedStream: STATUS_STOPPED = "stopped" STATUS_FINISHED = "finished" + SAVING_ID = 1 + STREAMING_ID = 2 + __slots__ = [ 'loop', 'config', @@ -304,7 +307,10 @@ class ManagedStream: raise IndexError(start_blob_num) for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): assert i + start_blob_num == blob_info.blob_num - decrypted = await self.downloader.read_blob(blob_info, connection_id) + if connection_id == self.STREAMING_ID: + decrypted = await self.downloader.cached_read_blob(blob_info) + else: + decrypted = await self.downloader.read_blob(blob_info, connection_id) yield (blob_info, decrypted) async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse: @@ -354,7 +360,7 @@ class ManagedStream: self.started_writing.clear() try: with open(output_path, 'wb') as file_write_handle: - async for blob_info, decrypted in self._aiter_read_stream(connection_id=1): + async for blob_info, decrypted in self._aiter_read_stream(connection_id=self.SAVING_ID): log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) await self.loop.run_in_executor(None, self._write_decrypted_blob, file_write_handle, decrypted) self.written_bytes += len(decrypted) diff --git a/lbrynet/testcase.py b/lbrynet/testcase.py index ccf0e9a7b..3ecbe5b12 100644 --- a/lbrynet/testcase.py +++ b/lbrynet/testcase.py @@ -62,6 +62,7 @@ class CommandTestCase(IntegrationTestCase): LEDGER = lbrynet.wallet MANAGER = LbryWalletManager VERBOSITY = logging.WARN + blob_lru_cache_size = 0 async def asyncSetUp(self): await super().asyncSetUp() @@ -81,6 +82,7 @@ class CommandTestCase(IntegrationTestCase): conf.lbryum_servers = [('127.0.0.1', 50001)] conf.reflector_servers = [('127.0.0.1', 5566)] conf.known_dht_nodes = [] + conf.blob_lru_cache_size = self.blob_lru_cache_size await self.account.ensure_address_gap() address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0] diff --git a/lbrynet/utils.py b/lbrynet/utils.py index 0bde133e2..8bc8b7ae1 100644 --- a/lbrynet/utils.py +++ b/lbrynet/utils.py @@ -229,11 +229,12 @@ class LRUCache: return item in self.cache -def lru_cache_concurrent(cache_size: int): - if not cache_size > 0: +def lru_cache_concurrent(cache_size: typing.Optional[int] = None, + override_lru_cache: typing.Optional[LRUCache] = None): + if not cache_size and override_lru_cache is None: raise ValueError("invalid cache size") concurrent_cache = {} - lru_cache = LRUCache(cache_size) + lru_cache = override_lru_cache or LRUCache(cache_size) def wrapper(async_fn): diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py index bbe77e9e5..29c2c7eca 100644 --- a/tests/integration/test_streaming.py +++ b/tests/integration/test_streaming.py @@ -2,6 +2,7 @@ import os import hashlib import aiohttp import aiohttp.web +import asyncio from lbrynet.utils import aiohttp_request from lbrynet.blob.blob_file import MAX_BLOB_SIZE @@ -373,3 +374,46 @@ class RangeRequests(CommandTestCase): await stream.finished_writing.wait() with open(stream.full_path, 'rb') as f: self.assertEqual(self.data, f.read()) + + +class RangeRequestsLRUCache(CommandTestCase): + blob_lru_cache_size = 32 + + async def _request_stream(self): + name = 'foo' + url = f'http://{self.daemon.conf.streaming_host}:{self.daemon.conf.streaming_port}/get/{name}' + + async with aiohttp_request('get', url) as req: + self.assertEqual(req.headers.get('Content-Type'), 'application/octet-stream') + content_range = req.headers.get('Content-Range') + content_length = int(req.headers.get('Content-Length')) + streamed_bytes = await req.content.read() + self.assertEqual(content_length, len(streamed_bytes)) + self.assertEqual(15, content_length) + self.assertEqual(b'hi\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', streamed_bytes) + self.assertEqual('bytes 0-14/15', content_range) + + async def test_range_requests_with_blob_lru_cache(self): + self.data = b'hi' + self.daemon.conf.save_blobs = False + self.daemon.conf.save_files = False + await self.stream_create('foo', '0.01', data=self.data, file_size=0) + await self.daemon.jsonrpc_file_list()[0].fully_reflected.wait() + await self.daemon.jsonrpc_file_delete(delete_from_download_dir=True, claim_name='foo') + self.assertEqual(0, len(os.listdir(self.daemon.blob_manager.blob_dir))) + + await self.daemon.streaming_runner.setup() + site = aiohttp.web.TCPSite(self.daemon.streaming_runner, self.daemon.conf.streaming_host, + self.daemon.conf.streaming_port) + await site.start() + self.assertListEqual(self.daemon.jsonrpc_file_list(), []) + + await self._request_stream() + self.assertEqual(1, len(self.daemon.jsonrpc_file_list())) + self.server.stop_server() + + # running with cache size 0 gets through without errors without + # this since the server doesnt stop immediately + await asyncio.sleep(1, loop=self.loop) + + await self._request_stream()