refactored how transactions are created, fixed list addresses command

This commit is contained in:
Lex Berezhny 2018-08-03 12:31:50 -04:00 committed by Jack Robison
parent 9ad9eb083b
commit fcd46629c4
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
9 changed files with 52 additions and 70 deletions

View file

@ -6,7 +6,6 @@ from binascii import hexlify
from twisted.internet import defer
from lbrynet.file_manager.EncryptedFileCreator import create_lbry_file
from lbrynet.wallet.account import get_certificate_lookup
log = logging.getLogger(__name__)
@ -58,7 +57,7 @@ class Publisher:
log.info("Removed old stream for claim update: %s", lbry_file.stream_hash)
yield self.storage.save_content_claim(
self.lbry_file.stream_hash, get_certificate_lookup(tx, 0)
self.lbry_file.stream_hash, tx.outputs[0].id
)
defer.returnValue(tx)
@ -70,7 +69,7 @@ class Publisher:
)
if stream_hash: # the stream_hash returned from the db will be None if this isn't a stream we have
yield self.storage.save_content_claim(
stream_hash.decode(), get_certificate_lookup(tx, 0)
stream_hash.decode(), tx.outputs[0].id
)
self.lbry_file = [f for f in self.lbry_file_manager.lbry_files if f.stream_hash == stream_hash][0]
defer.returnValue(tx)

View file

@ -4,11 +4,11 @@ from binascii import unhexlify
from twisted.internet import defer
from torba.baseaccount import BaseAccount
from torba.basetransaction import TXORef
from lbryschema.claim import ClaimDict
from lbryschema.signer import SECP256k1, get_signer
from .transaction import Transaction
log = logging.getLogger(__name__)
@ -18,26 +18,18 @@ def generate_certificate():
return ClaimDict.generate_certificate(secp256k1_private_key, curve=SECP256k1), secp256k1_private_key
def get_certificate_lookup(tx_or_hash, nout):
if isinstance(tx_or_hash, Transaction):
return '{}:{}'.format(tx_or_hash.id, nout)
else:
return '{}:{}'.format(tx_or_hash, nout)
class Account(BaseAccount):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.certificates = {}
def add_certificate_private_key(self, tx_or_hash, nout, private_key):
lookup_key = get_certificate_lookup(tx_or_hash, nout)
assert lookup_key not in self.certificates, 'Trying to add a duplicate certificate.'
self.certificates[lookup_key] = private_key
def add_certificate_private_key(self, ref: TXORef, private_key):
assert ref.id not in self.certificates, 'Trying to add a duplicate certificate.'
self.certificates[ref.id] = private_key
def get_certificate_private_key(self, tx_or_hash, nout):
return self.certificates.get(get_certificate_lookup(tx_or_hash, nout))
def get_certificate_private_key(self, ref: TXORef):
return self.certificates.get(ref.id)
@defer.inlineCallbacks
def maybe_migrate_certificates(self):
@ -81,7 +73,7 @@ class Account(BaseAccount):
return super().get_unspent_outputs(**constraints)
@classmethod
def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount
def from_dict(cls, ledger, d: dict) -> 'Account':
account = super().from_dict(ledger, d)
account.certificates = d['certificates']
return account

View file

@ -1,5 +1,5 @@
from collections import namedtuple
class Certificate(namedtuple('Certificate', ('txhash', 'nout', 'claim_id', 'name', 'private_key'))):
class Certificate(namedtuple('Certificate', ('txid', 'nout', 'claim_id', 'name', 'private_key'))):
pass

View file

