diff --git a/scripts/sync.py b/scripts/sync.py index e075ff2c5..69124f7f4 100644 --- a/scripts/sync.py +++ b/scripts/sync.py @@ -1,6 +1,7 @@ import argparse import asyncio from collections import namedtuple +from multiprocessing import Process import apsw from elasticsearch import AsyncElasticsearch @@ -11,14 +12,14 @@ from lbry.wallet.server.db.elastic_search import extract_doc, SearchIndex INDEX = 'claims' -async def get_all(db): +async def get_all(db, shard_num, shards_total): def exec_factory(cursor, statement, bindings): tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) cursor.setrowtrace(lambda cursor, row: tpl(*row)) return True db.setexectrace(exec_factory) - total = db.execute("select count(*) as total from claim;").fetchone()[0] + total = db.execute(f"select count(*) as total from claim where rowid % {shards_total} = {shard_num};").fetchone()[0] for num, claim in enumerate(db.execute(f""" SELECT claimtrie.claim_hash as is_controlling, claimtrie.last_take_over_height, @@ -26,6 +27,7 @@ SELECT claimtrie.claim_hash as is_controlling, (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, claim.* FROM claim LEFT JOIN claimtrie USING (claim_hash) +WHERE claim.rowid % {shards_total} = {shard_num} """)): claim = dict(claim._asdict()) claim['censor_type'] = 0 @@ -43,18 +45,30 @@ async def consume(producer): await es.close() -async def main(): - parser = argparse.ArgumentParser() - parser.add_argument("db_path", type=str) - parser.add_argument("-c", "--clients", type=int, default=16) - args = parser.parse_args() +async def run(args, shard): db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) index = SearchIndex('') await index.start() await index.stop() - producer = get_all(db.cursor()) - await asyncio.gather(*(consume(producer) for _ in range(args.clients))) + await consume(get_all(db.cursor(), shard, args.clients)) + +def __run(args, shard): + asyncio.run(run(args, shard)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("db_path", type=str) + parser.add_argument("-c", "--clients", type=int, default=16) + args = parser.parse_args() + processes = [] + for i in range(args.clients): + processes.append(Process(target=__run, args=(args, i))) + processes[-1].start() + for process in processes: + process.join() + process.close() if __name__ == '__main__': - asyncio.run(main()) + main()