From fdf479241a1b4f1f9983b28be344a471f21d3aed Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 7 Nov 2018 14:42:17 -0500 Subject: [PATCH] added ledger.wait(tx) to support reliably waiting for TX to be in ledger --- .../integration/test_transactions.py | 8 ++--- torba/client/baseledger.py | 19 ++++++++++ torba/testcase.py | 36 +++++++------------ 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/tests/client_tests/integration/test_transactions.py b/tests/client_tests/integration/test_transactions.py index 30fb12d52..572a2f000 100644 --- a/tests/client_tests/integration/test_transactions.py +++ b/tests/client_tests/integration/test_transactions.py @@ -37,9 +37,9 @@ class BasicTransactionTests(IntegrationTestCase): [account1], account1 ) await self.broadcast(tx) - await self.on_transaction(tx) # mempool + await self.ledger.wait(tx) # mempool await self.blockchain.generate(1) - await self.on_transaction(tx) # confirmed + await self.ledger.wait(tx) # confirmed self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) @@ -51,9 +51,9 @@ class BasicTransactionTests(IntegrationTestCase): [account1], account1 ) await self.broadcast(tx) - await self.on_transaction(tx) # mempool + await self.ledger.wait(tx) # mempool await self.blockchain.generate(1) - await self.on_transaction(tx) # confirmed + await self.ledger.wait(tx) # confirmed txs = await account1.get_transactions() tx = txs[1] diff --git a/torba/client/baseledger.py b/torba/client/baseledger.py index 79c458b33..11a603802 100644 --- a/torba/client/baseledger.py +++ b/torba/client/baseledger.py @@ -1,6 +1,7 @@ import os import asyncio import logging +from functools import partial from binascii import hexlify, unhexlify from io import StringIO @@ -384,3 +385,21 @@ class BaseLedger(metaclass=LedgerRegistry): def broadcast(self, tx): return self.network.broadcast(hexlify(tx.raw).decode()) + + async def wait(self, tx: basetransaction.BaseTransaction, height=0): + addresses = set() + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + addresses.add( + self.hash160_to_address(txi.txo_ref.txo.script.values['pubkey_hash']) + ) + for txo in tx.outputs: + addresses.add( + self.hash160_to_address(txo.script.values['pubkey_hash']) + ) + records = await self.db.get_addresses(cols=('address',), address__in=addresses) + await asyncio.wait([ + self.on_transaction.where(partial( + lambda a, e: a == e.address and e.tx.height >= height, address_record['address'] + )) for address_record in records + ]) diff --git a/torba/testcase.py b/torba/testcase.py index 988ab1eaf..2c8cc84d1 100644 --- a/torba/testcase.py +++ b/torba/testcase.py @@ -3,6 +3,12 @@ import logging import unittest from unittest.case import _Outcome from torba.orchstr8 import Conductor +from torba.orchstr8.node import BlockchainNode, WalletNode +from torba.client.baseledger import BaseLedger +from torba.client.baseaccount import BaseAccount +from torba.client.basemanager import BaseWalletManager +from torba.client.wallet import Wallet +from typing import Optional try: @@ -130,13 +136,13 @@ class IntegrationTestCase(AsyncioTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.conductor = None - self.blockchain = None - self.wallet_node = None - self.manager = None - self.ledger = None - self.wallet = None - self.account = None + self.conductor: Optional[Conductor] = None + self.blockchain: Optional[BlockchainNode] = None + self.wallet_node: Optional[WalletNode] = None + self.manager: Optional[BaseWalletManager] = None + self.ledger: Optional[BaseLedger] = None + self.wallet: Optional[Wallet] = None + self.account: Optional[BaseAccount] = None async def asyncSetUp(self): self.conductor = Conductor( @@ -178,19 +184,3 @@ class IntegrationTestCase(AsyncioTestCase): return self.ledger.on_transaction.where( lambda e: e.tx.id == tx.id and e.address == address ) - - async def on_transaction(self, tx): - addresses = await self.get_tx_addresses(tx, self.ledger) - await asyncio.wait([ - self.ledger.on_transaction.where(lambda e: e.address == address) # pylint: disable=W0640 - for address in addresses - ]) - - async def get_tx_addresses(self, tx, ledger): - addresses = set() - for txo in tx.outputs: - address = ledger.hash160_to_address(txo.script.values['pubkey_hash']) - record = await ledger.db.get_address(address=address) - if record is not None: - addresses.add(address) - return list(addresses)