working mempool

This commit is contained in:
Lex Berezhny 2020-09-25 10:22:06 -04:00
parent 9ce8910b42
commit 0a06810f36
4 changed files with 88 additions and 27 deletions

View file

@ -196,6 +196,7 @@ class Lbrycrd:
async def close_session(self): async def close_session(self):
if self._session is not None: if self._session is not None:
await self._session.close() await self._session.close()
self._session = None
async def start(self, *args): async def start(self, *args):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()

View file

@ -1,7 +1,6 @@
import asyncio
import logging import logging
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Tuple from typing import Tuple, List
from sqlalchemy import table, text, func, union, between from sqlalchemy import table, text, func, union, between
from sqlalchemy.future import select from sqlalchemy.future import select
@ -186,22 +185,25 @@ def clear_mempool(p: ProgressContext):
@event_emitter("blockchain.sync.mempool.main", "txs") @event_emitter("blockchain.sync.mempool.main", "txs")
def sync_mempool(p: ProgressContext): def sync_mempool(p: ProgressContext) -> List[str]:
chain = get_or_initialize_lbrycrd(p.ctx) chain = get_or_initialize_lbrycrd(p.ctx)
mempool = chain.sync_run(chain.get_raw_mempool()) mempool = chain.sync_run(chain.get_raw_mempool())
current = [hexlify(r['tx_hash'][::-1]) for r in p.ctx.fetchall( current = [hexlify(r['tx_hash'][::-1]).decode() for r in p.ctx.fetchall(
select(TX.c.tx_hash).where(TX.c.height < 0) select(TX.c.tx_hash).where(TX.c.height < 0)
)] )]
loader = p.ctx.get_bulk_loader() loader = p.ctx.get_bulk_loader()
added = []
for txid in mempool: for txid in mempool:
if txid not in current: if txid not in current:
raw_tx = chain.sync_run(chain.get_raw_transaction(txid)) raw_tx = chain.sync_run(chain.get_raw_transaction(txid))
loader.add_transaction( loader.add_transaction(
None, Transaction(unhexlify(raw_tx), height=-1) None, Transaction(unhexlify(raw_tx), height=-1)
) )
added.append(txid)
if p.ctx.stop_event.is_set(): if p.ctx.stop_event.is_set():
return return
loader.flush(TX) loader.flush(TX)
return added
@event_emitter("blockchain.sync.filters.generate", "blocks", throttle=100) @event_emitter("blockchain.sync.filters.generate", "blocks", throttle=100)
@ -305,8 +307,8 @@ def rewind(height: int, p: ProgressContext):
def delete_all_the_things(height: int, p: ProgressContext): def delete_all_the_things(height: int, p: ProgressContext):
def constrain(col): def constrain(col):
if height >= 0: if height == -1:
return col >= height return col == -1
return col <= height return col <= height
deletes = [ deletes = [

View file

@ -49,6 +49,7 @@ class BlockchainSync(Sync):
self.tx_hash_event = asyncio.Event() self.tx_hash_event = asyncio.Event()
self._on_mempool_controller = EventController() self._on_mempool_controller = EventController()
self.on_mempool = self._on_mempool_controller.stream self.on_mempool = self._on_mempool_controller.stream
self.mempool = []
async def wait_for_chain_ready(self): async def wait_for_chain_ready(self):
while True: while True:
@ -352,10 +353,17 @@ class BlockchainSync(Sync):
await self._on_block_controller.add(BlockEvent(blocks_added[-1])) await self._on_block_controller.add(BlockEvent(blocks_added[-1]))
async def sync_mempool(self): async def sync_mempool(self):
await self.db.run(block_phase.sync_mempool) added = await self.db.run(block_phase.sync_mempool)
await self.sync_spends([-1]) await self.sync_spends([-1])
await self.db.run(claim_phase.claims_insert, [-2, 0], True, self.CLAIM_FLUSH_SIZE) await self.db.run(claim_phase.claims_insert, [-1, -1], True, self.CLAIM_FLUSH_SIZE)
await self.db.run(claim_phase.claims_update, [-1, -1])
await self.db.run(claim_phase.claims_vacuum) await self.db.run(claim_phase.claims_vacuum)
self.mempool.extend(added)
await self._on_mempool_controller.add(added)
async def clear_mempool(self):
self.mempool.clear()
await self.db.run(block_phase.clear_mempool)
async def advance_loop(self): async def advance_loop(self):
while True: while True:
@ -366,7 +374,7 @@ class BlockchainSync(Sync):
], return_when=asyncio.FIRST_COMPLETED) ], return_when=asyncio.FIRST_COMPLETED)
if self.block_hash_event.is_set(): if self.block_hash_event.is_set():
self.block_hash_event.clear() self.block_hash_event.clear()
await self.db.run(block_phase.clear_mempool) await self.clear_mempool()
await self.advance() await self.advance()
self.tx_hash_event.clear() self.tx_hash_event.clear()
await self.sync_mempool() await self.sync_mempool()

