* on_transaction now produces TransactionEvents with useful info and lots of other changes

This commit is contained in:
Lex Berezhny 2018-06-26 17:22:05 -04:00
parent aba3ed7ca0
commit 9a467f8840
6 changed files with 111 additions and 50 deletions

View file

@ -17,30 +17,35 @@ class BasicTransactionTests(IntegrationTestCase):
address = await account1.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop())
sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5)
await self.on_transaction(sendtxid) #mempool
await self.on_transaction_id(sendtxid) #mempool
await self.blockchain.generate(1)
await self.on_transaction(sendtxid) #confirmed
await self.on_transaction_id(sendtxid) #confirmed
self.assertEqual(await self.get_balance(account1), int(5.5*COIN))
self.assertEqual(await self.get_balance(account2), 0)
address = await account2.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop())
hash1 = self.ledger.address_to_hash160(address)
tx = await self.ledger.transaction_class.pay(
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, self.ledger.address_to_hash160(address))],
[self.ledger.transaction_class.output_class.pay_pubkey_hash(2*COIN, hash1)],
[account1], account1
).asFuture(asyncio.get_event_loop())
await self.broadcast(tx)
await self.on_transaction(tx.hex_id.decode()) #mempool
await self.on_transaction(tx) #mempool
tx2 = await self.ledger.transaction_class.pay(
[self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, self.ledger.address_to_hash160(address))],
[self.ledger.transaction_class.output_class.pay_pubkey_hash(1*COIN, hash1)],
[account1], account1
).asFuture(asyncio.get_event_loop())
await self.broadcast(tx2)
await self.on_transaction(tx2.hex_id.decode()) #mempool
await self.on_transaction(tx2) #mempool
await self.blockchain.generate(1)
await self.on_transaction(tx.hex_id.decode()) #confirmed
await asyncio.wait([
self.on_header(202),
self.on_transaction(tx),
self.on_transaction(tx2),
])
#self.assertEqual(round(await self.get_balance(account1)/COIN, 1), 3.5)
#self.assertEqual(round(await self.get_balance(account2)/COIN, 1), 2.0)

View file

@ -113,7 +113,7 @@ class BaseDatabase(SQLiteMixin):
txhash blob primary key,
raw blob not null,
height integer not null,
is_verified boolean not null default false
is_verified boolean not null default 0
);
"""
@ -136,7 +136,8 @@ class BaseDatabase(SQLiteMixin):
address blob references pubkey_address,
position integer not null,
amount integer not null,
script blob not null
script blob not null,
is_reserved boolean not null default 0
);
"""
@ -168,7 +169,7 @@ class BaseDatabase(SQLiteMixin):
elif save_tx == 'update':
t.execute(*self._update_sql("tx", {
'height': height, 'is_verified': is_verified
}, 'WHERE txhash = ?', (sqlite3.Binary(tx.hash),)
}, 'txhash = ?', (sqlite3.Binary(tx.hash),)
))
existing_txos = list(map(itemgetter(0), t.execute(
@ -209,18 +210,28 @@ class BaseDatabase(SQLiteMixin):
t.execute(
"UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?",
(sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address))
(history, history.count(':')//2, sqlite3.Binary(address))
)
return self.db.runInteraction(_steps)
def reserve_spent_outputs(self, txoids, is_reserved=True):
return self.db.runOperation(
"UPDATE txo SET is_reserved = ? WHERE txoid IN ({})".format(
', '.join(['?']*len(txoids))
), [is_reserved]+txoids
)
def release_reserved_outputs(self, txoids):
return self.reserve_spent_outputs(txoids, is_reserved=False)
@defer.inlineCallbacks
def get_transaction(self, txhash):
result = yield self.db.runQuery(
"SELECT raw, height, is_verified FROM tx WHERE txhash = ?", (sqlite3.Binary(txhash),)
)
if result:
defer.returnValue(*result[0])
defer.returnValue(result[0])
else:
defer.returnValue((None, None, False))
@ -244,9 +255,9 @@ class BaseDatabase(SQLiteMixin):
def get_utxos(self, account, output_class):
utxos = yield self.db.runQuery(
"""
SELECT amount, script, txhash, txo.position
SELECT amount, script, txhash, txo.position, txoid
FROM txo JOIN pubkey_address ON pubkey_address.address=txo.address
WHERE account=:account AND txoid NOT IN (SELECT txoid FROM txi)
WHERE account=:account AND txo.is_reserved=0 AND txoid NOT IN (SELECT txoid FROM txi)
""",
{'account': sqlite3.Binary(account.public_key.address)}
)
@ -255,7 +266,8 @@ class BaseDatabase(SQLiteMixin):
values[0],
output_class.script_class(values[1]),
values[2],
index=values[3]
index=values[3],
txoid=values[4]
) for values in utxos
])

