Merge pull request #3153 from lbryio/elasticsearch

hub: use Elasticsearch for `claim_search` and `resolve` calls
This commit is contained in:
Jack Robison 2021-03-24 16:44:14 -04:00 committed by GitHub
commit 2cc7e5dfdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 1006 additions and 590 deletions

View file

@ -37,6 +37,17 @@ jobs:
- blockchain
- other
steps:
- name: Configure sysctl limits
run: |
sudo swapoff -a
sudo sysctl -w vm.swappiness=1
sudo sysctl -w fs.file-max=262144
sudo sysctl -w vm.max_map_count=262144
- name: Runs Elasticsearch
uses: elastic/elastic-github-actions/elasticsearch@master
with:
stack-version: 7.6.0
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:

View file

@ -1,4 +1,4 @@
FROM ubuntu:20.04
FROM debian:10-slim
ARG user=lbry
ARG db_dir=/database
@ -13,7 +13,9 @@ RUN apt-get update && \
wget \
tar unzip \
build-essential \
python3 \
pkg-config \
libleveldb-dev \
python3.7 \
python3-dev \
python3-pip \
python3-wheel \

View file

@ -1,36 +1,40 @@
version: "3"
volumes:
lbrycrd:
wallet_server:
es01:
services:
lbrycrd:
image: lbry/lbrycrd:${LBRYCRD_TAG:-latest-release}
restart: always
ports: # accessible from host
- "9246:9246" # rpc port
expose: # internal to docker network. also this doesn't do anything. its for documentation only.
- "9245" # node-to-node comms port
volumes:
- "lbrycrd:/data/.lbrycrd"
environment:
- RUN_MODE=default
# Curently not snapshot provided
#- SNAPSHOT_URL=${LBRYCRD_SNAPSHOT_URL-https://lbry.com/snapshot/blockchain}
- RPC_ALLOW_IP=0.0.0.0/0
wallet_server:
depends_on:
- es01
image: lbry/wallet-server:${WALLET_SERVER_TAG:-latest-release}
depends_on:
- lbrycrd
restart: always
network_mode: host
ports:
- "50001:50001" # rpc port
- "50005:50005" # websocket port
#- "2112:2112" # uncomment to enable prometheus
- "2112:2112" # uncomment to enable prometheus
volumes:
- "wallet_server:/database"
env_file: [/home/lbry/wallet-server-env]
environment:
# Curently not snapshot provided
# - SNAPSHOT_URL=${WALLET_SERVER_SNAPSHOT_URL-https://lbry.com/snapshot/wallet}
- DAEMON_URL=http://lbry:lbry@lbrycrd:9245
- DAEMON_URL=http://lbry:lbry@127.0.0.1:9245
- TCP_PORT=50001
- PROMETHEUS_PORT=2112
es01:
image: docker.elastic.co/elasticsearch/elasticsearch:7.11.0
container_name: es01
environment:
- node.name=es01
- discovery.type=single-node
- indices.query.bool.max_clause_count=4096
- bootstrap.memory_lock=true
- "ES_JAVA_OPTS=-Xms8g -Xmx8g" # no more than 32, remember to disable swap
ulimits:
memlock:
soft: -1
hard: -1
volumes:
- es01:/usr/share/elasticsearch/data
ports:
- 127.0.0.1:9200:9200

View file

@ -20,4 +20,6 @@ if [[ -n "$SNAPSHOT_URL" ]] && [[ ! -f /database/claims.db ]]; then
rm "$filename"
fi
/home/lbry/.local/bin/torba-server "$@"
/home/lbry/.local/bin/lbry-hub-elastic-sync /database/claims.db
echo 'starting server'
/home/lbry/.local/bin/lbry-hub "$@"

View file

@ -13,57 +13,45 @@ NOT_FOUND = ErrorMessage.Code.Name(ErrorMessage.NOT_FOUND)
BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED)
def set_reference(reference, claim_hash, rows):
if claim_hash:
for txo in rows:
if claim_hash == txo['claim_hash']:
reference.tx_hash = txo['txo_hash'][:32]
reference.nout = struct.unpack('<I', txo['txo_hash'][32:])[0]
reference.height = txo['height']
return
def set_reference(reference, txo_row):
if txo_row:
reference.tx_hash = txo_row['txo_hash'][:32]
reference.nout = struct.unpack('<I', txo_row['txo_hash'][32:])[0]
reference.height = txo_row['height']
class Censor:
__slots__ = 'streams', 'channels', 'limit_claims_per_channel', 'censored', 'claims_in_channel', 'total'
NOT_CENSORED = 0
SEARCH = 1
RESOLVE = 2
def __init__(self, streams: dict = None, channels: dict = None, limit_claims_per_channel: int = None):
self.streams = streams or {}
self.channels = channels or {}
self.limit_claims_per_channel = limit_claims_per_channel # doesn't count as censored
__slots__ = 'censor_type', 'censored'
def __init__(self, censor_type):
self.censor_type = censor_type
self.censored = {}
self.claims_in_channel = {}
self.total = 0
def is_censored(self, row):
return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type
def apply(self, rows):
return [row for row in rows if not self.censor(row)]
def censor(self, row) -> bool:
was_censored = False
for claim_hash, lookup in (
(row['claim_hash'], self.streams),
(row['claim_hash'], self.channels),
(row['channel_hash'], self.channels),
(row['reposted_claim_hash'], self.streams),
(row['reposted_claim_hash'], self.channels)):
censoring_channel_hash = lookup.get(claim_hash)
if censoring_channel_hash:
was_censored = True
self.censored.setdefault(censoring_channel_hash, 0)
self.censored[censoring_channel_hash] += 1
break
if was_censored:
self.total += 1
if not was_censored and self.limit_claims_per_channel is not None and row['channel_hash']:
self.claims_in_channel.setdefault(row['channel_hash'], 0)
self.claims_in_channel[row['channel_hash']] += 1
if self.claims_in_channel[row['channel_hash']] > self.limit_claims_per_channel:
return True
return was_censored
if self.is_censored(row):
censoring_channel_hash = row['censoring_channel_hash']
self.censored.setdefault(censoring_channel_hash, set())
self.censored[censoring_channel_hash].add(row['tx_hash'])
return True
return False
def to_message(self, outputs: OutputsMessage, extra_txo_rows):
outputs.blocked_total = self.total
def to_message(self, outputs: OutputsMessage, extra_txo_rows: dict):
for censoring_channel_hash, count in self.censored.items():
blocked = outputs.blocked.add()
blocked.count = count
set_reference(blocked.channel, censoring_channel_hash, extra_txo_rows)
blocked.count = len(count)
set_reference(blocked.channel, extra_txo_rows.get(censoring_channel_hash))
outputs.blocked_total += len(count)
class Outputs:
@ -168,6 +156,7 @@ class Outputs:
@classmethod
def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> bytes:
extra_txo_rows = {row['claim_hash']: row for row in extra_txo_rows}
page = OutputsMessage()
page.offset = offset
if total is not None:
@ -176,12 +165,12 @@ class Outputs:
blocked.to_message(page, extra_txo_rows)
for row in txo_rows:
cls.row_to_message(row, page.txos.add(), extra_txo_rows)
for row in extra_txo_rows:
for row in extra_txo_rows.values():
cls.row_to_message(row, page.extra_txos.add(), extra_txo_rows)
return page.SerializeToString()
@classmethod
def row_to_message(cls, txo, txo_message, extra_txo_rows):
def row_to_message(cls, txo, txo_message, extra_row_dict: dict):
if isinstance(txo, Exception):
txo_message.error.text = txo.args[0]
if isinstance(txo, ValueError):
@ -190,7 +179,7 @@ class Outputs:
txo_message.error.code = ErrorMessage.NOT_FOUND
elif isinstance(txo, ResolveCensoredError):
txo_message.error.code = ErrorMessage.BLOCKED
set_reference(txo_message.error.blocked.channel, txo.censor_hash, extra_txo_rows)
set_reference(txo_message.error.blocked.channel, extra_row_dict.get(txo.censor_hash))
return
txo_message.tx_hash = txo['txo_hash'][:32]
txo_message.nout, = struct.unpack('<I', txo['txo_hash'][32:])
@ -213,5 +202,5 @@ class Outputs:
txo_message.claim.trending_mixed = txo['trending_mixed']
txo_message.claim.trending_local = txo['trending_local']
txo_message.claim.trending_global = txo['trending_global']
set_reference(txo_message.claim.channel, txo['channel_hash'], extra_txo_rows)
set_reference(txo_message.claim.repost, txo['reposted_claim_hash'], extra_txo_rows)
set_reference(txo_message.claim.channel, extra_row_dict.get(txo['channel_hash']))
set_reference(txo_message.claim.repost, extra_row_dict.get(txo['reposted_claim_hash']))

View file

@ -55,6 +55,14 @@ class PathSegment(NamedTuple):
def normalized(self):
return normalize_name(self.name)
@property
def is_shortid(self):
return self.claim_id is not None and len(self.claim_id) < 40
@property
def is_fullid(self):
return self.claim_id is not None and len(self.claim_id) == 40
def to_dict(self):
q = {'name': self.name}
if self.claim_id is not None:

View file

@ -417,9 +417,6 @@ class Network:
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls, session_override=None):
return self.rpc('blockchain.claimtrie.resolve', urls, False, session_override)

View file

