fix mempool race condition in hub db writer

This commit is contained in:
Jack Robison 2022-01-26 13:43:41 -05:00
parent 7c46cc0805
commit c17544d8ef
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 60 additions and 64 deletions

View file

@ -100,6 +100,7 @@ class BlockProcessor:
self.ledger = TestNetLedger
else:
self.ledger = RegTestLedger
self.wait_for_blocks_duration = 0.1
self._caught_up_event: Optional[asyncio.Event] = None
self.height = 0
@ -200,29 +201,34 @@ class BlockProcessor:
return await asyncio.get_event_loop().run_in_executor(self._chain_executor, func, *args)
return await asyncio.shield(run_in_thread())
async def check_mempool(self):
async def refresh_mempool(self):
def fetch_mempool(mempool_prefix):
return {
k.tx_hash: v.raw_tx for (k, v) in mempool_prefix.iterate()
}
def update_mempool(mempool_prefix, to_put, to_delete):
def update_mempool(unsafe_commit, mempool_prefix, to_put, to_delete):
for tx_hash, raw_tx in to_put:
mempool_prefix.stage_put((tx_hash,), (raw_tx,))
for tx_hash, raw_tx in to_delete.items():
mempool_prefix.stage_delete((tx_hash,), (raw_tx,))
unsafe_commit()
current_mempool = await self.run_in_thread_with_lock(fetch_mempool, self.db.prefix_db.mempool_tx)
_to_put = []
for hh in await self.daemon.mempool_hashes():
tx_hash = bytes.fromhex(hh)[::-1]
if tx_hash in current_mempool:
current_mempool.pop(tx_hash)
else:
_to_put.append((tx_hash, bytes.fromhex(await self.daemon.getrawtransaction(hh))))
await self.run_in_thread_with_lock(update_mempool, self.db.prefix_db.mempool_tx, _to_put, current_mempool)
async with self.state_lock:
current_mempool = await self.run_in_thread(fetch_mempool, self.db.prefix_db.mempool_tx)
_to_put = []
for hh in await self.daemon.mempool_hashes():
tx_hash = bytes.fromhex(hh)[::-1]
if tx_hash in current_mempool:
current_mempool.pop(tx_hash)
else:
_to_put.append((tx_hash, bytes.fromhex(await self.daemon.getrawtransaction(hh))))
if current_mempool:
if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.coin.header_hash(self.db.headers[-1]):
return
await self.run_in_thread(
update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool
)
async def check_and_advance_blocks(self, raw_blocks):
"""Process the list of raw blocks passed. Detects and handles
@ -1571,7 +1577,7 @@ class BlockProcessor:
await self._first_caught_up()
self._caught_up_event.set()
try:
await asyncio.wait_for(self.blocks_event.wait(), 0.1)
await asyncio.wait_for(self.blocks_event.wait(), self.wait_for_blocks_duration)
except asyncio.TimeoutError:
pass
self.blocks_event.clear()
@ -1580,8 +1586,7 @@ class BlockProcessor:
break
if not blocks:
try:
await self.check_mempool()
await self.run_in_thread_with_lock(self.db.prefix_db.unsafe_commit)
await self.refresh_mempool()
except Exception:
self.logger.exception("error while updating mempool txs")
raise
@ -1594,7 +1599,6 @@ class BlockProcessor:
finally:
self._ready_to_stop.set()
async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}')
# Flush everything but with first_sync->False state.

View file

@ -107,6 +107,7 @@ class BlockchainReaderServer(BlockchainReader):
self.resolve_outputs_cache = {}
self.resolve_cache = {}
self.notifications_to_send = []
self.mempool_notifications = []
self.status_server = StatusServer()
self.daemon = env.coin.DAEMON(env.coin, env.daemon_url) # only needed for broadcasting txs
self.prometheus_server: typing.Optional[PrometheusServer] = None
@ -142,10 +143,7 @@ class BlockchainReaderServer(BlockchainReader):
def _detect_changes(self):
super()._detect_changes()
self.mempool.raw_mempool.clear()
self.mempool.raw_mempool.update(
{k.tx_hash: v.raw_tx for k, v in self.db.prefix_db.mempool_tx.iterate()}
)
self.mempool_notifications.append((self.db.fs_height, self.mempool.refresh()))
async def poll_for_changes(self):
await super().poll_for_changes()
@ -158,7 +156,10 @@ class BlockchainReaderServer(BlockchainReader):
self.log.info("reader advanced to %i", height)
if self._es_height == self.db.db_height:
self.synchronized.set()
await self.mempool.refresh_hashes(self.db.db_height)
if self.mempool_notifications:
for (height, touched) in self.mempool_notifications:
await self.mempool.on_mempool(set(self.mempool.touched_hashXs), touched, height)
self.mempool_notifications.clear()
self.notifications_to_send.clear()
async def receive_es_notifications(self, synchronized: asyncio.Event):

View file

@ -338,3 +338,8 @@ class LBCDaemon(Daemon):
async def getclaimsforname(self, name):
'''Given a name, retrieves all claims matching that name.'''
return await self._send_single('getclaimsforname', (name,))
@handles_errors
async def getbestblockhash(self):
'''Given a name, retrieves all claims matching that name.'''
return await self._send_single('getbestblockhash')

View file

@ -24,7 +24,7 @@ from lbry.wallet.server.merkle import Merkle, MerkleCache
from lbry.wallet.server.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO
from lbry.wallet.server.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, HubDB as Prefixes
from lbry.wallet.server.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE
from lbry.wallet.server.db.prefixes import PendingActivationKey, TXOToClaimValue, DBStatePrefixRow
from lbry.wallet.server.db.prefixes import PendingActivationKey, TXOToClaimValue, DBStatePrefixRow, MempoolTXPrefixRow
from lbry.wallet.transaction import OutputScript
from lbry.schema.claim import Claim, guess_stream_type
from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger
@ -815,7 +815,7 @@ class HubDB:
self.prefix_db = Prefixes(
db_path, cache_mb=self._cache_MB,
reorg_limit=self._reorg_limit, max_open_files=self._db_max_open_files,
unsafe_prefixes={DBStatePrefixRow.prefix}, secondary_path=secondary_path
unsafe_prefixes={DBStatePrefixRow.prefix, MempoolTXPrefixRow.prefix}, secondary_path=secondary_path
)
if secondary_path != '':

View file

@ -60,14 +60,12 @@ class MemPool:
self.mempool_process_time_metric = mempool_process_time_metric
self.session_manager: typing.Optional['LBRYSessionManager'] = None
async def refresh_hashes(self, height: int):
start = time.perf_counter()
new_touched = await self._process_mempool()
await self.on_mempool(set(self.touched_hashXs), new_touched, height)
duration = time.perf_counter() - start
self.mempool_process_time_metric.observe(duration)
def refresh(self) -> typing.Set[bytes]: # returns list of new touched hashXs
prefix_db = self._db.prefix_db
new_mempool = {k.tx_hash: v.raw_tx for k, v in prefix_db.mempool_tx.iterate()}
self.raw_mempool.clear()
self.raw_mempool.update(new_mempool)
async def _process_mempool(self) -> typing.Set[bytes]: # returns list of new touched hashXs
# Re-sync with the new set of hashes
# hashXs = self.hashXs # hashX: [tx_hash, ...]
@ -122,15 +120,15 @@ class MemPool:
elif prev_hash in tx_map: # this set of changes
utxo = tx_map[prev_hash].out_pairs[prev_index]
else: # get it from the db
prev_tx_num = self._db.prefix_db.tx_num.get(prev_hash)
prev_tx_num = prefix_db.tx_num.get(prev_hash)
if not prev_tx_num:
continue
prev_tx_num = prev_tx_num.tx_num
hashX_val = self._db.prefix_db.hashX_utxo.get(tx_hash[:4], prev_tx_num, prev_index)
hashX_val = prefix_db.hashX_utxo.get(tx_hash[:4], prev_tx_num, prev_index)
if not hashX_val:
continue
hashX = hashX_val.hashX
utxo_value = self._db.prefix_db.utxo.get(hashX, prev_tx_num, prev_index)
utxo_value = prefix_db.utxo.get(hashX, prev_tx_num, prev_index)
utxo = (hashX, utxo_value.amount)
# if not prev_raw:
# print("derp", prev_hash[::-1].hex())

View file

@ -921,6 +921,20 @@ class LBRYElectrumX(SessionBase):
def sub_count(self):
return len(self.hashX_subs)
async def get_hashX_status(self, hashX: bytes):
mempool_history = self.mempool.transaction_summaries(hashX)
history = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in await self.session_manager.limited_history(hashX))
history += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool_history)
if history:
status = sha256(history.encode()).hex()
else:
status = None
return history, status, len(mempool_history) > 0
async def send_history_notifications(self, *hashXes: typing.Iterable[bytes]):
notifications = []
for hashX in hashXes:
@ -930,20 +944,8 @@ class LBRYElectrumX(SessionBase):
else:
method = 'blockchain.address.subscribe'
start = time.perf_counter()
db_history = await self.session_manager.limited_history(hashX)
mempool = self.mempool.transaction_summaries(hashX)
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in db_history)
status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status:
status = sha256(status.encode()).hex()
else:
status = None
if mempool:
history, status, mempool_status = await self.get_hashX_status(hashX)
if mempool_status:
self.session_manager.mempool_statuses[hashX] = status
else:
self.session_manager.mempool_statuses.pop(hashX, None)
@ -1138,22 +1140,8 @@ class LBRYElectrumX(SessionBase):
"""
# Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
db_history = await self.session_manager.limited_history(hashX)
mempool = self.mempool.transaction_summaries(hashX)
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in db_history)
status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status:
status = sha256(status.encode()).hex()
else:
status = None
if mempool:
_, status, has_mempool_history = await self.get_hashX_status(hashX)
if has_mempool_history:
self.session_manager.mempool_statuses[hashX] = status
else:
self.session_manager.mempool_statuses.pop(hashX, None)