test sync integration test and other fixes

This commit is contained in:
Lex Berezhny 2018-11-19 18:04:07 -05:00
parent 1732167af1
commit e5714dc1fc
5 changed files with 130 additions and 4 deletions

View file

@ -0,0 +1,97 @@
import asyncio
import logging
from torba.testcase import IntegrationTestCase, WalletNode
from torba.client.constants import CENT
class SyncTests(IntegrationTestCase):
VERBOSITY = logging.INFO
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_port = 5280
self.started_nodes = []
async def asyncTearDown(self):
for node in self.started_nodes:
try:
await node.stop(cleanup=True)
except Exception as e:
print(e)
await super().asyncTearDown()
async def make_wallet_node(self, seed=None):
self.api_port += 1
wallet_node = WalletNode(
self.wallet_node.manager_class,
self.wallet_node.ledger_class,
api_port=self.api_port
)
await wallet_node.start(seed)
self.started_nodes.append(wallet_node)
return wallet_node
async def test_nodes_with_same_account_stay_in_sync(self):
# destination node/account for receiving TXs
node0 = await self.make_wallet_node()
account0 = node0.account
# main node/account creating TXs
node1 = self.wallet_node
account1 = self.wallet_node.account
# mirror node/account, expected to reflect everything in main node as it happens
node2 = await self.make_wallet_node(account1.seed)
account2 = node2.account
self.assertNotEqual(account0.id, account1.id)
self.assertEqual(account1.id, account2.id)
await self.assertBalance(account0, '0.0')
await self.assertBalance(account1, '0.0')
await self.assertBalance(account2, '0.0')
self.assertEqual(await account0.get_address_count(chain=0), 20)
self.assertEqual(await account1.get_address_count(chain=0), 20)
self.assertEqual(await account2.get_address_count(chain=0), 20)
self.assertEqual(await account1.get_address_count(chain=1), 6)
self.assertEqual(await account2.get_address_count(chain=1), 6)
# check that main node and mirror node generate 5 address to fill gap
fifth_address = (await account1.receiving.get_addresses())[4]
await self.blockchain.send_to_address(fifth_address, 1.00)
await asyncio.wait([
account1.ledger.on_address.first,
account2.ledger.on_address.first
])
self.assertEqual(await account1.get_address_count(chain=0), 25)
self.assertEqual(await account2.get_address_count(chain=0), 25)
await self.assertBalance(account1, '1.0')
await self.assertBalance(account2, '1.0')
await self.blockchain.generate(1)
# pay 0.01 from main node to receiving node, would have increased change addresses
address0 = (await account0.receiving.get_addresses())[0]
hash0 = self.ledger.address_to_hash160(address0)
tx = await account1.ledger.transaction_class.create(
[],
[self.ledger.transaction_class.output_class.pay_pubkey_hash(CENT, hash0)],
[account1], account1
)
await self.broadcast(tx)
await asyncio.wait([
account0.ledger.wait(tx),
account1.ledger.wait(tx),
account2.ledger.wait(tx),
])
self.assertEqual(await account0.get_address_count(chain=0), 21)
self.assertGreater(await account1.get_address_count(chain=1), 6)
self.assertGreater(await account2.get_address_count(chain=1), 6)
await self.assertBalance(account0, '0.01')
await self.assertBalance(account1, '0.989876')
await self.assertBalance(account2, '0.989876')
await self.blockchain.generate(1)
# create a new mirror node and see if it syncs to same balance from scratch
node3 = await self.make_wallet_node(account1.seed)
account3 = node3.account
await self.assertBalance(account3, '0.989876')

View file

@ -238,3 +238,12 @@ class TestQueries(AsyncioTestCase):
tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2) tx = await self.ledger.db.get_transaction(txid=tx2.id, account=account2)
self.assertEqual(tx.inputs[0].is_my_account, False) self.assertEqual(tx.inputs[0].is_my_account, False)
self.assertEqual(tx.outputs[0].is_my_account, True) self.assertEqual(tx.outputs[0].is_my_account, True)
# height 0 sorted to the top with the rest in descending order
tx4 = await self.create_tx_from_nothing(account1, 0)
txos = await self.ledger.db.get_txos()
self.assertEqual([0, 2, 1], [txo.tx_ref.height for txo in txos])
self.assertEqual([tx4.id, tx2.id, tx1.id], [txo.tx_ref.id for txo in txos])
txs = await self.ledger.db.get_transactions()
self.assertEqual([0, 3, 2, 1], [tx.height for tx in txs])
self.assertEqual([tx4.id, tx3.id, tx2.id, tx1.id], [tx.id for tx in txs])

