refactored tx sync

This commit is contained in:
Lex Berezhny 2018-11-18 22:54:00 -05:00
parent e745b6c16e
commit 345d4f8ab1
13 changed files with 425 additions and 234 deletions

View file

@ -10,10 +10,10 @@ class BasicTransactionTests(IntegrationTestCase):
async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
await self.ledger.update_account(account2)
await self.ledger.subscribe_account(account2)
self.assertEqual(await self.get_balance(account1), 0)
self.assertEqual(await self.get_balance(account2), 0)
await self.assertBalance(account1, '0.0')
await self.assertBalance(account2, '0.0')
sendtxids = []
for i in range(5):
@ -26,8 +26,8 @@ class BasicTransactionTests(IntegrationTestCase):
self.on_transaction_id(txid) for txid in sendtxids
])
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
await self.assertBalance(account1, '5.5')
await self.assertBalance(account2, '0.0')
address2 = await account2.receiving.get_or_create_usable_address()
hash2 = self.ledger.address_to_hash160(address2)
@ -41,8 +41,8 @@ class BasicTransactionTests(IntegrationTestCase):
await self.blockchain.generate(1)
await self.ledger.wait(tx) # confirmed
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
await self.assertBalance(account1, '3.499802')
await self.assertBalance(account2, '2.0')
utxos = await self.account.get_utxos()
tx = await self.ledger.transaction_class.create(

View file

@ -44,7 +44,8 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertEqual(len(addresses), 26)
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
await self.account.receiving.generate_keys(0, 200)
async with self.account.receiving.address_generator_lock:
await self.account.receiving._generate_keys(0, 200)
records = await self.account.receiving.get_address_records()
self.assertEqual(201, len(records))
@ -53,9 +54,10 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
await account.receiving.generate_keys(4, 7)
await account.receiving.generate_keys(0, 3)
await account.receiving.generate_keys(8, 11)
async with account.receiving.address_generator_lock:
await account.receiving._generate_keys(4, 7)
await account.receiving._generate_keys(0, 3)
await account.receiving._generate_keys(8, 11)
records = await account.receiving.get_address_records()
self.assertEqual(
[r['position'] for r in records],

View file

@ -128,7 +128,8 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx
async def create_tx_from_txo(self, txo, to_account, height):
@ -139,8 +140,9 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx
async def create_tx_to_nowhere(self, txo, height):
@ -150,7 +152,8 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
return tx
def txo(self, amount, address):

View file

@ -16,6 +16,7 @@ class MockNetwork:
self.address = None
self.get_history_called = []
self.get_transaction_called = []
self.is_connected = False
async def get_history(self, address):
self.get_history_called.append(address)
@ -85,16 +86,21 @@ class TestSynchronization(LedgerTestCase):
'abcd02': hexlify(get_transaction(get_output(2)).raw),
'abcd03': hexlify(get_transaction(get_output(3)).raw),
})
await self.ledger.update_history(address)
await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address)
await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, [])
@ -102,11 +108,17 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address)
await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
'047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828:3:'
)
class MocHeaderNetwork:

View file

@ -261,14 +261,14 @@ class TransactionIOBalancing(AsyncioTestCase):
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
.add_outputs(utxos)
save_tx = 'insert'
await self.ledger.db.insert_transaction(self.funding_tx)
for utxo in utxos:
await self.ledger.db.save_transaction_io(
save_tx, self.funding_tx,
self.funding_tx,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], ''
)
save_tx = 'update'
return utxos

View file

@ -1,6 +1,41 @@
import unittest
from torba.client.util import ArithUint256
from torba.client.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c
class TestCoinValueParsing(unittest.TestCase):
def test_good_output(self):
self.assertEqual(s2c(1), "0.00000001")
self.assertEqual(s2c(10**7), "0.1")
self.assertEqual(s2c(2*10**8), "2.0")
self.assertEqual(s2c(2*10**17), "2000000000.0")
def test_good_input(self):
self.assertEqual(c2s("0.00000001"), 1)
self.assertEqual(c2s("0.1"), 10**7)
self.assertEqual(c2s("1.0"), 10**8)
self.assertEqual(c2s("2.00000000"), 2*10**8)
self.assertEqual(c2s("2000000000.0"), 2*10**17)
def test_bad_input(self):
with self.assertRaises(ValueError):
c2s("1")
with self.assertRaises(ValueError):
c2s("-1.0")
with self.assertRaises(ValueError):
c2s("10000000000.0")
with self.assertRaises(ValueError):
c2s("1.000000000")
with self.assertRaises(ValueError):
c2s("-0")
with self.assertRaises(ValueError):
c2s("1")
with self.assertRaises(ValueError):
c2s(".1")
with self.assertRaises(ValueError):
c2s("1e-7")
class TestArithUint256(unittest.TestCase):

