Merge pull request #3275 from lbryio/search_caching_issues

add caching to "search ahead" code and invalidate short_url cache on every block
This commit is contained in:
Jack Robison 2021-04-28 14:14:29 -04:00 committed by GitHub
commit ad6281090d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import struct import struct
from binascii import unhexlify from binascii import unhexlify
from collections import Counter from collections import Counter, deque
from decimal import Decimal from decimal import Decimal
from operator import itemgetter from operator import itemgetter
from typing import Optional, List, Iterable, Union from typing import Optional, List, Iterable, Union
@ -42,7 +42,7 @@ class SearchIndex:
self.index = index_prefix + 'claims' self.index = index_prefix + 'claims'
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
self.claim_cache = LRUCache(2 ** 15) self.claim_cache = LRUCache(2 ** 15)
self.short_id_cache = LRUCache(2 ** 17) # never invalidated, since short ids are forever self.short_id_cache = LRUCache(2 ** 17)
self.search_cache = LRUCache(2 ** 17) self.search_cache = LRUCache(2 ** 17)
self.resolution_cache = LRUCache(2 ** 17) self.resolution_cache = LRUCache(2 ** 17)
self._elastic_host = elastic_host self._elastic_host = elastic_host
@ -134,6 +134,7 @@ class SearchIndex:
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index) await self.sync_client.indices.refresh(self.index)
self.search_cache.clear() self.search_cache.clear()
self.short_id_cache.clear()
self.claim_cache.clear() self.claim_cache.clear()
self.resolution_cache.clear() self.resolution_cache.clear()
@ -198,7 +199,7 @@ class SearchIndex:
self.claim_cache.set(result['claim_id'], result) self.claim_cache.set(result['claim_id'], result)
async def full_id_from_short_id(self, name, short_id, channel_id=None): async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id key = '#'.join((channel_id or '', name, short_id))
if key not in self.short_id_cache: if key not in self.short_id_cache:
query = {'name': name, 'claim_id': short_id} query = {'name': name, 'claim_id': short_id}
if channel_id: if channel_id:
@ -233,40 +234,56 @@ class SearchIndex:
async def search_ahead(self, **kwargs): async def search_ahead(self, **kwargs):
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
per_channel_per_page = kwargs.pop('limit_claims_per_channel') per_channel_per_page = kwargs.pop('limit_claims_per_channel')
limit = kwargs.pop('limit', 10) page_size = kwargs.pop('limit', 10)
offset = kwargs.pop('offset', 0) offset = kwargs.pop('offset', 0)
kwargs['limit'] = 1000 kwargs['limit'] = 1000
cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
if cache_item.result is not None:
reordered_hits = cache_item.result
else:
async with cache_item.lock:
if cache_item.result:
reordered_hits = cache_item.result
else:
query = expand_query(**kwargs) query = expand_query(**kwargs)
result = (await self.search_client.search( reordered_hits = await self.__search_ahead(query, page_size, per_channel_per_page)
cache_item.result = reordered_hits
return list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
async def __search_ahead(self, query: dict, page_size: int, per_channel_per_page: int):
search_hits = deque((await self.search_client.search(
query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id'] query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id']
))['hits'] ))['hits']['hits'])
to_inflate = [] reordered_hits = []
channel_counter = Counter() channel_counters = Counter()
delayed = [] next_page_hits_maybe_check_later = deque()
while result['hits'] or delayed: while search_hits or next_page_hits_maybe_check_later:
if len(to_inflate) % limit == 0: if reordered_hits and len(reordered_hits) % page_size == 0:
channel_counter.clear() channel_counters.clear()
elif not reordered_hits:
pass
else: else:
break # means last page was incomplete and we are left with bad replacements break # means last page was incomplete and we are left with bad replacements
new_delayed = [] for _ in range(len(next_page_hits_maybe_check_later)):
for id, chann in delayed: claim_id, channel_id = next_page_hits_maybe_check_later.popleft()
if channel_counter[chann] < per_channel_per_page: if channel_counters[channel_id] < per_channel_per_page:
to_inflate.append((id, chann)) reordered_hits.append((claim_id, channel_id))
channel_counter[hit_channel_id] += 1 channel_counters[channel_id] += 1
else: else:
new_delayed.append((id, chann)) next_page_hits_maybe_check_later.append((claim_id, channel_id))
delayed = new_delayed while search_hits:
while result['hits']: hit = search_hits.popleft()
hit = result['hits'].pop(0)
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
if channel_counter[hit_channel_id] < per_channel_per_page: if hit_channel_id is None:
to_inflate.append((hit_id, hit_channel_id)) reordered_hits.append((hit_id, hit_channel_id))
channel_counter[hit_channel_id] += 1 elif channel_counters[hit_channel_id] < per_channel_per_page:
if len(to_inflate) % limit == 0: reordered_hits.append((hit_id, hit_channel_id))
channel_counters[hit_channel_id] += 1
if len(reordered_hits) % page_size == 0:
break break
else: else:
delayed.append((hit_id, hit_channel_id)) next_page_hits_maybe_check_later.append((hit_id, hit_channel_id))
return list(await self.get_many(*(claim_id for claim_id, _ in to_inflate[offset:(offset + limit)]))) return reordered_hits
async def resolve_url(self, raw_url): async def resolve_url(self, raw_url):
if raw_url not in self.resolution_cache: if raw_url not in self.resolution_cache: