forked from LBRYCommunity/lbry-sdk
fixes from review
This commit is contained in:
parent
d47cf40544
commit
7df4cc44c4
3 changed files with 83 additions and 71 deletions
|
@ -8,7 +8,7 @@ services:
|
|||
wallet_server:
|
||||
depends_on:
|
||||
- es01
|
||||
image: lbry/wallet-server:${WALLET_SERVER_TAG:-development}
|
||||
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
|
||||
restart: always
|
||||
network_mode: host
|
||||
ports:
|
||||
|
|
|
@ -22,6 +22,7 @@ def set_reference(reference, txo_row):
|
|||
|
||||
class Censor:
|
||||
|
||||
NOT_CENSORED = 0
|
||||
SEARCH = 1
|
||||
RESOLVE = 2
|
||||
|
||||
|
@ -31,16 +32,19 @@ class Censor:
|
|||
self.censor_type = censor_type
|
||||
self.censored = {}
|
||||
|
||||
def is_censored(self, row):
|
||||
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type
|
||||
|
||||
def apply(self, rows):
|
||||
return [row for row in rows if not self.censor(row)]
|
||||
|
||||
def censor(self, row) -> bool:
|
||||
was_censored = (row.get('censor_type') or 0) >= self.censor_type
|
||||
if was_censored:
|
||||
if self.is_censored(row):
|
||||
censoring_channel_hash = row['censoring_channel_hash']
|
||||
self.censored.setdefault(censoring_channel_hash, set())
|
||||
self.censored[censoring_channel_hash].add(row['tx_hash'])
|
||||
return was_censored
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
|
||||
for censoring_channel_hash, count in self.censored.items():
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import struct
|
||||
from binascii import hexlify, unhexlify
|
||||
from binascii import unhexlify
|
||||
from decimal import Decimal
|
||||
from operator import itemgetter
|
||||
from typing import Optional, List, Iterable
|
||||
from typing import Optional, List, Iterable, Union
|
||||
|
||||
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
|
||||
from elasticsearch.helpers import async_streaming_bulk
|
||||
|
@ -21,11 +21,15 @@ from lbry.wallet.server.util import class_logger
|
|||
|
||||
|
||||
class ChannelResolution(str):
|
||||
pass
|
||||
@classmethod
|
||||
def lookup_error(cls, url):
|
||||
return LookupError(f'Could not find channel in "{url}".')
|
||||
|
||||
|
||||
class StreamResolution(str):
|
||||
pass
|
||||
@classmethod
|
||||
def lookup_error(cls, url):
|
||||
return LookupError(f'Could not find claim at "{url}".')
|
||||
|
||||
|
||||
class SearchIndex:
|
||||
|
@ -33,7 +37,7 @@ class SearchIndex:
|
|||
self.search_timeout = search_timeout
|
||||
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
|
||||
self.search_client: Optional[AsyncElasticsearch] = None
|
||||
self.client: Optional[AsyncElasticsearch] = None
|
||||
self.sync_client: Optional[AsyncElasticsearch] = None
|
||||
self.index = index_prefix + 'claims'
|
||||
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||
self.claim_cache = LRUCache(2 ** 15)
|
||||
|
@ -42,27 +46,27 @@ class SearchIndex:
|
|||
self.resolution_cache = LRUCache(2 ** 17)
|
||||
|
||||
async def start(self):
|
||||
if self.client:
|
||||
if self.sync_client:
|
||||
return
|
||||
self.client = AsyncElasticsearch(timeout=self.sync_timeout)
|
||||
self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
|
||||
self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
|
||||
while True:
|
||||
try:
|
||||
await self.client.cluster.health(wait_for_status='yellow')
|
||||
await self.sync_client.cluster.health(wait_for_status='yellow')
|
||||
break
|
||||
except ConnectionError:
|
||||
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
|
||||
await asyncio.sleep(1)
|
||||
res = await self.client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
|
||||
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
|
||||
return res.get('acknowledged', False)
|
||||
|
||||
def stop(self):
|
||||
clients = [self.client, self.search_client]
|
||||
self.client, self.search_client = None, None
|
||||
clients = [self.sync_client, self.search_client]
|
||||
self.sync_client, self.search_client = None, None
|
||||
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
|
||||
|
||||
def delete_index(self):
|
||||
return self.client.indices.delete(self.index, ignore_unavailable=True)
|
||||
return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
|
||||
|
||||
async def _consume_claim_producer(self, claim_producer):
|
||||
count = 0
|
||||
|
@ -77,49 +81,54 @@ class SearchIndex:
|
|||
self.logger.info("Indexing done for %d claims.", count)
|
||||
|
||||
async def claim_consumer(self, claim_producer):
|
||||
await self.client.indices.refresh(self.index)
|
||||
touched = set()
|
||||
async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer),
|
||||
async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
|
||||
raise_on_error=False):
|
||||
if not ok:
|
||||
self.logger.warning("indexing failed for an item: %s", item)
|
||||
else:
|
||||
item = item.popitem()[1]
|
||||
touched.add(item['_id'])
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
self.logger.info("Indexing done.")
|
||||
|
||||
def update_filter_query(self, censor_type, blockdict, channels=False):
|
||||
blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
|
||||
if channels:
|
||||
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
||||
else:
|
||||
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
||||
key = 'channel_id' if channels else 'claim_id'
|
||||
update['script'] = {
|
||||
"source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
|
||||
"lang": "painless",
|
||||
"params": blockdict
|
||||
}
|
||||
return update
|
||||
|
||||
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
|
||||
def make_query(censor_type, blockdict, channels=False):
|
||||
blockdict = dict(
|
||||
(hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items())
|
||||
if channels:
|
||||
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
||||
else:
|
||||
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
||||
key = 'channel_id' if channels else 'claim_id'
|
||||
update['script'] = {
|
||||
"source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
|
||||
"lang": "painless",
|
||||
"params": blockdict
|
||||
}
|
||||
return update
|
||||
if filtered_streams:
|
||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
if filtered_channels:
|
||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
if blocked_streams:
|
||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
if blocked_channels:
|
||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4)
|
||||
await self.client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
await self.sync_client.update_by_query(
|
||||
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
|
||||
await self.sync_client.indices.refresh(self.index)
|
||||
self.search_cache.clear()
|
||||
self.claim_cache.clear()
|
||||
self.resolution_cache.clear()
|
||||
|
@ -138,13 +147,13 @@ class SearchIndex:
|
|||
return cache_item.result
|
||||
censor = Censor(Censor.SEARCH)
|
||||
if kwargs.get('no_totals'):
|
||||
response, offset, total = await self.search(**kwargs, censor_type=0)
|
||||
response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
|
||||
else:
|
||||
response, offset, total = await self.search(**kwargs)
|
||||
censor.apply(response)
|
||||
total_referenced.extend(response)
|
||||
if censor.censored:
|
||||
response, _, _ = await self.search(**kwargs, censor_type=0)
|
||||
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
|
||||
total_referenced.extend(response)
|
||||
result = Outputs.to_base64(
|
||||
response, await self._get_referenced_rows(total_referenced), offset, total, censor
|
||||
|
@ -157,16 +166,8 @@ class SearchIndex:
|
|||
censor = Censor(Censor.RESOLVE)
|
||||
results = [await self.resolve_url(url) for url in urls]
|
||||
# just heat the cache
|
||||
await self.get_many(*filter(lambda x: isinstance(x, str), results))
|
||||
for index in range(len(results)):
|
||||
result = results[index]
|
||||
url = urls[index]
|
||||
if result in self.claim_cache:
|
||||
results[index] = self.claim_cache[result]
|
||||
elif isinstance(result, StreamResolution):
|
||||
results[index] = LookupError(f'Could not find claim at "{url}".')
|
||||
elif isinstance(result, ChannelResolution):
|
||||
results[index] = LookupError(f'Could not find channel in "{url}".')
|
||||
await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
|
||||
results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]
|
||||
|
||||
censored = [
|
||||
result if not isinstance(result, dict) or not censor.censor(result)
|
||||
|
@ -175,15 +176,22 @@ class SearchIndex:
|
|||
]
|
||||
return results, censored, censor
|
||||
|
||||
def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
|
||||
cached = self.claim_cache.get(resolution)
|
||||
return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))
|
||||
|
||||
async def get_many(self, *claim_ids):
|
||||
missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache]
|
||||
await self.populate_claim_cache(*claim_ids)
|
||||
return filter(None, map(self.claim_cache.get, claim_ids))
|
||||
|
||||
async def populate_claim_cache(self, *claim_ids):
|
||||
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
|
||||
if missing:
|
||||
results = await self.search_client.mget(
|
||||
index=self.index, body={"ids": missing}
|
||||
)
|
||||
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
|
||||
self.claim_cache.set(result['claim_id'], result)
|
||||
return filter(None, map(self.claim_cache.get, claim_ids))
|
||||
|
||||
async def full_id_from_short_id(self, name, short_id, channel_id=None):
|
||||
key = (channel_id or '') + name + short_id
|
||||
|
@ -304,23 +312,23 @@ class SearchIndex:
|
|||
|
||||
|
||||
def extract_doc(doc, index):
|
||||
doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode()
|
||||
doc['claim_id'] = doc.pop('claim_hash')[::-1].hex()
|
||||
if doc['reposted_claim_hash'] is not None:
|
||||
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode()
|
||||
doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex()
|
||||
else:
|
||||
doc['reposted_claim_id'] = None
|
||||
channel_hash = doc.pop('channel_hash')
|
||||
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
||||
doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash
|
||||
channel_hash = doc.pop('censoring_channel_hash')
|
||||
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
||||
doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash
|
||||
txo_hash = doc.pop('txo_hash')
|
||||
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
|
||||
doc['tx_id'] = txo_hash[:32][::-1].hex()
|
||||
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
|
||||
doc['is_controlling'] = bool(doc['is_controlling'])
|
||||
doc['signature'] = hexlify(doc.pop('signature') or b'').decode() or None
|
||||
doc['signature_digest'] = hexlify(doc.pop('signature_digest') or b'').decode() or None
|
||||
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None
|
||||
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None
|
||||
doc['signature'] = (doc.pop('signature') or b'').hex() or None
|
||||
doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None
|
||||
doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
|
||||
doc['public_key_hash'] = (doc.pop('public_key_hash') or b'').hex() or None
|
||||
doc['signature_valid'] = bool(doc['signature_valid'])
|
||||
doc['claim_type'] = doc.get('claim_type', 0) or 0
|
||||
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
|
||||
|
@ -357,14 +365,14 @@ def expand_query(**kwargs):
|
|||
value = [CLAIM_TYPES[claim_type] for claim_type in value]
|
||||
if key == '_id':
|
||||
if isinstance(value, Iterable):
|
||||
value = [hexlify(item[::-1]).decode() for item in value]
|
||||
value = [item[::-1].hex() for item in value]
|
||||
else:
|
||||
value = hexlify(value[::-1]).decode()
|
||||
value = value[::-1].hex()
|
||||
if not many and key in ('_id', 'claim_id') and len(value) < 20:
|
||||
partial_id = True
|
||||
if key == 'public_key_id':
|
||||
key = 'public_key_hash'
|
||||
value = hexlify(Base58.decode(value)[1:21]).decode()
|
||||
value = Base58.decode(value)[1:21].hex()
|
||||
if key == 'signature_valid':
|
||||
continue # handled later
|
||||
if key in TEXT_FIELDS:
|
||||
|
|
Loading…
Reference in a new issue