refactored tx sync
This commit is contained in:
parent
e745b6c16e
commit
345d4f8ab1
13 changed files with 425 additions and 234 deletions
|
@ -10,10 +10,10 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
|
||||
async def test_sending_and_receiving(self):
|
||||
account1, account2 = self.account, self.wallet.generate_account(self.ledger)
|
||||
await self.ledger.update_account(account2)
|
||||
await self.ledger.subscribe_account(account2)
|
||||
|
||||
self.assertEqual(await self.get_balance(account1), 0)
|
||||
self.assertEqual(await self.get_balance(account2), 0)
|
||||
await self.assertBalance(account1, '0.0')
|
||||
await self.assertBalance(account2, '0.0')
|
||||
|
||||
sendtxids = []
|
||||
for i in range(5):
|
||||
|
@ -26,8 +26,8 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
self.on_transaction_id(txid) for txid in sendtxids
|
||||
])
|
||||
|
||||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 5.5)
|
||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 0)
|
||||
await self.assertBalance(account1, '5.5')
|
||||
await self.assertBalance(account2, '0.0')
|
||||
|
||||
address2 = await account2.receiving.get_or_create_usable_address()
|
||||
hash2 = self.ledger.address_to_hash160(address2)
|
||||
|
@ -41,8 +41,8 @@ class BasicTransactionTests(IntegrationTestCase):
|
|||
await self.blockchain.generate(1)
|
||||
await self.ledger.wait(tx) # confirmed
|
||||
|
||||
self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
|
||||
self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)
|
||||
await self.assertBalance(account1, '3.499802')
|
||||
await self.assertBalance(account2, '2.0')
|
||||
|
||||
utxos = await self.account.get_utxos()
|
||||
tx = await self.ledger.transaction_class.create(
|
||||
|
|
|
@ -44,7 +44,8 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
|||
self.assertEqual(len(addresses), 26)
|
||||
|
||||
async def test_generate_keys_over_batch_threshold_saves_it_properly(self):
|
||||
await self.account.receiving.generate_keys(0, 200)
|
||||
async with self.account.receiving.address_generator_lock:
|
||||
await self.account.receiving._generate_keys(0, 200)
|
||||
records = await self.account.receiving.get_address_records()
|
||||
self.assertEqual(201, len(records))
|
||||
|
||||
|
@ -53,9 +54,10 @@ class TestHierarchicalDeterministicAccount(AsyncioTestCase):
|
|||
|
||||
self.assertIsInstance(account.receiving, HierarchicalDeterministic)
|
||||
|
||||
await account.receiving.generate_keys(4, 7)
|
||||
await account.receiving.generate_keys(0, 3)
|
||||
await account.receiving.generate_keys(8, 11)
|
||||
async with account.receiving.address_generator_lock:
|
||||
await account.receiving._generate_keys(4, 7)
|
||||
await account.receiving._generate_keys(0, 3)
|
||||
await account.receiving._generate_keys(8, 11)
|
||||
records = await account.receiving.get_address_records()
|
||||
self.assertEqual(
|
||||
[r['position'] for r in records],
|
||||
|
|
|
@ -128,7 +128,8 @@ class TestQueries(AsyncioTestCase):
|
|||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(self.txo(1, NULL_HASH))]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
await self.ledger.db.save_transaction_io('insert', tx, to_address, to_hash, '')
|
||||
await self.ledger.db.insert_transaction(tx)
|
||||
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
||||
return tx
|
||||
|
||||
async def create_tx_from_txo(self, txo, to_account, height):
|
||||
|
@ -139,8 +140,9 @@ class TestQueries(AsyncioTestCase):
|
|||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(txo)]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
await self.ledger.db.save_transaction_io('', tx, to_address, to_hash, '')
|
||||
await self.ledger.db.insert_transaction(tx)
|
||||
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
||||
await self.ledger.db.save_transaction_io(tx, to_address, to_hash, '')
|
||||
return tx
|
||||
|
||||
async def create_tx_to_nowhere(self, txo, height):
|
||||
|
@ -150,7 +152,8 @@ class TestQueries(AsyncioTestCase):
|
|||
tx = ledger_class.transaction_class(height=height, is_verified=True) \
|
||||
.add_inputs([self.txi(txo)]) \
|
||||
.add_outputs([self.txo(1, to_hash)])
|
||||
await self.ledger.db.save_transaction_io('insert', tx, from_address, from_hash, '')
|
||||
await self.ledger.db.insert_transaction(tx)
|
||||
await self.ledger.db.save_transaction_io(tx, from_address, from_hash, '')
|
||||
return tx
|
||||
|
||||
def txo(self, amount, address):
|
||||
|
|
|
@ -16,6 +16,7 @@ class MockNetwork:
|
|||
self.address = None
|
||||
self.get_history_called = []
|
||||
self.get_transaction_called = []
|
||||
self.is_connected = False
|
||||
|
||||
async def get_history(self, address):
|
||||
self.get_history_called.append(address)
|
||||
|
@ -85,16 +86,21 @@ class TestSynchronization(LedgerTestCase):
|
|||
'abcd02': hexlify(get_transaction(get_output(2)).raw),
|
||||
'abcd03': hexlify(get_transaction(get_output(3)).raw),
|
||||
})
|
||||
await self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address, '')
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
|
||||
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
|
||||
self.assertEqual(
|
||||
address_details['history'],
|
||||
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
|
||||
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
|
||||
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
|
||||
)
|
||||
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
await self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address, '')
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, [])
|
||||
|
||||
|
@ -102,11 +108,17 @@ class TestSynchronization(LedgerTestCase):
|
|||
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
|
||||
self.ledger.network.get_history_called = []
|
||||
self.ledger.network.get_transaction_called = []
|
||||
await self.ledger.update_history(address)
|
||||
await self.ledger.update_history(address, '')
|
||||
self.assertEqual(self.ledger.network.get_history_called, [address])
|
||||
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
|
||||
address_details = await self.ledger.db.get_address(address=address)
|
||||
self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
|
||||
self.assertEqual(
|
||||
address_details['history'],
|
||||
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:0:'
|
||||
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:1:'
|
||||
'a2ae3d1db3c727e7d696122cab39ee20a7f81856dab7019056dd539f38c548a0:2:'
|
||||
'047cf1d53ef68f0fd586d46f90c09ff8e57a4180f67e7f4b8dd0135c3741e828:3:'
|
||||
)
|
||||
|
||||
|
||||
class MocHeaderNetwork:
|
||||
|
|
|
@ -261,14 +261,14 @@ class TransactionIOBalancing(AsyncioTestCase):
|
|||
.add_inputs([self.txi(self.txo(sum(amounts)+0.1))]) \
|
||||
.add_outputs(utxos)
|
||||
|
||||
save_tx = 'insert'
|
||||
await self.ledger.db.insert_transaction(self.funding_tx)
|
||||
|
||||
for utxo in utxos:
|
||||
await self.ledger.db.save_transaction_io(
|
||||
save_tx, self.funding_tx,
|
||||
self.funding_tx,
|
||||
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
|
||||
utxo.script.values['pubkey_hash'], ''
|
||||
)
|
||||
save_tx = 'update'
|
||||
|
||||
return utxos
|
||||
|
||||
|
|
|
@ -1,6 +1,41 @@
|
|||
import unittest
|
||||
|
||||
from torba.client.util import ArithUint256
|
||||
from torba.client.util import coins_to_satoshis as c2s, satoshis_to_coins as s2c
|
||||
|
||||
|
||||
class TestCoinValueParsing(unittest.TestCase):
|
||||
|
||||
def test_good_output(self):
|
||||
self.assertEqual(s2c(1), "0.00000001")
|
||||
self.assertEqual(s2c(10**7), "0.1")
|
||||
self.assertEqual(s2c(2*10**8), "2.0")
|
||||
self.assertEqual(s2c(2*10**17), "2000000000.0")
|
||||
|
||||
def test_good_input(self):
|
||||
self.assertEqual(c2s("0.00000001"), 1)
|
||||
self.assertEqual(c2s("0.1"), 10**7)
|
||||
self.assertEqual(c2s("1.0"), 10**8)
|
||||
self.assertEqual(c2s("2.00000000"), 2*10**8)
|
||||
self.assertEqual(c2s("2000000000.0"), 2*10**17)
|
||||
|
||||
def test_bad_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("1")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("-1.0")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("10000000000.0")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("1.000000000")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("-0")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("1")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s(".1")
|
||||
with self.assertRaises(ValueError):
|
||||
c2s("1e-7")
|
||||
|
||||
|
||||
class TestArithUint256(unittest.TestCase):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import random
|
||||
import typing
|
||||
from typing import Dict, Tuple, Type, Optional, Any
|
||||
from typing import Dict, Tuple, Type, Optional, Any, List
|
||||
|
||||
from torba.client.mnemonic import Mnemonic
|
||||
from torba.client.bip32 import PrivateKey, PubKey, from_extended_key_string
|
||||
|
@ -15,12 +16,13 @@ class AddressManager:
|
|||
|
||||
name: str
|
||||
|
||||
__slots__ = 'account', 'public_key', 'chain_number'
|
||||
__slots__ = 'account', 'public_key', 'chain_number', 'address_generator_lock'
|
||||
|
||||
def __init__(self, account, public_key, chain_number):
|
||||
self.account = account
|
||||
self.public_key = public_key
|
||||
self.chain_number = chain_number
|
||||
self.address_generator_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, account: 'BaseAccount', d: dict) \
|
||||
|
@ -60,11 +62,11 @@ class AddressManager:
|
|||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_addresses(self, only_usable: bool = False, **constraints):
|
||||
async def get_addresses(self, only_usable: bool = False, **constraints) -> List[str]:
|
||||
records = await self.get_address_records(only_usable=only_usable, **constraints)
|
||||
return [r['address'] for r in records]
|
||||
|
||||
async def get_or_create_usable_address(self):
|
||||
async def get_or_create_usable_address(self) -> str:
|
||||
addresses = await self.get_addresses(only_usable=True, limit=10)
|
||||
if addresses:
|
||||
return random.choice(addresses)
|
||||
|
@ -87,8 +89,8 @@ class HierarchicalDeterministic(AddressManager):
|
|||
@classmethod
|
||||
def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]:
|
||||
return (
|
||||
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})),
|
||||
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2}))
|
||||
cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 1})),
|
||||
cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 1}))
|
||||
)
|
||||
|
||||
def to_dict_instance(self):
|
||||
|
@ -97,19 +99,7 @@ class HierarchicalDeterministic(AddressManager):
|
|||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key.child(self.chain_number).child(index)
|
||||
|
||||
async def generate_keys(self, start: int, end: int):
|
||||
keys_batch, final_keys = [], []
|
||||
for index in range(start, end+1):
|
||||
keys_batch.append((index, self.public_key.child(index)))
|
||||
if index % 180 == 0 or index == end:
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, keys_batch
|
||||
)
|
||||
final_keys.extend(keys_batch)
|
||||
keys_batch.clear()
|
||||
return [key[1].address for key in final_keys]
|
||||
|
||||
async def get_max_gap(self):
|
||||
async def get_max_gap(self) -> int:
|
||||
addresses = await self._query_addresses(order_by="position ASC")
|
||||
max_gap = 0
|
||||
current_gap = 0
|
||||
|
@ -121,7 +111,8 @@ class HierarchicalDeterministic(AddressManager):
|
|||
current_gap = 0
|
||||
return max_gap
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
addresses = await self._query_addresses(limit=self.gap, order_by="position DESC")
|
||||
|
||||
existing_gap = 0
|
||||
|
@ -136,12 +127,27 @@ class HierarchicalDeterministic(AddressManager):
|
|||
|
||||
start = addresses[0]['position']+1 if addresses else 0
|
||||
end = start + (self.gap - existing_gap)
|
||||
new_keys = await self.generate_keys(start, end-1)
|
||||
new_keys = await self._generate_keys(start, end-1)
|
||||
await self.account.ledger.subscribe_addresses(self, new_keys)
|
||||
return new_keys
|
||||
|
||||
async def _generate_keys(self, start: int, end: int) -> List[str]:
|
||||
if not self.address_generator_lock.locked():
|
||||
raise RuntimeError('Should not be called outside of address_generator_lock.')
|
||||
keys_batch, final_keys = [], []
|
||||
for index in range(start, end+1):
|
||||
keys_batch.append((index, self.public_key.child(index)))
|
||||
if index % 180 == 0 or index == end:
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, keys_batch
|
||||
)
|
||||
final_keys.extend(keys_batch)
|
||||
keys_batch.clear()
|
||||
return [key[1].address for key in final_keys]
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
if only_usable:
|
||||
constraints['used_times__lte'] = self.maximum_uses_per_address
|
||||
constraints['used_times__lt'] = self.maximum_uses_per_address
|
||||
return self._query_addresses(order_by="used_times ASC, position ASC", **constraints)
|
||||
|
||||
|
||||
|
@ -164,16 +170,19 @@ class SingleKey(AddressManager):
|
|||
def get_private_key(self, index: int) -> PrivateKey:
|
||||
return self.account.private_key
|
||||
|
||||
async def get_max_gap(self):
|
||||
async def get_max_gap(self) -> int:
|
||||
return 0
|
||||
|
||||
async def ensure_address_gap(self):
|
||||
async def ensure_address_gap(self) -> List[str]:
|
||||
async with self.address_generator_lock:
|
||||
exists = await self.get_address_records()
|
||||
if not exists:
|
||||
await self.account.ledger.db.add_keys(
|
||||
self.account, self.chain_number, [(0, self.public_key)]
|
||||
)
|
||||
return [self.public_key.address]
|
||||
new_keys = [self.public_key.address]
|
||||
await self.account.ledger.subscribe_addresses(self, new_keys)
|
||||
return new_keys
|
||||
return []
|
||||
|
||||
def get_address_records(self, only_usable: bool = False, **constraints):
|
||||
|
@ -211,7 +220,7 @@ class BaseAccount:
|
|||
generator_name = address_generator.get('name', HierarchicalDeterministic.name)
|
||||
self.address_generator = self.address_generators[generator_name]
|
||||
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
|
||||
self.address_managers = {self.receiving, self.change}
|
||||
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
|
||||
ledger.add_account(self)
|
||||
wallet.add_account(self)
|
||||
|
||||
|
@ -320,12 +329,12 @@ class BaseAccount:
|
|||
|
||||
async def ensure_address_gap(self):
|
||||
addresses = []
|
||||
for address_manager in self.address_managers:
|
||||
for address_manager in self.address_managers.values():
|
||||
new_addresses = await address_manager.ensure_address_gap()
|
||||
addresses.extend(new_addresses)
|
||||
return addresses
|
||||
|
||||
async def get_addresses(self, **constraints):
|
||||
async def get_addresses(self, **constraints) -> List[str]:
|
||||
rows = await self.ledger.db.select_addresses('address', account=self, **constraints)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
@ -337,8 +346,7 @@ class BaseAccount:
|
|||
|
||||
def get_private_key(self, chain: int, index: int) -> PrivateKey:
|
||||
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
|
||||
address_manager = {0: self.receiving, 1: self.change}[chain]
|
||||
return address_manager.get_private_key(index)
|
||||
return self.address_managers[chain].get_private_key(index)
|
||||
|
||||
def get_balance(self, confirmations: int = 0, **constraints):
|
||||
if confirmations > 0:
|
||||
|
|
|
@ -169,13 +169,16 @@ class SQLiteMixin:
|
|||
await self.db.close()
|
||||
|
||||
@staticmethod
|
||||
def _insert_sql(table: str, data: dict) -> Tuple[str, List]:
|
||||
def _insert_sql(table: str, data: dict, ignore_duplicate: bool = False) -> Tuple[str, List]:
|
||||
columns, values = [], []
|
||||
for column, value in data.items():
|
||||
columns.append(column)
|
||||
values.append(value)
|
||||
sql = "INSERT INTO {} ({}) VALUES ({})".format(
|
||||
table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||
or_ignore = ""
|
||||
if ignore_duplicate:
|
||||
or_ignore = " OR IGNORE"
|
||||
sql = "INSERT{} INTO {} ({}) VALUES ({})".format(
|
||||
or_ignore, table, ', '.join(columns), ', '.join(['?'] * len(values))
|
||||
)
|
||||
return sql, values
|
||||
|
||||
|
@ -273,60 +276,49 @@ class BaseDatabase(SQLiteMixin):
|
|||
'script': sqlite3.Binary(txo.script.source)
|
||||
}
|
||||
|
||||
def save_transaction_io(self, save_tx, tx: BaseTransaction, address, txhash, history):
|
||||
|
||||
def _transaction(conn: sqlite3.Connection, save_tx, tx: BaseTransaction, address, txhash, history):
|
||||
if save_tx == 'insert':
|
||||
conn.execute(*self._insert_sql('tx', {
|
||||
async def insert_transaction(self, tx):
|
||||
await self.db.execute(*self._insert_sql('tx', {
|
||||
'txid': tx.id,
|
||||
'raw': sqlite3.Binary(tx.raw),
|
||||
'height': tx.height,
|
||||
'position': tx.position,
|
||||
'is_verified': tx.is_verified
|
||||
}))
|
||||
elif save_tx == 'update':
|
||||
conn.execute(*self._update_sql("tx", {
|
||||
|
||||
async def update_transaction(self, tx):
|
||||
await self.db.execute(*self._update_sql("tx", {
|
||||
'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified
|
||||
}, 'txid = ?', (tx.id,)))
|
||||
|
||||
existing_txos = set(map(itemgetter(0), conn.execute(*query(
|
||||
"SELECT position FROM txo", txid=tx.id
|
||||
))))
|
||||
def save_transaction_io(self, tx: BaseTransaction, address, txhash, history):
|
||||
|
||||
def _transaction(conn: sqlite3.Connection, tx: BaseTransaction, address, txhash, history):
|
||||
|
||||
for txo in tx.outputs:
|
||||
if txo.position in existing_txos:
|
||||
continue
|
||||
if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash:
|
||||
conn.execute(*self._insert_sql("txo", self.txo_to_row(tx, address, txo)))
|
||||
conn.execute(*self._insert_sql(
|
||||
"txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True
|
||||
))
|
||||
elif txo.script.is_pay_script_hash:
|
||||
# TODO: implement script hash payments
|
||||
log.warning('Database.save_transaction_io: pay script hash is not implemented!')
|
||||
|
||||
# lookup the address associated with each TXI (via its TXO)
|
||||
txoid_to_address = {r[0]: r[1] for r in conn.execute(*query(
|
||||
"SELECT txoid, address FROM txo", txoid__in=[txi.txo_ref.id for txi in tx.inputs]
|
||||
))}
|
||||
|
||||
# list of TXIs that have already been added
|
||||
existing_txis = {r[0] for r in conn.execute(*query(
|
||||
"SELECT txoid FROM txi", txid=tx.id
|
||||
))}
|
||||
|
||||
for txi in tx.inputs:
|
||||
txoid = txi.txo_ref.id
|
||||
new_txi = txoid not in existing_txis
|
||||
address_matches = txoid_to_address.get(txoid) == address
|
||||
if new_txi and address_matches:
|
||||
if txi.txo_ref.txo is not None:
|
||||
txo = txi.txo_ref.txo
|
||||
if txo.get_address(self.ledger) == address:
|
||||
conn.execute(*self._insert_sql("txi", {
|
||||
'txid': tx.id,
|
||||
'txoid': txoid,
|
||||
'txoid': txo.id,
|
||||
'address': address,
|
||||
}))
|
||||
}, ignore_duplicate=True))
|
||||
|
||||
conn.execute(
|
||||
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
|
||||
(history, history.count(':')//2, address)
|
||||
)
|
||||
return self.db.run(_transaction, save_tx, tx, address, txhash, history)
|
||||
|
||||
return self.db.run(_transaction, tx, address, txhash, history)
|
||||
|
||||
async def reserve_outputs(self, txos, is_reserved=True):
|
||||
txoids = [txo.id for txo in txos]
|
||||
|
|
|
@ -5,7 +5,7 @@ from functools import partial
|
|||
from binascii import hexlify, unhexlify
|
||||
from io import StringIO
|
||||
|
||||
from typing import Dict, Type, Iterable
|
||||
from typing import Dict, Type, Iterable, List, Optional
|
||||
from operator import itemgetter
|
||||
from collections import namedtuple
|
||||
|
||||
|
@ -48,6 +48,51 @@ class BlockHeightEvent(namedtuple('BlockHeightEvent', ('height', 'change'))):
|
|||
pass
|
||||
|
||||
|
||||
class TransactionCacheItem:
|
||||
__slots__ = '_tx', 'lock', 'has_tx'
|
||||
|
||||
def __init__(self,
|
||||
tx: Optional[basetransaction.BaseTransaction] = None,
|
||||
lock: Optional[asyncio.Lock] = None):
|
||||
self.has_tx = asyncio.Event()
|
||||
self.lock = lock or asyncio.Lock()
|
||||
self.tx = tx
|
||||
|
||||
@property
|
||||
def tx(self):
|
||||
return self._tx
|
||||
|
||||
@tx.setter
|
||||
def tx(self, tx):
|
||||
self._tx = tx
|
||||
if tx is not None:
|
||||
self.has_tx.set()
|
||||
|
||||
|
||||
class SynchronizationMonitor:
|
||||
|
||||
def __init__(self):
|
||||
self.done = asyncio.Event()
|
||||
self.tasks = []
|
||||
|
||||
def add(self, coro):
|
||||
len(self.tasks) < 1 and self.done.clear()
|
||||
asyncio.ensure_future(self._monitor(coro))
|
||||
|
||||
def cancel(self):
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
|
||||
async def _monitor(self, coro):
|
||||
task = asyncio.ensure_future(coro)
|
||||
self.tasks.append(task)
|
||||
try:
|
||||
await task
|
||||
finally:
|
||||
self.tasks.remove(task)
|
||||
len(self.tasks) < 1 and self.done.set()
|
||||
|
||||
|
||||
class BaseLedger(metaclass=LedgerRegistry):
|
||||
|
||||
name: str
|
||||
|
@ -79,7 +124,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)
|
||||
self.network = self.config.get('network') or self.network_class(self)
|
||||
self.network.on_header.listen(self.receive_header)
|
||||
self.network.on_status.listen(self.receive_status)
|
||||
self.network.on_status.listen(self.process_status_update)
|
||||
|
||||
self.accounts = []
|
||||
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
|
||||
|
||||
|
@ -101,7 +147,8 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
)
|
||||
)
|
||||
|
||||
self._transaction_processing_locks = {}
|
||||
self._tx_cache = {}
|
||||
self.sync = SynchronizationMonitor()
|
||||
self._utxo_reservation_lock = asyncio.Lock()
|
||||
self._header_processing_lock = asyncio.Lock()
|
||||
|
||||
|
@ -166,16 +213,14 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
def release_outputs(self, txos):
|
||||
return self.db.release_outputs(txos)
|
||||
|
||||
async def get_local_status(self, address):
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history']
|
||||
return hexlify(sha256(history.encode())).decode() if history else None
|
||||
|
||||
async def get_local_history(self, address):
|
||||
async def get_local_status_and_history(self, address):
|
||||
address_details = await self.db.get_address(address=address)
|
||||
history = address_details['history'] or ''
|
||||
parts = history.split(':')[:-1]
|
||||
return list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
return (
|
||||
hexlify(sha256(history.encode())).decode() if history else None,
|
||||
list(zip(parts[0::2], map(int, parts[1::2])))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
|
||||
|
@ -189,21 +234,13 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
working_branch = double_sha256(combined)
|
||||
return hexlify(working_branch[::-1])
|
||||
|
||||
async def validate_transaction_and_set_position(self, tx, height, merkle):
|
||||
if not height <= len(self.headers):
|
||||
return False
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[height]
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
async def start(self):
|
||||
if not os.path.exists(self.path):
|
||||
os.mkdir(self.path)
|
||||
await asyncio.gather(
|
||||
await asyncio.wait([
|
||||
self.db.open(),
|
||||
self.headers.open()
|
||||
)
|
||||
])
|
||||
first_connection = self.network.on_connected.first
|
||||
asyncio.ensure_future(self.network.start())
|
||||
await first_connection
|
||||
|
@ -214,9 +251,15 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
log.info("Subscribing and updating accounts.")
|
||||
await self.update_headers()
|
||||
await self.network.subscribe_headers()
|
||||
await self.update_accounts()
|
||||
import time
|
||||
start = time.time()
|
||||
await self.subscribe_accounts()
|
||||
await self.sync.done.wait()
|
||||
log.info(f'elapsed: {time.time()-start}')
|
||||
|
||||
async def stop(self):
|
||||
self.sync.cancel()
|
||||
await self.sync.done.wait()
|
||||
await self.network.stop()
|
||||
await self.db.close()
|
||||
await self.headers.close()
|
||||
|
@ -299,89 +342,144 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
height=header['height'], headers=header['hex'], subscription_update=True
|
||||
)
|
||||
|
||||
async def update_accounts(self):
|
||||
return await asyncio.gather(*(
|
||||
self.update_account(a) for a in self.accounts
|
||||
async def subscribe_accounts(self):
|
||||
if self.network.is_connected and self.accounts:
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in self.accounts
|
||||
])
|
||||
|
||||
async def subscribe_account(self, account: baseaccount.BaseAccount):
|
||||
for address_manager in account.address_managers.values():
|
||||
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
|
||||
await account.ensure_address_gap()
|
||||
|
||||
async def subscribe_addresses(self, address_manager: baseaccount.AddressManager, addresses: List[str]):
|
||||
if self.network.is_connected and addresses:
|
||||
await asyncio.wait([
|
||||
self.subscribe_address(address_manager, address) for address in addresses
|
||||
])
|
||||
|
||||
async def subscribe_address(self, address_manager: baseaccount.AddressManager, address: str):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
self.sync.add(self.update_history(address, remote_status, address_manager))
|
||||
|
||||
def process_status_update(self, update):
|
||||
address, remote_status = update
|
||||
self.sync.add(self.update_history(address, remote_status))
|
||||
|
||||
async def update_history(self, address, remote_status,
|
||||
address_manager: baseaccount.AddressManager = None):
|
||||
local_status, local_history = await self.get_local_status_and_history(address)
|
||||
|
||||
if local_status == remote_status:
|
||||
return
|
||||
|
||||
remote_history = await self.network.get_history(address)
|
||||
|
||||
cache_tasks = []
|
||||
synced_history = StringIO()
|
||||
for i, (txid, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||
if i < len(local_history) and local_history[i] == (txid, remote_height):
|
||||
synced_history.write(f'{txid}:{remote_height}:')
|
||||
else:
|
||||
cache_tasks.append(asyncio.ensure_future(
|
||||
self.cache_transaction(txid, remote_height)
|
||||
))
|
||||
|
||||
async def update_account(self, account: baseaccount.BaseAccount):
|
||||
await account.ensure_address_gap()
|
||||
addresses = await account.get_addresses()
|
||||
while addresses:
|
||||
await asyncio.gather(*(self.subscribe_history(a) for a in addresses))
|
||||
addresses = await account.ensure_address_gap()
|
||||
for task in cache_tasks:
|
||||
tx = await task
|
||||
|
||||
def _prefetch_history(self, remote_history, local_history):
|
||||
proofs, network_txs = {}, {}
|
||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||
check_db_for_txos = []
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
if remote_height > 0:
|
||||
proofs[hex_id] = asyncio.ensure_future(self.network.get_merkle(hex_id, remote_height))
|
||||
network_txs[hex_id] = asyncio.ensure_future(self.network.get_transaction(hex_id))
|
||||
return proofs, network_txs
|
||||
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
|
||||
if cache_item is not None:
|
||||
if cache_item.tx is None:
|
||||
await cache_item.has_tx.wait()
|
||||
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
|
||||
else:
|
||||
check_db_for_txos.append(txi.txo_ref.tx_ref.id)
|
||||
|
||||
async def update_history(self, address):
|
||||
remote_history = await self.network.get_history(address)
|
||||
local_history = await self.get_local_history(address)
|
||||
proofs, network_txs = self._prefetch_history(remote_history, local_history)
|
||||
referenced_txos = {
|
||||
txo.id: txo for txo in await self.db.get_txos(txoid__in=check_db_for_txos)
|
||||
}
|
||||
|
||||
synced_history = StringIO()
|
||||
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
|
||||
|
||||
synced_history.write('{}:{}:'.format(hex_id, remote_height))
|
||||
|
||||
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
|
||||
for txi in tx.inputs:
|
||||
if txi.txo_ref.txo is not None:
|
||||
continue
|
||||
referenced_txo = referenced_txos.get(txi.txo_ref.tx_ref.id)
|
||||
if referenced_txos:
|
||||
txi.txo_ref = referenced_txo.ref
|
||||
|
||||
lock = self._transaction_processing_locks.setdefault(hex_id, asyncio.Lock())
|
||||
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
|
||||
# see if we have a local copy of transaction, otherwise fetch it from server
|
||||
tx = await self.db.get_transaction(txid=hex_id)
|
||||
save_tx = None
|
||||
if tx is None:
|
||||
_raw = await network_txs[hex_id]
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
save_tx = 'insert'
|
||||
|
||||
tx.height = remote_height
|
||||
|
||||
if remote_height > 0 and (not tx.is_verified or tx.position == -1):
|
||||
await self.validate_transaction_and_set_position(tx, remote_height, await proofs[hex_id])
|
||||
if save_tx is None:
|
||||
save_tx = 'update'
|
||||
synced_history.write(f'{tx.id}:{tx.height}:')
|
||||
|
||||
await self.db.save_transaction_io(
|
||||
save_tx, tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"%s: sync'ed tx %s for address: %s, height: %s, verified: %s",
|
||||
self.get_id(), hex_id, address, tx.height, tx.is_verified
|
||||
tx, address, self.address_to_hash160(address), synced_history.getvalue()
|
||||
)
|
||||
|
||||
self._on_transaction_controller.add(TransactionEvent(address, tx))
|
||||
|
||||
if address_manager is None:
|
||||
address_manager = await self.get_address_manager_for_address(address)
|
||||
|
||||
await address_manager.ensure_address_gap()
|
||||
|
||||
async def cache_transaction(self, txid, remote_height):
|
||||
cache_item = self._tx_cache.get(txid)
|
||||
if cache_item is None:
|
||||
cache_item = self._tx_cache[txid] = TransactionCacheItem()
|
||||
elif cache_item.tx is not None and \
|
||||
cache_item.tx.height >= remote_height and \
|
||||
(cache_item.tx.is_verified or remote_height < 1):
|
||||
return cache_item.tx # cached tx is already up-to-date
|
||||
|
||||
await cache_item.lock.acquire()
|
||||
|
||||
try:
|
||||
tx = cache_item.tx
|
||||
|
||||
if tx is None:
|
||||
# check local db
|
||||
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
|
||||
|
||||
if tx is None:
|
||||
# fetch from network
|
||||
_raw = await self.network.get_transaction(txid)
|
||||
if _raw:
|
||||
tx = self.transaction_class(unhexlify(_raw))
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
await self.db.insert_transaction(tx)
|
||||
cache_item.tx = tx # make sure it's saved before caching it
|
||||
return tx
|
||||
|
||||
if tx is None:
|
||||
raise ValueError(f'Transaction {txid} was not in database and not on network.')
|
||||
|
||||
if 0 < remote_height and not tx.is_verified:
|
||||
# tx from cache / db is not up-to-date
|
||||
await self.maybe_verify_transaction(tx, remote_height)
|
||||
await self.db.update_transaction(tx)
|
||||
|
||||
return tx
|
||||
|
||||
finally:
|
||||
lock.release()
|
||||
if not lock.locked() and hex_id in self._transaction_processing_locks:
|
||||
del self._transaction_processing_locks[hex_id]
|
||||
cache_item.lock.release()
|
||||
|
||||
async def subscribe_history(self, address):
|
||||
remote_status = await self.network.subscribe_address(address)
|
||||
local_status = await self.get_local_status(address)
|
||||
if local_status != remote_status:
|
||||
await self.update_history(address)
|
||||
async def maybe_verify_transaction(self, tx, remote_height):
|
||||
tx.height = remote_height
|
||||
if 0 < remote_height <= len(self.headers):
|
||||
merkle = await self.network.get_merkle(tx.id, remote_height)
|
||||
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
|
||||
header = self.headers[remote_height]
|
||||
tx.position = merkle['pos']
|
||||
tx.is_verified = merkle_root == header['merkle_root']
|
||||
|
||||
async def receive_status(self, response):
|
||||
address, remote_status = response
|
||||
local_status = await self.get_local_status(address)
|
||||
if local_status != remote_status:
|
||||
await self.update_history(address)
|
||||
async def get_address_manager_for_address(self, address) -> baseaccount.AddressManager:
|
||||
details = await self.db.get_address(address=address)
|
||||
for account in self.accounts:
|
||||
if account.id == details['account']:
|
||||
return account.address_managers[details['chain']]
|
||||
|
||||
def broadcast(self, tx):
|
||||
return self.network.broadcast(hexlify(tx.raw).decode())
|
||||
|
|
|
@ -1,5 +1,25 @@
|
|||
import re
|
||||
from binascii import unhexlify, hexlify
|
||||
from typing import TypeVar, Sequence, Optional
|
||||
from torba.client.constants import COIN
|
||||
|
||||
|
||||
def coins_to_satoshis(coins):
|
||||
if not isinstance(coins, str):
|
||||
raise ValueError("{coins} must be a string")
|
||||
result = re.search(r'^(\d{1,10})\.(\d{1,8})$', coins)
|
||||
if result is not None:
|
||||
whole, fractional = result.groups()
|
||||
return int(whole+fractional.ljust(8, "0"))
|
||||
raise ValueError("'{lbc}' is not a valid coin decimal")
|
||||
|
||||
|
||||
def satoshis_to_coins(satoshis):
|
||||
coins = '{:.8f}'.format(satoshis / COIN).rstrip('0')
|
||||
if coins.endswith('.'):
|
||||
return coins+'0'
|
||||
else:
|
||||
return coins
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
|
|
@ -112,7 +112,7 @@ class Conductor:
|
|||
class WalletNode:
|
||||
|
||||
def __init__(self, manager_class: Type[BaseWalletManager], ledger_class: Type[BaseLedger],
|
||||
verbose: bool = False) -> None:
|
||||
verbose: bool = False, api_port: int = 5279) -> None:
|
||||
self.manager_class = manager_class
|
||||
self.ledger_class = ledger_class
|
||||
self.verbose = verbose
|
||||
|
@ -121,8 +121,9 @@ class WalletNode:
|
|||
self.wallet: Optional[Wallet] = None
|
||||
self.account: Optional[BaseAccount] = None
|
||||
self.data_path: Optional[str] = None
|
||||
self.api_port = api_port
|
||||
|
||||
async def start(self):
|
||||
async def start(self, seed=None):
|
||||
self.data_path = tempfile.mkdtemp()
|
||||
wallet_file_name = os.path.join(self.data_path, 'my_wallet.json')
|
||||
with open(wallet_file_name, 'w') as wallet_file:
|
||||
|
@ -130,6 +131,7 @@ class WalletNode:
|
|||
self.manager = self.manager_class.from_config({
|
||||
'ledgers': {
|
||||
self.ledger_class.get_id(): {
|
||||
'api_port': self.api_port,
|
||||
'default_servers': [('localhost', 1984)],
|
||||
'data_path': self.data_path
|
||||
}
|
||||
|
@ -138,7 +140,12 @@ class WalletNode:
|
|||
})
|
||||
self.ledger = self.manager.ledgers[self.ledger_class]
|
||||
self.wallet = self.manager.default_wallet
|
||||
if seed is None:
|
||||
self.wallet.generate_account(self.ledger)
|
||||
else:
|
||||
self.ledger.account_class.from_dict(
|
||||
self.ledger, self.wallet, {'seed': seed}
|
||||
)
|
||||
self.account = self.wallet.default_account
|
||||
await self.manager.start()
|
||||
|
||||
|
|
|
@ -29,17 +29,17 @@ class BroadcastSubscription:
|
|||
|
||||
def _add(self, data):
|
||||
if self.can_fire and self._on_data is not None:
|
||||
maybe_coroutine = self._on_data(data)
|
||||
if asyncio.iscoroutine(maybe_coroutine):
|
||||
asyncio.ensure_future(maybe_coroutine)
|
||||
return self._on_data(data)
|
||||
|
||||
def _add_error(self, exception):
|
||||
if self.can_fire and self._on_error is not None:
|
||||
self._on_error(exception)
|
||||
return self._on_error(exception)
|
||||
|
||||
def _close(self):
|
||||
try:
|
||||
if self.can_fire and self._on_done is not None:
|
||||
self._on_done()
|
||||
return self._on_done()
|
||||
finally:
|
||||
self.is_closed = True
|
||||
|
||||
|
||||
|
@ -62,13 +62,28 @@ class StreamController:
|
|||
next_sub = next_sub._next
|
||||
yield subscription
|
||||
|
||||
def add(self, event):
|
||||
def _notify_and_ensure_future(self, notify):
|
||||
tasks = []
|
||||
for subscription in self._iterate_subscriptions:
|
||||
subscription._add(event)
|
||||
maybe_coroutine = notify(subscription)
|
||||
if asyncio.iscoroutine(maybe_coroutine):
|
||||
tasks.append(maybe_coroutine)
|
||||
if tasks:
|
||||
return asyncio.ensure_future(asyncio.wait(tasks))
|
||||
else:
|
||||
f = asyncio.get_event_loop().create_future()
|
||||
f.set_result(None)
|
||||
return f
|
||||
|
||||
def add(self, event):
|
||||
return self._notify_and_ensure_future(
|
||||
lambda subscription: subscription._add(event)
|
||||
)
|
||||
|
||||
def add_error(self, exception):
|
||||
for subscription in self._iterate_subscriptions:
|
||||
subscription._add_error(exception)
|
||||
return self._notify_and_ensure_future(
|
||||
lambda subscription: subscription._add_error(exception)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
for subscription in self._iterate_subscriptions:
|
||||
|
|
|
@ -9,6 +9,7 @@ from torba.client.baseledger import BaseLedger
|
|||
from torba.client.baseaccount import BaseAccount
|
||||
from torba.client.basemanager import BaseWalletManager
|
||||
from torba.client.wallet import Wallet
|
||||
from torba.client.util import satoshis_to_coins
|
||||
|
||||
|
||||
try:
|
||||
|
@ -159,15 +160,13 @@ class IntegrationTestCase(AsyncioTestCase):
|
|||
async def asyncTearDown(self):
|
||||
await self.conductor.stop()
|
||||
|
||||
async def assertBalance(self, account, expected_balance: str):
|
||||
balance = await account.get_balance()
|
||||
self.assertEqual(satoshis_to_coins(balance), expected_balance)
|
||||
|
||||
def broadcast(self, tx):
|
||||
return self.ledger.broadcast(tx)
|
||||
|
||||
def get_balance(self, account=None, confirmations=0):
|
||||
if account is None:
|
||||
return self.manager.get_balance(confirmations=confirmations)
|
||||
else:
|
||||
return account.get_balance(confirmations=confirmations)
|
||||
|
||||
async def on_header(self, height):
|
||||
if self.ledger.headers.height < height:
|
||||
await self.ledger.on_header.where(
|
||||
|
@ -175,8 +174,8 @@ class IntegrationTestCase(AsyncioTestCase):
|
|||
)
|
||||
return True
|
||||
|
||||
def on_transaction_id(self, txid):
|
||||
return self.ledger.on_transaction.where(
|
||||
def on_transaction_id(self, txid, ledger=None):
|
||||
return (ledger or self.ledger).on_transaction.where(
|
||||
lambda e: e.tx.id == txid
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue