fixes from review

This commit is contained in:
Victor Shyba 2021-03-24 05:35:31 -03:00
parent d47cf40544
commit 7df4cc44c4
3 changed files with 83 additions and 71 deletions

View file

@ -8,7 +8,7 @@ services:
wallet_server:
depends_on:
- es01
image: lbry/wallet-server:${WALLET_SERVER_TAG:-development}
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
restart: always
network_mode: host
ports:

View file

@ -22,6 +22,7 @@ def set_reference(reference, txo_row):
class Censor:
NOT_CENSORED = 0
SEARCH = 1
RESOLVE = 2
@ -31,16 +32,19 @@ class Censor:
self.censor_type = censor_type
self.censored = {}
def is_censored(self, row):
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type
def apply(self, rows):
return [row for row in rows if not self.censor(row)]
def censor(self, row) -> bool:
was_censored = (row.get('censor_type') or 0) >= self.censor_type
if was_censored:
if self.is_censored(row):
censoring_channel_hash = row['censoring_channel_hash']
self.censored.setdefault(censoring_channel_hash, set())
self.censored[censoring_channel_hash].add(row['tx_hash'])
return was_censored
return True
return False
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
for censoring_channel_hash, count in self.censored.items():

View file

@ -1,9 +1,9 @@
import asyncio
import struct
from binascii import hexlify, unhexlify
from binascii import unhexlify
from decimal import Decimal
from operator import itemgetter
from typing import Optional, List, Iterable
from typing import Optional, List, Iterable, Union
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
from elasticsearch.helpers import async_streaming_bulk
@ -21,11 +21,15 @@ from lbry.wallet.server.util import class_logger
class ChannelResolution(str):
pass
@classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find channel in "{url}".')
class StreamResolution(str):
pass
@classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find claim at "{url}".')
class SearchIndex:
@ -33,7 +37,7 @@ class SearchIndex:
self.search_timeout = search_timeout
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
self.search_client: Optional[AsyncElasticsearch] = None
self.client: Optional[AsyncElasticsearch] = None
self.sync_client: Optional[AsyncElasticsearch] = None
self.index = index_prefix + 'claims'
self.logger = class_logger(__name__, self.__class__.__name__)
self.claim_cache = LRUCache(2 ** 15)
@ -42,27 +46,27 @@ class SearchIndex:
self.resolution_cache = LRUCache(2 ** 17)
async def start(self):
if self.client:
if self.sync_client:
return
self.client = AsyncElasticsearch(timeout=self.sync_timeout)
self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
while True:
try:
await self.client.cluster.health(wait_for_status='yellow')
await self.sync_client.cluster.health(wait_for_status='yellow')
break
except ConnectionError:
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
await asyncio.sleep(1)
res = await self.client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
return res.get('acknowledged', False)
def stop(self):
clients = [self.client, self.search_client]
self.client, self.search_client = None, None
clients = [self.sync_client, self.search_client]
self.sync_client, self.search_client = None, None
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
def delete_index(self):
return self.client.indices.delete(self.index, ignore_unavailable=True)
return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
async def _consume_claim_producer(self, claim_producer):
count = 0
@ -77,49 +81,54 @@ class SearchIndex:
self.logger.info("Indexing done for %d claims.", count)
async def claim_consumer(self, claim_producer):
await self.client.indices.refresh(self.index)
touched = set()
async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer),
async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
raise_on_error=False):
if not ok:
self.logger.warning("indexing failed for an item: %s", item)
else:
item = item.popitem()[1]
touched.add(item['_id'])
await self.client.indices.refresh(self.index)
await self.sync_client.indices.refresh(self.index)
self.logger.info("Indexing done.")
def update_filter_query(self, censor_type, blockdict, channels=False):
blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
if channels:
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
else:
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
key = 'channel_id' if channels else 'claim_id'
update['script'] = {
"source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
"lang": "painless",
"params": blockdict
}
return update
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
def make_query(censor_type, blockdict, channels=False):
blockdict = dict(
(hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items())
if channels:
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
else:
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
key = 'channel_id' if channels else 'claim_id'
update['script'] = {
"source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
"lang": "painless",
"params": blockdict
}
return update
if filtered_streams:
await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4)
await self.client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if filtered_channels:
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4)
await self.client.indices.refresh(self.index)
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4)
await self.client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_streams:
await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4)
await self.client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_channels:
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4)
await self.client.indices.refresh(self.index)
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4)
await self.client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
self.search_cache.clear()
self.claim_cache.clear()
self.resolution_cache.clear()
@ -138,13 +147,13 @@ class SearchIndex:
return cache_item.result
censor = Censor(Censor.SEARCH)
if kwargs.get('no_totals'):
response, offset, total = await self.search(**kwargs, censor_type=0)
response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
else:
response, offset, total = await self.search(**kwargs)
censor.apply(response)
total_referenced.extend(response)
if censor.censored:
response, _, _ = await self.search(**kwargs, censor_type=0)
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
total_referenced.extend(response)
result = Outputs.to_base64(
response, await self._get_referenced_rows(total_referenced), offset, total, censor
@ -157,16 +166,8 @@ class SearchIndex:
censor = Censor(Censor.RESOLVE)
results = [await self.resolve_url(url) for url in urls]
# just heat the cache
await self.get_many(*filter(lambda x: isinstance(x, str), results))
for index in range(len(results)):
result = results[index]
url = urls[index]
if result in self.claim_cache:
results[index] = self.claim_cache[result]
elif isinstance(result, StreamResolution):
results[index] = LookupError(f'Could not find claim at "{url}".')
elif isinstance(result, ChannelResolution):
results[index] = LookupError(f'Could not find channel in "{url}".')
await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]
censored = [
result if not isinstance(result, dict) or not censor.censor(result)
@ -175,15 +176,22 @@ class SearchIndex:
]
return results, censored, censor
def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
cached = self.claim_cache.get(resolution)
return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))
async def get_many(self, *claim_ids):
missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache]
await self.populate_claim_cache(*claim_ids)
return filter(None, map(self.claim_cache.get, claim_ids))
async def populate_claim_cache(self, *claim_ids):
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
if missing:
results = await self.search_client.mget(
index=self.index, body={"ids": missing}
)
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
self.claim_cache.set(result['claim_id'], result)
return filter(None, map(self.claim_cache.get, claim_ids))
async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id
@ -304,23 +312,23 @@ class SearchIndex:
def extract_doc(doc, index):
doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode()
doc['claim_id'] = doc.pop('claim_hash')[::-1].hex()
if doc['reposted_claim_hash'] is not None:
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode()
doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex()
else:
doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash')
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash
channel_hash = doc.pop('censoring_channel_hash')
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash')
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
doc['tx_id'] = txo_hash[:32][::-1].hex()
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
doc['is_controlling'] = bool(doc['is_controlling'])
doc['signature'] = hexlify(doc.pop('signature') or b'').decode() or None
doc['signature_digest'] = hexlify(doc.pop('signature_digest') or b'').decode() or None
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None
doc['signature'] = (doc.pop('signature') or b'').hex() or None
doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None
doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
doc['public_key_hash'] = (doc.pop('public_key_hash') or b'').hex() or None
doc['signature_valid'] = bool(doc['signature_valid'])
doc['claim_type'] = doc.get('claim_type', 0) or 0
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
@ -357,14 +365,14 @@ def expand_query(**kwargs):
value = [CLAIM_TYPES[claim_type] for claim_type in value]
if key == '_id':
if isinstance(value, Iterable):
value = [hexlify(item[::-1]).decode() for item in value]
value = [item[::-1].hex() for item in value]
else:
value = hexlify(value[::-1]).decode()
value = value[::-1].hex()
if not many and key in ('_id', 'claim_id') and len(value) < 20:
partial_id = True
if key == 'public_key_id':
key = 'public_key_hash'
value = hexlify(Base58.decode(value)[1:21]).decode()
value = Base58.decode(value)[1:21].hex()
if key == 'signature_valid':
continue # handled later
if key in TEXT_FIELDS: