lbry-sdk/lbry/wallet/server/db/elastic_search.py
2021-03-19 19:58:13 -03:00

326 lines
14 KiB
Python

import asyncio
import struct
from binascii import hexlify, unhexlify
from decimal import Decimal
from operator import itemgetter
from typing import Optional, List, Iterable
from elasticsearch import AsyncElasticsearch, NotFoundError
from elasticsearch.helpers import async_bulk
from lbry.crypto.base58 import Base58
from lbry.schema.result import Outputs
from lbry.schema.tags import clean_tags
from lbry.schema.url import URL
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
class SearchIndex:
def __init__(self, index_prefix: str):
self.client: Optional[AsyncElasticsearch] = None
self.index = index_prefix + 'claims'
async def start(self):
self.client = AsyncElasticsearch()
try:
if await self.client.indices.exists(self.index):
return
await self.client.indices.create(
self.index,
{"settings":
{"analysis":
{"analyzer": {"porter": {"tokenizer": "whitespace", "filter": ["lowercase", "porter_stem" ]}}}
}
}
)
except Exception as e:
raise
def stop(self):
asyncio.ensure_future(self.client.close())
self.client = None
def delete_index(self):
return self.client.indices.delete(self.index)
async def sync_queue(self, claim_queue):
if claim_queue.empty():
return
to_delete, to_update = [], []
while not claim_queue.empty():
operation, doc = claim_queue.get_nowait()
if operation == 'delete':
to_delete.append(doc)
else:
to_update.append(doc)
await self.delete(to_delete)
await self.update(to_update)
await self.client.indices.refresh(self.index)
async def update(self, claims):
if not claims:
return
actions = [extract_doc(claim, self.index) for claim in claims]
await async_bulk(self.client, actions)
async def delete(self, claim_ids):
if not claim_ids:
return
actions = [{'_index': self.index, '_op_type': 'delete', '_id': claim_id} for claim_id in claim_ids]
await async_bulk(self.client, actions)
update = expand_query(channel_id__in=claim_ids)
update['script'] = {
"source": "ctx._source.signature_valid=false",
"lang": "painless"
}
await self.client.update_by_query(self.index, body=update)
async def session_query(self, query_name, function, kwargs):
offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
if query_name == 'resolve':
response = await self.resolve(*kwargs)
else:
response, offset, total = await self.search(**kwargs)
return Outputs.to_base64(response, await self._get_referenced_rows(response), offset, total)
async def resolve(self, *urls):
results = await asyncio.gather(*(self.resolve_url(url) for url in urls))
return results
async def search(self, **kwargs):
if 'channel' in kwargs:
result = await self.resolve_url(kwargs.pop('channel'))
if not result or not isinstance(result, Iterable):
return [], 0, 0
kwargs['channel_id'] = result['_id']
try:
result = await self.client.search(expand_query(**kwargs), self.index)
except NotFoundError:
# index has no docs, fixme: log something
return [], 0, 0
return expand_result(result['hits']['hits']), 0, result['hits']['total']['value']
async def resolve_url(self, raw_url):
try:
url = URL.parse(raw_url)
except ValueError as e:
return e
channel = None
if url.has_channel:
query = url.channel.to_dict()
if set(query) == {'name'}:
query['is_controlling'] = True
else:
query['order_by'] = ['^creation_height']
matches, _, _ = await self.search(**query, limit=1)
if matches:
channel = matches[0]
else:
return LookupError(f'Could not find channel in "{raw_url}".')
if url.has_stream:
query = url.stream.to_dict()
if channel is not None:
if set(query) == {'name'}:
# temporarily emulate is_controlling for claims in channel
query['order_by'] = ['effective_amount', '^height']
else:
query['order_by'] = ['^channel_join']
query['channel_hash'] = channel['claim_hash']
query['signature_valid'] = True
elif set(query) == {'name'}:
query['is_controlling'] = True
matches, _, _ = await self.search(**query, limit=1)
if matches:
return matches[0]
else:
return LookupError(f'Could not find claim at "{raw_url}".')
return channel
async def _get_referenced_rows(self, txo_rows: List[dict]):
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(filter(None, (row['channel_hash'] for row in txo_rows)))
reposted_txos = []
if repost_hashes:
reposted_txos, _, _ = await self.search(**{'claim.claim_hash__in': repost_hashes})
channel_hashes |= set(filter(None, (row['channel_hash'] for row in reposted_txos)))
channel_txos = []
if channel_hashes:
channel_txos, _, _ = await self.search(**{'claim.claim_hash__in': channel_hashes})
# channels must come first for client side inflation to work properly
return channel_txos + reposted_txos
def extract_doc(doc, index):
doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode()
if doc['reposted_claim_hash'] is not None:
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode()
else:
doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash')
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash')
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
doc['is_controlling'] = bool(doc['is_controlling'])
doc['signature'] = hexlify(doc.pop('signature') or b'').decode() or None
doc['signature_digest'] = hexlify(doc.pop('signature_digest') or b'').decode() or None
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None
doc['signature_valid'] = bool(doc['signature_valid'])
doc['claim_type'] = doc.get('claim_type', 0) or 0
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
return {'doc': doc, '_id': doc['claim_id'], '_index': index, '_op_type': 'update',
'doc_as_upsert': True}
FIELDS = ['is_controlling', 'last_take_over_height', 'claim_id', 'claim_name', 'normalized', 'tx_position', 'amount',
'timestamp', 'creation_timestamp', 'height', 'creation_height', 'activation_height', 'expiration_height',
'release_time', 'short_url', 'canonical_url', 'title', 'author', 'description', 'claim_type', 'reposted',
'stream_type', 'media_type', 'fee_amount', 'fee_currency', 'duration', 'reposted_claim_hash',
'claims_in_channel', 'channel_join', 'signature_valid', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed', 'trending_local', 'trending_global', 'channel_id', 'tx_id', 'tx_nout',
'signature', 'signature_digest', 'public_key_bytes', 'public_key_hash', 'public_key_id', '_id', 'tags']
TEXT_FIELDS = ['author', 'canonical_url', 'channel_id', 'claim_id', 'claim_name', 'description',
'media_type', 'normalized', 'public_key_bytes', 'public_key_hash', 'short_url', 'signature',
'signature_digest', 'stream_type', 'title', 'tx_id', 'fee_currency']
RANGE_FIELDS = ['height', 'fee_amount', 'duration']
REPLACEMENTS = {
'name': 'claim_name',
'txid': 'tx_id',
'claim_hash': '_id',
}
def expand_query(**kwargs):
query = {'must': [], 'must_not': []}
collapse = None
for key, value in kwargs.items():
key = key.replace('claim.', '')
many = key.endswith('__in')
if many:
key = key.replace('__in', '')
key = REPLACEMENTS.get(key, key)
if key in FIELDS:
if key == 'claim_type':
if isinstance(value, str):
value = CLAIM_TYPES[value]
else:
value = [CLAIM_TYPES[claim_type] for claim_type in value]
if key == '_id':
if isinstance(value, Iterable):
value = [hexlify(item[::-1]).decode() for item in value]
else:
value = hexlify(value[::-1]).decode()
if key == 'public_key_id':
key = 'public_key_hash'
value = hexlify(Base58.decode(value)[1:21]).decode()
if key == 'signature_valid':
continue # handled later
if key in TEXT_FIELDS:
key += '.keyword'
ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'}
if key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops:
operator_length = 2 if value[:2] in ops else 1
operator, value = value[:operator_length], value[operator_length:]
if key == 'fee_amount':
value = Decimal(value)*1000
query['must'].append({"range": {key: {ops[operator]: value}}})
elif many:
query['must'].append({"terms": {key: value}})
else:
if key == 'fee_amount':
value = Decimal(value)*1000
query['must'].append({"term": {key: {"value": value}}})
elif key == 'not_channel_ids':
for channel_id in value:
query['must_not'].append({"term": {'channel_id.keyword': channel_id}})
query['must_not'].append({"term": {'_id': channel_id}})
elif key == 'channel_ids':
query['must'].append({"terms": {'channel_id.keyword': value}})
elif key == 'media_types':
query['must'].append({"terms": {'media_type.keyword': value}})
elif key == 'stream_types':
query['must'].append({"terms": {'stream_type': [STREAM_TYPES[stype] for stype in value]}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': clean_tags(value)}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': value}})
elif key == 'all_languages':
query['must'].extend([{"term": {'languages': tag}} for tag in value])
elif key == 'any_tags':
query['must'].append({"terms": {'tags': clean_tags(value)}})
elif key == 'all_tags':
query['must'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
elif key == 'not_tags':
query['must_not'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
elif key == 'limit_claims_per_channel':
collapse = ('channel_id.keyword', value)
if kwargs.get('has_channel_signature'):
query['must'].append({"exists": {"field": "signature_digest"}})
if 'signature_valid' in kwargs:
query['must'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
elif 'signature_valid' in kwargs:
query.setdefault('should', [])
query["minimum_should_match"] = 1
query['should'].append({"bool": {"must_not": {"exists": {"field": "signature_digest"}}}})
query['should'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
if 'text' in kwargs:
return {"query":
{"query_string":
{"query": kwargs["text"], "fields": [
"claim_name", "channel_name", "title", "description", "author", "tags"
], "analyzer": "porter"}}}
query = {
'query': {'bool': query},
"sort": [],
}
if "limit" in kwargs:
query["size"] = kwargs["limit"]
if 'offset' in kwargs:
query["from"] = kwargs["offset"]
if 'order_by' in kwargs:
for value in kwargs['order_by']:
is_asc = value.startswith('^')
value = value[1:] if is_asc else value
value = REPLACEMENTS.get(value, value)
if value in TEXT_FIELDS:
value += '.keyword'
query['sort'].append({value: "asc" if is_asc else "desc"})
if collapse:
query["collapse"] = {
"field": collapse[0],
"inner_hits": {
"name": collapse[0],
"size": collapse[1],
"sort": query["sort"]
}
}
return query
def expand_result(results):
inner_hits = []
for result in results:
if result.get("inner_hits"):
for _, inner_hit in result["inner_hits"].items():
inner_hits.extend(inner_hit["hits"]["hits"])
continue
result.update(result.pop('_source'))
result['claim_hash'] = unhexlify(result['claim_id'])[::-1]
if result['reposted_claim_id']:
result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1]
else:
result['reposted_claim_hash'] = None
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
if inner_hits:
return expand_result(inner_hits)
return results