diff --git a/lbry/wallet/server/db/elasticsearch/sync.py b/lbry/wallet/server/db/elasticsearch/sync.py index 1552c6900..83eba3ee6 100644 --- a/lbry/wallet/server/db/elasticsearch/sync.py +++ b/lbry/wallet/server/db/elasticsearch/sync.py @@ -10,61 +10,19 @@ from elasticsearch import AsyncElasticsearch from elasticsearch.helpers import async_bulk from lbry.wallet.server.env import Env from lbry.wallet.server.coin import LBC +from lbry.wallet.server.leveldb import LevelDB +from lbry.wallet.server.db.prefixes import Prefixes from lbry.wallet.server.db.elasticsearch.search import extract_doc, SearchIndex, IndexVersionMismatch -async def get_all(db, shard_num, shards_total, limit=0, index_name='claims'): - logging.info("shard %d starting", shard_num) - - def namedtuple_factory(cursor, row): - Row = namedtuple('Row', (d[0] for d in cursor.description)) - return Row(*row) - db.row_factory = namedtuple_factory - total = db.execute(f"select count(*) as total from claim where height % {shards_total} = {shard_num};").fetchone()[0] - for num, claim in enumerate(db.execute(f""" - SELECT claimtrie.claim_hash as is_controlling, - claimtrie.last_take_over_height, - (select group_concat(tag, ',,') from tag where tag.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as tags, - (select group_concat(language, ' ') from language where language.claim_hash in (claim.claim_hash, claim.reposted_claim_hash)) as languages, - cr.has_source as reposted_has_source, - cr.claim_type as reposted_claim_type, - cr.stream_type as reposted_stream_type, - cr.media_type as reposted_media_type, - cr.duration as reposted_duration, - cr.fee_amount as reposted_fee_amount, - cr.fee_currency as reposted_fee_currency, - claim.* - FROM claim LEFT JOIN claimtrie USING (claim_hash) LEFT JOIN claim cr ON cr.claim_hash=claim.reposted_claim_hash - WHERE claim.height % {shards_total} = {shard_num} - ORDER BY claim.height desc -""")): - claim = dict(claim._asdict()) - claim['has_source'] = bool(claim.pop('reposted_has_source') or claim['has_source']) - claim['stream_type'] = claim.pop('reposted_stream_type') or claim['stream_type'] - claim['media_type'] = claim.pop('reposted_media_type') or claim['media_type'] - claim['fee_amount'] = claim.pop('reposted_fee_amount') or claim['fee_amount'] - claim['fee_currency'] = claim.pop('reposted_fee_currency') or claim['fee_currency'] - claim['duration'] = claim.pop('reposted_duration') or claim['duration'] - claim['censor_type'] = 0 - claim['censoring_channel_id'] = None - claim['tags'] = claim['tags'].split(',,') if claim['tags'] else [] - claim['languages'] = claim['languages'].split(' ') if claim['languages'] else [] - if num % 10_000 == 0: - logging.info("%d/%d", num, total) - yield extract_doc(claim, index_name) - if 0 < limit <= num: - break - - -async def consume(producer, index_name): +async def get_all_claims(index_name='claims', db=None): env = Env(LBC) - logging.info("ES sync host: %s:%i", env.elastic_host, env.elastic_port) - es = AsyncElasticsearch([{'host': env.elastic_host, 'port': env.elastic_port}]) - try: - await async_bulk(es, producer, request_timeout=120) - await es.indices.refresh(index=index_name) - finally: - await es.close() + need_open = db is None + db = db or LevelDB(env) + if need_open: + await db.open_dbs() + for claim in db.all_claims_producer(): + yield extract_doc(claim, index_name) async def make_es_index(index=None): @@ -85,16 +43,19 @@ async def make_es_index(index=None): index.stop() -async def run(db_path, clients, blocks, shard, index_name='claims'): - db = sqlite3.connect(db_path, isolation_level=None, check_same_thread=False, uri=True) - db.execute('pragma journal_mode=wal;') - db.execute('pragma temp_store=memory;') - producer = get_all(db, shard, clients, limit=blocks, index_name=index_name) - await asyncio.gather(*(consume(producer, index_name=index_name) for _ in range(min(8, clients)))) +async def run_sync(index_name='claims', db=None): + env = Env(LBC) + logging.info("ES sync host: %s:%i", env.elastic_host, env.elastic_port) + es = AsyncElasticsearch([{'host': env.elastic_host, 'port': env.elastic_port}]) + try: + await async_bulk(es, get_all_claims(index_name=index_name, db=db), request_timeout=120) + await es.indices.refresh(index=index_name) + finally: + await es.close() def __run(args, shard): - asyncio.run(run(args.db_path, args.clients, args.blocks, shard)) + asyncio.run(run_sync()) def run_elastic_sync(): @@ -104,23 +65,17 @@ def run_elastic_sync(): logging.info('lbry.server starting') parser = argparse.ArgumentParser(prog="lbry-hub-elastic-sync") - parser.add_argument("db_path", type=str) + # parser.add_argument("db_path", type=str) parser.add_argument("-c", "--clients", type=int, default=16) parser.add_argument("-b", "--blocks", type=int, default=0) parser.add_argument("-f", "--force", default=False, action='store_true') args = parser.parse_args() - processes = [] - if not args.force and not os.path.exists(args.db_path): - logging.info("DB path doesnt exist") - return + # if not args.force and not os.path.exists(args.db_path): + # logging.info("DB path doesnt exist") + # return if not args.force and not asyncio.run(make_es_index()): logging.info("ES is already initialized") return - for i in range(args.clients): - processes.append(Process(target=__run, args=(args, i))) - processes[-1].start() - for process in processes: - process.join() - process.close() + asyncio.run(run_sync()) diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index 8cee0b4f1..541b6e72c 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -182,6 +182,9 @@ class ClaimSearchCommand(ClaimTestCase): claims = [three, two, signed] await self.assertFindsClaims(claims, channel_ids=[self.channel_id]) await self.assertFindsClaims(claims, channel=f"@abc#{self.channel_id}") + await self.assertFindsClaims(claims, channel=f"@abc#{self.channel_id}", valid_channel_signature=True) + await self.assertFindsClaims(claims, channel=f"@abc#{self.channel_id}", has_channel_signature=True, valid_channel_signature=True) + await self.assertFindsClaims([], channel=f"@abc#{self.channel_id}", has_channel_signature=True, invalid_channel_signature=True) # fixme await self.assertFindsClaims([], channel=f"@inexistent") await self.assertFindsClaims([three, two, signed2, signed], channel_ids=[channel_id2, self.channel_id]) await self.channel_abandon(claim_id=self.channel_id) diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index 0b079bbdc..31fa5273b 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -5,7 +5,7 @@ import lbry.wallet from lbry.error import ServerPaymentFeeAboveMaxAllowedError from lbry.wallet.network import ClientSession from lbry.wallet.rpc import RPCError -from lbry.wallet.server.db.elasticsearch.sync import run as run_sync, make_es_index +from lbry.wallet.server.db.elasticsearch.sync import run_sync, make_es_index from lbry.wallet.server.session import LBRYElectrumX from lbry.testcase import IntegrationTestCase, CommandTestCase from lbry.wallet.orchstr8.node import SPVNode @@ -104,8 +104,11 @@ class TestESSync(CommandTestCase): async def resync(): await db.search_index.start() db.search_index.clear_caches() - await run_sync(db.sql._db_path, 1, 0, 0, index_name=db.search_index.index) + await run_sync(index_name=db.search_index.index, db=db) self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + + self.assertEqual(0, len(await self.claim_search(order_by=['height']))) + await resync() # this time we will test a migration from unversioned to v1