fixes from review

This commit is contained in:
Victor Shyba 2021-03-24 05:35:31 -03:00
parent d47cf40544
commit 7df4cc44c4
3 changed files with 83 additions and 71 deletions

View file

@ -8,7 +8,7 @@ services:
wallet_server: wallet_server:
depends_on: depends_on:
- es01 - es01
image: lbry/wallet-server:${WALLET_SERVER_TAG:-development} image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
restart: always restart: always
network_mode: host network_mode: host
ports: ports:

View file

@ -22,6 +22,7 @@ def set_reference(reference, txo_row):
class Censor: class Censor:
NOT_CENSORED = 0
SEARCH = 1 SEARCH = 1
RESOLVE = 2 RESOLVE = 2
@ -31,16 +32,19 @@ class Censor:
self.censor_type = censor_type self.censor_type = censor_type
self.censored = {} self.censored = {}
def is_censored(self, row):
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type
def apply(self, rows): def apply(self, rows):
return [row for row in rows if not self.censor(row)] return [row for row in rows if not self.censor(row)]
def censor(self, row) -> bool: def censor(self, row) -> bool:
was_censored = (row.get('censor_type') or 0) >= self.censor_type if self.is_censored(row):
if was_censored:
censoring_channel_hash = row['censoring_channel_hash'] censoring_channel_hash = row['censoring_channel_hash']
self.censored.setdefault(censoring_channel_hash, set()) self.censored.setdefault(censoring_channel_hash, set())
self.censored[censoring_channel_hash].add(row['tx_hash']) self.censored[censoring_channel_hash].add(row['tx_hash'])
return was_censored return True
return False
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict): def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
for censoring_channel_hash, count in self.censored.items(): for censoring_channel_hash, count in self.censored.items():

View file