@ -11,6 +11,7 @@ import importlib
from binascii import hexlify
from typing import Type, Optional
import urllib.request
from uuid import uuid4
import lbry
from lbry.wallet.server.server import Server
@ -187,7 +188,9 @@ class SPVNode:
'SESSION_TIMEOUT': str(self.session_timeout),
'MAX_QUERY_WORKERS': '0',
'INDIVIDUAL_TAG_INDEXES': '',
'RPC_PORT': self.rpc_port
'RPC_PORT': self.rpc_port,
'ES_INDEX_PREFIX': uuid4().hex,
'ES_MODE': 'writer',
}
if extraconf:
conf.update(extraconf)
@ -199,6 +202,8 @@ class SPVNode:
async def stop(self, cleanup=True):
try:
await self.server.db.search_index.delete_index()
await self.server.db.search_index.stop()
await self.server.stop()
finally:
cleanup and self.cleanup()

View file

@ -32,10 +32,13 @@ import inspect
# other_params: None means cannot be called with keyword arguments only
# any means any name is good
from functools import lru_cache
SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
'required_names other_names')
@lru_cache(256)
def signature_info(func):
params = inspect.signature(func).parameters
min_args = max_args = 0

View file

@ -5,7 +5,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional
from prometheus_client import Gauge, Histogram
import lbry
from lbry.schema.claim import Claim
from lbry.wallet.server.db.writer import SQLDB
from lbry.wallet.server.daemon import DaemonError
from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN
@ -215,6 +214,8 @@ class BlockProcessor:
if hprevs == chain:
start = time.perf_counter()
await self.run_in_thread_with_lock(self.advance_blocks, blocks)
if self.sql:
await self.db.search_index.claim_consumer(self.sql.claim_producer())
for cache in self.search_cache.values():
cache.clear()
self.history_cache.clear()
@ -228,6 +229,9 @@ class BlockProcessor:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time))
if self._caught_up_event.is_set():
if self.sql:
await self.db.search_index.apply_filters(self.sql.blocked_streams, self.sql.blocked_channels,
self.sql.filtered_streams, self.sql.filtered_channels)
await self.notifications.on_block(self.touched, self.height)
self.touched = set()
elif hprevs[0] != chain[0]:
@ -282,7 +286,6 @@ class BlockProcessor:
await self.run_in_thread_with_lock(flush_backup)
last -= len(raw_blocks)
await self.run_in_thread_with_lock(self.db.sql.delete_claims_above_height, self.height)
await self.prefetcher.reset_height(self.height)
self.reorg_count_metric.inc()
except:
@ -651,7 +654,11 @@ class BlockProcessor:
self.reorg_count = 0
else:
blocks = self.prefetcher.get_prefetched_blocks()
await self.check_and_advance_blocks(blocks)
try:
await self.check_and_advance_blocks(blocks)
except Exception:
self.logger.exception("error while processing txs")
raise
async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}')
@ -782,15 +789,17 @@ class LBRYBlockProcessor(BlockProcessor):
self.timer = Timer('BlockProcessor')
def advance_blocks(self, blocks):
self.sql.begin()
if self.sql:
self.sql.begin()
try:
self.timer.run(super().advance_blocks, blocks)
except:
self.logger.exception(f'Error while advancing transaction in new block.')
raise
finally:
self.sql.commit()
if self.db.first_sync and self.height == self.daemon.cached_height():
if self.sql:
self.sql.commit()
if self.sql and self.db.first_sync and self.height == self.daemon.cached_height():
self.timer.run(self.sql.execute, self.sql.SEARCH_INDEXES, timer_name='executing SEARCH_INDEXES')
if self.env.individual_tag_indexes:
self.timer.run(self.sql.execute, self.sql.TAG_INDEXES, timer_name='executing TAG_INDEXES')
@ -799,22 +808,8 @@ class LBRYBlockProcessor(BlockProcessor):
def advance_txs(self, height, txs, header, block_hash):
timer = self.timer.sub_timers['advance_blocks']
undo = timer.run(super().advance_txs, height, txs, header, block_hash, timer_name='super().advance_txs')
timer.run(self.sql.advance_txs, height, txs, header, self.daemon.cached_height(), forward_timer=True)
if self.sql:
timer.run(self.sql.advance_txs, height, txs, header, self.daemon.cached_height(), forward_timer=True)
if (height % 10000 == 0 or not self.db.first_sync) and self.logger.isEnabledFor(10):
self.timer.show(height=height)
return undo
def _checksig(self, value, address):
try:
claim_dict = Claim.from_bytes(value)
cert_id = claim_dict.signing_channel_hash
if not self.should_validate_signatures:
return cert_id
if cert_id:
cert_claim = self.db.get_claim_info(cert_id)
if cert_claim:
certificate = Claim.from_bytes(cert_claim.value)
claim_dict.validate_signature(address, certificate)
return cert_id
except Exception:
pass

View file

@ -8,7 +8,7 @@ from lbry.wallet.server.server import Server
def get_argument_parser():
parser = argparse.ArgumentParser(
prog="torba-server"
prog="lbry-hub"
)
parser.add_argument("spvserver", type=str, help="Python class path to SPV server implementation.",
nargs="?", default="lbry.wallet.server.coin.LBC")

View file

@ -0,0 +1 @@
from .search import SearchIndex

View file

@ -0,0 +1,61 @@
INDEX_DEFAULT_SETTINGS = {
"settings":
{"analysis":
{"analyzer": {
"default": {"tokenizer": "whitespace", "filter": ["lowercase", "porter_stem"]}}},
"index":
{"refresh_interval": -1,
"number_of_shards": 1,
"number_of_replicas": 0,
"sort": {
"field": ["trending_mixed", "release_time"],
"order": ["desc", "desc"]
}}
},
"mappings": {
"properties": {
"claim_id": {
"fields": {
"keyword": {
"ignore_above": 256,
"type": "keyword"
}
},
"type": "text",
"index_prefixes": {
"min_chars": 1,
"max_chars": 10
}
},
"height": {"type": "integer"},
"claim_type": {"type": "byte"},
"censor_type": {"type": "byte"},
"trending_mixed": {"type": "float"},
"release_time": {"type": "long"},
}
}
}
FIELDS = {'is_controlling', 'last_take_over_height', 'claim_id', 'claim_name', 'normalized', 'tx_position', 'amount',
'timestamp', 'creation_timestamp', 'height', 'creation_height', 'activation_height', 'expiration_height',
'release_time', 'short_url', 'canonical_url', 'title', 'author', 'description', 'claim_type', 'reposted',
'stream_type', 'media_type', 'fee_amount', 'fee_currency', 'duration', 'reposted_claim_hash', 'censor_type',
'claims_in_channel', 'channel_join', 'signature_valid', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed', 'trending_local', 'trending_global', 'channel_id', 'tx_id', 'tx_nout',
'signature', 'signature_digest', 'public_key_bytes', 'public_key_hash', 'public_key_id', '_id', 'tags',
'reposted_claim_id', 'has_source'}
TEXT_FIELDS = {'author', 'canonical_url', 'channel_id', 'claim_name', 'description', 'claim_id',
'media_type', 'normalized', 'public_key_bytes', 'public_key_hash', 'short_url', 'signature',
'signature_digest', 'stream_type', 'title', 'tx_id', 'fee_currency', 'reposted_claim_id', 'tags'}
RANGE_FIELDS = {
'height', 'creation_height', 'activation_height', 'expiration_height',
'timestamp', 'creation_timestamp', 'duration', 'release_time', 'fee_amount',
'tx_position', 'channel_join', 'reposted', 'limit_claims_per_channel',
'amount', 'effective_amount', 'support_amount',
'trending_group', 'trending_mixed', 'censor_type',
'trending_local', 'trending_global',
}
REPLACEMENTS = {
'name': 'normalized',
'txid': 'tx_id',
'claim_hash': '_id'
}

View file

