From 7df4cc44c4e69cf8572e027796bcc9291141a0b3 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 24 Mar 2021 05:35:31 -0300 Subject: [PATCH] fixes from review --- docker/docker-compose-wallet-server.yml | 2 +- lbry/schema/result.py | 10 +- lbry/wallet/server/db/elasticsearch/search.py | 142 +++++++++--------- 3 files changed, 83 insertions(+), 71 deletions(-) diff --git a/docker/docker-compose-wallet-server.yml b/docker/docker-compose-wallet-server.yml index 0ef9d4d6d..92a01e562 100644 --- a/docker/docker-compose-wallet-server.yml +++ b/docker/docker-compose-wallet-server.yml @@ -8,7 +8,7 @@ services: wallet_server: depends_on: - es01 - image: lbry/wallet-server:${WALLET_SERVER_TAG:-development} + image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release} restart: always network_mode: host ports: diff --git a/lbry/schema/result.py b/lbry/schema/result.py index 7b2f31a3f..7b4b30009 100644 --- a/lbry/schema/result.py +++ b/lbry/schema/result.py @@ -22,6 +22,7 @@ def set_reference(reference, txo_row): class Censor: + NOT_CENSORED = 0 SEARCH = 1 RESOLVE = 2 @@ -31,16 +32,19 @@ class Censor: self.censor_type = censor_type self.censored = {} + def is_censored(self, row): + return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type + def apply(self, rows): return [row for row in rows if not self.censor(row)] def censor(self, row) -> bool: - was_censored = (row.get('censor_type') or 0) >= self.censor_type - if was_censored: + if self.is_censored(row): censoring_channel_hash = row['censoring_channel_hash'] self.censored.setdefault(censoring_channel_hash, set()) self.censored[censoring_channel_hash].add(row['tx_hash']) - return was_censored + return True + return False def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict): for censoring_channel_hash, count in self.censored.items(): diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index ab8708d1d..362111489 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -1,9 +1,9 @@ import asyncio import struct -from binascii import hexlify, unhexlify +from binascii import unhexlify from decimal import Decimal from operator import itemgetter -from typing import Optional, List, Iterable +from typing import Optional, List, Iterable, Union from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from elasticsearch.helpers import async_streaming_bulk @@ -21,11 +21,15 @@ from lbry.wallet.server.util import class_logger class ChannelResolution(str): - pass + @classmethod + def lookup_error(cls, url): + return LookupError(f'Could not find channel in "{url}".') class StreamResolution(str): - pass + @classmethod + def lookup_error(cls, url): + return LookupError(f'Could not find claim at "{url}".') class SearchIndex: @@ -33,7 +37,7 @@ class SearchIndex: self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None - self.client: Optional[AsyncElasticsearch] = None + self.sync_client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2 ** 15) @@ -42,27 +46,27 @@ class SearchIndex: self.resolution_cache = LRUCache(2 ** 17) async def start(self): - if self.client: + if self.sync_client: return - self.client = AsyncElasticsearch(timeout=self.sync_timeout) + self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(timeout=self.search_timeout) while True: try: - await self.client.cluster.health(wait_for_status='yellow') + await self.sync_client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) - res = await self.client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) + res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) return res.get('acknowledged', False) def stop(self): - clients = [self.client, self.search_client] - self.client, self.search_client = None, None + clients = [self.sync_client, self.search_client] + self.sync_client, self.search_client = None, None return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) def delete_index(self): - return self.client.indices.delete(self.index, ignore_unavailable=True) + return self.sync_client.indices.delete(self.index, ignore_unavailable=True) async def _consume_claim_producer(self, claim_producer): count = 0 @@ -77,49 +81,54 @@ class SearchIndex: self.logger.info("Indexing done for %d claims.", count) async def claim_consumer(self, claim_producer): - await self.client.indices.refresh(self.index) touched = set() - async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer), + async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer), raise_on_error=False): if not ok: self.logger.warning("indexing failed for an item: %s", item) else: item = item.popitem()[1] touched.add(item['_id']) - await self.client.indices.refresh(self.index) + await self.sync_client.indices.refresh(self.index) self.logger.info("Indexing done.") + def update_filter_query(self, censor_type, blockdict, channels=False): + blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()} + if channels: + update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") + else: + update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") + key = 'channel_id' if channels else 'claim_id' + update['script'] = { + "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]", + "lang": "painless", + "params": blockdict + } + return update + async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): - def make_query(censor_type, blockdict, channels=False): - blockdict = dict( - (hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items()) - if channels: - update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") - else: - update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") - key = 'channel_id' if channels else 'claim_id' - update['script'] = { - "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]", - "lang": "painless", - "params": blockdict - } - return update if filtered_streams: - await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4) - await self.client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4) + await self.sync_client.indices.refresh(self.index) if filtered_channels: - await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4) - await self.client.indices.refresh(self.index) - await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4) - await self.client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4) + await self.sync_client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4) + await self.sync_client.indices.refresh(self.index) if blocked_streams: - await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4) - await self.client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4) + await self.sync_client.indices.refresh(self.index) if blocked_channels: - await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4) - await self.client.indices.refresh(self.index) - await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4) - await self.client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4) + await self.sync_client.indices.refresh(self.index) + await self.sync_client.update_by_query( + self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) + await self.sync_client.indices.refresh(self.index) self.search_cache.clear() self.claim_cache.clear() self.resolution_cache.clear() @@ -138,13 +147,13 @@ class SearchIndex: return cache_item.result censor = Censor(Censor.SEARCH) if kwargs.get('no_totals'): - response, offset, total = await self.search(**kwargs, censor_type=0) + response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) else: response, offset, total = await self.search(**kwargs) censor.apply(response) total_referenced.extend(response) if censor.censored: - response, _, _ = await self.search(**kwargs, censor_type=0) + response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) result = Outputs.to_base64( response, await self._get_referenced_rows(total_referenced), offset, total, censor @@ -157,16 +166,8 @@ class SearchIndex: censor = Censor(Censor.RESOLVE) results = [await self.resolve_url(url) for url in urls] # just heat the cache - await self.get_many(*filter(lambda x: isinstance(x, str), results)) - for index in range(len(results)): - result = results[index] - url = urls[index] - if result in self.claim_cache: - results[index] = self.claim_cache[result] - elif isinstance(result, StreamResolution): - results[index] = LookupError(f'Could not find claim at "{url}".') - elif isinstance(result, ChannelResolution): - results[index] = LookupError(f'Could not find channel in "{url}".') + await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results)) + results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)] censored = [ result if not isinstance(result, dict) or not censor.censor(result) @@ -175,15 +176,22 @@ class SearchIndex: ] return results, censored, censor + def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]): + cached = self.claim_cache.get(resolution) + return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url)) + async def get_many(self, *claim_ids): - missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache] + await self.populate_claim_cache(*claim_ids) + return filter(None, map(self.claim_cache.get, claim_ids)) + + async def populate_claim_cache(self, *claim_ids): + missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None] if missing: results = await self.search_client.mget( index=self.index, body={"ids": missing} ) for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) - return filter(None, map(self.claim_cache.get, claim_ids)) async def full_id_from_short_id(self, name, short_id, channel_id=None): key = (channel_id or '') + name + short_id @@ -304,23 +312,23 @@ class SearchIndex: def extract_doc(doc, index): - doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode() + doc['claim_id'] = doc.pop('claim_hash')[::-1].hex() if doc['reposted_claim_hash'] is not None: - doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode() + doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex() else: doc['reposted_claim_id'] = None channel_hash = doc.pop('channel_hash') - doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash + doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash channel_hash = doc.pop('censoring_channel_hash') - doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash + doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash txo_hash = doc.pop('txo_hash') - doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode() + doc['tx_id'] = txo_hash[:32][::-1].hex() doc['tx_nout'] = struct.unpack('