import asyncio import struct from binascii import hexlify, unhexlify from decimal import Decimal from operator import itemgetter from typing import Optional, List, Iterable from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from elasticsearch.helpers import async_streaming_bulk from lbry.crypto.base58 import Base58 from lbry.error import ResolveCensoredError, claim_id as parse_claim_id from lbry.schema.result import Outputs, Censor from lbry.schema.tags import clean_tags from lbry.schema.url import URL, normalize_name from lbry.utils import LRUCache from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES from lbry.wallet.server.util import class_logger class ChannelResolution(str): pass class StreamResolution(str): pass class SearchIndex: def __init__(self, index_prefix: str, search_timeout=3.0): self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None self.client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2 ** 15) # invalidated on touched self.short_id_cache = LRUCache(2 ** 17) # never invalidated, since short ids are forever self.search_cache = LRUCache(2 ** 17) # fixme: dont let session manager replace it self.resolution_cache = LRUCache(2 ** 17) async def start(self): if self.client: return self.client = AsyncElasticsearch(timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(timeout=self.search_timeout) while True: try: await self.client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) res = await self.client.indices.create( self.index, { "settings": {"analysis": {"analyzer": { "default": {"tokenizer": "whitespace", "filter": ["lowercase", "porter_stem"]}}}, "index": {"refresh_interval": -1, "number_of_shards": 1, "number_of_replicas": 0} }, "mappings": { "properties": { "claim_id": { "fields": { "keyword": { "ignore_above": 256, "type": "keyword" } }, "type": "text", "index_prefixes": { "min_chars": 1, "max_chars": 10 } }, "height": {"type": "integer"}, "claim_type": {"type": "byte"}, "censor_type": {"type": "byte"}, "trending_mixed": {"type": "float"}, } } }, ignore=400 ) return res.get('acknowledged', False) def stop(self): clients = [self.client, self.search_client] self.client, self.search_client = None, None return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) def delete_index(self): return self.client.indices.delete(self.index, ignore_unavailable=True) async def _consume_claim_producer(self, claim_producer): count = 0 for op, doc in claim_producer: if op == 'delete': yield {'_index': self.index, '_op_type': 'delete', '_id': doc} else: yield extract_doc(doc, self.index) count += 1 if count % 100 == 0: self.logger.info("Indexing in progress, %d claims.", count) self.logger.info("Indexing done for %d claims.", count) async def claim_consumer(self, claim_producer): await self.client.indices.refresh(self.index) touched = set() async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer), raise_on_error=False): if not ok: self.logger.warning("indexing failed for an item: %s", item) else: item = item.popitem()[1] touched.add(item['_id']) await self.client.indices.refresh(self.index) await self.client.indices.flush(self.index) self.logger.info("Indexing done.") async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): def make_query(censor_type, blockdict, channels=False): blockdict = dict( (hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items()) if channels: update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") else: update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") key = 'channel_id' if channels else 'claim_id' update['script'] = { "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]", "lang": "painless", "params": blockdict } return update if filtered_streams: await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4) await self.client.indices.refresh(self.index) if filtered_channels: await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4) await self.client.indices.refresh(self.index) await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4) await self.client.indices.refresh(self.index) if blocked_streams: await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4) await self.client.indices.refresh(self.index) if blocked_channels: await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4) await self.client.indices.refresh(self.index) await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4) await self.client.indices.refresh(self.index) self.search_cache.clear() self.claim_cache.clear() self.resolution_cache.clear() async def delete_above_height(self, height): await self.client.delete_by_query(self.index, expand_query(height='>'+str(height))) await self.client.indices.refresh(self.index) async def session_query(self, query_name, kwargs): offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0 total_referenced = [] if query_name == 'resolve': total_referenced, response, censor = await self.resolve(*kwargs) else: cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) if cache_item.result is not None: return cache_item.result async with cache_item.lock: if cache_item.result: return cache_item.result censor = Censor(Censor.SEARCH) response, offset, total = await self.search(**kwargs) censor.apply(response) total_referenced.extend(response) if censor.censored: response, _, _ = await self.search(**kwargs, censor_type=0) total_referenced.extend(response) result = Outputs.to_base64( response, await self._get_referenced_rows(total_referenced), offset, total, censor ) cache_item.result = result return result return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor) async def resolve(self, *urls): censor = Censor(Censor.RESOLVE) results = [await self.resolve_url(url) for url in urls] missing = await self.get_many(*filter(lambda x: isinstance(x, str), results)) for index in range(len(results)): result = results[index] url = urls[index] if missing.get(result): results[index] = missing[result] elif isinstance(result, StreamResolution): results[index] = LookupError(f'Could not find claim at "{url}".') elif isinstance(result, ChannelResolution): results[index] = LookupError(f'Could not find channel in "{url}".') censored = [ result if not isinstance(result, dict) or not censor.censor(result) else ResolveCensoredError(url, result['censoring_channel_hash']) for url, result in zip(urls, results) ] return results, censored, censor async def get_many(self, *claim_ids): missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache] if missing: results = await self.search_client.mget( index=self.index, body={"ids": missing}, _source_excludes=['description', 'title'] ) results = expand_result(filter(lambda doc: doc['found'], results["docs"])) for result in results: self.claim_cache.set(result['claim_id'], result) return {claim_id: self.claim_cache[claim_id] for claim_id in claim_ids if claim_id in self.claim_cache} async def full_id_from_short_id(self, name, short_id, channel_id=None): key = (channel_id or '') + name + short_id if key not in self.short_id_cache: query = {'name': name, 'claim_id': short_id} if channel_id: query['channel_id'] = channel_id query['order_by'] = ['^channel_join'] query['channel_id'] = channel_id query['signature_valid'] = True else: query['order_by'] = '^creation_height' result, _, _ = await self.search(**query, limit=1) if len(result) == 1: result = result[0]['claim_id'] self.short_id_cache[key] = result return self.short_id_cache.get(key, None) async def search(self, **kwargs): if 'channel' in kwargs: channel_id = await self.resolve_url(kwargs.pop('channel')) if not channel_id or not isinstance(channel_id, str): return [], 0, 0 kwargs['channel_id'] = channel_id try: result = await self.search_client.search( expand_query(**kwargs), index=self.index, track_total_hits=200 ) except NotFoundError: # index has no docs, fixme: log something return [], 0, 0 return expand_result(result['hits']['hits']), 0, result['hits']['total']['value'] async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: self.resolution_cache[raw_url] = await self._resolve_url(raw_url) return self.resolution_cache[raw_url] async def _resolve_url(self, raw_url): try: url = URL.parse(raw_url) except ValueError as e: return e stream = LookupError(f'Could not find claim at "{raw_url}".') channel_id = await self.resolve_channel_id(url) if isinstance(channel_id, LookupError): return channel_id stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream if url.has_stream: return StreamResolution(stream) else: return ChannelResolution(channel_id) async def resolve_channel_id(self, url: URL): if not url.has_channel: return if url.channel.is_fullid: return url.channel.claim_id if url.channel.is_shortid: channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id) if not channel_id: return LookupError(f'Could not find channel in "{url}".') return channel_id query = url.channel.to_dict() if set(query) == {'name'}: query['is_controlling'] = True else: query['order_by'] = ['^creation_height'] matches, _, _ = await self.search(**query, limit=1) if matches: channel_id = matches[0]['claim_id'] else: return LookupError(f'Could not find channel in "{url}".') return channel_id async def resolve_stream(self, url: URL, channel_id: str = None): if not url.has_stream: return None if url.has_channel and channel_id is None: return None query = url.stream.to_dict() if url.stream.claim_id is not None: if url.stream.is_fullid: claim_id = url.stream.claim_id else: claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) return claim_id if channel_id is not None: if set(query) == {'name'}: # temporarily emulate is_controlling for claims in channel query['order_by'] = ['effective_amount', '^height'] else: query['order_by'] = ['^channel_join'] query['channel_id'] = channel_id query['signature_valid'] = True elif set(query) == {'name'}: query['is_controlling'] = True matches, _, _ = await self.search(**query, limit=1) if matches: return matches[0]['claim_id'] async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] repost_hashes = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) channel_hashes = set(filter(None, (row['channel_id'] for row in txo_rows))) channel_hashes |= set(map(parse_claim_id, filter(None, (row['censoring_channel_hash'] for row in txo_rows)))) reposted_txos = [] if repost_hashes: reposted_txos = list((await self.get_many(*repost_hashes)).values()) channel_hashes |= set(filter(None, (row['channel_id'] for row in reposted_txos))) channel_txos = [] if channel_hashes: channel_txos = list((await self.get_many(*channel_hashes)).values()) # channels must come first for client side inflation to work properly return channel_txos + reposted_txos def extract_doc(doc, index): doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode() if doc['reposted_claim_hash'] is not None: doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode() else: doc['reposted_claim_id'] = None channel_hash = doc.pop('channel_hash') doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash channel_hash = doc.pop('censoring_channel_hash') doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash txo_hash = doc.pop('txo_hash') doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode() doc['tx_nout'] = struct.unpack('=': 'gte', '<': 'lt', '>': 'gt'} if partial_id: query['must'].append({"prefix": {"claim_id": value}}) elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops: operator_length = 2 if value[:2] in ops else 1 operator, value = value[:operator_length], value[operator_length:] if key == 'fee_amount': value = str(Decimal(value)*1000) query['must'].append({"range": {key: {ops[operator]: value}}}) elif many: query['must'].append({"terms": {key: value}}) else: if key == 'fee_amount': value = str(Decimal(value)*1000) query['must'].append({"term": {key: {"value": value}}}) elif key == 'not_channel_ids': for channel_id in value: query['must_not'].append({"term": {'channel_id.keyword': channel_id}}) query['must_not'].append({"term": {'_id': channel_id}}) elif key == 'channel_ids': query['must'].append({"terms": {'channel_id.keyword': value}}) elif key == 'claim_ids': query['must'].append({"terms": {'claim_id.keyword': value}}) elif key == 'media_types': query['must'].append({"terms": {'media_type.keyword': value}}) elif key == 'stream_types': query['must'].append({"terms": {'stream_type': [STREAM_TYPES[stype] for stype in value]}}) elif key == 'any_languages': query['must'].append({"terms": {'languages': clean_tags(value)}}) elif key == 'any_languages': query['must'].append({"terms": {'languages': value}}) elif key == 'all_languages': query['must'].extend([{"term": {'languages': tag}} for tag in value]) elif key == 'any_tags': query['must'].append({"terms": {'tags.keyword': clean_tags(value)}}) elif key == 'all_tags': query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) elif key == 'not_tags': query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)]) elif key == 'not_claim_id': query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value]) elif key == 'limit_claims_per_channel': collapse = ('channel_id.keyword', value) if kwargs.get('has_channel_signature'): query['must'].append({"exists": {"field": "signature_digest"}}) if 'signature_valid' in kwargs: query['must'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}}) elif 'signature_valid' in kwargs: query.setdefault('should', []) query["minimum_should_match"] = 1 query['should'].append({"bool": {"must_not": {"exists": {"field": "signature_digest"}}}}) query['should'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}}) if kwargs.get('text'): query['must'].append( {"simple_query_string": {"query": kwargs["text"], "fields": [ "claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5" ]}}) query = { "_source": {"excludes": ["description", "title"]}, 'query': {'bool': query}, "sort": [], } if "limit" in kwargs: query["size"] = kwargs["limit"] if 'offset' in kwargs: query["from"] = kwargs["offset"] if 'order_by' in kwargs: if isinstance(kwargs["order_by"], str): kwargs["order_by"] = [kwargs["order_by"]] for value in kwargs['order_by']: if 'trending_group' in value: # fixme: trending_mixed is 0 for all records on variable decay, making sort slow. continue is_asc = value.startswith('^') value = value[1:] if is_asc else value value = REPLACEMENTS.get(value, value) if value in TEXT_FIELDS: value += '.keyword' query['sort'].append({value: "asc" if is_asc else "desc"}) if collapse: query["collapse"] = { "field": collapse[0], "inner_hits": { "name": collapse[0], "size": collapse[1], "sort": query["sort"] } } return query def expand_result(results): inner_hits = [] expanded = [] for result in results: if result.get("inner_hits"): for _, inner_hit in result["inner_hits"].items(): inner_hits.extend(inner_hit["hits"]["hits"]) continue result = result['_source'] result['claim_hash'] = unhexlify(result['claim_id'])[::-1] if result['reposted_claim_id']: result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1] else: result['reposted_claim_hash'] = None result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack(' str: return self._result @result.setter def result(self, result: str): self._result = result if result is not None: self.has_result.set() @classmethod def from_cache(cls, cache_key, cache): cache_item = cache.get(cache_key) if cache_item is None: cache_item = cache[cache_key] = ResultCacheItem() return cache_item