refactored tx sync

This commit is contained in:
Lex Berezhny 2018-11-18 22:54:00 -05:00
parent e745b6c16e
commit 345d4f8ab1
13 changed files with 425 additions and 234 deletions

View file

@ -10,10 +10,10 @@ class BasicTransactionTests(IntegrationTestCase):
async def test_sending_and_receiving(self): async def test_sending_and_receiving(self):
account1, account2 = self.account, self.wallet.generate_account(self.ledger) account1, account2 = self.account, self.wallet.generate_account(self.ledger)
await self.ledger.update_account(account2) await self.ledger.subscribe_account(account2)
self.assertEqual(await self.get_balance(account1), 0) await self.assertBalance(account1, '0.0')
self.assertEqual(await self.get_balance(account2), 0) await self.assertBalance(account2, '0.0')
sendtxids = [] sendtxids = []
for i in range(5): for i in range(5):
@ -26,8 +26,8 @@ class BasicTransactionTests(IntegrationTestCase):
self.on_transaction_id(txid) for txid in sendtxids self.on_transaction_id(txid) for txid in sendtxids
]) ])
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5) await self.assertBalance(account1, '5.5')
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0) await self.assertBalance(account2, '0.0')
address2 = await account2.receiving.get_or_create_usable_address() address2 = await account2.receiving.get_or_create_usable_address()
hash2 = self.ledger.address_to_hash160(address2) hash2 = self.ledger.address_to_hash160(address2)
@ -41,8 +41,8 @@ class BasicTransactionTests(IntegrationTestCase):
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.ledger.wait(tx) # confirmed await self.ledger.wait(tx) # confirmed
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) await self.assertBalance(account1, '3.499802')
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) await self.assertBalance(account2, '2.0')
utxos = await self.account.get_utxos() utxos = await self.account.get_utxos()
tx = await self.ledger.transaction_class.create( tx = await self.ledger.transaction_class.create(

View file

@ -44,7 +44,8 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertEqual(len(addresses), 26) self.assertEqual(len(addresses), 26)
async def test_generate_keys_over_batch_threshold_saves_it_properly(self): async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
await self.account.receiving.generate_keys(0, 200) async with self.account.receiving.address_generator_lock:
await self.account.receiving._generate_keys(0, 200)
records = await self.account.receiving.get_address_records() records = await self.account.receiving.get_address_records()
self.assertEqual(201, len(records)) self.assertEqual(201, len(records))
@ -53,9 +54,10 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
self.assertIsInstance(account.receiving, HierarchicalDeterministic) self.assertIsInstance(account.receiving, HierarchicalDeterministic)
await account.receiving.generate_keys(4, 7) async with account.receiving.address_generator_lock:
await account.receiving.generate_keys(0, 3) await account.receiving._generate_keys(4, 7)
await account.receiving.generate_keys(8, 11) await account.receiving._generate_keys(0, 3)
await account.receiving._generate_keys(8, 11)
records = await account.receiving.get_address_records() records = await account.receiving.get_address_records()
self.assertEqual( self.assertEqual(
[r['position'] for r in records], [r['position'] for r in records],

View file

@ -128,7 +128,8 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \ .add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '') await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx return tx
async def create_tx_from_txo(self, txo, to_account, height): async def create_tx_from_txo(self, txo, to_account, height):
@ -139,8 +140,9 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \ .add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '') await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
return tx return tx
async def create_tx_to_nowhere(self, txo, height): async def create_tx_to_nowhere(self, txo, height):
@ -150,7 +152,8 @@ class TestQueries(AsyncioTestCase):
tx = ledger_class.transaction_class(height=height, is_verified=True) \ tx = ledger_class.transaction_class(height=height, is_verified=True) \
.add_inputs([self.txi(txo)]) \ .add_inputs([self.txi(txo)]) \
.add_outputs([self.txo(1, to_hash)]) .add_outputs([self.txo(1, to_hash)])
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '') await self.ledger.db.insert_transaction(tx)
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
return tx return tx
def txo(self, amount, address): def txo(self, amount, address):

View file

@ -16,6 +16,7 @@ class MockNetwork:
self.address = None self.address = None
self.get_history_called = [] self.get_history_called = []
self.get_transaction_called = [] self.get_transaction_called = []
self.is_connected = False
async def get_history(self, address): async def get_history(self, address):
self.get_history_called.append(address) self.get_history_called.append(address)
@ -85,16 +86,21 @@ class TestSynchronization(LedgerTestCase):
'abcd02': hexlify(get_transaction(get_output(2)).raw), 'abcd02': hexlify(get_transaction(get_output(2)).raw),
'abcd03': hexlify(get_transaction(get_output(3)).raw), 'abcd03': hexlify(get_transaction(get_output(3)).raw),
}) })
await self.ledger.update_history(address) await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:') self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address) await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, []) self.assertEqual(self.ledger.network.get_transaction_called, [])
@ -102,11 +108,17 @@ class TestSynchronization(LedgerTestCase):
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
await self.ledger.update_history(address) await self.ledger.update_history(address, '')
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
address_details = await self.ledger.db.get_address(address=address) address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:') self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
'047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828:3:'
)
class MocHeaderNetwork: class MocHeaderNetwork:

