+ tx.position, + tx.net_account_balance, + txo.is_my_account

This commit is contained in:
Lex Berezhny 2018-09-25 18:02:50 -04:00
parent 5977f42a7e
commit c29b4c476d
4 changed files with 151 additions and 51 deletions

View file

@ -54,6 +54,50 @@ class TestSizeAndFeeEstimation(unittest.TestCase):
self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size)
class TestAccountBalanceImpactFromTransaction(unittest.TestCase):
def test_is_my_account_not_set(self):
tx = get_transaction()
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
_ = tx.net_account_balance
tx.inputs[0].is_my_account = True
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
_ = tx.net_account_balance
tx.outputs[0].is_my_account = True
# all inputs/outputs are set now so it should work
_ = tx.net_account_balance
def test_paying_from_my_account_to_other_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = True
tx.outputs[0].is_my_account = False
tx.outputs[1].is_my_account = True
self.assertEqual(tx.net_account_balance, -200*CENT)
def test_paying_from_other_account_to_my_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = False
tx.outputs[0].is_my_account = True
tx.outputs[1].is_my_account = False
self.assertEqual(tx.net_account_balance, 190*CENT)
def test_paying_from_my_account_to_my_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = True
tx.outputs[0].is_my_account = True
tx.outputs[1].is_my_account = True
self.assertEqual(tx.net_account_balance, -10*CENT) # lost to fee
class TestTransactionSerialization(unittest.TestCase):
def test_genesis_transaction(self):
@ -217,7 +261,7 @@ class TransactionIOBalancing(unittest.TestCase):
save_tx = 'insert'
for utxo in utxos:
yield self.ledger.db.save_transaction_io(
save_tx, self.funding_tx, 1, True,
save_tx, self.funding_tx, True,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], ''
)

View file

@ -143,6 +143,7 @@ class BaseDatabase(SQLiteMixin):
txid text primary key,
raw blob not null,
height integer not null,
position integer not null,
is_verified boolean not null default 0
);
"""
@ -185,19 +186,20 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source)
}
def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history):
def save_transaction_io(self, save_tx, tx, is_verified, address, txhash, history):
def _steps(t):
if save_tx == 'insert':
self.execute(t, *self._insert_sql('tx', {
'txid': tx.id,
'raw': sqlite3.Binary(tx.raw),
'height': height,
'height': tx.height,
'position': tx.position,
'is_verified': is_verified
}))
elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified
'height': tx.height, 'position': tx.position, 'is_verified': is_verified
}, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in self.execute(
@ -260,19 +262,19 @@ class BaseDatabase(SQLiteMixin):
@defer.inlineCallbacks
def get_transaction(self, txid):
result = yield self.run_query(
"SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,)
"SELECT raw, height, position, is_verified FROM tx WHERE txid = ?", (txid,)
)
if result:
return result[0]
else:
return None, None, False
return None, None, None, False
@defer.inlineCallbacks
def get_transactions(self, account, offset=0, limit=100):
offset, limit = min(offset, 0), max(limit, 100)
account_id = account.public_key.address
tx_rows = yield self.run_query(
"""
SELECT txid, raw, height FROM tx WHERE txid IN (
SELECT txid, raw, height, position FROM tx WHERE txid IN (
SELECT txo.txid FROM txo
JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account
@ -281,47 +283,67 @@ class BaseDatabase(SQLiteMixin):
JOIN txo USING (txoid)
JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account
) ORDER BY height DESC LIMIT :offset, :limit
""", {'account': account.public_key.address, 'offset': offset, 'limit': limit}
) ORDER BY height DESC, position DESC LIMIT :offset, :limit
""", {
'account': account_id,
'offset': min(offset, 0),
'limit': max(limit, 100)
}
)
txids, txs = [], []
for row in tx_rows:
txids.append(row[0])
txs.append(account.ledger.transaction_class(raw=row[1], height=row[2]))
txs.append(account.ledger.transaction_class(
raw=row[1], height=row[2], position=row[3]
))
txo_rows = yield self.run_query(
"""
SELECT txoid, pubkey_address.chain
SELECT txoid, chain, account
FROM txo JOIN pubkey_address USING (address)
WHERE txid IN ({})
""".format(', '.join(['?']*len(txids))), txids
)
txos = dict(txo_rows)
txos = {}
for row in txo_rows:
txos[row[0]] = {
'is_change': row[1] == 1,
'is_my_account': row[2] == account_id
}
txi_rows = yield self.run_query(
referenced_txo_rows = yield self.run_query(
"""
SELECT txoid, txo.amount, txo.script, txo.txid, txo.position
FROM txi JOIN txo USING (txoid)
SELECT txoid, txo.amount, txo.script, txo.txid, txo.position, chain, account
FROM txi
JOIN txo USING (txoid)
JOIN pubkey_address USING (address)
WHERE txi.txid IN ({})
""".format(', '.join(['?']*len(txids))), txids
)
txis = {}
referenced_txos = {}
output_class = account.ledger.transaction_class.output_class
for row in txi_rows:
txis[row[0]] = output_class(
row[1],
output_class.script_class(row[2]),
TXRefImmutable.from_id(row[3]),
position=row[4]
for row in referenced_txo_rows:
referenced_txos[row[0]] = output_class(
amount=row[1],
script=output_class.script_class(row[2]),
tx_ref=TXRefImmutable.from_id(row[3]),
position=row[4],
is_change=row[5] == 1,
is_my_account=row[6] == account_id
)
for tx in txs:
for txi in tx.inputs:
if txi.txo_ref.id in txis:
txi.txo_ref = TXORefResolvable(txis[txi.txo_ref.id])
if txi.txo_ref.id in referenced_txos:
txi.txo_ref = TXORefResolvable(referenced_txos[txi.txo_ref.id])
for txo in tx.outputs:
if txo.id in txos:
txo.is_change = txos[txo.id] == 1
txo_meta = txos.get(txo.id)
if txo_meta is not None:
txo.is_change = txo_meta['is_change']
txo.is_my_account = txo_meta['is_my_account']
else:
txo.is_change = False
txo.is_my_account = False
return txs

View file

@ -40,7 +40,7 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))):
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'is_verified'))):
pass
@ -87,7 +87,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.on_transaction.listen(
lambda e: log.info(
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.height, e.is_verified, e.tx.id
self.get_id(), e.address, e.tx.height, e.is_verified, e.tx.id
)
)
@ -207,12 +207,13 @@ class BaseLedger(metaclass=LedgerRegistry):
return hexlify(working_branch[::-1])
@defer.inlineCallbacks
def is_valid_transaction(self, tx, height):
def validate_transaction_and_set_position(self, tx, height):
if not height <= len(self.headers):
return False
merkle = yield self.network.get_merkle(tx.id, height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height]
tx.position = merkle['pos']
return merkle_root == header['merkle_root']
@defer.inlineCallbacks
@ -365,23 +366,23 @@ class BaseLedger(metaclass=LedgerRegistry):
try:
# see if we have a local copy of transaction, otherwise fetch it from server
raw, _, is_verified = yield self.db.get_transaction(hex_id)
raw, _, position, is_verified = yield self.db.get_transaction(hex_id)
save_tx = None
if raw is None:
_raw = yield self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw))
tx = self.transaction_class(unhexlify(_raw), height=remote_height)
save_tx = 'insert'
else:
tx = self.transaction_class(raw)
tx = self.transaction_class(raw, height=remote_height)
if remote_height > 0 and not is_verified:
is_verified = yield self.is_valid_transaction(tx, remote_height)
if remote_height > 0 and (not is_verified or position is None):
is_verified = yield self.validate_transaction_and_set_position(tx, remote_height)
is_verified = 1 if is_verified else 0
if save_tx is None:
save_tx = 'update'
yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address),
save_tx, tx, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history)
)
@ -390,7 +391,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.get_id(), hex_id, address, remote_height, is_verified
)
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified))
self._on_transaction_controller.add(TransactionEvent(address, tx, is_verified))
except Exception:
log.exception('Failed to synchronize transaction:')

