import logging import asyncio import struct from bisect import bisect_right from collections import Counter, deque from decimal import Decimal from operator import itemgetter from typing import Optional, List, Iterable, TYPE_CHECKING from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from scribe.schema.result import Censor, Outputs from scribe.schema.tags import clean_tags from scribe.schema.url import normalize_name from scribe.error import TooManyClaimSearchParametersError from scribe.common import LRUCache from scribe.db.common import CLAIM_TYPES, STREAM_TYPES from scribe.elasticsearch.constants import INDEX_DEFAULT_SETTINGS, REPLACEMENTS, FIELDS, TEXT_FIELDS, RANGE_FIELDS from scribe.db.common import ResolveResult if TYPE_CHECKING: from scribe.db import HubDB class ChannelResolution(str): @classmethod def lookup_error(cls, url): return LookupError(f'Could not find channel in "{url}".') class StreamResolution(str): @classmethod def lookup_error(cls, url): return LookupError(f'Could not find claim at "{url}".') class IndexVersionMismatch(Exception): def __init__(self, got_version, expected_version): self.got_version = got_version self.expected_version = expected_version class SearchIndex: VERSION = 1 def __init__(self, hub_db: 'HubDB', index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200): self.hub_db = hub_db self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None self.sync_client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = logging.getLogger(__name__) self.claim_cache = LRUCache(2 ** 15) self.search_cache = LRUCache(2 ** 17) self._elastic_host = elastic_host self._elastic_port = elastic_port async def get_index_version(self) -> int: try: template = await self.sync_client.indices.get_template(self.index) return template[self.index]['version'] except NotFoundError: return 0 async def set_index_version(self, version): await self.sync_client.indices.put_template( self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400 ) async def start(self) -> bool: if self.sync_client: return False hosts = [{'host': self._elastic_host, 'port': self._elastic_port}] self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout) while True: try: await self.sync_client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) acked = res.get('acknowledged', False) if acked: await self.set_index_version(self.VERSION) return acked index_version = await self.get_index_version() if index_version != self.VERSION: self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION) raise IndexVersionMismatch(index_version, self.VERSION) await self.sync_client.indices.refresh(self.index) return acked async def stop(self): clients = [c for c in (self.sync_client, self.search_client) if c is not None] self.sync_client, self.search_client = None, None if clients: await asyncio.gather(*(client.close() for client in clients)) def clear_caches(self): self.search_cache.clear() self.claim_cache.clear() def _make_resolve_result(self, es_result): channel_hash = es_result['channel_hash'] reposted_claim_hash = es_result['reposted_claim_hash'] channel_tx_hash = None channel_tx_position = None channel_height = None reposted_tx_hash = None reposted_tx_position = None reposted_height = None if channel_hash: # FIXME: do this inside ES in a script channel_txo = self.hub_db.get_cached_claim_txo(channel_hash[::-1]) if channel_txo: channel_tx_hash = self.hub_db.get_tx_hash(channel_txo.tx_num) channel_tx_position = channel_txo.position channel_height = bisect_right(self.hub_db.tx_counts, channel_txo.tx_num) if reposted_claim_hash: repost_txo = self.hub_db.get_cached_claim_txo(reposted_claim_hash[::-1]) if repost_txo: reposted_tx_hash = self.hub_db.get_tx_hash(repost_txo.tx_num) reposted_tx_position = repost_txo.position reposted_height = bisect_right(self.hub_db.tx_counts, repost_txo.tx_num) return ResolveResult( name=es_result['claim_name'], normalized_name=es_result['normalized_name'], claim_hash=es_result['claim_hash'], tx_num=es_result['tx_num'], position=es_result['tx_nout'], tx_hash=es_result['tx_hash'], height=es_result['height'], amount=es_result['amount'], short_url=es_result['short_url'], is_controlling=es_result['is_controlling'], canonical_url=es_result['canonical_url'], creation_height=es_result['creation_height'], activation_height=es_result['activation_height'], expiration_height=es_result['expiration_height'], effective_amount=es_result['effective_amount'], support_amount=es_result['support_amount'], last_takeover_height=es_result['last_take_over_height'], claims_in_channel=es_result['claims_in_channel'], channel_hash=channel_hash, reposted_claim_hash=reposted_claim_hash, reposted=es_result['reposted'], signature_valid=es_result['signature_valid'], reposted_tx_hash=reposted_tx_hash, reposted_tx_position=reposted_tx_position, reposted_height=reposted_height, channel_tx_hash=channel_tx_hash, channel_tx_position=channel_tx_position, channel_height=channel_height, ) async def cached_search(self, kwargs): total_referenced = [] cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) if cache_item.result is not None: return cache_item.result async with cache_item.lock: if cache_item.result: return cache_item.result response, offset, total = await self.search(**kwargs) censored = {} for row in response: if (row.get('censor_type') or 0) >= Censor.SEARCH: censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1] censored.setdefault(censoring_channel_hash, set()) censored[censoring_channel_hash].add(row['tx_hash']) total_referenced.extend(response) if censored: response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) response = [self._make_resolve_result(r) for r in response] extra = [self._make_resolve_result(r) for r in await self._get_referenced_rows(total_referenced)] result = Outputs.to_base64( response, extra, offset, total, censored ) cache_item.result = result return result async def get_many(self, *claim_ids): await self.populate_claim_cache(*claim_ids) return filter(None, map(self.claim_cache.get, claim_ids)) async def populate_claim_cache(self, *claim_ids): missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None] if missing: results = await self.search_client.mget( index=self.index, body={"ids": missing} ) for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) async def search(self, **kwargs): try: return await self.search_ahead(**kwargs) except NotFoundError: return [], 0, 0 # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 remove_duplicates = kwargs.pop('remove_duplicates', False) page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) if cache_item.result is not None: reordered_hits = cache_item.result else: async with cache_item.lock: if cache_item.result: reordered_hits = cache_item.result else: query = expand_query(**kwargs) search_hits = deque((await self.search_client.search( query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height'] ))['hits']['hits']) if remove_duplicates: search_hits = self.__remove_duplicates(search_hits) if per_channel_per_page > 0: reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page) else: reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] cache_item.result = reordered_hits result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) return result, 0, len(reordered_hits) def __remove_duplicates(self, search_hits: deque) -> deque: known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original dropped = set() for hit in search_hits: hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id'] if hit_id not in known_ids: known_ids[hit_id] = (hit_height, hit['_id']) else: previous_height, previous_id = known_ids[hit_id] if hit_height < previous_height: known_ids[hit_id] = (hit_height, hit['_id']) dropped.add(previous_id) else: dropped.add(hit['_id']) return deque(hit for hit in search_hits if hit['_id'] not in dropped) def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): reordered_hits = [] channel_counters = Counter() next_page_hits_maybe_check_later = deque() while search_hits or next_page_hits_maybe_check_later: if reordered_hits and len(reordered_hits) % page_size == 0: channel_counters.clear() elif not reordered_hits: pass else: break # means last page was incomplete and we are left with bad replacements for _ in range(len(next_page_hits_maybe_check_later)): claim_id, channel_id = next_page_hits_maybe_check_later.popleft() if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page: reordered_hits.append((claim_id, channel_id)) channel_counters[channel_id] += 1 else: next_page_hits_maybe_check_later.append((claim_id, channel_id)) while search_hits: hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] if hit_channel_id is None or per_channel_per_page <= 0: reordered_hits.append((hit_id, hit_channel_id)) elif channel_counters[hit_channel_id] < per_channel_per_page: reordered_hits.append((hit_id, hit_channel_id)) channel_counters[hit_channel_id] += 1 if len(reordered_hits) % page_size == 0: break else: next_page_hits_maybe_check_later.append((hit_id, hit_channel_id)) return reordered_hits async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows))) referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows))) referenced_txos = [] if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos))) if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) return referenced_txos def expand_query(**kwargs): if "amount_order" in kwargs: kwargs["limit"] = 1 kwargs["order_by"] = "effective_amount" kwargs["offset"] = int(kwargs["amount_order"]) - 1 if 'name' in kwargs: kwargs['name'] = normalize_name(kwargs.pop('name')) if kwargs.get('is_controlling') is False: kwargs.pop('is_controlling') query = {'must': [], 'must_not': []} collapse = None if 'fee_currency' in kwargs and kwargs['fee_currency'] is not None: kwargs['fee_currency'] = kwargs['fee_currency'].upper() for key, value in kwargs.items(): key = key.replace('claim.', '') many = key.endswith('__in') or isinstance(value, list) if many and len(value) > 2048: raise TooManyClaimSearchParametersError(key, 2048) if many: key = key.replace('__in', '') value = list(filter(None, value)) if value is None or isinstance(value, list) and len(value) == 0: continue key = REPLACEMENTS.get(key, key) if key in FIELDS: partial_id = False if key == 'claim_type': if isinstance(value, str): value = CLAIM_TYPES[value] else: value = [CLAIM_TYPES[claim_type] for claim_type in value] elif key == 'stream_type': value = [STREAM_TYPES[value]] if isinstance(value, str) else list(map(STREAM_TYPES.get, value)) if key == '_id': if isinstance(value, Iterable): value = [item[::-1].hex() for item in value] else: value = value[::-1].hex() if not many and key in ('_id', 'claim_id', 'sd_hash') and len(value) < 20: partial_id = True if key in ('signature_valid', 'has_source'): continue # handled later if key in TEXT_FIELDS: key += '.keyword' ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'} if partial_id: query['must'].append({"prefix": {key: value}}) elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops: operator_length = 2 if value[:2] in ops else 1 operator, value = value[:operator_length], value[operator_length:] if key == 'fee_amount': value = str(Decimal(value)*1000) query['must'].append({"range": {key: {ops[operator]: value}}}) elif key in RANGE_FIELDS and isinstance(value, list) and all(v[0] in ops for v in value): range_constraints = [] release_times = [] for v in value: operator_length = 2 if v[:2] in ops else 1 operator, stripped_op_v = v[:operator_length], v[operator_length:] if key == 'fee_amount': stripped_op_v = str(Decimal(stripped_op_v)*1000) if key == 'release_time': release_times.append((operator, stripped_op_v)) else: range_constraints.append((operator, stripped_op_v)) if key != 'release_time': query['must'].append({"range": {key: {ops[operator]: v for operator, v in range_constraints}}}) else: query['must'].append( {"bool": {"should": [ {"bool": { "must_not": { "exists": { "field": "release_time" } } }}, {"bool": { "must": [ {"exists": {"field": "release_time"}}, {'range': {key: {ops[operator]: v for operator, v in release_times}}}, ]}}, ]} } ) elif many: query['must'].append({"terms": {key: value}}) else: if key == 'fee_amount': value = str(Decimal(value)*1000) query['must'].append({"term": {key: {"value": value}}}) elif key == 'not_channel_ids': for channel_id in value: query['must_not'].append({"term": {'channel_id.keyword': channel_id}}) query['must_not'].append({"term": {'_id': channel_id}}) elif key == 'channel_ids': query['must'].append({"terms": {'channel_id.keyword': value}}) elif key == 'claim_ids': query['must'].append({"terms": {'claim_id.keyword': value}}) elif key == 'media_types': query['must'].append({"terms": {'media_type.keyword': value}}) elif key == 'any_languages': query['must'].append({"terms": {'languages': clean_tags(value)}}) elif key == 'any_languages': query['must'].append({"terms": {'languages': value}}) elif key == 'all_languages': query['must'].extend([{"term": {'languages': tag}} for tag in value]) elif key == 'any_tags': query['must'].append({"terms": {'tags.keyword': clean_tags(value)}}) elif key == 'all_tags': query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) elif key == 'not_tags': query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) elif key == 'not_claim_id': query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value]) elif key == 'limit_claims_per_channel': collapse = ('channel_id.keyword', value) if kwargs.get('has_channel_signature'): query['must'].append({"exists": {"field": "signature"}}) if 'signature_valid' in kwargs: query['must'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}}) elif 'signature_valid' in kwargs: query['must'].append( {"bool": {"should": [ {"bool": {"must_not": {"exists": {"field": "signature"}}}}, {"bool" : {"must" : {"term": {"is_signature_valid": bool(kwargs["signature_valid"])}}}} ]} } ) if 'has_source' in kwargs: is_stream_or_repost_terms = {"terms": {"claim_type": [CLAIM_TYPES['stream'], CLAIM_TYPES['repost']]}} query['must'].append( {"bool": {"should": [ {"bool": # when is_stream_or_repost AND has_source {"must": [ {"match": {"has_source": kwargs['has_source']}}, is_stream_or_repost_terms, ] }, }, {"bool": # when not is_stream_or_repost {"must_not": is_stream_or_repost_terms} }, {"bool": # when reposted_claim_type wouldn't have source {"must_not": [ {"term": {"reposted_claim_type": CLAIM_TYPES['stream']}} ], "must": [ {"term": {"claim_type": CLAIM_TYPES['repost']}} ] } } ]} } ) if kwargs.get('text'): query['must'].append( {"simple_query_string": {"query": kwargs["text"], "fields": [ "claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5" ]}}) query = { "_source": {"excludes": ["description", "title"]}, 'query': {'bool': query}, "sort": [], } if "limit" in kwargs: query["size"] = kwargs["limit"] if 'offset' in kwargs: query["from"] = kwargs["offset"] if 'order_by' in kwargs: if isinstance(kwargs["order_by"], str): kwargs["order_by"] = [kwargs["order_by"]] for value in kwargs['order_by']: if 'trending_group' in value: # fixme: trending_mixed is 0 for all records on variable decay, making sort slow. continue is_asc = value.startswith('^') value = value[1:] if is_asc else value value = REPLACEMENTS.get(value, value) if value in TEXT_FIELDS: value += '.keyword' query['sort'].append({value: "asc" if is_asc else "desc"}) if collapse: query["collapse"] = { "field": collapse[0], "inner_hits": { "name": collapse[0], "size": collapse[1], "sort": query["sort"] } } return query def expand_result(results): inner_hits = [] expanded = [] for result in results: if result.get("inner_hits"): for _, inner_hit in result["inner_hits"].items(): inner_hits.extend(inner_hit["hits"]["hits"]) continue result = result['_source'] result['claim_hash'] = bytes.fromhex(result['claim_id'])[::-1] if result['reposted_claim_id']: result['reposted_claim_hash'] = bytes.fromhex(result['reposted_claim_id'])[::-1] else: result['reposted_claim_hash'] = None result['channel_hash'] = bytes.fromhex(result['channel_id'])[::-1] if result['channel_id'] else None result['txo_hash'] = bytes.fromhex(result['tx_id'])[::-1] + struct.pack(' str: return self._result @result.setter def result(self, result: str): self._result = result if result is not None: self.has_result.set() @classmethod def from_cache(cls, cache_key, cache): cache_item = cache.get(cache_key) if cache_item is None: cache_item = cache[cache_key] = ResultCacheItem() return cache_item