added ledger.wait(tx) to support reliably waiting for TX to be in ledger

This commit is contained in:
Lex Berezhny 2018-11-07 14:42:17 -05:00
parent 8fcbf48bdf
commit fdf479241a
3 changed files with 36 additions and 27 deletions

View file

@ -37,9 +37,9 @@ class BasicTransactionTests(IntegrationTestCase):
[account1], account1 [account1], account1
) )
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx) # mempool await self.ledger.wait(tx) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx) # confirmed await self.ledger.wait(tx) # confirmed
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
@ -51,9 +51,9 @@ class BasicTransactionTests(IntegrationTestCase):
[account1], account1 [account1], account1
) )
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx) # mempool await self.ledger.wait(tx) # mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx) # confirmed await self.ledger.wait(tx) # confirmed
txs = await account1.get_transactions() txs = await account1.get_transactions()
tx = txs[1] tx = txs[1]

View file

@ -1,6 +1,7 @@
import os import os
import asyncio import asyncio
import logging import logging
from functools import partial
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from io import StringIO from io import StringIO
@ -384,3 +385,21 @@ class BaseLedger(metaclass=LedgerRegistry):
def broadcast(self, tx): def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())
async def wait(self, tx: basetransaction.BaseTransaction, height=0):
addresses = set()
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
addresses.add(
self.hash160_to_address(txi.txo_ref.txo.script.values['pubkey_hash'])
)
for txo in tx.outputs:
addresses.add(
self.hash160_to_address(txo.script.values['pubkey_hash'])
)
records = await self.db.get_addresses(cols=('address',), address__in=addresses)
await asyncio.wait([
self.on_transaction.where(partial(
lambda a, e: a == e.address and e.tx.height >= height, address_record['address']
)) for address_record in records
])

View file

@ -3,6 +3,12 @@ import logging
import unittest import unittest
from unittest.case import _Outcome from unittest.case import _Outcome
from torba.orchstr8 import Conductor from torba.orchstr8 import Conductor
from torba.orchstr8.node import BlockchainNode, WalletNode
from torba.client.baseledger import BaseLedger
from torba.client.baseaccount import BaseAccount
from torba.client.basemanager import BaseWalletManager
from torba.client.wallet import Wallet
from typing import Optional
try: try:
@ -130,13 +136,13 @@ class IntegrationTestCase(AsyncioTestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conductor = None self.conductor: Optional[Conductor] = None
self.blockchain = None self.blockchain: Optional[BlockchainNode] = None
self.wallet_node = None self.wallet_node: Optional[WalletNode] = None
self.manager = None self.manager: Optional[BaseWalletManager] = None
self.ledger = None self.ledger: Optional[BaseLedger] = None
self.wallet = None self.wallet: Optional[Wallet] = None
self.account = None self.account: Optional[BaseAccount] = None
async def asyncSetUp(self): async def asyncSetUp(self):
self.conductor = Conductor( self.conductor = Conductor(
@ -178,19 +184,3 @@ class IntegrationTestCase(AsyncioTestCase):
return self.ledger.on_transaction.where( return self.ledger.on_transaction.where(
lambda e: e.tx.id == tx.id and e.address == address lambda e: e.tx.id == tx.id and e.address == address
) )
async def on_transaction(self, tx):
addresses = await self.get_tx_addresses(tx, self.ledger)
await asyncio.wait([
self.ledger.on_transaction.where(lambda e: e.address == address) # pylint: disable=W0640
for address in addresses
])
async def get_tx_addresses(self, tx, ledger):
addresses = set()
for txo in tx.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = await ledger.db.get_address(address=address)
if record is not None:
addresses.add(address)
return list(addresses)