* on_transaction now produces TransactionEvents with useful info and lots of other changes

This commit is contained in:
Lex Berezhny 2018-06-26 17:22:05 -04:00
parent aba3ed7ca0
commit 9a467f8840
6 changed files with 111 additions and 50 deletions

View file

@ -17,30 +17,35 @@ class BasicTransactionTests(IntegrationTestCase):
address = await account1.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) address = await account1.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop())
sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5) sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5)
await self.on_transaction(sendtxid) #mempool await self.on_transaction_id(sendtxid) #mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(sendtxid) #confirmed await self.on_transaction_id(sendtxid) #confirmed
self.assertEqual(await self.get_balance(account1), int(5.5*COIN)) self.assertEqual(await self.get_balance(account1), int(5.5*COIN))
self.assertEqual(await self.get_balance(account2), 0) self.assertEqual(await self.get_balance(account2), 0)
address = await account2.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) address = await account2.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop())
hash1 = self.ledger.address_to_hash160(address)
tx = await self.ledger.transaction_class.pay( tx = await self.ledger.transaction_class.pay(
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, self.ledger.address_to_hash160(address))], [self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash1)],
[account1], account1 [account1], account1
).asFuture(asyncio.get_event_loop()) ).asFuture(asyncio.get_event_loop())
await self.broadcast(tx) await self.broadcast(tx)
await self.on_transaction(tx.hex_id.decode()) #mempool await self.on_transaction(tx) #mempool
tx2 = await self.ledger.transaction_class.pay( tx2 = await self.ledger.transaction_class.pay(
[self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, self.ledger.address_to_hash160(address))], [self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, hash1)],
[account1], account1 [account1], account1
).asFuture(asyncio.get_event_loop()) ).asFuture(asyncio.get_event_loop())
await self.broadcast(tx2) await self.broadcast(tx2)
await self.on_transaction(tx2.hex_id.decode()) #mempool await self.on_transaction(tx2) #mempool
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(tx.hex_id.decode()) #confirmed await asyncio.wait([
self.on_header(202),
self.on_transaction(tx),
self.on_transaction(tx2),
])
#self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5) #self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
#self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0) #self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)

View file

@ -113,7 +113,7 @@ class BaseDatabase(SQLiteMixin):
txhash blob primary key, txhash blob primary key,
raw blob not null, raw blob not null,
height integer not null, height integer not null,
is_verified boolean not null default false is_verified boolean not null default 0
); );
""" """
@ -136,7 +136,8 @@ class BaseDatabase(SQLiteMixin):
address blob references pubkey_address, address blob references pubkey_address,
position integer not null, position integer not null,
amount integer not null, amount integer not null,
script blob not null script blob not null,
is_reserved boolean not null default 0
); );
""" """
@ -168,7 +169,7 @@ class BaseDatabase(SQLiteMixin):
elif save_tx == 'update': elif save_tx == 'update':
t.execute(*self._update_sql("tx", { t.execute(*self._update_sql("tx", {
'height': height, 'is_verified': is_verified 'height': height, 'is_verified': is_verified
}, 'WHERE txhash = ?', (sqlite3.Binary(tx.hash),) }, 'txhash = ?', (sqlite3.Binary(tx.hash),)
)) ))
existing_txos = list(map(itemgetter(0), t.execute( existing_txos = list(map(itemgetter(0), t.execute(
@ -209,18 +210,28 @@ class BaseDatabase(SQLiteMixin):
t.execute( t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address)) (history, history.count(':')//2, sqlite3.Binary(address))
) )
return self.db.runInteraction(_steps) return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True):
return self.db.runOperation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids))
), [is_reserved]+txoids
)
def release_reserved_outputs(self, txoids):
return self.reserve_spent_outputs(txoids, is_reserved=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txhash): def get_transaction(self, txhash):
result = yield self.db.runQuery( result = yield self.db.runQuery(
"SELECT raw, height, is_verified FROM tx WHERE txhash = ?", (sqlite3.Binary(txhash),) "SELECT raw, height, is_verified FROM tx WHERE txhash = ?", (sqlite3.Binary(txhash),)
) )
if result: if result:
defer.returnValue(*result[0]) defer.returnValue(result[0])
else: else:
defer.returnValue((None, None, False)) defer.returnValue((None, None, False))
@ -244,9 +255,9 @@ class BaseDatabase(SQLiteMixin):
def get_utxos(self, account, output_class): def get_utxos(self, account, output_class):
utxos = yield self.db.runQuery( utxos = yield self.db.runQuery(
""" """
SELECT amount, script, txhash, txo.position SELECT amount, script, txhash, txo.position, txoid
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE account=:account AND txoid NOT IN (SELECT txoid FROM txi) WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi)
""", """,
{'account': sqlite3.Binary(account.public_key.address)} {'account': sqlite3.Binary(account.public_key.address)}
) )
@ -255,7 +266,8 @@ class BaseDatabase(SQLiteMixin):
values[0], values[0],
output_class.script_class(values[1]), output_class.script_class(values[1]),
values[2], values[2],
index=values[3] index=values[3],
txoid=values[4]
) for values in utxos ) for values in utxos
]) ])