View file

@ -261,14 +261,14 @@ class TransactionIOBalancing(AsyncioTestCase):
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \ .add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
.add_outputs(utxos) .add_outputs(utxos)
save_tx = 'insert' await self.ledger.db.insert_transaction(self.funding_tx)
for utxo in utxos: for utxo in utxos:
await self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
save_tx, self.funding_tx, self.funding_tx,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], '' utxo.script.values['pubkey_hash'], ''
) )
save_tx = 'update'
return utxos return utxos

View file

@ -1,6 +1,41 @@
import unittest import unittest
from torba.client.util import ArithUint256 from torba.client.util import ArithUint256
from torba.client.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c
class TestCoinValueParsing(unittest.TestCase):
def test_good_output(self):
self.assertEqual(s2c(1), "0.00000001")
self.assertEqual(s2c(10**7), "0.1")
self.assertEqual(s2c(2*10**8), "2.0")
self.assertEqual(s2c(2*10**17), "2000000000.0")
def test_good_input(self):
self.assertEqual(c2s("0.00000001"), 1)
self.assertEqual(c2s("0.1"), 10**7)
self.assertEqual(c2s("1.0"), 10**8)
self.assertEqual(c2s("2.00000000"), 2*10**8)
self.assertEqual(c2s("2000000000.0"), 2*10**17)
def test_bad_input(self):
with self.assertRaises(ValueError):
c2s("1")
with self.assertRaises(ValueError):
c2s("-1.0")
with self.assertRaises(ValueError):
c2s("10000000000.0")
with self.assertRaises(ValueError):
c2s("1.000000000")
with self.assertRaises(ValueError):
c2s("-0")
with self.assertRaises(ValueError):
c2s("1")
with self.assertRaises(ValueError):
c2s(".1")
with self.assertRaises(ValueError):
c2s("1e-7")
class TestArithUint256(unittest.TestCase): class TestArithUint256(unittest.TestCase):

View file

