From 64e306801dd21a1800c6d0d861a7c5393e520d68 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Mon, 15 Oct 2018 17:16:43 -0400 Subject: [PATCH] updated wallet to use asyncio --- lbrynet/wallet/account.py | 21 ++-- lbrynet/wallet/database.py | 13 +- lbrynet/wallet/ledger.py | 70 +++++------ lbrynet/wallet/manager.py | 116 ++++++++---------- lbrynet/wallet/resolve.py | 58 ++++----- tests/integration/wallet/test_transactions.py | 24 ++-- tests/unit/wallet/test_account.py | 36 +++--- tests/unit/wallet/test_claim_proofs.py | 2 +- tests/unit/wallet/test_dewies.py | 3 +- tests/unit/wallet/test_headers.py | 19 ++- tests/unit/wallet/test_ledger.py | 57 ++++----- tests/unit/wallet/test_script.py | 2 +- tests/unit/wallet/test_transaction.py | 37 +++--- 13 files changed, 198 insertions(+), 260 deletions(-) diff --git a/lbrynet/wallet/account.py b/lbrynet/wallet/account.py index 297d55682..edf3f2637 100644 --- a/lbrynet/wallet/account.py +++ b/lbrynet/wallet/account.py @@ -1,8 +1,6 @@ import json import logging -from twisted.internet import defer - from torba.baseaccount import BaseAccount from torba.basetransaction import TXORef @@ -31,8 +29,7 @@ class Account(BaseAccount): def get_certificate_private_key(self, ref: TXORef): return self.certificates.get(ref.id) - @defer.inlineCallbacks - def maybe_migrate_certificates(self): + async def maybe_migrate_certificates(self): if not self.certificates: return @@ -49,7 +46,7 @@ class Account(BaseAccount): for maybe_claim_id in list(self.certificates): results['total'] += 1 if ':' not in maybe_claim_id: - claims = yield self.ledger.network.get_claims_by_ids(maybe_claim_id) + claims = await self.ledger.network.get_claims_by_ids(maybe_claim_id) if maybe_claim_id not in claims: log.warning( "Failed to migrate claim '%s', server did not return any claim information.", @@ -60,7 +57,7 @@ class Account(BaseAccount): claim = claims[maybe_claim_id] tx = None if claim: - tx = yield self.ledger.db.get_transaction(txid=claim['txid']) + tx = await self.ledger.db.get_transaction(txid=claim['txid']) else: log.warning(maybe_claim_id) if tx is not None: @@ -96,7 +93,7 @@ class Account(BaseAccount): else: try: txid, nout = maybe_claim_id.split(':') - tx = yield self.ledger.db.get_transaction(txid=txid) + tx = await self.ledger.db.get_transaction(txid=txid) if tx.outputs[int(nout)].script.is_claim_involved: results['previous-success'] += 1 else: @@ -115,9 +112,8 @@ class Account(BaseAccount): indent=2 )) - @defer.inlineCallbacks - def save_max_gap(self): - gap = yield self.get_max_gap() + async def save_max_gap(self): + gap = await self.get_max_gap() self.receiving.gap = max(20, gap['max_receiving_gap'] + 1) self.change.gap = max(6, gap['max_change_gap'] + 1) self.wallet.save() @@ -144,9 +140,8 @@ class Account(BaseAccount): d['certificates'] = self.certificates return d - @defer.inlineCallbacks - def get_details(self, **kwargs): - details = yield super().get_details(**kwargs) + async def get_details(self, **kwargs): + details = await super().get_details(**kwargs) details['certificates'] = len(self.certificates) return details diff --git a/lbrynet/wallet/database.py b/lbrynet/wallet/database.py index 323cd4a07..acca6adb7 100644 --- a/lbrynet/wallet/database.py +++ b/lbrynet/wallet/database.py @@ -1,4 +1,3 @@ -from twisted.internet import defer from torba.basedatabase import BaseDatabase @@ -49,11 +48,10 @@ class WalletDatabase(BaseDatabase): row['claim_name'] = txo.claim_name return row - @defer.inlineCallbacks - def get_txos(self, **constraints): + async def get_txos(self, **constraints): my_account = constraints.get('my_account', constraints.get('account', None)) - txos = yield super().get_txos(**constraints) + txos = await super().get_txos(**constraints) channel_ids = set() for txo in txos: @@ -66,7 +64,7 @@ class WalletDatabase(BaseDatabase): if channel_ids: channels = { txo.claim_id: txo for txo in - (yield super().get_utxos( + (await super().get_utxos( my_account=my_account, claim_id__in=channel_ids )) @@ -103,9 +101,8 @@ class WalletDatabase(BaseDatabase): self.constrain_channels(constraints) return self.get_claim_count(**constraints) - @defer.inlineCallbacks - def get_certificates(self, private_key_accounts, exclude_without_key=False, **constraints): - channels = yield self.get_channels(**constraints) + async def get_certificates(self, private_key_accounts, exclude_without_key=False, **constraints): + channels = await self.get_channels(**constraints) certificates = [] if private_key_accounts is not None: for channel in channels: diff --git a/lbrynet/wallet/ledger.py b/lbrynet/wallet/ledger.py index 2d08f9ea4..992a43e7e 100644 --- a/lbrynet/wallet/ledger.py +++ b/lbrynet/wallet/ledger.py @@ -1,15 +1,12 @@ +import asyncio import logging - -from six import int2byte from binascii import unhexlify -from twisted.internet import defer - -from .resolve import Resolver from lbryschema.error import URIParseError from lbryschema.uri import parse_lbry_uri from torba.baseledger import BaseLedger +from .resolve import Resolver from .account import Account from .network import Network from .database import WalletDatabase @@ -25,15 +22,17 @@ class MainNetLedger(BaseLedger): symbol = 'LBC' network_name = 'mainnet' + headers: Headers + account_class = Account database_class = WalletDatabase headers_class = Headers network_class = Network transaction_class = Transaction - secret_prefix = int2byte(0x1c) - pubkey_address_prefix = int2byte(0x55) - script_address_prefix = int2byte(0x7a) + secret_prefix = bytes((0x1c,)) + pubkey_address_prefix = bytes((0x55,)) + script_address_prefix = bytes((0x7a,)) extended_public_key_prefix = unhexlify('0488b21e') extended_private_key_prefix = unhexlify('0488ade4') @@ -54,45 +53,38 @@ class MainNetLedger(BaseLedger): return Resolver(self.headers.claim_trie_root, self.headers.height, self.transaction_class, hash160_to_address=self.hash160_to_address, network=self.network) - @defer.inlineCallbacks - def resolve(self, page, page_size, *uris): + async def resolve(self, page, page_size, *uris): for uri in uris: try: parse_lbry_uri(uri) except URIParseError as err: - defer.returnValue({'error': err.message}) - resolutions = yield self.network.get_values_for_uris(self.headers.hash().decode(), *uris) - return (yield self.resolver._handle_resolutions(resolutions, uris, page, page_size)) + return {'error': err.args[0]} + resolutions = await self.network.get_values_for_uris(self.headers.hash().decode(), *uris) + return await self.resolver._handle_resolutions(resolutions, uris, page, page_size) - @defer.inlineCallbacks - def get_claim_by_claim_id(self, claim_id): - result = (yield self.network.get_claims_by_ids(claim_id)).pop(claim_id, {}) - return (yield self.resolver.get_certificate_and_validate_result(result)) + async def get_claim_by_claim_id(self, claim_id): + result = (await self.network.get_claims_by_ids(claim_id)).pop(claim_id, {}) + return await self.resolver.get_certificate_and_validate_result(result) - @defer.inlineCallbacks - def get_claim_by_outpoint(self, txid, nout): - claims = (yield self.network.get_claims_in_tx(txid)) or [] + async def get_claim_by_outpoint(self, txid, nout): + claims = (await self.network.get_claims_in_tx(txid)) or [] for claim in claims: if claim['nout'] == nout: - return (yield self.resolver.get_certificate_and_validate_result(claim)) + return await self.resolver.get_certificate_and_validate_result(claim) return 'claim not found' - @defer.inlineCallbacks - def start(self): - yield super().start() - yield defer.DeferredList([ - a.maybe_migrate_certificates() for a in self.accounts - ]) - yield defer.DeferredList([a.save_max_gap() for a in self.accounts]) - yield self._report_state() + async def start(self): + await super().start() + await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) + await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) + await self._report_state() - @defer.inlineCallbacks - def _report_state(self): + async def _report_state(self): for account in self.accounts: - total_receiving = len((yield account.receiving.get_addresses())) - total_change = len((yield account.change.get_addresses())) - channel_count = yield account.get_channel_count() - claim_count = yield account.get_claim_count() + total_receiving = len((await account.receiving.get_addresses())) + total_change = len((await account.change.get_addresses())) + channel_count = await account.get_channel_count() + claim_count = await account.get_claim_count() log.info("Loaded account %s with %d receiving addresses (gap: %d), " "%d change addresses (gap: %d), %d channels, %d certificates and %d claims.", account.id, total_receiving, account.receiving.gap, total_change, account.change.gap, @@ -101,8 +93,8 @@ class MainNetLedger(BaseLedger): class TestNetLedger(MainNetLedger): network_name = 'testnet' - pubkey_address_prefix = int2byte(111) - script_address_prefix = int2byte(196) + pubkey_address_prefix = bytes((111,)) + script_address_prefix = bytes((196,)) extended_public_key_prefix = unhexlify('043587cf') extended_private_key_prefix = unhexlify('04358394') @@ -110,8 +102,8 @@ class TestNetLedger(MainNetLedger): class RegTestLedger(MainNetLedger): network_name = 'regtest' headers_class = UnvalidatedHeaders - pubkey_address_prefix = int2byte(111) - script_address_prefix = int2byte(196) + pubkey_address_prefix = bytes((111,)) + script_address_prefix = bytes((196,)) extended_public_key_prefix = unhexlify('043587cf') extended_private_key_prefix = unhexlify('04358394') diff --git a/lbrynet/wallet/manager.py b/lbrynet/wallet/manager.py index 85e8f9dfa..3f4fa6d8b 100644 --- a/lbrynet/wallet/manager.py +++ b/lbrynet/wallet/manager.py @@ -6,8 +6,6 @@ from binascii import unhexlify from datetime import datetime from typing import Optional -from twisted.internet import defer - from lbryschema.schema import SECP256k1 from torba.basemanager import BaseWalletManager @@ -73,7 +71,7 @@ class LbryWalletManager(BaseWalletManager): return not self.default_account.encrypted def check_locked(self): - return defer.succeed(self.default_account.encrypted) + return self.default_account.encrypted def decrypt_account(self, account): assert account.password is not None, "account is not unlocked" @@ -157,8 +155,7 @@ class LbryWalletManager(BaseWalletManager): return receiving_addresses, change_addresses @classmethod - @defer.inlineCallbacks - def from_lbrynet_config(cls, settings, db): + async def from_lbrynet_config(cls, settings, db): ledger_id = { 'lbrycrd_main': 'lbc_mainnet', @@ -194,17 +191,16 @@ class LbryWalletManager(BaseWalletManager): if receiving_addresses or change_addresses: if not os.path.exists(ledger.path): os.mkdir(ledger.path) - yield ledger.db.open() + await ledger.db.open() try: - yield manager._migrate_addresses(receiving_addresses, change_addresses) + await manager._migrate_addresses(receiving_addresses, change_addresses) finally: - yield ledger.db.close() - defer.returnValue(manager) + await ledger.db.close() + return manager - @defer.inlineCallbacks - def _migrate_addresses(self, receiving_addresses: set, change_addresses: set): - migrated_receiving = set((yield self.default_account.receiving.generate_keys(0, len(receiving_addresses)))) - migrated_change = set((yield self.default_account.change.generate_keys(0, len(change_addresses)))) + async def _migrate_addresses(self, receiving_addresses: set, change_addresses: set): + migrated_receiving = set((await self.default_account.receiving.generate_keys(0, len(receiving_addresses)))) + migrated_change = set((await self.default_account.change.generate_keys(0, len(change_addresses)))) receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses)) change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses)) if not any(change_addresses.difference(migrated_change)): @@ -231,25 +227,23 @@ class LbryWalletManager(BaseWalletManager): # TODO: check if we have enough to cover amount return ReservedPoints(address, amount) - @defer.inlineCallbacks - def send_amount_to_address(self, amount: int, destination_address: bytes, account=None): + async def send_amount_to_address(self, amount: int, destination_address: bytes, account=None): account = account or self.default_account - tx = yield Transaction.pay(amount, destination_address, [account], account) - yield account.ledger.broadcast(tx) + tx = await Transaction.pay(amount, destination_address, [account], account) + await account.ledger.broadcast(tx) return tx - @defer.inlineCallbacks - def send_claim_to_address(self, claim_id: str, destination_address: str, amount: Optional[int], + async def send_claim_to_address(self, claim_id: str, destination_address: str, amount: Optional[int], account=None): account = account or self.default_account - claims = account.ledger.db.get_utxos(claim_id=claim_id) + claims = await account.ledger.db.get_utxos(claim_id=claim_id) if not claims: raise NameError("Claim not found: {}".format(claim_id)) - tx = yield Transaction.update( + tx = await Transaction.update( claims[0], ClaimDict.deserialize(claims[0].script.value['claim']), amount, destination_address.encode(), [account], account ) - yield self.ledger.broadcast(tx) + await self.ledger.broadcast(tx) return tx def send_points_to_address(self, reserved: ReservedPoints, amount: int, account=None): @@ -262,23 +256,21 @@ class LbryWalletManager(BaseWalletManager): def get_info_exchanger(self): return LBRYcrdAddressRequester(self) - @defer.inlineCallbacks - def resolve(self, *uris, **kwargs): + async def resolve(self, *uris, **kwargs): page = kwargs.get('page', 0) page_size = kwargs.get('page_size', 10) check_cache = kwargs.get('check_cache', False) # TODO: put caching back (was force_refresh parameter) - ledger = self.default_account.ledger # type: MainNetLedger - results = yield ledger.resolve(page, page_size, *uris) - yield self.old_db.save_claims_for_resolve( + ledger: MainNetLedger = self.default_account.ledger + results = await ledger.resolve(page, page_size, *uris) + await self.old_db.save_claims_for_resolve( (value for value in results.values() if 'error' not in value)) - defer.returnValue(results) + return results def get_claims_for_name(self, name: str): return self.ledger.network.get_claims_for_name(name) - @defer.inlineCallbacks - def address_is_mine(self, unknown_address, account): - match = yield self.ledger.db.get_address(address=unknown_address, account=account) + async def address_is_mine(self, unknown_address, account): + match = await self.ledger.db.get_address(address=unknown_address, account=account) if match is not None: return True return False @@ -287,10 +279,9 @@ class LbryWalletManager(BaseWalletManager): return self.default_account.ledger.get_transaction(txid) @staticmethod - @defer.inlineCallbacks - def get_history(account: BaseAccount, **constraints): + async def get_history(account: BaseAccount, **constraints): headers = account.ledger.headers - txs = (yield account.get_transactions(**constraints)) + txs = await account.get_transactions(**constraints) history = [] for tx in txs: ts = headers[tx.height]['timestamp'] @@ -346,29 +337,28 @@ class LbryWalletManager(BaseWalletManager): def get_utxos(account: BaseAccount): return account.get_utxos() - @defer.inlineCallbacks - def claim_name(self, name, amount, claim_dict, certificate=None, claim_address=None): + async def claim_name(self, name, amount, claim_dict, certificate=None, claim_address=None): account = self.default_account claim = ClaimDict.load_dict(claim_dict) if not claim_address: - claim_address = yield account.receiving.get_or_create_usable_address() + claim_address = await account.receiving.get_or_create_usable_address() if certificate: claim = claim.sign( certificate.private_key, claim_address, certificate.claim_id, curve=SECP256k1 ) - existing_claims = yield account.get_claims(claim_name=name) + existing_claims = await account.get_claims(claim_name=name) if len(existing_claims) == 0: - tx = yield Transaction.claim( + tx = await Transaction.claim( name, claim, amount, claim_address, [account], account ) elif len(existing_claims) == 1: - tx = yield Transaction.update( + tx = await Transaction.update( existing_claims[0], claim, amount, claim_address, [account], account ) else: raise NameError("More than one other claim exists with the name '{}'.".format(name)) - yield account.ledger.broadcast(tx) - yield self.old_db.save_claims([self._old_get_temp_claim_info( + await account.ledger.broadcast(tx) + await self.old_db.save_claims([self._old_get_temp_claim_info( tx, tx.outputs[0], claim_address, claim_dict, name, amount )]) # TODO: release reserved tx outputs in case anything fails by this point @@ -387,43 +377,39 @@ class LbryWalletManager(BaseWalletManager): "claim_sequence": -1, } - @defer.inlineCallbacks - def support_claim(self, claim_name, claim_id, amount, account): - holding_address = yield account.receiving.get_or_create_usable_address() - tx = yield Transaction.support(claim_name, claim_id, amount, holding_address, [account], account) - yield account.ledger.broadcast(tx) + async def support_claim(self, claim_name, claim_id, amount, account): + holding_address = await account.receiving.get_or_create_usable_address() + tx = await Transaction.support(claim_name, claim_id, amount, holding_address, [account], account) + await account.ledger.broadcast(tx) return tx - @defer.inlineCallbacks - def tip_claim(self, amount, claim_id, account): - claim_to_tip = yield self.get_claim_by_claim_id(claim_id) - tx = yield Transaction.support( + async def tip_claim(self, amount, claim_id, account): + claim_to_tip = await self.get_claim_by_claim_id(claim_id) + tx = await Transaction.support( claim_to_tip['name'], claim_id, amount, claim_to_tip['address'], [account], account ) - yield account.ledger.broadcast(tx) + await account.ledger.broadcast(tx) return tx - @defer.inlineCallbacks - def abandon_claim(self, claim_id, txid, nout, account): - claim = yield account.get_claim(claim_id=claim_id, txid=txid, nout=nout) + async def abandon_claim(self, claim_id, txid, nout, account): + claim = await account.get_claim(claim_id=claim_id, txid=txid, nout=nout) if not claim: raise Exception('No claim found for the specified claim_id or txid:nout') - tx = yield Transaction.abandon(claim, [account], account) - yield account.ledger.broadcast(tx) + tx = await Transaction.abandon(claim, [account], account) + await account.ledger.broadcast(tx) # TODO: release reserved tx outputs in case anything fails by this point - defer.returnValue(tx) + return tx - @defer.inlineCallbacks - def claim_new_channel(self, channel_name, amount): + async def claim_new_channel(self, channel_name, amount): account = self.default_account - address = yield account.receiving.get_or_create_usable_address() + address = await account.receiving.get_or_create_usable_address() cert, key = generate_certificate() - tx = yield Transaction.claim(channel_name, cert, amount, address, [account], account) - yield account.ledger.broadcast(tx) + tx = await Transaction.claim(channel_name, cert, amount, address, [account], account) + await account.ledger.broadcast(tx) account.add_certificate_private_key(tx.outputs[0].ref, key.decode()) # TODO: release reserved tx outputs in case anything fails by this point - defer.returnValue(tx) + return tx def get_certificates(self, private_key_accounts, exclude_without_key=True, **constraints): return self.db.get_certificates( @@ -443,7 +429,7 @@ class LbryWalletManager(BaseWalletManager): pass # TODO: Data payments is disabled def send_points(self, reserved_points, amount): - defer.succeed(True) # TODO: Data payments is disabled + return True # TODO: Data payments is disabled def cancel_point_reservation(self, reserved_points): pass # fixme: disabled for now. diff --git a/lbrynet/wallet/resolve.py b/lbrynet/wallet/resolve.py index 7da7d52d2..d56a7c1f0 100644 --- a/lbrynet/wallet/resolve.py +++ b/lbrynet/wallet/resolve.py @@ -3,8 +3,6 @@ import logging from ecdsa import BadSignatureError from binascii import unhexlify, hexlify -from twisted.internet import defer - from lbrynet.core.Error import UnknownNameError, UnknownClaimID, UnknownURI, UnknownOutpoint from lbryschema.address import is_address from lbryschema.claim import ClaimDict @@ -25,24 +23,22 @@ class Resolver: self.hash160_to_address = hash160_to_address self.network = network - @defer.inlineCallbacks - def _handle_resolutions(self, resolutions, requested_uris, page, page_size): + async def _handle_resolutions(self, resolutions, requested_uris, page, page_size): results = {} for uri in requested_uris: resolution = (resolutions or {}).get(uri, {}) if resolution: try: results[uri] = _handle_claim_result( - (yield self._handle_resolve_uri_response(uri, resolution, page, page_size)) + await self._handle_resolve_uri_response(uri, resolution, page, page_size) ) except (UnknownNameError, UnknownClaimID, UnknownURI) as err: results[uri] = {'error': str(err)} else: results[uri] = {'error': "URI lbry://{} cannot be resolved".format(uri.replace("lbry://", ""))} - defer.returnValue(results) + return results - @defer.inlineCallbacks - def _handle_resolve_uri_response(self, uri, resolution, page=0, page_size=10, raw=False): + async def _handle_resolve_uri_response(self, uri, resolution, page=0, page_size=10, raw=False): result = {} claim_trie_root = self.claim_trie_root parsed_uri = parse_lbry_uri(uri) @@ -120,21 +116,21 @@ class Resolver: elif 'unverified_claims_for_name' in resolution and 'certificate' in result: unverified_claims_for_name = resolution['unverified_claims_for_name'] - channel_info = yield self.get_channel_claims_page(unverified_claims_for_name, + channel_info = await self.get_channel_claims_page(unverified_claims_for_name, result['certificate'], page=1) claims_in_channel, upper_bound = channel_info - if len(claims_in_channel) > 1: - log.error("Multiple signed claims for the same name") - elif not claims_in_channel: + if not claims_in_channel: log.error("No valid claims for this name for this channel") + elif len(claims_in_channel) > 1: + log.error("Multiple signed claims for the same name") else: result['claim'] = claims_in_channel[0] # parse and validate claims in a channel iteratively into pages of results elif 'unverified_claims_in_channel' in resolution and 'certificate' in result: ids_to_check = resolution['unverified_claims_in_channel'] - channel_info = yield self.get_channel_claims_page(ids_to_check, result['certificate'], + channel_info = await self.get_channel_claims_page(ids_to_check, result['certificate'], page=page, page_size=page_size) claims_in_channel, upper_bound = channel_info @@ -145,16 +141,15 @@ class Resolver: result['success'] = False result['uri'] = str(parsed_uri) - defer.returnValue(result) + return result - @defer.inlineCallbacks - def get_certificate_and_validate_result(self, claim_result): + async def get_certificate_and_validate_result(self, claim_result): if not claim_result or 'value' not in claim_result: return claim_result certificate = None certificate_id = smart_decode(claim_result['value']).certificate_id if certificate_id: - certificate = yield self.network.get_claims_by_ids(certificate_id.decode()) + certificate = await self.network.get_claims_by_ids(certificate_id.decode()) certificate = certificate.pop(certificate_id.decode()) if certificate else None return self.parse_and_validate_claim_result(claim_result, certificate=certificate) @@ -227,8 +222,7 @@ class Resolver: abs_position += 1 return queries, names, absolute_position_index - @defer.inlineCallbacks - def iter_channel_claims_pages(self, queries, claim_positions, claim_names, certificate, + async def iter_channel_claims_pages(self, queries, claim_positions, claim_names, certificate, page_size=10): # lbryum server returns a dict of {claim_id: (name, claim_height)} # first, sort the claims by block height (and by claim id int value within a block). @@ -243,11 +237,10 @@ class Resolver: # processed them. # TODO: fix ^ in lbryschema - @defer.inlineCallbacks - def iter_validate_channel_claims(): + async def iter_validate_channel_claims(): formatted_claims = [] for claim_ids in queries: - batch_result = yield self.network.get_claims_by_ids(*claim_ids) + batch_result = await self.network.get_claims_by_ids(*claim_ids) for claim_id in claim_ids: claim = batch_result[claim_id] if claim['name'] == claim_names[claim_id]: @@ -258,25 +251,20 @@ class Resolver: else: log.warning("ignoring claim with name mismatch %s %s", claim['name'], claim['claim_id']) - defer.returnValue(formatted_claims) + return formatted_claims - yielded_page = False results = [] - for claim in (yield iter_validate_channel_claims()): + for claim in (await iter_validate_channel_claims()): results.append(claim) # if there is a full page of results, yield it if len(results) and len(results) % page_size == 0: - defer.returnValue(results[-page_size:]) - yielded_page = True + return results[-page_size:] - # if we didn't get a full page of results, yield what results we did get - if not yielded_page: - defer.returnValue(results) + return results - @defer.inlineCallbacks - def get_channel_claims_page(self, channel_claim_infos, certificate, page, page_size=10): + async def get_channel_claims_page(self, channel_claim_infos, certificate, page, page_size=10): page = page or 0 page_size = max(page_size, 1) if page_size > 500: @@ -284,14 +272,14 @@ class Resolver: start_position = (page - 1) * page_size queries, names, claim_positions = self.prepare_claim_queries(start_position, page_size, channel_claim_infos) - page_generator = yield self.iter_channel_claims_pages(queries, claim_positions, names, + page_generator = await self.iter_channel_claims_pages(queries, claim_positions, names, certificate, page_size=page_size) upper_bound = len(claim_positions) if not page: - defer.returnValue((None, upper_bound)) + return None, upper_bound if start_position > upper_bound: raise IndexError("claim %i greater than max %i" % (start_position, upper_bound)) - defer.returnValue((page_generator, upper_bound)) + return page_generator, upper_bound # Format amount to be decimal encoded string diff --git a/tests/integration/wallet/test_transactions.py b/tests/integration/wallet/test_transactions.py index 5e5047937..864d22cce 100644 --- a/tests/integration/wallet/test_transactions.py +++ b/tests/integration/wallet/test_transactions.py @@ -1,6 +1,6 @@ import asyncio -from orchstr8.testcase import IntegrationTestCase, d2f +from orchstr8.testcase import IntegrationTestCase from lbryschema.claim import ClaimDict from lbrynet.wallet.transaction import Transaction from lbrynet.wallet.account import generate_certificate @@ -43,9 +43,9 @@ class BasicTransactionTest(IntegrationTestCase): async def test_creating_updating_and_abandoning_claim_with_channel(self): - await d2f(self.account.ensure_address_gap()) + await self.account.ensure_address_gap() - address1, address2 = await d2f(self.account.receiving.get_addresses(limit=2, only_usable=True)) + address1, address2 = await self.account.receiving.get_addresses(limit=2, only_usable=True) sendtxid1 = await self.blockchain.send_to_address(address1, 5) sendtxid2 = await self.blockchain.send_to_address(address2, 5) await self.blockchain.generate(1) @@ -54,13 +54,13 @@ class BasicTransactionTest(IntegrationTestCase): self.on_transaction_id(sendtxid2), ]) - self.assertEqual(d2l(await d2f(self.account.get_balance())), '10.0') + self.assertEqual(d2l(await self.account.get_balance()), '10.0') cert, key = generate_certificate() - cert_tx = await d2f(Transaction.claim('@bar', cert, l2d('1.0'), address1, [self.account], self.account)) + cert_tx = await Transaction.claim('@bar', cert, l2d('1.0'), address1, [self.account], self.account) claim = ClaimDict.load_dict(example_claim_dict) claim = claim.sign(key, address1, cert_tx.outputs[0].claim_id) - claim_tx = await d2f(Transaction.claim('foo', claim, l2d('1.0'), address1, [self.account], self.account)) + claim_tx = await Transaction.claim('foo', claim, l2d('1.0'), address1, [self.account], self.account) await self.broadcast(cert_tx) await self.broadcast(claim_tx) @@ -74,23 +74,23 @@ class BasicTransactionTest(IntegrationTestCase): self.on_transaction(cert_tx), ]) - self.assertEqual(d2l(await d2f(self.account.get_balance(confirmations=1))), '7.985786') - self.assertEqual(d2l(await d2f(self.account.get_balance(include_claims=True))), '9.985786') + self.assertEqual(d2l(await self.account.get_balance(confirmations=1)), '7.985786') + self.assertEqual(d2l(await self.account.get_balance(include_claims=True)), '9.985786') - response = await d2f(self.ledger.resolve(0, 10, 'lbry://@bar/foo')) + response = await self.ledger.resolve(0, 10, 'lbry://@bar/foo') self.assertIn('lbry://@bar/foo', response) self.assertIn('claim', response['lbry://@bar/foo']) - abandon_tx = await d2f(Transaction.abandon([claim_tx.outputs[0]], [self.account], self.account)) + abandon_tx = await Transaction.abandon([claim_tx.outputs[0]], [self.account], self.account) await self.broadcast(abandon_tx) await self.on_transaction(abandon_tx) await self.blockchain.generate(1) await self.on_transaction(abandon_tx) - response = await d2f(self.ledger.resolve(0, 10, 'lbry://@bar/foo')) + response = await self.ledger.resolve(0, 10, 'lbry://@bar/foo') self.assertNotIn('claim', response['lbry://@bar/foo']) # checks for expected format in inexistent URIs - response = await d2f(self.ledger.resolve(0, 10, 'lbry://404', 'lbry://@404')) + response = await self.ledger.resolve(0, 10, 'lbry://404', 'lbry://@404') self.assertEqual('URI lbry://404 cannot be resolved', response['lbry://404']['error']) self.assertEqual('URI lbry://@404 cannot be resolved', response['lbry://@404']['error']) diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index af05828ca..30bf7a35b 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -1,23 +1,24 @@ -from twisted.trial import unittest -from twisted.internet import defer +from orchstr8.testcase import AsyncioTestCase +from torba.wallet import Wallet from lbrynet.wallet.ledger import MainNetLedger, WalletDatabase from lbrynet.wallet.header import Headers from lbrynet.wallet.account import Account -from torba.wallet import Wallet -class TestAccount(unittest.TestCase): +class TestAccount(AsyncioTestCase): - def setUp(self): + async def asyncSetUp(self): self.ledger = MainNetLedger({ 'db': WalletDatabase(':memory:'), 'headers': Headers(':memory:') }) - return self.ledger.db.open() + await self.ledger.db.open() - @defer.inlineCallbacks - def test_generate_account(self): + async def asyncTearDown(self): + await self.ledger.db.close() + + async def test_generate_account(self): account = Account.generate(self.ledger, Wallet(), 'lbryum') self.assertEqual(account.ledger, self.ledger) self.assertIsNotNone(account.seed) @@ -27,20 +28,19 @@ class TestAccount(unittest.TestCase): self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 0) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 0) - yield account.ensure_address_gap() + await account.ensure_address_gap() - addresses = yield account.receiving.get_addresses() + addresses = await account.receiving.get_addresses() self.assertEqual(len(addresses), 20) - addresses = yield account.change.get_addresses() + addresses = await account.change.get_addresses() self.assertEqual(len(addresses), 6) - @defer.inlineCallbacks - def test_generate_account_from_seed(self): + async def test_generate_account_from_seed(self): account = Account.from_dict( self.ledger, Wallet(), { "seed": @@ -58,16 +58,16 @@ class TestAccount(unittest.TestCase): 'xpub661MyMwAqRbcGWtPvbWh9sc2BCfw2cTeVDYF23o3N1t6UZ5wv3EMmDgp66FxH' 'uDtWdft3B5eL5xQtyzAtkdmhhC95gjRjLzSTdkho95asu9' ) - address = yield account.receiving.ensure_address_gap() + address = await account.receiving.ensure_address_gap() self.assertEqual(address[0], 'bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') - private_key = yield self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') + private_key = await self.ledger.get_private_key_for_address('bCqJrLHdoiRqEZ1whFZ3WHNb33bP34SuGx') self.assertEqual( private_key.extended_key_string(), 'xprv9vwXVierUTT4hmoe3dtTeBfbNv1ph2mm8RWXARU6HsZjBaAoFaS2FRQu4fptR' 'AyJWhJW42dmsEaC1nKnVKKTMhq3TVEHsNj1ca3ciZMKktT' ) - private_key = yield self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + private_key = await self.ledger.get_private_key_for_address('BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') self.assertIsNone(private_key) def test_load_and_save_account(self): diff --git a/tests/unit/wallet/test_claim_proofs.py b/tests/unit/wallet/test_claim_proofs.py index 4b288c799..aa547b3f4 100644 --- a/tests/unit/wallet/test_claim_proofs.py +++ b/tests/unit/wallet/test_claim_proofs.py @@ -1,5 +1,5 @@ -from binascii import hexlify, unhexlify import unittest +from binascii import hexlify, unhexlify from lbrynet.wallet.claim_proofs import get_hash_for_outpoint, verify_proof from lbryschema.hashing import double_sha256 diff --git a/tests/unit/wallet/test_dewies.py b/tests/unit/wallet/test_dewies.py index ca0177e13..a29a77680 100644 --- a/tests/unit/wallet/test_dewies.py +++ b/tests/unit/wallet/test_dewies.py @@ -1,4 +1,5 @@ -from twisted.trial import unittest +import unittest + from lbrynet.wallet.dewies import lbc_to_dewies as l2d, dewies_to_lbc as d2l diff --git a/tests/unit/wallet/test_headers.py b/tests/unit/wallet/test_headers.py index 45da17f72..872c5bac8 100644 --- a/tests/unit/wallet/test_headers.py +++ b/tests/unit/wallet/test_headers.py @@ -1,15 +1,12 @@ -from io import BytesIO from binascii import unhexlify -from twisted.trial import unittest -from twisted.internet import defer +from orchstr8.testcase import AsyncioTestCase +from torba.util import ArithUint256 from lbrynet.wallet.ledger import Headers -from torba.util import ArithUint256 - -class TestHeaders(unittest.TestCase): +class TestHeaders(AsyncioTestCase): def test_deserialize(self): self.maxDiff = None @@ -36,19 +33,17 @@ class TestHeaders(unittest.TestCase): 'version': 536870912 }) - @defer.inlineCallbacks - def test_connect_from_genesis(self): + async def test_connect_from_genesis(self): headers = Headers(':memory:') self.assertEqual(headers.height, -1) - yield headers.connect(0, HEADERS) + await headers.connect(0, HEADERS) self.assertEqual(headers.height, 19) - @defer.inlineCallbacks - def test_connect_from_middle(self): + async def test_connect_from_middle(self): h = Headers(':memory:') h.io.write(HEADERS[:10*Headers.header_size]) self.assertEqual(h.height, 9) - yield h.connect(len(h), HEADERS[10*Headers.header_size:20*Headers.header_size]) + await h.connect(len(h), HEADERS[10*Headers.header_size:20*Headers.header_size]) self.assertEqual(h.height, 19) def test_target_calculation(self): diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 4a587130d..5cd097696 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -1,79 +1,68 @@ -from twisted.internet import defer -from twisted.trial import unittest +from orchstr8.testcase import AsyncioTestCase +from torba.wallet import Wallet + from lbrynet.wallet.account import Account from lbrynet.wallet.transaction import Transaction, Output, Input from lbrynet.wallet.ledger import MainNetLedger -from torba.wallet import Wallet -class LedgerTestCase(unittest.TestCase): +class LedgerTestCase(AsyncioTestCase): - def setUp(self): - super().setUp() + async def asyncSetUp(self): self.ledger = MainNetLedger({ 'db': MainNetLedger.database_class(':memory:'), 'headers': MainNetLedger.headers_class(':memory:') }) self.account = Account.generate(self.ledger, Wallet(), "lbryum") - return self.ledger.db.open() + await self.ledger.db.open() - def tearDown(self): - super().tearDown() - return self.ledger.db.close() + async def asyncTearDown(self): + await self.ledger.db.close() class BasicAccountingTests(LedgerTestCase): - @defer.inlineCallbacks - def test_empty_state(self): - balance = yield self.account.get_balance() - self.assertEqual(balance, 0) + async def test_empty_state(self): + self.assertEqual(await self.account.get_balance(), 0) - @defer.inlineCallbacks - def test_balance(self): - address = yield self.account.receiving.get_or_create_usable_address() + async def test_balance(self): + address = await self.account.receiving.get_or_create_usable_address() hash160 = self.ledger.address_to_hash160(address) tx = Transaction(is_verified=True)\ .add_outputs([Output.pay_pubkey_hash(100, hash160)]) - yield self.ledger.db.save_transaction_io( + await self.ledger.db.save_transaction_io( 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) ) - balance = yield self.account.get_balance(0) - self.assertEqual(balance, 100) + self.assertEqual(await self.account.get_balance(), 100) tx = Transaction(is_verified=True)\ .add_outputs([Output.pay_claim_name_pubkey_hash(100, 'foo', b'', hash160)]) - yield self.ledger.db.save_transaction_io( + await self.ledger.db.save_transaction_io( 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) ) - balance = yield self.account.get_balance(0) - self.assertEqual(balance, 100) # claim names don't count towards balance - balance = yield self.account.get_balance(0, include_claims=True) - self.assertEqual(balance, 200) + self.assertEqual(await self.account.get_balance(), 100) # claim names don't count towards balance + self.assertEqual(await self.account.get_balance(include_claims=True), 200) - @defer.inlineCallbacks - def test_get_utxo(self): + async def test_get_utxo(self): address = yield self.account.receiving.get_or_create_usable_address() hash160 = self.ledger.address_to_hash160(address) tx = Transaction(is_verified=True)\ .add_outputs([Output.pay_pubkey_hash(100, hash160)]) - yield self.ledger.db.save_transaction_io( + await self.ledger.db.save_transaction_io( 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) ) - utxos = yield self.account.get_utxos() + utxos = await self.account.get_utxos() self.assertEqual(len(utxos), 1) tx = Transaction(is_verified=True)\ .add_inputs([Input.spend(utxos[0])]) - yield self.ledger.db.save_transaction_io( + await self.ledger.db.save_transaction_io( 'insert', tx, address, hash160, '{}:{}:'.format(tx.id, 1) ) - balance = yield self.account.get_balance(0, include_claims=True) - self.assertEqual(balance, 0) + self.assertEqual(await self.account.get_balance(include_claims=True), 0) - utxos = yield self.account.get_utxos() + utxos = await self.account.get_utxos() self.assertEqual(len(utxos), 0) - diff --git a/tests/unit/wallet/test_script.py b/tests/unit/wallet/test_script.py index 3e2dde27b..5954d61e6 100644 --- a/tests/unit/wallet/test_script.py +++ b/tests/unit/wallet/test_script.py @@ -1,5 +1,5 @@ +import unittest from binascii import hexlify, unhexlify -from twisted.trial import unittest from lbrynet.wallet.script import OutputScript diff --git a/tests/unit/wallet/test_transaction.py b/tests/unit/wallet/test_transaction.py index 3cab170cd..f3359ac22 100644 --- a/tests/unit/wallet/test_transaction.py +++ b/tests/unit/wallet/test_transaction.py @@ -1,7 +1,7 @@ +import unittest from binascii import hexlify, unhexlify -from twisted.trial import unittest -from twisted.internet import defer +from orchstr8.testcase import AsyncioTestCase from torba.constants import CENT, COIN, NULL_HASH32 from torba.wallet import Wallet @@ -35,19 +35,17 @@ def get_claim_transaction(claim_name, claim=b''): ) -class TestSizeAndFeeEstimation(unittest.TestCase): +class TestSizeAndFeeEstimation(AsyncioTestCase): - def setUp(self): - super().setUp() + async def asyncSetUp(self): self.ledger = MainNetLedger({ 'db': MainNetLedger.database_class(':memory:'), 'headers': MainNetLedger.headers_class(':memory:') }) - return self.ledger.db.open() + await self.ledger.db.open() - def tearDown(self): - super().tearDown() - return self.ledger.db.close() + async def asyncTearDown(self): + await self.ledger.db.close() def test_output_size_and_fee(self): txo = get_output() @@ -219,22 +217,19 @@ class TestTransactionSerialization(unittest.TestCase): self.assertEqual(tx.raw, raw) -class TestTransactionSigning(unittest.TestCase): +class TestTransactionSigning(AsyncioTestCase): - def setUp(self): - super().setUp() + async def asyncSetUp(self): self.ledger = MainNetLedger({ 'db': MainNetLedger.database_class(':memory:'), 'headers': MainNetLedger.headers_class(':memory:') }) - return self.ledger.db.open() + await self.ledger.db.open() - def tearDown(self): - super().tearDown() - return self.ledger.db.close() + async def asyncTearDown(self): + await self.ledger.db.close() - @defer.inlineCallbacks - def test_sign(self): + async def test_sign(self): account = self.ledger.account_class.from_dict( self.ledger, Wallet(), { "seed": @@ -243,8 +238,8 @@ class TestTransactionSigning(unittest.TestCase): } ) - yield account.ensure_address_gap() - address1, address2 = yield account.receiving.get_addresses(limit=2) + await account.ensure_address_gap() + address1, address2 = await account.receiving.get_addresses(limit=2) pubkey_hash1 = self.ledger.address_to_hash160(address1) pubkey_hash2 = self.ledger.address_to_hash160(address2) @@ -252,7 +247,7 @@ class TestTransactionSigning(unittest.TestCase): .add_inputs([Input.spend(get_output(int(2*COIN), pubkey_hash1))]) \ .add_outputs([Output.pay_pubkey_hash(int(1.9*COIN), pubkey_hash2)]) - yield tx.sign([account]) + await tx.sign([account]) self.assertEqual( hexlify(tx.inputs[0].script.values['signature']),