View file

@ -1,6 +1,7 @@
import asyncio
import random
import typing
from typing import Dict, Tuple, Type, Optional, Any
from typing import Dict, Tuple, Type, Optional, Any, List
from torba.client.mnemonic import Mnemonic
from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string
@ -15,12 +16,13 @@ class AddressManager:
name: str
__slots__ = 'account', 'public_key', 'chain_number'
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
def __init__(self, account, public_key, chain_number):
self.account = account
self.public_key = public_key
self.chain_number = chain_number
self.address_generator_lock = asyncio.Lock()
@classmethod
def from_dict(cls, account: 'BaseAccount', d: dict) \
@ -60,11 +62,11 @@ class AddressManager:
def get_address_records(self, only_usable: bool = False, **constraints):
raise NotImplementedError
async def get_addresses(self, only_usable: bool = False, **constraints):
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
records = await self.get_address_records(only_usable=only_usable, **constraints)
return [r['address'] for r in records]
async def get_or_create_usable_address(self):
async def get_or_create_usable_address(self) -> str:
addresses = await self.get_addresses(only_usable=True, limit=10)
if addresses:
return random.choice(addresses)
@ -87,8 +89,8 @@ class HierarchicalDeterministic(AddressManager):
@classmethod
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
return (
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})),
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2}))
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
)
def to_dict_instance(self):
@ -97,19 +99,7 @@ class HierarchicalDeterministic(AddressManager):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index)
async def generate_keys(self, start: int, end: int):
keys_batch, final_keys = [], []
for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch
)
final_keys.extend(keys_batch)
keys_batch.clear()
return [key[1].address for key in final_keys]
async def get_max_gap(self):
async def get_max_gap(self) -> int:
addresses = await self._query_addresses(order_by="position ASC")
max_gap = 0
current_gap = 0
@ -121,7 +111,8 @@ class HierarchicalDeterministic(AddressManager):
current_gap = 0
return max_gap
async def ensure_address_gap(self):
async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
existing_gap = 0
@ -136,12 +127,27 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['position']+1 if addresses else 0
end = start + (self.gap - existing_gap)
new_keys = await self.generate_keys(start, end-1)
new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.subscribe_addresses(self, new_keys)
return new_keys
async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.')
keys_batch, final_keys = [], []
for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch
)
final_keys.extend(keys_batch)
keys_batch.clear()
return [key[1].address for key in final_keys]
def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable:
constraints['used_times__lte'] = self.maximum_uses_per_address
constraints['used_times__lt'] = self.maximum_uses_per_address
return self._query_addresses(order_by="used_times ASC, position ASC", **constraints)
@ -164,16 +170,19 @@ class SingleKey(AddressManager):
def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key
async def get_max_gap(self):
async def get_max_gap(self) -> int:
return 0
async def ensure_address_gap(self):
async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
exists = await self.get_address_records()
if not exists:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, [(0, self.public_key)]
)
return [self.public_key.address]
new_keys = [self.public_key.address]
await self.account.ledger.subscribe_addresses(self, new_keys)
return new_keys
return []
def get_address_records(self, only_usable: bool = False, **constraints):
@ -211,7 +220,7 @@ class BaseAccount:
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name]
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {self.receiving, self.change}
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
ledger.add_account(self)
wallet.add_account(self)
@ -320,12 +329,12 @@ class BaseAccount:
async def ensure_address_gap(self):
addresses = []
for address_manager in self.address_managers:
for address_manager in self.address_managers.values():
new_addresses = await address_manager.ensure_address_gap()
addresses.extend(new_addresses)
return addresses
async def get_addresses(self, **constraints):
async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
return [r[0] for r in rows]
@ -337,8 +346,7 @@ class BaseAccount:
def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
address_manager = {0: self.receiving, 1: self.change}[chain]
return address_manager.get_private_key(index)
return self.address_managers[chain].get_private_key(index)
def get_balance(self, confirmations: int = 0, **constraints):
if confirmations > 0:

View file

