From 1b7c5a1373a01d4d6936d92b736b1c83b7c1487f Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 5 Dec 2018 11:02:52 -0500 Subject: [PATCH] update_history serially handles unique addresses --- .../integration/test_transactions.py | 21 ++-- torba/client/baseledger.py | 107 +++++++++--------- torba/testcase.py | 14 ++- 3 files changed, 76 insertions(+), 66 deletions(-) diff --git a/tests/client_tests/integration/test_transactions.py b/tests/client_tests/integration/test_transactions.py index 5c946a8d5..611816be3 100644 --- a/tests/client_tests/integration/test_transactions.py +++ b/tests/client_tests/integration/test_transactions.py @@ -18,11 +18,16 @@ class BasicTransactionTests(IntegrationTestCase): address1 = await self.account.receiving.get_or_create_usable_address() hash1 = self.ledger.address_to_hash160(address1) - tasks = [] - for _ in range(10): - sendtxid = await self.blockchain.send_to_address(address1, 100) - tasks.append(self.on_transaction_id(sendtxid)) - await asyncio.wait(tasks) + txids = await asyncio.gather(*( + self.blockchain.send_to_address(address1, 100) + for _ in range(10) + )) + + await asyncio.wait([ + self.on_transaction_id(txid) + for txid in txids + ]) + await self.assertBalance(self.account, '1000.0') tasks = [] @@ -37,11 +42,7 @@ class BasicTransactionTests(IntegrationTestCase): await asyncio.wait(tasks) - #await asyncio.sleep(5) - - await self.assertBalance(self.account, '1000.0') - - await self.blockchain.generate(1) + await self.assertBalance(self.account, '999.99876') async def test_sending_and_receiving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) diff --git a/torba/client/baseledger.py b/torba/client/baseledger.py index 35a792b62..f3bd6cf73 100644 --- a/torba/client/baseledger.py +++ b/torba/client/baseledger.py @@ -75,20 +75,21 @@ class TransactionCacheItem: class SynchronizationMonitor: - def __init__(self): + def __init__(self, loop=None): self.done = asyncio.Event() self.tasks = [] + self.loop = loop or asyncio.get_event_loop() def add(self, coro): len(self.tasks) < 1 and self.done.clear() - asyncio.ensure_future(self._monitor(coro)) + self.loop.create_task(self._monitor(coro)) def cancel(self): for task in self.tasks: task.cancel() async def _monitor(self, coro): - task = asyncio.ensure_future(coro) + task = self.loop.create_task(coro) self.tasks.append(task) try: await task @@ -161,6 +162,7 @@ class BaseLedger(metaclass=LedgerRegistry): self.sync = SynchronizationMonitor() self._utxo_reservation_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock() + self._address_update_locks: Dict[str, asyncio.Lock] = {} @classmethod def get_id(cls): @@ -382,63 +384,66 @@ class BaseLedger(metaclass=LedgerRegistry): async def update_history(self, address, remote_status, address_manager: baseaccount.AddressManager = None): - local_status, local_history = await self.get_local_status_and_history(address) - if local_status == remote_status: - return + async with self._address_update_locks.setdefault(address, asyncio.Lock()): - remote_history = await self.network.get_history(address) + local_status, local_history = await self.get_local_status_and_history(address) - cache_tasks = [] - synced_history = StringIO() - for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): - if i < len(local_history) and local_history[i] == (txid, remote_height): - synced_history.write(f'{txid}:{remote_height}:') - else: - cache_tasks.append(asyncio.ensure_future( - self.cache_transaction(txid, remote_height) - )) + if local_status == remote_status: + return - for task in cache_tasks: - tx = await task + remote_history = await self.network.get_history(address) - check_db_for_txos = [] - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) - if cache_item is not None: - if cache_item.tx is None: - await cache_item.has_tx.wait() - assert cache_item.tx is not None - txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + cache_tasks = [] + synced_history = StringIO() + for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): + if i < len(local_history) and local_history[i] == (txid, remote_height): + synced_history.write(f'{txid}:{remote_height}:') else: - check_db_for_txos.append(txi.txo_ref.tx_ref.id) + cache_tasks.append(asyncio.ensure_future( + self.cache_transaction(txid, remote_height) + )) - referenced_txos = { - txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos) - } + for task in cache_tasks: + tx = await task - for txi in tx.inputs: - if txi.txo_ref.txo is not None: - continue - referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id) - if referenced_txo is not None: - txi.txo_ref = referenced_txo.ref + check_db_for_txos = [] + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) + if cache_item is not None: + if cache_item.tx is None: + await cache_item.has_tx.wait() + assert cache_item.tx is not None + txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + else: + check_db_for_txos.append(txi.txo_ref.tx_ref.id) - synced_history.write(f'{tx.id}:{tx.height}:') + referenced_txos = { + txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos) + } - await self.db.save_transaction_io( - tx, address, self.address_to_hash160(address), synced_history.getvalue() - ) + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id) + if referenced_txo is not None: + txi.txo_ref = referenced_txo.ref - self._on_transaction_controller.add(TransactionEvent(address, tx)) + synced_history.write(f'{tx.id}:{tx.height}:') - if address_manager is None: - address_manager = await self.get_address_manager_for_address(address) + await self.db.save_transaction_io( + tx, address, self.address_to_hash160(address), synced_history.getvalue() + ) - if address_manager is not None: - await address_manager.ensure_address_gap() + await self._on_transaction_controller.add(TransactionEvent(address, tx)) + + if address_manager is None: + address_manager = await self.get_address_manager_for_address(address) + + if address_manager is not None: + await address_manager.ensure_address_gap() async def cache_transaction(self, txid, remote_height): cache_item = self._tx_cache.get(txid) @@ -449,9 +454,8 @@ class BaseLedger(metaclass=LedgerRegistry): (cache_item.tx.is_verified or remote_height < 1): return cache_item.tx # cached tx is already up-to-date - await cache_item.lock.acquire() + async with cache_item.lock: - try: tx = cache_item.tx if tx is None: @@ -478,9 +482,6 @@ class BaseLedger(metaclass=LedgerRegistry): return tx - finally: - cache_item.lock.release() - async def maybe_verify_transaction(self, tx, remote_height): tx.height = remote_height if 0 < remote_height <= len(self.headers): @@ -514,6 +515,6 @@ class BaseLedger(metaclass=LedgerRegistry): records = await self.db.get_addresses(cols=('address',), address__in=addresses) await asyncio.wait([ self.on_transaction.where(partial( - lambda a, e: a == e.address and e.tx.height >= height, address_record['address'] + lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, address_record['address'] )) for address_record in records ]) diff --git a/torba/testcase.py b/torba/testcase.py index 7b808a920..8e61b7828 100644 --- a/torba/testcase.py +++ b/torba/testcase.py @@ -18,14 +18,22 @@ class ColorHandler(logging.StreamHandler): level_color = { logging.DEBUG: "black", - logging.INFO: "black", + logging.INFO: "light_gray", logging.WARNING: "yellow", logging.ERROR: "red" } color_code = dict( - black=30, red=31, green=32, yellow=33, - blue=34, magenta=35, cyan=36, white=37 + black=30, + red=31, + green=32, + yellow=33, + blue=34, + magenta=35, + cyan=36, + white=37, + light_gray='0;37', + dark_gray='1;30' ) def emit(self, record):