updated wallet to use asyncio

This commit is contained in:
Lex Berezhny 2018-10-15 17:16:43 -04:00
parent 330db61b36
commit 64e306801d
13 changed files with 198 additions and 260 deletions

View file

@ -1,8 +1,6 @@
import json import json
import logging import logging
from twisted.internet import defer
from torba.baseaccount import BaseAccount from torba.baseaccount import BaseAccount
from torba.basetransaction import TXORef from torba.basetransaction import TXORef
@ -31,8 +29,7 @@ class Account(BaseAccount):
def get_certificate_private_key(self, ref: TXORef): def get_certificate_private_key(self, ref: TXORef):
return self.certificates.get(ref.id) return self.certificates.get(ref.id)
@defer.inlineCallbacks async def maybe_migrate_certificates(self):
def maybe_migrate_certificates(self):
if not self.certificates: if not self.certificates:
return return
@ -49,7 +46,7 @@ class Account(BaseAccount):
for maybe_claim_id in list(self.certificates): for maybe_claim_id in list(self.certificates):
results['total'] += 1 results['total'] += 1
if ':' not in maybe_claim_id: if ':' not in maybe_claim_id:
claims = yield self.ledger.network.get_claims_by_ids(maybe_claim_id) claims = await self.ledger.network.get_claims_by_ids(maybe_claim_id)
if maybe_claim_id not in claims: if maybe_claim_id not in claims:
log.warning( log.warning(
"Failed to migrate claim '%s', server did not return any claim information.", "Failed to migrate claim '%s', server did not return any claim information.",
@ -60,7 +57,7 @@ class Account(BaseAccount):
claim = claims[maybe_claim_id] claim = claims[maybe_claim_id]
tx = None tx = None
if claim: if claim:
tx = yield self.ledger.db.get_transaction(txid=claim['txid']) tx = await self.ledger.db.get_transaction(txid=claim['txid'])
else: else:
log.warning(maybe_claim_id) log.warning(maybe_claim_id)
if tx is not None: if tx is not None:
@ -96,7 +93,7 @@ class Account(BaseAccount):
else: else:
try: try:
txid, nout = maybe_claim_id.split(':') txid, nout = maybe_claim_id.split(':')
tx = yield self.ledger.db.get_transaction(txid=txid) tx = await self.ledger.db.get_transaction(txid=txid)
if tx.outputs[int(nout)].script.is_claim_involved: if tx.outputs[int(nout)].script.is_claim_involved:
results['previous-success'] += 1 results['previous-success'] += 1
else: else:
@ -115,9 +112,8 @@ class Account(BaseAccount):
indent=2 indent=2
)) ))
@defer.inlineCallbacks async def save_max_gap(self):
def save_max_gap(self): gap = await self.get_max_gap()
gap = yield self.get_max_gap()
self.receiving.gap = max(20, gap['max_receiving_gap'] + 1) self.receiving.gap = max(20, gap['max_receiving_gap'] + 1)
self.change.gap = max(6, gap['max_change_gap'] + 1) self.change.gap = max(6, gap['max_change_gap'] + 1)
self.wallet.save() self.wallet.save()
@ -144,9 +140,8 @@ class Account(BaseAccount):
d['certificates'] = self.certificates d['certificates'] = self.certificates
return d return d
@defer.inlineCallbacks async def get_details(self, **kwargs):
def get_details(self, **kwargs): details = await super().get_details(**kwargs)
details = yield super().get_details(**kwargs)
details['certificates'] = len(self.certificates) details['certificates'] = len(self.certificates)
return details return details

View file