@ -169,13 +169,16 @@ class SQLiteMixin:
await self.db.close()
@staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]:
columns, values = [], []
for column, value in data.items():
columns.append(column)
values.append(value)
sql = "INSERT INTO {} ({}) VALUES ({})".format(
table, ', '.join(columns), ', '.join(['?'] * len(values))
or_ignore = ""
if ignore_duplicate:
or_ignore = " OR IGNORE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values))
)
return sql, values
@ -273,60 +276,49 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source)
}
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
def _transaction(conn: sqlite3.Connection, save_tx, tx: BaseTransaction, address, txhash, history):
if save_tx == 'insert':
conn.execute(*self._insert_sql('tx', {
async def insert_transaction(self, tx):
await self.db.execute(*self._insert_sql('tx', {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': tx.height,
'position': tx.position,
'is_verified': tx.is_verified
}))
elif save_tx == 'update':
conn.execute(*self._update_sql("tx", {
async def update_transaction(self, tx):
await self.db.execute(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = set(map(itemgetter(0), conn.execute(*query(
"SELECT position FROM txo", txid=tx.id
))))
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
conn.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
))
elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in conn.execute(*query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
))}
# list of TXIs that have already been added
existing_txis = {r[0] for r in conn.execute(*query(
"SELECT txoid FROM txi", txid=tx.id
))}
for txi in tx.inputs:
txoid = txi.txo_ref.id
new_txi = txoid not in existing_txis
address_matches = txoid_to_address.get(txoid) == address
if new_txi and address_matches:
if txi.txo_ref.txo is not None:
txo = txi.txo_ref.txo
if txo.get_address(self.ledger) == address:
conn.execute(*self._insert_sql("txi", {
'txid': tx.id,
'txoid': txoid,
'txoid': txo.id,
'address': address,
}))
}, ignore_duplicate=True))
conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address)
)
return self.db.run(_transaction, save_tx, tx, address, txhash, history)
return self.db.run(_transaction, tx, address, txhash, history)
async def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos]

View file

