lbry-sdk/lbry/wallet/server/db/elastic_sync.py

107 lines
3.3 KiB
Python
Raw Normal View History

2021-01-20 05:41:54 +01:00
import argparse
import asyncio
from collections import namedtuple
2021-01-27 05:43:06 +01:00
from multiprocessing import Process
2021-01-20 05:41:54 +01:00
import apsw
from elasticsearch import AsyncElasticsearch
from elasticsearch.helpers import async_bulk
from lbry.wallet.server.db.elastic_search import extract_doc, SearchIndex
INDEX = 'claims'
2021-01-27 05:43:06 +01:00
async def get_all(db, shard_num, shards_total):
2021-01-20 05:41:54 +01:00
def exec_factory(cursor, statement, bindings):
tpl = namedtuple('row', (d[0] for d in cursor.getdescription()))
cursor.setrowtrace(lambda cursor, row: tpl(*row))
return True
db.setexectrace(exec_factory)
2021-02-02 21:11:13 +01:00
total = db.execute(f"select count(*) as total from claim where height % {shards_total} = {shard_num};").fetchone()[0]
2021-01-20 05:41:54 +01:00
for num, claim in enumerate(db.execute(f"""
SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height,
2021-01-27 02:33:17 +01:00
(select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags,
2021-01-20 05:41:54 +01:00
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash)
2021-02-02 21:11:13 +01:00
WHERE claim.height % {shards_total} = {shard_num}
2021-01-20 05:41:54 +01:00
""")):
claim = dict(claim._asdict())
claim['censor_type'] = 0
claim['censoring_channel_hash'] = None
2021-01-27 02:33:17 +01:00
claim['tags'] = claim['tags'].split(',,') if claim['tags'] else []
2021-01-20 05:41:54 +01:00
claim['languages'] = claim['languages'].split(' ') if claim['languages'] else []
2021-01-27 02:26:45 +01:00
if num % 10_000 == 0:
print(num, total)
2021-01-20 05:41:54 +01:00
yield extract_doc(claim, INDEX)
async def consume(producer):
es = AsyncElasticsearch()
2021-02-12 05:10:30 +01:00
try:
await async_bulk(es, producer, request_timeout=120)
await es.indices.refresh(index=INDEX)
finally:
await es.close()
async def make_es_index():
es = AsyncElasticsearch()
try:
if await es.indices.exists(index=INDEX):
print("already synced ES")
return 1
index = SearchIndex('')
await index.start()
await index.stop()
return 0
finally:
await es.close()
2021-01-27 05:43:06 +01:00
async def run(args, shard):
2021-01-20 05:41:54 +01:00
db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
2021-02-02 21:11:13 +01:00
db.cursor().execute('pragma journal_mode=wal;')
db.cursor().execute('pragma temp_store=memory;')
2021-01-20 05:41:54 +01:00
index = SearchIndex('')
await index.start()
await index.stop()
2021-02-12 05:10:30 +01:00
2021-01-27 06:56:43 +01:00
producer = get_all(db.cursor(), shard, args.clients)
await asyncio.gather(*(consume(producer) for _ in range(min(8, args.clients))))
2021-01-27 05:43:06 +01:00
2021-02-12 05:10:30 +01:00
2021-01-27 05:43:06 +01:00
def __run(args, shard):
asyncio.run(run(args, shard))
2021-02-12 05:10:30 +01:00
def __make_index():
return asyncio.run(make_es_index())
def run_elastic_sync():
2021-01-27 05:43:06 +01:00
parser = argparse.ArgumentParser()
parser.add_argument("db_path", type=str)
parser.add_argument("-c", "--clients", type=int, default=16)
args = parser.parse_args()
processes = []
2021-02-12 05:10:30 +01:00
init_proc = Process(target=__make_index, args=())
init_proc.start()
init_proc.join()
exitcode = init_proc.exitcode
init_proc.close()
if exitcode:
print("ES is already initialized")
return
print("bulk-loading ES")
2021-01-27 05:43:06 +01:00
for i in range(args.clients):
processes.append(Process(target=__run, args=(args, i)))
processes[-1].start()
for process in processes:
process.join()
process.close()