forked from LBRYCommunity/lbry-sdk
fixes from review
This commit is contained in:
parent
d47cf40544
commit
7df4cc44c4
3 changed files with 83 additions and 71 deletions
|
@ -8,7 +8,7 @@ services:
|
||||||
wallet_server:
|
wallet_server:
|
||||||
depends_on:
|
depends_on:
|
||||||
- es01
|
- es01
|
||||||
image: lbry/wallet-server:${WALLET_SERVER_TAG:-development}
|
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
|
||||||
restart: always
|
restart: always
|
||||||
network_mode: host
|
network_mode: host
|
||||||
ports:
|
ports:
|
||||||
|
|
|
@ -22,6 +22,7 @@ def set_reference(reference, txo_row):
|
||||||
|
|
||||||
class Censor:
|
class Censor:
|
||||||
|
|
||||||
|
NOT_CENSORED = 0
|
||||||
SEARCH = 1
|
SEARCH = 1
|
||||||
RESOLVE = 2
|
RESOLVE = 2
|
||||||
|
|
||||||
|
@ -31,16 +32,19 @@ class Censor:
|
||||||
self.censor_type = censor_type
|
self.censor_type = censor_type
|
||||||
self.censored = {}
|
self.censored = {}
|
||||||
|
|
||||||
|
def is_censored(self, row):
|
||||||
|
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type
|
||||||
|
|
||||||
def apply(self, rows):
|
def apply(self, rows):
|
||||||
return [row for row in rows if not self.censor(row)]
|
return [row for row in rows if not self.censor(row)]
|
||||||
|
|
||||||
def censor(self, row) -> bool:
|
def censor(self, row) -> bool:
|
||||||
was_censored = (row.get('censor_type') or 0) >= self.censor_type
|
if self.is_censored(row):
|
||||||
if was_censored:
|
|
||||||
censoring_channel_hash = row['censoring_channel_hash']
|
censoring_channel_hash = row['censoring_channel_hash']
|
||||||
self.censored.setdefault(censoring_channel_hash, set())
|
self.censored.setdefault(censoring_channel_hash, set())
|
||||||
self.censored[censoring_channel_hash].add(row['tx_hash'])
|
self.censored[censoring_channel_hash].add(row['tx_hash'])
|
||||||
return was_censored
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
|
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
|
||||||
for censoring_channel_hash, count in self.censored.items():
|
for censoring_channel_hash, count in self.censored.items():
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
import struct
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import unhexlify
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Optional, List, Iterable
|
from typing import Optional, List, Iterable, Union
|
||||||
|
|
||||||
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
|
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
|
||||||
from elasticsearch.helpers import async_streaming_bulk
|
from elasticsearch.helpers import async_streaming_bulk
|
||||||
|
@ -21,11 +21,15 @@ from lbry.wallet.server.util import class_logger
|
||||||
|
|
||||||
|
|
||||||
class ChannelResolution(str):
|
class ChannelResolution(str):
|
||||||
pass
|
@classmethod
|
||||||
|
def lookup_error(cls, url):
|
||||||
|
return LookupError(f'Could not find channel in "{url}".')
|
||||||
|
|
||||||
|
|
||||||
class StreamResolution(str):
|
class StreamResolution(str):
|
||||||
pass
|
@classmethod
|
||||||
|
def lookup_error(cls, url):
|
||||||
|
return LookupError(f'Could not find claim at "{url}".')
|
||||||
|
|
||||||
|
|
||||||
class SearchIndex:
|
class SearchIndex:
|
||||||
|
@ -33,7 +37,7 @@ class SearchIndex:
|
||||||
self.search_timeout = search_timeout
|
self.search_timeout = search_timeout
|
||||||
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
|
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
|
||||||
self.search_client: Optional[AsyncElasticsearch] = None
|
self.search_client: Optional[AsyncElasticsearch] = None
|
||||||
self.client: Optional[AsyncElasticsearch] = None
|
self.sync_client: Optional[AsyncElasticsearch] = None
|
||||||
self.index = index_prefix + 'claims'
|
self.index = index_prefix + 'claims'
|
||||||
self.logger = class_logger(__name__, self.__class__.__name__)
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||||||
self.claim_cache = LRUCache(2 ** 15)
|
self.claim_cache = LRUCache(2 ** 15)
|
||||||
|
@ -42,27 +46,27 @@ class SearchIndex:
|
||||||
self.resolution_cache = LRUCache(2 ** 17)
|
self.resolution_cache = LRUCache(2 ** 17)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
if self.client:
|
if self.sync_client:
|
||||||
return
|
return
|
||||||
self.client = AsyncElasticsearch(timeout=self.sync_timeout)
|
self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
|
||||||
self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
|
self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self.client.cluster.health(wait_for_status='yellow')
|
await self.sync_client.cluster.health(wait_for_status='yellow')
|
||||||
break
|
break
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
|
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
res = await self.client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
|
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
|
||||||
return res.get('acknowledged', False)
|
return res.get('acknowledged', False)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
clients = [self.client, self.search_client]
|
clients = [self.sync_client, self.search_client]
|
||||||
self.client, self.search_client = None, None
|
self.sync_client, self.search_client = None, None
|
||||||
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
|
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
|
||||||
|
|
||||||
def delete_index(self):
|
def delete_index(self):
|
||||||
return self.client.indices.delete(self.index, ignore_unavailable=True)
|
return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
|
||||||
|
|
||||||
async def _consume_claim_producer(self, claim_producer):
|
async def _consume_claim_producer(self, claim_producer):
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -77,22 +81,19 @@ class SearchIndex:
|
||||||
self.logger.info("Indexing done for %d claims.", count)
|
self.logger.info("Indexing done for %d claims.", count)
|
||||||
|
|
||||||
async def claim_consumer(self, claim_producer):
|
async def claim_consumer(self, claim_producer):
|
||||||
await self.client.indices.refresh(self.index)
|
|
||||||
touched = set()
|
touched = set()
|
||||||
async for ok, item in async_streaming_bulk(self.client, self._consume_claim_producer(claim_producer),
|
async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
|
||||||
raise_on_error=False):
|
raise_on_error=False):
|
||||||
if not ok:
|
if not ok:
|
||||||
self.logger.warning("indexing failed for an item: %s", item)
|
self.logger.warning("indexing failed for an item: %s", item)
|
||||||
else:
|
else:
|
||||||
item = item.popitem()[1]
|
item = item.popitem()[1]
|
||||||
touched.add(item['_id'])
|
touched.add(item['_id'])
|
||||||
await self.client.indices.refresh(self.index)
|
await self.sync_client.indices.refresh(self.index)
|
||||||
self.logger.info("Indexing done.")
|
self.logger.info("Indexing done.")
|
||||||
|
|
||||||
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
|
def update_filter_query(self, censor_type, blockdict, channels=False):
|
||||||
def make_query(censor_type, blockdict, channels=False):
|
blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
|
||||||
blockdict = dict(
|
|
||||||
(hexlify(key[::-1]).decode(), hexlify(value[::-1]).decode()) for key, value in blockdict.items())
|
|
||||||
if channels:
|
if channels:
|
||||||
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
|
||||||
else:
|
else:
|
||||||
|
@ -104,22 +105,30 @@ class SearchIndex:
|
||||||
"params": blockdict
|
"params": blockdict
|
||||||
}
|
}
|
||||||
return update
|
return update
|
||||||
|
|
||||||
|
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
|
||||||
if filtered_streams:
|
if filtered_streams:
|
||||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_streams), slices=4)
|
await self.sync_client.update_by_query(
|
||||||
await self.client.indices.refresh(self.index)
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
|
||||||
|
await self.sync_client.indices.refresh(self.index)
|
||||||
if filtered_channels:
|
if filtered_channels:
|
||||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels), slices=4)
|
await self.sync_client.update_by_query(
|
||||||
await self.client.indices.refresh(self.index)
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
|
||||||
await self.client.update_by_query(self.index, body=make_query(1, filtered_channels, True), slices=4)
|
await self.sync_client.indices.refresh(self.index)
|
||||||
await self.client.indices.refresh(self.index)
|
await self.sync_client.update_by_query(
|
||||||
|
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
|
||||||
|
await self.sync_client.indices.refresh(self.index)
|
||||||
if blocked_streams:
|
if blocked_streams:
|
||||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_streams), slices=4)
|
await self.sync_client.update_by_query(
|
||||||
await self.client.indices.refresh(self.index)
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
|
||||||
|
await self.sync_client.indices.refresh(self.index)
|
||||||
if blocked_channels:
|
if blocked_channels:
|
||||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels), slices=4)
|
await self.sync_client.update_by_query(
|
||||||
await self.client.indices.refresh(self.index)
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
|
||||||
await self.client.update_by_query(self.index, body=make_query(2, blocked_channels, True), slices=4)
|
await self.sync_client.indices.refresh(self.index)
|
||||||
await self.client.indices.refresh(self.index)
|
await self.sync_client.update_by_query(
|
||||||
|
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
|
||||||
|
await self.sync_client.indices.refresh(self.index)
|
||||||
self.search_cache.clear()
|
self.search_cache.clear()
|
||||||
self.claim_cache.clear()
|
self.claim_cache.clear()
|
||||||
self.resolution_cache.clear()
|
self.resolution_cache.clear()
|
||||||
|
@ -138,13 +147,13 @@ class SearchIndex:
|
||||||
return cache_item.result
|
return cache_item.result
|
||||||
censor = Censor(Censor.SEARCH)
|
censor = Censor(Censor.SEARCH)
|
||||||
if kwargs.get('no_totals'):
|
if kwargs.get('no_totals'):
|
||||||
response, offset, total = await self.search(**kwargs, censor_type=0)
|
response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
|
||||||
else:
|
else:
|
||||||
response, offset, total = await self.search(**kwargs)
|
response, offset, total = await self.search(**kwargs)
|
||||||
censor.apply(response)
|
censor.apply(response)
|
||||||
total_referenced.extend(response)
|
total_referenced.extend(response)
|
||||||
if censor.censored:
|
if censor.censored:
|
||||||
response, _, _ = await self.search(**kwargs, censor_type=0)
|
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
|
||||||
total_referenced.extend(response)
|
total_referenced.extend(response)
|
||||||
result = Outputs.to_base64(
|
result = Outputs.to_base64(
|
||||||
response, await self._get_referenced_rows(total_referenced), offset, total, censor
|
response, await self._get_referenced_rows(total_referenced), offset, total, censor
|
||||||
|
@ -157,16 +166,8 @@ class SearchIndex:
|
||||||
censor = Censor(Censor.RESOLVE)
|
censor = Censor(Censor.RESOLVE)
|
||||||
results = [await self.resolve_url(url) for url in urls]
|
results = [await self.resolve_url(url) for url in urls]
|
||||||
# just heat the cache
|
# just heat the cache
|
||||||
await self.get_many(*filter(lambda x: isinstance(x, str), results))
|
await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
|
||||||
for index in range(len(results)):
|
results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]
|
||||||
result = results[index]
|
|
||||||
url = urls[index]
|
|
||||||
if result in self.claim_cache:
|
|
||||||
results[index] = self.claim_cache[result]
|
|
||||||
elif isinstance(result, StreamResolution):
|
|
||||||
results[index] = LookupError(f'Could not find claim at "{url}".')
|
|
||||||
elif isinstance(result, ChannelResolution):
|
|
||||||
results[index] = LookupError(f'Could not find channel in "{url}".')
|
|
||||||
|
|
||||||
censored = [
|
censored = [
|
||||||
result if not isinstance(result, dict) or not censor.censor(result)
|
result if not isinstance(result, dict) or not censor.censor(result)
|
||||||
|
@ -175,15 +176,22 @@ class SearchIndex:
|
||||||
]
|
]
|
||||||
return results, censored, censor
|
return results, censored, censor
|
||||||
|
|
||||||
|
def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
|
||||||
|
cached = self.claim_cache.get(resolution)
|
||||||
|
return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))
|
||||||
|
|
||||||
async def get_many(self, *claim_ids):
|
async def get_many(self, *claim_ids):
|
||||||
missing = [claim_id for claim_id in claim_ids if claim_id not in self.claim_cache]
|
await self.populate_claim_cache(*claim_ids)
|
||||||
|
return filter(None, map(self.claim_cache.get, claim_ids))
|
||||||
|
|
||||||
|
async def populate_claim_cache(self, *claim_ids):
|
||||||
|
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
|
||||||
if missing:
|
if missing:
|
||||||
results = await self.search_client.mget(
|
results = await self.search_client.mget(
|
||||||
index=self.index, body={"ids": missing}
|
index=self.index, body={"ids": missing}
|
||||||
)
|
)
|
||||||
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
|
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
|
||||||
self.claim_cache.set(result['claim_id'], result)
|
self.claim_cache.set(result['claim_id'], result)
|
||||||
return filter(None, map(self.claim_cache.get, claim_ids))
|
|
||||||
|
|
||||||
async def full_id_from_short_id(self, name, short_id, channel_id=None):
|
async def full_id_from_short_id(self, name, short_id, channel_id=None):
|
||||||
key = (channel_id or '') + name + short_id
|
key = (channel_id or '') + name + short_id
|
||||||
|
@ -304,23 +312,23 @@ class SearchIndex:
|
||||||
|
|
||||||
|
|
||||||
def extract_doc(doc, index):
|
def extract_doc(doc, index):
|
||||||
doc['claim_id'] = hexlify(doc.pop('claim_hash')[::-1]).decode()
|
doc['claim_id'] = doc.pop('claim_hash')[::-1].hex()
|
||||||
if doc['reposted_claim_hash'] is not None:
|
if doc['reposted_claim_hash'] is not None:
|
||||||
doc['reposted_claim_id'] = hexlify(doc.pop('reposted_claim_hash')[::-1]).decode()
|
doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex()
|
||||||
else:
|
else:
|
||||||
doc['reposted_claim_id'] = None
|
doc['reposted_claim_id'] = None
|
||||||
channel_hash = doc.pop('channel_hash')
|
channel_hash = doc.pop('channel_hash')
|
||||||
doc['channel_id'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash
|
||||||
channel_hash = doc.pop('censoring_channel_hash')
|
channel_hash = doc.pop('censoring_channel_hash')
|
||||||
doc['censoring_channel_hash'] = hexlify(channel_hash[::-1]).decode() if channel_hash else channel_hash
|
doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash
|
||||||
txo_hash = doc.pop('txo_hash')
|
txo_hash = doc.pop('txo_hash')
|
||||||
doc['tx_id'] = hexlify(txo_hash[:32][::-1]).decode()
|
doc['tx_id'] = txo_hash[:32][::-1].hex()
|
||||||
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
|
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
|
||||||
doc['is_controlling'] = bool(doc['is_controlling'])
|
doc['is_controlling'] = bool(doc['is_controlling'])
|
||||||
doc['signature'] = hexlify(doc.pop('signature') or b'').decode() or None
|
doc['signature'] = (doc.pop('signature') or b'').hex() or None
|
||||||
doc['signature_digest'] = hexlify(doc.pop('signature_digest') or b'').decode() or None
|
doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None
|
||||||
doc['public_key_bytes'] = hexlify(doc.pop('public_key_bytes') or b'').decode() or None
|
doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
|
||||||
doc['public_key_hash'] = hexlify(doc.pop('public_key_hash') or b'').decode() or None
|
doc['public_key_hash'] = (doc.pop('public_key_hash') or b'').hex() or None
|
||||||
doc['signature_valid'] = bool(doc['signature_valid'])
|
doc['signature_valid'] = bool(doc['signature_valid'])
|
||||||
doc['claim_type'] = doc.get('claim_type', 0) or 0
|
doc['claim_type'] = doc.get('claim_type', 0) or 0
|
||||||
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
|
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
|
||||||
|
@ -357,14 +365,14 @@ def expand_query(**kwargs):
|
||||||
value = [CLAIM_TYPES[claim_type] for claim_type in value]
|
value = [CLAIM_TYPES[claim_type] for claim_type in value]
|
||||||
if key == '_id':
|
if key == '_id':
|
||||||
if isinstance(value, Iterable):
|
if isinstance(value, Iterable):
|
||||||
value = [hexlify(item[::-1]).decode() for item in value]
|
value = [item[::-1].hex() for item in value]
|
||||||
else:
|
else:
|
||||||
value = hexlify(value[::-1]).decode()
|
value = value[::-1].hex()
|
||||||
if not many and key in ('_id', 'claim_id') and len(value) < 20:
|
if not many and key in ('_id', 'claim_id') and len(value) < 20:
|
||||||
partial_id = True
|
partial_id = True
|
||||||
if key == 'public_key_id':
|
if key == 'public_key_id':
|
||||||
key = 'public_key_hash'
|
key = 'public_key_hash'
|
||||||
value = hexlify(Base58.decode(value)[1:21]).decode()
|
value = Base58.decode(value)[1:21].hex()
|
||||||
if key == 'signature_valid':
|
if key == 'signature_valid':
|
||||||
continue # handled later
|
continue # handled later
|
||||||
if key in TEXT_FIELDS:
|
if key in TEXT_FIELDS:
|
||||||
|
|
Loading…
Reference in a new issue