@ -2,6 +2,7 @@ from binascii import hexlify
from twisted.internet import defer
from torba.basedatabase import BaseDatabase
from torba.hash import TXRefImmutable
from torba.basetransaction import TXORef
from .certificate import Certificate
@ -61,12 +62,12 @@ class WalletDatabase(BaseDatabase):
certificates = []
# Lookup private keys for each certificate.
if private_key_accounts is not None:
for txhash, nout, claim_id in txos:
for txid, nout, claim_id in txos:
for account in private_key_accounts:
private_key = account.get_certificate_private_key(
txhash, nout
TXORef(TXRefImmutable.from_id(txid), nout)
)
certificates.append(Certificate(txhash, nout, claim_id, name, private_key))
certificates.append(Certificate(txid, nout, claim_id, name, private_key))
if exclude_without_key:
defer.returnValue([

View file

@ -128,20 +128,6 @@ class MainNetLedger(BaseLedger):
super().__init__(*args, **kwargs)
self.fee_per_name_char = self.config.get('fee_per_name_char', self.default_fee_per_name_char)
def get_transaction_base_fee(self, tx):
""" Fee for the transaction header and all outputs; without inputs. """
return max(
super().get_transaction_base_fee(tx),
self.get_transaction_claim_name_fee(tx)
)
def get_transaction_claim_name_fee(self, tx):
fee = 0
for output in tx.outputs:
if output.script.is_claim_name:
fee += len(output.script.values['claim_name']) * self.fee_per_name_char
return fee
@defer.inlineCallbacks
def resolve(self, page, page_size, *uris):
for uri in uris:

View file

@ -122,6 +122,9 @@ class LbryWalletManager(BaseWalletManager):
def get_new_address(self):
return self.get_unused_address()
def list_addresses(self):
return self.default_account.get_addresses()
def reserve_points(self, address, amount):
# TODO: check if we have enough to cover amount
return ReservedPoints(address, amount)
@ -210,7 +213,7 @@ class LbryWalletManager(BaseWalletManager):
cert, key = generate_certificate()
tx = yield Transaction.claim(channel_name.encode(), cert, amount, address, [account], account)
yield account.ledger.broadcast(tx)
account.add_certificate_private_key(tx, 0, key.decode())
account.add_certificate_private_key(tx.outputs[0].ref, key.decode())
# TODO: release reserved tx outputs in case anything fails by this point
defer.returnValue(tx)

View file

@ -1,9 +1,9 @@
import struct
from typing import List # pylint: disable=unused-import
from typing import List, Iterable # pylint: disable=unused-import
from twisted.internet import defer # pylint: disable=unused-import
from torba.baseaccount import BaseAccount # pylint: disable=unused-import
from .account import Account # pylint: disable=unused-import
from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput
from torba.hash import hash160
@ -22,6 +22,12 @@ class Input(BaseInput):
class Output(BaseOutput):
script_class = OutputScript
def get_fee(self, ledger):
name_fee = 0
if self.script.is_claim_name:
name_fee = len(self.script.values['claim_name']) * ledger.fee_per_name_char
return max(name_fee, super().get_fee(ledger))
@classmethod
def pay_claim_name_pubkey_hash(cls, amount, claim_name, claim, pubkey_hash):
script = cls.script_class.pay_claim_name_pubkey_hash(claim_name, claim, pubkey_hash)
@ -40,14 +46,13 @@ class Transaction(BaseTransaction):
@classmethod
def claim(cls, name, meta, amount, holding_address, funding_accounts, change_account):
# type: (bytes, ClaimDict, int, bytes, List[BaseAccount], BaseAccount) -> defer.Deferred
# type: (bytes, ClaimDict, int, bytes, List[Account], Account) -> defer.Deferred
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, meta.serialized, ledger.address_to_hash160(holding_address)
)
return cls.pay([claim_output], funding_accounts, change_account)
return cls.create([], [claim_output], funding_accounts, change_account)
@classmethod
def abandon(cls, utxo, funding_accounts, change_account):
# type: (Output, List[BaseAccount], BaseAccount) -> defer.Deferred
return cls.liquidate(utxo, funding_accounts, change_account)
def abandon(cls, claims: Iterable[Output], funding_accounts: Iterable[Account], change_account: Account):
return cls.create([Input.spend(txo) for txo in claims], [], funding_accounts, change_account)

View file

@ -42,43 +42,39 @@ class TestSizeAndFeeEstimation(unittest.TestCase):
def setUp(self):
self.ledger = MainNetLedger({'db': WalletDatabase(':memory:')})
return self.ledger.db.start()
def io_fee(self, io):
return self.ledger.get_input_output_fee(io)
def test_output_size_and_fee(self):
txo = get_output()
self.assertEqual(txo.size, 46)
self.assertEqual(self.io_fee(txo), 46 * FEE_PER_BYTE)
self.assertEqual(txo.get_fee(self.ledger), 46 * FEE_PER_BYTE)
claim_name = b'verylongname'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - tx.inputs[0].size - tx.outputs[0].size
txo = tx.outputs[0]
self.assertEqual(tx.size, 4225)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(txo.size, 4067)
self.assertEqual(txo.get_fee(self.ledger), len(claim_name) * FEE_PER_CHAR)
# fee based on total bytes is the larger fee
claim_name = b'a'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - tx.inputs[0].size - tx.outputs[0].size
txo = tx.outputs[0]
self.assertEqual(tx.size, 4214)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(txo.size, 4056)
self.assertEqual(txo.get_fee(self.ledger), txo.size * FEE_PER_BYTE)
def test_input_size_and_fee(self):
txi = get_input()
self.assertEqual(txi.size, 148)
self.assertEqual(self.io_fee(txi), 148 * FEE_PER_BYTE)
self.assertEqual(txi.get_fee(self.ledger), 148 * FEE_PER_BYTE)
def test_transaction_size_and_fee(self):
tx = get_transaction()
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 204)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size)
def test_claim_name_transaction_size_and_fee(self):
# fee based on claim name is the larger fee
claim_name = b'verylongname'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 4225)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), len(claim_name) * FEE_PER_CHAR)
# fee based on total bytes is the larger fee
claim_name = b'a'
tx = get_claim_transaction(claim_name, b'0'*4000)
base_size = tx.size - 1 - tx.inputs[0].size
self.assertEqual(tx.size, 4214)
self.assertEqual(tx.base_size, base_size)
self.assertEqual(self.ledger.get_transaction_base_fee(tx), FEE_PER_BYTE * base_size)
self.assertEqual(tx.base_size, tx.size - tx.inputs[0].size - tx.outputs[0].size)
self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size)
class TestTransactionSerialization(unittest.TestCase):

View file

@ -17,4 +17,4 @@ setenv =
commands =
orchstr8 download
coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_transactions.BasicTransactionTest
coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_commands.CommonWorkflowTests
coverage run -p --source={envsitepackagesdir}/lbrynet -m twisted.trial --reactor=asyncio integration.wallet.test_commands.EpicAdventuresOfChris45