fix mempool race condition in hub db writer

This commit is contained in:
Jack Robison 2022-01-26 13:43:41 -05:00
parent 7c46cc0805
commit c17544d8ef
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 60 additions and 64 deletions

View file

@ -100,6 +100,7 @@ class BlockProcessor:
self.ledger = TestNetLedger self.ledger = TestNetLedger
else: else:
self.ledger = RegTestLedger self.ledger = RegTestLedger
self.wait_for_blocks_duration = 0.1
self._caught_up_event: Optional[asyncio.Event] = None self._caught_up_event: Optional[asyncio.Event] = None
self.height = 0 self.height = 0
@ -200,29 +201,34 @@ class BlockProcessor:
return await asyncio.get_event_loop().run_in_executor(self._chain_executor, func, *args) return await asyncio.get_event_loop().run_in_executor(self._chain_executor, func, *args)
return await asyncio.shield(run_in_thread()) return await asyncio.shield(run_in_thread())
async def check_mempool(self): async def refresh_mempool(self):
def fetch_mempool(mempool_prefix): def fetch_mempool(mempool_prefix):
return { return {
k.tx_hash: v.raw_tx for (k, v) in mempool_prefix.iterate() k.tx_hash: v.raw_tx for (k, v) in mempool_prefix.iterate()
} }
def update_mempool(mempool_prefix, to_put, to_delete): def update_mempool(unsafe_commit, mempool_prefix, to_put, to_delete):
for tx_hash, raw_tx in to_put: for tx_hash, raw_tx in to_put:
mempool_prefix.stage_put((tx_hash,), (raw_tx,)) mempool_prefix.stage_put((tx_hash,), (raw_tx,))
for tx_hash, raw_tx in to_delete.items(): for tx_hash, raw_tx in to_delete.items():
mempool_prefix.stage_delete((tx_hash,), (raw_tx,)) mempool_prefix.stage_delete((tx_hash,), (raw_tx,))
unsafe_commit()
current_mempool = await self.run_in_thread_with_lock(fetch_mempool, self.db.prefix_db.mempool_tx) async with self.state_lock:
current_mempool = await self.run_in_thread(fetch_mempool, self.db.prefix_db.mempool_tx)
_to_put = [] _to_put = []
for hh in await self.daemon.mempool_hashes(): for hh in await self.daemon.mempool_hashes():
tx_hash = bytes.fromhex(hh)[::-1] tx_hash = bytes.fromhex(hh)[::-1]
if tx_hash in current_mempool: if tx_hash in current_mempool:
current_mempool.pop(tx_hash) current_mempool.pop(tx_hash)
else: else:
_to_put.append((tx_hash, bytes.fromhex(await self.daemon.getrawtransaction(hh)))) _to_put.append((tx_hash, bytes.fromhex(await self.daemon.getrawtransaction(hh))))
if current_mempool:
await self.run_in_thread_with_lock(update_mempool, self.db.prefix_db.mempool_tx, _to_put, current_mempool) if bytes.fromhex(await self.daemon.getbestblockhash())[::-1] != self.coin.header_hash(self.db.headers[-1]):
return
await self.run_in_thread(
update_mempool, self.db.prefix_db.unsafe_commit, self.db.prefix_db.mempool_tx, _to_put, current_mempool
)
async def check_and_advance_blocks(self, raw_blocks): async def check_and_advance_blocks(self, raw_blocks):
"""Process the list of raw blocks passed. Detects and handles """Process the list of raw blocks passed. Detects and handles
@ -1571,7 +1577,7 @@ class BlockProcessor:
await self._first_caught_up() await self._first_caught_up()
self._caught_up_event.set() self._caught_up_event.set()
try: try:
await asyncio.wait_for(self.blocks_event.wait(), 0.1) await asyncio.wait_for(self.blocks_event.wait(), self.wait_for_blocks_duration)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
self.blocks_event.clear() self.blocks_event.clear()
@ -1580,8 +1586,7 @@ class BlockProcessor:
break break
if not blocks: if not blocks:
try: try:
await self.check_mempool() await self.refresh_mempool()
await self.run_in_thread_with_lock(self.db.prefix_db.unsafe_commit)
except Exception: except Exception:
self.logger.exception("error while updating mempool txs") self.logger.exception("error while updating mempool txs")
raise raise
@ -1594,7 +1599,6 @@ class BlockProcessor:
finally: finally:
self._ready_to_stop.set() self._ready_to_stop.set()
async def _first_caught_up(self): async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}') self.logger.info(f'caught up to height {self.height}')
# Flush everything but with first_sync->False state. # Flush everything but with first_sync->False state.

View file