View file

@ -128,7 +128,7 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['position']+1 if addresses else 0 start = addresses[0]['position']+1 if addresses else 0
end = start + (self.gap - existing_gap) end = start + (self.gap - existing_gap)
new_keys = await self._generate_keys(start, end-1) new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.subscribe_addresses(self, new_keys) await self.account.ledger.announce_addresses(self, new_keys)
return new_keys return new_keys
async def _generate_keys(self, start: int, end: int) -> List[str]: async def _generate_keys(self, start: int, end: int) -> List[str]:
@ -148,7 +148,9 @@ class HierarchicalDeterministic(AddressManager):
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable: if only_usable:
constraints['used_times__lt'] = self.maximum_uses_per_address constraints['used_times__lt'] = self.maximum_uses_per_address
return self._query_addresses(order_by="used_times ASC, position ASC", **constraints) if 'order_by' not in constraints:
constraints['order_by'] = "used_times ASC, position ASC"
return self._query_addresses(**constraints)
class SingleKey(AddressManager): class SingleKey(AddressManager):
@ -181,7 +183,7 @@ class SingleKey(AddressManager):
self.account, self.chain_number, [(0, self.public_key)] self.account, self.chain_number, [(0, self.public_key)]
) )
new_keys = [self.public_key.address] new_keys = [self.public_key.address]
await self.account.ledger.subscribe_addresses(self, new_keys) await self.account.ledger.announce_addresses(self, new_keys)
return new_keys return new_keys
return [] return []

View file

@ -355,7 +355,7 @@ class BaseDatabase(SQLiteMixin):
tx_rows = await self.select_transactions( tx_rows = await self.select_transactions(
'txid, raw, height, position, is_verified', 'txid, raw, height, position, is_verified',
order_by=["height DESC", "position DESC"], order_by=["height=0 DESC", "height DESC", "position DESC"],
**constraints **constraints
) )
@ -422,6 +422,8 @@ class BaseDatabase(SQLiteMixin):
my_account = my_account or constraints.get('account', None) my_account = my_account or constraints.get('account', None)
if isinstance(my_account, BaseAccount): if isinstance(my_account, BaseAccount):
my_account = my_account.public_key.address my_account = my_account.public_key.address
if 'order_by' not in constraints:
constraints['order_by'] = ["tx.height=0 DESC", "tx.height DESC", "tx.position DESC"]
rows = await self.select_txos( rows = await self.select_txos(
"amount, script, txid, tx.height, txo.position, chain, account", **constraints "amount, script, txid, tx.height, txo.position, chain, account", **constraints
) )

View file

@ -44,6 +44,10 @@ class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx'))):
pass pass
class AddressesGeneratedEvent(namedtuple('AddressesGeneratedEvent', ('address_manager', 'addresses'))):
pass
class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
pass pass
@ -138,6 +142,12 @@ class BaseLedger(metaclass=LedgerRegistry):
) )
) )
self._on_address_controller = StreamController()
self.on_address = self._on_address_controller.stream
self.on_address.listen(
lambda e: log.info('(%s) on_address: %s', self.get_id(), e.addresses)
)
self._on_header_controller = StreamController() self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self.on_header.listen( self.on_header.listen(
@ -350,6 +360,12 @@ class BaseLedger(metaclass=LedgerRegistry):
await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap() await account.ensure_address_gap()
async def announce_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
await self.subscribe_addresses(address_manager, addresses)
await self._on_address_controller.add(
AddressesGeneratedEvent(address_manager, addresses)
)
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses: if self.network.is_connected and addresses:
await asyncio.wait([ await asyncio.wait([