View file

@ -1,14 +1,16 @@
import os
import struct
import logging
from binascii import unhexlify
from twisted.internet import threads, defer
import torba
from torba.stream import StreamController, execute_serially
from torba.util import int_to_hex, rev_hex, hash_encode
from torba.hash import double_sha256, pow_hash
log = logging.getLogger(__name__)
class BaseHeaders:
@ -32,7 +34,7 @@ class BaseHeaders:
@property
def height(self):
return len(self) - 1
return len(self)
def sync_read_length(self):
return os.path.getsize(self.path) // self.header_size
@ -76,7 +78,9 @@ class BaseHeaders:
_old_size = self._size
self._size = self.sync_read_length()
change = self._size - _old_size
#log.info('saved {} header blocks'.format(change))
log.info('{}: added {} header blocks, final height {}'.format(
self.ledger.get_id(), change, self.height)
)
self._on_change_controller.add(change)
def _iterate_headers(self, height, headers):

View file

@ -1,9 +1,11 @@
import os
import six
import hashlib
import logging
from binascii import hexlify, unhexlify
from typing import Dict, Type, Iterable, Generator
from operator import itemgetter
from collections import namedtuple
from twisted.internet import defer
@ -15,6 +17,8 @@ from torba import basetransaction
from torba.stream import StreamController, execute_serially
from torba.hash import hash160, double_sha256, Base58
log = logging.getLogger(__name__)
class LedgerRegistry(type):
ledgers = {} # type: Dict[str, Type[BaseLedger]]
@ -33,6 +37,10 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))):
pass
class BaseLedger(six.with_metaclass(LedgerRegistry)):
name = None
@ -67,6 +75,14 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
self._on_transaction_controller = StreamController()
self.on_transaction = self._on_transaction_controller.stream
self.on_transaction.listen(
lambda e: log.info('({}) on_transaction: address={}, height={}, is_verified={}, tx.id={}'.format(
self.get_id(), e.address, e.height, e.is_verified, e.tx.hex_id)
)
)
self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream
self._transaction_processing_locks = {}
@ -133,19 +149,15 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks
def get_local_status(self, address):
address_details = yield self.db.get_address(address)
history = address_details['history'] or b''
if six.PY2:
history = str(history)
hash = hashlib.sha256(history).digest()
history = address_details['history'] or ''
hash = hashlib.sha256(history.encode()).digest()
defer.returnValue(hexlify(hash))
@defer.inlineCallbacks
def get_local_history(self, address):
address_details = yield self.db.get_address(address)
history = address_details['history'] or b''
if six.PY2:
history = str(history)
parts = history.split(b':')[:-1]
history = address_details['history'] or ''
parts = history.split(':')[:-1]
defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2]))))
@staticmethod
@ -162,7 +174,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
@defer.inlineCallbacks
def is_valid_transaction(self, tx, height):
len(self.headers) < height or defer.returnValue(False)
height <= len(self.headers) or defer.returnValue(False)
merkle = yield self.network.get_merkle(tx.hex_id.decode(), height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
@ -193,6 +205,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
if headers['count'] <= 0:
break
yield self.headers.connect(height_sought, unhexlify(headers['hex']))
self._on_header_controller.add(height_sought)
@defer.inlineCallbacks
def process_header(self, response):
@ -202,6 +215,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
if header['height'] == len(self.headers):
# New header from network directly connects after the last local header.
yield self.headers.connect(len(self.headers), unhexlify(header['hex']))
self._on_header_controller.add(len(self.headers))
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.update_headers()
@ -239,45 +253,45 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
local_history = yield self.get_local_history(address)
synced_history = []
for i, (hash, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
for i, (hex_id, remote_height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)):
synced_history.append((hash, remote_height))
synced_history.append((hex_id, remote_height))
if i < len(local_history) and local_history[i] == (hash, remote_height):
if i < len(local_history) and local_history[i] == (hex_id, remote_height):
continue
lock = self._transaction_processing_locks.setdefault(hash, defer.DeferredLock())
lock = self._transaction_processing_locks.setdefault(hex_id, defer.DeferredLock())
yield lock.acquire()
try:
# see if we have a local copy of transaction, otherwise fetch it from server
raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hash))
raw, local_height, is_verified = yield self.db.get_transaction(unhexlify(hex_id)[::-1])
save_tx = None
if raw is None:
_raw = yield self.network.get_transaction(hash)
_raw = yield self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw))
save_tx = 'insert'
else:
tx = self.transaction_class(unhexlify(raw))
tx = self.transaction_class(raw)
if remote_height > 0 and not is_verified:
is_verified = yield self.is_valid_transaction(tx, remote_height)
is_verified = 1 if is_verified else 0
if save_tx is None:
save_tx = 'update'
yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(hash.decode(), height) for hash, height in synced_history).encode()
''.join('{}:{}:'.format(tx_id.decode(), tx_height) for tx_id, tx_height in synced_history)
)
if save_tx is not None:
self._on_transaction_controller.add(tx)
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
finally:
lock.release()
if not lock.locked:
del self._transaction_processing_locks[hash]
del self._transaction_processing_locks[hex_id]
@defer.inlineCallbacks
def subscribe_history(self, address):

