lbry-sdk/torba/basetransaction.py

462 lines
15 KiB
Python
Raw Normal View History

2018-05-25 08:03:25 +02:00
import logging
import typing
from typing import List, Iterable, Optional
2018-05-25 15:54:01 +02:00
from binascii import hexlify
2018-05-25 08:03:25 +02:00
2018-06-12 16:02:04 +02:00
from twisted.internet import defer
2018-05-25 08:03:25 +02:00
from torba.basescript import BaseInputScript, BaseOutputScript
from torba.baseaccount import BaseAccount
from torba.constants import COIN, NULL_HASH32
2018-05-25 08:03:25 +02:00
from torba.bcd_data_stream import BCDataStream
from torba.hash import sha256, TXRef, TXRefImmutable
2018-05-25 08:03:25 +02:00
from torba.util import ReadOnlyList
if typing.TYPE_CHECKING:
from torba import baseledger
2018-05-25 08:03:25 +02:00
log = logging.getLogger()
class TXRefMutable(TXRef):
2018-05-25 08:03:25 +02:00
__slots__ = ('tx',)
2018-05-25 08:03:25 +02:00
def __init__(self, tx: 'BaseTransaction') -> None:
super().__init__()
self.tx = tx
@property
def id(self):
if self._id is None:
self._id = hexlify(self.hash[::-1]).decode()
return self._id
@property
def hash(self):
if self._hash is None:
self._hash = sha256(sha256(self.tx.raw))
return self._hash
2018-05-25 08:03:25 +02:00
def reset(self):
self._id = None
self._hash = None
class TXORef:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef, position: int) -> None:
self.tx_ref = tx_ref
self.position = position
@property
def id(self):
return '{}:{}'.format(self.tx_ref.id, self.position)
@property
def is_null(self):
return self.tx_ref.is_null
@property
def txo(self) -> Optional['BaseOutput']:
return None
class TXORefResolvable(TXORef):
__slots__ = ('_txo',)
def __init__(self, txo: 'BaseOutput') -> None:
assert txo.tx_ref is not None
assert txo.position is not None
super().__init__(txo.tx_ref, txo.position)
self._txo = txo
2018-06-08 05:47:46 +02:00
@property
def txo(self):
return self._txo
class InputOutput:
__slots__ = 'tx_ref', 'position'
def __init__(self, tx_ref: TXRef = None, position: int = None) -> None:
self.tx_ref = tx_ref
self.position = position
2018-06-08 05:47:46 +02:00
2018-05-25 08:03:25 +02:00
@property
def size(self) -> int:
2018-05-25 08:03:25 +02:00
""" Size of this input / output in bytes. """
stream = BCDataStream()
self.serialize_to(stream)
return len(stream.get_bytes())
def get_fee(self, ledger):
return self.size * ledger.fee_per_byte
def serialize_to(self, stream, alternate_script=None):
raise NotImplementedError
2018-05-25 08:03:25 +02:00
class BaseInput(InputOutput):
2018-06-11 15:33:32 +02:00
script_class = BaseInputScript
2018-05-25 08:03:25 +02:00
NULL_SIGNATURE = b'\x00'*72
NULL_PUBLIC_KEY = b'\x00'*33
__slots__ = 'txo_ref', 'sequence', 'coinbase', 'script'
def __init__(self, txo_ref: TXORef, script: BaseInputScript, sequence: int = 0xFFFFFFFF,
tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position)
self.txo_ref = txo_ref
2018-05-25 08:03:25 +02:00
self.sequence = sequence
self.coinbase = script if txo_ref.is_null else None
self.script = script if not txo_ref.is_null else None
@property
def is_coinbase(self):
return self.coinbase is not None
2018-05-25 08:03:25 +02:00
@classmethod
def spend(cls, txo: 'BaseOutput') -> 'BaseInput':
2018-05-25 08:03:25 +02:00
""" Create an input to spend the output."""
assert txo.script.is_pay_pubkey_hash, 'Attempting to spend unsupported output.'
2018-05-25 08:03:25 +02:00
script = cls.script_class.redeem_pubkey_hash(cls.NULL_SIGNATURE, cls.NULL_PUBLIC_KEY)
return cls(txo.ref, script)
2018-05-25 08:03:25 +02:00
@property
def amount(self) -> int:
2018-05-25 08:03:25 +02:00
""" Amount this input adds to the transaction. """
if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to get amount.')
return self.txo_ref.txo.amount
2018-05-25 08:03:25 +02:00
@classmethod
def deserialize_from(cls, stream):
tx_ref = TXRefImmutable.from_hash(stream.read(32))
position = stream.read_uint32()
2018-05-25 08:03:25 +02:00
script = stream.read_string()
sequence = stream.read_uint32()
return cls(
TXORef(tx_ref, position),
cls.script_class(script) if not tx_ref.is_null else script,
2018-05-25 08:03:25 +02:00
sequence
)
def serialize_to(self, stream, alternate_script=None):
stream.write(self.txo_ref.tx_ref.hash)
stream.write_uint32(self.txo_ref.position)
2018-05-25 08:03:25 +02:00
if alternate_script is not None:
stream.write_string(alternate_script)
else:
if self.is_coinbase:
stream.write_string(self.coinbase)
else:
stream.write_string(self.script.source)
stream.write_uint32(self.sequence)
class BaseOutputEffectiveAmountEstimator:
2018-05-25 08:03:25 +02:00
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
2018-05-25 08:03:25 +02:00
def __init__(self, ledger: 'baseledger.BaseLedger', txo: 'BaseOutput') -> None:
2018-06-04 02:13:30 +02:00
self.txo = txo
2018-06-08 05:47:46 +02:00
self.txi = ledger.transaction_class.input_class.spend(txo)
self.fee: int = self.txi.get_fee(ledger)
self.effective_amount: int = txo.amount - self.fee
2018-05-25 08:03:25 +02:00
def __lt__(self, other):
return self.effective_amount < other.effective_amount
class BaseOutput(InputOutput):
2018-06-11 15:33:32 +02:00
script_class = BaseOutputScript
2018-06-08 05:47:46 +02:00
estimator_class = BaseOutputEffectiveAmountEstimator
2018-05-25 08:03:25 +02:00
__slots__ = 'amount', 'script'
def __init__(self, amount: int, script: BaseOutputScript,
tx_ref: TXRef = None, position: int = None) -> None:
super().__init__(tx_ref, position)
self.amount = amount
self.script = script
@property
def ref(self):
return TXORefResolvable(self)
@property
def id(self):
return self.ref.id
2018-05-25 08:03:25 +02:00
def get_address(self, ledger):
return ledger.hash160_to_address(
self.script.values['pubkey_hash']
)
2018-06-08 05:47:46 +02:00
def get_estimator(self, ledger):
return self.estimator_class(ledger, self)
2018-05-25 08:03:25 +02:00
@classmethod
def pay_pubkey_hash(cls, amount, pubkey_hash):
return cls(amount, cls.script_class.pay_pubkey_hash(pubkey_hash))
@classmethod
def deserialize_from(cls, stream):
return cls(
amount=stream.read_uint64(),
script=cls.script_class(stream.read_string())
)
def serialize_to(self, stream, alternate_script=None):
2018-05-25 08:03:25 +02:00
stream.write_uint64(self.amount)
stream.write_string(self.script.source)
class BaseTransaction:
2018-06-11 15:33:32 +02:00
input_class = BaseInput
output_class = BaseOutput
2018-05-25 08:03:25 +02:00
def __init__(self, raw=None, version=1, locktime=0) -> None:
2018-05-25 08:03:25 +02:00
self._raw = raw
self.ref = TXRefMutable(self)
2018-05-25 08:03:25 +02:00
self.version = version # type: int
self.locktime = locktime # type: int
self._inputs = [] # type: List[BaseInput]
self._outputs = [] # type: List[BaseOutput]
if raw is not None:
self._deserialize()
@property
def id(self):
return self.ref.id
2018-05-25 08:03:25 +02:00
@property
def hash(self):
return self.ref.hash
2018-05-25 08:03:25 +02:00
@property
def raw(self):
if self._raw is None:
self._raw = self._serialize()
return self._raw
def _reset(self):
self._raw = None
self.ref.reset()
2018-05-25 08:03:25 +02:00
@property
def inputs(self): # type: () -> ReadOnlyList[BaseInput]
return ReadOnlyList(self._inputs)
@property
def outputs(self): # type: () -> ReadOnlyList[BaseOutput]
return ReadOnlyList(self._outputs)
def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
2018-06-08 05:47:46 +02:00
for txio in new_ios:
txio.tx_ref = self.ref
txio.position = len(existing_ios)
2018-06-08 05:47:46 +02:00
existing_ios.append(txio)
2018-05-25 08:03:25 +02:00
self._reset()
return self
def add_inputs(self, inputs: Iterable[BaseInput]) -> 'BaseTransaction':
2018-06-08 05:47:46 +02:00
return self._add(inputs, self._inputs)
def add_outputs(self, outputs: Iterable[BaseOutput]) -> 'BaseTransaction':
2018-06-08 05:47:46 +02:00
return self._add(outputs, self._outputs)
2018-05-25 08:03:25 +02:00
@property
def size(self) -> int:
2018-05-25 08:03:25 +02:00
""" Size in bytes of the entire transaction. """
return len(self.raw)
@property
def base_size(self) -> int:
""" Size of transaction without inputs or outputs in bytes. """
return (
self.size
- sum(txi.size for txi in self._inputs)
- sum(txo.size for txo in self._outputs)
)
@property
def input_sum(self):
return sum(i.amount for i in self.inputs)
@property
def output_sum(self):
return sum(o.amount for o in self.outputs)
2018-08-03 18:29:02 +02:00
@property
def fee(self):
return self.input_sum - self.output_sum
def get_base_fee(self, ledger):
""" Fee for base tx excluding inputs and outputs. """
return self.base_size * ledger.fee_per_byte
def get_effective_input_sum(self, ledger):
""" Sum of input values *minus* the cost involved to spend them. """
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
def get_total_output_sum(self, ledger):
""" Sum of output values *plus* the cost involved to spend them. """
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)
2018-05-25 08:03:25 +02:00
def _serialize(self, with_inputs: bool = True) -> bytes:
2018-05-25 08:03:25 +02:00
stream = BCDataStream()
stream.write_uint32(self.version)
if with_inputs:
stream.write_compact_size(len(self._inputs))
for txin in self._inputs:
txin.serialize_to(stream)
stream.write_compact_size(len(self._outputs))
for txout in self._outputs:
txout.serialize_to(stream)
stream.write_uint32(self.locktime)
return stream.get_bytes()
def _serialize_for_signature(self, signing_input: int) -> bytes:
2018-05-25 08:03:25 +02:00
stream = BCDataStream()
stream.write_uint32(self.version)
stream.write_compact_size(len(self._inputs))
for i, txin in enumerate(self._inputs):
if signing_input == i:
assert txin.txo_ref.txo is not None
txin.serialize_to(stream, txin.txo_ref.txo.script.source)
2018-05-25 08:03:25 +02:00
else:
txin.serialize_to(stream, b'')
stream.write_compact_size(len(self._outputs))
for txout in self._outputs:
txout.serialize_to(stream)
stream.write_uint32(self.locktime)
2018-06-14 02:57:57 +02:00
stream.write_uint32(self.signature_hash_type(1)) # signature hash type: SIGHASH_ALL
2018-05-25 08:03:25 +02:00
return stream.get_bytes()
def _deserialize(self):
if self._raw is not None:
stream = BCDataStream(self._raw)
self.version = stream.read_uint32()
input_count = stream.read_compact_size()
self.add_inputs([
self.input_class.deserialize_from(stream) for _ in range(input_count)
])
output_count = stream.read_compact_size()
self.add_outputs([
self.output_class.deserialize_from(stream) for _ in range(output_count)
])
self.locktime = stream.read_uint32()
2018-06-08 05:47:46 +02:00
@classmethod
2018-07-29 19:13:40 +02:00
def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount],
change_account: BaseAccount = None) -> 'baseledger.BaseLedger':
2018-06-08 05:47:46 +02:00
ledger = None
for account in funding_accounts:
if ledger is None:
2018-06-11 15:33:32 +02:00
ledger = account.ledger
if ledger != account.ledger:
2018-06-08 05:47:46 +02:00
raise ValueError(
'All funding accounts used to create a transaction must be on the same ledger.'
)
2018-06-11 15:33:32 +02:00
if change_account is not None and change_account.ledger != ledger:
2018-06-08 05:47:46 +02:00
raise ValueError('Change account must use same ledger as funding accounts.')
if ledger is None:
raise ValueError('No ledger found.')
2018-06-08 05:47:46 +02:00
return ledger
@classmethod
2018-06-12 16:02:04 +02:00
@defer.inlineCallbacks
def create(cls, inputs: Iterable[BaseInput], outputs: Iterable[BaseOutput],
funding_accounts: Iterable[BaseAccount], change_account: BaseAccount):
""" Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """
tx = cls() \
.add_inputs(inputs) \
.add_outputs(outputs)
2018-06-08 05:47:46 +02:00
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
# value of the outputs plus associated fees
cost = (
tx.get_base_fee(ledger) +
tx.get_total_output_sum(ledger)
)
# value of the inputs less the cost to spend those inputs
payment = tx.get_effective_input_sum(ledger)
2018-06-08 05:47:46 +02:00
try:
if payment < cost:
deficit = cost - payment
spendables = yield ledger.get_spendable_utxos(deficit, funding_accounts)
if not spendables:
raise ValueError('Not enough funds to cover this transaction.')
payment += sum(s.effective_amount for s in spendables)
tx.add_inputs(s.txi for s in spendables)
if payment > cost:
cost_of_change = (
tx.get_base_fee(ledger) +
cls.output_class.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(ledger)
)
change = payment - cost
if change > cost_of_change:
change_address = yield change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change
tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)])
yield tx.sign(funding_accounts)
except Exception as e:
log.exception('Failed to synchronize transaction:')
yield ledger.release_outputs(tx.outputs)
raise e
2018-06-08 05:47:46 +02:00
2018-06-12 16:02:04 +02:00
defer.returnValue(tx)
2018-06-08 05:47:46 +02:00
@staticmethod
def signature_hash_type(hash_type):
2018-06-14 02:57:57 +02:00
return hash_type
@defer.inlineCallbacks
def sign(self, funding_accounts: Iterable[BaseAccount]) -> defer.Deferred:
2018-06-08 05:47:46 +02:00
ledger = self.ensure_all_have_same_ledger(funding_accounts)
2018-05-25 08:03:25 +02:00
for i, txi in enumerate(self._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
2018-05-25 08:03:25 +02:00
if txo_script.is_pay_pubkey_hash:
2018-06-11 15:33:32 +02:00
address = ledger.hash160_to_address(txo_script.values['pubkey_hash'])
2018-06-14 02:57:57 +02:00
private_key = yield ledger.get_private_key_for_address(address)
2018-05-25 08:03:25 +02:00
tx = self._serialize_for_signature(i)
2018-06-14 02:57:57 +02:00
txi.script.values['signature'] = \
private_key.sign(tx) + bytes((self.signature_hash_type(1),))
2018-05-25 08:03:25 +02:00
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate()
2018-06-08 05:47:46 +02:00
else:
raise NotImplementedError("Don't know how to spend this output.")
2018-05-25 08:03:25 +02:00
self._reset()
@defer.inlineCallbacks
def get_my_addresses(self, ledger):
addresses = set()
for txo in self.outputs:
address = ledger.hash160_to_address(txo.script.values['pubkey_hash'])
record = yield ledger.db.get_address(address)
if record is not None:
addresses.add(address)
defer.returnValue(list(addresses))