@ -1,6 +1,7 @@
import asyncio
import random import random
import typing import typing
from typing import Dict, Tuple, Type, Optional, Any from typing import Dict, Tuple, Type, Optional, Any, List
from torba.client.mnemonic import Mnemonic from torba.client.mnemonic import Mnemonic
from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string
@ -15,12 +16,13 @@ class AddressManager:
name: str name: str
__slots__ = 'account', 'public_key', 'chain_number' __slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
def __init__(self, account, public_key, chain_number): def __init__(self, account, public_key, chain_number):
self.account = account self.account = account
self.public_key = public_key self.public_key = public_key
self.chain_number = chain_number self.chain_number = chain_number
self.address_generator_lock = asyncio.Lock()
@classmethod @classmethod
def from_dict(cls, account: 'BaseAccount', d: dict) \ def from_dict(cls, account: 'BaseAccount', d: dict) \
@ -60,11 +62,11 @@ class AddressManager:
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
raise NotImplementedError raise NotImplementedError
async def get_addresses(self, only_usable: bool = False, **constraints): async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
records = await self.get_address_records(only_usable=only_usable, **constraints) records = await self.get_address_records(only_usable=only_usable, **constraints)
return [r['address'] for r in records] return [r['address'] for r in records]
async def get_or_create_usable_address(self): async def get_or_create_usable_address(self) -> str:
addresses = await self.get_addresses(only_usable=True, limit=10) addresses = await self.get_addresses(only_usable=True, limit=10)
if addresses: if addresses:
return random.choice(addresses) return random.choice(addresses)
@ -87,8 +89,8 @@ class HierarchicalDeterministic(AddressManager):
@classmethod @classmethod
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]: def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
return ( return (
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})), cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2})) cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
) )
def to_dict_instance(self): def to_dict_instance(self):
@ -97,19 +99,7 @@ class HierarchicalDeterministic(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key.child(self.chain_number).child(index) return self.account.private_key.child(self.chain_number).child(index)
async def generate_keys(self, start: int, end: int): async def get_max_gap(self) -> int:
keys_batch, final_keys = [], []
for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch
)
final_keys.extend(keys_batch)
keys_batch.clear()
return [key[1].address for key in final_keys]
async def get_max_gap(self):
addresses = await self._query_addresses(order_by="position ASC") addresses = await self._query_addresses(order_by="position ASC")
max_gap = 0 max_gap = 0
current_gap = 0 current_gap = 0
@ -121,7 +111,8 @@ class HierarchicalDeterministic(AddressManager):
current_gap = 0 current_gap = 0
return max_gap return max_gap
async def ensure_address_gap(self): async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC") addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
existing_gap = 0 existing_gap = 0
@ -136,12 +127,27 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['position']+1 if addresses else 0 start = addresses[0]['position']+1 if addresses else 0
end = start + (self.gap - existing_gap) end = start + (self.gap - existing_gap)
new_keys = await self.generate_keys(start, end-1) new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.subscribe_addresses(self, new_keys)
return new_keys return new_keys
async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.')
keys_batch, final_keys = [], []
for index in range(start, end+1):
keys_batch.append((index, self.public_key.child(index)))
if index % 180 == 0 or index == end:
await self.account.ledger.db.add_keys(
self.account, self.chain_number, keys_batch
)
final_keys.extend(keys_batch)
keys_batch.clear()
return [key[1].address for key in final_keys]
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable: if only_usable:
constraints['used_times__lte'] = self.maximum_uses_per_address constraints['used_times__lt'] = self.maximum_uses_per_address
return self._query_addresses(order_by="used_times ASC, position ASC", **constraints) return self._query_addresses(order_by="used_times ASC, position ASC", **constraints)
@ -164,16 +170,19 @@ class SingleKey(AddressManager):
def get_private_key(self, index: int) -> PrivateKey: def get_private_key(self, index: int) -> PrivateKey:
return self.account.private_key return self.account.private_key
async def get_max_gap(self): async def get_max_gap(self) -> int:
return 0 return 0
async def ensure_address_gap(self): async def ensure_address_gap(self) -> List[str]:
async with self.address_generator_lock:
exists = await self.get_address_records() exists = await self.get_address_records()
if not exists: if not exists:
await self.account.ledger.db.add_keys( await self.account.ledger.db.add_keys(
self.account, self.chain_number, [(0, self.public_key)] self.account, self.chain_number, [(0, self.public_key)]
) )
return [self.public_key.address] new_keys = [self.public_key.address]
await self.account.ledger.subscribe_addresses(self, new_keys)
return new_keys
return [] return []
def get_address_records(self, only_usable: bool = False, **constraints): def get_address_records(self, only_usable: bool = False, **constraints):
@ -211,7 +220,7 @@ class BaseAccount:
generator_name = address_generator.get('name', HierarchicalDeterministic.name) generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name] self.address_generator = self.address_generators[generator_name]
self.receiving, self.change = self.address_generator.from_dict(self, address_generator) self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {self.receiving, self.change} self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
ledger.add_account(self) ledger.add_account(self)
wallet.add_account(self) wallet.add_account(self)
@ -320,12 +329,12 @@ class BaseAccount:
async def ensure_address_gap(self): async def ensure_address_gap(self):
addresses = [] addresses = []
for address_manager in self.address_managers: for address_manager in self.address_managers.values():
new_addresses = await address_manager.ensure_address_gap() new_addresses = await address_manager.ensure_address_gap()
addresses.extend(new_addresses) addresses.extend(new_addresses)
return addresses return addresses
async def get_addresses(self, **constraints): async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses('address', account=self, **constraints) rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
return [r[0] for r in rows] return [r[0] for r in rows]
@ -337,8 +346,7 @@ class BaseAccount:
def get_private_key(self, chain: int, index: int) -> PrivateKey: def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account." assert not self.encrypted, "Cannot get private key on encrypted wallet account."
address_manager = {0: self.receiving, 1: self.change}[chain] return self.address_managers[chain].get_private_key(index)
return address_manager.get_private_key(index)
def get_balance(self, confirmations: int = 0, **constraints): def get_balance(self, confirmations: int = 0, **constraints):
if confirmations > 0: if confirmations > 0:

