diff --git a/tests/integration/test_transactions.py b/tests/integration/test_transactions.py index 883b39d45..939d87a19 100644 --- a/tests/integration/test_transactions.py +++ b/tests/integration/test_transactions.py @@ -17,30 +17,35 @@ class BasicTransactionTests(IntegrationTestCase): address = await account1.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5) - await self.on_transaction(sendtxid) #mempool + await self.on_transaction_id(sendtxid) #mempool await self.blockchain.generate(1) - await self.on_transaction(sendtxid) #confirmed + await self.on_transaction_id(sendtxid) #confirmed self.assertEqual(await self.get_balance(account1), int(5.5*COIN)) self.assertEqual(await self.get_balance(account2), 0) address = await account2.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) + hash1 = self.ledger.address_to_hash160(address) tx = await self.ledger.transaction_class.pay( - [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, self.ledger.address_to_hash160(address))], + [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash1)], [account1], account1 ).asFuture(asyncio.get_event_loop()) await self.broadcast(tx) - await self.on_transaction(tx.hex_id.decode()) #mempool + await self.on_transaction(tx) #mempool tx2 = await self.ledger.transaction_class.pay( - [self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, self.ledger.address_to_hash160(address))], + [self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, hash1)], [account1], account1 ).asFuture(asyncio.get_event_loop()) await self.broadcast(tx2) - await self.on_transaction(tx2.hex_id.decode()) #mempool + await self.on_transaction(tx2) #mempool await self.blockchain.generate(1) - await self.on_transaction(tx.hex_id.decode()) #confirmed + await asyncio.wait([ + self.on_header(202), + self.on_transaction(tx), + self.on_transaction(tx2), + ]) #self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) #self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index b58845d74..6923c6e5e 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -113,7 +113,7 @@ class BaseDatabase(SQLiteMixin): txhash blob primary key, raw blob not null, height integer not null, - is_verified boolean not null default false + is_verified boolean not null default 0 ); """ @@ -136,7 +136,8 @@ class BaseDatabase(SQLiteMixin): address blob references pubkey_address, position integer not null, amount integer not null, - script blob not null + script blob not null, + is_reserved boolean not null default 0 ); """ @@ -168,7 +169,7 @@ class BaseDatabase(SQLiteMixin): elif save_tx == 'update': t.execute(*self._update_sql("tx", { 'height': height, 'is_verified': is_verified - }, 'WHERE txhash = ?', (sqlite3.Binary(tx.hash),) + }, 'txhash = ?', (sqlite3.Binary(tx.hash),) )) existing_txos = list(map(itemgetter(0), t.execute( @@ -209,18 +210,28 @@ class BaseDatabase(SQLiteMixin): t.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", - (sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address)) + (history, history.count(':')//2, sqlite3.Binary(address)) ) return self.db.runInteraction(_steps) + def reserve_spent_outputs(self, txoids, is_reserved=True): + return self.db.runOperation( + "UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format( + ', '.join(['?']*len(txoids)) + ), [is_reserved]+txoids + ) + + def release_reserved_outputs(self, txoids): + return self.reserve_spent_outputs(txoids, is_reserved=False) + @defer.inlineCallbacks def get_transaction(self, txhash): result = yield self.db.runQuery( "SELECT raw, height, is_verified FROM tx WHERE txhash = ?", (sqlite3.Binary(txhash),) ) if result: - defer.returnValue(*result[0]) + defer.returnValue(result[0]) else: defer.returnValue((None, None, False)) @@ -244,9 +255,9 @@ class BaseDatabase(SQLiteMixin): def get_utxos(self, account, output_class): utxos = yield self.db.runQuery( """ - SELECT amount, script, txhash, txo.position + SELECT amount, script, txhash, txo.position, txoid FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address - WHERE account=:account AND txoid NOT IN (SELECT txoid FROM txi) + WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi) """, {'account': sqlite3.Binary(account.public_key.address)} ) @@ -255,7 +266,8 @@ class BaseDatabase(SQLiteMixin): values[0], output_class.script_class(values[1]), values[2], - index=values[3] + index=values[3], + txoid=values[4] ) for values in utxos ]) diff --git a/torba/baseheader.py b/torba/baseheader.py index 506cab22c..db69b24ec 100644 --- a/torba/baseheader.py +++ b/torba/baseheader.py @@ -1,14 +1,16 @@ import os import struct +import logging from binascii import unhexlify from twisted.internet import threads, defer -import torba from torba.stream import StreamController, execute_serially from torba.util import int_to_hex, rev_hex, hash_encode from torba.hash import double_sha256, pow_hash +log = logging.getLogger(__name__) + class BaseHeaders: @@ -32,7 +34,7 @@ class BaseHeaders: @property def height(self): - return len(self) - 1 + return len(self) def sync_read_length(self): return os.path.getsize(self.path) // self.header_size @@ -76,7 +78,9 @@ class BaseHeaders: _old_size = self._size self._size = self.sync_read_length() change = self._size - _old_size - #log.info('saved {} header blocks'.format(change)) + log.info('{}: added {} header blocks, final height {}'.format( + self.ledger.get_id(), change, self.height) + ) self._on_change_controller.add(change) def _iterate_headers(self, height, headers): diff --git a/torba/baseledger.py b/torba/baseledger.py index 36d7f6a6e..745fa6fcb 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -1,9 +1,11 @@ import os import six import hashlib +import logging from binascii import hexlify, unhexlify from typing import Dict, Type, Iterable, Generator from operator import itemgetter +from collections import namedtuple from twisted.internet import defer @@ -15,6 +17,8 @@ from torba import basetransaction from torba.stream import StreamController, execute_serially from torba.hash import hash160, double_sha256, Base58 +log = logging.getLogger(__name__) + class LedgerRegistry(type): ledgers = {} # type: Dict[str, Type[BaseLedger]] @@ -33,6 +37,10 @@ class LedgerRegistry(type): return mcs.ledgers[ledger_id] +class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))): + pass + + class BaseLedger(six.with_metaclass(LedgerRegistry)): name = None @@ -67,6 +75,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): self._on_transaction_controller = StreamController() self.on_transaction = self._on_transaction_controller.stream + self.on_transaction.listen( + lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format( + self.get_id(), e.address, e.height, e.is_verified, e.tx.hex_id) + ) + ) + + self._on_header_controller = StreamController() + self.on_header = self._on_header_controller.stream self._transaction_processing_locks = {} @@ -133,19 +149,15 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @defer.inlineCallbacks def get_local_status(self, address): address_details = yield self.db.get_address(address) - history = address_details['history'] or b'' - if six.PY2: - history = str(history) - hash = hashlib.sha256(history).digest() + history = address_details['history'] or '' + hash = hashlib.sha256(history.encode()).digest() defer.returnValue(hexlify(hash)) @defer.inlineCallbacks def get_local_history(self, address): address_details = yield self.db.get_address(address) - history = address_details['history'] or b'' - if six.PY2: - history = str(history) - parts = history.split(b':')[:-1] + history = address_details['history'] or '' + parts = history.split(':')[:-1] defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2])))) @staticmethod @@ -162,7 +174,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @defer.inlineCallbacks def is_valid_transaction(self, tx, height): - len(self.headers) < height or defer.returnValue(False) + height <= len(self.headers) or defer.returnValue(False) merkle = yield self.network.get_merkle(tx.hex_id.decode(), height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = self.headers[height] @@ -193,6 +205,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): if headers['count'] <= 0: break yield self.headers.connect(height_sought, unhexlify(headers['hex'])) + self._on_header_controller.add(height_sought) @defer.inlineCallbacks def process_header(self, response): @@ -202,6 +215,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): if header['height'] == len(self.headers): # New header from network directly connects after the last local header. yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) + self._on_header_controller.add(len(self.headers)) elif header['height'] > len(self.headers): # New header is several heights ahead of local, do download instead. yield self.update_headers() @@ -239,45 +253,45 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): local_history = yield self.get_local_history(address) synced_history = [] - for i, (hash, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): + for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): - synced_history.append((hash, remote_height)) + synced_history.append((hex_id, remote_height)) - if i < len(local_history) and local_history[i] == (hash, remote_height): + if i < len(local_history) and local_history[i] == (hex_id, remote_height): continue - lock = self._transaction_processing_locks.setdefault(hash, defer.DeferredLock()) + lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock()) yield lock.acquire() try: # see if we have a local copy of transaction, otherwise fetch it from server - raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hash)) + raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hex_id)[::-1]) save_tx = None if raw is None: - _raw = yield self.network.get_transaction(hash) + _raw = yield self.network.get_transaction(hex_id) tx = self.transaction_class(unhexlify(_raw)) save_tx = 'insert' else: - tx = self.transaction_class(unhexlify(raw)) + tx = self.transaction_class(raw) if remote_height > 0 and not is_verified: is_verified = yield self.is_valid_transaction(tx, remote_height) + is_verified = 1 if is_verified else 0 if save_tx is None: save_tx = 'update' yield self.db.save_transaction_io( save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address), - ''.join('{}:{}:'.format(hash.decode(), height) for hash, height in synced_history).encode() + ''.join('{}:{}:'.format(tx_id.decode(), tx_height) for tx_id, tx_height in synced_history) ) - if save_tx is not None: - self._on_transaction_controller.add(tx) + self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) finally: lock.release() if not lock.locked: - del self._transaction_processing_locks[hash] + del self._transaction_processing_locks[hex_id] @defer.inlineCallbacks def subscribe_history(self, address): diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 8914b9173..004d8f323 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -124,10 +124,11 @@ class BaseOutput(InputOutput): script_class = BaseOutputScript estimator_class = BaseOutputEffectiveAmountEstimator - def __init__(self, amount, script, txhash=None, index=None): + def __init__(self, amount, script, txhash=None, index=None, txoid=None): super(BaseOutput, self).__init__(txhash, index) self.amount = amount # type: int self.script = script # type: BaseOutputScript + self.txoid = txoid def get_estimator(self, ledger): return self.estimator_class(ledger, self) @@ -288,7 +289,7 @@ class BaseTransaction: @classmethod @defer.inlineCallbacks - def pay(cls, outputs, funding_accounts, change_account): + def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True): # type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred """ Efficiently spend utxos from funding_accounts to cover the new outputs. """ @@ -307,15 +308,26 @@ class BaseTransaction: if not spendables: raise ValueError('Not enough funds to cover this transaction.') - spent_sum = sum(s.effective_amount for s in spendables) - if spent_sum > amount: - change_address = yield change_account.change.get_or_create_usable_address() - change_hash160 = change_account.ledger.address_to_hash160(change_address) - change_amount = spent_sum - amount - tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) + reserved_outputs = [s.txo.txoid for s in spendables] + if reserve_outputs: + yield ledger.db.reserve_spent_outputs(reserved_outputs) + + try: + spent_sum = sum(s.effective_amount for s in spendables) + if spent_sum > amount: + change_address = yield change_account.change.get_or_create_usable_address() + change_hash160 = change_account.ledger.address_to_hash160(change_address) + change_amount = spent_sum - amount + tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) + + tx.add_inputs([s.txi for s in spendables]) + yield tx.sign(funding_accounts) + + except Exception: + if reserve_outputs: + yield ledger.db.release_reserved_outputs(reserved_outputs) + raise - tx.add_inputs([s.txi for s in spendables]) - yield tx.sign(funding_accounts) defer.returnValue(tx) @classmethod @@ -354,3 +366,13 @@ class BaseTransaction: @property def output_sum(self): return sum(o.amount for o in self.outputs) + + @defer.inlineCallbacks + def get_my_addresses(self, ledger): + addresses = set() + for txo in self.outputs: + address = ledger.hash160_to_address(txo.script.values['pubkey_hash']) + record = yield ledger.db.get_address(address) + if record is not None: + addresses.add(address) + defer.returnValue(list(addresses)) diff --git a/torba/util.py b/torba/util.py index a6d4b8a52..9683a8dfe 100644 --- a/torba/util.py +++ b/torba/util.py @@ -1,13 +1,17 @@ from binascii import unhexlify, hexlify from collections import Sequence +from typing import TypeVar, Generic -class ReadOnlyList(Sequence): +T = TypeVar('T') + + +class ReadOnlyList(Sequence, Generic[T]): def __init__(self, lst): self.lst = lst - def __getitem__(self, key): + def __getitem__(self, key): # type: (int) -> T return self.lst[key] def __len__(self):