View file

@ -1,14 +1,16 @@
import os import os
import struct import struct
import logging
from binascii import unhexlify from binascii import unhexlify
from twisted.internet import threads, defer from twisted.internet import threads, defer
import torba
from torba.stream import StreamController, execute_serially from torba.stream import StreamController, execute_serially
from torba.util import int_to_hex, rev_hex, hash_encode from torba.util import int_to_hex, rev_hex, hash_encode
from torba.hash import double_sha256, pow_hash from torba.hash import double_sha256, pow_hash
log = logging.getLogger(__name__)
class BaseHeaders: class BaseHeaders:
@ -32,7 +34,7 @@ class BaseHeaders:
@property @property
def height(self): def height(self):
return len(self) - 1 return len(self)
def sync_read_length(self): def sync_read_length(self):
return os.path.getsize(self.path) // self.header_size return os.path.getsize(self.path) // self.header_size
@ -76,7 +78,9 @@ class BaseHeaders:
_old_size = self._size _old_size = self._size
self._size = self.sync_read_length() self._size = self.sync_read_length()
change = self._size - _old_size change = self._size - _old_size
#log.info('saved {} header blocks'.format(change)) log.info('{}: added {} header blocks, final height {}'.format(
self.ledger.get_id(), change, self.height)
)
self._on_change_controller.add(change) self._on_change_controller.add(change)
def _iterate_headers(self, height, headers): def _iterate_headers(self, height, headers):

View file

