diff --git a/tests/client_tests/integration/test_transactions.py b/tests/client_tests/integration/test_transactions.py index 572a2f000..93f49a0d7 100644 --- a/tests/client_tests/integration/test_transactions.py +++ b/tests/client_tests/integration/test_transactions.py @@ -10,10 +10,10 @@ class BasicTransactionTests(IntegrationTestCase): async def test_sending_and_receiving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) - await self.ledger.update_account(account2) + await self.ledger.subscribe_account(account2) - self.assertEqual(await self.get_balance(account1), 0) - self.assertEqual(await self.get_balance(account2), 0) + await self.assertBalance(account1, '0.0') + await self.assertBalance(account2, '0.0') sendtxids = [] for i in range(5): @@ -26,8 +26,8 @@ class BasicTransactionTests(IntegrationTestCase): self.on_transaction_id(txid) for txid in sendtxids ]) - self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5) - self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0) + await self.assertBalance(account1, '5.5') + await self.assertBalance(account2, '0.0') address2 = await account2.receiving.get_or_create_usable_address() hash2 = self.ledger.address_to_hash160(address2) @@ -41,8 +41,8 @@ class BasicTransactionTests(IntegrationTestCase): await self.blockchain.generate(1) await self.ledger.wait(tx) # confirmed - self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) - self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) + await self.assertBalance(account1, '3.499802') + await self.assertBalance(account2, '2.0') utxos = await self.account.get_utxos() tx = await self.ledger.transaction_class.create( diff --git a/tests/client_tests/unit/test_account.py b/tests/client_tests/unit/test_account.py index 8c50cea1c..b6a6d6895 100644 --- a/tests/client_tests/unit/test_account.py +++ b/tests/client_tests/unit/test_account.py @@ -44,7 +44,8 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase): self.assertEqual(len(addresses), 26) async def test_generate_keys_over_batch_threshold_saves_it_properly(self): - await self.account.receiving.generate_keys(0, 200) + async with self.account.receiving.address_generator_lock: + await self.account.receiving._generate_keys(0, 200) records = await self.account.receiving.get_address_records() self.assertEqual(201, len(records)) @@ -53,9 +54,10 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase): self.assertIsInstance(account.receiving, HierarchicalDeterministic) - await account.receiving.generate_keys(4, 7) - await account.receiving.generate_keys(0, 3) - await account.receiving.generate_keys(8, 11) + async with account.receiving.address_generator_lock: + await account.receiving._generate_keys(4, 7) + await account.receiving._generate_keys(0, 3) + await account.receiving._generate_keys(8, 11) records = await account.receiving.get_address_records() self.assertEqual( [r['position'] for r in records], diff --git a/tests/client_tests/unit/test_database.py b/tests/client_tests/unit/test_database.py index dcd241539..374280d52 100644 --- a/tests/client_tests/unit/test_database.py +++ b/tests/client_tests/unit/test_database.py @@ -128,7 +128,8 @@ class TestQueries(AsyncioTestCase): tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(self.txo(1, NULL_HASH))]) \ .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') + await self.ledger.db.insert_transaction(tx) + await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '') return tx async def create_tx_from_txo(self, txo, to_account, height): @@ -139,8 +140,9 @@ class TestQueries(AsyncioTestCase): tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(txo)]) \ .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') - await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') + await self.ledger.db.insert_transaction(tx) + await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '') + await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '') return tx async def create_tx_to_nowhere(self, txo, height): @@ -150,7 +152,8 @@ class TestQueries(AsyncioTestCase): tx = ledger_class.transaction_class(height=height, is_verified=True) \ .add_inputs([self.txi(txo)]) \ .add_outputs([self.txo(1, to_hash)]) - await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') + await self.ledger.db.insert_transaction(tx) + await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '') return tx def txo(self, amount, address): diff --git a/tests/client_tests/unit/test_ledger.py b/tests/client_tests/unit/test_ledger.py index caa5b990c..0e077b441 100644 --- a/tests/client_tests/unit/test_ledger.py +++ b/tests/client_tests/unit/test_ledger.py @@ -16,6 +16,7 @@ class MockNetwork: self.address = None self.get_history_called = [] self.get_transaction_called = [] + self.is_connected = False async def get_history(self, address): self.get_history_called.append(address) @@ -85,16 +86,21 @@ class TestSynchronization(LedgerTestCase): 'abcd02': hexlify(get_transaction(get_output(2)).raw), 'abcd03': hexlify(get_transaction(get_output(3)).raw), }) - await self.ledger.update_history(address) + await self.ledger.update_history(address, '') self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03']) address_details = await self.ledger.db.get_address(address=address) - self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:') + self.assertEqual( + address_details['history'], + '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:' + 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:' + 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:' + ) self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] - await self.ledger.update_history(address) + await self.ledger.update_history(address, '') self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, []) @@ -102,11 +108,17 @@ class TestSynchronization(LedgerTestCase): self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.get_history_called = [] self.ledger.network.get_transaction_called = [] - await self.ledger.update_history(address) + await self.ledger.update_history(address, '') self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04']) address_details = await self.ledger.db.get_address(address=address) - self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:') + self.assertEqual( + address_details['history'], + '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:' + 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:' + 'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:' + '047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828:3:' + ) class MocHeaderNetwork: diff --git a/tests/client_tests/unit/test_transaction.py b/tests/client_tests/unit/test_transaction.py index 8532f95fc..800c05322 100644 --- a/tests/client_tests/unit/test_transaction.py +++ b/tests/client_tests/unit/test_transaction.py @@ -261,14 +261,14 @@ class TransactionIOBalancing(AsyncioTestCase): .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ .add_outputs(utxos) - save_tx = 'insert' + await self.ledger.db.insert_transaction(self.funding_tx) + for utxo in utxos: await self.ledger.db.save_transaction_io( - save_tx, self.funding_tx, + self.funding_tx, self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), utxo.script.values['pubkey_hash'], '' ) - save_tx = 'update' return utxos diff --git a/tests/client_tests/unit/test_utils.py b/tests/client_tests/unit/test_utils.py index 8bc368302..3d9e19d0a 100644 --- a/tests/client_tests/unit/test_utils.py +++ b/tests/client_tests/unit/test_utils.py @@ -1,6 +1,41 @@ import unittest from torba.client.util import ArithUint256 +from torba.client.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c + + +class TestCoinValueParsing(unittest.TestCase): + + def test_good_output(self): + self.assertEqual(s2c(1), "0.00000001") + self.assertEqual(s2c(10**7), "0.1") + self.assertEqual(s2c(2*10**8), "2.0") + self.assertEqual(s2c(2*10**17), "2000000000.0") + + def test_good_input(self): + self.assertEqual(c2s("0.00000001"), 1) + self.assertEqual(c2s("0.1"), 10**7) + self.assertEqual(c2s("1.0"), 10**8) + self.assertEqual(c2s("2.00000000"), 2*10**8) + self.assertEqual(c2s("2000000000.0"), 2*10**17) + + def test_bad_input(self): + with self.assertRaises(ValueError): + c2s("1") + with self.assertRaises(ValueError): + c2s("-1.0") + with self.assertRaises(ValueError): + c2s("10000000000.0") + with self.assertRaises(ValueError): + c2s("1.000000000") + with self.assertRaises(ValueError): + c2s("-0") + with self.assertRaises(ValueError): + c2s("1") + with self.assertRaises(ValueError): + c2s(".1") + with self.assertRaises(ValueError): + c2s("1e-7") class TestArithUint256(unittest.TestCase): diff --git a/torba/client/baseaccount.py b/torba/client/baseaccount.py index 9e9cf8d05..ad173f197 100644 --- a/torba/client/baseaccount.py +++ b/torba/client/baseaccount.py @@ -1,6 +1,7 @@ +import asyncio import random import typing -from typing import Dict, Tuple, Type, Optional, Any +from typing import Dict, Tuple, Type, Optional, Any, List from torba.client.mnemonic import Mnemonic from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string @@ -15,12 +16,13 @@ class AddressManager: name: str - __slots__ = 'account', 'public_key', 'chain_number' + __slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock' def __init__(self, account, public_key, chain_number): self.account = account self.public_key = public_key self.chain_number = chain_number + self.address_generator_lock = asyncio.Lock() @classmethod def from_dict(cls, account: 'BaseAccount', d: dict) \ @@ -60,11 +62,11 @@ class AddressManager: def get_address_records(self, only_usable: bool = False, **constraints): raise NotImplementedError - async def get_addresses(self, only_usable: bool = False, **constraints): + async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]: records = await self.get_address_records(only_usable=only_usable, **constraints) return [r['address'] for r in records] - async def get_or_create_usable_address(self): + async def get_or_create_usable_address(self) -> str: addresses = await self.get_addresses(only_usable=True, limit=10) if addresses: return random.choice(addresses) @@ -87,8 +89,8 @@ class HierarchicalDeterministic(AddressManager): @classmethod def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]: return ( - cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})), - cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2})) + cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})), + cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1})) ) def to_dict_instance(self): @@ -97,19 +99,7 @@ class HierarchicalDeterministic(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key.child(self.chain_number).child(index) - async def generate_keys(self, start: int, end: int): - keys_batch, final_keys = [], [] - for index in range(start, end+1): - keys_batch.append((index, self.public_key.child(index))) - if index % 180 == 0 or index == end: - await self.account.ledger.db.add_keys( - self.account, self.chain_number, keys_batch - ) - final_keys.extend(keys_batch) - keys_batch.clear() - return [key[1].address for key in final_keys] - - async def get_max_gap(self): + async def get_max_gap(self) -> int: addresses = await self._query_addresses(order_by="position ASC") max_gap = 0 current_gap = 0 @@ -121,27 +111,43 @@ class HierarchicalDeterministic(AddressManager): current_gap = 0 return max_gap - async def ensure_address_gap(self): - addresses = await self._query_addresses(limit=self.gap, order_by="position DESC") + async def ensure_address_gap(self) -> List[str]: + async with self.address_generator_lock: + addresses = await self._query_addresses(limit=self.gap, order_by="position DESC") - existing_gap = 0 - for address in addresses: - if address['used_times'] == 0: - existing_gap += 1 - else: - break + existing_gap = 0 + for address in addresses: + if address['used_times'] == 0: + existing_gap += 1 + else: + break - if existing_gap == self.gap: - return [] + if existing_gap == self.gap: + return [] - start = addresses[0]['position']+1 if addresses else 0 - end = start + (self.gap - existing_gap) - new_keys = await self.generate_keys(start, end-1) - return new_keys + start = addresses[0]['position']+1 if addresses else 0 + end = start + (self.gap - existing_gap) + new_keys = await self._generate_keys(start, end-1) + await self.account.ledger.subscribe_addresses(self, new_keys) + return new_keys + + async def _generate_keys(self, start: int, end: int) -> List[str]: + if not self.address_generator_lock.locked(): + raise RuntimeError('Should not be called outside of address_generator_lock.') + keys_batch, final_keys = [], [] + for index in range(start, end+1): + keys_batch.append((index, self.public_key.child(index))) + if index % 180 == 0 or index == end: + await self.account.ledger.db.add_keys( + self.account, self.chain_number, keys_batch + ) + final_keys.extend(keys_batch) + keys_batch.clear() + return [key[1].address for key in final_keys] def get_address_records(self, only_usable: bool = False, **constraints): if only_usable: - constraints['used_times__lte'] = self.maximum_uses_per_address + constraints['used_times__lt'] = self.maximum_uses_per_address return self._query_addresses(order_by="used_times ASC, position ASC", **constraints) @@ -164,17 +170,20 @@ class SingleKey(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key - async def get_max_gap(self): + async def get_max_gap(self) -> int: return 0 - async def ensure_address_gap(self): - exists = await self.get_address_records() - if not exists: - await self.account.ledger.db.add_keys( - self.account, self.chain_number, [(0, self.public_key)] - ) - return [self.public_key.address] - return [] + async def ensure_address_gap(self) -> List[str]: + async with self.address_generator_lock: + exists = await self.get_address_records() + if not exists: + await self.account.ledger.db.add_keys( + self.account, self.chain_number, [(0, self.public_key)] + ) + new_keys = [self.public_key.address] + await self.account.ledger.subscribe_addresses(self, new_keys) + return new_keys + return [] def get_address_records(self, only_usable: bool = False, **constraints): return self._query_addresses(**constraints) @@ -211,7 +220,7 @@ class BaseAccount: generator_name = address_generator.get('name', HierarchicalDeterministic.name) self.address_generator = self.address_generators[generator_name] self.receiving, self.change = self.address_generator.from_dict(self, address_generator) - self.address_managers = {self.receiving, self.change} + self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}} ledger.add_account(self) wallet.add_account(self) @@ -320,12 +329,12 @@ class BaseAccount: async def ensure_address_gap(self): addresses = [] - for address_manager in self.address_managers: + for address_manager in self.address_managers.values(): new_addresses = await address_manager.ensure_address_gap() addresses.extend(new_addresses) return addresses - async def get_addresses(self, **constraints): + async def get_addresses(self, **constraints) -> List[str]: rows = await self.ledger.db.select_addresses('address', account=self, **constraints) return [r[0] for r in rows] @@ -337,8 +346,7 @@ class BaseAccount: def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." - address_manager = {0: self.receiving, 1: self.change}[chain] - return address_manager.get_private_key(index) + return self.address_managers[chain].get_private_key(index) def get_balance(self, confirmations: int = 0, **constraints): if confirmations > 0: diff --git a/torba/client/basedatabase.py b/torba/client/basedatabase.py index c64fedff6..dc5a1d93f 100644 --- a/torba/client/basedatabase.py +++ b/torba/client/basedatabase.py @@ -169,13 +169,16 @@ class SQLiteMixin: await self.db.close() @staticmethod - def _insert_sql(table: str, data: dict) -> Tuple[str, List]: + def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]: columns, values = [], [] for column, value in data.items(): columns.append(column) values.append(value) - sql = "INSERT INTO {} ({}) VALUES ({})".format( - table, ', '.join(columns), ', '.join(['?'] * len(values)) + or_ignore = "" + if ignore_duplicate: + or_ignore = " OR IGNORE" + sql = "INSERT{} INTO {} ({}) VALUES ({})".format( + or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values)) ) return sql, values @@ -273,60 +276,49 @@ class BaseDatabase(SQLiteMixin): 'script': sqlite3.Binary(txo.script.source) } - def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): + async def insert_transaction(self, tx): + await self.db.execute(*self._insert_sql('tx', { + 'txid': tx.id, + 'raw': sqlite3.Binary(tx.raw), + 'height': tx.height, + 'position': tx.position, + 'is_verified': tx.is_verified + })) - def _transaction(conn: sqlite3.Connection, save_tx, tx: BaseTransaction, address, txhash, history): - if save_tx == 'insert': - conn.execute(*self._insert_sql('tx', { - 'txid': tx.id, - 'raw': sqlite3.Binary(tx.raw), - 'height': tx.height, - 'position': tx.position, - 'is_verified': tx.is_verified - })) - elif save_tx == 'update': - conn.execute(*self._update_sql("tx", { - 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified - }, 'txid = ?', (tx.id,))) + async def update_transaction(self, tx): + await self.db.execute(*self._update_sql("tx", { + 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified + }, 'txid = ?', (tx.id,))) - existing_txos = set(map(itemgetter(0), conn.execute(*query( - "SELECT position FROM txo", txid=tx.id - )))) + def save_transaction_io(self, tx: BaseTransaction, address, txhash, history): + + def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history): for txo in tx.outputs: - if txo.position in existing_txos: - continue if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: - conn.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo))) + conn.execute(*self._insert_sql( + "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True + )) elif txo.script.is_pay_script_hash: # TODO: implement script hash payments log.warning('Database.save_transaction_io: pay script hash is not implemented!') - # lookup the address associated with each TXI (via its TXO) - txoid_to_address = {r[0]: r[1] for r in conn.execute(*query( - "SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs] - ))} - - # list of TXIs that have already been added - existing_txis = {r[0] for r in conn.execute(*query( - "SELECT txoid FROM txi", txid=tx.id - ))} - for txi in tx.inputs: - txoid = txi.txo_ref.id - new_txi = txoid not in existing_txis - address_matches = txoid_to_address.get(txoid) == address - if new_txi and address_matches: - conn.execute(*self._insert_sql("txi", { - 'txid': tx.id, - 'txoid': txoid, - 'address': address, - })) + if txi.txo_ref.txo is not None: + txo = txi.txo_ref.txo + if txo.get_address(self.ledger) == address: + conn.execute(*self._insert_sql("txi", { + 'txid': tx.id, + 'txoid': txo.id, + 'address': address, + }, ignore_duplicate=True)) + conn.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) - return self.db.run(_transaction, save_tx, tx, address, txhash, history) + + return self.db.run(_transaction, tx, address, txhash, history) async def reserve_outputs(self, txos, is_reserved=True): txoids = [txo.id for txo in txos] diff --git a/torba/client/baseledger.py b/torba/client/baseledger.py index 11a603802..6fd86b008 100644 --- a/torba/client/baseledger.py +++ b/torba/client/baseledger.py @@ -5,7 +5,7 @@ from functools import partial from binascii import hexlify, unhexlify from io import StringIO -from typing import Dict, Type, Iterable +from typing import Dict, Type, Iterable, List, Optional from operator import itemgetter from collections import namedtuple @@ -48,6 +48,51 @@ class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))): pass +class TransactionCacheItem: + __slots__ = '_tx', 'lock', 'has_tx' + + def __init__(self, + tx: Optional[basetransaction.BaseTransaction] = None, + lock: Optional[asyncio.Lock] = None): + self.has_tx = asyncio.Event() + self.lock = lock or asyncio.Lock() + self.tx = tx + + @property + def tx(self): + return self._tx + + @tx.setter + def tx(self, tx): + self._tx = tx + if tx is not None: + self.has_tx.set() + + +class SynchronizationMonitor: + + def __init__(self): + self.done = asyncio.Event() + self.tasks = [] + + def add(self, coro): + len(self.tasks) < 1 and self.done.clear() + asyncio.ensure_future(self._monitor(coro)) + + def cancel(self): + for task in self.tasks: + task.cancel() + + async def _monitor(self, coro): + task = asyncio.ensure_future(coro) + self.tasks.append(task) + try: + await task + finally: + self.tasks.remove(task) + len(self.tasks) < 1 and self.done.set() + + class BaseLedger(metaclass=LedgerRegistry): name: str @@ -79,7 +124,8 @@ class BaseLedger(metaclass=LedgerRegistry): ) self.network = self.config.get('network') or self.network_class(self) self.network.on_header.listen(self.receive_header) - self.network.on_status.listen(self.receive_status) + self.network.on_status.listen(self.process_status_update) + self.accounts = [] self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) @@ -101,7 +147,8 @@ class BaseLedger(metaclass=LedgerRegistry): ) ) - self._transaction_processing_locks = {} + self._tx_cache = {} + self.sync = SynchronizationMonitor() self._utxo_reservation_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock() @@ -166,16 +213,14 @@ class BaseLedger(metaclass=LedgerRegistry): def release_outputs(self, txos): return self.db.release_outputs(txos) - async def get_local_status(self, address): - address_details = await self.db.get_address(address=address) - history = address_details['history'] - return hexlify(sha256(history.encode())).decode() if history else None - - async def get_local_history(self, address): + async def get_local_status_and_history(self, address): address_details = await self.db.get_address(address=address) history = address_details['history'] or '' parts = history.split(':')[:-1] - return list(zip(parts[0::2], map(int, parts[1::2]))) + return ( + hexlify(sha256(history.encode())).decode() if history else None, + list(zip(parts[0::2], map(int, parts[1::2]))) + ) @staticmethod def get_root_of_merkle_tree(branches, branch_positions, working_branch): @@ -189,21 +234,13 @@ class BaseLedger(metaclass=LedgerRegistry): working_branch = double_sha256(combined) return hexlify(working_branch[::-1]) - async def validate_transaction_and_set_position(self, tx, height, merkle): - if not height <= len(self.headers): - return False - merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) - header = self.headers[height] - tx.position = merkle['pos'] - tx.is_verified = merkle_root == header['merkle_root'] - async def start(self): if not os.path.exists(self.path): os.mkdir(self.path) - await asyncio.gather( + await asyncio.wait([ self.db.open(), self.headers.open() - ) + ]) first_connection = self.network.on_connected.first asyncio.ensure_future(self.network.start()) await first_connection @@ -214,9 +251,15 @@ class BaseLedger(metaclass=LedgerRegistry): log.info("Subscribing and updating accounts.") await self.update_headers() await self.network.subscribe_headers() - await self.update_accounts() + import time + start = time.time() + await self.subscribe_accounts() + await self.sync.done.wait() + log.info(f'elapsed: {time.time()-start}') async def stop(self): + self.sync.cancel() + await self.sync.done.wait() await self.network.stop() await self.db.close() await self.headers.close() @@ -299,89 +342,144 @@ class BaseLedger(metaclass=LedgerRegistry): height=header['height'], headers=header['hex'], subscription_update=True ) - async def update_accounts(self): - return await asyncio.gather(*( - self.update_account(a) for a in self.accounts - )) + async def subscribe_accounts(self): + if self.network.is_connected and self.accounts: + await asyncio.wait([ + self.subscribe_account(a) for a in self.accounts + ]) - async def update_account(self, account: baseaccount.BaseAccount): + async def subscribe_account(self, account: baseaccount.BaseAccount): + for address_manager in account.address_managers.values(): + await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) await account.ensure_address_gap() - addresses = await account.get_addresses() - while addresses: - await asyncio.gather(*(self.subscribe_history(a) for a in addresses)) - addresses = await account.ensure_address_gap() - def _prefetch_history(self, remote_history, local_history): - proofs, network_txs = {}, {} - for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): - if i < len(local_history) and local_history[i] == (hex_id, remote_height): - continue - if remote_height > 0: - proofs[hex_id] = asyncio.ensure_future(self.network.get_merkle(hex_id, remote_height)) - network_txs[hex_id] = asyncio.ensure_future(self.network.get_transaction(hex_id)) - return proofs, network_txs + async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]): + if self.network.is_connected and addresses: + await asyncio.wait([ + self.subscribe_address(address_manager, address) for address in addresses + ]) - async def update_history(self, address): - remote_history = await self.network.get_history(address) - local_history = await self.get_local_history(address) - proofs, network_txs = self._prefetch_history(remote_history, local_history) - - synced_history = StringIO() - for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): - - synced_history.write('{}:{}:'.format(hex_id, remote_height)) - - if i < len(local_history) and local_history[i] == (hex_id, remote_height): - continue - - lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock()) - - await lock.acquire() - - try: - - # see if we have a local copy of transaction, otherwise fetch it from server - tx = await self.db.get_transaction(txid=hex_id) - save_tx = None - if tx is None: - _raw = await network_txs[hex_id] - tx = self.transaction_class(unhexlify(_raw)) - save_tx = 'insert' - - tx.height = remote_height - - if remote_height > 0 and (not tx.is_verified or tx.position == -1): - await self.validate_transaction_and_set_position(tx, remote_height, await proofs[hex_id]) - if save_tx is None: - save_tx = 'update' - - await self.db.save_transaction_io( - save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue() - ) - - log.debug( - "%s: sync'ed tx %s for address: %s, height: %s, verified: %s", - self.get_id(), hex_id, address, tx.height, tx.is_verified - ) - - self._on_transaction_controller.add(TransactionEvent(address, tx)) - - finally: - lock.release() - if not lock.locked() and hex_id in self._transaction_processing_locks: - del self._transaction_processing_locks[hex_id] - - async def subscribe_history(self, address): + async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str): remote_status = await self.network.subscribe_address(address) - local_status = await self.get_local_status(address) - if local_status != remote_status: - await self.update_history(address) + self.sync.add(self.update_history(address, remote_status, address_manager)) - async def receive_status(self, response): - address, remote_status = response - local_status = await self.get_local_status(address) - if local_status != remote_status: - await self.update_history(address) + def process_status_update(self, update): + address, remote_status = update + self.sync.add(self.update_history(address, remote_status)) + + async def update_history(self, address, remote_status, + address_manager: baseaccount.AddressManager = None): + local_status, local_history = await self.get_local_status_and_history(address) + + if local_status == remote_status: + return + + remote_history = await self.network.get_history(address) + + cache_tasks = [] + synced_history = StringIO() + for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): + if i < len(local_history) and local_history[i] == (txid, remote_height): + synced_history.write(f'{txid}:{remote_height}:') + else: + cache_tasks.append(asyncio.ensure_future( + self.cache_transaction(txid, remote_height) + )) + + for task in cache_tasks: + tx = await task + + check_db_for_txos = [] + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id) + if cache_item is not None: + if cache_item.tx is None: + await cache_item.has_tx.wait() + txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + else: + check_db_for_txos.append(txi.txo_ref.tx_ref.id) + + referenced_txos = { + txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos) + } + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id) + if referenced_txos: + txi.txo_ref = referenced_txo.ref + + synced_history.write(f'{tx.id}:{tx.height}:') + + await self.db.save_transaction_io( + tx, address, self.address_to_hash160(address), synced_history.getvalue() + ) + + self._on_transaction_controller.add(TransactionEvent(address, tx)) + + if address_manager is None: + address_manager = await self.get_address_manager_for_address(address) + + await address_manager.ensure_address_gap() + + async def cache_transaction(self, txid, remote_height): + cache_item = self._tx_cache.get(txid) + if cache_item is None: + cache_item = self._tx_cache[txid] = TransactionCacheItem() + elif cache_item.tx is not None and \ + cache_item.tx.height >= remote_height and \ + (cache_item.tx.is_verified or remote_height < 1): + return cache_item.tx # cached tx is already up-to-date + + await cache_item.lock.acquire() + + try: + tx = cache_item.tx + + if tx is None: + # check local db + tx = cache_item.tx = await self.db.get_transaction(txid=txid) + + if tx is None: + # fetch from network + _raw = await self.network.get_transaction(txid) + if _raw: + tx = self.transaction_class(unhexlify(_raw)) + await self.maybe_verify_transaction(tx, remote_height) + await self.db.insert_transaction(tx) + cache_item.tx = tx # make sure it's saved before caching it + return tx + + if tx is None: + raise ValueError(f'Transaction {txid} was not in database and not on network.') + + if 0 < remote_height and not tx.is_verified: + # tx from cache / db is not up-to-date + await self.maybe_verify_transaction(tx, remote_height) + await self.db.update_transaction(tx) + + return tx + + finally: + cache_item.lock.release() + + async def maybe_verify_transaction(self, tx, remote_height): + tx.height = remote_height + if 0 < remote_height <= len(self.headers): + merkle = await self.network.get_merkle(tx.id, remote_height) + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + header = self.headers[remote_height] + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] + + async def get_address_manager_for_address(self, address) -> baseaccount.AddressManager: + details = await self.db.get_address(address=address) + for account in self.accounts: + if account.id == details['account']: + return account.address_managers[details['chain']] def broadcast(self, tx): return self.network.broadcast(hexlify(tx.raw).decode()) diff --git a/torba/client/util.py b/torba/client/util.py index 8ae838027..f57963a89 100644 --- a/torba/client/util.py +++ b/torba/client/util.py @@ -1,5 +1,25 @@ +import re from binascii import unhexlify, hexlify from typing import TypeVar, Sequence, Optional +from torba.client.constants import COIN + + +def coins_to_satoshis(coins): + if not isinstance(coins, str): + raise ValueError("{coins} must be a string") + result = re.search(r'^(\d{1,10})\.(\d{1,8})$', coins) + if result is not None: + whole, fractional = result.groups() + return int(whole+fractional.ljust(8, "0")) + raise ValueError("'{lbc}' is not a valid coin decimal") + + +def satoshis_to_coins(satoshis): + coins = '{:.8f}'.format(satoshis / COIN).rstrip('0') + if coins.endswith('.'): + return coins+'0' + else: + return coins T = TypeVar('T') diff --git a/torba/orchstr8/node.py b/torba/orchstr8/node.py index 7ebd2f879..46c9b9bb1 100644 --- a/torba/orchstr8/node.py +++ b/torba/orchstr8/node.py @@ -112,7 +112,7 @@ class Conductor: class WalletNode: def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger], - verbose: bool = False) -> None: + verbose: bool = False, api_port: int = 5279) -> None: self.manager_class = manager_class self.ledger_class = ledger_class self.verbose = verbose @@ -121,8 +121,9 @@ class WalletNode: self.wallet: Optional[Wallet] = None self.account: Optional[BaseAccount] = None self.data_path: Optional[str] = None + self.api_port = api_port - async def start(self): + async def start(self, seed=None): self.data_path = tempfile.mkdtemp() wallet_file_name = os.path.join(self.data_path, 'my_wallet.json') with open(wallet_file_name, 'w') as wallet_file: @@ -130,6 +131,7 @@ class WalletNode: self.manager = self.manager_class.from_config({ 'ledgers': { self.ledger_class.get_id(): { + 'api_port': self.api_port, 'default_servers': [('localhost', 1984)], 'data_path': self.data_path } @@ -138,7 +140,12 @@ class WalletNode: }) self.ledger = self.manager.ledgers[self.ledger_class] self.wallet = self.manager.default_wallet - self.wallet.generate_account(self.ledger) + if seed is None: + self.wallet.generate_account(self.ledger) + else: + self.ledger.account_class.from_dict( + self.ledger, self.wallet, {'seed': seed} + ) self.account = self.wallet.default_account await self.manager.start() diff --git a/torba/stream.py b/torba/stream.py index 64c1f5e10..40589ade0 100644 --- a/torba/stream.py +++ b/torba/stream.py @@ -29,18 +29,18 @@ class BroadcastSubscription: def _add(self, data): if self.can_fire and self._on_data is not None: - maybe_coroutine = self._on_data(data) - if asyncio.iscoroutine(maybe_coroutine): - asyncio.ensure_future(maybe_coroutine) + return self._on_data(data) def _add_error(self, exception): if self.can_fire and self._on_error is not None: - self._on_error(exception) + return self._on_error(exception) def _close(self): - if self.can_fire and self._on_done is not None: - self._on_done() - self.is_closed = True + try: + if self.can_fire and self._on_done is not None: + return self._on_done() + finally: + self.is_closed = True class StreamController: @@ -62,13 +62,28 @@ class StreamController: next_sub = next_sub._next yield subscription - def add(self, event): + def _notify_and_ensure_future(self, notify): + tasks = [] for subscription in self._iterate_subscriptions: - subscription._add(event) + maybe_coroutine = notify(subscription) + if asyncio.iscoroutine(maybe_coroutine): + tasks.append(maybe_coroutine) + if tasks: + return asyncio.ensure_future(asyncio.wait(tasks)) + else: + f = asyncio.get_event_loop().create_future() + f.set_result(None) + return f + + def add(self, event): + return self._notify_and_ensure_future( + lambda subscription: subscription._add(event) + ) def add_error(self, exception): - for subscription in self._iterate_subscriptions: - subscription._add_error(exception) + return self._notify_and_ensure_future( + lambda subscription: subscription._add_error(exception) + ) def close(self): for subscription in self._iterate_subscriptions: diff --git a/torba/testcase.py b/torba/testcase.py index d8f1f9ef0..30c2f3526 100644 --- a/torba/testcase.py +++ b/torba/testcase.py @@ -9,6 +9,7 @@ from torba.client.baseledger import BaseLedger from torba.client.baseaccount import BaseAccount from torba.client.basemanager import BaseWalletManager from torba.client.wallet import Wallet +from torba.client.util import satoshis_to_coins try: @@ -159,15 +160,13 @@ class IntegrationTestCase(AsyncioTestCase): async def asyncTearDown(self): await self.conductor.stop() + async def assertBalance(self, account, expected_balance: str): + balance = await account.get_balance() + self.assertEqual(satoshis_to_coins(balance), expected_balance) + def broadcast(self, tx): return self.ledger.broadcast(tx) - def get_balance(self, account=None, confirmations=0): - if account is None: - return self.manager.get_balance(confirmations=confirmations) - else: - return account.get_balance(confirmations=confirmations) - async def on_header(self, height): if self.ledger.headers.height < height: await self.ledger.on_header.where( @@ -175,8 +174,8 @@ class IntegrationTestCase(AsyncioTestCase): ) return True - def on_transaction_id(self, txid): - return self.ledger.on_transaction.where( + def on_transaction_id(self, txid, ledger=None): + return (ledger or self.ledger).on_transaction.where( lambda e: e.tx.id == txid )