add cache_concurrent decorator
This commit is contained in:
parent
676f0015aa
commit
c663e5a3cf
2 changed files with 46 additions and 1 deletions
lbrynet
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
from lbrynet.utils import drain_tasks
|
from lbrynet.utils import drain_tasks, cache_concurrent
|
||||||
from lbrynet.blob_exchange.client import request_blob
|
from lbrynet.blob_exchange.client import request_blob
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from lbrynet.conf import Config
|
from lbrynet.conf import Config
|
||||||
|
@ -90,6 +90,7 @@ class BlobDownloader:
|
||||||
for banned_peer in forgiven:
|
for banned_peer in forgiven:
|
||||||
self.ignored.pop(banned_peer)
|
self.ignored.pop(banned_peer)
|
||||||
|
|
||||||
|
@cache_concurrent
|
||||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
|
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
|
||||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
|
|
|
@ -166,6 +166,50 @@ def async_timed_cache(duration: int):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def cache_concurrent(async_fn):
|
||||||
|
"""
|
||||||
|
When the decorated function has concurrent calls made to it with the same arguments, only run it the once
|
||||||
|
"""
|
||||||
|
running: typing.Optional[asyncio.Event] = None
|
||||||
|
cache: typing.Dict = {}
|
||||||
|
|
||||||
|
def initialize_running_event_and_cache():
|
||||||
|
# this is to avoid automatically setting up an event loop by using the decorator
|
||||||
|
nonlocal running
|
||||||
|
if running is not None:
|
||||||
|
return
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
running = asyncio.Event(loop=loop)
|
||||||
|
|
||||||
|
@functools.wraps(async_fn)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
if running is None:
|
||||||
|
initialize_running_event_and_cache()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
||||||
|
if running.is_set() and key in cache:
|
||||||
|
return await cache[key]
|
||||||
|
running.set()
|
||||||
|
if key not in cache:
|
||||||
|
cache[key] = loop.create_future()
|
||||||
|
error = False
|
||||||
|
try:
|
||||||
|
result = await async_fn(*args, **kwargs)
|
||||||
|
cache[key].set_result(result)
|
||||||
|
except Exception as err:
|
||||||
|
cache[key].set_exception(err)
|
||||||
|
error = True
|
||||||
|
finally:
|
||||||
|
fut = cache.pop(key)
|
||||||
|
if not cache:
|
||||||
|
running.clear()
|
||||||
|
if error:
|
||||||
|
raise fut.exception()
|
||||||
|
return fut.result()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@async_timed_cache(300)
|
@async_timed_cache(300)
|
||||||
async def resolve_host(url: str, port: int, proto: str) -> str:
|
async def resolve_host(url: str, port: int, proto: str) -> str:
|
||||||
if proto not in ['udp', 'tcp']:
|
if proto not in ['udp', 'tcp']:
|
||||||
|
|
Loading…
Add table
Reference in a new issue