@ -1,4 +1,3 @@
from twisted.internet import defer
from torba.basedatabase import BaseDatabase from torba.basedatabase import BaseDatabase
@ -49,11 +48,10 @@ class WalletDatabase(BaseDatabase):
row['claim_name'] = txo.claim_name row['claim_name'] = txo.claim_name
return row return row
@defer.inlineCallbacks async def get_txos(self, **constraints):
def get_txos(self, **constraints):
my_account = constraints.get('my_account', constraints.get('account', None)) my_account = constraints.get('my_account', constraints.get('account', None))
txos = yield super().get_txos(**constraints) txos = await super().get_txos(**constraints)
channel_ids = set() channel_ids = set()
for txo in txos: for txo in txos:
@ -66,7 +64,7 @@ class WalletDatabase(BaseDatabase):
if channel_ids: if channel_ids:
channels = { channels = {
txo.claim_id: txo for txo in txo.claim_id: txo for txo in
(yield super().get_utxos( (await super().get_utxos(
my_account=my_account, my_account=my_account,
claim_id__in=channel_ids claim_id__in=channel_ids
)) ))
@ -103,9 +101,8 @@ class WalletDatabase(BaseDatabase):
self.constrain_channels(constraints) self.constrain_channels(constraints)
return self.get_claim_count(**constraints) return self.get_claim_count(**constraints)
@defer.inlineCallbacks async def get_certificates(self, private_key_accounts, exclude_without_key=False, **constraints):
def get_certificates(self, private_key_accounts, exclude_without_key=False, **constraints): channels = await self.get_channels(**constraints)
channels = yield self.get_channels(**constraints)
certificates = [] certificates = []
if private_key_accounts is not None: if private_key_accounts is not None:
for channel in channels: for channel in channels:

View file

@ -1,15 +1,12 @@
import asyncio
import logging import logging
from six import int2byte
from binascii import unhexlify from binascii import unhexlify
from twisted.internet import defer
from .resolve import Resolver
from lbryschema.error import URIParseError from lbryschema.error import URIParseError
from lbryschema.uri import parse_lbry_uri from lbryschema.uri import parse_lbry_uri
from torba.baseledger import BaseLedger from torba.baseledger import BaseLedger
from .resolve import Resolver
from .account import Account from .account import Account
from .network import Network from .network import Network
from .database import WalletDatabase from .database import WalletDatabase
@ -25,15 +22,17 @@ class MainNetLedger(BaseLedger):
symbol = 'LBC' symbol = 'LBC'
network_name = 'mainnet' network_name = 'mainnet'
headers: Headers
account_class = Account account_class = Account
database_class = WalletDatabase database_class = WalletDatabase
headers_class = Headers headers_class = Headers
network_class = Network network_class = Network
transaction_class = Transaction transaction_class = Transaction
secret_prefix = int2byte(0x1c) secret_prefix = bytes((0x1c,))
pubkey_address_prefix = int2byte(0x55) pubkey_address_prefix = bytes((0x55,))
script_address_prefix = int2byte(0x7a) script_address_prefix = bytes((0x7a,))
extended_public_key_prefix = unhexlify('0488b21e') extended_public_key_prefix = unhexlify('0488b21e')
extended_private_key_prefix = unhexlify('0488ade4') extended_private_key_prefix = unhexlify('0488ade4')
@ -54,45 +53,38 @@ class MainNetLedger(BaseLedger):
return Resolver(self.headers.claim_trie_root, self.headers.height, self.transaction_class, return Resolver(self.headers.claim_trie_root, self.headers.height, self.transaction_class,
hash160_to_address=self.hash160_to_address, network=self.network) hash160_to_address=self.hash160_to_address, network=self.network)
@defer.inlineCallbacks async def resolve(self, page, page_size, *uris):
def resolve(self, page, page_size, *uris):
for uri in uris: for uri in uris:
try: try:
parse_lbry_uri(uri) parse_lbry_uri(uri)
except URIParseError as err: except URIParseError as err:
defer.returnValue({'error': err.message}) return {'error': err.args[0]}
resolutions = yield self.network.get_values_for_uris(self.headers.hash().decode(), *uris) resolutions = await self.network.get_values_for_uris(self.headers.hash().decode(), *uris)
return (yield self.resolver._handle_resolutions(resolutions, uris, page, page_size)) return await self.resolver._handle_resolutions(resolutions, uris, page, page_size)
@defer.inlineCallbacks async def get_claim_by_claim_id(self, claim_id):
def get_claim_by_claim_id(self, claim_id): result = (await self.network.get_claims_by_ids(claim_id)).pop(claim_id, {})
result = (yield self.network.get_claims_by_ids(claim_id)).pop(claim_id, {}) return await self.resolver.get_certificate_and_validate_result(result)
return (yield self.resolver.get_certificate_and_validate_result(result))
@defer.inlineCallbacks async def get_claim_by_outpoint(self, txid, nout):
def get_claim_by_outpoint(self, txid, nout): claims = (await self.network.get_claims_in_tx(txid)) or []
claims = (yield self.network.get_claims_in_tx(txid)) or []
for claim in claims: for claim in claims:
if claim['nout'] == nout: if claim['nout'] == nout:
return (yield self.resolver.get_certificate_and_validate_result(claim)) return await self.resolver.get_certificate_and_validate_result(claim)
return 'claim not found' return 'claim not found'
@defer.inlineCallbacks async def start(self):
def start(self): await super().start()
yield super().start() await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts))
yield defer.DeferredList([ await asyncio.gather(*(a.save_max_gap() for a in self.accounts))
a.maybe_migrate_certificates() for a in self.accounts await self._report_state()
])
yield defer.DeferredList([a.save_max_gap() for a in self.accounts])
yield self._report_state()
@defer.inlineCallbacks async def _report_state(self):
def _report_state(self):
for account in self.accounts: for account in self.accounts:
total_receiving = len((yield account.receiving.get_addresses())) total_receiving = len((await account.receiving.get_addresses()))
total_change = len((yield account.change.get_addresses())) total_change = len((await account.change.get_addresses()))
channel_count = yield account.get_channel_count() channel_count = await account.get_channel_count()
claim_count = yield account.get_claim_count() claim_count = await account.get_claim_count()
log.info("Loaded account %s with %d receiving addresses (gap: %d), " log.info("Loaded account %s with %d receiving addresses (gap: %d), "
"%d change addresses (gap: %d), %d channels, %d certificates and %d claims.", "%d change addresses (gap: %d), %d channels, %d certificates and %d claims.",
account.id, total_receiving, account.receiving.gap, total_change, account.change.gap, account.id, total_receiving, account.receiving.gap, total_change, account.change.gap,
@ -101,8 +93,8 @@ class MainNetLedger(BaseLedger):
class TestNetLedger(MainNetLedger): class TestNetLedger(MainNetLedger):
network_name = 'testnet' network_name = 'testnet'
pubkey_address_prefix = int2byte(111) pubkey_address_prefix = bytes((111,))
script_address_prefix = int2byte(196) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')
@ -110,8 +102,8 @@ class TestNetLedger(MainNetLedger):
class RegTestLedger(MainNetLedger): class RegTestLedger(MainNetLedger):
network_name = 'regtest' network_name = 'regtest'
headers_class = UnvalidatedHeaders headers_class = UnvalidatedHeaders
pubkey_address_prefix = int2byte(111) pubkey_address_prefix = bytes((111,))
script_address_prefix = int2byte(196) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')

View file

@ -6,8 +6,6 @@ from binascii import unhexlify
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from twisted.internet import defer
from lbryschema.schema import SECP256k1 from lbryschema.schema import SECP256k1
from torba.basemanager import BaseWalletManager from torba.basemanager import BaseWalletManager
@ -73,7 +71,7 @@ class LbryWalletManager(BaseWalletManager):
return not self.default_account.encrypted return not self.default_account.encrypted
def check_locked(self): def check_locked(self):
return defer.succeed(self.default_account.encrypted) return self.default_account.encrypted
def decrypt_account(self, account): def decrypt_account(self, account):
assert account.password is not None, "account is not unlocked" assert account.password is not None, "account is not unlocked"
@ -157,8 +155,7 @@ class LbryWalletManager(BaseWalletManager):
return receiving_addresses, change_addresses return receiving_addresses, change_addresses
@classmethod @classmethod
@defer.inlineCallbacks async def from_lbrynet_config(cls, settings, db):
def from_lbrynet_config(cls, settings, db):
ledger_id = { ledger_id = {
'lbrycrd_main': 'lbc_mainnet', 'lbrycrd_main': 'lbc_mainnet',
@ -194,17 +191,16 @@ class LbryWalletManager(BaseWalletManager):
if receiving_addresses or change_addresses: if receiving_addresses or change_addresses:
if not os.path.exists(ledger.path): if not os.path.exists(ledger.path):
os.mkdir(ledger.path) os.mkdir(ledger.path)
yield ledger.db.open() await ledger.db.open()
try: try:
yield manager._migrate_addresses(receiving_addresses, change_addresses) await manager._migrate_addresses(receiving_addresses, change_addresses)
finally: finally:
yield ledger.db.close() await ledger.db.close()
defer.returnValue(manager) return manager
@defer.inlineCallbacks async def _migrate_addresses(self, receiving_addresses: set, change_addresses: set):
def _migrate_addresses(self, receiving_addresses: set, change_addresses: set): migrated_receiving = set((await self.default_account.receiving.generate_keys(0, len(receiving_addresses))))
migrated_receiving = set((yield self.default_account.receiving.generate_keys(0, len(receiving_addresses)))) migrated_change = set((await self.default_account.change.generate_keys(0, len(change_addresses))))
migrated_change = set((yield self.default_account.change.generate_keys(0, len(change_addresses))))
receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses)) receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses))
change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses)) change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses))
if not any(change_addresses.difference(migrated_change)): if not any(change_addresses.difference(migrated_change)):
@ -231,25 +227,23 @@ class LbryWalletManager(BaseWalletManager):
# TODO: check if we have enough to cover amount # TODO: check if we have enough to cover amount
return ReservedPoints(address, amount) return ReservedPoints(address, amount)
@defer.inlineCallbacks async def send_amount_to_address(self, amount: int, destination_address: bytes, account=None):
def send_amount_to_address(self, amount: int, destination_address: bytes, account=None):
account = account or self.default_account account = account or self.default_account
tx = yield Transaction.pay(amount, destination_address, [account], account) tx = await Transaction.pay(amount, destination_address, [account], account)
yield account.ledger.broadcast(tx) await account.ledger.broadcast(tx)
return tx return tx
@defer.inlineCallbacks async def send_claim_to_address(self, claim_id: str, destination_address: str, amount: Optional[int],
def send_claim_to_address(self, claim_id: str, destination_address: str, amount: Optional[int],
account=None): account=None):
account = account or self.default_account account = account or self.default_account
claims = account.ledger.db.get_utxos(claim_id=claim_id) claims = await account.ledger.db.get_utxos(claim_id=claim_id)
if not claims: if not claims:
raise NameError("Claim not found: {}".format(claim_id)) raise NameError("Claim not found: {}".format(claim_id))
tx = yield Transaction.update( tx = await Transaction.update(
claims[0], ClaimDict.deserialize(claims[0].script.value['claim']), amount, claims[0], ClaimDict.deserialize(claims[0].script.value['claim']), amount,
destination_address.encode(), [account], account destination_address.encode(), [account], account
) )
yield self.ledger.broadcast(tx) await self.ledger.broadcast(tx)
return tx return tx
def send_points_to_address(self, reserved: ReservedPoints, amount: int, account=None): def send_points_to_address(self, reserved: ReservedPoints, amount: int, account=None):
@ -262,23 +256,21 @@ class LbryWalletManager(BaseWalletManager):
def get_info_exchanger(self): def get_info_exchanger(self):
return LBRYcrdAddressRequester(self) return LBRYcrdAddressRequester(self)
@defer.inlineCallbacks async def resolve(self, *uris, **kwargs):
def resolve(self, *uris, **kwargs):
page = kwargs.get('page', 0) page = kwargs.get('page', 0)
page_size = kwargs.get('page_size', 10) page_size = kwargs.get('page_size', 10)
check_cache = kwargs.get('check_cache', False) # TODO: put caching back (was force_refresh parameter) check_cache = kwargs.get('check_cache', False) # TODO: put caching back (was force_refresh parameter)
ledger = self.default_account.ledger # type: MainNetLedger ledger: MainNetLedger = self.default_account.ledger
results = yield ledger.resolve(page, page_size, *uris) results = await ledger.resolve(page, page_size, *uris)
yield self.old_db.save_claims_for_resolve( await self.old_db.save_claims_for_resolve(
(value for value in results.values() if 'error' not in value)) (value for value in results.values() if 'error' not in value))
defer.returnValue(results) return results
def get_claims_for_name(self, name: str): def get_claims_for_name(self, name: str):
return self.ledger.network.get_claims_for_name(name) return self.ledger.network.get_claims_for_name(name)
@defer.inlineCallbacks async def address_is_mine(self, unknown_address, account):
def address_is_mine(self, unknown_address, account): match = await self.ledger.db.get_address(address=unknown_address, account=account)
match = yield self.ledger.db.get_address(address=unknown_address, account=account)
if match is not None: if match is not None:
return True return True
return False return False
@ -287,10 +279,9 @@ class LbryWalletManager(BaseWalletManager):
return self.default_account.ledger.get_transaction(txid) return self.default_account.ledger.get_transaction(txid)
@staticmethod @staticmethod
@defer.inlineCallbacks async def get_history(account: BaseAccount, **constraints):
def get_history(account: BaseAccount, **constraints):
headers = account.ledger.headers headers = account.ledger.headers
txs = (yield account.get_transactions(**constraints)) txs = await account.get_transactions(**constraints)
history = [] history = []
for tx in txs: for tx in txs:
ts = headers[tx.height]['timestamp'] ts = headers[tx.height]['timestamp']
@ -346,29 +337,28 @@ class LbryWalletManager(BaseWalletManager):
def get_utxos(account: BaseAccount): def get_utxos(account: BaseAccount):
return account.get_utxos() return account.get_utxos()
@defer.inlineCallbacks async def claim_name(self, name, amount, claim_dict, certificate=None, claim_address=None):
def claim_name(self, name, amount, claim_dict, certificate=None, claim_address=None):
account = self.default_account account = self.default_account
claim = ClaimDict.load_dict(claim_dict) claim = ClaimDict.load_dict(claim_dict)
if not claim_address: if not claim_address:
claim_address = yield account.receiving.get_or_create_usable_address() claim_address = await account.receiving.get_or_create_usable_address()
if certificate: if certificate:
claim = claim.sign( claim = claim.sign(
certificate.private_key, claim_address, certificate.claim_id, curve=SECP256k1 certificate.private_key, claim_address, certificate.claim_id, curve=SECP256k1
) )
existing_claims = yield account.get_claims(claim_name=name) existing_claims = await account.get_claims(claim_name=name)
if len(existing_claims) == 0: if len(existing_claims) == 0:
tx = yield Transaction.claim( tx = await Transaction.claim(
name, claim, amount, claim_address, [account], account name, claim, amount, claim_address, [account], account
) )
elif len(existing_claims) == 1: elif len(existing_claims) == 1:
tx = yield Transaction.update( tx = await Transaction.update(
existing_claims[0], claim, amount, claim_address, [account], account existing_claims[0], claim, amount, claim_address, [account], account
) )
else: else:
raise NameError("More than one other claim exists with the name '{}'.".format(name)) raise NameError("More than one other claim exists with the name '{}'.".format(name))
yield account.ledger.broadcast(tx) await account.ledger.broadcast(tx)
yield self.old_db.save_claims([self._old_get_temp_claim_info( await self.old_db.save_claims([self._old_get_temp_claim_info(
tx, tx.outputs[0], claim_address, claim_dict, name, amount tx, tx.outputs[0], claim_address, claim_dict, name, amount
)]) )])
# TODO: release reserved tx outputs in case anything fails by this point # TODO: release reserved tx outputs in case anything fails by this point
@ -387,43 +377,39 @@ class LbryWalletManager(BaseWalletManager):
"claim_sequence": -1, "claim_sequence": -1,
} }
@defer.inlineCallbacks async def support_claim(self, claim_name, claim_id, amount, account):
def support_claim(self, claim_name, claim_id, amount, account): holding_address = await account.receiving.get_or_create_usable_address()
holding_address = yield account.receiving.get_or_create_usable_address() tx = await Transaction.support(claim_name, claim_id, amount, holding_address, [account], account)
tx = yield Transaction.support(claim_name, claim_id, amount, holding_address, [account], account) await account.ledger.broadcast(tx)
yield account.ledger.broadcast(tx)
return tx return tx
@defer.inlineCallbacks async def tip_claim(self, amount, claim_id, account):
def tip_claim(self, amount, claim_id, account): claim_to_tip = await self.get_claim_by_claim_id(claim_id)
claim_to_tip = yield self.get_claim_by_claim_id(claim_id) tx = await Transaction.support(
tx = yield Transaction.support(
claim_to_tip['name'], claim_id, amount, claim_to_tip['address'], [account], account claim_to_tip['name'], claim_id, amount, claim_to_tip['address'], [account], account
) )
yield account.ledger.broadcast(tx) await account.ledger.broadcast(tx)
return tx return tx
@defer.inlineCallbacks async def abandon_claim(self, claim_id, txid, nout, account):
def abandon_claim(self, claim_id, txid, nout, account): claim = await account.get_claim(claim_id=claim_id, txid=txid, nout=nout)
claim = yield account.get_claim(claim_id=claim_id, txid=txid, nout=nout)
if not claim: if not claim:
raise Exception('No claim found for the specified claim_id or txid:nout') raise Exception('No claim found for the specified claim_id or txid:nout')
tx = yield Transaction.abandon(claim, [account], account) tx = await Transaction.abandon(claim, [account], account)
yield account.ledger.broadcast(tx) await account.ledger.broadcast(tx)
# TODO: release reserved tx outputs in case anything fails by this point # TODO: release reserved tx outputs in case anything fails by this point
defer.returnValue(tx) return tx
@defer.inlineCallbacks async def claim_new_channel(self, channel_name, amount):
def claim_new_channel(self, channel_name, amount):
account = self.default_account account = self.default_account
address = yield account.receiving.get_or_create_usable_address() address = await account.receiving.get_or_create_usable_address()
cert, key = generate_certificate() cert, key = generate_certificate()
tx = yield Transaction.claim(channel_name, cert, amount, address, [account], account) tx = await Transaction.claim(channel_name, cert, amount, address, [account], account)
yield account.ledger.broadcast(tx) await account.ledger.broadcast(tx)
account.add_certificate_private_key(tx.outputs[0].ref, key.decode()) account.add_certificate_private_key(tx.outputs[0].ref, key.decode())
# TODO: release reserved tx outputs in case anything fails by this point # TODO: release reserved tx outputs in case anything fails by this point
defer.returnValue(tx) return tx
def get_certificates(self, private_key_accounts, exclude_without_key=True, **constraints): def get_certificates(self, private_key_accounts, exclude_without_key=True, **constraints):
return self.db.get_certificates( return self.db.get_certificates(
@ -443,7 +429,7 @@ class LbryWalletManager(BaseWalletManager):
pass # TODO: Data payments is disabled pass # TODO: Data payments is disabled
def send_points(self, reserved_points, amount): def send_points(self, reserved_points, amount):
defer.succeed(True) # TODO: Data payments is disabled return True # TODO: Data payments is disabled
def cancel_point_reservation(self, reserved_points): def cancel_point_reservation(self, reserved_points):
pass # fixme: disabled for now. pass # fixme: disabled for now.

View file

@ -3,8 +3,6 @@ import logging
from ecdsa import BadSignatureError from ecdsa import BadSignatureError
from binascii import unhexlify, hexlify from binascii import unhexlify, hexlify
from twisted.internet import defer
from lbrynet.core.Error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint from lbrynet.core.Error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint
from lbryschema.address import is_address from lbryschema.address import is_address
from lbryschema.claim import ClaimDict from lbryschema.claim import ClaimDict
@ -25,24 +23,22 @@ class Resolver:
self.hash160_to_address = hash160_to_address self.hash160_to_address = hash160_to_address
self.network = network self.network = network
@defer.inlineCallbacks async def _handle_resolutions(self, resolutions, requested_uris, page, page_size):
def _handle_resolutions(self, resolutions, requested_uris, page, page_size):
results = {} results = {}
for uri in requested_uris: for uri in requested_uris:
resolution = (resolutions or {}).get(uri, {}) resolution = (resolutions or {}).get(uri, {})
if resolution: if resolution:
try: try:
results[uri] = _handle_claim_result( results[uri] = _handle_claim_result(
(yield self._handle_resolve_uri_response(uri, resolution, page, page_size)) await self._handle_resolve_uri_response(uri, resolution, page, page_size)
) )
except (UnknownNameError, UnknownClaimID, UnknownURI) as err: except (UnknownNameError, UnknownClaimID, UnknownURI) as err:
results[uri] = {'error': str(err)} results[uri] = {'error': str(err)}
else: else:
results[uri] = {'error': "URI lbry://{} cannot be resolved".format(uri.replace("lbry://", ""))} results[uri] = {'error': "URI lbry://{} cannot be resolved".format(uri.replace("lbry://", ""))}
defer.returnValue(results) return results
@defer.inlineCallbacks async def _handle_resolve_uri_response(self, uri, resolution, page=0, page_size=10, raw=False):
def _handle_resolve_uri_response(self, uri, resolution, page=0, page_size=10, raw=False):
result = {} result = {}
claim_trie_root = self.claim_trie_root claim_trie_root = self.claim_trie_root
parsed_uri = parse_lbry_uri(uri) parsed_uri = parse_lbry_uri(uri)
@ -120,21 +116,21 @@ class Resolver:
elif 'unverified_claims_for_name' in resolution and 'certificate' in result: elif 'unverified_claims_for_name' in resolution and 'certificate' in result:
unverified_claims_for_name = resolution['unverified_claims_for_name'] unverified_claims_for_name = resolution['unverified_claims_for_name']
channel_info = yield self.get_channel_claims_page(unverified_claims_for_name, channel_info = await self.get_channel_claims_page(unverified_claims_for_name,
result['certificate'], page=1) result['certificate'], page=1)
claims_in_channel, upper_bound = channel_info claims_in_channel, upper_bound = channel_info
if len(claims_in_channel) > 1: if not claims_in_channel:
log.error("Multiple signed claims for the same name")
elif not claims_in_channel:
log.error("No valid claims for this name for this channel") log.error("No valid claims for this name for this channel")
elif len(claims_in_channel) > 1:
log.error("Multiple signed claims for the same name")
else: else:
result['claim'] = claims_in_channel[0] result['claim'] = claims_in_channel[0]
# parse and validate claims in a channel iteratively into pages of results # parse and validate claims in a channel iteratively into pages of results
elif 'unverified_claims_in_channel' in resolution and 'certificate' in result: elif 'unverified_claims_in_channel' in resolution and 'certificate' in result:
ids_to_check = resolution['unverified_claims_in_channel'] ids_to_check = resolution['unverified_claims_in_channel']
channel_info = yield self.get_channel_claims_page(ids_to_check, result['certificate'], channel_info = await self.get_channel_claims_page(ids_to_check, result['certificate'],
page=page, page_size=page_size) page=page, page_size=page_size)
claims_in_channel, upper_bound = channel_info claims_in_channel, upper_bound = channel_info
@ -145,16 +141,15 @@ class Resolver:
result['success'] = False result['success'] = False
result['uri'] = str(parsed_uri) result['uri'] = str(parsed_uri)
defer.returnValue(result) return result
@defer.inlineCallbacks async def get_certificate_and_validate_result(self, claim_result):
def get_certificate_and_validate_result(self, claim_result):
if not claim_result or 'value' not in claim_result: if not claim_result or 'value' not in claim_result:
return claim_result return claim_result
certificate = None certificate = None
certificate_id = smart_decode(claim_result['value']).certificate_id certificate_id = smart_decode(claim_result['value']).certificate_id
if certificate_id: if certificate_id:
certificate = yield self.network.get_claims_by_ids(certificate_id.decode()) certificate = await self.network.get_claims_by_ids(certificate_id.decode())
certificate = certificate.pop(certificate_id.decode()) if certificate else None certificate = certificate.pop(certificate_id.decode()) if certificate else None
return self.parse_and_validate_claim_result(claim_result, certificate=certificate) return self.parse_and_validate_claim_result(claim_result, certificate=certificate)
@ -227,8 +222,7 @@ class Resolver:
abs_position += 1 abs_position += 1
return queries, names, absolute_position_index return queries, names, absolute_position_index
@defer.inlineCallbacks async def iter_channel_claims_pages(self, queries, claim_positions, claim_names, certificate,
def iter_channel_claims_pages(self, queries, claim_positions, claim_names, certificate,
page_size=10): page_size=10):
# lbryum server returns a dict of {claim_id: (name, claim_height)} # lbryum server returns a dict of {claim_id: (name, claim_height)}
# first, sort the claims by block height (and by claim id int value within a block). # first, sort the claims by block height (and by claim id int value within a block).
@ -243,11 +237,10 @@ class Resolver:
# processed them. # processed them.
# TODO: fix ^ in lbryschema # TODO: fix ^ in lbryschema
@defer.inlineCallbacks async def iter_validate_channel_claims():
def iter_validate_channel_claims():
formatted_claims = [] formatted_claims = []
for claim_ids in queries: for claim_ids in queries:
batch_result = yield self.network.get_claims_by_ids(*claim_ids) batch_result = await self.network.get_claims_by_ids(*claim_ids)
for claim_id in claim_ids: for claim_id in claim_ids:
claim = batch_result[claim_id] claim = batch_result[claim_id]
if claim['name'] == claim_names[claim_id]: if claim['name'] == claim_names[claim_id]:
@ -258,25 +251,20 @@ class Resolver:
else: else:
log.warning("ignoring claim with name mismatch %s %s", claim['name'], log.warning("ignoring claim with name mismatch %s %s", claim['name'],
claim['claim_id']) claim['claim_id'])
defer.returnValue(formatted_claims) return formatted_claims
yielded_page = False
results = [] results = []
for claim in (yield iter_validate_channel_claims()): for claim in (await iter_validate_channel_claims()):
results.append(claim) results.append(claim)
# if there is a full page of results, yield it # if there is a full page of results, yield it
if len(results) and len(results) % page_size == 0: if len(results) and len(results) % page_size == 0:
defer.returnValue(results[-page_size:]) return results[-page_size:]
yielded_page = True
# if we didn't get a full page of results, yield what results we did get return results
if not yielded_page:
defer.returnValue(results)
@defer.inlineCallbacks async def get_channel_claims_page(self, channel_claim_infos, certificate, page, page_size=10):
def get_channel_claims_page(self, channel_claim_infos, certificate, page, page_size=10):
page = page or 0 page = page or 0
page_size = max(page_size, 1) page_size = max(page_size, 1)
if page_size > 500: if page_size > 500:
@ -284,14 +272,14 @@ class Resolver:
start_position = (page - 1) * page_size start_position = (page - 1) * page_size
queries, names, claim_positions = self.prepare_claim_queries(start_position, page_size, queries, names, claim_positions = self.prepare_claim_queries(start_position, page_size,
channel_claim_infos) channel_claim_infos)
page_generator = yield self.iter_channel_claims_pages(queries, claim_positions, names, page_generator = await self.iter_channel_claims_pages(queries, claim_positions, names,
certificate, page_size=page_size) certificate, page_size=page_size)
upper_bound = len(claim_positions) upper_bound = len(claim_positions)
if not page: if not page:
defer.returnValue((None, upper_bound)) return None, upper_bound
if start_position > upper_bound: if start_position > upper_bound:
raise IndexError("claim %i greater than max %i" % (start_position, upper_bound)) raise IndexError("claim %i greater than max %i" % (start_position, upper_bound))
defer.returnValue((page_generator, upper_bound)) return page_generator, upper_bound
# Format amount to be decimal encoded string # Format amount to be decimal encoded string

View file

@ -1,6 +1,6 @@
import asyncio import asyncio
from orchstr8.testcase import IntegrationTestCase, d2f from orchstr8.testcase import IntegrationTestCase
from lbryschema.claim import ClaimDict from lbryschema.claim import ClaimDict
from lbrynet.wallet.transaction import Transaction from lbrynet.wallet.transaction import Transaction
from lbrynet.wallet.account import generate_certificate from lbrynet.wallet.account import generate_certificate
@ -43,9 +43,9 @@ class BasicTransactionTest(IntegrationTestCase):
async def test_creating_updating_and_abandoning_claim_with_channel(self): async def test_creating_updating_and_abandoning_claim_with_channel(self):
await d2f(self.account.ensure_address_gap()) await self.account.ensure_address_gap()
address1, address2 = await d2f(self.account.receiving.get_addresses(limit=2, only_usable=True)) address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True)
sendtxid1 = await self.blockchain.send_to_address(address1, 5) sendtxid1 = await self.blockchain.send_to_address(address1, 5)
sendtxid2 = await self.blockchain.send_to_address(address2, 5) sendtxid2 = await self.blockchain.send_to_address(address2, 5)
await self.blockchain.generate(1) await self.blockchain.generate(1)
@ -54,13 +54,13 @@ class BasicTransactionTest(IntegrationTestCase):
self.on_transaction_id(sendtxid2), self.on_transaction_id(sendtxid2),
]) ])
self.assertEqual(d2l(await d2f(self.account.get_balance())), '10.0') self.assertEqual(d2l(await self.account.get_balance()), '10.0')
cert, key = generate_certificate() cert, key = generate_certificate()
cert_tx = await d2f(Transaction.claim('@bar', cert, l2d('1.0'), address1, [self.account], self.account)) cert_tx = await Transaction.claim('@bar', cert, l2d('1.0'), address1, [self.account], self.account)
claim = ClaimDict.load_dict(example_claim_dict) claim = ClaimDict.load_dict(example_claim_dict)
claim = claim.sign(key, address1, cert_tx.outputs[0].claim_id) claim = claim.sign(key, address1, cert_tx.outputs[0].claim_id)
claim_tx = await d2f(Transaction.claim('foo', claim, l2d('1.0'), address1, [self.account], self.account)) claim_tx = await Transaction.claim('foo', claim, l2d('1.0'), address1, [self.account], self.account)
await self.broadcast(cert_tx) await self.broadcast(cert_tx)
await self.broadcast(claim_tx) await self.broadcast(claim_tx)
@ -74,23 +74,23 @@ class BasicTransactionTest(IntegrationTestCase):
self.on_transaction(cert_tx), self.on_transaction(cert_tx),
]) ])
self.assertEqual(d2l(await d2f(self.account.get_balance(confirmations=1))), '7.985786') self.assertEqual(d2l(await self.account.get_balance(confirmations=1)), '7.985786')
self.assertEqual(d2l(await d2f(self.account.get_balance(include_claims=True))), '9.985786') self.assertEqual(d2l(await self.account.get_balance(include_claims=True)), '9.985786')
response = await d2f(self.ledger.resolve(0, 10, 'lbry://@bar/foo')) response = await self.ledger.resolve(0, 10, 'lbry://@bar/foo')
self.assertIn('lbry://@bar/foo', response) self.assertIn('lbry://@bar/foo', response)
self.assertIn('claim', response['lbry://@bar/foo']) self.assertIn('claim', response['lbry://@bar/foo'])
abandon_tx = await d2f(Transaction.abandon([claim_tx.outputs[0]], [self.account], self.account)) abandon_tx = await Transaction.abandon([claim_tx.outputs[0]], [self.account], self.account)
await self.broadcast(abandon_tx) await self.broadcast(abandon_tx)
await self.on_transaction(abandon_tx) await self.on_transaction(abandon_tx)
await self.blockchain.generate(1) await self.blockchain.generate(1)
await self.on_transaction(abandon_tx) await self.on_transaction(abandon_tx)
response = await d2f(self.ledger.resolve(0, 10, 'lbry://@bar/foo')) response = await self.ledger.resolve(0, 10, 'lbry://@bar/foo')
self.assertNotIn('claim', response['lbry://@bar/foo']) self.assertNotIn('claim', response['lbry://@bar/foo'])
# checks for expected format in inexistent URIs # checks for expected format in inexistent URIs
response = await d2f(self.ledger.resolve(0, 10, 'lbry://404', 'lbry://@404')) response = await self.ledger.resolve(0, 10, 'lbry://404', 'lbry://@404')
self.assertEqual('URI lbry://404 cannot be resolved', response['lbry://404']['error']) self.assertEqual('URI lbry://404 cannot be resolved', response['lbry://404']['error'])
self.assertEqual('URI lbry://@404 cannot be resolved', response['lbry://@404']['error']) self.assertEqual('URI lbry://@404 cannot be resolved', response['lbry://@404']['error'])

