working resolve/search caching

This commit is contained in:
Lex Berezhny 2019-07-16 00:16:35 -04:00
parent 71dc882d5d
commit e02a4e44c0
2 changed files with 51 additions and 23 deletions

View file

@ -7,6 +7,7 @@ from binascii import hexlify
from weakref import WeakSet from weakref import WeakSet
from pylru import lrucache from pylru import lrucache
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Optional
from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite from aiohttp.web import Application, AppRunner, WebSocketResponse, TCPSite
from aiohttp.http_websocket import WSMsgType, WSCloseCode from aiohttp.http_websocket import WSMsgType, WSCloseCode
@ -45,11 +46,9 @@ class AdminWebSocket:
async def start(self): async def start(self):
await self.runner.setup() await self.runner.setup()
await TCPSite(self.runner, self.manager.env.websocket_host, self.manager.env.websocket_port).start() await TCPSite(self.runner, self.manager.env.websocket_host, self.manager.env.websocket_port).start()
print('started websocket')
async def stop(self): async def stop(self):
await self.runner.cleanup() await self.runner.cleanup()
print('stopped websocket')
async def on_connect(self, request): async def on_connect(self, request):
web_socket = WebSocketResponse() web_socket = WebSocketResponse()
@ -58,7 +57,6 @@ class AdminWebSocket:
try: try:
async for msg in web_socket: async for msg in web_socket:
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
print(msg.data)
await self.on_status(None) await self.on_status(None)
elif msg.type == WSMsgType.ERROR: elif msg.type == WSMsgType.ERROR:
print('web socket connection closed with exception %s' % print('web socket connection closed with exception %s' %
@ -69,11 +67,29 @@ class AdminWebSocket:
@staticmethod @staticmethod
async def on_shutdown(app): async def on_shutdown(app):
print('disconnecting websockets')
for web_socket in set(app['websockets']): for web_socket in set(app['websockets']):
await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown') await web_socket.close(code=WSCloseCode.GOING_AWAY, message='Server shutdown')
class ResultCacheItem:
__slots__ = '_result', 'lock', 'has_result'
def __init__(self):
self.has_result = asyncio.Event()
self.lock = asyncio.Lock()
self._result = None
@property
def result(self) -> str:
return self._result
@result.setter
def result(self, result: str):
self._result = result
if result is not None:
self.has_result.set()
class LBRYSessionManager(SessionManager): class LBRYSessionManager(SessionManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -87,12 +103,13 @@ class LBRYSessionManager(SessionManager):
if self.env.websocket_host is not None and self.env.websocket_port is not None: if self.env.websocket_host is not None and self.env.websocket_port is not None:
self.websocket = AdminWebSocket(self) self.websocket = AdminWebSocket(self)
self.search_cache = self.bp.search_cache self.search_cache = self.bp.search_cache
self.search_cache['search'] = lrucache(1000) self.search_cache['search'] = lrucache(10000)
self.search_cache['resolve'] = lrucache(1000) self.search_cache['resolve'] = lrucache(10000)
def get_command_tracking_info(self, command): def get_command_tracking_info(self, command):
if command not in self.command_metrics: if command not in self.command_metrics:
self.command_metrics[command] = { self.command_metrics[command] = {
'cache_hit': 0,
'started': 0, 'started': 0,
'finished': 0, 'finished': 0,
'total_time': 0, 'total_time': 0,
@ -102,6 +119,11 @@ class LBRYSessionManager(SessionManager):
} }
return self.command_metrics[command] return self.command_metrics[command]
def cache_hit(self, command_name):
if self.env.track_metrics:
command = self.get_command_tracking_info(command_name)
command['cache_hit'] += 1
def start_command_tracking(self, command_name): def start_command_tracking(self, command_name):
if self.env.track_metrics: if self.env.track_metrics:
command = self.get_command_tracking_info(command_name) command = self.get_command_tracking_info(command_name)
@ -195,26 +217,33 @@ class LBRYElectrumX(ElectrumX):
elapsed = int((time.perf_counter() - start) * 1000) elapsed = int((time.perf_counter() - start) * 1000)
(result, metrics) = result (result, metrics) = result
self.session_mgr.finish_command_tracking(name, elapsed, metrics) self.session_mgr.finish_command_tracking(name, elapsed, metrics)
result = base64.b64encode(result) return base64.b64encode(result).decode()
if self.env.cache_search:
self.session_mgr.search_cache[name][str(kwargs)] = result async def run_and_cache_query(self, query_name, function, kwargs):
cache = self.session_mgr.search_cache[query_name]
cache_key = str(kwargs)
cache_item = cache.get(cache_key)
if cache_item is None:
cache_item = cache[cache_key] = ResultCacheItem()
elif cache_item.result is not None:
self.session_mgr.cache_hit(query_name)
return cache_item.result
async with cache_item.lock:
result = cache_item.result
if result is None:
result = cache_item.result = await self.run_in_executor(
query_name, function, kwargs
)
else:
self.session_mgr.cache_hit(query_name)
return result return result
async def claimtrie_search(self, **kwargs): async def claimtrie_search(self, **kwargs):
if 'claim_id' in kwargs: return await self.run_and_cache_query('search', reader.search_to_bytes, kwargs)
self.assert_claim_id(kwargs['claim_id'])
if self.env.cache_search:
key = str(kwargs)
if key in self.session_mgr.search_cache['search']:
return self.session_mgr.search_cache['search'][key]
return await self.run_in_executor('search', reader.search_to_bytes, kwargs)
async def claimtrie_resolve(self, *urls): async def claimtrie_resolve(self, *urls):
if self.env.cache_search: return await self.run_and_cache_query('resolve', reader.resolve_to_bytes, urls)
key = str(urls)
if key in self.session_mgr.search_cache['resolve']:
return self.session_mgr.search_cache['resolve'][key]
return await self.run_in_executor('resolve', reader.resolve_to_bytes, urls)
async def get_server_height(self): async def get_server_height(self):
return self.bp.height return self.bp.height

View file

@ -38,7 +38,6 @@ class Env:
self.db_dir = self.required('DB_DIRECTORY') self.db_dir = self.required('DB_DIRECTORY')
self.db_engine = self.default('DB_ENGINE', 'leveldb') self.db_engine = self.default('DB_ENGINE', 'leveldb')
self.max_query_workers = self.integer('MAX_QUERY_WORKERS', None) self.max_query_workers = self.integer('MAX_QUERY_WORKERS', None)
self.cache_search = self.boolean('CACHE_SEARCH', False)
self.track_metrics = self.boolean('TRACK_METRICS', False) self.track_metrics = self.boolean('TRACK_METRICS', False)
self.websocket_host = self.default('WEBSOCKET_HOST', None) self.websocket_host = self.default('WEBSOCKET_HOST', None)
self.websocket_port = self.integer('WEBSOCKET_PORT', None) self.websocket_port = self.integer('WEBSOCKET_PORT', None)