lbry-sdk/lbry/wallet/server/db/elasticsearch/sync.py

117 lines
4.6 KiB
Python
Raw Normal View History

2021-01-20 05:41:54 +01:00
import argparse
import asyncio
import logging
2021-02-16 16:52:32 +01:00
import os
2021-01-20 05:41:54 +01:00
from collections import namedtuple
2021-01-27 05:43:06 +01:00
from multiprocessing import Process
2021-01-20 05:41:54 +01:00
import sqlite3
2021-01-20 05:41:54 +01:00
from elasticsearch import AsyncElasticsearch
from elasticsearch.helpers import async_bulk
from lbry.wallet.server.env import Env
from lbry.wallet.server.coin import LBC
from lbry.wallet.server.db.elasticsearch.search import extract_doc, SearchIndex, IndexVersionMismatch
2021-01-20 05:41:54 +01:00
2021-05-12 02:38:05 +02:00
async def get_all(db, shard_num, shards_total, limit=0, index_name='claims'):
logging.info("shard %d starting", shard_num)
2021-01-20 05:41:54 +01:00
def namedtuple_factory(cursor, row):
Row = namedtuple('Row', (d[0] for d in cursor.description))
return Row(*row)
db.row_factory = namedtuple_factory
2021-02-02 21:11:13 +01:00
total = db.execute(f"select count(*) as total from claim where height % {shards_total} = {shard_num};").fetchone()[0]
2021-01-20 05:41:54 +01:00
for num, claim in enumerate(db.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
2021-01-27 02:33:17 +01:00
(select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
2021-01-20 05:41:54 +01:00
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
2021-03-25 08:46:21 +01:00
(select cr.has_source from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_has_source,
2021-05-03 23:40:21 +02:00
(select cr.claim_type from claim cr where cr.claim_hash = claim.reposted_claim_hash) as reposted_claim_type,
2021-01-20 05:41:54 +01:00
claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash)
2021-02-02 21:11:13 +01:00
WHERE claim.height % {shards_total} = {shard_num}
ORDER BY claim.height desc
2021-01-20 05:41:54 +01:00
""")):
claim = dict(claim._asdict())
2021-03-25 08:46:21 +01:00
claim['has_source'] = bool(claim.pop('reposted_has_source') or claim['has_source'])
2021-01-20 05:41:54 +01:00
claim['censor_type'] = 0
claim['censoring_channel_hash'] = None
2021-01-27 02:33:17 +01:00
claim['tags'] = claim['tags'].split(',,') if claim['tags'] else []
2021-01-20 05:41:54 +01:00
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
2021-01-27 02:26:45 +01:00
if num % 10_000 == 0:
logging.info("%d/%d", num, total)
2021-05-12 02:38:05 +02:00
yield extract_doc(claim, index_name)
if 0 < limit <= num:
break
2021-01-20 05:41:54 +01:00
2021-05-12 02:38:05 +02:00
async def consume(producer, index_name):
env = Env(LBC)
logging.info("ES sync host: %s:%i", env.elastic_host, env.elastic_port)
es = AsyncElasticsearch([{'host': env.elastic_host, 'port': env.elastic_port}])
2021-02-12 05:10:30 +01:00
try:
await async_bulk(es, producer, request_timeout=120)
2021-05-12 05:21:03 +02:00
await es.indices.refresh(index=index_name)
2021-02-12 05:10:30 +01:00
finally:
await es.close()
2021-05-12 02:38:05 +02:00
async def make_es_index(index=None):
env = Env(LBC)
2021-05-12 02:38:05 +02:00
if index is None:
index = SearchIndex('', elastic_host=env.elastic_host, elastic_port=env.elastic_port)
2021-02-12 05:10:30 +01:00
try:
return await index.start()
except IndexVersionMismatch as err:
logging.info(
"dropping ES search index (version %s) for upgrade to version %s", err.got_version, err.expected_version
)
await index.delete_index()
2021-05-12 05:21:03 +02:00
await index.stop()
return await index.start()
2021-02-12 05:10:30 +01:00
finally:
index.stop()
2021-05-12 02:38:05 +02:00
async def run(db_path, clients, blocks, shard, index_name='claims'):
db = sqlite3.connect(db_path, isolation_level=None, check_same_thread=False, uri=True)
db.execute('pragma journal_mode=wal;')
db.execute('pragma temp_store=memory;')
producer = get_all(db, shard, clients, limit=blocks, index_name=index_name)
2021-05-12 02:38:05 +02:00
await asyncio.gather(*(consume(producer, index_name=index_name) for _ in range(min(8, clients))))
2021-01-27 05:43:06 +01:00
2021-02-12 05:10:30 +01:00
2021-01-27 05:43:06 +01:00
def __run(args, shard):
2021-05-12 02:38:05 +02:00
asyncio.run(run(args.db_path, args.clients, args.blocks, shard))
2021-01-27 05:43:06 +01:00
2021-02-12 05:10:30 +01:00
def run_elastic_sync():
logging.basicConfig(level=logging.INFO)
logging.getLogger('aiohttp').setLevel(logging.WARNING)
logging.getLogger('elasticsearch').setLevel(logging.WARNING)
logging.info('lbry.server starting')
2021-03-24 21:07:17 +01:00
parser = argparse.ArgumentParser(prog="lbry-hub-elastic-sync")
2021-01-27 05:43:06 +01:00
parser.add_argument("db_path", type=str)
parser.add_argument("-c", "--clients", type=int, default=16)
parser.add_argument("-b", "--blocks", type=int, default=0)
parser.add_argument("-f", "--force", default=False, action='store_true')
2021-01-27 05:43:06 +01:00
args = parser.parse_args()
processes = []
2021-02-12 05:10:30 +01:00
if not args.force and not os.path.exists(args.db_path):
2021-02-16 16:52:32 +01:00
logging.info("DB path doesnt exist")
return
if not args.force and not asyncio.run(make_es_index()):
logging.info("ES is already initialized")
2021-02-12 05:10:30 +01:00
return
2021-01-27 05:43:06 +01:00
for i in range(args.clients):
processes.append(Process(target=__run, args=(args, i)))
processes[-1].start()
for process in processes:
process.join()
process.close()