View file

@ -1,23 +1,24 @@
from twisted.trial import unittest from orchstr8.testcase import AsyncioTestCase
from twisted.internet import defer from torba.wallet import Wallet
from lbrynet.wallet.ledger import MainNetLedger, WalletDatabase from lbrynet.wallet.ledger import MainNetLedger, WalletDatabase
from lbrynet.wallet.header import Headers from lbrynet.wallet.header import Headers
from lbrynet.wallet.account import Account from lbrynet.wallet.account import Account
from torba.wallet import Wallet
class TestAccount(unittest.TestCase): class TestAccount(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
self.ledger = MainNetLedger({ self.ledger = MainNetLedger({
'db': WalletDatabase(':memory:'), 'db': WalletDatabase(':memory:'),
'headers': Headers(':memory:') 'headers': Headers(':memory:')
}) })
return self.ledger.db.open() await self.ledger.db.open()
@defer.inlineCallbacks async def asyncTearDown(self):
def test_generate_account(self): await self.ledger.db.close()
async def test_generate_account(self):
account = Account.generate(self.ledger, Wallet(), 'lbryum') account = Account.generate(self.ledger, Wallet(), 'lbryum')
self.assertEqual(account.ledger, self.ledger) self.assertEqual(account.ledger, self.ledger)
self.assertIsNotNone(account.seed) self.assertIsNotNone(account.seed)
@ -27,20 +28,19 @@ class TestAccount(unittest.TestCase):
self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.public_key.ledger, self.ledger)
self.assertEqual(account.private_key.public_key, account.public_key) self.assertEqual(account.private_key.public_key, account.public_key)
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 0) self.assertEqual(len(addresses), 0)
yield account.ensure_address_gap() await account.ensure_address_gap()
addresses = yield account.receiving.get_addresses() addresses = await account.receiving.get_addresses()
self.assertEqual(len(addresses), 20) self.assertEqual(len(addresses), 20)
addresses = yield account.change.get_addresses() addresses = await account.change.get_addresses()
self.assertEqual(len(addresses), 6) self.assertEqual(len(addresses), 6)
@defer.inlineCallbacks async def test_generate_account_from_seed(self):
def test_generate_account_from_seed(self):
account = Account.from_dict( account = Account.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "seed":
@ -58,16 +58,16 @@ class TestAccount(unittest.TestCase):
'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EMmDgp66FxH' 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EMmDgp66FxH'
'uDtWdft3B5eL5xQtyzAtkdmhhC95gjRjLzSTdkho95asu9' 'uDtWdft3B5eL5xQtyzAtkdmhhC95gjRjLzSTdkho95asu9'
) )
address = yield account.receiving.ensure_address_gap() address = await account.receiving.ensure_address_gap()
self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
private_key = yield self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') private_key = await self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx')
self.assertEqual( self.assertEqual(
private_key.extended_key_string(), private_key.extended_key_string(),
'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR' 'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR'
'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT' 'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT'
) )
private_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') private_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX')
self.assertIsNone(private_key) self.assertIsNone(private_key)
def test_load_and_save_account(self): def test_load_and_save_account(self):