View file

@ -124,10 +124,11 @@ class BaseOutput(InputOutput):
script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator
def __init__(self, amount, script, txhash=None, index=None):
def __init__(self, amount, script, txhash=None, index=None, txoid=None):
super(BaseOutput, self).__init__(txhash, index)
self.amount = amount # type: int
self.script = script # type: BaseOutputScript
self.txoid = txoid
def get_estimator(self, ledger):
return self.estimator_class(ledger, self)
@ -288,7 +289,7 @@ class BaseTransaction:
@classmethod
@defer.inlineCallbacks
def pay(cls, outputs, funding_accounts, change_account):
def pay(cls, outputs, funding_accounts, change_account, reserve_outputs=True):
# type: (List[BaseOutput], List[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> defer.Deferred
""" Efficiently spend utxos from funding_accounts to cover the new outputs. """
@ -307,15 +308,26 @@ class BaseTransaction:
if not spendables:
raise ValueError('Not enough funds to cover this transaction.')
spent_sum = sum(s.effective_amount for s in spendables)
if spent_sum > amount:
change_address = yield change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = spent_sum - amount
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
reserved_outputs = [s.txo.txoid for s in spendables]
if reserve_outputs:
yield ledger.db.reserve_spent_outputs(reserved_outputs)
try:
spent_sum = sum(s.effective_amount for s in spendables)
if spent_sum > amount:
change_address = yield change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = spent_sum - amount
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
tx.add_inputs([s.txi for s in spendables])
yield tx.sign(funding_accounts)
except Exception:
if reserve_outputs:
yield ledger.db.release_reserved_outputs(reserved_outputs)
raise
tx.add_inputs([s.txi for s in spendables])
yield tx.sign(funding_accounts)
defer.returnValue(tx)
@classmethod
@ -354,3 +366,13 @@ class BaseTransaction:
@property
def output_sum(self):
return sum(o.amount for o in self.outputs)
@defer.inlineCallbacks
def get_my_addresses(self, ledger):
addresses = set()
for txo in self.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = yield ledger.db.get_address(address)
if record is not None:
addresses.add(address)
defer.returnValue(list(addresses))

View file

@ -1,13 +1,17 @@
from binascii import unhexlify, hexlify
from collections import Sequence
from typing import TypeVar, Generic
class ReadOnlyList(Sequence):
T = TypeVar('T')
class ReadOnlyList(Sequence, Generic[T]):
def __init__(self, lst):
self.lst = lst
def __getitem__(self, key):
def __getitem__(self, key): # type: (int) -> T
return self.lst[key]
def __len__(self):