View file

@ -137,6 +137,13 @@ class BaseInput(InputOutput):
raise ValueError('Cannot resolve output to get amount.')
return self.txo_ref.txo.amount
@property
def is_my_account(self) -> int:
""" True if the output this input spends is yours. """
if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to determine ownership.')
return self.txo_ref.txo.is_my_account
@classmethod
def deserialize_from(cls, stream):
tx_ref = TXRefImmutable.from_hash(stream.read(32))
@ -181,14 +188,17 @@ class BaseOutput(InputOutput):
script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator
__slots__ = 'amount', 'script', 'is_change'
__slots__ = 'amount', 'script', 'is_change', 'is_my_account'
def __init__(self, amount: int, script: BaseOutputScript,
tx_ref: TXRef = None, position: int = None) -> None:
tx_ref: TXRef = None, position: int = None,
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None
) -> None:
super().__init__(tx_ref, position)
self.amount = amount
self.script = script
self.is_change = None
self.is_change = is_change
self.is_my_account = is_my_account
@property
def ref(self):
@ -227,14 +237,16 @@ class BaseTransaction:
input_class = BaseInput
output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0, height=None) -> None:
def __init__(self, raw=None, version: int=1, locktime: int=0,
height: int=-1, position: int=-1) -> None:
self._raw = raw
self.ref = TXRefMutable(self)
self.version = version # type: int
self.locktime = locktime # type: int
self._inputs = [] # type: List[BaseInput]
self._outputs = [] # type: List[BaseOutput]
self.version = version
self.locktime = locktime
self._inputs: List[BaseInput] = []
self._outputs: List[BaseOutput] = []
self.height = height
self.position = position
if raw is not None:
self._deserialize()
@ -257,11 +269,11 @@ class BaseTransaction:
self.ref.reset()
@property
def inputs(self): # type: () -> ReadOnlyList[BaseInput]
def inputs(self) -> ReadOnlyList[BaseInput]:
return ReadOnlyList(self._inputs)
@property
def outputs(self): # type: () -> ReadOnlyList[BaseOutput]
def outputs(self) -> ReadOnlyList[BaseOutput]:
return ReadOnlyList(self._outputs)
def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
@ -301,18 +313,39 @@ class BaseTransaction:
return sum(o.amount for o in self.outputs)
@property
def fee(self):
def net_account_balance(self) -> int:
balance = 0
for txi in self.inputs:
if txi.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
elif txi.is_my_account:
balance -= txi.amount
for txo in self.outputs:
if txo.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
elif txo.is_my_account:
balance += txo.amount
return balance
@property
def fee(self) -> int:
return self.input_sum - self.output_sum
def get_base_fee(self, ledger):
def get_base_fee(self, ledger) -> int:
""" Fee for base tx excluding inputs and outputs. """
return self.base_size * ledger.fee_per_byte
def get_effective_input_sum(self, ledger):
def get_effective_input_sum(self, ledger) -> int:
""" Sum of input values *minus* the cost involved to spend them. """
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
def get_total_output_sum(self, ledger):
def get_total_output_sum(self, ledger) -> int:
""" Sum of output values *plus* the cost involved to spend them. """
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)