View file

@ -1,5 +1,5 @@
from binascii import hexlify, unhexlify
import unittest import unittest
from binascii import hexlify, unhexlify
from lbrynet.wallet.claim_proofs import get_hash_for_outpoint, verify_proof from lbrynet.wallet.claim_proofs import get_hash_for_outpoint, verify_proof
from lbryschema.hashing import double_sha256 from lbryschema.hashing import double_sha256

View file

@ -1,4 +1,5 @@
from twisted.trial import unittest import unittest
from lbrynet.wallet.dewies import lbc_to_dewies as l2d, dewies_to_lbc as d2l from lbrynet.wallet.dewies import lbc_to_dewies as l2d, dewies_to_lbc as d2l

View file

@ -1,15 +1,12 @@
from io import BytesIO
from binascii import unhexlify from binascii import unhexlify
from twisted.trial import unittest from orchstr8.testcase import AsyncioTestCase
from twisted.internet import defer from torba.util import ArithUint256
from lbrynet.wallet.ledger import Headers from lbrynet.wallet.ledger import Headers
from torba.util import ArithUint256
class TestHeaders(AsyncioTestCase):
class TestHeaders(unittest.TestCase):
def test_deserialize(self): def test_deserialize(self):
self.maxDiff = None self.maxDiff = None
@ -36,19 +33,17 @@ class TestHeaders(unittest.TestCase):
'version': 536870912 'version': 536870912
}) })
@defer.inlineCallbacks async def test_connect_from_genesis(self):
def test_connect_from_genesis(self):
headers = Headers(':memory:') headers = Headers(':memory:')
self.assertEqual(headers.height, -1) self.assertEqual(headers.height, -1)
yield headers.connect(0, HEADERS) await headers.connect(0, HEADERS)
self.assertEqual(headers.height, 19) self.assertEqual(headers.height, 19)
@defer.inlineCallbacks async def test_connect_from_middle(self):
def test_connect_from_middle(self):
h = Headers(':memory:') h = Headers(':memory:')
h.io.write(HEADERS[:10*Headers.header_size]) h.io.write(HEADERS[:10*Headers.header_size])
self.assertEqual(h.height, 9) self.assertEqual(h.height, 9)
yield h.connect(len(h), HEADERS[10*Headers.header_size:20*Headers.header_size]) await h.connect(len(h), HEADERS[10*Headers.header_size:20*Headers.header_size])
self.assertEqual(h.height, 19) self.assertEqual(h.height, 19)
def test_target_calculation(self): def test_target_calculation(self):