@ -5,7 +5,7 @@ from functools import partial
from binascii import hexlify, unhexlify
from io import StringIO
from typing import Dict, Type, Iterable
from typing import Dict, Type, Iterable, List, Optional
from operator import itemgetter
from collections import namedtuple
@ -48,6 +48,51 @@ class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
pass
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx'
def __init__(self,
tx: Optional[basetransaction.BaseTransaction] = None,
lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self.tx = tx
@property
def tx(self):
return self._tx
@tx.setter
def tx(self, tx):
self._tx = tx
if tx is not None:
self.has_tx.set()
class SynchronizationMonitor:
def __init__(self):
self.done = asyncio.Event()
self.tasks = []
def add(self, coro):
len(self.tasks) < 1 and self.done.clear()
asyncio.ensure_future(self._monitor(coro))
def cancel(self):
for task in self.tasks:
task.cancel()
async def _monitor(self, coro):
task = asyncio.ensure_future(coro)
self.tasks.append(task)
try:
await task
finally:
self.tasks.remove(task)
len(self.tasks) < 1 and self.done.set()
class BaseLedger(metaclass=LedgerRegistry):
name: str
@ -79,7 +124,8 @@ class BaseLedger(metaclass=LedgerRegistry):
)
self.network = self.config.get('network') or self.network_class(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.receive_status)
self.network.on_status.listen(self.process_status_update)
self.accounts = []
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
@ -101,7 +147,8 @@ class BaseLedger(metaclass=LedgerRegistry):
)
)
self._transaction_processing_locks = {}
self._tx_cache = {}
self.sync = SynchronizationMonitor()
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
@ -166,16 +213,14 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_outputs(self, txos):
return self.db.release_outputs(txos)
async def get_local_status(self, address):
address_details = await self.db.get_address(address=address)
history = address_details['history']
return hexlify(sha256(history.encode())).decode() if history else None
async def get_local_history(self, address):
async def get_local_status_and_history(self, address):
address_details = await self.db.get_address(address=address)
history = address_details['history'] or ''
parts = history.split(':')[:-1]
return list(zip(parts[0::2], map(int, parts[1::2])))
return (
hexlify(sha256(history.encode())).decode() if history else None,
list(zip(parts[0::2], map(int, parts[1::2])))
)
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
@ -189,21 +234,13 @@ class BaseLedger(metaclass=LedgerRegistry):
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
async def validate_transaction_and_set_position(self, tx, height, merkle):
if not height <= len(self.headers):
return False
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def start(self):
if not os.path.exists(self.path):
os.mkdir(self.path)
await asyncio.gather(
await asyncio.wait([
self.db.open(),
self.headers.open()
)
])
first_connection = self.network.on_connected.first
asyncio.ensure_future(self.network.start())
await first_connection
@ -214,9 +251,15 @@ class BaseLedger(metaclass=LedgerRegistry):
log.info("Subscribing and updating accounts.")
await self.update_headers()
await self.network.subscribe_headers()
await self.update_accounts()
import time
start = time.time()
await self.subscribe_accounts()
await self.sync.done.wait()
log.info(f'elapsed: {time.time()-start}')
async def stop(self):
self.sync.cancel()
await self.sync.done.wait()
await self.network.stop()
await self.db.close()
await self.headers.close()
@ -299,89 +342,144 @@ class BaseLedger(metaclass=LedgerRegistry):
height=header['height'], headers=header['hex'], subscription_update=True
)
async def update_accounts(self):
return await asyncio.gather(*(
self.update_account(a) for a in self.accounts
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: baseaccount.BaseAccount):
for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap()
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses:
await asyncio.wait([
self.subscribe_address(address_manager, address) for address in addresses
])
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self.sync.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update):
address, remote_status = update
self.sync.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status,
address_manager: baseaccount.AddressManager = None):
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status:
return
remote_history = await self.network.get_history(address)
cache_tasks = []
synced_history = StringIO()
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (txid, remote_height):
synced_history.write(f'{txid}:{remote_height}:')
else:
cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height)
))
async def update_account(self, account: baseaccount.BaseAccount):
await account.ensure_address_gap()
addresses = await account.get_addresses()
while addresses:
await asyncio.gather(*(self.subscribe_history(a) for a in addresses))
addresses = await account.ensure_address_gap()
for task in cache_tasks:
tx = await task
def _prefetch_history(self, remote_history, local_history):
proofs, network_txs = {}, {}
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
if remote_height > 0:
proofs[hex_id] = asyncio.ensure_future(self.network.get_merkle(hex_id, remote_height))
network_txs[hex_id] = asyncio.ensure_future(self.network.get_transaction(hex_id))
return proofs, network_txs
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.tx_ref.id)
async def update_history(self, address):
remote_history = await self.network.get_history(address)
local_history = await self.get_local_history(address)
proofs, network_txs = self._prefetch_history(remote_history, local_history)
referenced_txos = {
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos)
}
synced_history = StringIO()
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
synced_history.write('{}:{}:'.format(hex_id, remote_height))
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id)
if referenced_txos:
txi.txo_ref = referenced_txo.ref
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
await lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
tx = await self.db.get_transaction(txid=hex_id)
save_tx = None
if tx is None:
_raw = await network_txs[hex_id]
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
tx.height = remote_height
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
await self.validate_transaction_and_set_position(tx, remote_height, await proofs[hex_id])
if save_tx is None:
save_tx = 'update'
synced_history.write(f'{tx.id}:{tx.height}:')
await self.db.save_transaction_io(
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
)
log.debug(
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
self.get_id(), hex_id, address, tx.height, tx.is_verified
tx, address, self.address_to_hash160(address), synced_history.getvalue()
)
self._on_transaction_controller.add(TransactionEvent(address, tx))
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
await address_manager.ensure_address_gap()
async def cache_transaction(self, txid, remote_height):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
await cache_item.lock.acquire()
try:
tx = cache_item.tx
if tx is None:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
if tx is None:
# fetch from network
_raw = await self.network.get_transaction(txid)
if _raw:
tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
await self.db.insert_transaction(tx)
cache_item.tx = tx # make sure it's saved before caching it
return tx
if tx is None:
raise ValueError(f'Transaction {txid} was not in database and not on network.')
if 0 < remote_height and not tx.is_verified:
# tx from cache / db is not up-to-date
await self.maybe_verify_transaction(tx, remote_height)
await self.db.update_transaction(tx)
return tx
finally:
lock.release()
if not lock.locked() and hex_id in self._transaction_processing_locks:
del self._transaction_processing_locks[hex_id]
cache_item.lock.release()
async def subscribe_history(self, address):
remote_status = await self.network.subscribe_address(address)
local_status = await self.get_local_status(address)
if local_status != remote_status:
await self.update_history(address)
async def maybe_verify_transaction(self, tx, remote_height):
tx.height = remote_height
if 0 < remote_height <= len(self.headers):
merkle = await self.network.get_merkle(tx.id, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def receive_status(self, response):
address, remote_status = response
local_status = await self.get_local_status(address)
if local_status != remote_status:
await self.update_history(address)
async def get_address_manager_for_address(self, address) -> baseaccount.AddressManager:
details = await self.db.get_address(address=address)
for account in self.accounts:
if account.id == details['account']:
return account.address_managers[details['chain']]
def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw).decode())

View file