@ -1,9 +1,11 @@
import os import os
import six import six
import hashlib import hashlib
import logging
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Type, Iterable, Generator from typing import Dict, Type, Iterable, Generator
from operator import itemgetter from operator import itemgetter
from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
@ -15,6 +17,8 @@ from torba import basetransaction
from torba.stream import StreamController, execute_serially from torba.stream import StreamController, execute_serially
from torba.hash import hash160, double_sha256, Base58 from torba.hash import hash160, double_sha256, Base58
log = logging.getLogger(__name__)
class LedgerRegistry(type): class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]] ledgers = {} # type: Dict[str, Type[BaseLedger]]
@ -33,6 +37,10 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))):
pass
class BaseLedger(six.with_metaclass(LedgerRegistry)): class BaseLedger(six.with_metaclass(LedgerRegistry)):
name = None name = None
@ -67,6 +75,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self._on_transaction_controller = StreamController() self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen(
lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format(
self.get_id(), e.address, e.height, e.is_verified, e.tx.hex_id)
)
)
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
self._transaction_processing_locks = {} self._transaction_processing_locks = {}
@ -133,19 +149,15 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_status(self, address): def get_local_status(self, address):
address_details = yield self.db.get_address(address) address_details = yield self.db.get_address(address)
history = address_details['history'] or b'' history = address_details['history'] or ''
if six.PY2: hash = hashlib.sha256(history.encode()).digest()
history = str(history)
hash = hashlib.sha256(history).digest()
defer.returnValue(hexlify(hash)) defer.returnValue(hexlify(hash))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_history(self, address): def get_local_history(self, address):
address_details = yield self.db.get_address(address) address_details = yield self.db.get_address(address)
history = address_details['history'] or b'' history = address_details['history'] or ''
if six.PY2: parts = history.split(':')[:-1]
history = str(history)
parts = history.split(b':')[:-1]
defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2])))) defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2]))))
@staticmethod @staticmethod
@ -162,7 +174,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks @defer.inlineCallbacks
def is_valid_transaction(self, tx, height): def is_valid_transaction(self, tx, height):
len(self.headers) < height or defer.returnValue(False) height <= len(self.headers) or defer.returnValue(False)
merkle = yield self.network.get_merkle(tx.hex_id.decode(), height) merkle = yield self.network.get_merkle(tx.hex_id.decode(), height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height] header = self.headers[height]
@ -193,6 +205,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
if headers['count'] <= 0: if headers['count'] <= 0:
break break
yield self.headers.connect(height_sought, unhexlify(headers['hex'])) yield self.headers.connect(height_sought, unhexlify(headers['hex']))
self._on_header_controller.add(height_sought)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_header(self, response): def process_header(self, response):
@ -202,6 +215,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
if header['height'] == len(self.headers): if header['height'] == len(self.headers):
# New header from network directly connects after the last local header. # New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
self._on_header_controller.add(len(self.headers))
elif header['height'] > len(self.headers): elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead. # New header is several heights ahead of local, do download instead.
yield self.update_headers() yield self.update_headers()
@ -239,45 +253,45 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
local_history = yield self.get_local_history(address) local_history = yield self.get_local_history(address)
synced_history = [] synced_history = []
for i, (hash, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
synced_history.append((hash, remote_height)) synced_history.append((hex_id, remote_height))
if i < len(local_history) and local_history[i] == (hash, remote_height): if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue continue
lock = self._transaction_processing_locks.setdefault(hash, defer.DeferredLock()) lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
yield lock.acquire() yield lock.acquire()
try: try:
# see if we have a local copy of transaction, otherwise fetch it from server # see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hash)) raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hex_id)[::-1])
save_tx = None save_tx = None
if raw is None: if raw is None:
_raw = yield self.network.get_transaction(hash) _raw = yield self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert' save_tx = 'insert'
else: else:
tx = self.transaction_class(unhexlify(raw)) tx = self.transaction_class(raw)
if remote_height > 0 and not is_verified: if remote_height > 0 and not is_verified:
is_verified = yield self.is_valid_transaction(tx, remote_height) is_verified = yield self.is_valid_transaction(tx, remote_height)
is_verified = 1 if is_verified else 0
if save_tx is None: if save_tx is None:
save_tx = 'update' save_tx = 'update'
yield self.db.save_transaction_io( yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address), save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(hash.decode(), height) for hash, height in synced_history).encode() ''.join('{}:{}:'.format(tx_id.decode(), tx_height) for tx_id, tx_height in synced_history)
) )
if save_tx is not None: self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
self._on_transaction_controller.add(tx)
finally: finally:
lock.release() lock.release()
if not lock.locked: if not lock.locked:
del self._transaction_processing_locks[hash] del self._transaction_processing_locks[hex_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def subscribe_history(self, address): def subscribe_history(self, address):

View file

@ -124,10 +124,11 @@ class BaseOutput(InputOutput):
script_class = BaseOutputScript script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator estimator_class = BaseOutputEffectiveAmountEstimator
def __init__(self, amount, script, txhash=None, index=None): def __init__(self, amount, script, txhash=None, index=None, txoid=None):
super(BaseOutput, self).__init__(txhash, index) super(BaseOutput, self).__init__(txhash, index)
self.amount = amount # type: int self.amount = amount # type: int
self.script = script # type: BaseOutputScript self.script = script # type: BaseOutputScript
self.txoid = txoid
def get_estimator(self, ledger): def get_estimator(self, ledger):
return self.estimator_class(ledger, self) return self.estimator_class(ledger, self)
@ -288,7 +289,7 @@ class BaseTransaction:
@classmethod @classmethod
@defer.inlineCallbacks @defer.inlineCallbacks
def pay(cls, outputs, funding_accounts, change_account): def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True):
# type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred # type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred
""" Efficiently spend utxos from funding_accounts to cover the new outputs. """ """ Efficiently spend utxos from funding_accounts to cover the new outputs. """
@ -307,15 +308,26 @@ class BaseTransaction:
if not spendables: if not spendables:
raise ValueError('Not enough funds to cover this transaction.') raise ValueError('Not enough funds to cover this transaction.')
spent_sum = sum(s.effective_amount for s in spendables) reserved_outputs = [s.txo.txoid for s in spendables]
if spent_sum > amount: if reserve_outputs:
change_address = yield change_account.change.get_or_create_usable_address() yield ledger.db.reserve_spent_outputs(reserved_outputs)
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = spent_sum - amount try:
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) spent_sum = sum(s.effective_amount for s in spendables)
if spent_sum > amount:
change_address = yield change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = spent_sum - amount
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
tx.add_inputs([s.txi for s in spendables])
yield tx.sign(funding_accounts)
except Exception:
if reserve_outputs:
yield ledger.db.release_reserved_outputs(reserved_outputs)
raise
tx.add_inputs([s.txi for s in spendables])
yield tx.sign(funding_accounts)
defer.returnValue(tx) defer.returnValue(tx)
@classmethod @classmethod
@ -354,3 +366,13 @@ class BaseTransaction:
@property @property
def output_sum(self): def output_sum(self):
return sum(o.amount for o in self.outputs) return sum(o.amount for o in self.outputs)
@defer.inlineCallbacks
def get_my_addresses(self, ledger):
addresses = set()
for txo in self.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = yield ledger.db.get_address(address)
if record is not None:
addresses.add(address)
defer.returnValue(list(addresses))

View file

@ -1,13 +1,17 @@
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from collections import Sequence from collections import Sequence
from typing import TypeVar, Generic
class ReadOnlyList(Sequence): T = TypeVar('T')
class ReadOnlyList(Sequence, Generic[T]):
def __init__(self, lst): def __init__(self, lst):
self.lst = lst self.lst = lst
def __getitem__(self, key): def __getitem__(self, key): # type: (int) -> T
return self.lst[key] return self.lst[key]
def __len__(self): def __len__(self):