@ -107,6 +107,7 @@ class BlockchainReaderServer(BlockchainReader):
self.resolve_outputs_cache = {} self.resolve_outputs_cache = {}
self.resolve_cache = {} self.resolve_cache = {}
self.notifications_to_send = [] self.notifications_to_send = []
self.mempool_notifications = []
self.status_server = StatusServer() self.status_server = StatusServer()
self.daemon = env.coin.DAEMON(env.coin, env.daemon_url) # only needed for broadcasting txs self.daemon = env.coin.DAEMON(env.coin, env.daemon_url) # only needed for broadcasting txs
self.prometheus_server: typing.Optional[PrometheusServer] = None self.prometheus_server: typing.Optional[PrometheusServer] = None
@ -142,10 +143,7 @@ class BlockchainReaderServer(BlockchainReader):
def _detect_changes(self): def _detect_changes(self):
super()._detect_changes() super()._detect_changes()
self.mempool.raw_mempool.clear() self.mempool_notifications.append((self.db.fs_height, self.mempool.refresh()))
self.mempool.raw_mempool.update(
{k.tx_hash: v.raw_tx for k, v in self.db.prefix_db.mempool_tx.iterate()}
)
async def poll_for_changes(self): async def poll_for_changes(self):
await super().poll_for_changes() await super().poll_for_changes()
@ -158,7 +156,10 @@ class BlockchainReaderServer(BlockchainReader):
self.log.info("reader advanced to %i", height) self.log.info("reader advanced to %i", height)
if self._es_height == self.db.db_height: if self._es_height == self.db.db_height:
self.synchronized.set() self.synchronized.set()
await self.mempool.refresh_hashes(self.db.db_height) if self.mempool_notifications:
for (height, touched) in self.mempool_notifications:
await self.mempool.on_mempool(set(self.mempool.touched_hashXs), touched, height)
self.mempool_notifications.clear()
self.notifications_to_send.clear() self.notifications_to_send.clear()
async def receive_es_notifications(self, synchronized: asyncio.Event): async def receive_es_notifications(self, synchronized: asyncio.Event):

View file

@ -338,3 +338,8 @@ class LBCDaemon(Daemon):
async def getclaimsforname(self, name): async def getclaimsforname(self, name):
'''Given a name, retrieves all claims matching that name.''' '''Given a name, retrieves all claims matching that name.'''
return await self._send_single('getclaimsforname', (name,)) return await self._send_single('getclaimsforname', (name,))
@handles_errors
async def getbestblockhash(self):
'''Given a name, retrieves all claims matching that name.'''
return await self._send_single('getbestblockhash')

View file

@ -24,7 +24,7 @@ from lbry.wallet.server.merkle import Merkle, MerkleCache
from lbry.wallet.server.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO from lbry.wallet.server.db.common import ResolveResult, STREAM_TYPES, CLAIM_TYPES, ExpandedResolveResult, DBError, UTXO
from lbry.wallet.server.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, HubDB as Prefixes from lbry.wallet.server.db.prefixes import PendingActivationValue, ClaimTakeoverValue, ClaimToTXOValue, HubDB as Prefixes
from lbry.wallet.server.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE from lbry.wallet.server.db.prefixes import ACTIVATED_CLAIM_TXO_TYPE, ACTIVATED_SUPPORT_TXO_TYPE
from lbry.wallet.server.db.prefixes import PendingActivationKey, TXOToClaimValue, DBStatePrefixRow from lbry.wallet.server.db.prefixes import PendingActivationKey, TXOToClaimValue, DBStatePrefixRow, MempoolTXPrefixRow
from lbry.wallet.transaction import OutputScript from lbry.wallet.transaction import OutputScript
from lbry.schema.claim import Claim, guess_stream_type from lbry.schema.claim import Claim, guess_stream_type
from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger
@ -815,7 +815,7 @@ class HubDB:
self.prefix_db = Prefixes( self.prefix_db = Prefixes(
db_path, cache_mb=self._cache_MB, db_path, cache_mb=self._cache_MB,
reorg_limit=self._reorg_limit, max_open_files=self._db_max_open_files, reorg_limit=self._reorg_limit, max_open_files=self._db_max_open_files,
unsafe_prefixes={DBStatePrefixRow.prefix}, secondary_path=secondary_path unsafe_prefixes={DBStatePrefixRow.prefix, MempoolTXPrefixRow.prefix}, secondary_path=secondary_path
) )
if secondary_path != '': if secondary_path != '':

View file

@ -60,14 +60,12 @@ class MemPool:
self.mempool_process_time_metric = mempool_process_time_metric self.mempool_process_time_metric = mempool_process_time_metric
self.session_manager: typing.Optional['LBRYSessionManager'] = None self.session_manager: typing.Optional['LBRYSessionManager'] = None
async def refresh_hashes(self, height: int): def refresh(self) -> typing.Set[bytes]: # returns list of new touched hashXs
start = time.perf_counter() prefix_db = self._db.prefix_db
new_touched = await self._process_mempool() new_mempool = {k.tx_hash: v.raw_tx for k, v in prefix_db.mempool_tx.iterate()}
await self.on_mempool(set(self.touched_hashXs), new_touched, height) self.raw_mempool.clear()
duration = time.perf_counter() - start self.raw_mempool.update(new_mempool)
self.mempool_process_time_metric.observe(duration)
async def _process_mempool(self) -> typing.Set[bytes]: # returns list of new touched hashXs
# Re-sync with the new set of hashes # Re-sync with the new set of hashes
# hashXs = self.hashXs # hashX: [tx_hash, ...] # hashXs = self.hashXs # hashX: [tx_hash, ...]
@ -122,15 +120,15 @@ class MemPool:
elif prev_hash in tx_map: # this set of changes elif prev_hash in tx_map: # this set of changes
utxo = tx_map[prev_hash].out_pairs[prev_index] utxo = tx_map[prev_hash].out_pairs[prev_index]
else: # get it from the db else: # get it from the db
prev_tx_num = self._db.prefix_db.tx_num.get(prev_hash) prev_tx_num = prefix_db.tx_num.get(prev_hash)
if not prev_tx_num: if not prev_tx_num:
continue continue
prev_tx_num = prev_tx_num.tx_num prev_tx_num = prev_tx_num.tx_num
hashX_val = self._db.prefix_db.hashX_utxo.get(tx_hash[:4], prev_tx_num, prev_index) hashX_val = prefix_db.hashX_utxo.get(tx_hash[:4], prev_tx_num, prev_index)
if not hashX_val: if not hashX_val:
continue continue
hashX = hashX_val.hashX hashX = hashX_val.hashX
utxo_value = self._db.prefix_db.utxo.get(hashX, prev_tx_num, prev_index) utxo_value = prefix_db.utxo.get(hashX, prev_tx_num, prev_index)
utxo = (hashX, utxo_value.amount) utxo = (hashX, utxo_value.amount)
# if not prev_raw: # if not prev_raw:
# print("derp", prev_hash[::-1].hex()) # print("derp", prev_hash[::-1].hex())

View file

@ -921,6 +921,20 @@ class LBRYElectrumX(SessionBase):
def sub_count(self): def sub_count(self):
return len(self.hashX_subs) return len(self.hashX_subs)
async def get_hashX_status(self, hashX: bytes):
mempool_history = self.mempool.transaction_summaries(hashX)
history = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in await self.session_manager.limited_history(hashX))
history += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool_history)
if history:
status = sha256(history.encode()).hex()
else:
status = None
return history, status, len(mempool_history) > 0
async def send_history_notifications(self, *hashXes: typing.Iterable[bytes]): async def send_history_notifications(self, *hashXes: typing.Iterable[bytes]):
notifications = [] notifications = []
for hashX in hashXes: for hashX in hashXes:
@ -930,20 +944,8 @@ class LBRYElectrumX(SessionBase):
else: else:
method = 'blockchain.address.subscribe' method = 'blockchain.address.subscribe'
start = time.perf_counter() start = time.perf_counter()
db_history = await self.session_manager.limited_history(hashX) history, status, mempool_status = await self.get_hashX_status(hashX)
mempool = self.mempool.transaction_summaries(hashX) if mempool_status:
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in db_history)
status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status:
status = sha256(status.encode()).hex()
else:
status = None
if mempool:
self.session_manager.mempool_statuses[hashX] = status self.session_manager.mempool_statuses[hashX] = status
else: else:
self.session_manager.mempool_statuses.pop(hashX, None) self.session_manager.mempool_statuses.pop(hashX, None)
@ -1138,22 +1140,8 @@ class LBRYElectrumX(SessionBase):
""" """
# Note history is ordered and mempool unordered in electrum-server # Note history is ordered and mempool unordered in electrum-server
# For mempool, height is -1 if it has unconfirmed inputs, otherwise 0 # For mempool, height is -1 if it has unconfirmed inputs, otherwise 0
_, status, has_mempool_history = await self.get_hashX_status(hashX)
db_history = await self.session_manager.limited_history(hashX) if has_mempool_history:
mempool = self.mempool.transaction_summaries(hashX)
status = ''.join(f'{hash_to_hex_str(tx_hash)}:'
f'{height:d}:'
for tx_hash, height in db_history)
status += ''.join(f'{hash_to_hex_str(tx.hash)}:'
f'{-tx.has_unconfirmed_inputs:d}:'
for tx in mempool)
if status:
status = sha256(status.encode()).hex()
else:
status = None
if mempool:
self.session_manager.mempool_statuses[hashX] = status self.session_manager.mempool_statuses[hashX] = status
else: else:
self.session_manager.mempool_statuses.pop(hashX, None) self.session_manager.mempool_statuses.pop(hashX, None)