View file

@ -169,13 +169,16 @@ class SQLiteMixin:
await self.db.close() await self.db.close()
@staticmethod @staticmethod
def _insert_sql(table: str, data: dict) -> Tuple[str, List]: def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]:
columns, values = [], [] columns, values = [], []
for column, value in data.items(): for column, value in data.items():
columns.append(column) columns.append(column)
values.append(value) values.append(value)
sql = "INSERT INTO {} ({}) VALUES ({})".format( or_ignore = ""
table, ', '.join(columns), ', '.join(['?'] * len(values)) if ignore_duplicate:
or_ignore = " OR IGNORE"
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values))
) )
return sql, values return sql, values
@ -273,60 +276,49 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source) 'script': sqlite3.Binary(txo.script.source)
} }
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history): async def insert_transaction(self, tx):
await self.db.execute(*self._insert_sql('tx', {
def _transaction(conn: sqlite3.Connection, save_tx, tx: BaseTransaction, address, txhash, history):
if save_tx == 'insert':
conn.execute(*self._insert_sql('tx', {
'txid': tx.id, 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw), 'raw': sqlite3.Binary(tx.raw),
'height': tx.height, 'height': tx.height,
'position': tx.position, 'position': tx.position,
'is_verified': tx.is_verified 'is_verified': tx.is_verified
})) }))
elif save_tx == 'update':
conn.execute(*self._update_sql("tx", { async def update_transaction(self, tx):
await self.db.execute(*self._update_sql("tx", {
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
existing_txos = set(map(itemgetter(0), conn.execute(*query( def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
"SELECT position FROM txo", txid=tx.id
)))) def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
for txo in tx.outputs: for txo in tx.outputs:
if txo.position in existing_txos:
continue
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
conn.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo))) conn.execute(*self._insert_sql(
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
))
elif txo.script.is_pay_script_hash: elif txo.script.is_pay_script_hash:
# TODO: implement script hash payments # TODO: implement script hash payments
log.warning('Database.save_transaction_io: pay script hash is not implemented!') log.warning('Database.save_transaction_io: pay script hash is not implemented!')
# lookup the address associated with each TXI (via its TXO)
txoid_to_address = {r[0]: r[1] for r in conn.execute(*query(
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
))}
# list of TXIs that have already been added
existing_txis = {r[0] for r in conn.execute(*query(
"SELECT txoid FROM txi", txid=tx.id
))}
for txi in tx.inputs: for txi in tx.inputs:
txoid = txi.txo_ref.id if txi.txo_ref.txo is not None:
new_txi = txoid not in existing_txis txo = txi.txo_ref.txo
address_matches = txoid_to_address.get(txoid) == address if txo.get_address(self.ledger) == address:
if new_txi and address_matches:
conn.execute(*self._insert_sql("txi", { conn.execute(*self._insert_sql("txi", {
'txid': tx.id, 'txid': tx.id,
'txoid': txoid, 'txoid': txo.id,
'address': address, 'address': address,
})) }, ignore_duplicate=True))
conn.execute( conn.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(history, history.count(':')//2, address) (history, history.count(':')//2, address)
) )
return self.db.run(_transaction, save_tx, tx, address, txhash, history)
return self.db.run(_transaction, tx, address, txhash, history)
async def reserve_outputs(self, txos, is_reserved=True): async def reserve_outputs(self, txos, is_reserved=True):
txoids = [txo.id for txo in txos] txoids = [txo.id for txo in txos]

View file

@ -5,7 +5,7 @@ from functools import partial
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from io import StringIO from io import StringIO
from typing import Dict, Type, Iterable from typing import Dict, Type, Iterable, List, Optional
from operator import itemgetter from operator import itemgetter
from collections import namedtuple from collections import namedtuple
@ -48,6 +48,51 @@ class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
pass pass
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx'
def __init__(self,
tx: Optional[basetransaction.BaseTransaction] = None,
lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self.tx = tx
@property
def tx(self):
return self._tx
@tx.setter
def tx(self, tx):
self._tx = tx
if tx is not None:
self.has_tx.set()
class SynchronizationMonitor:
def __init__(self):
self.done = asyncio.Event()
self.tasks = []
def add(self, coro):
len(self.tasks) < 1 and self.done.clear()
asyncio.ensure_future(self._monitor(coro))
def cancel(self):
for task in self.tasks:
task.cancel()
async def _monitor(self, coro):
task = asyncio.ensure_future(coro)
self.tasks.append(task)
try:
await task
finally:
self.tasks.remove(task)
len(self.tasks) < 1 and self.done.set()
class BaseLedger(metaclass=LedgerRegistry): class BaseLedger(metaclass=LedgerRegistry):
name: str name: str
@ -79,7 +124,8 @@ class BaseLedger(metaclass=LedgerRegistry):
) )
self.network = self.config.get('network') or self.network_class(self) self.network = self.config.get('network') or self.network_class(self)
self.network.on_header.listen(self.receive_header) self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.receive_status) self.network.on_status.listen(self.process_status_update)
self.accounts = [] self.accounts = []
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
@ -101,7 +147,8 @@ class BaseLedger(metaclass=LedgerRegistry):
) )
) )
self._transaction_processing_locks = {} self._tx_cache = {}
self.sync = SynchronizationMonitor()
self._utxo_reservation_lock = asyncio.Lock() self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock() self._header_processing_lock = asyncio.Lock()
@ -166,16 +213,14 @@ class BaseLedger(metaclass=LedgerRegistry):
def release_outputs(self, txos): def release_outputs(self, txos):
return self.db.release_outputs(txos) return self.db.release_outputs(txos)
async def get_local_status(self, address): async def get_local_status_and_history(self, address):
address_details = await self.db.get_address(address=address)
history = address_details['history']
return hexlify(sha256(history.encode())).decode() if history else None
async def get_local_history(self, address):
address_details = await self.db.get_address(address=address) address_details = await self.db.get_address(address=address)
history = address_details['history'] or '' history = address_details['history'] or ''
parts = history.split(':')[:-1] parts = history.split(':')[:-1]
return list(zip(parts[0::2], map(int, parts[1::2]))) return (
hexlify(sha256(history.encode())).decode() if history else None,
list(zip(parts[0::2], map(int, parts[1::2])))
)
@staticmethod @staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch): def get_root_of_merkle_tree(branches, branch_positions, working_branch):
@ -189,21 +234,13 @@ class BaseLedger(metaclass=LedgerRegistry):
working_branch = double_sha256(combined) working_branch = double_sha256(combined)
return hexlify(working_branch[::-1]) return hexlify(working_branch[::-1])
async def validate_transaction_and_set_position(self, tx, height, merkle):
if not height <= len(self.headers):
return False
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def start(self): async def start(self):
if not os.path.exists(self.path): if not os.path.exists(self.path):
os.mkdir(self.path) os.mkdir(self.path)
await asyncio.gather( await asyncio.wait([
self.db.open(), self.db.open(),
self.headers.open() self.headers.open()
) ])
first_connection = self.network.on_connected.first first_connection = self.network.on_connected.first
asyncio.ensure_future(self.network.start()) asyncio.ensure_future(self.network.start())
await first_connection await first_connection
@ -214,9 +251,15 @@ class BaseLedger(metaclass=LedgerRegistry):
log.info("Subscribing and updating accounts.") log.info("Subscribing and updating accounts.")
await self.update_headers() await self.update_headers()
await self.network.subscribe_headers() await self.network.subscribe_headers()
await self.update_accounts() import time
start = time.time()
await self.subscribe_accounts()
await self.sync.done.wait()
log.info(f'elapsed: {time.time()-start}')
async def stop(self): async def stop(self):
self.sync.cancel()
await self.sync.done.wait()
await self.network.stop() await self.network.stop()
await self.db.close() await self.db.close()
await self.headers.close() await self.headers.close()
@ -299,89 +342,144 @@ class BaseLedger(metaclass=LedgerRegistry):
height=header['height'], headers=header['hex'], subscription_update=True height=header['height'], headers=header['hex'], subscription_update=True
) )
async def update_accounts(self): async def subscribe_accounts(self):
return await asyncio.gather(*( if self.network.is_connected and self.accounts:
self.update_account(a) for a in self.accounts await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: baseaccount.BaseAccount):
for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap()
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
if self.network.is_connected and addresses:
await asyncio.wait([
self.subscribe_address(address_manager, address) for address in addresses
])
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
remote_status = await self.network.subscribe_address(address)
self.sync.add(self.update_history(address, remote_status, address_manager))
def process_status_update(self, update):
address, remote_status = update
self.sync.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status,
address_manager: baseaccount.AddressManager = None):
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status:
return
remote_history = await self.network.get_history(address)
cache_tasks = []
synced_history = StringIO()
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
if i < len(local_history) and local_history[i] == (txid, remote_height):
synced_history.write(f'{txid}:{remote_height}:')
else:
cache_tasks.append(asyncio.ensure_future(
self.cache_transaction(txid, remote_height)
)) ))
async def update_account(self, account: baseaccount.BaseAccount): for task in cache_tasks:
await account.ensure_address_gap() tx = await task
addresses = await account.get_addresses()
while addresses:
await asyncio.gather(*(self.subscribe_history(a) for a in addresses))
addresses = await account.ensure_address_gap()
def _prefetch_history(self, remote_history, local_history): check_db_for_txos = []
proofs, network_txs = {}, {} for txi in tx.inputs:
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): if txi.txo_ref.txo is not None:
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue continue
if remote_height > 0: cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
proofs[hex_id] = asyncio.ensure_future(self.network.get_merkle(hex_id, remote_height)) if cache_item is not None:
network_txs[hex_id] = asyncio.ensure_future(self.network.get_transaction(hex_id)) if cache_item.tx is None:
return proofs, network_txs await cache_item.has_tx.wait()
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.tx_ref.id)
async def update_history(self, address): referenced_txos = {
remote_history = await self.network.get_history(address) txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos)
local_history = await self.get_local_history(address) }
proofs, network_txs = self._prefetch_history(remote_history, local_history)
synced_history = StringIO() for txi in tx.inputs:
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): if txi.txo_ref.txo is not None:
synced_history.write('{}:{}:'.format(hex_id, remote_height))
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue continue
referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id)
if referenced_txos:
txi.txo_ref = referenced_txo.ref
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock()) synced_history.write(f'{tx.id}:{tx.height}:')
await lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
tx = await self.db.get_transaction(txid=hex_id)
save_tx = None
if tx is None:
_raw = await network_txs[hex_id]
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
tx.height = remote_height
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
await self.validate_transaction_and_set_position(tx, remote_height, await proofs[hex_id])
if save_tx is None:
save_tx = 'update'
await self.db.save_transaction_io( await self.db.save_transaction_io(
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue() tx, address, self.address_to_hash160(address), synced_history.getvalue()
)
log.debug(
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
self.get_id(), hex_id, address, tx.height, tx.is_verified
) )
self._on_transaction_controller.add(TransactionEvent(address, tx)) self._on_transaction_controller.add(TransactionEvent(address, tx))
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
await address_manager.ensure_address_gap()
async def cache_transaction(self, txid, remote_height):
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = self._tx_cache[txid] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
await cache_item.lock.acquire()
try:
tx = cache_item.tx
if tx is None:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
if tx is None:
# fetch from network
_raw = await self.network.get_transaction(txid)
if _raw:
tx = self.transaction_class(unhexlify(_raw))
await self.maybe_verify_transaction(tx, remote_height)
await self.db.insert_transaction(tx)
cache_item.tx = tx # make sure it's saved before caching it
return tx
if tx is None:
raise ValueError(f'Transaction {txid} was not in database and not on network.')
if 0 < remote_height and not tx.is_verified:
# tx from cache / db is not up-to-date
await self.maybe_verify_transaction(tx, remote_height)
await self.db.update_transaction(tx)
return tx
finally: finally:
lock.release() cache_item.lock.release()
if not lock.locked() and hex_id in self._transaction_processing_locks:
del self._transaction_processing_locks[hex_id]
async def subscribe_history(self, address): async def maybe_verify_transaction(self, tx, remote_height):
remote_status = await self.network.subscribe_address(address) tx.height = remote_height
local_status = await self.get_local_status(address) if 0 < remote_height <= len(self.headers):
if local_status != remote_status: merkle = await self.network.get_merkle(tx.id, remote_height)
await self.update_history(address) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[remote_height]
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def receive_status(self, response): async def get_address_manager_for_address(self, address) -> baseaccount.AddressManager:
address, remote_status = response details = await self.db.get_address(address=address)
local_status = await self.get_local_status(address) for account in self.accounts:
if local_status != remote_status: if account.id == details['account']:
await self.update_history(address) return account.address_managers[details['chain']]
def broadcast(self, tx): def broadcast(self, tx):
return self.network.broadcast(hexlify(tx.raw).decode()) return self.network.broadcast(hexlify(tx.raw).decode())

View file

@ -1,5 +1,25 @@
import re
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from typing import TypeVar, Sequence, Optional from typing import TypeVar, Sequence, Optional
from torba.client.constants import COIN
def coins_to_satoshis(coins):
if not isinstance(coins, str):
raise ValueError("{coins} must be a string")
result = re.search(r'^(\d{1,10})\.(\d{1,8})$', coins)
if result is not None:
whole, fractional = result.groups()
return int(whole+fractional.ljust(8, "0"))
raise ValueError("'{lbc}' is not a valid coin decimal")
def satoshis_to_coins(satoshis):
coins = '{:.8f}'.format(satoshis / COIN).rstrip('0')
if coins.endswith('.'):
return coins+'0'
else:
return coins
T = TypeVar('T') T = TypeVar('T')

View file

@ -112,7 +112,7 @@ class Conductor:
class WalletNode: class WalletNode:
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger], def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
verbose: bool = False) -> None: verbose: bool = False, api_port: int = 5279) -> None:
self.manager_class = manager_class self.manager_class = manager_class
self.ledger_class = ledger_class self.ledger_class = ledger_class
self.verbose = verbose self.verbose = verbose
@ -121,8 +121,9 @@ class WalletNode:
self.wallet: Optional[Wallet] = None self.wallet: Optional[Wallet] = None
self.account: Optional[BaseAccount] = None self.account: Optional[BaseAccount] = None
self.data_path: Optional[str] = None self.data_path: Optional[str] = None
self.api_port = api_port
async def start(self): async def start(self, seed=None):
self.data_path = tempfile.mkdtemp() self.data_path = tempfile.mkdtemp()
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json') wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
with open(wallet_file_name, 'w') as wallet_file: with open(wallet_file_name, 'w') as wallet_file:
@ -130,6 +131,7 @@ class WalletNode:
self.manager = self.manager_class.from_config({ self.manager = self.manager_class.from_config({
'ledgers': { 'ledgers': {
self.ledger_class.get_id(): { self.ledger_class.get_id(): {
'api_port': self.api_port,
'default_servers': [('localhost', 1984)], 'default_servers': [('localhost', 1984)],
'data_path': self.data_path 'data_path': self.data_path
} }
@ -138,7 +140,12 @@ class WalletNode:
}) })
self.ledger = self.manager.ledgers[self.ledger_class] self.ledger = self.manager.ledgers[self.ledger_class]
self.wallet = self.manager.default_wallet self.wallet = self.manager.default_wallet
if seed is None:
self.wallet.generate_account(self.ledger) self.wallet.generate_account(self.ledger)
else:
self.ledger.account_class.from_dict(
self.ledger, self.wallet, {'seed': seed}
)
self.account = self.wallet.default_account self.account = self.wallet.default_account
await self.manager.start() await self.manager.start()

View file

@ -29,17 +29,17 @@ class BroadcastSubscription:
def _add(self, data): def _add(self, data):
if self.can_fire and self._on_data is not None: if self.can_fire and self._on_data is not None:
maybe_coroutine = self._on_data(data) return self._on_data(data)
if asyncio.iscoroutine(maybe_coroutine):
asyncio.ensure_future(maybe_coroutine)
def _add_error(self, exception): def _add_error(self, exception):
if self.can_fire and self._on_error is not None: if self.can_fire and self._on_error is not None:
self._on_error(exception) return self._on_error(exception)
def _close(self): def _close(self):
try:
if self.can_fire and self._on_done is not None: if self.can_fire and self._on_done is not None:
self._on_done() return self._on_done()
finally:
self.is_closed = True self.is_closed = True
@ -62,13 +62,28 @@ class StreamController:
next_sub = next_sub._next next_sub = next_sub._next
yield subscription yield subscription
def add(self, event): def _notify_and_ensure_future(self, notify):
tasks = []
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:
subscription._add(event) maybe_coroutine = notify(subscription)
if asyncio.iscoroutine(maybe_coroutine):
tasks.append(maybe_coroutine)
if tasks:
return asyncio.ensure_future(asyncio.wait(tasks))
else:
f = asyncio.get_event_loop().create_future()
f.set_result(None)
return f
def add(self, event):
return self._notify_and_ensure_future(
lambda subscription: subscription._add(event)
)
def add_error(self, exception): def add_error(self, exception):
for subscription in self._iterate_subscriptions: return self._notify_and_ensure_future(
subscription._add_error(exception) lambda subscription: subscription._add_error(exception)
)
def close(self): def close(self):
for subscription in self._iterate_subscriptions: for subscription in self._iterate_subscriptions:

View file

@ -9,6 +9,7 @@ from torba.client.baseledger import BaseLedger
from torba.client.baseaccount import BaseAccount from torba.client.baseaccount import BaseAccount
from torba.client.basemanager import BaseWalletManager from torba.client.basemanager import BaseWalletManager
from torba.client.wallet import Wallet from torba.client.wallet import Wallet
from torba.client.util import satoshis_to_coins
try: try:
@ -159,15 +160,13 @@ class IntegrationTestCase(AsyncioTestCase):
async def asyncTearDown(self): async def asyncTearDown(self):
await self.conductor.stop() await self.conductor.stop()
async def assertBalance(self, account, expected_balance: str):
balance = await account.get_balance()
self.assertEqual(satoshis_to_coins(balance), expected_balance)
def broadcast(self, tx): def broadcast(self, tx):
return self.ledger.broadcast(tx) return self.ledger.broadcast(tx)
def get_balance(self, account=None, confirmations=0):
if account is None:
return self.manager.get_balance(confirmations=confirmations)
else:
return account.get_balance(confirmations=confirmations)
async def on_header(self, height): async def on_header(self, height):
if self.ledger.headers.height < height: if self.ledger.headers.height < height:
await self.ledger.on_header.where( await self.ledger.on_header.where(
@ -175,8 +174,8 @@ class IntegrationTestCase(AsyncioTestCase):
) )
return True return True
def on_transaction_id(self, txid): def on_transaction_id(self, txid, ledger=None):
return self.ledger.on_transaction.where( return (ledger or self.ledger).on_transaction.where(
lambda e: e.tx.id == txid lambda e: e.tx.id == txid
) )