@ -0,0 +1,520 @@
import asyncio
import struct
from binascii import unhexlify
from decimal import Decimal
from operator import itemgetter
from typing import Optional, List, Iterable, Union
from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError
from elasticsearch.helpers import async_streaming_bulk
from lbry.crypto.base58 import Base58
from lbry.error import ResolveCensoredError, claim_id as parse_claim_id
from lbry.schema.result import Outputs, Censor
from lbry.schema.tags import clean_tags
from lbry.schema.url import URL, normalize_name
from lbry.utils import LRUCache
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES
from lbry.wallet.server.db.elasticsearch.constants import INDEX_DEFAULT_SETTINGS, REPLACEMENTS, FIELDS, TEXT_FIELDS, \
RANGE_FIELDS
from lbry.wallet.server.util import class_logger
class ChannelResolution(str):
@classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find channel in "{url}".')
class StreamResolution(str):
@classmethod
def lookup_error(cls, url):
return LookupError(f'Could not find claim at "{url}".')
class SearchIndex:
def __init__(self, index_prefix: str, search_timeout=3.0):
self.search_timeout = search_timeout
self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import
self.search_client: Optional[AsyncElasticsearch] = None
self.sync_client: Optional[AsyncElasticsearch] = None
self.index = index_prefix + 'claims'
self.logger = class_logger(__name__, self.__class__.__name__)
self.claim_cache = LRUCache(2 ** 15)
self.short_id_cache = LRUCache(2 ** 17) # never invalidated, since short ids are forever
self.search_cache = LRUCache(2 ** 17)
self.resolution_cache = LRUCache(2 ** 17)
async def start(self):
if self.sync_client:
return
self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
while True:
try:
await self.sync_client.cluster.health(wait_for_status='yellow')
break
except ConnectionError:
self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
await asyncio.sleep(1)
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
return res.get('acknowledged', False)
def stop(self):
clients = [self.sync_client, self.search_client]
self.sync_client, self.search_client = None, None
return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))
def delete_index(self):
return self.sync_client.indices.delete(self.index, ignore_unavailable=True)
async def _consume_claim_producer(self, claim_producer):
count = 0
for op, doc in claim_producer:
if op == 'delete':
yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
else:
yield extract_doc(doc, self.index)
count += 1
if count % 100 == 0:
self.logger.info("Indexing in progress, %d claims.", count)
self.logger.info("Indexing done for %d claims.", count)
async def claim_consumer(self, claim_producer):
touched = set()
async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
raise_on_error=False):
if not ok:
self.logger.warning("indexing failed for an item: %s", item)
else:
item = item.popitem()[1]
touched.add(item['_id'])
await self.sync_client.indices.refresh(self.index)
self.logger.info("Indexing done.")
def update_filter_query(self, censor_type, blockdict, channels=False):
blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
if channels:
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
else:
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
key = 'channel_id' if channels else 'claim_id'
update['script'] = {
"source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
"lang": "painless",
"params": blockdict
}
return update
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
if filtered_streams:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if filtered_channels:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_streams:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_channels:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
self.search_cache.clear()
self.claim_cache.clear()
self.resolution_cache.clear()
async def session_query(self, query_name, kwargs):
offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
total_referenced = []
if query_name == 'resolve':
total_referenced, response, censor = await self.resolve(*kwargs)
else:
cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
if cache_item.result is not None:
return cache_item.result
async with cache_item.lock:
if cache_item.result:
return cache_item.result
censor = Censor(Censor.SEARCH)
if kwargs.get('no_totals'):
response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
else:
response, offset, total = await self.search(**kwargs)
censor.apply(response)
total_referenced.extend(response)
if censor.censored:
response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
total_referenced.extend(response)
result = Outputs.to_base64(
response, await self._get_referenced_rows(total_referenced), offset, total, censor
)
cache_item.result = result
return result
return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor)
async def resolve(self, *urls):
censor = Censor(Censor.RESOLVE)
results = [await self.resolve_url(url) for url in urls]
# just heat the cache
await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]
censored = [
result if not isinstance(result, dict) or not censor.censor(result)
else ResolveCensoredError(url, result['censoring_channel_hash'])
for url, result in zip(urls, results)
]
return results, censored, censor
def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
cached = self.claim_cache.get(resolution)
return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))
async def get_many(self, *claim_ids):
await self.populate_claim_cache(*claim_ids)
return filter(None, map(self.claim_cache.get, claim_ids))
async def populate_claim_cache(self, *claim_ids):
missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
if missing:
results = await self.search_client.mget(
index=self.index, body={"ids": missing}
)
for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
self.claim_cache.set(result['claim_id'], result)
async def full_id_from_short_id(self, name, short_id, channel_id=None):
key = (channel_id or '') + name + short_id
if key not in self.short_id_cache:
query = {'name': name, 'claim_id': short_id}
if channel_id:
query['channel_id'] = channel_id
query['order_by'] = ['^channel_join']
query['signature_valid'] = True
else:
query['order_by'] = '^creation_height'
result, _, _ = await self.search(**query, limit=1)
if len(result) == 1:
result = result[0]['claim_id']
self.short_id_cache[key] = result
return self.short_id_cache.get(key, None)
async def search(self, **kwargs):
if 'channel' in kwargs:
kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel'))
if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
return [], 0, 0
try:
result = (await self.search_client.search(
expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000
))['hits']
except NotFoundError:
return [], 0, 0
return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)
async def resolve_url(self, raw_url):
if raw_url not in self.resolution_cache:
self.resolution_cache[raw_url] = await self._resolve_url(raw_url)
return self.resolution_cache[raw_url]
async def _resolve_url(self, raw_url):
try:
url = URL.parse(raw_url)
except ValueError as e:
return e
stream = LookupError(f'Could not find claim at "{raw_url}".')
channel_id = await self.resolve_channel_id(url)
if isinstance(channel_id, LookupError):
return channel_id
stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
if url.has_stream:
return StreamResolution(stream)
else:
return ChannelResolution(channel_id)
async def resolve_channel_id(self, url: URL):
if not url.has_channel:
return
if url.channel.is_fullid:
return url.channel.claim_id
if url.channel.is_shortid:
channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id)
if not channel_id:
return LookupError(f'Could not find channel in "{url}".')
return channel_id
query = url.channel.to_dict()
if set(query) == {'name'}:
query['is_controlling'] = True
else:
query['order_by'] = ['^creation_height']
matches, _, _ = await self.search(**query, limit=1)
if matches:
channel_id = matches[0]['claim_id']
else:
return LookupError(f'Could not find channel in "{url}".')
return channel_id
async def resolve_stream(self, url: URL, channel_id: str = None):
if not url.has_stream:
return None
if url.has_channel and channel_id is None:
return None
query = url.stream.to_dict()
if url.stream.claim_id is not None:
if url.stream.is_fullid:
claim_id = url.stream.claim_id
else:
claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
return claim_id
if channel_id is not None:
if set(query) == {'name'}:
# temporarily emulate is_controlling for claims in channel
query['order_by'] = ['effective_amount', '^height']
else:
query['order_by'] = ['^channel_join']
query['channel_id'] = channel_id
query['signature_valid'] = True
elif set(query) == {'name'}:
query['is_controlling'] = True
matches, _, _ = await self.search(**query, limit=1)
if matches:
return matches[0]['claim_id']
async def _get_referenced_rows(self, txo_rows: List[dict]):
txo_rows = [row for row in txo_rows if isinstance(row, dict)]
referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
referenced_ids |= set(map(parse_claim_id, filter(None, (row['censoring_channel_hash'] for row in txo_rows))))
referenced_txos = []
if referenced_ids:
referenced_txos.extend(await self.get_many(*referenced_ids))
referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))
if referenced_ids:
referenced_txos.extend(await self.get_many(*referenced_ids))
return referenced_txos
def extract_doc(doc, index):
doc['claim_id'] = doc.pop('claim_hash')[::-1].hex()
if doc['reposted_claim_hash'] is not None:
doc['reposted_claim_id'] = doc.pop('reposted_claim_hash')[::-1].hex()
else:
doc['reposted_claim_id'] = None
channel_hash = doc.pop('channel_hash')
doc['channel_id'] = channel_hash[::-1].hex() if channel_hash else channel_hash
channel_hash = doc.pop('censoring_channel_hash')
doc['censoring_channel_hash'] = channel_hash[::-1].hex() if channel_hash else channel_hash
txo_hash = doc.pop('txo_hash')
doc['tx_id'] = txo_hash[:32][::-1].hex()
doc['tx_nout'] = struct.unpack('<I', txo_hash[32:])[0]
doc['is_controlling'] = bool(doc['is_controlling'])
doc['signature'] = (doc.pop('signature') or b'').hex() or None
doc['signature_digest'] = (doc.pop('signature_digest') or b'').hex() or None
doc['public_key_bytes'] = (doc.pop('public_key_bytes') or b'').hex() or None
doc['public_key_hash'] = (doc.pop('public_key_hash') or b'').hex() or None
doc['signature_valid'] = bool(doc['signature_valid'])
doc['claim_type'] = doc.get('claim_type', 0) or 0
doc['stream_type'] = int(doc.get('stream_type', 0) or 0)
doc['has_source'] = bool(doc['has_source'])
return {'doc': doc, '_id': doc['claim_id'], '_index': index, '_op_type': 'update', 'doc_as_upsert': True}
def expand_query(**kwargs):
if "amount_order" in kwargs:
kwargs["limit"] = 1
kwargs["order_by"] = "effective_amount"
kwargs["offset"] = int(kwargs["amount_order"]) - 1
if 'name' in kwargs:
kwargs['name'] = normalize_name(kwargs.pop('name'))
if kwargs.get('is_controlling') is False:
kwargs.pop('is_controlling')
query = {'must': [], 'must_not': []}
collapse = None
for key, value in kwargs.items():
key = key.replace('claim.', '')
many = key.endswith('__in') or isinstance(value, list)
if many:
key = key.replace('__in', '')
value = list(filter(None, value))
if value is None or isinstance(value, list) and len(value) == 0:
continue
key = REPLACEMENTS.get(key, key)
if key in FIELDS:
partial_id = False
if key == 'claim_type':
if isinstance(value, str):
value = CLAIM_TYPES[value]
else:
value = [CLAIM_TYPES[claim_type] for claim_type in value]
if key == '_id':
if isinstance(value, Iterable):
value = [item[::-1].hex() for item in value]
else:
value = value[::-1].hex()
if not many and key in ('_id', 'claim_id') and len(value) < 20:
partial_id = True
if key == 'public_key_id':
key = 'public_key_hash'
value = Base58.decode(value)[1:21].hex()
if key == 'signature_valid':
continue # handled later
if key in TEXT_FIELDS:
key += '.keyword'
ops = {'<=': 'lte', '>=': 'gte', '<': 'lt', '>': 'gt'}
if partial_id:
query['must'].append({"prefix": {"claim_id": value}})
elif key in RANGE_FIELDS and isinstance(value, str) and value[0] in ops:
operator_length = 2 if value[:2] in ops else 1
operator, value = value[:operator_length], value[operator_length:]
if key == 'fee_amount':
value = str(Decimal(value)*1000)
query['must'].append({"range": {key: {ops[operator]: value}}})
elif many:
query['must'].append({"terms": {key: value}})
else:
if key == 'fee_amount':
value = str(Decimal(value)*1000)
query['must'].append({"term": {key: {"value": value}}})
elif key == 'not_channel_ids':
for channel_id in value:
query['must_not'].append({"term": {'channel_id.keyword': channel_id}})
query['must_not'].append({"term": {'_id': channel_id}})
elif key == 'channel_ids':
query['must'].append({"terms": {'channel_id.keyword': value}})
elif key == 'claim_ids':
query['must'].append({"terms": {'claim_id.keyword': value}})
elif key == 'media_types':
query['must'].append({"terms": {'media_type.keyword': value}})
elif key == 'stream_types':
query['must'].append({"terms": {'stream_type': [STREAM_TYPES[stype] for stype in value]}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': clean_tags(value)}})
elif key == 'any_languages':
query['must'].append({"terms": {'languages': value}})
elif key == 'all_languages':
query['must'].extend([{"term": {'languages': tag}} for tag in value])
elif key == 'any_tags':
query['must'].append({"terms": {'tags.keyword': clean_tags(value)}})
elif key == 'all_tags':
query['must'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
elif key == 'not_tags':
query['must_not'].extend([{"term": {'tags.keyword': tag}} for tag in clean_tags(value)])
elif key == 'not_claim_id':
query['must_not'].extend([{"term": {'claim_id.keyword': cid}} for cid in value])
elif key == 'limit_claims_per_channel':
collapse = ('channel_id.keyword', value)
if kwargs.get('has_channel_signature'):
query['must'].append({"exists": {"field": "signature_digest"}})
if 'signature_valid' in kwargs:
query['must'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
elif 'signature_valid' in kwargs:
query.setdefault('should', [])
query["minimum_should_match"] = 1
query['should'].append({"bool": {"must_not": {"exists": {"field": "signature_digest"}}}})
query['should'].append({"term": {"signature_valid": bool(kwargs["signature_valid"])}})
if kwargs.get('text'):
query['must'].append(
{"simple_query_string":
{"query": kwargs["text"], "fields": [
"claim_name^4", "channel_name^8", "title^1", "description^.5", "author^1", "tags^.5"
]}})
query = {
"_source": {"excludes": ["description", "title"]},
'query': {'bool': query},
"sort": [],
}
if "limit" in kwargs:
query["size"] = kwargs["limit"]
if 'offset' in kwargs:
query["from"] = kwargs["offset"]
if 'order_by' in kwargs:
if isinstance(kwargs["order_by"], str):
kwargs["order_by"] = [kwargs["order_by"]]
for value in kwargs['order_by']:
if 'trending_group' in value:
# fixme: trending_mixed is 0 for all records on variable decay, making sort slow.
continue
is_asc = value.startswith('^')
value = value[1:] if is_asc else value
value = REPLACEMENTS.get(value, value)
if value in TEXT_FIELDS:
value += '.keyword'
query['sort'].append({value: "asc" if is_asc else "desc"})
if collapse:
query["collapse"] = {
"field": collapse[0],
"inner_hits": {
"name": collapse[0],
"size": collapse[1],
"sort": query["sort"]
}
}
return query
def expand_result(results):
inner_hits = []
expanded = []
for result in results:
if result.get("inner_hits"):
for _, inner_hit in result["inner_hits"].items():
inner_hits.extend(inner_hit["hits"]["hits"])
continue
result = result['_source']
result['claim_hash'] = unhexlify(result['claim_id'])[::-1]
if result['reposted_claim_id']:
result['reposted_claim_hash'] = unhexlify(result['reposted_claim_id'])[::-1]
else:
result['reposted_claim_hash'] = None
result['channel_hash'] = unhexlify(result['channel_id'])[::-1] if result['channel_id'] else None
result['txo_hash'] = unhexlify(result['tx_id'])[::-1] + struct.pack('<I', result['tx_nout'])
result['tx_hash'] = unhexlify(result['tx_id'])[::-1]
if result['censoring_channel_hash']:
result['censoring_channel_hash'] = unhexlify(result['censoring_channel_hash'])[::-1]
expanded.append(result)
if inner_hits:
return expand_result(inner_hits)
return expanded
class ResultCacheItem:
__slots__ = '_result', 'lock', 'has_result'
def __init__(self):
self.has_result = asyncio.Event()
self.lock = asyncio.Lock()
self._result = None
@property
def result(self) -> str:
return self._result
@result.setter
def result(self, result: str):
self._result = result
if result is not None:
self.has_result.set()
@classmethod
def from_cache(cls, cache_key, cache):
cache_item = cache.get(cache_key)
if cache_item is None:
cache_item = cache[cache_key] = ResultCacheItem()
return cache_item

View file

@ -0,0 +1,105 @@
import argparse
import asyncio
import logging
import os
from collections import namedtuple
from multiprocessing import Process
import apsw
from elasticsearch import AsyncElasticsearch
from elasticsearch.helpers import async_bulk
from .search import extract_doc, SearchIndex
INDEX = 'claims'
async def get_all(db, shard_num, shards_total, limit=0):
logging.info("shard %d starting", shard_num)
def exec_factory(cursor, statement, bindings):
tpl = namedtuple('row', (d[0] for d in cursor.getdescription()))
cursor.setrowtrace(lambda cursor, row: tpl(*row))
return True
db.setexectrace(exec_factory)
total = db.execute(f"select count(*) as total from claim where height % {shards_total} = {shard_num};").fetchone()[0]
for num, claim in enumerate(db.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
(select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim.height % {shards_total} = {shard_num}
ORDER BY claim.height desc
""")):
claim = dict(claim._asdict())
claim['censor_type'] = 0
claim['censoring_channel_hash'] = None
claim['tags'] = claim['tags'].split(',,') if claim['tags'] else []
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
if num % 10_000 == 0:
logging.info("%d/%d", num, total)
yield extract_doc(claim, INDEX)
if 0 < limit <= num:
break
async def consume(producer):
es = AsyncElasticsearch()
try:
await async_bulk(es, producer, request_timeout=120)
await es.indices.refresh(index=INDEX)
finally:
await es.close()
async def make_es_index():
index = SearchIndex('')
try:
return await index.start()
finally:
index.stop()
async def run(args, shard):
def itsbusy(*_):
logging.info("shard %d: db is busy, retry", shard)
return True
db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
db.setbusyhandler(itsbusy)
db.cursor().execute('pragma journal_mode=wal;')
db.cursor().execute('pragma temp_store=memory;')
producer = get_all(db.cursor(), shard, args.clients, limit=args.blocks)
await asyncio.gather(*(consume(producer) for _ in range(min(8, args.clients))))
def __run(args, shard):
asyncio.run(run(args, shard))
def run_elastic_sync():
logging.basicConfig(level=logging.INFO)
logging.info('lbry.server starting')
parser = argparse.ArgumentParser(prog="lbry-hub-elastic-sync")
parser.add_argument("db_path", type=str)
parser.add_argument("-c", "--clients", type=int, default=16)
parser.add_argument("-b", "--blocks", type=int, default=0)
parser.add_argument("-f", "--force", default=False, action='store_true')
args = parser.parse_args()
processes = []
if not args.force and not os.path.exists(args.db_path):
logging.info("DB path doesnt exist")
return
if not args.force and not asyncio.run(make_es_index()):
logging.info("ES is already initialized")
return
for i in range(args.clients):
processes.append(Process(target=__run, args=(args, i)))
processes[-1].start()
for process in processes:
process.join()
process.close()

View file

@ -1,52 +0,0 @@
from lbry.wallet.database import constraints_to_sql
CREATE_FULL_TEXT_SEARCH = """
create virtual table if not exists search using fts5(
claim_name, channel_name, title, description, author, tags,
content=claim, tokenize=porter
);
"""
FTS_ORDER_BY = "bm25(search, 4.0, 8.0, 1.0, 0.5, 1.0, 0.5)"
def fts_action_sql(claims=None, action='insert'):
select = {
'rowid': "claim.rowid",
'claim_name': "claim.normalized",
'channel_name': "channel.normalized",
'title': "claim.title",
'description': "claim.description",
'author': "claim.author",
'tags': "(select group_concat(tag, ' ') from tag where tag.claim_hash=claim.claim_hash)"
}
if action == 'delete':
select['search'] = '"delete"'
where, values = "", {}
if claims:
where, values = constraints_to_sql({'claim.claim_hash__in': claims})
where = 'WHERE '+where
return f"""
INSERT INTO search ({','.join(select.keys())})
SELECT {','.join(select.values())} FROM claim
LEFT JOIN claim as channel ON (claim.channel_hash=channel.claim_hash) {where}
""", values
def update_full_text_search(action, outputs, db, is_first_sync):
if is_first_sync:
return
if not outputs:
return
if action in ("before-delete", "before-update"):
db.execute(*fts_action_sql(outputs, 'delete'))
elif action in ("after-insert", "after-update"):
db.execute(*fts_action_sql(outputs, 'insert'))
else:
raise ValueError(f"Invalid action for updating full text search: '{action}'")
def first_sync_finished(db):
db.execute(*fts_action_sql())

View file

@ -1,11 +1,12 @@
import os
import apsw
from typing import Union, Tuple, Set, List
from itertools import chain
from decimal import Decimal
from collections import namedtuple
from multiprocessing import Manager
from binascii import unhexlify
from binascii import unhexlify, hexlify
from lbry.wallet.server.leveldb import LevelDB
from lbry.wallet.server.util import class_logger
from lbry.wallet.database import query, constraints_to_sql
@ -15,11 +16,10 @@ from lbry.schema.mime_types import guess_stream_type
from lbry.wallet import Ledger, RegTestLedger
from lbry.wallet.transaction import Transaction, Output
from lbry.wallet.server.db.canonical import register_canonical_functions
from lbry.wallet.server.db.full_text_search import update_full_text_search, CREATE_FULL_TEXT_SEARCH, first_sync_finished
from lbry.wallet.server.db.trending import TRENDING_ALGORITHMS
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES
from lbry.wallet.server.db.elasticsearch import SearchIndex
ATTRIBUTE_ARRAY_MAX_LENGTH = 100
@ -135,6 +135,22 @@ class SQLDB:
create index if not exists claimtrie_claim_hash_idx on claimtrie (claim_hash);
"""
CREATE_CHANGELOG_TRIGGER = """
create table if not exists changelog (
claim_hash bytes primary key
);
create index if not exists claimtrie_claim_hash_idx on claimtrie (claim_hash);
create trigger if not exists claim_changelog after update on claim
begin
insert or ignore into changelog (claim_hash) values (new.claim_hash);
end;
create trigger if not exists claimtrie_changelog after update on claimtrie
begin
insert or ignore into changelog (claim_hash) values (new.claim_hash);
insert or ignore into changelog (claim_hash) values (old.claim_hash);
end;
"""
SEARCH_INDEXES = """
-- used by any tag clouds
create index if not exists tag_tag_idx on tag (tag, claim_hash);
@ -190,10 +206,10 @@ class SQLDB:
CREATE_TABLES_QUERY = (
CREATE_CLAIM_TABLE +
CREATE_FULL_TEXT_SEARCH +
CREATE_SUPPORT_TABLE +
CREATE_CLAIMTRIE_TABLE +
CREATE_TAG_TABLE +
CREATE_CHANGELOG_TRIGGER +
CREATE_LANGUAGE_TABLE
)
@ -204,7 +220,6 @@ class SQLDB:
self.db = None
self.logger = class_logger(__name__, self.__class__.__name__)
self.ledger = Ledger if main.coin.NET == 'mainnet' else RegTestLedger
self._fts_synced = False
self.state_manager = None
self.blocked_streams = None
self.blocked_channels = None
@ -217,6 +232,7 @@ class SQLDB:
unhexlify(channel_id)[::-1] for channel_id in filtering_channels if channel_id
}
self.trending = trending
self.pending_deletes = set()
def open(self):
self.db = apsw.Connection(
@ -422,7 +438,7 @@ class SQLDB:
claims = self._upsertable_claims(txos, header)
if claims:
self.executemany("""
INSERT OR IGNORE INTO claim (
INSERT OR REPLACE INTO claim (
claim_hash, claim_id, claim_name, normalized, txo_hash, tx_position, amount,
claim_type, media_type, stream_type, timestamp, creation_timestamp, has_source,
fee_currency, fee_amount, title, description, author, duration, height, reposted_claim_hash,
@ -531,6 +547,7 @@ class SQLDB:
WHERE claim_hash = ?
""", targets
)
return set(target[0] for target in targets)
def validate_channel_signatures(self, height, new_claims, updated_claims, spent_claims, affected_channels, timer):
if not new_claims and not updated_claims and not spent_claims:
@ -804,11 +821,54 @@ class SQLDB:
f"SELECT claim_hash, normalized FROM claim WHERE expiration_height = {height}"
)
def enqueue_changes(self):
for claim in self.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
(select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim.claim_hash in (SELECT claim_hash FROM changelog)
"""):
claim = claim._asdict()
id_set = set(filter(None, (claim['claim_hash'], claim['channel_hash'], claim['reposted_claim_hash'])))
claim['censor_type'] = 0
claim['censoring_channel_hash'] = None
for reason_id in id_set:
if reason_id in self.blocked_streams:
claim['censor_type'] = 2
claim['censoring_channel_hash'] = self.blocked_streams.get(reason_id)
elif reason_id in self.blocked_channels:
claim['censor_type'] = 2
claim['censoring_channel_hash'] = self.blocked_channels.get(reason_id)
elif reason_id in self.filtered_streams:
claim['censor_type'] = 1
claim['censoring_channel_hash'] = self.filtered_streams.get(reason_id)
elif reason_id in self.filtered_channels:
claim['censor_type'] = 1
claim['censoring_channel_hash'] = self.filtered_channels.get(reason_id)
claim['tags'] = claim['tags'].split(',,') if claim['tags'] else []
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
yield 'update', claim
def clear_changelog(self):
self.execute("delete from changelog;")
def claim_producer(self):
while self.pending_deletes:
claim_hash = self.pending_deletes.pop()
yield 'delete', hexlify(claim_hash[::-1]).decode()
for claim in self.enqueue_changes():
yield claim
self.clear_changelog()
def advance_txs(self, height, all_txs, header, daemon_height, timer):
insert_claims = []
update_claims = []
update_claim_hashes = set()
delete_claim_hashes = set()
delete_claim_hashes = self.pending_deletes
insert_supports = []
delete_support_txo_hashes = set()
recalculate_claim_hashes = set() # added/deleted supports, added/updated claim
@ -877,28 +937,17 @@ class SQLDB:
expire_timer.stop()
r = timer.run
r(update_full_text_search, 'before-delete',
delete_claim_hashes, self.db.cursor(), self.main.first_sync)
affected_channels = r(self.delete_claims, delete_claim_hashes)
r(self.delete_supports, delete_support_txo_hashes)
r(self.insert_claims, insert_claims, header)
r(self.calculate_reposts, insert_claims)
r(update_full_text_search, 'after-insert',
[txo.claim_hash for txo in insert_claims], self.db.cursor(), self.main.first_sync)
r(update_full_text_search, 'before-update',
[txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync)
r(self.update_claims, update_claims, header)
r(update_full_text_search, 'after-update',
[txo.claim_hash for txo in update_claims], self.db.cursor(), self.main.first_sync)
r(self.validate_channel_signatures, height, insert_claims,
update_claims, delete_claim_hashes, affected_channels, forward_timer=True)
r(self.insert_supports, insert_supports)
r(self.update_claimtrie, height, recalculate_claim_hashes, deleted_claim_names, forward_timer=True)
for algorithm in self.trending:
r(algorithm.run, self.db.cursor(), height, daemon_height, recalculate_claim_hashes)
if not self._fts_synced and self.main.first_sync and height == daemon_height:
r(first_sync_finished, self.db.cursor())
self._fts_synced = True
class LBRYLevelDB(LevelDB):
@ -910,17 +959,28 @@ class LBRYLevelDB(LevelDB):
for algorithm_name in self.env.trending_algorithms:
if algorithm_name in TRENDING_ALGORITHMS:
trending.append(TRENDING_ALGORITHMS[algorithm_name])
self.sql = SQLDB(
self, path,
self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '),
self.env.default('FILTERING_CHANNEL_IDS', '').split(' '),
trending
)
if self.env.es_mode == 'reader':
self.logger.info('Index mode: reader')
self.sql = None
else:
self.logger.info('Index mode: writer. Using SQLite db to sync ES')
self.sql = SQLDB(
self, path,
self.env.default('BLOCKING_CHANNEL_IDS', '').split(' '),
self.env.default('FILTERING_CHANNEL_IDS', '').split(' '),
trending
)
# Search index
self.search_index = SearchIndex(self.env.es_index_prefix, self.env.database_query_timeout)
def close(self):
super().close()
self.sql.close()
if self.sql:
self.sql.close()
async def _open_dbs(self, *args, **kwargs):
await self.search_index.start()
await super()._open_dbs(*args, **kwargs)
self.sql.open()
if self.sql:
self.sql.open()

View file

@ -53,6 +53,8 @@ class Env:
coin_name = self.required('COIN').strip()
network = self.default('NET', 'mainnet').strip()
self.coin = Coin.lookup_coin_class(coin_name, network)
self.es_index_prefix = self.default('ES_INDEX_PREFIX', '')
self.es_mode = self.default('ES_MODE', 'writer')
self.cache_MB = self.integer('CACHE_MB', 1200)
self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT)
# Server stuff
@ -95,7 +97,7 @@ class Env:
self.identities = [identity
for identity in (clearnet_identity, tor_identity)
if identity is not None]
self.database_query_timeout = float(self.integer('QUERY_TIMEOUT_MS', 250)) / 1000.0
self.database_query_timeout = float(self.integer('QUERY_TIMEOUT_MS', 3000)) / 1000.0
@classmethod
def default(cls, envvar, default):

View file

@ -12,6 +12,7 @@
import asyncio
import array
import ast
import base64
import os
import time
import zlib
@ -82,6 +83,7 @@ class LevelDB:
self.utxo_db = None
self.tx_counts = None
self.headers = None
self.encoded_headers = LRUCacheWithMetrics(1 << 21, metric_name='encoded_headers', namespace='wallet_server')
self.last_flush = time.time()
self.logger.info(f'using {self.env.db_engine} for DB backend')
@ -440,6 +442,16 @@ class LevelDB:
raise IndexError(f'height {height:,d} out of range')
return header
def encode_headers(self, start_height, count, headers):
key = (start_height, count)
if not self.encoded_headers.get(key):
compressobj = zlib.compressobj(wbits=-15, level=1, memLevel=9)
headers = base64.b64encode(compressobj.compress(headers) + compressobj.flush()).decode()
if start_height % 1000 != 0:
return headers
self.encoded_headers[key] = headers
return self.encoded_headers.get(key)
def read_headers(self, start_height, count) -> typing.Tuple[bytes, int]:
"""Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This

View file

@ -210,6 +210,15 @@ class MemPool:
return deferred, {prevout: utxo_map[prevout] for prevout in unspent}
async def _mempool_loop(self, synchronized_event):
try:
return await self._refresh_hashes(synchronized_event)
except asyncio.CancelledError:
raise
except Exception as e:
self.logger.exception("MEMPOOL DIED")
raise e
async def _refresh_hashes(self, synchronized_event):
"""Refresh our view of the daemon's mempool."""
while True:
@ -326,7 +335,7 @@ class MemPool:
async def keep_synchronized(self, synchronized_event):
"""Keep the mempool synchronized with the daemon."""
await asyncio.wait([
self._refresh_hashes(synchronized_event),
self._mempool_loop(synchronized_event),
# self._refresh_histogram(synchronized_event),
self._logging(synchronized_event)
])

View file

@ -94,6 +94,7 @@ class Server:
self.session_mgr = env.coin.SESSION_MANAGER(
env, db, bp, daemon, mempool, self.shutdown_event
)
self._indexer_task = None
async def start(self):
env = self.env

View file

@ -3,7 +3,6 @@ import ssl
import math
import time
import json
import zlib
import base64
import codecs
import typing
@ -16,8 +15,10 @@ from asyncio import Event, sleep
from collections import defaultdict
from functools import partial
from binascii import hexlify, unhexlify
from binascii import hexlify
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from elasticsearch import ConnectionTimeout
from prometheus_client import Counter, Info, Histogram, Gauge
import lbry
@ -25,7 +26,6 @@ from lbry.utils import LRUCacheWithMetrics
from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG
from lbry.wallet.server.block_processor import LBRYBlockProcessor
from lbry.wallet.server.db.writer import LBRYLevelDB
from lbry.wallet.server.db import reader
from lbry.wallet.server.websocket import AdminWebSocket
from lbry.wallet.server.metrics import ServerLoadData, APICallMetrics
from lbry.wallet.rpc.framing import NewlineFramer
@ -813,9 +813,6 @@ class LBRYSessionManager(SessionManager):
self.running = False
if self.env.websocket_host is not None and self.env.websocket_port is not None:
self.websocket = AdminWebSocket(self)
self.search_cache = self.bp.search_cache
self.search_cache['search'] = LRUCacheWithMetrics(2 ** 14, metric_name='search', namespace=NAMESPACE)
self.search_cache['resolve'] = LRUCacheWithMetrics(2 ** 16, metric_name='resolve', namespace=NAMESPACE)
async def process_metrics(self):
while self.running:
@ -829,22 +826,11 @@ class LBRYSessionManager(SessionManager):
async def start_other(self):
self.running = True
path = os.path.join(self.env.db_dir, 'claims.db')
args = dict(
initializer=reader.initializer,
initargs=(
self.logger, path, self.env.coin.NET, self.env.database_query_timeout,
self.env.track_metrics, (
self.db.sql.blocked_streams, self.db.sql.blocked_channels,
self.db.sql.filtered_streams, self.db.sql.filtered_channels
)
)
)
if self.env.max_query_workers is not None and self.env.max_query_workers == 0:
self.query_executor = ThreadPoolExecutor(max_workers=1, **args)
self.query_executor = ThreadPoolExecutor(max_workers=1)
else:
self.query_executor = ProcessPoolExecutor(
max_workers=self.env.max_query_workers or max(os.cpu_count(), 4), **args
max_workers=self.env.max_query_workers or max(os.cpu_count(), 4)
)
if self.websocket is not None:
await self.websocket.start()
@ -897,7 +883,6 @@ class LBRYElectrumX(SessionBase):
'blockchain.transaction.get_height': cls.transaction_get_height,
'blockchain.claimtrie.search': cls.claimtrie_search,
'blockchain.claimtrie.resolve': cls.claimtrie_resolve,
'blockchain.claimtrie.getclaimsbyids': cls.claimtrie_getclaimsbyids,
'blockchain.block.get_server_height': cls.get_server_height,
'mempool.get_fee_histogram': cls.mempool_compact_histogram,
'blockchain.block.headers': cls.block_headers,
@ -1002,16 +987,6 @@ class LBRYElectrumX(SessionBase):
)
except asyncio.CancelledError:
raise
except reader.SQLiteInterruptedError as error:
metrics = self.get_metrics_or_placeholder_for_api(query_name)
metrics.query_interrupt(start, error.metrics)
self.session_mgr.interrupt_count_metric.inc()
raise RPCError(JSONRPC.QUERY_TIMEOUT, 'sqlite query timed out')
except reader.SQLiteOperationalError as error:
metrics = self.get_metrics_or_placeholder_for_api(query_name)
metrics.query_error(start, error.metrics)
self.session_mgr.db_operational_error_metric.inc()
raise RPCError(JSONRPC.INTERNAL_ERROR, 'query failed to execute')
except Exception:
log.exception("dear devs, please handle this exception better")
metrics = self.get_metrics_or_placeholder_for_api(query_name)
@ -1028,40 +1003,33 @@ class LBRYElectrumX(SessionBase):
self.session_mgr.pending_query_metric.dec()
self.session_mgr.executor_time_metric.observe(time.perf_counter() - start)
async def run_and_cache_query(self, query_name, function, kwargs):
metrics = self.get_metrics_or_placeholder_for_api(query_name)
metrics.start()
cache = self.session_mgr.search_cache[query_name]
cache_key = str(kwargs)
cache_item = cache.get(cache_key)
if cache_item is None:
cache_item = cache[cache_key] = ResultCacheItem()
elif cache_item.result is not None:
metrics.cache_response()
return cache_item.result
async with cache_item.lock:
if cache_item.result is None:
cache_item.result = await self.run_in_executor(
query_name, function, kwargs
)
else:
metrics = self.get_metrics_or_placeholder_for_api(query_name)
metrics.cache_response()
return cache_item.result
async def run_and_cache_query(self, query_name, kwargs):
start = time.perf_counter()
if isinstance(kwargs, dict):
kwargs['release_time'] = format_release_time(kwargs.get('release_time'))
try:
self.session_mgr.pending_query_metric.inc()
return await self.db.search_index.session_query(query_name, kwargs)
except ConnectionTimeout:
self.session_mgr.interrupt_count_metric.inc()
raise RPCError(JSONRPC.QUERY_TIMEOUT, 'query timed out')
finally:
self.session_mgr.pending_query_metric.dec()
self.session_mgr.executor_time_metric.observe(time.perf_counter() - start)
async def mempool_compact_histogram(self):
return self.mempool.compact_fee_histogram()
async def claimtrie_search(self, **kwargs):
if kwargs:
return await self.run_and_cache_query('search', reader.search_to_bytes, kwargs)
return await self.run_and_cache_query('search', kwargs)
async def claimtrie_resolve(self, *urls):
if urls:
count = len(urls)
try:
self.session_mgr.urls_to_resolve_count_metric.inc(count)
return await self.run_and_cache_query('resolve', reader.resolve_to_bytes, urls)
return await self.run_and_cache_query('resolve', urls)
finally:
self.session_mgr.resolved_url_count_metric.inc(count)
@ -1078,67 +1046,6 @@ class LBRYElectrumX(SessionBase):
return -1
return None
async def claimtrie_getclaimsbyids(self, *claim_ids):
claims = await self.batched_formatted_claims_from_daemon(claim_ids)
return dict(zip(claim_ids, claims))
async def batched_formatted_claims_from_daemon(self, claim_ids):
claims = await self.daemon.getclaimsbyids(claim_ids)
result = []
for claim in claims:
if claim and claim.get('value'):
result.append(self.format_claim_from_daemon(claim))
return result
def format_claim_from_daemon(self, claim, name=None):
"""Changes the returned claim data to the format expected by lbry and adds missing fields."""
if not claim:
return {}
# this ISO-8859 nonsense stems from a nasty form of encoding extended characters in lbrycrd
# it will be fixed after the lbrycrd upstream merge to v17 is done
# it originated as a fear of terminals not supporting unicode. alas, they all do
if 'name' in claim:
name = claim['name'].encode('ISO-8859-1').decode()
info = self.db.sql.get_claims(claim_id=claim['claimId'])
if not info:
# raise RPCError("Lbrycrd has {} but not lbryumx, please submit a bug report.".format(claim_id))
return {}
address = info.address.decode()
# fixme: temporary
#supports = self.format_supports_from_daemon(claim.get('supports', []))
supports = []
amount = get_from_possible_keys(claim, 'amount', 'nAmount')
height = get_from_possible_keys(claim, 'height', 'nHeight')
effective_amount = get_from_possible_keys(claim, 'effective amount', 'nEffectiveAmount')
valid_at_height = get_from_possible_keys(claim, 'valid at height', 'nValidAtHeight')
result = {
"name": name,
"claim_id": claim['claimId'],
"txid": claim['txid'],
"nout": claim['n'],
"amount": amount,
"depth": self.db.db_height - height + 1,
"height": height,
"value": hexlify(claim['value'].encode('ISO-8859-1')).decode(),
"address": address, # from index
"supports": supports,
"effective_amount": effective_amount,
"valid_at_height": valid_at_height
}
if 'claim_sequence' in claim:
# TODO: ensure that lbrycrd #209 fills in this value
result['claim_sequence'] = claim['claim_sequence']
else:
result['claim_sequence'] = -1
if 'normalized_name' in claim:
result['normalized_name'] = claim['normalized_name'].encode('ISO-8859-1').decode()
return result
def assert_tx_hash(self, value):
'''Raise an RPCError if the value is not a valid transaction
hash.'''
@ -1149,16 +1056,6 @@ class LBRYElectrumX(SessionBase):
pass
raise RPCError(1, f'{value} should be a transaction hash')
def assert_claim_id(self, value):
'''Raise an RPCError if the value is not a valid claim id
hash.'''
try:
if len(util.hex_to_bytes(value)) == 20:
return
except Exception:
pass
raise RPCError(1, f'{value} should be a claim id hash')
async def subscribe_headers_result(self):
"""The result of a header subscription or notification."""
return self.session_mgr.hsub_results[self.subscribe_headers_raw]
@ -1363,8 +1260,7 @@ class LBRYElectrumX(SessionBase):
headers, count = self.db.read_headers(start_height, count)
if b64:
compressobj = zlib.compressobj(wbits=-15, level=1, memLevel=9)
headers = base64.b64encode(compressobj.compress(headers) + compressobj.flush()).decode()
headers = self.db.encode_headers(start_height, count, headers)
else:
headers = headers.hex()
result = {
@ -1614,26 +1510,20 @@ class LocalRPC(SessionBase):
return 'RPC'
class ResultCacheItem:
__slots__ = '_result', 'lock', 'has_result'
def __init__(self):
self.has_result = asyncio.Event()
self.lock = asyncio.Lock()
self._result = None
@property
def result(self) -> str:
return self._result
@result.setter
def result(self, result: str):
self._result = result
if result is not None:
self.has_result.set()
def get_from_possible_keys(dictionary, *keys):
for key in keys:
if key in dictionary:
return dictionary[key]
def format_release_time(release_time):
# round release time to 1000 so it caches better
# also set a default so we dont show claims in the future
def roundup_time(number, factor=360):
return int(1 + int(number / factor)) * factor
if isinstance(release_time, str) and len(release_time) > 0:
time_digits = ''.join(filter(str.isdigit, release_time))
time_prefix = release_time[:-len(time_digits)]
return time_prefix + str(roundup_time(int(time_digits)))
elif isinstance(release_time, int):
return roundup_time(release_time)

View file

@ -1,177 +0,0 @@
import os
import time
import textwrap
import argparse
import asyncio
import logging
from concurrent.futures.process import ProcessPoolExecutor
from lbry.wallet.server.db.reader import search_to_bytes, initializer, _get_claims, interpolate
from lbry.wallet.ledger import MainNetLedger
log = logging.getLogger(__name__)
log.addHandler(logging.StreamHandler())
log.setLevel(logging.CRITICAL)
DEFAULT_ANY_TAGS = [
'blockchain',
'news',
'learning',
'technology',
'automotive',
'economics',
'food',
'science',
'art',
'nature'
]
COMMON_AND_RARE = [
'gaming',
'ufos'
]
COMMON_AND_RARE2 = [
'city fix',
'gaming'
]
RARE_ANY_TAGS = [
'city fix',
'ufos',
]
CITY_FIX = [
'city fix'
]
MATURE_TAGS = [
'porn',
'nsfw',
'mature',
'xxx'
]
ORDER_BY = [
[
"trending_global",
"trending_mixed",
],
[
"release_time"
],
[
"effective_amount"
]
]
def get_args(limit=20):
args = []
any_tags_combinations = [DEFAULT_ANY_TAGS, COMMON_AND_RARE, RARE_ANY_TAGS, COMMON_AND_RARE2, CITY_FIX, []]
not_tags_combinations = [MATURE_TAGS, []]
for no_fee in [False, True]:
for claim_type in [None, 'stream', 'channel']:
for no_totals in [True]:
for offset in [0, 100]:
for any_tags in any_tags_combinations:
for not_tags in not_tags_combinations:
for order_by in ORDER_BY:
kw = {
'order_by': order_by,
'offset': offset,
'limit': limit,
'no_totals': no_totals
}
if not_tags:
kw['not_tags'] = not_tags
if any_tags:
kw['any_tags'] = any_tags
if claim_type:
kw['claim_type'] = claim_type
if no_fee:
kw['fee_amount'] = 0
args.append(kw)
print(f"-- Trying {len(args)} argument combinations")
return args
def _search(kwargs):
start = time.perf_counter()
error = None
try:
search_to_bytes(kwargs)
except Exception as err:
error = str(err)
return time.perf_counter() - start, kwargs, error
async def search(executor, kwargs):
return await asyncio.get_running_loop().run_in_executor(
executor, _search, kwargs
)
async def main(db_path, max_query_time):
args = dict(initializer=initializer, initargs=(log, db_path, MainNetLedger, 0.25))
workers = max(os.cpu_count(), 4)
log.info(f"using {workers} reader processes")
query_executor = ProcessPoolExecutor(workers, **args)
tasks = [search(query_executor, constraints) for constraints in get_args()]
try:
results = await asyncio.gather(*tasks)
query_times = [
{
'sql': interpolate(*_get_claims("""
claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
claim.claim_hash, claim.txo_hash,
claim.claims_in_channel,
claim.height, claim.creation_height,
claim.activation_height, claim.expiration_height,
claim.effective_amount, claim.support_amount,
claim.trending_group, claim.trending_mixed,
claim.trending_local, claim.trending_global,
claim.short_url, claim.canonical_url,
claim.channel_hash, channel.txo_hash AS channel_txo_hash,
channel.height AS channel_height, claim.signature_valid
""", **constraints)),
'duration': ts,
'error': error
}
for ts, constraints, error in results
]
errored = [query_info for query_info in query_times if query_info['error']]
errors = {str(query_info['error']): [] for query_info in errored}
for error in errored:
errors[str(error['error'])].append(error['sql'])
slow = [
query_info for query_info in query_times
if not query_info['error'] and query_info['duration'] > (max_query_time / 2.0)
]
fast = [
query_info for query_info in query_times
if not query_info['error'] and query_info['duration'] <= (max_query_time / 2.0)
]
print(f"-- {len(fast)} queries were fast")
slow.sort(key=lambda query_info: query_info['duration'], reverse=True)
print(f"-- Failing queries:")
for error in errors:
print(f"-- Failure: \"{error}\"")
for failing_query in errors[error]:
print(f"{textwrap.dedent(failing_query)};\n")
print()
print(f"-- Slow queries:")
for slow_query in slow:
print(f"-- Query took {slow_query['duration']}\n{textwrap.dedent(slow_query['sql'])};\n")
finally:
query_executor.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--db_path', dest='db_path', default=os.path.expanduser('~/claims.db'), type=str)
parser.add_argument('--max_time', dest='max_time', default=0.25, type=float)
args = parser.parse_args()
db_path = args.db_path
max_query_time = args.max_time
asyncio.run(main(db_path, max_query_time))

View file

@ -1,62 +0,0 @@
import uvloop, asyncio, time, sys, logging
from concurrent.futures import ProcessPoolExecutor
from lbry.wallet.server.db import reader
from lbry.wallet.server.metrics import calculate_avg_percentiles
db_path = '../../../lbryconf/wallet-server/claims.db'
default_query_timout = 0.25
log = logging.getLogger(__name__)
log.addHandler(logging.StreamHandler())
async def run_times(executor, iterations, show=True):
start = time.perf_counter()
timings = await asyncio.gather(*(asyncio.get_running_loop().run_in_executor(
executor, reader.search_to_bytes, {
'no_totals': True,
'offset': 0,
'limit': 20,
'any_tags': [
'ufos', 'city fix'
],
'not_tags': [
'porn', 'mature', 'xxx', 'nsfw'
],
'order_by': [
'release_time'
]
}
) for _ in range(iterations)))
timings = [r[1]['execute_query'][0]['total'] for r in timings]
total = int((time.perf_counter() - start) * 100)
if show:
avg = sum(timings)/len(timings)
print(f"{iterations:4}: {total}ms total concurrent, {len(timings)*avg*1000:.3f}s total sequential (avg*runs)")
print(f" {total/len(timings):.1f}ms/query concurrent (total/runs)")
print(f" {avg:.1f}ms/query actual average (sum(queries)/runs)")
stats = calculate_avg_percentiles(timings)
print(f" min: {stats[1]}, 5%: {stats[2]}, 25%: {stats[3]}, 50%: {stats[4]}, 75%: {stats[5]}, 95%: {stats[6]}, max: {stats[7]}")
sys.stdout.write(' sample:')
for i, t in zip(range(10), timings[::-1]):
sys.stdout.write(f' {t}ms')
print(' ...\n' if len(timings) > 10 else '\n')
async def main():
executor = ProcessPoolExecutor(
4, initializer=reader.initializer, initargs=(log, db_path, 'mainnet', 1.0, True)
)
#await run_times(executor, 4, show=False)
#await run_times(executor, 1)
await run_times(executor, 2**3)
await run_times(executor, 2**5)
await run_times(executor, 2**7)
#await run_times(executor, 2**9)
#await run_times(executor, 2**11)
#await run_times(executor, 2**13)
executor.shutdown(True)
if __name__ == '__main__':
uvloop.install()
asyncio.run(main())

View file

@ -28,8 +28,9 @@ setup(
entry_points={
'console_scripts': [
'lbrynet=lbry.extras.cli:main',
'torba-server=lbry.wallet.server.cli:main',
'lbry-hub=lbry.wallet.server.cli:main',
'orchstr8=lbry.wallet.orchstr8.cli:main',
'lbry-hub-elastic-sync=lbry.wallet.server.db.elasticsearch.sync:run_elastic_sync'
],
},
install_requires=[
@ -53,7 +54,8 @@ setup(
'coincurve==11.0.0',
'pbkdf2==1.3',
'attrs==18.2.0',
'pylru==1.1.0'
'pylru==1.1.0',
'elasticsearch==7.10.1'
] + PLYVEL,
classifiers=[
'Framework :: AsyncIO',

View file

@ -114,15 +114,6 @@ class BlockchainReorganizationTests(CommandTestCase):
client_reorg_block_hash = (await self.ledger.headers.hash(208)).decode()
self.assertEqual(client_reorg_block_hash, reorg_block_hash)
# verify the dropped claim is no longer returned by claim search
txos, _, _, _ = await self.ledger.claim_search([], name='hovercraft')
self.assertListEqual(txos, [])
# verify the claim published a block earlier wasn't also reverted
txos, _, _, _ = await self.ledger.claim_search([], name='still-valid')
self.assertEqual(1, len(txos))
self.assertEqual(207, txos[0].tx_ref.height)
# broadcast the claim in a different block
new_txid = await self.blockchain.sendrawtransaction(hexlify(broadcast_tx.raw).decode())
self.assertEqual(broadcast_tx.id, new_txid)

View file

@ -3,6 +3,7 @@ import tempfile
import logging
import asyncio
from binascii import unhexlify
from unittest import skip
from urllib.request import urlopen
from lbry.error import InsufficientFundsError
@ -10,6 +11,7 @@ from lbry.extras.daemon.comment_client import verify
from lbry.extras.daemon.daemon import DEFAULT_PAGE_SIZE
from lbry.testcase import CommandTestCase
from lbry.wallet.orchstr8.node import SPVNode
from lbry.wallet.transaction import Transaction
from lbry.wallet.util import satoshis_to_coins as lbc
@ -72,9 +74,11 @@ class ClaimSearchCommand(ClaimTestCase):
for claim, result in zip(claims, results):
self.assertEqual(
(claim['txid'], self.get_claim_id(claim)),
(result['txid'], result['claim_id'])
(result['txid'], result['claim_id']),
f"{claim['outputs'][0]['name']} != {result['name']}"
)
@skip("doesnt happen on ES...?")
async def test_disconnect_on_memory_error(self):
claim_ids = [
'0000000000000000000000000000000000000000',
@ -94,6 +98,18 @@ class ClaimSearchCommand(ClaimTestCase):
with self.assertRaises(ConnectionResetError):
await self.claim_search(claim_ids=claim_ids)
async def test_claim_search_as_reader_server(self):
node2 = SPVNode(self.conductor.spv_module, node_number=2)
current_prefix = self.conductor.spv_node.server.bp.env.es_index_prefix
await node2.start(self.blockchain, extraconf={'ES_MODE': 'reader', 'ES_INDEX_PREFIX': current_prefix})
self.addCleanup(node2.stop)
self.ledger.network.config['default_servers'] = [(node2.hostname, node2.port)]
await self.ledger.stop()
await self.ledger.start()
channel2 = await self.channel_create('@abc', '0.1', allow_duplicate_name=True)
await asyncio.sleep(1) # fixme: find a way to block on the writer
await self.assertFindsClaims([channel2], name='@abc')
async def test_basic_claim_search(self):
await self.create_channel()
channel_txo = self.channel['outputs'][0]
@ -134,6 +150,7 @@ class ClaimSearchCommand(ClaimTestCase):
claims = [three, two, signed]
await self.assertFindsClaims(claims, channel_ids=[self.channel_id])
await self.assertFindsClaims(claims, channel=f"@abc#{self.channel_id}")
await self.assertFindsClaims([], channel=f"@inexistent")
await self.assertFindsClaims([three, two, signed2, signed], channel_ids=[channel_id2, self.channel_id])
await self.channel_abandon(claim_id=self.channel_id)
await self.assertFindsClaims([], channel=f"@abc#{self.channel_id}", valid_channel_signature=True)
@ -157,6 +174,10 @@ class ClaimSearchCommand(ClaimTestCase):
# abandoned stream won't show up for streams in channel search
await self.stream_abandon(txid=signed2['txid'], nout=0)
await self.assertFindsClaims([], channel_ids=[channel_id2])
# resolve by claim ids
await self.assertFindsClaims([three, two], claim_ids=[self.get_claim_id(three), self.get_claim_id(two)])
await self.assertFindsClaims([three], claim_id=self.get_claim_id(three))
await self.assertFindsClaims([three], claim_id=self.get_claim_id(three), text='*')
async def test_source_filter(self):
no_source = await self.stream_create('no_source', data=None)
@ -431,10 +452,11 @@ class ClaimSearchCommand(ClaimTestCase):
await self.assertFindsClaims([claim2], text='autobiography')
await self.assertFindsClaims([claim3], text='history')
await self.assertFindsClaims([claim4], text='conspiracy')
await self.assertFindsClaims([], text='conspiracy AND history')
await self.assertFindsClaims([claim4, claim3], text='conspiracy OR history')
await self.assertFindsClaims([claim1, claim4, claim2, claim3], text='documentary')
await self.assertFindsClaims([claim4, claim1, claim2, claim3], text='satoshi')
await self.assertFindsClaims([], text='conspiracy+history')
await self.assertFindsClaims([claim4, claim3], text='conspiracy|history')
await self.assertFindsClaims([claim1, claim4, claim2, claim3], text='documentary', order_by=[])
# todo: check why claim1 and claim2 order changed. used to be ...claim1, claim2...
await self.assertFindsClaims([claim4, claim2, claim1, claim3], text='satoshi', order_by=[])
claim2 = await self.stream_update(
self.get_claim_id(claim2), clear_tags=True, tags=['cloud'],
@ -1345,6 +1367,11 @@ class StreamCommands(ClaimTestCase):
self.assertEqual(1, blocked['channels'][0]['blocked'])
self.assertTrue(blocked['channels'][0]['channel']['short_url'].startswith('lbry://@filtering#'))
# same search, but details omitted by 'no_totals'
last_result = result
result = await self.out(self.daemon.jsonrpc_claim_search(name='bad_content', no_totals=True))
self.assertEqual(result['items'], last_result['items'])
# search inside channel containing filtered content
result = await self.out(self.daemon.jsonrpc_claim_search(channel='@some_channel'))
filtered = result['blocked']
@ -1354,6 +1381,11 @@ class StreamCommands(ClaimTestCase):
self.assertEqual(1, filtered['channels'][0]['blocked'])
self.assertTrue(filtered['channels'][0]['channel']['short_url'].startswith('lbry://@filtering#'))
# same search, but details omitted by 'no_totals'
last_result = result
result = await self.out(self.daemon.jsonrpc_claim_search(channel='@some_channel', no_totals=True))
self.assertEqual(result['items'], last_result['items'])
# content was filtered by not_tag before censoring
result = await self.out(self.daemon.jsonrpc_claim_search(channel='@some_channel', not_tags=["good", "bad"]))
self.assertEqual(0, len(result['items']))
@ -1407,6 +1439,13 @@ class StreamCommands(ClaimTestCase):
self.assertEqual(3, filtered['channels'][0]['blocked'])
self.assertTrue(filtered['channels'][0]['channel']['short_url'].startswith('lbry://@filtering#'))
# same search, but details omitted by 'no_totals'
last_result = result
result = await self.out(
self.daemon.jsonrpc_claim_search(any_tags=['bad-stuff'], order_by=['height'], no_totals=True)
)
self.assertEqual(result['items'], last_result['items'])
# filtered channel should still resolve
result = await self.resolve('lbry://@bad_channel')
self.assertEqual(bad_channel_id, result['claim_id'])

View file

@ -80,7 +80,6 @@ class ReconnectTests(IntegrationTestCase):
self.assertFalse(self.ledger.network.is_connected)
await self.ledger.resolve([], ['derp'])
self.assertEqual(50002, self.ledger.network.client.server[1])
await node2.stop(True)
async def test_direct_sync(self):
await self.ledger.stop()

View file

@ -18,8 +18,7 @@ from lbry.schema.tags import clean_tags
from lbry.schema.result import Outputs, Censor
from lbry.wallet import Ledger, RegTestLedger
from .common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES
from .full_text_search import FTS_ORDER_BY
from lbry.wallet.server.db.common import CLAIM_TYPES, STREAM_TYPES, COMMON_TAGS, INDEXED_LANGUAGES
class SQLiteOperationalError(apsw.Error):
@ -94,10 +93,10 @@ class ReaderState:
self.db.setprogresshandler(interruptor, 100)
def get_resolve_censor(self) -> Censor:
return Censor(self.blocked_streams, self.blocked_channels)
return Censor(Censor.RESOLVE)
def get_search_censor(self, limit_claims_per_channel: int) -> Censor:
return Censor(self.filtered_streams, self.filtered_channels, limit_claims_per_channel)
return Censor(Censor.SEARCH)
ctx: ContextVar[Optional[ReaderState]] = ContextVar('ctx')
@ -342,12 +341,7 @@ def claims_query(cols, for_count=False, **constraints) -> Tuple[str, Dict]:
_apply_constraints_for_array_attributes(constraints, 'language', lambda _: _, for_count)
_apply_constraints_for_array_attributes(constraints, 'location', lambda _: _, for_count)
if 'text' in constraints:
constraints["search"] = constraints.pop("text")
constraints["order_by"] = FTS_ORDER_BY
select = f"SELECT {cols} FROM search JOIN claim ON (search.rowid=claim.rowid)"
else:
select = f"SELECT {cols} FROM claim"
select = f"SELECT {cols} FROM claim"
if not for_count:
select += " LEFT JOIN claimtrie USING (claim_hash)"
return query(select, **constraints)
@ -372,7 +366,7 @@ def count_claims(**constraints) -> int:
constraints.pop('offset', None)
constraints.pop('limit', None)
constraints.pop('order_by', None)
count = select_claims(Censor(), 'count(*) as row_count', for_count=True, **constraints)
count = select_claims(Censor(Censor.SEARCH), 'count(*) as row_count', for_count=True, **constraints)
return count[0]['row_count']

View file

@ -8,12 +8,16 @@ from typing import List, Tuple
from lbry.wallet.constants import COIN, NULL_HASH32
from lbry.schema.claim import Claim
from lbry.schema.result import Censor
from lbry.wallet.server.db import reader, writer
from lbry.wallet.server.db import writer
from lbry.wallet.server.coin import LBCRegTest
from lbry.wallet.server.db.trending import zscore
from lbry.wallet.server.db.canonical import FindShortestID
from lbry.wallet.server.block_processor import Timer
from lbry.wallet.transaction import Transaction, Input, Output
try:
import reader
except:
from . import reader
def get_output(amount=COIN, pubkey_hash=NULL_HASH32):
@ -31,7 +35,7 @@ def get_tx():
def search(**constraints) -> List:
return reader.search_claims(Censor(), **constraints)
return reader.search_claims(Censor(Censor.SEARCH), **constraints)
def censored_search(**constraints) -> Tuple[List, Censor]:
@ -553,6 +557,7 @@ class TestTrending(TestSQLDB):
self.advance(zscore.TRENDING_WINDOW * 2, [self.get_support(problematic, 500000000)])
@unittest.skip("filtering/blocking is applied during ES sync, this needs to be ported to integration test")
class TestContentBlocking(TestSQLDB):
def test_blocking_and_filtering(self):