claim search and resolve translated to ES queries

This commit is contained in:
Victor Shyba 2021-01-19 04:37:31 -03:00
parent 488785d013
commit 996686c1da
8 changed files with 335 additions and 57 deletions

View file

@ -11,6 +11,7 @@ import importlib
from binascii import hexlify
from typing import Type, Optional
import urllib.request
from uuid import uuid4
import lbry
from lbry.wallet.server.server import Server
@ -187,7 +188,8 @@ class SPVNode:
'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0',
'INDIVIDUAL_TAG_INDEXES': '',
'RPC_PORT': self.rpc_port
'RPC_PORT': self.rpc_port,
'ES_INDEX_PREFIX': uuid4().hex
}
if extraconf:
conf.update(extraconf)
@ -199,6 +201,7 @@ class SPVNode:
async def stop(self, cleanup=True):
try:
await self.server.db.search_index.delete_index()
await self.server.stop()
finally:
cleanup and self.cleanup()

View file

@ -6,6 +6,7 @@ from typing import Optional
from prometheus_client import Gauge, Histogram
import lbry
from lbry.schema.claim import Claim
from lbry.wallet.server.db.elastic_search import SearchIndex
from lbry.wallet.server.db.writer import SQLDB
from lbry.wallet.server.daemon import DaemonError
from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN
@ -215,6 +216,7 @@ class BlockProcessor:
if hprevs == chain:
start = time.perf_counter()
await self.run_in_thread_with_lock(self.advance_blocks, blocks)
await self.db.search_index.sync_queue(self.sql.claim_queue)
for cache in self.search_cache.values():
cache.clear()
self.history_cache.clear()
@ -651,7 +653,11 @@ class BlockProcessor:
self.reorg_count = 0
else:
blocks = self.prefetcher.get_prefetched_blocks()
await self.check_and_advance_blocks(blocks)
try:
await self.check_and_advance_blocks(blocks)
except Exception:
self.logger.exception("error while processing txs")
raise
async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}')
@ -803,18 +809,3 @@ class LBRYBlockProcessor(BlockProcessor):
if (height % 10000 == 0 or not self.db.first_sync) and self.logger.isEnabledFor(10):
self.timer.show(height=height)
return undo
def _checksig(self, value, address):
try:
claim_dict = Claim.from_bytes(value)
cert_id = claim_dict.signing_channel_hash
if not self.should_validate_signatures:
return cert_id
if cert_id:
cert_claim = self.db.get_claim_info(cert_id)
if cert_claim:
certificate = Claim.from_bytes(cert_claim.value)
claim_dict.validate_signature(address, certificate)
return cert_id
except Exception:
pass

View file