@ -1,9 +1,9 @@
import asyncio import asyncio
import struct import struct
from binascii import hexlify, unhexlify from binascii import unhexlify
from decimal import Decimal from decimal import Decimal
from operator import itemgetter from operator import itemgetter
from typing import Optional, List, Iterable from typing import Optional, List, Iterable, Union
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
from elasticsearch.helpers import async_streaming_bulk from elasticsearch.helpers import async_streaming_bulk
@ -21,11 +21,15 @@ from lbry.wallet.server.util import class_logger
class ChannelResolution(str): class ChannelResolution(str):
pass @classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find channel in "{url}".')
class StreamResolution(str): class StreamResolution(str):
pass @classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find claim at "{url}".')
class SearchIndex: class SearchIndex:
@ -33,7 +37,7 @@ class SearchIndex:
self.search_timeout = search_timeout self.search_timeout = search_timeout
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
self.search_client: Optional[AsyncElasticsearch] = None self.search_client: Optional[AsyncElasticsearch] = None
self.client: Optional[AsyncElasticsearch] = None self.sync_client: Optional[AsyncElasticsearch] = None
self.index = index_prefix + 'claims' self.index = index_prefix + 'claims'
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
self.claim_cache = LRUCache(2 ** 15) self.claim_cache = LRUCache(2 ** 15)
@ -42,27 +46,27 @@ class SearchIndex:
self.resolution_cache = LRUCache(2 ** 17) self.resolution_cache = LRUCache(2 ** 17)
async def start(self): async def start(self):
if self.client: if self.sync_client:
return return
self.client = AsyncElasticsearch(timeout=self.sync_timeout) self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
self.search_client = AsyncElasticsearch(timeout=self.search_timeout) self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
while True: while True:
try: try:
await self.client.cluster.health(wait_for_status='yellow') await self.sync_client.cluster.health(wait_for_status='yellow')
break break
except ConnectionError: except ConnectionError:
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
await asyncio.sleep(1) await asyncio.sleep(1)
res = await self.client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
return res.get('acknowledged', False) return res.get('acknowledged', False)
def stop(self): def stop(self):
clients = [self.client, self.search_client] clients = [self.sync_client, self.search_client]
self.client, self.search_client = None, None self.sync_client, self.search_client = None, None
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
def delete_index(self): def delete_index(self):
return self.client.indices.delete(self.index, ignore_unavailable=True) return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
async def _consume_claim_producer(self, claim_producer): async def _consume_claim_producer(self, claim_producer):
count = 0 count = 0
@ -77,22 +81,19 @@ class SearchIndex:
self.logger.info("Indexing done for %d claims.", count) self.logger.info("Indexing done for %d claims.", count)
async def claim_consumer(self, claim_producer): async def claim_consumer(self, claim_producer):
await self.client.indices.refresh(self.index)
touched = set() touched = set()
async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer), async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
raise_on_error=False): raise_on_error=False):
if not ok: if not ok:
self.logger.warning("indexing failed for an item: %s", item) self.logger.warning("indexing failed for an item: %s", item)
else: else:
item = item.popitem()[1] item = item.popitem()[1]
touched.add(item['_id']) touched.add(item['_id'])
await self.client.indices.refresh(self.index) await self.sync_client.indices.refresh(self.index)
self.logger.info("Indexing done.") self.logger.info("Indexing done.")
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): def update_filter_query(self, censor_type, blockdict, channels=False):
def make_query(censor_type, blockdict, channels=False): blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
blockdict = dict(
(hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items())
if channels: if channels:
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
else: else:
@ -104,22 +105,30 @@ class SearchIndex:
"params": blockdict "params": blockdict
} }
return update return update
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
if filtered_streams: if filtered_streams:
await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4) await self.sync_client.update_by_query(
await self.client.indices.refresh(self.index) self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if filtered_channels: if filtered_channels:
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4) await self.sync_client.update_by_query(
await self.client.indices.refresh(self.index) self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4) await self.sync_client.indices.refresh(self.index)
await self.client.indices.refresh(self.index) await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_streams: if blocked_streams:
await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4) await self.sync_client.update_by_query(
await self.client.indices.refresh(self.index) self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_channels: if blocked_channels:
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4) await self.sync_client.update_by_query(
await self.client.indices.refresh(self.index) self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index)
await self.client.indices.refresh(self.index) await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
self.search_cache.clear() self.search_cache.clear()
self.claim_cache.clear() self.claim_cache.clear()
self.resolution_cache.clear() self.resolution_cache.clear()
@ -138,13 +147,13 @@ class SearchIndex:
return cache_item.result return cache_item.result
censor = Censor(Censor.SEARCH) censor = Censor(Censor.SEARCH)
if kwargs.get('no_totals'): if kwargs.get('no_totals'):
response, offset, total = await self.search(**kwargs, censor_type=0) response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
else: else:
response, offset, total = await self.search(**kwargs) response, offset, total = await self.search(**kwargs)
censor.apply(response) censor.apply(response)
total_referenced.extend(response) total_referenced.extend(response)
if censor.censored: if censor.censored:
response, _, _ = await self.search(**kwargs, censor_type=0) response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
total_referenced.extend(response) total_referenced.extend(response)
result = Outputs.to_base64( result = Outputs.to_base64(
response, await self._get_referenced_rows(total_referenced), offset, total, censor response, await self._get_referenced_rows(total_referenced), offset, total, censor
@ -157,16 +166,8 @@ class SearchIndex:
censor = Censor(Censor.RESOLVE) censor = Censor(Censor.RESOLVE)
results = [await self.resolve_url(url) for url in urls] results = [await self.resolve_url(url) for url in urls]
# just heat the cache # just heat the cache
await self.get_many(*filter(lambda x: isinstance(x, str), results)) await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
for index in range(len(results)): results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]
result = results[index]
url = urls[index]
if result in self.claim_cache:
results[index] = self.claim_cache[result]
elif isinstance(result, StreamResolution):
results[index] = LookupError(f'Could not find claim at "{url}".')
elif isinstance(result, ChannelResolution):
results[index] = LookupError(f'Could not find channel in "{url}".')
censored = [ censored = [
result if not isinstance(result, dict) or not censor.censor(result) result if not isinstance(result, dict) or not censor.censor(result)
@ -175,15 +176,22 @@ class SearchIndex:
] ]
return results, censored, censor return results, censored, censor
def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
cached = self.claim_cache.get(resolution)
return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))
async def get_many(self, *claim_ids): async def get_many(self, *claim_ids):
missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache] await self.populate_claim_cache(*claim_ids)
return filter(None, map(self.claim_cache.get, claim_ids))
async def populate_claim_cache(self, *claim_ids):
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
if missing: if missing:
results = await self.search_client.mget( results = await self.search_client.mget(
index=self.index, body={"ids": missing} index=self.index, body={"ids": missing}
) )
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
self.claim_cache.set(result['claim_id'], result) self.claim_cache.set(result['claim_id'], result)
return filter(None, map(self.claim_cache.get, claim_ids))
async def full_id_from_short_id(self, name, short_id, channel_id=None): async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id key = (channel_id or '') + name + short_id
@ -304,23 +312,23 @@ class SearchIndex:
def extract_doc(doc, index): def extract_doc(doc, index):
doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode() doc['claim_id'] = doc.pop('claim_hash')[::-1].hex()
if doc['reposted_claim_hash'] is not None: if doc['reposted_claim_hash'] is not None:
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode() doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex()
else: else:
doc['reposted_claim_id'] = None doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash') channel_hash = doc.pop('channel_hash')
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash
channel_hash = doc.pop('censoring_channel_hash') channel_hash = doc.pop('censoring_channel_hash')
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash') txo_hash = doc.pop('txo_hash')
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode() doc['tx_id'] = txo_hash[:32][::-1].hex()
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0] doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
doc['is_controlling'] = bool(doc['is_controlling']) doc['is_controlling'] = bool(doc['is_controlling'])
doc['signature'] = hexlify(doc.pop('signature') or b'').decode() or None doc['signature'] = (doc.pop('signature') or b'').hex() or None
doc['signature_digest'] = hexlify(doc.pop('signature_digest') or b'').decode() or None doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None doc['public_key_hash'] = (doc.pop('public_key_hash') or b'').hex() or None
doc['signature_valid'] = bool(doc['signature_valid']) doc['signature_valid'] = bool(doc['signature_valid'])
doc['claim_type'] = doc.get('claim_type', 0) or 0 doc['claim_type'] = doc.get('claim_type', 0) or 0
doc['stream_type'] = int(doc.get('stream_type', 0) or 0) doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
@ -357,14 +365,14 @@ def expand_query(**kwargs):
value = [CLAIM_TYPES[claim_type] for claim_type in value] value = [CLAIM_TYPES[claim_type] for claim_type in value]
if key == '_id': if key == '_id':
if isinstance(value, Iterable): if isinstance(value, Iterable):
value = [hexlify(item[::-1]).decode() for item in value] value = [item[::-1].hex() for item in value]
else: else:
value = hexlify(value[::-1]).decode() value = value[::-1].hex()
if not many and key in ('_id', 'claim_id') and len(value) < 20: if not many and key in ('_id', 'claim_id') and len(value) < 20:
partial_id = True partial_id = True
if key == 'public_key_id': if key == 'public_key_id':
key = 'public_key_hash' key = 'public_key_hash'
value = hexlify(Base58.decode(value)[1:21]).decode() value = Base58.decode(value)[1:21].hex()
if key == 'signature_valid': if key == 'signature_valid':
continue # handled later continue # handled later
if key in TEXT_FIELDS: if key in TEXT_FIELDS: