fix test, add lru_cache_concurrent decorator
This commit is contained in:
parent
41a6e288aa
commit
51f301785f
2 changed files with 58 additions and 7 deletions
|
@ -15,6 +15,7 @@ import contextlib
|
|||
import certifi
|
||||
import aiohttp
|
||||
import functools
|
||||
import collections
|
||||
from lbrynet.schema.claim import Claim
|
||||
from lbrynet.cryptoutils import get_lbry_hash_obj
|
||||
|
||||
|
@ -201,6 +202,59 @@ async def resolve_host(url: str, port: int, proto: str) -> str:
|
|||
))[0][4][0]
|
||||
|
||||
|
||||
class LRUCache:
|
||||
__slots__ = [
|
||||
'capacity',
|
||||
'cache'
|
||||
]
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.cache = collections.OrderedDict()
|
||||
|
||||
def get(self, key):
|
||||
value = self.cache.pop(key)
|
||||
self.cache[key] = value
|
||||
return value
|
||||
|
||||
def set(self, key, value):
|
||||
try:
|
||||
self.cache.pop(key)
|
||||
except KeyError:
|
||||
if len(self.cache) >= self.capacity:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = value
|
||||
|
||||
def __contains__(self, item) -> bool:
|
||||
return item in self.cache
|
||||
|
||||
|
||||
def lru_cache_concurrent(cache_size: int):
|
||||
if not cache_size > 0:
|
||||
raise ValueError("invalid cache size")
|
||||
concurrent_cache = {}
|
||||
lru_cache = LRUCache(cache_size)
|
||||
|
||||
def wrapper(async_fn):
|
||||
|
||||
@functools.wraps(async_fn)
|
||||
async def _inner(*args, **kwargs):
|
||||
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
||||
if key in lru_cache:
|
||||
return lru_cache.get(key)
|
||||
|
||||
concurrent_cache[key] = concurrent_cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
|
||||
|
||||
try:
|
||||
result = await concurrent_cache[key]
|
||||
lru_cache.set(key, result)
|
||||
return result
|
||||
finally:
|
||||
concurrent_cache.pop(key, None)
|
||||
return _inner
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_ssl_context() -> ssl.SSLContext:
|
||||
return ssl.create_default_context(
|
||||
purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import logging
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from binascii import unhexlify, hexlify
|
||||
|
||||
from lbrynet.utils import lru_cache_concurrent
|
||||
from lbrynet.wallet.account import validate_claim_id
|
||||
from lbrynet.wallet.dewies import dewies_to_lbc
|
||||
from lbrynet.error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint
|
||||
|
@ -51,11 +50,9 @@ class Resolver:
|
|||
results = await asyncio.gather(*futs)
|
||||
return dict(list(map(lambda result: list(result.items())[0], results)))
|
||||
|
||||
@lru_cache(256)
|
||||
def _fetch_tx(self, txid):
|
||||
async def __fetch_parse(txid):
|
||||
@lru_cache_concurrent(256)
|
||||
async def _fetch_tx(self, txid):
|
||||
return self.transaction_class(unhexlify(await self.network.get_transaction(txid)))
|
||||
return asyncio.ensure_future(__fetch_parse(txid))
|
||||
|
||||
async def _handle_resolutions(self, resolutions, requested_uris, page, page_size, claim_trie_root):
|
||||
results = {}
|
||||
|
|
Loading…
Reference in a new issue