@ -1,5 +1,25 @@
import re
from binascii import unhexlify, hexlify
from typing import TypeVar, Sequence, Optional
from torba.client.constants import COIN
def coins_to_satoshis(coins):
if not isinstance(coins, str):
raise ValueError("{coins} must be a string")
result = re.search(r'^(\d{1,10})\.(\d{1,8})$', coins)
if result is not None:
whole, fractional = result.groups()
return int(whole+fractional.ljust(8, "0"))
raise ValueError("'{lbc}' is not a valid coin decimal")
def satoshis_to_coins(satoshis):
coins = '{:.8f}'.format(satoshis / COIN).rstrip('0')
if coins.endswith('.'):
return coins+'0'
else:
return coins
T = TypeVar('T')

View file

@ -112,7 +112,7 @@ class Conductor:
class WalletNode:
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
verbose: bool = False) -> None:
verbose: bool = False, api_port: int = 5279) -> None:
self.manager_class = manager_class
self.ledger_class = ledger_class
self.verbose = verbose
@ -121,8 +121,9 @@ class WalletNode:
self.wallet: Optional[Wallet] = None
self.account: Optional[BaseAccount] = None
self.data_path: Optional[str] = None
self.api_port = api_port
async def start(self):
async def start(self, seed=None):
self.data_path = tempfile.mkdtemp()
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
with open(wallet_file_name, 'w') as wallet_file:
@ -130,6 +131,7 @@ class WalletNode:
self.manager = self.manager_class.from_config({
'ledgers': {
self.ledger_class.get_id(): {
'api_port': self.api_port,
'default_servers': [('localhost', 1984)],
'data_path': self.data_path
}
@ -138,7 +140,12 @@ class WalletNode:
})
self.ledger = self.manager.ledgers[self.ledger_class]
self.wallet = self.manager.default_wallet
if seed is None:
self.wallet.generate_account(self.ledger)
else:
self.ledger.account_class.from_dict(
self.ledger, self.wallet, {'seed': seed}
)
self.account = self.wallet.default_account
await self.manager.start()

View file

@ -29,17 +29,17 @@ class BroadcastSubscription:
def _add(self, data):
if self.can_fire and self._on_data is not None:
maybe_coroutine = self._on_data(data)
if asyncio.iscoroutine(maybe_coroutine):
asyncio.ensure_future(maybe_coroutine)
return self._on_data(data)
def _add_error(self, exception):
if self.can_fire and self._on_error is not None:
self._on_error(exception)
return self._on_error(exception)
def _close(self):
try:
if self.can_fire and self._on_done is not None:
self._on_done()
return self._on_done()
finally:
self.is_closed = True
@ -62,13 +62,28 @@ class StreamController:
next_sub = next_sub._next
yield subscription
def add(self, event):
def _notify_and_ensure_future(self, notify):
tasks = []
for subscription in self._iterate_subscriptions:
subscription._add(event)
maybe_coroutine = notify(subscription)
if asyncio.iscoroutine(maybe_coroutine):
tasks.append(maybe_coroutine)
if tasks:
return asyncio.ensure_future(asyncio.wait(tasks))
else:
f = asyncio.get_event_loop().create_future()
f.set_result(None)
return f
def add(self, event):
return self._notify_and_ensure_future(
lambda subscription: subscription._add(event)
)
def add_error(self, exception):
for subscription in self._iterate_subscriptions:
subscription._add_error(exception)
return self._notify_and_ensure_future(
lambda subscription: subscription._add_error(exception)
)
def close(self):
for subscription in self._iterate_subscriptions:

View file

@ -9,6 +9,7 @@ from torba.client.baseledger import BaseLedger
from torba.client.baseaccount import BaseAccount
from torba.client.basemanager import BaseWalletManager
from torba.client.wallet import Wallet
from torba.client.util import satoshis_to_coins
try:
@ -159,15 +160,13 @@ class IntegrationTestCase(AsyncioTestCase):
async def asyncTearDown(self):
await self.conductor.stop()
async def assertBalance(self, account, expected_balance: str):
balance = await account.get_balance()
self.assertEqual(satoshis_to_coins(balance), expected_balance)
def broadcast(self, tx):
return self.ledger.broadcast(tx)
def get_balance(self, account=None, confirmations=0):
if account is None:
return self.manager.get_balance(confirmations=confirmations)
else:
return account.get_balance(confirmations=confirmations)
async def on_header(self, height):
if self.ledger.headers.height < height:
await self.ledger.on_header.where(
@ -175,8 +174,8 @@ class IntegrationTestCase(AsyncioTestCase):
)
return True
def on_transaction_id(self, txid):
return self.ledger.on_transaction.where(
def on_transaction_id(self, txid, ledger=None):
return (ledger or self.ledger).on_transaction.where(
lambda e: e.tx.id == txid
)