add cache_concurrent decorator

This commit is contained in:
Jack Robison 2019-03-30 21:05:46 -04:00
parent 676f0015aa
commit c663e5a3cf
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 46 additions and 1 deletions

View file

@ -1,7 +1,7 @@
import asyncio
import typing
import logging
from lbrynet.utils import drain_tasks
from lbrynet.utils import drain_tasks, cache_concurrent
from lbrynet.blob_exchange.client import request_blob
if typing.TYPE_CHECKING:
from lbrynet.conf import Config
@ -90,6 +90,7 @@ class BlobDownloader:
for banned_peer in forgiven:
self.ignored.pop(banned_peer)
@cache_concurrent
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified():

View file

@ -166,6 +166,50 @@ def async_timed_cache(duration: int):
return wrapper
def cache_concurrent(async_fn):
"""
When the decorated function has concurrent calls made to it with the same arguments, only run it the once
"""
running: typing.Optional[asyncio.Event] = None
cache: typing.Dict = {}
def initialize_running_event_and_cache():
# this is to avoid automatically setting up an event loop by using the decorator
nonlocal running
if running is not None:
return
loop = asyncio.get_running_loop()
running = asyncio.Event(loop=loop)
@functools.wraps(async_fn)
async def wrapper(*args, **kwargs):
if running is None:
initialize_running_event_and_cache()
loop = asyncio.get_running_loop()
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
if running.is_set() and key in cache:
return await cache[key]
running.set()
if key not in cache:
cache[key] = loop.create_future()
error = False
try:
result = await async_fn(*args, **kwargs)
cache[key].set_result(result)
except Exception as err:
cache[key].set_exception(err)
error = True
finally:
fut = cache.pop(key)
if not cache:
running.clear()
if error:
raise fut.exception()
return fut.result()
return wrapper
@async_timed_cache(300)
async def resolve_host(url: str, port: int, proto: str) -> str:
if proto not in ['udp', 'tcp']: