make sync parallel

This commit is contained in:
Victor Shyba 2021-01-27 01:43:06 -03:00
parent e2441ea3e7
commit 7295b7e329

View file

@ -1,6 +1,7 @@
import argparse import argparse
import asyncio import asyncio
from collections import namedtuple from collections import namedtuple
from multiprocessing import Process
import apsw import apsw
from elasticsearch import AsyncElasticsearch from elasticsearch import AsyncElasticsearch
@ -11,14 +12,14 @@ from lbry.wallet.server.db.elastic_search import extract_doc, SearchIndex
INDEX = 'claims' INDEX = 'claims'
async def get_all(db): async def get_all(db, shard_num, shards_total):
def exec_factory(cursor, statement, bindings): def exec_factory(cursor, statement, bindings):
tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) tpl = namedtuple('row', (d[0] for d in cursor.getdescription()))
cursor.setrowtrace(lambda cursor, row: tpl(*row)) cursor.setrowtrace(lambda cursor, row: tpl(*row))
return True return True
db.setexectrace(exec_factory) db.setexectrace(exec_factory)
total = db.execute("select count(*) as total from claim;").fetchone()[0] total = db.execute(f"select count(*) as total from claim where rowid % {shards_total} = {shard_num};").fetchone()[0]
for num, claim in enumerate(db.execute(f""" for num, claim in enumerate(db.execute(f"""
SELECT claimtrie.claim_hash as is_controlling, SELECT claimtrie.claim_hash as is_controlling,
claimtrie.last_take_over_height, claimtrie.last_take_over_height,
@ -26,6 +27,7 @@ SELECT claimtrie.claim_hash as is_controlling,
(select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages,
claim.* claim.*
FROM claim LEFT JOIN claimtrie USING (claim_hash) FROM claim LEFT JOIN claimtrie USING (claim_hash)
WHERE claim.rowid % {shards_total} = {shard_num}
""")): """)):
claim = dict(claim._asdict()) claim = dict(claim._asdict())
claim['censor_type'] = 0 claim['censor_type'] = 0
@ -43,18 +45,30 @@ async def consume(producer):
await es.close() await es.close()
async def main(): async def run(args, shard):
parser = argparse.ArgumentParser()
parser.add_argument("db_path", type=str)
parser.add_argument("-c", "--clients", type=int, default=16)
args = parser.parse_args()
db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
index = SearchIndex('') index = SearchIndex('')
await index.start() await index.start()
await index.stop() await index.stop()
producer = get_all(db.cursor()) await consume(get_all(db.cursor(), shard, args.clients))
await asyncio.gather(*(consume(producer) for _ in range(args.clients)))
def __run(args, shard):
asyncio.run(run(args, shard))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("db_path", type=str)
parser.add_argument("-c", "--clients", type=int, default=16)
args = parser.parse_args()
processes = []
for i in range(args.clients):
processes.append(Process(target=__run, args=(args, i)))
processes[-1].start()
for process in processes:
process.join()
process.close()
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main()) main()