refactored how transactions are created, fixed list addresses command

This commit is contained in:
Lex Berezhny 2018-08-03 12:31:50 -04:00 committed by Jack Robison
parent 9ad9eb083b
commit fcd46629c4
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
9 changed files with 52 additions and 70 deletions

View file

@ -6,7 +6,6 @@ from binascii import hexlify
from twisted.internet import defer from twisted.internet import defer
from lbrynet.file_manager.EncryptedFileCreator import create_lbry_file from lbrynet.file_manager.EncryptedFileCreator import create_lbry_file
from lbrynet.wallet.account import get_certificate_lookup
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -58,7 +57,7 @@ class Publisher:
log.info("Removed old stream for claim update: %s", lbry_file.stream_hash) log.info("Removed old stream for claim update: %s", lbry_file.stream_hash)
yield self.storage.save_content_claim( yield self.storage.save_content_claim(
self.lbry_file.stream_hash, get_certificate_lookup(tx, 0) self.lbry_file.stream_hash, tx.outputs[0].id
) )
defer.returnValue(tx) defer.returnValue(tx)
@ -70,7 +69,7 @@ class Publisher:
) )
if stream_hash: # the stream_hash returned from the db will be None if this isn't a stream we have if stream_hash: # the stream_hash returned from the db will be None if this isn't a stream we have
yield self.storage.save_content_claim( yield self.storage.save_content_claim(
stream_hash.decode(), get_certificate_lookup(tx, 0) stream_hash.decode(), tx.outputs[0].id
) )
self.lbry_file = [f for f in self.lbry_file_manager.lbry_files if f.stream_hash == stream_hash][0] self.lbry_file = [f for f in self.lbry_file_manager.lbry_files if f.stream_hash == stream_hash][0]
defer.returnValue(tx) defer.returnValue(tx)

View file

@ -4,11 +4,11 @@ from binascii import unhexlify
from twisted.internet import defer from twisted.internet import defer
from torba.baseaccount import BaseAccount from torba.baseaccount import BaseAccount
from torba.basetransaction import TXORef
from lbryschema.claim import ClaimDict from lbryschema.claim import ClaimDict
from lbryschema.signer import SECP256k1, get_signer from lbryschema.signer import SECP256k1, get_signer
from .transaction import Transaction
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -18,26 +18,18 @@ def generate_certificate():
return ClaimDict.generate_certificate(secp256k1_private_key, curve=SECP256k1), secp256k1_private_key return ClaimDict.generate_certificate(secp256k1_private_key, curve=SECP256k1), secp256k1_private_key
def get_certificate_lookup(tx_or_hash, nout):
if isinstance(tx_or_hash, Transaction):
return '{}:{}'.format(tx_or_hash.id, nout)
else:
return '{}:{}'.format(tx_or_hash, nout)
class Account(BaseAccount): class Account(BaseAccount):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.certificates = {} self.certificates = {}
def add_certificate_private_key(self, tx_or_hash, nout, private_key): def add_certificate_private_key(self, ref: TXORef, private_key):
lookup_key = get_certificate_lookup(tx_or_hash, nout) assert ref.id not in self.certificates, 'Trying to add a duplicate certificate.'
assert lookup_key not in self.certificates, 'Trying to add a duplicate certificate.' self.certificates[ref.id] = private_key
self.certificates[lookup_key] = private_key
def get_certificate_private_key(self, tx_or_hash, nout): def get_certificate_private_key(self, ref: TXORef):
return self.certificates.get(get_certificate_lookup(tx_or_hash, nout)) return self.certificates.get(ref.id)
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_migrate_certificates(self): def maybe_migrate_certificates(self):
@ -81,7 +73,7 @@ class Account(BaseAccount):
return super().get_unspent_outputs(**constraints) return super().get_unspent_outputs(**constraints)
@classmethod @classmethod
def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount def from_dict(cls, ledger, d: dict) -> 'Account':
account = super().from_dict(ledger, d) account = super().from_dict(ledger, d)
account.certificates = d['certificates'] account.certificates = d['certificates']
return account return account

View file

@ -1,5 +1,5 @@
from collections import namedtuple from collections import namedtuple
class Certificate(namedtuple('Certificate', ('txhash', 'nout', 'claim_id', 'name', 'private_key'))): class Certificate(namedtuple('Certificate', ('txid', 'nout', 'claim_id', 'name', 'private_key'))):
pass pass