View file

@ -1,79 +1,68 @@
from twisted.internet import defer from orchstr8.testcase import AsyncioTestCase
from twisted.trial import unittest from torba.wallet import Wallet
from lbrynet.wallet.account import Account from lbrynet.wallet.account import Account
from lbrynet.wallet.transaction import Transaction, Output, Input from lbrynet.wallet.transaction import Transaction, Output, Input
from lbrynet.wallet.ledger import MainNetLedger from lbrynet.wallet.ledger import MainNetLedger
from torba.wallet import Wallet
class LedgerTestCase(unittest.TestCase): class LedgerTestCase(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
super().setUp()
self.ledger = MainNetLedger({ self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'), 'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': MainNetLedger.headers_class(':memory:')
}) })
self.account = Account.generate(self.ledger, Wallet(), "lbryum") self.account = Account.generate(self.ledger, Wallet(), "lbryum")
return self.ledger.db.open() await self.ledger.db.open()
def tearDown(self): async def asyncTearDown(self):
super().tearDown() await self.ledger.db.close()
return self.ledger.db.close()
class BasicAccountingTests(LedgerTestCase): class BasicAccountingTests(LedgerTestCase):
@defer.inlineCallbacks async def test_empty_state(self):
def test_empty_state(self): self.assertEqual(await self.account.get_balance(), 0)
balance = yield self.account.get_balance()
self.assertEqual(balance, 0)
@defer.inlineCallbacks async def test_balance(self):
def test_balance(self): address = await self.account.receiving.get_or_create_usable_address()
address = yield self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address) hash160 = self.ledger.address_to_hash160(address)
tx = Transaction(is_verified=True)\ tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)]) .add_outputs([Output.pay_pubkey_hash(100, hash160)])
yield self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0) self.assertEqual(await self.account.get_balance(), 100)
self.assertEqual(balance, 100)
tx = Transaction(is_verified=True)\ tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)]) .add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)])
yield self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0) self.assertEqual(await self.account.get_balance(), 100) # claim names don't count towards balance
self.assertEqual(balance, 100) # claim names don't count towards balance self.assertEqual(await self.account.get_balance(include_claims=True), 200)
balance = yield self.account.get_balance(0, include_claims=True)
self.assertEqual(balance, 200)
@defer.inlineCallbacks async def test_get_utxo(self):
def test_get_utxo(self):
address = yield self.account.receiving.get_or_create_usable_address() address = yield self.account.receiving.get_or_create_usable_address()
hash160 = self.ledger.address_to_hash160(address) hash160 = self.ledger.address_to_hash160(address)
tx = Transaction(is_verified=True)\ tx = Transaction(is_verified=True)\
.add_outputs([Output.pay_pubkey_hash(100, hash160)]) .add_outputs([Output.pay_pubkey_hash(100, hash160)])
yield self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
utxos = yield self.account.get_utxos() utxos = await self.account.get_utxos()
self.assertEqual(len(utxos), 1) self.assertEqual(len(utxos), 1)
tx = Transaction(is_verified=True)\ tx = Transaction(is_verified=True)\
.add_inputs([Input.spend(utxos[0])]) .add_inputs([Input.spend(utxos[0])])
yield self.ledger.db.save_transaction_io( await self.ledger.db.save_transaction_io(
'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1)
) )
balance = yield self.account.get_balance(0, include_claims=True) self.assertEqual(await self.account.get_balance(include_claims=True), 0)
self.assertEqual(balance, 0)
utxos = yield self.account.get_utxos() utxos = await self.account.get_utxos()
self.assertEqual(len(utxos), 0) self.assertEqual(len(utxos), 0)