@ -1,40 +1,161 @@
import asyncio
import struct
from binascii import hexlify
from multiprocessing.queues import Queue
from binascii import hexlify, unhexlify
from decimal import Decimal
from operator import itemgetter
from typing import Optional, List, Iterable
from elasticsearch import AsyncElasticsearch
from elasticsearch import AsyncElasticsearch, NotFoundError
from elasticsearch.helpers import async_bulk
from lbry.wallet.constants import CLAIM_TYPE_NAMES
from lbry.crypto.base58 import Base58
from lbry.schema.result import Outputs
from lbry.schema.tags import clean_tags
from lbry.schema.url import URL
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
async def indexer_task(claim_queue: Queue, index='claims'):
es = AsyncElasticsearch()
try:
await consume(es, claim_queue, index)
finally:
await es.close()
class SearchIndex:
def __init__(self, index_prefix: str):
self.client: Optional[AsyncElasticsearch] = None
self.index = index_prefix + 'claims'
async def start(self):
self.client = AsyncElasticsearch()
try:
if await self.client.indices.exists(self.index):
return
await self.client.indices.create(
self.index,
{"settings":
{"analysis":
{"analyzer": {"porter": {"tokenizer": "whitespace", "filter": ["lowercase", "porter_stem" ]}}}
}
}
)
except Exception as e:
raise
async def consume(es, claim_queue, index):
to_send = []
while True:
if not claim_queue.empty():
def stop(self):
asyncio.ensure_future(self.client.close())
self.client = None
def delete_index(self):
return self.client.indices.delete(self.index)
async def sync_queue(self, claim_queue):
if claim_queue.empty():
return
to_delete, to_update = [], []
while not claim_queue.empty():
operation, doc = claim_queue.get_nowait()
if operation == 'delete':
to_send.append({'_index': index, '_op_type': 'delete', '_id': hexlify(doc[::-1]).decode()})
continue
try:
to_send.append(extract_doc(doc, index))
except OSError as e:
print(e)
else:
if to_send:
print(await async_bulk(es, to_send, raise_on_error=False))
to_send.clear()
to_delete.append(doc)
else:
await asyncio.sleep(.1)
to_update.append(doc)
await self.delete(to_delete)
await self.update(to_update)
await self.client.indices.refresh(self.index)
async def update(self, claims):
if not claims:
return
actions = [extract_doc(claim, self.index) for claim in claims]
await async_bulk(self.client, actions)
async def delete(self, claim_ids):
if not claim_ids:
return
actions = [{'_index': self.index, '_op_type': 'delete', '_id': claim_id} for claim_id in claim_ids]
await async_bulk(self.client, actions)
update = expand_query(channel_id__in=claim_ids)
update['script'] = {
"source": "ctx._source.signature_valid=false",
"lang": "painless"
}
await self.client.update_by_query(self.index, body=update)
async def session_query(self, query_name, function, kwargs):
offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
if query_name == 'resolve':
response = await self.resolve(*kwargs)
else:
response, offset, total = await self.search(**kwargs)
return Outputs.to_base64(response, await self._get_referenced_rows(response), offset, total)
async def resolve(self, *urls):
results = await asyncio.gather(*(self.resolve_url(url) for url in urls))
return results
async def search(self, **kwargs):
if 'channel' in kwargs:
result = await self.resolve_url(kwargs.pop('channel'))
if not result or not isinstance(result, Iterable):
return [], 0, 0
kwargs['channel_id'] = result['_id']
try:
result = await self.client.search(expand_query(**kwargs), self.index)
except NotFoundError:
# index has no docs, fixme: log something
return [], 0, 0
return expand_result(result['hits']['hits']), 0, result['hits']['total']['value']
async def resolve_url(self, raw_url):
try:
url = URL.parse(raw_url)
except ValueError as e:
return e
channel = None
if url.has_channel:
query = url.channel.to_dict()
if set(query) == {'name'}:
query['is_controlling'] = True
else:
query['order_by'] = ['^creation_height']
matches, _, _ = await self.search(**query, limit=1)
if matches:
channel = matches[0]
else:
return LookupError(f'Could not find channel in "{raw_url}".')
if url.has_stream:
query = url.stream.to_dict()
if channel is not None:
if set(query) == {'name'}:
# temporarily emulate is_controlling for claims in channel
query['order_by'] = ['effective_amount', '^height']
else:
query['order_by'] = ['^channel_join']
query['channel_hash'] = channel['claim_hash']
query['signature_valid'] = True
elif set(query) == {'name'}:
query['is_controlling'] = True
matches, _, _ = await self.search(**query, limit=1)
if matches:
return matches[0]
else:
return LookupError(f'Could not find claim at "{raw_url}".')
return channel
async def _get_referenced_rows(self, txo_rows: List[dict]):
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
repost_hashes = set(filter(None, map(itemgetter('reposted_claim_hash'), txo_rows)))
channel_hashes = set(filter(None, (row['channel_hash'] for row in txo_rows)))
reposted_txos = []
if repost_hashes:
reposted_txos, _, _ = await self.search(**{'claim.claim_hash__in': repost_hashes})
channel_hashes |= set(filter(None, (row['channel_hash'] for row in reposted_txos)))
channel_txos = []
if channel_hashes:
channel_txos, _, _ = await self.search(**{'claim.claim_hash__in': channel_hashes})
# channels must come first for client side inflation to work properly
return channel_txos + reposted_txos
def extract_doc(doc, index):
@ -42,7 +163,7 @@ def extract_doc(doc, index):
if doc['reposted_claim_hash'] is not None:
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode()
else:
doc['reposted_claim_hash'] = None
doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash')
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash')
@ -54,9 +175,152 @@ def extract_doc(doc, index):
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None
doc['signature_valid'] = bool(doc['signature_valid'])
if doc['claim_type'] is None:
doc['claim_type'] = 'invalid'
else:
doc['claim_type'] = CLAIM_TYPE_NAMES[doc['claim_type']]
doc['claim_type'] = doc.get('claim_type', 0) or 0
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
return {'doc': doc, '_id': doc['claim_id'], '_index': index, '_op_type': 'update',
'doc_as_upsert': True}
FIELDS = ['is_controlling', 'last_take_over_height', 'claim_id', 'claim_name', 'normalized', 'tx_position', 'amount',
'timestamp', 'creation_timestamp', 'height', 'creation_height', 'activation_height', 'expiration_height',
'release_time', 'short_url', 'canonical_url', 'title', 'author', 'description', 'claim_type', 'reposted',
'stream_type', 'media_type', 'fee_amount', 'fee_currency', 'duration', 'reposted_claim_hash',
'claims_in_channel', 'channel_join', 'signature_valid', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed', 'trending_local', 'trending_global', 'channel_id', 'tx_id', 'tx_nout',
'signature', 'signature_digest', 'public_key_bytes', 'public_key_hash', 'public_key_id', '_id', 'tags']
TEXT_FIELDS = ['author', 'canonical_url', 'channel_id', 'claim_id', 'claim_name', 'description',
'media_type', 'normalized', 'public_key_bytes', 'public_key_hash', 'short_url', 'signature',
'signature_digest', 'stream_type', 'title', 'tx_id', 'fee_currency']
RANGE_FIELDS = ['height', 'fee_amount', 'duration']
REPLACEMENTS = {
'name': 'claim_name',
'txid': 'tx_id',
'claim_hash': '_id',
}
def expand_query(**kwargs):
query = {'must': [], 'must_not': []}
collapse = None
for key, value in kwargs.items():
key = key.replace('claim.', '')
many = key.endswith('__in')
if many:
key = key.replace('__in', '')
key = REPLACEMENTS.get(key, key)
if key in FIELDS:
if key == 'claim_type':
if isinstance(value, str):
value = CLAIM_TYPES[value]
else:
value = [CLAIM_TYPES[claim_type] for claim_type in value]
if key == '_id':
if isinstance(value, Iterable):
value = [hexlify(item[::-1]).decode() for item in value]
else:
value = hexlify(value[::-1]).decode()
if key == 'public_key_id':
key = 'public_key_hash'
value = hexlify(Base58.decode(value)[1:21]).decode()
if key == 'signature_valid':
continue # handled later
if key in TEXT_FIELDS:
key += '.keyword'
ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'}
if key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops:
operator_length = 2 if value[:2] in ops else 1
operator, value = value[:operator_length], value[operator_length:]
if key == 'fee_amount':
value = Decimal(value)*1000
query['must'].append({"range": {key: {ops[operator]: value}}})
elif many:
query['must'].append({"terms": {key: value}})
else:
if key == 'fee_amount':
value = Decimal(value)*1000
query['must'].append({"term": {key: {"value": value}}})
elif key == 'not_channel_ids':
for channel_id in value:
query['must_not'].append({"term": {'channel_id.keyword': channel_id}})
query['must_not'].append({"term": {'_id': channel_id}})
elif key == 'channel_ids':
query['must'].append({"terms": {'channel_id.keyword': value}})
elif key == 'media_types':
query['must'].append({"terms": {'media_type.keyword': value}})
elif key == 'stream_types':
query['must'].append({"terms": {'stream_type': [STREAM_TYPES[stype] for stype in value]}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': clean_tags(value)}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': value}})
elif key == 'all_languages':
query['must'].extend([{"term": {'languages': tag}} for tag in value])
elif key == 'any_tags':
query['must'].append({"terms": {'tags': clean_tags(value)}})
elif key == 'all_tags':
query['must'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
elif key == 'not_tags':
query['must_not'].extend([{"term": {'tags': tag}} for tag in clean_tags(value)])
elif key == 'limit_claims_per_channel':
collapse = ('channel_id.keyword', value)
if kwargs.get('has_channel_signature'):
query['must'].append({"exists": {"field": "signature_digest"}})
if 'signature_valid' in kwargs:
query['must'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
elif 'signature_valid' in kwargs:
query.setdefault('should', [])
query["minimum_should_match"] = 1
query['should'].append({"bool": {"must_not": {"exists": {"field": "signature_digest"}}}})
query['should'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
if 'text' in kwargs:
return {"query":
{"query_string":
{"query": kwargs["text"], "fields": [
"claim_name", "channel_name", "title", "description", "author", "tags"
], "analyzer": "porter"}}}
query = {
'query': {'bool': query},
"sort": [],
}
if "limit" in kwargs:
query["size"] = kwargs["limit"]
if 'offset' in kwargs:
query["from"] = kwargs["offset"]
if 'order_by' in kwargs:
for value in kwargs['order_by']:
is_asc = value.startswith('^')
value = value[1:] if is_asc else value
value = REPLACEMENTS.get(value, value)
if value in TEXT_FIELDS:
value += '.keyword'
query['sort'].append({value: "asc" if is_asc else "desc"})
if collapse:
query["collapse"] = {
"field": collapse[0],
"inner_hits": {
"name": collapse[0],
"size": collapse[1],
"sort": query["sort"]
}
}
return query
def expand_result(results):
inner_hits = []
for result in results:
if result.get("inner_hits"):
for _, inner_hit in result["inner_hits"].items():
inner_hits.extend(inner_hit["hits"]["hits"])
continue
result.update(result.pop('_source'))
result['claim_hash'] = unhexlify(result['claim_id'])[::-1]
if result['reposted_claim_id']:
result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1]
else:
result['reposted_claim_hash'] = None
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
if inner_hits:
return expand_result(inner_hits)
return results

View file

@ -5,7 +5,7 @@ from itertools import chain
from decimal import Decimal
from collections import namedtuple
from multiprocessing import Manager, Queue
from binascii import unhexlify
from binascii import unhexlify, hexlify
from lbry.wallet.server.leveldb import LevelDB
from lbry.wallet.server.util import class_logger
from lbry.wallet.database import query, constraints_to_sql
@ -19,6 +19,7 @@ from lbry.wallet.server.db.full_text_search import update_full_text_search, CREA
from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES
from .elastic_search import SearchIndex
ATTRIBUTE_ARRAY_MAX_LENGTH = 100
@ -807,9 +808,18 @@ class SQLDB:
def enqueue_changes(self, changed_claim_hashes, deleted_claims):
if not changed_claim_hashes and not deleted_claims:
return
for claim_hash in deleted_claims:
if not self.claim_queue.full():
self.claim_queue.put_nowait(('delete', claim_hash))
tags = {}
langs = {}
for claim_hash, tag in self.execute(
f"select claim_hash, tag from tag "
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
tags.setdefault(claim_hash, [])
tags[claim_hash].append(tag)
for claim_hash, lang in self.execute(
f"select claim_hash, language from language "
f"WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})", changed_claim_hashes):
langs.setdefault(claim_hash, [])
langs[claim_hash].append(lang)
for claim in self.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
@ -817,8 +827,14 @@ class SQLDB:
FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim_hash IN ({','.join('?' for _ in changed_claim_hashes)})
""", changed_claim_hashes):
claim = dict(claim._asdict())
claim['tags'] = tags.get(claim['claim_hash'], [])
claim['languages'] = langs.get(claim['claim_hash'], [])
if not self.claim_queue.full():
self.claim_queue.put_nowait(('update', dict(claim._asdict())))
self.claim_queue.put_nowait(('update', claim))
for claim_hash in deleted_claims:
if not self.claim_queue.full():
self.claim_queue.put_nowait(('delete', hexlify(claim_hash[::-1]).decode()))
def advance_txs(self, height, all_txs, header, daemon_height, timer):
insert_claims = []
@ -915,7 +931,7 @@ class SQLDB:
if not self._fts_synced and self.main.first_sync and height == daemon_height:
r(first_sync_finished, self.db.cursor())
self._fts_synced = True
r(self.enqueue_changes, recalculate_claim_hashes, delete_claim_hashes)
r(self.enqueue_changes, recalculate_claim_hashes | affected_channels, delete_claim_hashes)
class LBRYLevelDB(LevelDB):
@ -934,10 +950,14 @@ class LBRYLevelDB(LevelDB):
trending
)
# Search index
self.search_index = SearchIndex(self.env.es_index_prefix)
def close(self):
super().close()
self.sql.close()
async def _open_dbs(self, *args, **kwargs):
await self.search_index.start()
await super()._open_dbs(*args, **kwargs)
self.sql.open()

View file

@ -53,6 +53,7 @@ class Env:
coin_name = self.required('COIN').strip()
network = self.default('NET', 'mainnet').strip()
self.coin = Coin.lookup_coin_class(coin_name, network)
self.es_index_prefix = self.default('ES_INDEX_PREFIX', '')
self.cache_MB = self.integer('CACHE_MB', 1200)
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
# Server stuff

View file

@ -5,7 +5,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
import typing
import lbry
from lbry.wallet.server.db.elastic_search import indexer_task
from lbry.wallet.server.mempool import MemPool, MemPoolAPI
from lbry.prometheus import PrometheusServer

View file

@ -1041,9 +1041,7 @@ class LBRYElectrumX(SessionBase):
return cache_item.result
async with cache_item.lock:
if cache_item.result is None:
cache_item.result = await self.run_in_executor(
query_name, function, kwargs
)
cache_item.result = await self.db.search_index.session_query(query_name, function, kwargs)
else:
metrics = self.get_metrics_or_placeholder_for_api(query_name)
metrics.cache_response()

View file

@ -3,6 +3,7 @@ import tempfile
import logging
import asyncio
from binascii import unhexlify
from unittest import skip
from urllib.request import urlopen
from lbry.error import InsufficientFundsError
@ -75,6 +76,7 @@ class ClaimSearchCommand(ClaimTestCase):
(result['txid'], result['claim_id'])
)
@skip("doesnt happen on ES...?")
async def test_disconnect_on_memory_error(self):
claim_ids = [
'0000000000000000000000000000000000000000',