diff --git a/lbry/wallet/server/db/elasticsearch/search.py b/lbry/wallet/server/db/elasticsearch/search.py index 6e79cb39e..302512082 100644 --- a/lbry/wallet/server/db/elasticsearch/search.py +++ b/lbry/wallet/server/db/elasticsearch/search.py @@ -162,6 +162,9 @@ class SearchIndex: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) + self.clear_caches() + + def clear_caches(self): self.search_cache.clear() self.short_id_cache.clear() self.claim_cache.clear() diff --git a/lbry/wallet/server/db/elasticsearch/sync.py b/lbry/wallet/server/db/elasticsearch/sync.py index 2ca7644f6..8e8134e09 100644 --- a/lbry/wallet/server/db/elasticsearch/sync.py +++ b/lbry/wallet/server/db/elasticsearch/sync.py @@ -12,10 +12,8 @@ from lbry.wallet.server.env import Env from lbry.wallet.server.coin import LBC from lbry.wallet.server.db.elasticsearch.search import extract_doc, SearchIndex, IndexVersionMismatch -INDEX = 'claims' - -async def get_all(db, shard_num, shards_total, limit=0): +async def get_all(db, shard_num, shards_total, limit=0, index_name='claims'): logging.info("shard %d starting", shard_num) def exec_factory(cursor, statement, bindings): tpl = namedtuple('row', (d[0] for d in cursor.getdescription())) @@ -44,25 +42,26 @@ ORDER BY claim.height desc claim['languages'] = claim['languages'].split(' ') if claim['languages'] else [] if num % 10_000 == 0: logging.info("%d/%d", num, total) - yield extract_doc(claim, INDEX) + yield extract_doc(claim, index_name) if 0 < limit <= num: break -async def consume(producer): +async def consume(producer, index_name): env = Env(LBC) logging.info("ES sync host: %s:%i", env.elastic_host, env.elastic_port) es = AsyncElasticsearch([{'host': env.elastic_host, 'port': env.elastic_port}]) try: await async_bulk(es, producer, request_timeout=120) - await es.indices.refresh(index=INDEX) + print(await es.indices.refresh(index=index_name)) finally: await es.close() -async def make_es_index(): +async def make_es_index(index=None): env = Env(LBC) - index = SearchIndex('', elastic_host=env.elastic_host, elastic_port=env.elastic_port) + if index is None: + index = SearchIndex('', elastic_host=env.elastic_host, elastic_port=env.elastic_port) try: return await index.start() @@ -76,21 +75,21 @@ async def make_es_index(): index.stop() -async def run(args, shard): +async def run(db_path, clients, blocks, shard, index_name='claims'): def itsbusy(*_): logging.info("shard %d: db is busy, retry", shard) return True - db = apsw.Connection(args.db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) + db = apsw.Connection(db_path, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) db.setbusyhandler(itsbusy) db.cursor().execute('pragma journal_mode=wal;') db.cursor().execute('pragma temp_store=memory;') - producer = get_all(db.cursor(), shard, args.clients, limit=args.blocks) - await asyncio.gather(*(consume(producer) for _ in range(min(8, args.clients)))) + producer = get_all(db.cursor(), shard, clients, limit=blocks, index_name=index_name) + await asyncio.gather(*(consume(producer, index_name=index_name) for _ in range(min(8, clients)))) def __run(args, shard): - asyncio.run(run(args, shard)) + asyncio.run(run(args.db_path, args.clients, args.blocks, shard)) def run_elastic_sync(): diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index b0a770558..f4b3db185 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -4,6 +4,7 @@ import lbry import lbry.wallet from lbry.error import ServerPaymentFeeAboveMaxAllowedError from lbry.wallet.network import ClientSession +from lbry.wallet.server.db.elasticsearch.sync import run as run_sync, make_es_index from lbry.wallet.server.session import LBRYElectrumX from lbry.testcase import IntegrationTestCase, CommandTestCase from lbry.wallet.orchstr8.node import SPVNode @@ -13,9 +14,6 @@ class TestSessions(IntegrationTestCase): """ Tests that server cleans up stale connections after session timeout and client times out too. """ - - LEDGER = lbry.wallet - async def test_session_bloat_from_socket_timeout(self): await self.conductor.stop_spv() await self.ledger.stop() @@ -87,3 +85,22 @@ class TestUsagePayment(CommandTestCase): self.assertIsNotNone(await self.blockchain.get_raw_transaction(tx.id)) # verify its broadcasted self.assertEqual(tx.outputs[0].amount, 100000000) self.assertEqual(tx.outputs[0].get_address(self.ledger), address) + + +class TestESSync(CommandTestCase): + VERBOSITY = 'DEBUG' + async def test_es_sync_utility(self): + for i in range(10): + await self.stream_create(f"stream{i}", bid='0.001') + await self.generate(1) + self.assertEqual(10, len(await self.claim_search(order_by=['height']))) + db = self.conductor.spv_node.server.db + await db.search_index.delete_index() + db.search_index.clear_caches() + self.assertEqual(0, len(await self.claim_search(order_by=['height']))) + await db.search_index.stop() + self.assertTrue(await make_es_index(db.search_index)) + await db.search_index.start() + db.search_index.clear_caches() + await run_sync(db.sql._db_path, 1, 0, 0, index_name=db.search_index.index) + self.assertEqual(10, len(await self.claim_search(order_by=['height'])))