lbry-sdk/lbry/wallet/server/db/elasticsearch/sync.py
Jack Robison 888d47f88b
fix test
2022-02-22 13:10:08 -05:00

324 lines
14 KiB
Python

import os
import time
import signal
import json
import typing
from collections import defaultdict
import asyncio
import logging
from elasticsearch import AsyncElasticsearch, NotFoundError
from elasticsearch.helpers import async_streaming_bulk
from lbry.schema.result import Censor
from lbry.wallet.server.db.elasticsearch.search import IndexVersionMismatch
from lbry.wallet.server.db.elasticsearch.constants import ALL_FIELDS, INDEX_DEFAULT_SETTINGS
from lbry.wallet.server.db.elasticsearch.common import expand_query
from lbry.wallet.server.db.elasticsearch.notifier import ElasticNotifierProtocol
from lbry.wallet.server.db.elasticsearch.fast_ar_trending import FAST_AR_TRENDING_SCRIPT
from lbry.wallet.server.chain_reader import BlockchainReader
from lbry.wallet.server.db.revertable import RevertableOp
from lbry.wallet.server.db import DB_PREFIXES
log = logging.getLogger()
class ElasticWriter(BlockchainReader):
VERSION = 1
def __init__(self, env):
super().__init__(env, 'lbry-elastic-writer', thread_workers=1, thread_prefix='lbry-elastic-writer')
# self._refresh_interval = 0.1
self._task = None
self.index = self.env.es_index_prefix + 'claims'
self._elastic_host = env.elastic_host
self._elastic_port = env.elastic_port
self.sync_timeout = 1800
self.sync_client = AsyncElasticsearch(
[{'host': self._elastic_host, 'port': self._elastic_port}], timeout=self.sync_timeout
)
self._es_info_path = os.path.join(env.db_dir, 'es_info')
self._last_wrote_height = 0
self._last_wrote_block_hash = None
self._touched_claims = set()
self._deleted_claims = set()
self._removed_during_undo = set()
self._trending = defaultdict(list)
self._advanced = True
self.synchronized = asyncio.Event()
self._listeners: typing.List[ElasticNotifierProtocol] = []
async def run_es_notifier(self, synchronized: asyncio.Event):
server = await asyncio.get_event_loop().create_server(
lambda: ElasticNotifierProtocol(self._listeners), '127.0.0.1', self.env.elastic_notifier_port
)
self.log.info("ES notifier server listening on TCP localhost:%i", self.env.elastic_notifier_port)
synchronized.set()
async with server:
await server.serve_forever()
def notify_es_notification_listeners(self, height: int, block_hash: bytes):
for p in self._listeners:
p.send_height(height, block_hash)
self.log.info("notify listener %i", height)
def _read_es_height(self):
with open(self._es_info_path, 'r') as f:
info = json.loads(f.read())
self._last_wrote_height = int(info.get('height', 0))
self._last_wrote_block_hash = info.get('block_hash', None)
async def read_es_height(self):
await asyncio.get_event_loop().run_in_executor(self._executor, self._read_es_height)
def write_es_height(self, height: int, block_hash: str):
with open(self._es_info_path, 'w') as f:
f.write(json.dumps({'height': height, 'block_hash': block_hash}, indent=2))
self._last_wrote_height = height
self._last_wrote_block_hash = block_hash
async def get_index_version(self) -> int:
try:
template = await self.sync_client.indices.get_template(self.index)
return template[self.index]['version']
except NotFoundError:
return 0
async def set_index_version(self, version):
await self.sync_client.indices.put_template(
self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400
)
async def start_index(self) -> bool:
if self.sync_client:
return False
hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
while True:
try:
await self.sync_client.cluster.health(wait_for_status='yellow')
break
except ConnectionError:
self.log.warning("Failed to connect to Elasticsearch. Waiting for it!")
await asyncio.sleep(1)
res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
acked = res.get('acknowledged', False)
if acked:
await self.set_index_version(self.VERSION)
return acked
index_version = await self.get_index_version()
if index_version != self.VERSION:
self.log.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION)
raise IndexVersionMismatch(index_version, self.VERSION)
await self.sync_client.indices.refresh(self.index)
return acked
async def stop_index(self):
if self.sync_client:
await self.sync_client.close()
self.sync_client = None
async def delete_index(self):
if self.sync_client:
return await self.sync_client.indices.delete(self.index, ignore_unavailable=True)
def update_filter_query(self, censor_type, blockdict, channels=False):
blockdict = {blocked.hex(): blocker.hex() for blocked, blocker in blockdict.items()}
if channels:
update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
else:
update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
key = 'channel_id' if channels else 'claim_id'
update['script'] = {
"source": f"ctx._source.censor_type={censor_type}; "
f"ctx._source.censoring_channel_id=params[ctx._source.{key}];",
"lang": "painless",
"params": blockdict
}
return update
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
if filtered_streams:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if filtered_channels:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_streams:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
await self.sync_client.indices.refresh(self.index)
if blocked_channels:
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
await self.sync_client.indices.refresh(self.index)
await self.sync_client.update_by_query(
self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
await self.sync_client.indices.refresh(self.index)
async def _claim_producer(self):
for deleted in self._deleted_claims:
yield {
'_index': self.index,
'_op_type': 'delete',
'_id': deleted.hex()
}
for touched in self._touched_claims:
claim = self.db.claim_producer(touched)
if claim:
yield {
'doc': {key: value for key, value in claim.items() if key in ALL_FIELDS},
'_id': claim['claim_id'],
'_index': self.index,
'_op_type': 'update',
'doc_as_upsert': True
}
for claim_hash, notifications in self._trending.items():
yield {
'_id': claim_hash.hex(),
'_index': self.index,
'_op_type': 'update',
'script': {
'lang': 'painless',
'source': FAST_AR_TRENDING_SCRIPT,
'params': {'src': {
'changes': [
{
'height': notify_height,
'prev_amount': trending_v.previous_amount / 1E8,
'new_amount': trending_v.new_amount / 1E8,
} for (notify_height, trending_v) in notifications
]
}}
},
}
def advance(self, height: int):
super().advance(height)
touched_or_deleted = self.db.prefix_db.touched_or_deleted.get(height)
for k, v in self.db.prefix_db.trending_notification.iterate((height,)):
self._trending[k.claim_hash].append((k.height, v))
if touched_or_deleted:
readded_after_reorg = self._removed_during_undo.intersection(touched_or_deleted.touched_claims)
self._deleted_claims.difference_update(readded_after_reorg)
self._touched_claims.update(touched_or_deleted.touched_claims)
self._deleted_claims.update(touched_or_deleted.deleted_claims)
self._touched_claims.difference_update(self._deleted_claims)
for to_del in touched_or_deleted.deleted_claims:
if to_del in self._trending:
self._trending.pop(to_del)
self.log.info("advanced to %i, %i touched %i to delete (%i %i)", height, len(touched_or_deleted.touched_claims), len(touched_or_deleted.deleted_claims),
len(self._touched_claims), len(self._deleted_claims))
self._advanced = True
def unwind(self):
self.db.tx_counts.pop()
reverted_block_hash = self.db.coin.header_hash(self.db.headers.pop())
packed = self.db.prefix_db.undo.get(len(self.db.tx_counts), reverted_block_hash)
touched_or_deleted = None
claims_to_delete = []
# find and apply the touched_or_deleted items in the undos for the reverted blocks
assert packed, f'missing undo information for block {len(self.db.tx_counts)}'
while packed:
op, packed = RevertableOp.unpack(packed)
if op.is_delete and op.key.startswith(DB_PREFIXES.touched_or_deleted.value):
assert touched_or_deleted is None, 'only should have one match'
touched_or_deleted = self.db.prefix_db.touched_or_deleted.unpack_value(op.value)
elif op.is_delete and op.key.startswith(DB_PREFIXES.claim_to_txo.value):
v = self.db.prefix_db.claim_to_txo.unpack_value(op.value)
if v.root_tx_num == v.tx_num and v.root_tx_num > self.db.tx_counts[-1]:
claims_to_delete.append(self.db.prefix_db.claim_to_txo.unpack_key(op.key).claim_hash)
if touched_or_deleted:
self._touched_claims.update(set(touched_or_deleted.deleted_claims).union(
touched_or_deleted.touched_claims.difference(set(claims_to_delete))))
self._deleted_claims.update(claims_to_delete)
self._removed_during_undo.update(claims_to_delete)
self._advanced = True
self.log.warning("delete %i claim and upsert %i from reorg", len(self._deleted_claims), len(self._touched_claims))
async def poll_for_changes(self):
await super().poll_for_changes()
cnt = 0
success = 0
if self._advanced:
if self._touched_claims or self._deleted_claims or self._trending:
async for ok, item in async_streaming_bulk(
self.sync_client, self._claim_producer(),
raise_on_error=False):
cnt += 1
if not ok:
self.log.warning("indexing failed for an item: %s", item)
else:
success += 1
await self.sync_client.indices.refresh(self.index)
await self.db.reload_blocking_filtering_streams()
await self.apply_filters(
self.db.blocked_streams, self.db.blocked_channels, self.db.filtered_streams,
self.db.filtered_channels
)
self.write_es_height(self.db.db_height, self.db.db_tip[::-1].hex())
self.log.info("Indexing block %i done. %i/%i successful", self._last_wrote_height, success, cnt)
self._touched_claims.clear()
self._deleted_claims.clear()
self._removed_during_undo.clear()
self._trending.clear()
self._advanced = False
self.synchronized.set()
self.notify_es_notification_listeners(self._last_wrote_height, self.db.db_tip)
@property
def last_synced_height(self) -> int:
return self._last_wrote_height
async def start(self):
await super().start()
def _start_cancellable(run, *args):
_flag = asyncio.Event()
self.cancellable_tasks.append(asyncio.ensure_future(run(*args, _flag)))
return _flag.wait()
self.db.open_db()
await self.db.initialize_caches()
await self.start_index()
self.last_state = self.db.read_db_state()
await _start_cancellable(self.run_es_notifier)
await _start_cancellable(self.refresh_blocks_forever)
async def stop(self, delete_index=False):
async with self._lock:
while self.cancellable_tasks:
t = self.cancellable_tasks.pop()
if not t.done():
t.cancel()
if delete_index:
await self.delete_index()
await self.stop_index()
self._executor.shutdown(wait=True)
self._executor = None
def run(self):
loop = asyncio.get_event_loop()
def __exit():
raise SystemExit()
try:
loop.add_signal_handler(signal.SIGINT, __exit)
loop.add_signal_handler(signal.SIGTERM, __exit)
loop.run_until_complete(self.start())
loop.run_until_complete(self.shutdown_event.wait())
except (SystemExit, KeyboardInterrupt):
pass
finally:
loop.run_until_complete(self.stop())