diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index 0d9381f50..38bf12ba3 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -1,7 +1,7 @@ import asyncio import typing import logging -from lbrynet.utils import drain_tasks +from lbrynet.utils import drain_tasks, cache_concurrent from lbrynet.blob_exchange.client import request_blob if typing.TYPE_CHECKING: from lbrynet.conf import Config @@ -90,6 +90,7 @@ class BlobDownloader: for banned_peer in forgiven: self.ignored.pop(banned_peer) + @cache_concurrent async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob': blob = self.blob_manager.get_blob(blob_hash, length) if blob.get_is_verified(): diff --git a/lbrynet/utils.py b/lbrynet/utils.py index 0e43ee313..057dc11fa 100644 --- a/lbrynet/utils.py +++ b/lbrynet/utils.py @@ -166,6 +166,50 @@ def async_timed_cache(duration: int): return wrapper +def cache_concurrent(async_fn): + """ + When the decorated function has concurrent calls made to it with the same arguments, only run it the once + """ + running: typing.Optional[asyncio.Event] = None + cache: typing.Dict = {} + + def initialize_running_event_and_cache(): + # this is to avoid automatically setting up an event loop by using the decorator + nonlocal running + if running is not None: + return + loop = asyncio.get_running_loop() + running = asyncio.Event(loop=loop) + + @functools.wraps(async_fn) + async def wrapper(*args, **kwargs): + if running is None: + initialize_running_event_and_cache() + loop = asyncio.get_running_loop() + key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])]) + if running.is_set() and key in cache: + return await cache[key] + running.set() + if key not in cache: + cache[key] = loop.create_future() + error = False + try: + result = await async_fn(*args, **kwargs) + cache[key].set_result(result) + except Exception as err: + cache[key].set_exception(err) + error = True + finally: + fut = cache.pop(key) + if not cache: + running.clear() + if error: + raise fut.exception() + return fut.result() + + return wrapper + + @async_timed_cache(300) async def resolve_host(url: str, port: int, proto: str) -> str: if proto not in ['udp', 'tcp']: