Merge pull request #2995 from lbryio/batch-sync

Batched sync for wallet transactions
This commit is contained in:
Lex Berezhny 2020-07-14 23:33:13 -04:00 committed by GitHub
commit 516a8c5ee5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 253 additions and 102 deletions

View file

@ -3,7 +3,6 @@ import copy
import time import time
import asyncio import asyncio
import logging import logging
from io import StringIO
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
@ -164,6 +163,7 @@ class Ledger(metaclass=LedgerRegistry):
self._utxo_reservation_lock = asyncio.Lock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._history_lock = asyncio.Lock()
self.coin_selection_strategy = None self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set() self._known_addresses_out_of_sync = set()
@ -489,10 +489,10 @@ class Ledger(metaclass=LedgerRegistry):
address, remote_status = update address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status)) self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None): async def update_history(self, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True):
async with self._address_update_locks[address]: async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address) self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address) local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status: if local_status == remote_status:
@ -502,60 +502,94 @@ class Ledger(metaclass=LedgerRegistry):
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history) we_need = set(remote_history) - set(local_history)
if not we_need: if not we_need:
remote_missing = set(local_history) - set(remote_history)
if remote_missing:
log.warning(
"%i transactions we have for %s are not in the remote address history",
len(remote_missing), address
)
return True return True
cache_tasks: List[asyncio.Task[Transaction]] = [] acquire_lock_tasks = []
synced_history = StringIO()
loop = asyncio.get_running_loop()
for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:')
else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(loop.create_task(
self.cache_transaction(txid, remote_height, check_local=check_local)
))
synced_txs = [] synced_txs = []
for task in cache_tasks: to_request = {}
tx = await task pending_synced_history = {}
updated_cached_items = {}
already_synced = set()
check_db_for_txos = [] already_synced_offset = 0
for txi in tx.inputs: for i, (txid, remote_height) in enumerate(remote_history):
if txi.txo_ref.txo is not None: if i == already_synced_offset and i < len(local_history) and local_history[i] == (txid, remote_height):
pending_synced_history[i] = f'{txid}:{remote_height}:'
already_synced.add((txid, remote_height))
already_synced_offset += 1
continue continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) cache_item = self._tx_cache.get(txid)
if cache_item is not None: if cache_item is None:
if cache_item.tx is None: cache_item = TransactionCacheItem()
await cache_item.has_tx.wait() self._tx_cache[txid] = cache_item
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)
referenced_txos = {} if not check_db_for_txos else { for txid, remote_height in remote_history[already_synced_offset:]:
txo.id: txo for txo in await self.db.get_txos( cache_item = self._tx_cache[txid]
txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True acquire_lock_tasks.append(asyncio.create_task(cache_item.lock.acquire()))
if acquire_lock_tasks:
await asyncio.wait(acquire_lock_tasks)
tx_indexes = {}
for i, (txid, remote_height) in enumerate(remote_history):
tx_indexes[txid] = i
if (txid, remote_height) in already_synced:
continue
cache_item = self._tx_cache.get(txid)
cache_item.pending_verifications += 1
updated_cached_items[txid] = cache_item
assert cache_item is not None, 'cache item is none'
assert cache_item.lock.locked(), 'cache lock is not held?'
# tx = cache_item.tx
# if cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# synced_txs.append(cache_item.tx) # cached tx is already up-to-date
# pending_synced_history[i] = f'{tx.id}:{tx.height}:'
# continue
to_request[i] = (txid, remote_height)
log.debug(
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address
) )
} requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address)
for tx in requested_txes:
for txi in tx.inputs: pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx) synced_txs.append(tx)
assert len(pending_synced_history) == len(remote_history), \
f"{len(pending_synced_history)} vs {len(remote_history)}"
synced_history = ""
for remote_i, i in zip(range(len(remote_history)), sorted(pending_synced_history.keys())):
assert i == remote_i, f"{i} vs {remote_i}"
txid, height = remote_history[remote_i]
if f"{txid}:{height}:" != pending_synced_history[i]:
log.warning("history mismatch: %s vs %s", remote_history[remote_i], pending_synced_history[i])
synced_history += pending_synced_history[i]
cache_size = self.config.get("tx_cache_size", 100_000)
for txid, cache_item in updated_cached_items.items():
cache_item.pending_verifications -= 1
if cache_item.pending_verifications < 0:
log.warning("config value tx cache size %i needs to be increased", cache_size)
cache_item.pending_verifications = 0
try:
cache_item.lock.release()
except RuntimeError:
log.warning("lock was already released?")
await self.db.save_transaction_io_batch( await self.db.save_transaction_io_batch(
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue() [], address, self.address_to_hash160(address), synced_history
) )
await asyncio.wait([
self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])
if address_manager is None: if address_manager is None:
address_manager = await self.get_address_manager_for_address(address) address_manager = await self.get_address_manager_for_address(address)
@ -563,8 +597,16 @@ class Ledger(metaclass=LedgerRegistry):
if address_manager is not None: if address_manager is not None:
await address_manager.ensure_address_gap() await address_manager.ensure_address_gap()
for txid, cache_item in updated_cached_items.items():
if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False)
return False
local_status, local_history = \ local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue()) await self.get_local_status_and_history(address, synced_history)
if local_status != remote_status: if local_status != remote_status:
if local_history == remote_history: if local_history == remote_history:
log.warning( log.warning(
@ -590,6 +632,7 @@ class Ledger(metaclass=LedgerRegistry):
self._known_addresses_out_of_sync.add(address) self._known_addresses_out_of_sync.add(address)
return False return False
else: else:
log.debug("finished syncing transaction history for %s, %i known txs", address, len(local_history))
return True return True
async def cache_transaction(self, txid, remote_height, check_local=True): async def cache_transaction(self, txid, remote_height, check_local=True):
@ -601,32 +644,30 @@ class Ledger(metaclass=LedgerRegistry):
(cache_item.tx.is_verified or remote_height < 1): (cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date return cache_item.tx # cached tx is already up-to-date
try:
cache_item.pending_verifications += 1 cache_item.pending_verifications += 1
return await self._update_cache_item(cache_item, txid, remote_height, check_local) try:
finally:
cache_item.pending_verifications -= 1
async def _update_cache_item(self, cache_item, txid, remote_height, check_local=True):
async with cache_item.lock: async with cache_item.lock:
tx = cache_item.tx tx = cache_item.tx
if tx is None and check_local: if tx is None and check_local:
# check local db # check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid) tx = cache_item.tx = await self.db.get_transaction(txid=txid)
merkle = None merkle = None
if tx is None: if tx is None:
# fetch from network # fetch from network
_raw, merkle = await self.network.retriable_call( _raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, txid, remote_height self.network.get_transaction_and_merkle, txid, remote_height
) )
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) tx = Transaction(unhexlify(_raw), height=merkle['block_height'])
cache_item.tx = tx # make sure it's saved before caching it cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height, merkle) tx.height = remote_height
if merkle and 0 < remote_height < len(self.headers):
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
return tx return tx
finally:
cache_item.pending_verifications -= 1
async def maybe_verify_transaction(self, tx, remote_height, merkle=None): async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height tx.height = remote_height
@ -634,8 +675,8 @@ class Ledger(metaclass=LedgerRegistry):
if not cached: if not cached:
# cache txs looked up by transaction_show too # cache txs looked up by transaction_show too
cached = TransactionCacheItem() cached = TransactionCacheItem()
cached.tx = tx
self._tx_cache[tx.id] = cached self._tx_cache[tx.id] = cached
cached.tx = tx
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle: if not merkle:
@ -645,6 +686,100 @@ class Ledger(metaclass=LedgerRegistry):
tx.position = merkle['pos'] tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root'] tx.is_verified = merkle_root == header['merkle_root']
async def _request_transaction_batch(self, to_request, remote_history_size, address):
header_cache = {}
batches = [[]]
remote_heights = {}
synced_txs = []
heights_in_batch = 0
last_height = 0
for idx in sorted(to_request):
txid = to_request[idx][0]
height = to_request[idx][1]
remote_heights[txid] = height
if height != last_height:
heights_in_batch += 1
last_height = height
if len(batches[-1]) == 100 or heights_in_batch == 20:
batches.append([])
heights_in_batch = 1
batches[-1].append(txid)
if not batches[-1]:
batches.pop()
last_showed_synced_count = 0
async def _single_batch(batch):
this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = TransactionCacheItem()
self._tx_cache[txid] = cache_item
tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height)
tx.height = remote_height
cache_item.tx = tx
if 'merkle' in merkle and remote_heights[txid] > 0:
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
try:
header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height))
except IndexError:
log.warning("failed to verify %s at height %i", tx.id, merkle_height)
else:
header_cache[remote_heights[txid]] = header
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is not None:
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)
referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True
)
}
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
continue
cache_item = self._tx_cache.get(txi.txo_ref.id)
if cache_item is None:
cache_item = self._tx_cache[txi.txo_ref.id] = TransactionCacheItem()
if cache_item.tx is not None:
txi.txo_ref = cache_item.tx.ref
synced_txs.append(tx)
this_batch_synced.append(tx)
await self.db.save_transaction_io_batch(
this_batch_synced, address, self.address_to_hash160(address), ""
)
await asyncio.wait([
self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in this_batch_synced
])
nonlocal last_showed_synced_count
if last_showed_synced_count + 100 < len(synced_txs):
log.info("synced %i/%i transactions for %s", len(synced_txs), remote_history_size, address)
last_showed_synced_count = len(synced_txs)
for batch in batches:
await _single_batch(batch)
return synced_txs
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address) details = await self.db.get_address(address=address)
for account in self.accounts: for account in self.accounts:
@ -697,7 +832,7 @@ class Ledger(metaclass=LedgerRegistry):
local_height, height local_height, height
) )
return False return False
log.debug( log.warning(
"local history does not contain %s, requested height %i", tx.id, height "local history does not contain %s, requested height %i", tx.id, height
) )
return False return False

View file

@ -255,6 +255,10 @@ class Network:
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted) return self.rpc('blockchain.transaction.get', [tx_hash], restricted)
def get_transaction_batch(self, txids):
# use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, True)
def get_transaction_and_merkle(self, tx_hash, known_height=None): def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history # use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10

View file

@ -389,3 +389,17 @@ class MemPool:
if hX == hashX: if hX == hashX:
utxos.append(UTXO(-1, pos, tx_hash, 0, value)) utxos.append(UTXO(-1, pos, tx_hash, 0, value))
return utxos return utxos
def get_mempool_height(self, tx_hash):
# Height Progression
# -2: not broadcast
# -1: in mempool but has unconfirmed inputs
# 0: in mempool and all inputs confirmed
# +num: confirmed in a specific block (height)
if tx_hash not in self.txs:
return -2
tx = self.txs[tx_hash]
unspent_inputs = sum(1 if hash in self.txs else 0 for hash, idx in tx.prevouts)
if unspent_inputs:
return -1
return 0

View file

@ -1529,7 +1529,7 @@ class LBRYElectrumX(SessionBase):
block_hash = tx_info.get('blockhash') block_hash = tx_info.get('blockhash')
if not block_hash: if not block_hash:
return raw_tx, {'block_height': -1} return raw_tx, {'block_height': -1}
merkle_height = (await self.daemon_request('deserialised_block', block_hash))['height'] merkle_height = (await self.daemon.deserialised_block(block_hash))['height']
merkle = await self.transaction_merkle(tx_hash, merkle_height) merkle = await self.transaction_merkle(tx_hash, merkle_height)
return raw_tx, merkle return raw_tx, merkle
@ -1539,34 +1539,24 @@ class LBRYElectrumX(SessionBase):
for tx_hash in tx_hashes: for tx_hash in tx_hashes:
assert_tx_hash(tx_hash) assert_tx_hash(tx_hash)
batch_result = {} batch_result = {}
height = None
block_hash = None
block = None
for tx_hash in tx_hashes: for tx_hash in tx_hashes:
tx_info = await self.daemon_request('getrawtransaction', tx_hash, True) tx_info = await self.daemon_request('getrawtransaction', tx_hash, True)
raw_tx = tx_info['hex'] raw_tx = tx_info['hex']
if height is None: block_hash = tx_info.get('blockhash')
if 'blockhash' in tx_info: merkle = {}
block_hash = tx_info['blockhash'] if block_hash:
block = await self.daemon_request('deserialised_block', block_hash) block = await self.daemon.deserialised_block(block_hash)
height = block['height'] height = block['height']
else:
height = -1
if block_hash != tx_info.get('blockhash'):
raise RPCError(BAD_REQUEST, f'request contains a mix of transaction heights')
else:
if not block_hash:
merkle = {'block_height': -1}
else:
try: try:
pos = block['tx'].index(tx_hash) pos = block['tx'].index(tx_hash)
except ValueError: except ValueError:
raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in ' raise RPCError(BAD_REQUEST, f'tx hash {tx_hash} not in '
f'block {block_hash} at height {height:,d}') f'block {block_hash} at height {height:,d}')
merkle = { merkle["merkle"] = self._get_merkle_branch(block['tx'], pos)
"merkle": self._get_merkle_branch(block['tx'], pos), merkle["pos"] = pos
"pos": pos else:
} height = -1
merkle['block_height'] = height
batch_result[tx_hash] = [raw_tx, merkle] batch_result[tx_hash] = [raw_tx, merkle]
return batch_result return batch_result
@ -1592,7 +1582,7 @@ class LBRYElectrumX(SessionBase):
height = non_negative_integer(height) height = non_negative_integer(height)
hex_hashes = await self.daemon_request('block_hex_hashes', height, 1) hex_hashes = await self.daemon_request('block_hex_hashes', height, 1)
block_hash = hex_hashes[0] block_hash = hex_hashes[0]
block = await self.daemon_request('deserialised_block', block_hash) block = await self.daemon.deserialised_block(block_hash)
return block_hash, block['tx'] return block_hash, block['tx']
def _get_merkle_branch(self, tx_hashes, tx_pos): def _get_merkle_branch(self, tx_hashes, tx_pos):

View file

@ -35,6 +35,7 @@ disable=
too-many-statements, too-many-statements,
too-many-nested-blocks, too-many-nested-blocks,
too-many-public-methods, too-many-public-methods,
too-many-return-statements,
too-many-instance-attributes, too-many-instance-attributes,
protected-access, protected-access,
unused-argument unused-argument

View file

@ -35,11 +35,17 @@ class MockNetwork:
async def get_transaction_and_merkle(self, tx_hash, known_height=None): async def get_transaction_and_merkle(self, tx_hash, known_height=None):
tx = await self.get_transaction(tx_hash) tx = await self.get_transaction(tx_hash)
merkle = {} merkle = {'block_height': -1}
if known_height: if known_height:
merkle = await self.get_merkle(tx_hash, known_height) merkle = await self.get_merkle(tx_hash, known_height)
return tx, merkle return tx, merkle
async def get_transaction_batch(self, txids):
return {
txid: await self.get_transaction_and_merkle(txid)
for txid in txids
}
class LedgerTestCase(AsyncioTestCase): class LedgerTestCase(AsyncioTestCase):
@ -120,8 +126,9 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
for cache_item in self.ledger._tx_cache.values(): self.assertFalse(self.ledger._tx_cache[txid1].tx.is_verified)
cache_item.tx.is_verified = True self.assertFalse(self.ledger._tx_cache[txid2].tx.is_verified)
self.assertFalse(self.ledger._tx_cache[txid3].tx.is_verified)
await self.ledger.update_history(address, '') await self.ledger.update_history(address, '')
self.assertListEqual(self.ledger.network.get_history_called, [address]) self.assertListEqual(self.ledger.network.get_history_called, [address])
self.assertListEqual(self.ledger.network.get_transaction_called, []) self.assertListEqual(self.ledger.network.get_transaction_called, [])