0daf205cb0
-used to set references in Outputs, allows for faster serialization of resolve/claim_search responses -requires resyncing elasticsearch
849 lines
37 KiB
Python
849 lines
37 KiB
Python
import logging
|
|
import time
|
|
import asyncio
|
|
import struct
|
|
from binascii import unhexlify
|
|
from collections import Counter, deque
|
|
from decimal import Decimal
|
|
from operator import itemgetter
|
|
from typing import Optional, List, Iterable
|
|
|
|
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
|
|
from elasticsearch.helpers import async_streaming_bulk
|
|
from scribe.schema.result import Outputs, Censor
|
|
from scribe.schema.tags import clean_tags
|
|
from scribe.schema.url import normalize_name
|
|
from scribe.error import TooManyClaimSearchParametersError
|
|
from scribe.common import LRUCache
|
|
from scribe.db.common import CLAIM_TYPES, STREAM_TYPES
|
|
from scribe.elasticsearch.constants import INDEX_DEFAULT_SETTINGS, REPLACEMENTS, FIELDS, TEXT_FIELDS, \
|
|
RANGE_FIELDS, ALL_FIELDS
|
|
from scribe.db.common import ResolveResult
|
|
|
|
|
|
def expand_query(**kwargs):
|
|
if "amount_order" in kwargs:
|
|
kwargs["limit"] = 1
|
|
kwargs["order_by"] = "effective_amount"
|
|
kwargs["offset"] = int(kwargs["amount_order"]) - 1
|
|
if 'name' in kwargs:
|
|
kwargs['name'] = normalize_name(kwargs.pop('name'))
|
|
if kwargs.get('is_controlling') is False:
|
|
kwargs.pop('is_controlling')
|
|
query = {'must': [], 'must_not': []}
|
|
collapse = None
|
|
if 'fee_currency' in kwargs and kwargs['fee_currency'] is not None:
|
|
kwargs['fee_currency'] = kwargs['fee_currency'].upper()
|
|
for key, value in kwargs.items():
|
|
key = key.replace('claim.', '')
|
|
many = key.endswith('__in') or isinstance(value, list)
|
|
if many and len(value) > 2048:
|
|
raise TooManyClaimSearchParametersError(key, 2048)
|
|
if many:
|
|
key = key.replace('__in', '')
|
|
value = list(filter(None, value))
|
|
if value is None or isinstance(value, list) and len(value) == 0:
|
|
continue
|
|
key = REPLACEMENTS.get(key, key)
|
|
if key in FIELDS:
|
|
partial_id = False
|
|
if key == 'claim_type':
|
|
if isinstance(value, str):
|
|
value = CLAIM_TYPES[value]
|
|
else:
|
|
value = [CLAIM_TYPES[claim_type] for claim_type in value]
|
|
elif key == 'stream_type':
|
|
value = [STREAM_TYPES[value]] if isinstance(value, str) else list(map(STREAM_TYPES.get, value))
|
|
if key == '_id':
|
|
if isinstance(value, Iterable):
|
|
value = [item[::-1].hex() for item in value]
|
|
else:
|
|
value = value[::-1].hex()
|
|
if not many and key in ('_id', 'claim_id') and len(value) < 20:
|
|
partial_id = True
|
|
if key in ('signature_valid', 'has_source'):
|
|
continue # handled later
|
|
if key in TEXT_FIELDS:
|
|
key += '.keyword'
|
|
ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'}
|
|
if partial_id:
|
|
query['must'].append({"prefix": {"claim_id": value}})
|
|
elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops:
|
|
operator_length = 2 if value[:2] in ops else 1
|
|
operator, value = value[:operator_length], value[operator_length:]
|
|
if key == 'fee_amount':
|
|
value = str(Decimal(value)*1000)
|
|
query['must'].append({"range": {key: {ops[operator]: value}}})
|
|
elif many:
|
|
query['must'].append({"terms": {key: value}})
|
|
else:
|
|
if key == 'fee_amount':
|
|
value = str(Decimal(value)*1000)
|
|
query['must'].append({"term": {key: {"value": value}}})
|
|
elif key == 'not_channel_ids':
|
|
for channel_id in value:
|
|
query['must_not'].append({"term": {'channel_id.keyword': channel_id}})
|
|
query['must_not'].append({"term": {'_id': channel_id}})
|
|
elif key == 'channel_ids':
|
|
query['must'].append({"terms": {'channel_id.keyword': value}})
|
|
elif key == 'claim_ids':
|
|
query['must'].append({"terms": {'claim_id.keyword': value}})
|
|
elif key == 'media_types':
|
|
query['must'].append({"terms": {'media_type.keyword': value}})
|
|
elif key == 'any_languages':
|
|
query['must'].append({"terms": {'languages': clean_tags(value)}})
|
|
elif key == 'any_languages':
|
|
query['must'].append({"terms": {'languages': value}})
|
|
elif key == 'all_languages':
|
|
query['must'].extend([{"term": {'languages': tag}} for tag in value])
|
|
elif key == 'any_tags':
|
|
query['must'].append({"terms": {'tags.keyword': clean_tags(value)}})
|
|
elif key == 'all_tags':
|
|
query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
|
|
elif key == 'not_tags':
|
|
query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
|
|
elif key == 'not_claim_id':
|
|
query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value])
|
|
elif key == 'limit_claims_per_channel':
|
|
collapse = ('channel_id.keyword', value)
|
|
if kwargs.get('has_channel_signature'):
|
|
query['must'].append({"exists": {"field": "signature"}})
|
|
if 'signature_valid' in kwargs:
|
|
query['must'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}})
|
|
elif 'signature_valid' in kwargs:
|
|
query.setdefault('should', [])
|
|
query["minimum_should_match"] = 1
|
|
query['should'].append({"bool": {"must_not": {"exists": {"field": "signature"}}}})
|
|
query['should'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}})
|
|
if 'has_source' in kwargs:
|
|
query.setdefault('should', [])
|
|
query["minimum_should_match"] = 1
|
|
is_stream_or_repost = {"terms": {"claim_type": [CLAIM_TYPES['stream'], CLAIM_TYPES['repost']]}}
|
|
query['should'].append(
|
|
{"bool": {"must": [{"match": {"has_source": kwargs['has_source']}}, is_stream_or_repost]}})
|
|
query['should'].append({"bool": {"must_not": [is_stream_or_repost]}})
|
|
query['should'].append({"bool": {"must": [{"term": {"reposted_claim_type": CLAIM_TYPES['channel']}}]}})
|
|
if kwargs.get('text'):
|
|
query['must'].append(
|
|
{"simple_query_string":
|
|
{"query": kwargs["text"], "fields": [
|
|
"claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5"
|
|
]}})
|
|
query = {
|
|
"_source": {"excludes": ["description", "title"]},
|
|
'query': {'bool': query},
|
|
"sort": [],
|
|
}
|
|
if "limit" in kwargs:
|
|
query["size"] = kwargs["limit"]
|
|
if 'offset' in kwargs:
|
|
query["from"] = kwargs["offset"]
|
|
if 'order_by' in kwargs:
|
|
if isinstance(kwargs["order_by"], str):
|
|
kwargs["order_by"] = [kwargs["order_by"]]
|
|
for value in kwargs['order_by']:
|
|
if 'trending_group' in value:
|
|
# fixme: trending_mixed is 0 for all records on variable decay, making sort slow.
|
|
continue
|
|
is_asc = value.startswith('^')
|
|
value = value[1:] if is_asc else value
|
|
value = REPLACEMENTS.get(value, value)
|
|
if value in TEXT_FIELDS:
|
|
value += '.keyword'
|
|
query['sort'].append({value: "asc" if is_asc else "desc"})
|
|
if collapse:
|
|
query["collapse"] = {
|
|
"field": collapse[0],
|
|
"inner_hits": {
|
|
"name": collapse[0],
|
|
"size": collapse[1],
|
|
"sort": query["sort"]
|
|
}
|
|
}
|
|
return query
|
|
|
|
|
|
|
|
class ChannelResolution(str):
|
|
@classmethod
|
|
def lookup_error(cls, url):
|
|
return LookupError(f'Could not find channel in "{url}".')
|
|
|
|
|
|
class StreamResolution(str):
|
|
@classmethod
|
|
def lookup_error(cls, url):
|
|
return LookupError(f'Could not find claim at "{url}".')
|
|
|
|
|
|
class IndexVersionMismatch(Exception):
|
|
def __init__(self, got_version, expected_version):
|
|
self.got_version = got_version
|
|
self.expected_version = expected_version
|
|
|
|
|
|
class SearchIndex:
|
|
VERSION = 1
|
|
|
|
def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200):
|
|
self.search_timeout = search_timeout
|
|
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
|
|
self.search_client: Optional[AsyncElasticsearch] = None
|
|
self.sync_client: Optional[AsyncElasticsearch] = None
|
|
self.index = index_prefix + 'claims'
|
|
self.logger = logging.getLogger(__name__)
|
|
self.claim_cache = LRUCache(2 ** 15)
|
|
self.search_cache = LRUCache(2 ** 17)
|
|
self._elastic_host = elastic_host
|
|
self._elastic_port = elastic_port
|
|
|
|
async def get_index_version(self) -> int:
|
|
try:
|
|
template = await self.sync_client.indices.get_template(self.index)
|
|
return template[self.index]['version']
|
|
except NotFoundError:
|
|
return 0
|
|
|
|
async def set_index_version(self, version):
|
|
await self.sync_client.indices.put_template(
|
|
self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400
|
|
)
|
|
|
|
async def start(self) -> bool:
|
|
if self.sync_client:
|
|
return False
|
|
hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
|
|
self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
|
|
self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout)
|
|
while True:
|
|
try:
|
|
await self.sync_client.cluster.health(wait_for_status='yellow')
|
|
break
|
|
except ConnectionError:
|
|
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
|
|
await asyncio.sleep(1)
|
|
|
|
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
|
|
acked = res.get('acknowledged', False)
|
|
if acked:
|
|
await self.set_index_version(self.VERSION)
|
|
return acked
|
|
index_version = await self.get_index_version()
|
|
if index_version != self.VERSION:
|
|
self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION)
|
|
raise IndexVersionMismatch(index_version, self.VERSION)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
return acked
|
|
|
|
async def stop(self):
|
|
clients = [c for c in (self.sync_client, self.search_client) if c is not None]
|
|
self.sync_client, self.search_client = None, None
|
|
if clients:
|
|
await asyncio.gather(*(client.close() for client in clients))
|
|
|
|
def delete_index(self):
|
|
return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
|
|
|
|
async def _consume_claim_producer(self, claim_producer):
|
|
count = 0
|
|
async for op, doc in claim_producer:
|
|
if op == 'delete':
|
|
yield {
|
|
'_index': self.index,
|
|
'_op_type': 'delete',
|
|
'_id': doc
|
|
}
|
|
else:
|
|
yield {
|
|
'doc': {key: value for key, value in doc.items() if key in ALL_FIELDS},
|
|
'_id': doc['claim_id'],
|
|
'_index': self.index,
|
|
'_op_type': 'update',
|
|
'doc_as_upsert': True
|
|
}
|
|
count += 1
|
|
if count % 100 == 0:
|
|
self.logger.info("Indexing in progress, %d claims.", count)
|
|
if count:
|
|
self.logger.info("Indexing done for %d claims.", count)
|
|
else:
|
|
self.logger.debug("Indexing done for %d claims.", count)
|
|
|
|
async def claim_consumer(self, claim_producer):
|
|
touched = set()
|
|
async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
|
|
raise_on_error=False):
|
|
if not ok:
|
|
self.logger.warning("indexing failed for an item: %s", item)
|
|
else:
|
|
item = item.popitem()[1]
|
|
touched.add(item['_id'])
|
|
await self.sync_client.indices.refresh(self.index)
|
|
self.logger.debug("Indexing done.")
|
|
|
|
def update_filter_query(self, censor_type, blockdict, channels=False):
|
|
blockdict = {blocked.hex(): blocker.hex() for blocked, blocker in blockdict.items()}
|
|
if channels:
|
|
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
|
else:
|
|
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
|
key = 'channel_id' if channels else 'claim_id'
|
|
update['script'] = {
|
|
"source": f"ctx._source.censor_type={censor_type}; "
|
|
f"ctx._source.censoring_channel_id=params[ctx._source.{key}];",
|
|
"lang": "painless",
|
|
"params": blockdict
|
|
}
|
|
return update
|
|
|
|
async def update_trending_score(self, params):
|
|
update_trending_score_script = """
|
|
double softenLBC(double lbc) { return (Math.pow(lbc, 1.0 / 3.0)); }
|
|
|
|
double logsumexp(double x, double y)
|
|
{
|
|
double top;
|
|
if(x > y)
|
|
top = x;
|
|
else
|
|
top = y;
|
|
double result = top + Math.log(Math.exp(x-top) + Math.exp(y-top));
|
|
return(result);
|
|
}
|
|
|
|
double logdiffexp(double big, double small)
|
|
{
|
|
return big + Math.log(1.0 - Math.exp(small - big));
|
|
}
|
|
|
|
double squash(double x)
|
|
{
|
|
if(x < 0.0)
|
|
return -Math.log(1.0 - x);
|
|
else
|
|
return Math.log(x + 1.0);
|
|
}
|
|
|
|
double unsquash(double x)
|
|
{
|
|
if(x < 0.0)
|
|
return 1.0 - Math.exp(-x);
|
|
else
|
|
return Math.exp(x) - 1.0;
|
|
}
|
|
|
|
double log_to_squash(double x)
|
|
{
|
|
return logsumexp(x, 0.0);
|
|
}
|
|
|
|
double squash_to_log(double x)
|
|
{
|
|
//assert x > 0.0;
|
|
return logdiffexp(x, 0.0);
|
|
}
|
|
|
|
double squashed_add(double x, double y)
|
|
{
|
|
// squash(unsquash(x) + unsquash(y)) but avoiding overflow.
|
|
// Cases where the signs are the same
|
|
if (x < 0.0 && y < 0.0)
|
|
return -logsumexp(-x, logdiffexp(-y, 0.0));
|
|
if (x >= 0.0 && y >= 0.0)
|
|
return logsumexp(x, logdiffexp(y, 0.0));
|
|
// Where the signs differ
|
|
if (x >= 0.0 && y < 0.0)
|
|
if (Math.abs(x) >= Math.abs(y))
|
|
return logsumexp(0.0, logdiffexp(x, -y));
|
|
else
|
|
return -logsumexp(0.0, logdiffexp(-y, x));
|
|
if (x < 0.0 && y >= 0.0)
|
|
{
|
|
// Addition is commutative, hooray for new math
|
|
return squashed_add(y, x);
|
|
}
|
|
return 0.0;
|
|
}
|
|
|
|
double squashed_multiply(double x, double y)
|
|
{
|
|
// squash(unsquash(x)*unsquash(y)) but avoiding overflow.
|
|
int sign;
|
|
if(x*y >= 0.0)
|
|
sign = 1;
|
|
else
|
|
sign = -1;
|
|
return sign*logsumexp(squash_to_log(Math.abs(x))
|
|
+ squash_to_log(Math.abs(y)), 0.0);
|
|
}
|
|
|
|
// Squashed inflated units
|
|
double inflateUnits(int height) {
|
|
double timescale = 576.0; // Half life of 400 = e-folding time of a day
|
|
// by coincidence, so may as well go with it
|
|
return log_to_squash(height / timescale);
|
|
}
|
|
|
|
double spikePower(double newAmount) {
|
|
if (newAmount < 50.0) {
|
|
return(0.5);
|
|
} else if (newAmount < 85.0) {
|
|
return(newAmount / 100.0);
|
|
} else {
|
|
return(0.85);
|
|
}
|
|
}
|
|
|
|
double spikeMass(double oldAmount, double newAmount) {
|
|
double softenedChange = softenLBC(Math.abs(newAmount - oldAmount));
|
|
double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount));
|
|
double power = spikePower(newAmount);
|
|
if (oldAmount > newAmount) {
|
|
-1.0 * Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
|
|
} else {
|
|
Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
|
|
}
|
|
}
|
|
for (i in params.src.changes) {
|
|
double units = inflateUnits(i.height);
|
|
if (ctx._source.trending_score == null) {
|
|
ctx._source.trending_score = 0.0;
|
|
}
|
|
double bigSpike = squashed_multiply(units, squash(spikeMass(i.prev_amount, i.new_amount)));
|
|
ctx._source.trending_score = squashed_add(ctx._source.trending_score, bigSpike);
|
|
}
|
|
"""
|
|
start = time.perf_counter()
|
|
|
|
def producer():
|
|
for claim_id, claim_updates in params.items():
|
|
yield {
|
|
'_id': claim_id,
|
|
'_index': self.index,
|
|
'_op_type': 'update',
|
|
'script': {
|
|
'lang': 'painless',
|
|
'source': update_trending_score_script,
|
|
'params': {'src': {
|
|
'changes': [
|
|
{
|
|
'height': p.height,
|
|
'prev_amount': p.prev_amount / 1E8,
|
|
'new_amount': p.new_amount / 1E8,
|
|
} for p in claim_updates
|
|
]
|
|
}}
|
|
},
|
|
}
|
|
if not params:
|
|
return
|
|
async for ok, item in async_streaming_bulk(self.sync_client, producer(), raise_on_error=False):
|
|
if not ok:
|
|
self.logger.warning("updating trending failed for an item: %s", item)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
self.logger.info("updated trending scores in %ims", int((time.perf_counter() - start) * 1000))
|
|
|
|
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
|
|
if filtered_streams:
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
if filtered_channels:
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
if blocked_streams:
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
if blocked_channels:
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
await self.sync_client.update_by_query(
|
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
|
|
await self.sync_client.indices.refresh(self.index)
|
|
self.clear_caches()
|
|
|
|
def clear_caches(self):
|
|
self.search_cache.clear()
|
|
self.claim_cache.clear()
|
|
|
|
def _make_resolve_result(self, es_result):
|
|
return ResolveResult(
|
|
name=es_result['claim_name'],
|
|
normalized_name=es_result['normalized_name'],
|
|
claim_hash=es_result['claim_hash'],
|
|
tx_num=es_result['tx_num'],
|
|
position=es_result['tx_nout'],
|
|
tx_hash=es_result['tx_hash'],
|
|
height=es_result['height'],
|
|
amount=es_result['amount'],
|
|
short_url=es_result['short_url'],
|
|
is_controlling=es_result['is_controlling'],
|
|
canonical_url=es_result['canonical_url'],
|
|
creation_height=es_result['creation_height'],
|
|
activation_height=es_result['activation_height'],
|
|
expiration_height=es_result['expiration_height'],
|
|
effective_amount=es_result['effective_amount'],
|
|
support_amount=es_result['support_amount'],
|
|
last_takeover_height=es_result['last_take_over_height'],
|
|
claims_in_channel=es_result['claims_in_channel'],
|
|
channel_hash=es_result['channel_hash'],
|
|
reposted_claim_hash=es_result['reposted_claim_hash'],
|
|
reposted=es_result['reposted'],
|
|
signature_valid=es_result['signature_valid'],
|
|
reposted_tx_hash=bytes.fromhex(es_result['reposted_tx_id'] or '')[::-1] or None,
|
|
reposted_tx_position=es_result['reposted_tx_position'],
|
|
reposted_height=es_result['reposted_height'],
|
|
channel_tx_hash=bytes.fromhex(es_result['channel_tx_id'] or '')[::-1] or None,
|
|
channel_tx_position=es_result['channel_tx_position'],
|
|
channel_height=es_result['channel_height'],
|
|
)
|
|
|
|
async def cached_search(self, kwargs):
|
|
total_referenced = []
|
|
cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
|
|
if cache_item.result is not None:
|
|
return cache_item.result
|
|
async with cache_item.lock:
|
|
if cache_item.result:
|
|
return cache_item.result
|
|
censor = Censor(Censor.SEARCH)
|
|
response, offset, total = await self.search(**kwargs)
|
|
censor.apply(response)
|
|
total_referenced.extend(response)
|
|
|
|
if censor.censored:
|
|
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
|
|
total_referenced.extend(response)
|
|
response = [self._make_resolve_result(r) for r in response]
|
|
extra = [self._make_resolve_result(r) for r in await self._get_referenced_rows(total_referenced)]
|
|
result = Outputs.to_base64(
|
|
response, extra, offset, total, censor
|
|
)
|
|
cache_item.result = result
|
|
return result
|
|
|
|
async def get_many(self, *claim_ids):
|
|
await self.populate_claim_cache(*claim_ids)
|
|
return filter(None, map(self.claim_cache.get, claim_ids))
|
|
|
|
async def populate_claim_cache(self, *claim_ids):
|
|
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
|
|
if missing:
|
|
results = await self.search_client.mget(
|
|
index=self.index, body={"ids": missing}
|
|
)
|
|
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
|
|
self.claim_cache.set(result['claim_id'], result)
|
|
|
|
|
|
async def search(self, **kwargs):
|
|
try:
|
|
return await self.search_ahead(**kwargs)
|
|
except NotFoundError:
|
|
return [], 0, 0
|
|
# return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)
|
|
|
|
async def search_ahead(self, **kwargs):
|
|
# 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
|
|
per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
|
|
remove_duplicates = kwargs.pop('remove_duplicates', False)
|
|
page_size = kwargs.pop('limit', 10)
|
|
offset = kwargs.pop('offset', 0)
|
|
kwargs['limit'] = 1000
|
|
cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
|
|
if cache_item.result is not None:
|
|
reordered_hits = cache_item.result
|
|
else:
|
|
async with cache_item.lock:
|
|
if cache_item.result:
|
|
reordered_hits = cache_item.result
|
|
else:
|
|
query = expand_query(**kwargs)
|
|
search_hits = deque((await self.search_client.search(
|
|
query, index=self.index, track_total_hits=False,
|
|
_source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
|
|
))['hits']['hits'])
|
|
if remove_duplicates:
|
|
search_hits = self.__remove_duplicates(search_hits)
|
|
if per_channel_per_page > 0:
|
|
reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
|
|
else:
|
|
reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
|
|
cache_item.result = reordered_hits
|
|
result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
|
|
return result, 0, len(reordered_hits)
|
|
|
|
def __remove_duplicates(self, search_hits: deque) -> deque:
|
|
known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
|
|
dropped = set()
|
|
for hit in search_hits:
|
|
hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
|
|
if hit_id not in known_ids:
|
|
known_ids[hit_id] = (hit_height, hit['_id'])
|
|
else:
|
|
previous_height, previous_id = known_ids[hit_id]
|
|
if hit_height < previous_height:
|
|
known_ids[hit_id] = (hit_height, hit['_id'])
|
|
dropped.add(previous_id)
|
|
else:
|
|
dropped.add(hit['_id'])
|
|
return deque(hit for hit in search_hits if hit['_id'] not in dropped)
|
|
|
|
def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
|
|
reordered_hits = []
|
|
channel_counters = Counter()
|
|
next_page_hits_maybe_check_later = deque()
|
|
while search_hits or next_page_hits_maybe_check_later:
|
|
if reordered_hits and len(reordered_hits) % page_size == 0:
|
|
channel_counters.clear()
|
|
elif not reordered_hits:
|
|
pass
|
|
else:
|
|
break # means last page was incomplete and we are left with bad replacements
|
|
for _ in range(len(next_page_hits_maybe_check_later)):
|
|
claim_id, channel_id = next_page_hits_maybe_check_later.popleft()
|
|
if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page:
|
|
reordered_hits.append((claim_id, channel_id))
|
|
channel_counters[channel_id] += 1
|
|
else:
|
|
next_page_hits_maybe_check_later.append((claim_id, channel_id))
|
|
while search_hits:
|
|
hit = search_hits.popleft()
|
|
hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
|
|
if hit_channel_id is None or per_channel_per_page <= 0:
|
|
reordered_hits.append((hit_id, hit_channel_id))
|
|
elif channel_counters[hit_channel_id] < per_channel_per_page:
|
|
reordered_hits.append((hit_id, hit_channel_id))
|
|
channel_counters[hit_channel_id] += 1
|
|
if len(reordered_hits) % page_size == 0:
|
|
break
|
|
else:
|
|
next_page_hits_maybe_check_later.append((hit_id, hit_channel_id))
|
|
return reordered_hits
|
|
|
|
async def _get_referenced_rows(self, txo_rows: List[dict]):
|
|
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
|
|
referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
|
|
referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
|
|
referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows)))
|
|
|
|
referenced_txos = []
|
|
if referenced_ids:
|
|
referenced_txos.extend(await self.get_many(*referenced_ids))
|
|
referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))
|
|
|
|
if referenced_ids:
|
|
referenced_txos.extend(await self.get_many(*referenced_ids))
|
|
|
|
return referenced_txos
|
|
|
|
|
|
def expand_query(**kwargs):
|
|
if "amount_order" in kwargs:
|
|
kwargs["limit"] = 1
|
|
kwargs["order_by"] = "effective_amount"
|
|
kwargs["offset"] = int(kwargs["amount_order"]) - 1
|
|
if 'name' in kwargs:
|
|
kwargs['name'] = normalize_name(kwargs.pop('name'))
|
|
if kwargs.get('is_controlling') is False:
|
|
kwargs.pop('is_controlling')
|
|
query = {'must': [], 'must_not': []}
|
|
collapse = None
|
|
if 'fee_currency' in kwargs and kwargs['fee_currency'] is not None:
|
|
kwargs['fee_currency'] = kwargs['fee_currency'].upper()
|
|
for key, value in kwargs.items():
|
|
key = key.replace('claim.', '')
|
|
many = key.endswith('__in') or isinstance(value, list)
|
|
if many and len(value) > 2048:
|
|
raise TooManyClaimSearchParametersError(key, 2048)
|
|
if many:
|
|
key = key.replace('__in', '')
|
|
value = list(filter(None, value))
|
|
if value is None or isinstance(value, list) and len(value) == 0:
|
|
continue
|
|
key = REPLACEMENTS.get(key, key)
|
|
if key in FIELDS:
|
|
partial_id = False
|
|
if key == 'claim_type':
|
|
if isinstance(value, str):
|
|
value = CLAIM_TYPES[value]
|
|
else:
|
|
value = [CLAIM_TYPES[claim_type] for claim_type in value]
|
|
elif key == 'stream_type':
|
|
value = [STREAM_TYPES[value]] if isinstance(value, str) else list(map(STREAM_TYPES.get, value))
|
|
if key == '_id':
|
|
if isinstance(value, Iterable):
|
|
value = [item[::-1].hex() for item in value]
|
|
else:
|
|
value = value[::-1].hex()
|
|
if not many and key in ('_id', 'claim_id', 'sd_hash') and len(value) < 20:
|
|
partial_id = True
|
|
if key in ('signature_valid', 'has_source'):
|
|
continue # handled later
|
|
if key in TEXT_FIELDS:
|
|
key += '.keyword'
|
|
ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'}
|
|
if partial_id:
|
|
query['must'].append({"prefix": {key: value}})
|
|
elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops:
|
|
operator_length = 2 if value[:2] in ops else 1
|
|
operator, value = value[:operator_length], value[operator_length:]
|
|
if key == 'fee_amount':
|
|
value = str(Decimal(value)*1000)
|
|
query['must'].append({"range": {key: {ops[operator]: value}}})
|
|
elif key in RANGE_FIELDS and isinstance(value, list) and all(v[0] in ops for v in value):
|
|
range_constraints = []
|
|
for v in value:
|
|
operator_length = 2 if v[:2] in ops else 1
|
|
operator, stripped_op_v = v[:operator_length], v[operator_length:]
|
|
if key == 'fee_amount':
|
|
stripped_op_v = str(Decimal(stripped_op_v)*1000)
|
|
range_constraints.append((operator, stripped_op_v))
|
|
query['must'].append({"range": {key: {ops[operator]: v for operator, v in range_constraints}}})
|
|
elif many:
|
|
query['must'].append({"terms": {key: value}})
|
|
else:
|
|
if key == 'fee_amount':
|
|
value = str(Decimal(value)*1000)
|
|
query['must'].append({"term": {key: {"value": value}}})
|
|
elif key == 'not_channel_ids':
|
|
for channel_id in value:
|
|
query['must_not'].append({"term": {'channel_id.keyword': channel_id}})
|
|
query['must_not'].append({"term": {'_id': channel_id}})
|
|
elif key == 'channel_ids':
|
|
query['must'].append({"terms": {'channel_id.keyword': value}})
|
|
elif key == 'claim_ids':
|
|
query['must'].append({"terms": {'claim_id.keyword': value}})
|
|
elif key == 'media_types':
|
|
query['must'].append({"terms": {'media_type.keyword': value}})
|
|
elif key == 'any_languages':
|
|
query['must'].append({"terms": {'languages': clean_tags(value)}})
|
|
elif key == 'any_languages':
|
|
query['must'].append({"terms": {'languages': value}})
|
|
elif key == 'all_languages':
|
|
query['must'].extend([{"term": {'languages': tag}} for tag in value])
|
|
elif key == 'any_tags':
|
|
query['must'].append({"terms": {'tags.keyword': clean_tags(value)}})
|
|
elif key == 'all_tags':
|
|
query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
|
|
elif key == 'not_tags':
|
|
query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
|
|
elif key == 'not_claim_id':
|
|
query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value])
|
|
elif key == 'limit_claims_per_channel':
|
|
collapse = ('channel_id.keyword', value)
|
|
if kwargs.get('has_channel_signature'):
|
|
query['must'].append({"exists": {"field": "signature"}})
|
|
if 'signature_valid' in kwargs:
|
|
query['must'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}})
|
|
elif 'signature_valid' in kwargs:
|
|
query.setdefault('should', [])
|
|
query["minimum_should_match"] = 1
|
|
query['should'].append({"bool": {"must_not": {"exists": {"field": "signature"}}}})
|
|
query['should'].append({"term": {"is_signature_valid": bool(kwargs["signature_valid"])}})
|
|
if 'has_source' in kwargs:
|
|
query.setdefault('should', [])
|
|
query["minimum_should_match"] = 1
|
|
is_stream_or_repost = {"terms": {"claim_type": [CLAIM_TYPES['stream'], CLAIM_TYPES['repost']]}}
|
|
query['should'].append(
|
|
{"bool": {"must": [{"match": {"has_source": kwargs['has_source']}}, is_stream_or_repost]}})
|
|
query['should'].append({"bool": {"must_not": [is_stream_or_repost]}})
|
|
query['should'].append({"bool": {"must": [{"term": {"reposted_claim_type": CLAIM_TYPES['channel']}}]}})
|
|
if kwargs.get('text'):
|
|
query['must'].append(
|
|
{"simple_query_string":
|
|
{"query": kwargs["text"], "fields": [
|
|
"claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5"
|
|
]}})
|
|
query = {
|
|
"_source": {"excludes": ["description", "title"]},
|
|
'query': {'bool': query},
|
|
"sort": [],
|
|
}
|
|
if "limit" in kwargs:
|
|
query["size"] = kwargs["limit"]
|
|
if 'offset' in kwargs:
|
|
query["from"] = kwargs["offset"]
|
|
if 'order_by' in kwargs:
|
|
if isinstance(kwargs["order_by"], str):
|
|
kwargs["order_by"] = [kwargs["order_by"]]
|
|
for value in kwargs['order_by']:
|
|
if 'trending_group' in value:
|
|
# fixme: trending_mixed is 0 for all records on variable decay, making sort slow.
|
|
continue
|
|
is_asc = value.startswith('^')
|
|
value = value[1:] if is_asc else value
|
|
value = REPLACEMENTS.get(value, value)
|
|
if value in TEXT_FIELDS:
|
|
value += '.keyword'
|
|
query['sort'].append({value: "asc" if is_asc else "desc"})
|
|
if collapse:
|
|
query["collapse"] = {
|
|
"field": collapse[0],
|
|
"inner_hits": {
|
|
"name": collapse[0],
|
|
"size": collapse[1],
|
|
"sort": query["sort"]
|
|
}
|
|
}
|
|
return query
|
|
|
|
|
|
def expand_result(results):
|
|
inner_hits = []
|
|
expanded = []
|
|
for result in results:
|
|
if result.get("inner_hits"):
|
|
for _, inner_hit in result["inner_hits"].items():
|
|
inner_hits.extend(inner_hit["hits"]["hits"])
|
|
continue
|
|
result = result['_source']
|
|
result['claim_hash'] = unhexlify(result['claim_id'])[::-1]
|
|
if result['reposted_claim_id']:
|
|
result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1]
|
|
else:
|
|
result['reposted_claim_hash'] = None
|
|
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
|
|
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
|
|
result['tx_hash'] = unhexlify(result['tx_id'])[::-1]
|
|
result['reposted'] = result.pop('repost_count')
|
|
result['signature_valid'] = result.pop('is_signature_valid')
|
|
# result['normalized'] = result.pop('normalized_name')
|
|
# if result['censoring_channel_hash']:
|
|
# result['censoring_channel_hash'] = unhexlify(result['censoring_channel_hash'])[::-1]
|
|
expanded.append(result)
|
|
if inner_hits:
|
|
return expand_result(inner_hits)
|
|
return expanded
|
|
|
|
|
|
class ResultCacheItem:
|
|
__slots__ = '_result', 'lock', 'has_result'
|
|
|
|
def __init__(self):
|
|
self.has_result = asyncio.Event()
|
|
self.lock = asyncio.Lock()
|
|
self._result = None
|
|
|
|
@property
|
|
def result(self) -> str:
|
|
return self._result
|
|
|
|
@result.setter
|
|
def result(self, result: str):
|
|
self._result = result
|
|
if result is not None:
|
|
self.has_result.set()
|
|
|
|
@classmethod
|
|
def from_cache(cls, cache_key, cache):
|
|
cache_item = cache.get(cache_key)
|
|
if cache_item is None:
|
|
cache_item = cache[cache_key] = ResultCacheItem()
|
|
return cache_item
|