View file

@ -2,6 +2,7 @@ from binascii import hexlify
from twisted.internet import defer from twisted.internet import defer
from torba.basedatabase import BaseDatabase from torba.basedatabase import BaseDatabase
from torba.hash import TXRefImmutable from torba.hash import TXRefImmutable
from torba.basetransaction import TXORef
from .certificate import Certificate from .certificate import Certificate
@ -61,12 +62,12 @@ class WalletDatabase(BaseDatabase):
certificates = [] certificates = []
# Lookup private keys for each certificate. # Lookup private keys for each certificate.
if private_key_accounts is not None: if private_key_accounts is not None:
for txhash, nout, claim_id in txos: for txid, nout, claim_id in txos:
for account in private_key_accounts: for account in private_key_accounts:
private_key = account.get_certificate_private_key( private_key = account.get_certificate_private_key(
txhash, nout TXORef(TXRefImmutable.from_id(txid), nout)
) )
certificates.append(Certificate(txhash, nout, claim_id, name, private_key)) certificates.append(Certificate(txid, nout, claim_id, name, private_key))
if exclude_without_key: if exclude_without_key:
defer.returnValue([ defer.returnValue([

View file

@ -128,20 +128,6 @@ class MainNetLedger(BaseLedger):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char) self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char)
def get_transaction_base_fee(self, tx):
""" Fee for the transaction header and all outputs; without inputs. """
return max(
super().get_transaction_base_fee(tx),
self.get_transaction_claim_name_fee(tx)
)
def get_transaction_claim_name_fee(self, tx):
fee = 0
for output in tx.outputs:
if output.script.is_claim_name:
fee += len(output.script.values['claim_name']) * self.fee_per_name_char
return fee
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve(self, page, page_size, *uris): def resolve(self, page, page_size, *uris):
for uri in uris: for uri in uris:

View file

@ -122,6 +122,9 @@ class LbryWalletManager(BaseWalletManager):
def get_new_address(self): def get_new_address(self):
return self.get_unused_address() return self.get_unused_address()
def list_addresses(self):
return self.default_account.get_addresses()
def reserve_points(self, address, amount): def reserve_points(self, address, amount):
# TODO: check if we have enough to cover amount # TODO: check if we have enough to cover amount
return ReservedPoints(address, amount) return ReservedPoints(address, amount)
@ -210,7 +213,7 @@ class LbryWalletManager(BaseWalletManager):
cert, key = generate_certificate() cert, key = generate_certificate()
tx = yield Transaction.claim(channel_name.encode(), cert, amount, address, [account], account) tx = yield Transaction.claim(channel_name.encode(), cert, amount, address, [account], account)
yield account.ledger.broadcast(tx) yield account.ledger.broadcast(tx)
account.add_certificate_private_key(tx, 0, key.decode()) account.add_certificate_private_key(tx.outputs[0].ref, key.decode())
# TODO: release reserved tx outputs in case anything fails by this point # TODO: release reserved tx outputs in case anything fails by this point
defer.returnValue(tx) defer.returnValue(tx)

View file

@ -1,9 +1,9 @@
import struct import struct
from typing import List # pylint: disable=unused-import from typing import List, Iterable # pylint: disable=unused-import
from twisted.internet import defer # pylint: disable=unused-import from twisted.internet import defer # pylint: disable=unused-import
from torba.baseaccount import BaseAccount # pylint: disable=unused-import from .account import Account # pylint: disable=unused-import
from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.hash import hash160 from torba.hash import hash160
@ -22,6 +22,12 @@ class Input(BaseInput):
class Output(BaseOutput): class Output(BaseOutput):
script_class = OutputScript script_class = OutputScript
def get_fee(self, ledger):
name_fee = 0
if self.script.is_claim_name:
name_fee = len(self.script.values['claim_name']) * ledger.fee_per_name_char
return max(name_fee, super().get_fee(ledger))
@classmethod @classmethod
def pay_claim_name_pubkey_hash(cls, amount, claim_name, claim, pubkey_hash): def pay_claim_name_pubkey_hash(cls, amount, claim_name, claim, pubkey_hash):
script = cls.script_class.pay_claim_name_pubkey_hash(claim_name, claim, pubkey_hash) script = cls.script_class.pay_claim_name_pubkey_hash(claim_name, claim, pubkey_hash)
@ -40,14 +46,13 @@ class Transaction(BaseTransaction):
@classmethod @classmethod
def claim(cls, name, meta, amount, holding_address, funding_accounts, change_account): def claim(cls, name, meta, amount, holding_address, funding_accounts, change_account):
# type: (bytes, ClaimDict, int, bytes, List[BaseAccount], BaseAccount) -> defer.Deferred # type: (bytes, ClaimDict, int, bytes, List[Account], Account) -> defer.Deferred
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
claim_output = Output.pay_claim_name_pubkey_hash( claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, meta.serialized, ledger.address_to_hash160(holding_address) amount, name, meta.serialized, ledger.address_to_hash160(holding_address)
) )
return cls.pay([claim_output], funding_accounts, change_account) return cls.create([], [claim_output], funding_accounts, change_account)
@classmethod @classmethod
def abandon(cls, utxo, funding_accounts, change_account): def abandon(cls, claims: Iterable[Output], funding_accounts: Iterable[Account], change_account: Account):
# type: (Output, List[BaseAccount], BaseAccount) -> defer.Deferred return cls.create([Input.spend(txo) for txo in claims], [], funding_accounts, change_account)
return cls.liquidate(utxo, funding_accounts, change_account)

View file

@ -42,43 +42,39 @@ class TestSizeAndFeeEstimation(unittest.TestCase):
def setUp(self): def setUp(self):
self.ledger = MainNetLedger({'db': WalletDatabase(':memory:')}) self.ledger = MainNetLedger({'db': WalletDatabase(':memory:')})
return self.ledger.db.start()
def io_fee(self, io):
return self.ledger.get_input_output_fee(io)
def test_output_size_and_fee(self): def test_output_size_and_fee(self):
txo = get_output() txo = get_output()
self.assertEqual(txo.size, 46) self.assertEqual(txo.size, 46)
self.assertEqual(self.io_fee(txo), 46 * FEE_PER_BYTE) self.assertEqual(txo.get_fee(self.ledger), 46 * FEE_PER_BYTE)
claim_name = b'verylongname'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - tx.inputs[0].size - tx.outputs[0].size
txo = tx.outputs[0]
self.assertEqual(tx.size, 4225)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(txo.size, 4067)
self.assertEqual(txo.get_fee(self.ledger), len(claim_name) * FEE_PER_CHAR)
# fee based on total bytes is the larger fee
claim_name = b'a'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - tx.inputs[0].size - tx.outputs[0].size
txo = tx.outputs[0]
self.assertEqual(tx.size, 4214)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(txo.size, 4056)
self.assertEqual(txo.get_fee(self.ledger), txo.size * FEE_PER_BYTE)
def test_input_size_and_fee(self): def test_input_size_and_fee(self):
txi = get_input() txi = get_input()
self.assertEqual(txi.size, 148) self.assertEqual(txi.size, 148)
self.assertEqual(self.io_fee(txi), 148 * FEE_PER_BYTE) self.assertEqual(txi.get_fee(self.ledger), 148 * FEE_PER_BYTE)
def test_transaction_size_and_fee(self): def test_transaction_size_and_fee(self):
tx = get_transaction() tx = get_transaction()
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 204) self.assertEqual(tx.size, 204)
self.assertEqual(tx.base_size, base_size) self.assertEqual(tx.base_size, tx.size - tx.inputs[0].size - tx.outputs[0].size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size) self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size)
def test_claim_name_transaction_size_and_fee(self):
# fee based on claim name is the larger fee
claim_name = b'verylongname'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 4225)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), len(claim_name) * FEE_PER_CHAR)
# fee based on total bytes is the larger fee
claim_name = b'a'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 4214)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size)
class TestTransactionSerialization(unittest.TestCase): class TestTransactionSerialization(unittest.TestCase):

View file

@ -17,4 +17,4 @@ setenv =
commands = commands =
orchstr8 download orchstr8 download
coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_transactions.BasicTransactionTest coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_transactions.BasicTransactionTest
coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_commands.CommonWorkflowTests coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_commands.EpicAdventuresOfChris45