forked from LBRYCommunity/lbry-sdk
Merge pull request #2148 from lbryio/lru-cache-concurrent
Fix resolve internals caching and persisting errors
This commit is contained in:
commit
eca677a720
3 changed files with 112 additions and 7 deletions
|
@ -15,6 +15,7 @@ import contextlib
|
||||||
import certifi
|
import certifi
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import functools
|
import functools
|
||||||
|
import collections
|
||||||
from lbrynet.schema.claim import Claim
|
from lbrynet.schema.claim import Claim
|
||||||
from lbrynet.cryptoutils import get_lbry_hash_obj
|
from lbrynet.cryptoutils import get_lbry_hash_obj
|
||||||
|
|
||||||
|
@ -201,6 +202,59 @@ async def resolve_host(url: str, port: int, proto: str) -> str:
|
||||||
))[0][4][0]
|
))[0][4][0]
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
__slots__ = [
|
||||||
|
'capacity',
|
||||||
|
'cache'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, capacity):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.cache = collections.OrderedDict()
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
value = self.cache.pop(key)
|
||||||
|
self.cache[key] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def set(self, key, value):
|
||||||
|
try:
|
||||||
|
self.cache.pop(key)
|
||||||
|
except KeyError:
|
||||||
|
if len(self.cache) >= self.capacity:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def __contains__(self, item) -> bool:
|
||||||
|
return item in self.cache
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cache_concurrent(cache_size: int):
|
||||||
|
if not cache_size > 0:
|
||||||
|
raise ValueError("invalid cache size")
|
||||||
|
concurrent_cache = {}
|
||||||
|
lru_cache = LRUCache(cache_size)
|
||||||
|
|
||||||
|
def wrapper(async_fn):
|
||||||
|
|
||||||
|
@functools.wraps(async_fn)
|
||||||
|
async def _inner(*args, **kwargs):
|
||||||
|
key = tuple([args, tuple([tuple([k, kwargs[k]]) for k in kwargs])])
|
||||||
|
if key in lru_cache:
|
||||||
|
return lru_cache.get(key)
|
||||||
|
|
||||||
|
concurrent_cache[key] = concurrent_cache.get(key) or asyncio.create_task(async_fn(*args, **kwargs))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await concurrent_cache[key]
|
||||||
|
lru_cache.set(key, result)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
concurrent_cache.pop(key, None)
|
||||||
|
return _inner
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_ssl_context() -> ssl.SSLContext:
|
def get_ssl_context() -> ssl.SSLContext:
|
||||||
return ssl.create_default_context(
|
return ssl.create_default_context(
|
||||||
purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
|
purpose=ssl.Purpose.CLIENT_AUTH, capath=certifi.where()
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
from cryptography.exceptions import InvalidSignature
|
||||||
from binascii import unhexlify, hexlify
|
from binascii import unhexlify, hexlify
|
||||||
|
from lbrynet.utils import lru_cache_concurrent
|
||||||
from lbrynet.wallet.account import validate_claim_id
|
from lbrynet.wallet.account import validate_claim_id
|
||||||
from lbrynet.wallet.dewies import dewies_to_lbc
|
from lbrynet.wallet.dewies import dewies_to_lbc
|
||||||
from lbrynet.error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint
|
from lbrynet.error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint
|
||||||
|
@ -51,11 +50,9 @@ class Resolver:
|
||||||
results = await asyncio.gather(*futs)
|
results = await asyncio.gather(*futs)
|
||||||
return dict(list(map(lambda result: list(result.items())[0], results)))
|
return dict(list(map(lambda result: list(result.items())[0], results)))
|
||||||
|
|
||||||
@lru_cache(256)
|
@lru_cache_concurrent(256)
|
||||||
def _fetch_tx(self, txid):
|
async def _fetch_tx(self, txid):
|
||||||
async def __fetch_parse(txid):
|
|
||||||
return self.transaction_class(unhexlify(await self.network.get_transaction(txid)))
|
return self.transaction_class(unhexlify(await self.network.get_transaction(txid)))
|
||||||
return asyncio.ensure_future(__fetch_parse(txid))
|
|
||||||
|
|
||||||
async def _handle_resolutions(self, resolutions, requested_uris, page, page_size, claim_trie_root):
|
async def _handle_resolutions(self, resolutions, requested_uris, page, page_size, claim_trie_root):
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from lbrynet.testcase import CommandTestCase
|
from lbrynet.testcase import CommandTestCase
|
||||||
|
@ -88,3 +89,56 @@ class ResolveCommand(CommandTestCase):
|
||||||
self.assertFalse(claim['decoded_claim'])
|
self.assertFalse(claim['decoded_claim'])
|
||||||
self.assertEqual(claim['txid'], txid)
|
self.assertEqual(claim['txid'], txid)
|
||||||
self.assertEqual(claim['effective_amount'], "0.1")
|
self.assertEqual(claim['effective_amount'], "0.1")
|
||||||
|
|
||||||
|
async def _test_resolve_abc_foo(self):
|
||||||
|
response = await self.resolve('lbry://@abc/foo')
|
||||||
|
claim = response['lbry://@abc/foo']
|
||||||
|
self.assertIn('certificate', claim)
|
||||||
|
self.assertIn('claim', claim)
|
||||||
|
self.assertEqual(claim['claim']['name'], 'foo')
|
||||||
|
self.assertEqual(claim['claim']['channel_name'], '@abc')
|
||||||
|
self.assertEqual(claim['certificate']['name'], '@abc')
|
||||||
|
self.assertEqual(claim['claims_in_channel'], 0)
|
||||||
|
self.assertEqual(
|
||||||
|
claim['claim']['timestamp'],
|
||||||
|
self.ledger.headers[claim['claim']['height']]['timestamp']
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim['certificate']['timestamp'],
|
||||||
|
self.ledger.headers[claim['certificate']['height']]['timestamp']
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_resolve_lru_cache_doesnt_persist_errors(self):
|
||||||
|
original_get_transaction = self.daemon.wallet_manager.ledger.network.get_transaction
|
||||||
|
|
||||||
|
async def timeout_get_transaction(txid):
|
||||||
|
fut = self.loop.create_future()
|
||||||
|
|
||||||
|
def delayed_raise_cancelled_error():
|
||||||
|
fut.set_exception(asyncio.CancelledError())
|
||||||
|
|
||||||
|
self.loop.call_soon(delayed_raise_cancelled_error)
|
||||||
|
return await fut
|
||||||
|
|
||||||
|
tx = await self.channel_create('@abc', '0.01')
|
||||||
|
channel_id = tx['outputs'][0]['claim_id']
|
||||||
|
await self.stream_create('foo', '0.01', channel_id=channel_id)
|
||||||
|
|
||||||
|
# raise a cancelled error from get_transaction
|
||||||
|
self.daemon.wallet_manager.ledger.network.get_transaction = timeout_get_transaction
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
await self._test_resolve_abc_foo()
|
||||||
|
|
||||||
|
# restore the real get_transaction that doesn't cancel, it should be called and the result cached
|
||||||
|
self.daemon.wallet_manager.ledger.network.get_transaction = original_get_transaction
|
||||||
|
await self._test_resolve_abc_foo()
|
||||||
|
called_again = asyncio.Event(loop=self.loop)
|
||||||
|
|
||||||
|
def check_result_cached(txid):
|
||||||
|
called_again.set()
|
||||||
|
return original_get_transaction(txid)
|
||||||
|
|
||||||
|
# check that the result was cached
|
||||||
|
self.daemon.wallet_manager.ledger.network.get_transaction = check_result_cached
|
||||||
|
await self._test_resolve_abc_foo()
|
||||||
|
self.assertFalse(called_again.is_set())
|
||||||
|
|
Loading…
Reference in a new issue