View file

@ -41,15 +41,15 @@ class BasicBlockchainTestCase(AsyncioTestCase):
async def make_db(self, chain): async def make_db(self, chain):
db_driver = os.environ.get('TEST_DB', 'sqlite') db_driver = os.environ.get('TEST_DB', 'sqlite')
if db_driver == 'sqlite': if db_driver == 'sqlite':
db = Database.temp_sqlite_regtest(chain.ledger.conf.lbrycrd_dir) db = Database.temp_sqlite_regtest(chain.ledger.conf)
elif db_driver.startswith('postgres') or db_driver.startswith('psycopg'): elif db_driver.startswith('postgres'):
db_driver = 'postgresql' db_driver = 'postgresql'
db_name = f'lbry_test_chain' db_name = f'lbry_test_chain'
db_connection = 'postgres:postgres@localhost:5432' db_connection = 'postgres:postgres@localhost:5432'
meta_db = Database.from_url(f'postgresql://{db_connection}/postgres') meta_db = Database.from_url(f'postgresql://{db_connection}/postgres')
await meta_db.drop(db_name) await meta_db.drop(db_name)
await meta_db.create(db_name) await meta_db.create(db_name)
db = Database.temp_from_url_regtest(f'postgresql://{db_connection}/{db_name}', chain.ledger.conf.lbrycrd_dir) db = Database.temp_from_url_regtest(f'postgresql://{db_connection}/{db_name}', chain.ledger.conf)
else: else:
raise RuntimeError(f"Unsupported database driver: {db_driver}") raise RuntimeError(f"Unsupported database driver: {db_driver}")
self.addCleanup(remove_tree, db.ledger.conf.data_dir) self.addCleanup(remove_tree, db.ledger.conf.data_dir)
@ -185,7 +185,6 @@ class SyncingBlockchainTestCase(BasicBlockchainTestCase):
self.find_claim_txo(tx).sign(sign) self.find_claim_txo(tx).sign(sign)
tx._reset() tx._reset()
signed = await self.chain.sign_raw_transaction_with_wallet(hexlify(tx.raw).decode()) signed = await self.chain.sign_raw_transaction_with_wallet(hexlify(tx.raw).decode())
tx = Transaction(unhexlify(signed['hex']))
return await self.chain.send_raw_transaction(signed['hex']) return await self.chain.send_raw_transaction(signed['hex'])
async def abandon_claim(self, txid: str) -> str: async def abandon_claim(self, txid: str) -> str:
@ -276,7 +275,8 @@ class SyncingBlockchainTestCase(BasicBlockchainTestCase):
accepted = [] accepted = []
for txo in await self.db.search_claims( for txo in await self.db.search_claims(
activation_height__gt=self.current_height, activation_height__gt=self.current_height,
expiration_height__gt=self.current_height): expiration_height__gt=self.current_height,
order_by=['^activation_height']):
accepted.append(( accepted.append((
txo.claim.stream.title, dewies_to_lbc(txo.amount), txo.claim.stream.title, dewies_to_lbc(txo.amount),
dewies_to_lbc(txo.meta['staked_amount']), txo.meta['activation_height'] dewies_to_lbc(txo.meta['staked_amount']), txo.meta['activation_height']
@ -303,7 +303,7 @@ class TestLbrycrdAPIs(AsyncioTestCase):
async def test_zmq(self): async def test_zmq(self):
chain = Lbrycrd.temp_regtest() chain = Lbrycrd.temp_regtest()
chain.ledger.conf.set(lbrycrd_zmq_blocks='') chain.ledger.conf.set(lbrycrd_zmq='')
await chain.ensure() await chain.ensure()
self.addCleanup(chain.stop) self.addCleanup(chain.stop)
@ -313,20 +313,20 @@ class TestLbrycrdAPIs(AsyncioTestCase):
await chain.ensure_subscribable() await chain.ensure_subscribable()
await chain.stop() await chain.stop()
# lbrycrdr started with zmq, ensure_subscribable updates lbrycrd_zmq_blocks config # lbrycrdr started with zmq, ensure_subscribable updates lbrycrd_zmq config
await chain.start('-zmqpubhashblock=tcp://127.0.0.1:29005') await chain.start('-zmqpubhashblock=tcp://127.0.0.1:29005')
self.assertEqual(chain.ledger.conf.lbrycrd_zmq_blocks, '') self.assertEqual(chain.ledger.conf.lbrycrd_zmq, '')
await chain.ensure_subscribable() await chain.ensure_subscribable()
self.assertEqual(chain.ledger.conf.lbrycrd_zmq_blocks, 'tcp://127.0.0.1:29005') self.assertEqual(chain.ledger.conf.lbrycrd_zmq, 'tcp://127.0.0.1:29005')
await chain.stop() await chain.stop()
# lbrycrdr started with zmq, ensure_subscribable does not override lbrycrd_zmq_blocks config # lbrycrdr started with zmq, ensure_subscribable does not override lbrycrd_zmq config
chain.ledger.conf.set(lbrycrd_zmq_blocks='') chain.ledger.conf.set(lbrycrd_zmq='')
await chain.start('-zmqpubhashblock=tcp://127.0.0.1:29005') await chain.start('-zmqpubhashblock=tcp://127.0.0.1:29005')
self.assertEqual(chain.ledger.conf.lbrycrd_zmq_blocks, '') self.assertEqual(chain.ledger.conf.lbrycrd_zmq, '')
chain.ledger.conf.set(lbrycrd_zmq_blocks='tcp://external-ip:29005') chain.ledger.conf.set(lbrycrd_zmq='tcp://external-ip:29005')
await chain.ensure_subscribable() await chain.ensure_subscribable()
self.assertEqual(chain.ledger.conf.lbrycrd_zmq_blocks, 'tcp://external-ip:29005') self.assertEqual(chain.ledger.conf.lbrycrd_zmq, 'tcp://external-ip:29005')
async def test_block_event(self): async def test_block_event(self):
chain = Lbrycrd.temp_regtest() chain = Lbrycrd.temp_regtest()
@ -337,9 +337,9 @@ class TestLbrycrdAPIs(AsyncioTestCase):
msgs = [] msgs = []
await chain.subscribe() await chain.subscribe()
chain.on_block.listen(lambda e: msgs.append(e['msg'])) chain.on_block_hash.listen(lambda e: msgs.append(e['msg']))
res = await chain.generate(5) res = await chain.generate(5)
await chain.on_block.where(lambda e: e['msg'] == 4) await chain.on_block_hash.where(lambda e: e['msg'] == 4)
self.assertEqual([0, 1, 2, 3, 4], msgs) self.assertEqual([0, 1, 2, 3, 4], msgs)
self.assertEqual(5, len(res)) self.assertEqual(5, len(res))
@ -350,7 +350,7 @@ class TestLbrycrdAPIs(AsyncioTestCase):
await chain.subscribe() await chain.subscribe()
res = await chain.generate(3) res = await chain.generate(3)
await chain.on_block.where(lambda e: e['msg'] == 9) await chain.on_block_hash.where(lambda e: e['msg'] == 9)
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
self.assertEqual([ self.assertEqual([
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
@ -613,12 +613,13 @@ class TestMultiBlockFileSyncing(BasicBlockchainTestCase):
class TestGeneralBlockchainSync(SyncingBlockchainTestCase): class TestGeneralBlockchainSync(SyncingBlockchainTestCase):
async def test_sync_waits_for_lbrycrd_to_start_but_exits_if_zmq_misconfigured(self): async def test_sync_waits_for_lbrycrd_to_start_but_exits_if_zmq_misconfigured(self):
await self.sync.stop() await self.sync.stop()
await self.chain.stop() await self.chain.stop()
sync_start = asyncio.create_task(self.sync.start()) sync_start = asyncio.create_task(self.sync.start())
await asyncio.sleep(0) await asyncio.sleep(0)
self.chain.ledger.conf.set(lbrycrd_zmq_blocks='') self.chain.ledger.conf.set(lbrycrd_zmq='')
await self.chain.start() await self.chain.start()
with self.assertRaises(LbrycrdEventSubscriptionError): with self.assertRaises(LbrycrdEventSubscriptionError):
await asyncio.wait_for(sync_start, timeout=10) await asyncio.wait_for(sync_start, timeout=10)
@ -643,6 +644,55 @@ class TestGeneralBlockchainSync(SyncingBlockchainTestCase):
self.assertEqual([110], [b.height for b in blocks]) self.assertEqual([110], [b.height for b in blocks])
self.assertEqual(110, self.current_height) self.assertEqual(110, self.current_height)
async def test_mempool(self):
search = self.db.search_claims
# send claim to mempool
self.assertEqual(self.sync.mempool, [])
txid = await self.create_claim('foo', '0.01')
await self.sync.on_mempool.where(lambda e: e == [txid])
self.assertEqual(self.sync.mempool, [txid])
# mempool claims are searchable
claims = await search()
self.assertEqual(1, len(claims))
self.assertEqual(claims[0].amount, lbc_to_dewies('0.01'))
self.assertEqual(claims[0].tx_ref.id, txid)
self.assertEqual(claims[0].tx_ref.height, -1)
# move claim into a block
await self.generate(1)
# still searchable with updated values
claims = await search()
self.assertEqual(1, len(claims))
self.assertEqual(claims[0].amount, lbc_to_dewies('0.01'))
self.assertEqual(claims[0].tx_ref.id, txid)
self.assertEqual(claims[0].tx_ref.height, 102)
# send claim update to mempool
self.assertEqual(self.sync.mempool, [])
txid = await self.update_claim(claims[0], '0.02')
await self.sync.on_mempool.where(lambda e: e == [txid])
self.assertEqual(self.sync.mempool, [txid])
# update takes affect from mempool
claims = await search()
self.assertEqual(1, len(claims))
self.assertEqual(claims[0].amount, lbc_to_dewies('0.02'))
self.assertEqual(claims[0].tx_ref.id, txid)
self.assertEqual(claims[0].tx_ref.height, -1)
# move claim into a block
await self.generate(1)
# update makes it into a block
claims = await search()
self.assertEqual(1, len(claims))
self.assertEqual(claims[0].amount, lbc_to_dewies('0.02'))
self.assertEqual(claims[0].tx_ref.id, txid)
self.assertEqual(claims[0].tx_ref.height, 103)
async def test_claim_create_update_and_delete(self): async def test_claim_create_update_and_delete(self):
search = self.db.search_claims search = self.db.search_claims
await self.create_claim('foo', '0.01') await self.create_claim('foo', '0.01')