View file

@ -1,5 +1,5 @@
import unittest
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest
from lbrynet.wallet.script import OutputScript from lbrynet.wallet.script import OutputScript

View file

@ -1,7 +1,7 @@
import unittest
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest
from twisted.internet import defer
from orchstr8.testcase import AsyncioTestCase
from torba.constants import CENT, COIN, NULL_HASH32 from torba.constants import CENT, COIN, NULL_HASH32
from torba.wallet import Wallet from torba.wallet import Wallet
@ -35,19 +35,17 @@ def get_claim_transaction(claim_name, claim=b''):
) )
class TestSizeAndFeeEstimation(unittest.TestCase): class TestSizeAndFeeEstimation(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
super().setUp()
self.ledger = MainNetLedger({ self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'), 'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': MainNetLedger.headers_class(':memory:')
}) })
return self.ledger.db.open() await self.ledger.db.open()
def tearDown(self): async def asyncTearDown(self):
super().tearDown() await self.ledger.db.close()
return self.ledger.db.close()
def test_output_size_and_fee(self): def test_output_size_and_fee(self):
txo = get_output() txo = get_output()
@ -219,22 +217,19 @@ class TestTransactionSerialization(unittest.TestCase):
self.assertEqual(tx.raw, raw) self.assertEqual(tx.raw, raw)
class TestTransactionSigning(unittest.TestCase): class TestTransactionSigning(AsyncioTestCase):
def setUp(self): async def asyncSetUp(self):
super().setUp()
self.ledger = MainNetLedger({ self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'), 'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:') 'headers': MainNetLedger.headers_class(':memory:')
}) })
return self.ledger.db.open() await self.ledger.db.open()
def tearDown(self): async def asyncTearDown(self):
super().tearDown() await self.ledger.db.close()
return self.ledger.db.close()
@defer.inlineCallbacks async def test_sign(self):
def test_sign(self):
account = self.ledger.account_class.from_dict( account = self.ledger.account_class.from_dict(
self.ledger, Wallet(), { self.ledger, Wallet(), {
"seed": "seed":
@ -243,8 +238,8 @@ class TestTransactionSigning(unittest.TestCase):
} }
) )
yield account.ensure_address_gap() await account.ensure_address_gap()
address1, address2 = yield account.receiving.get_addresses(limit=2) address1, address2 = await account.receiving.get_addresses(limit=2)
pubkey_hash1 = self.ledger.address_to_hash160(address1) pubkey_hash1 = self.ledger.address_to_hash160(address1)
pubkey_hash2 = self.ledger.address_to_hash160(address2) pubkey_hash2 = self.ledger.address_to_hash160(address2)
@ -252,7 +247,7 @@ class TestTransactionSigning(unittest.TestCase):
.add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \ .add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \
.add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) .add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)])
yield tx.sign([account]) await tx.sign([account])
self.assertEqual( self.assertEqual(
hexlify(tx.inputs[0].script.values['signature']), hexlify(tx.inputs[0].script.values['signature']),