diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index 31c2ba60c..e69de29bb 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -1,17 +0,0 @@ -__node_daemon__ = 'lbrycrdd' -__node_cli__ = 'lbrycrd-cli' -__node_bin__ = '' -__node_url__ = ( - 'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.4/lbrycrd-linux-1744.zip' -) -__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' - -from .bip32 import PubKey -from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK -from .manager import WalletManager -from .network import Network -from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent -from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic -from .transaction import Transaction, Output, Input -from .script import OutputScript, InputScript -from .header import Headers diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 2c9c32b8a..31fa614de 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -2,7 +2,6 @@ import os import time import json import logging -import typing import asyncio import random from functools import partial @@ -16,14 +15,15 @@ import ecdsa from lbry.error import InvalidPasswordError from lbry.crypto.crypt import aes_encrypt, aes_decrypt -from .bip32 import PrivateKey, PubKey, from_extended_key_string -from .mnemonic import Mnemonic -from .constants import COIN, CLAIM_TYPES, TXO_TYPES -from .transaction import Transaction, Input, Output +from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string +from lbry.constants import COIN +from lbry.blockchain.transaction import Transaction, Input, Output +from lbry.blockchain.ledger import Ledger +from lbry.db import Database +from lbry.db.constants import CLAIM_TYPE_CODES, TXO_TYPES + +from .mnemonic import Mnemonic -if typing.TYPE_CHECKING: - from .ledger import Ledger - from .wallet import Wallet log = logging.getLogger(__name__) @@ -71,12 +71,12 @@ class AddressManager: def to_dict_instance(self) -> Optional[dict]: raise NotImplementedError - def _query_addresses(self, **constraints): - return self.account.ledger.db.get_addresses( - accounts=[self.account], + async def _query_addresses(self, **constraints): + return (await self.account.db.get_addresses( + account=self.account, chain=self.chain_number, **constraints - ) + ))[0] def get_private_key(self, index: int) -> PrivateKey: raise NotImplementedError @@ -166,22 +166,22 @@ class HierarchicalDeterministic(AddressManager): start = addresses[0]['pubkey'].n+1 if addresses else 0 end = start + (self.gap - existing_gap) new_keys = await self._generate_keys(start, end-1) - await self.account.ledger.announce_addresses(self, new_keys) + #await self.account.ledger.announce_addresses(self, new_keys) return new_keys async def _generate_keys(self, start: int, end: int) -> List[str]: if not self.address_generator_lock.locked(): raise RuntimeError('Should not be called outside of address_generator_lock.') keys = [self.public_key.child(index) for index in range(start, end+1)] - await self.account.ledger.db.add_keys(self.account, self.chain_number, keys) + await self.account.db.add_keys(self.account, self.chain_number, keys) return [key.address for key in keys] - def get_address_records(self, only_usable: bool = False, **constraints): + async def get_address_records(self, only_usable: bool = False, **constraints): if only_usable: constraints['used_times__lt'] = self.maximum_uses_per_address if 'order_by' not in constraints: constraints['order_by'] = "used_times asc, n asc" - return self._query_addresses(**constraints) + return await self._query_addresses(**constraints) class SingleKey(AddressManager): @@ -213,14 +213,14 @@ class SingleKey(AddressManager): async with self.address_generator_lock: exists = await self.get_address_records() if not exists: - await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key]) + await self.account.db.add_keys(self.account, self.chain_number, [self.public_key]) new_keys = [self.public_key.address] - await self.account.ledger.announce_addresses(self, new_keys) + #await self.account.ledger.announce_addresses(self, new_keys) return new_keys return [] - def get_address_records(self, only_usable: bool = False, **constraints): - return self._query_addresses(**constraints) + async def get_address_records(self, only_usable: bool = False, **constraints): + return await self._query_addresses(**constraints) class Account: @@ -233,12 +233,12 @@ class Account: HierarchicalDeterministic.name: HierarchicalDeterministic, } - def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str, + def __init__(self, ledger: 'Ledger', db: 'Database', name: str, seed: str, private_key_string: str, encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey, address_generator: dict, modified_on: float, channel_keys: dict) -> None: self.ledger = ledger - self.wallet = wallet + self.db = db self.id = public_key.address self.name = name self.seed = seed @@ -253,8 +253,6 @@ class Account: self.receiving, self.change = self.address_generator.from_dict(self, address_generator) self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}} self.channel_keys = channel_keys - ledger.add_account(self) - wallet.add_account(self) def get_init_vector(self, key) -> Optional[bytes]: init_vector = self.init_vectors.get(key, None) @@ -263,9 +261,9 @@ class Account: return init_vector @classmethod - def generate(cls, ledger: 'Ledger', wallet: 'Wallet', + def generate(cls, ledger: 'Ledger', db: 'Database', name: str = None, address_generator: dict = None): - return cls.from_dict(ledger, wallet, { + return cls.from_dict(ledger, db, { 'name': name, 'seed': cls.mnemonic_class().make_seed(), 'address_generator': address_generator or {} @@ -297,14 +295,14 @@ class Account: return seed, private_key, public_key @classmethod - def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict): + def from_dict(cls, ledger: 'Ledger', db: 'Database', d: dict): seed, private_key, public_key = cls.keys_from_dict(ledger, d) name = d.get('name') if not name: name = f'Account #{public_key.address}' return cls( ledger=ledger, - wallet=wallet, + db=db, name=name, seed=seed, private_key_string=d.get('private_key', ''), @@ -328,7 +326,6 @@ class Account: if seed: seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed')) d = { - 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': seed, 'encrypted': bool(self.encrypted or encrypt_password), @@ -365,7 +362,6 @@ class Account: details = { 'id': self.id, 'name': self.name, - 'ledger': self.ledger.get_id(), 'coins': round(satoshis/COIN, 2), 'satoshis': satoshis, 'encrypted': self.encrypted, @@ -436,15 +432,18 @@ class Account: addresses.extend(new_addresses) return addresses + async def get_address_records(self, **constraints): + return await self.db.get_addresses(account=self, **constraints) + async def get_addresses(self, **constraints) -> List[str]: - rows = await self.ledger.db.select_addresses([text('account_address.address')], accounts=[self], **constraints) + rows, _ = await self.get_address_records(cols=['account_address.address'], **constraints) return [r['address'] for r in rows] - def get_address_records(self, **constraints): - return self.ledger.db.get_addresses(accounts=[self], **constraints) - - def get_address_count(self, **constraints): - return self.ledger.db.get_address_count(accounts=[self], **constraints) + async def get_valid_receiving_address(self, default_address: str) -> str: + if default_address is None: + return await self.receiving.get_or_create_usable_address() + self.ledger.valid_address_or_error(default_address) + return default_address def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." @@ -459,7 +458,7 @@ class Account: if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) - return self.ledger.db.get_balance(accounts=[self], **constraints) + return self.db.get_balance(account=self, **constraints) async def get_max_gap(self): change_gap = await self.change.get_max_gap() @@ -469,24 +468,6 @@ class Account: 'max_receiving_gap': receiving_gap, } - def get_txos(self, **constraints): - return self.ledger.get_txos(wallet=self.wallet, accounts=[self], **constraints) - - def get_txo_count(self, **constraints): - return self.ledger.get_txo_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_utxos(self, **constraints): - return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints) - - def get_utxo_count(self, **constraints): - return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_transactions(self, **constraints): - return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints) - - def get_transaction_count(self, **constraints): - return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints) - async def fund(self, to_account, amount=None, everything=False, outputs=1, broadcast=False, **constraints): assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' @@ -551,9 +532,9 @@ class Account: self.wallet.save() async def save_max_gap(self): + gap_changed = False if issubclass(self.address_generator, HierarchicalDeterministic): gap = await self.get_max_gap() - gap_changed = False new_receiving_gap = max(20, gap['max_receiving_gap'] + 1) if self.receiving.gap != new_receiving_gap: self.receiving.gap = new_receiving_gap @@ -562,8 +543,10 @@ class Account: if self.change.gap != new_change_gap: self.change.gap = new_change_gap gap_changed = True - if gap_changed: - self.wallet.save() + return gap_changed + + def get_support_summary(self): + return self.db.get_supports_summary(account=self) async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False): tips_balance, supports_balance, claims_balance = 0, 0, 0 @@ -571,7 +554,7 @@ class Account: include_claims=True) total = await get_total_balance() if reserved_subtotals: - claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES) + claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPE_CODES) for txo in await self.get_support_summary(): if confirmations > 0 and not 0 < txo.tx_ref.height <= self.ledger.headers.height - (confirmations - 1): continue @@ -594,49 +577,3 @@ class Account: 'tips': tips_balance } if reserved_subtotals else None } - - def get_transaction_history(self, **constraints): - return self.ledger.get_transaction_history( - wallet=self.wallet, accounts=[self], **constraints - ) - - def get_transaction_history_count(self, **constraints): - return self.ledger.get_transaction_history_count( - wallet=self.wallet, accounts=[self], **constraints - ) - - def get_claims(self, **constraints): - return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints) - - def get_claim_count(self, **constraints): - return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_streams(self, **constraints): - return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints) - - def get_stream_count(self, **constraints): - return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_channels(self, **constraints): - return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints) - - def get_channel_count(self, **constraints): - return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_collections(self, **constraints): - return self.ledger.get_collections(wallet=self.wallet, accounts=[self], **constraints) - - def get_collection_count(self, **constraints): - return self.ledger.get_collection_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_supports(self, **constraints): - return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints) - - def get_support_count(self, **constraints): - return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints) - - def get_support_summary(self): - return self.ledger.db.get_supports_summary(wallet=self.wallet, accounts=[self]) - - async def release_all_outputs(self): - await self.ledger.db.release_all_outputs(self) diff --git a/lbry/wallet/coinselection.py b/lbry/wallet/coinselection.py index 63bbb6977..2323d1179 100644 --- a/lbry/wallet/coinselection.py +++ b/lbry/wallet/coinselection.py @@ -1,18 +1,32 @@ from random import Random from typing import List -from lbry.wallet.transaction import OutputEffectiveAmountEstimator +from lbry.blockchain.transaction import Input, Output MAXIMUM_TRIES = 100000 -STRATEGIES = [] +COIN_SELECTION_STRATEGIES = [] def strategy(method): - STRATEGIES.append(method.__name__) + COIN_SELECTION_STRATEGIES.append(method.__name__) return method +class OutputEffectiveAmountEstimator: + + __slots__ = 'txo', 'txi', 'fee', 'effective_amount' + + def __init__(self, ledger, txo: Output) -> None: + self.txo = txo + self.txi = Input.spend(txo) + self.fee: int = self.txi.get_fee(ledger) + self.effective_amount: int = txo.amount - self.fee + + def __lt__(self, other): + return self.effective_amount < other.effective_amount + + class CoinSelector: def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None: diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 40f14b762..57a2b2d62 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -10,16 +10,14 @@ from typing import List, Type, MutableSequence, MutableMapping, Optional from lbry.error import KeyFeeAboveMaxAllowedError from lbry.conf import Config -from .dewies import dewies_to_lbc from .account import Account -from .ledger import Ledger, LedgerRegistry -from .transaction import Transaction, Output -from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK -from .rpc.jsonrpc import CodeMessageError +from lbry.blockchain.dewies import dewies_to_lbc +from lbry.blockchain.ledger import Ledger +from lbry.db import Database +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.transaction import Transaction, Output -if typing.TYPE_CHECKING: - from lbry.db import Database - from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager +from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK log = logging.getLogger(__name__) @@ -27,35 +25,42 @@ log = logging.getLogger(__name__) class WalletManager: - def __init__(self, wallets: MutableSequence[Wallet] = None, + def __init__(self, ledger: Ledger, db: Database, + wallets: MutableSequence[Wallet] = None, ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None: + self.ledger = ledger + self.db = db self.wallets = wallets or [] self.ledgers = ledgers or {} self.running = False self.config: Optional[Config] = None - @classmethod - def from_config(cls, config: dict) -> 'WalletManager': - manager = cls() - for ledger_id, ledger_config in config.get('ledgers', {}).items(): - manager.get_or_create_ledger(ledger_id, ledger_config) - for wallet_path in config.get('wallets', []): - wallet_storage = WalletStorage(wallet_path) - wallet = Wallet.from_storage(wallet_storage, manager) - manager.wallets.append(wallet) - return manager + async def open(self): + conf = self.ledger.conf - def get_or_create_ledger(self, ledger_id, ledger_config=None): - ledger_class = LedgerRegistry.get_ledger_class(ledger_id) - ledger = self.ledgers.get(ledger_class) - if ledger is None: - ledger = ledger_class(ledger_config or {}) - self.ledgers[ledger_class] = ledger - return ledger + wallets_directory = os.path.join(conf.wallet_dir, 'wallets') + if not os.path.exists(wallets_directory): + os.mkdir(wallets_directory) + + for wallet_file in conf.wallets: + wallet_path = os.path.join(wallets_directory, wallet_file) + wallet_storage = WalletStorage(wallet_path) + wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage) + self.wallets.append(wallet) + + self.ledger.coin_selection_strategy = self.ledger.conf.coin_selection_strategy + default_wallet = self.default_wallet + if default_wallet.default_account is None: + log.info('Wallet at %s is empty, generating a default account.', default_wallet.id) + default_wallet.generate_account() + default_wallet.save() + if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None: + default_wallet.preferences[ENCRYPT_ON_DISK] = True + default_wallet.save() def import_wallet(self, path): storage = WalletStorage(path) - wallet = Wallet.from_storage(storage, self) + wallet = Wallet.from_storage(self.ledger, self.db, storage) self.wallets.append(wallet) return wallet @@ -104,123 +109,9 @@ class WalletManager: return 0 return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts) - @property - def ledger(self) -> Ledger: - return self.default_account.ledger - - @property - def db(self) -> 'Database': - return self.ledger.db - def check_locked(self): return self.default_wallet.is_locked - @staticmethod - def migrate_lbryum_to_torba(path): - if not os.path.exists(path): - return None, None - with open(path, 'r') as f: - unmigrated_json = f.read() - unmigrated = json.loads(unmigrated_json) - # TODO: After several public releases of new torba based wallet, we can delete - # this lbryum->torba conversion code and require that users who still - # have old structured wallets install one of the earlier releases that - # still has the below conversion code. - if 'master_public_keys' not in unmigrated: - return None, None - total = unmigrated.get('addr_history') - receiving_addresses, change_addresses = set(), set() - for _, unmigrated_account in unmigrated.get('accounts', {}).items(): - receiving_addresses.update(map(unhexlify, unmigrated_account.get('receiving', []))) - change_addresses.update(map(unhexlify, unmigrated_account.get('change', []))) - log.info("Wallet migrator found %s receiving addresses and %s change addresses. %s in total on history.", - len(receiving_addresses), len(change_addresses), len(total)) - - migrated_json = json.dumps({ - 'version': 1, - 'name': 'My Wallet', - 'accounts': [{ - 'version': 1, - 'name': 'Main Account', - 'ledger': 'lbc_mainnet', - 'encrypted': unmigrated['use_encryption'], - 'seed': unmigrated['seed'], - 'seed_version': unmigrated['seed_version'], - 'private_key': unmigrated['master_private_keys']['x/'], - 'public_key': unmigrated['master_public_keys']['x/'], - 'certificates': unmigrated.get('claim_certificates', {}), - 'address_generator': { - 'name': 'deterministic-chain', - 'receiving': {'gap': 20, 'maximum_uses_per_address': 1}, - 'change': {'gap': 6, 'maximum_uses_per_address': 1} - } - }] - }, indent=4, sort_keys=True) - mode = os.stat(path).st_mode - i = 1 - backup_path_template = os.path.join(os.path.dirname(path), "old_lbryum_wallet") + "_%i" - while os.path.isfile(backup_path_template % i): - i += 1 - os.rename(path, backup_path_template % i) - temp_path = f"{path}.tmp.{os.getpid()}" - with open(temp_path, "w") as f: - f.write(migrated_json) - f.flush() - os.fsync(f.fileno()) - os.rename(temp_path, path) - os.chmod(path, mode) - return receiving_addresses, change_addresses - - @classmethod - async def from_lbrynet_config(cls, config: Config): - - ledger_id = { - 'lbrycrd_main': 'lbc_mainnet', - 'lbrycrd_testnet': 'lbc_testnet', - 'lbrycrd_regtest': 'lbc_regtest' - }[config.blockchain_name] - - ledger_config = { - 'auto_connect': True, - 'default_servers': config.lbryum_servers, - 'data_path': config.wallet_dir, - } - - wallets_directory = os.path.join(config.wallet_dir, 'wallets') - if not os.path.exists(wallets_directory): - os.mkdir(wallets_directory) - - receiving_addresses, change_addresses = cls.migrate_lbryum_to_torba( - os.path.join(wallets_directory, 'default_wallet') - ) - - manager = cls.from_config({ - 'ledgers': {ledger_id: ledger_config}, - 'wallets': [ - os.path.join(wallets_directory, wallet_file) for wallet_file in config.wallets - ] - }) - manager.config = config - ledger = manager.get_or_create_ledger(ledger_id) - ledger.coin_selection_strategy = config.coin_selection_strategy - default_wallet = manager.default_wallet - if default_wallet.default_account is None: - log.info('Wallet at %s is empty, generating a default account.', default_wallet.id) - default_wallet.generate_account(ledger) - default_wallet.save() - if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None: - default_wallet.preferences[ENCRYPT_ON_DISK] = True - default_wallet.save() - if receiving_addresses or change_addresses: - if not os.path.exists(ledger.path): - os.mkdir(ledger.path) - await ledger.db.open() - try: - await manager._migrate_addresses(receiving_addresses, change_addresses) - finally: - await ledger.db.close() - return manager - async def reset(self): self.ledger.config = { 'auto_connect': True, @@ -230,24 +121,6 @@ class WalletManager: await self.ledger.stop() await self.ledger.start() - async def _migrate_addresses(self, receiving_addresses: set, change_addresses: set): - async with self.default_account.receiving.address_generator_lock: - migrated_receiving = set(await self.default_account.receiving._generate_keys(0, len(receiving_addresses))) - async with self.default_account.change.address_generator_lock: - migrated_change = set(await self.default_account.change._generate_keys(0, len(change_addresses))) - receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses)) - change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses)) - if not any(change_addresses.difference(migrated_change)): - log.info("Successfully migrated %s change addresses.", len(change_addresses)) - else: - log.warning("Failed to migrate %s change addresses!", - len(set(change_addresses).difference(set(migrated_change)))) - if not any(receiving_addresses.difference(migrated_receiving)): - log.info("Successfully migrated %s receiving addresses.", len(receiving_addresses)) - else: - log.warning("Failed to migrate %s receiving addresses!", - len(set(receiving_addresses).difference(set(migrated_receiving)))) - async def get_best_blockhash(self): if len(self.ledger.headers) <= 0: return self.ledger.genesis_hash @@ -272,35 +145,3 @@ class WalletManager: await self.ledger.maybe_verify_transaction(tx, height, merkle) return tx - async def create_purchase_transaction( - self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager', - override_max_key_fee=False): - fee = txo.claim.stream.fee - fee_amount = exchange.to_dewies(fee.currency, fee.amount) - if not override_max_key_fee and self.config.max_key_fee: - max_fee = self.config.max_key_fee - max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount'])) - if max_fee_amount and fee_amount > max_fee_amount: - error_fee = f"{dewies_to_lbc(fee_amount)} LBC" - if fee.currency != 'LBC': - error_fee += f" ({fee.amount} {fee.currency})" - error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC" - if max_fee['currency'] != 'LBC': - error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})" - raise KeyFeeAboveMaxAllowedError( - f"Purchase price of {error_fee} exceeds maximum " - f"configured price of {error_max_fee}." - ) - fee_address = fee.address or txo.get_address(self.ledger) - return await Transaction.purchase( - txo.claim_id, fee_amount, fee_address, accounts, accounts[0] - ) - - async def broadcast_or_release(self, tx, blocking=False): - try: - await self.ledger.broadcast(tx) - if blocking: - await self.ledger.wait(tx, timeout=None) - except: - await self.ledger.release_tx(tx) - raise diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py deleted file mode 100644 index e99e56567..000000000 --- a/lbry/wallet/network.py +++ /dev/null @@ -1,406 +0,0 @@ -import logging -import asyncio -import json -from time import perf_counter -from operator import itemgetter -from typing import Dict, Optional, Tuple -from binascii import hexlify - -from lbry import __version__ -from lbry.error import IncompatibleWalletServerError -from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError -from lbry.wallet.stream import StreamController - -log = logging.getLogger(__name__) - - -class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): - self.network = network - self.server = server - super().__init__(*args, **kwargs) - self._on_disconnect_controller = StreamController() - self.on_disconnected = self._on_disconnect_controller.stream - self.framer.max_size = self.max_errors = 1 << 32 - self.timeout = timeout - self.max_seconds_idle = timeout * 2 - self.response_time: Optional[float] = None - self.connection_latency: Optional[float] = None - self._response_samples = 0 - self.pending_amount = 0 - self._on_connect_cb = on_connect_callback or (lambda: None) - self.trigger_urgent_reconnect = asyncio.Event() - - @property - def available(self): - return not self.is_closing() and self.response_time is not None - - @property - def server_address_and_port(self) -> Optional[Tuple[str, int]]: - if not self.transport: - return None - return self.transport.get_extra_info('peername') - - async def send_timed_server_version_request(self, args=(), timeout=None): - timeout = timeout or self.timeout - log.debug("send version request to %s:%i", *self.server) - start = perf_counter() - result = await asyncio.wait_for( - super().send_request('server.version', args), timeout=timeout - ) - current_response_time = perf_counter() - start - response_sum = (self.response_time or 0) * self._response_samples + current_response_time - self.response_time = response_sum / (self._response_samples + 1) - self._response_samples += 1 - return result - - async def send_request(self, method, args=()): - self.pending_amount += 1 - log.debug("send %s%s to %s:%i", method, tuple(args), *self.server) - try: - if method == 'server.version': - return await self.send_timed_server_version_request(args, self.timeout) - request = asyncio.ensure_future(super().send_request(method, args)) - while not request.done(): - done, pending = await asyncio.wait([request], timeout=self.timeout) - if pending: - log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received) - if (perf_counter() - self.last_packet_received) < self.timeout: - continue - log.info("timeout sending %s to %s:%i", method, *self.server) - raise asyncio.TimeoutError - if done: - try: - return request.result() - except ConnectionResetError: - log.error( - "wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes", - self.server[0], method, len(args), len(json.dumps(args)) - ) - raise - except (RPCError, ProtocolError) as e: - log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", - *self.server, *e.args) - raise e - except ConnectionError: - log.warning("connection to %s:%i lost", *self.server) - self.synchronous_close() - raise - except asyncio.CancelledError: - log.info("cancelled sending %s to %s:%i", method, *self.server) - self.synchronous_close() - raise - finally: - self.pending_amount -= 1 - - async def ensure_session(self): - # Handles reconnecting and maintaining a session alive - # TODO: change to 'ping' on newer protocol (above 1.2) - retry_delay = default_delay = 1.0 - while True: - try: - if self.is_closing(): - await self.create_connection(self.timeout) - await self.ensure_server_version() - self._on_connect_cb() - if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None: - await self.ensure_server_version() - retry_delay = default_delay - except RPCError as e: - await self.close() - log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message) - retry_delay = 60 * 60 - except IncompatibleWalletServerError: - await self.close() - retry_delay = 60 * 60 - log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server) - except (asyncio.TimeoutError, OSError): - await self.close() - retry_delay = min(60, retry_delay * 2) - log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) - try: - await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) - except asyncio.TimeoutError: - pass - finally: - self.trigger_urgent_reconnect.clear() - - async def ensure_server_version(self, required=None, timeout=3): - required = required or self.network.PROTOCOL_VERSION - response = await asyncio.wait_for( - self.send_request('server.version', [__version__, required]), timeout=timeout - ) - if tuple(int(piece) for piece in response[0].split(".")) < self.network.MINIMUM_REQUIRED: - raise IncompatibleWalletServerError(*self.server) - return response - - async def create_connection(self, timeout=6): - connector = Connector(lambda: self, *self.server) - start = perf_counter() - await asyncio.wait_for(connector.create_connection(), timeout=timeout) - self.connection_latency = perf_counter() - start - - async def handle_request(self, request): - controller = self.network.subscription_controllers[request.method] - controller.add(request.args) - - def connection_lost(self, exc): - log.debug("Connection lost: %s:%d", *self.server) - super().connection_lost(exc) - self.response_time = None - self.connection_latency = None - self._response_samples = 0 - self.pending_amount = 0 - self._on_disconnect_controller.add(True) - - -class Network: - - PROTOCOL_VERSION = __version__ - MINIMUM_REQUIRED = (0, 65, 0) - - def __init__(self, ledger): - self.ledger = ledger - self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) - self.client: Optional[ClientSession] = None - self.server_features = None - self._switch_task: Optional[asyncio.Task] = None - self.running = False - self.remote_height: int = 0 - self._concurrency = asyncio.Semaphore(16) - - self._on_connected_controller = StreamController() - self.on_connected = self._on_connected_controller.stream - - self._on_header_controller = StreamController(merge_repeated_events=True) - self.on_header = self._on_header_controller.stream - - self._on_status_controller = StreamController(merge_repeated_events=True) - self.on_status = self._on_status_controller.stream - - self.subscription_controllers = { - 'blockchain.headers.subscribe': self._on_header_controller, - 'blockchain.address.subscribe': self._on_status_controller, - } - - @property - def config(self): - return self.ledger.config - - async def switch_forever(self): - while self.running: - if self.is_connected: - await self.client.on_disconnected.first - self.server_features = None - self.client = None - continue - self.client = await self.session_pool.wait_for_fastest_session() - log.info("Switching to SPV wallet server: %s:%d", *self.client.server) - try: - self.server_features = await self.get_server_features() - self._update_remote_height((await self.subscribe_headers(),)) - self._on_connected_controller.add(True) - log.info("Subscribed to headers: %s:%d", *self.client.server) - except (asyncio.TimeoutError, ConnectionError): - log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server) - self.client.synchronous_close() - self.server_features = None - self.client = None - - async def start(self): - self.running = True - self._switch_task = asyncio.ensure_future(self.switch_forever()) - # this may become unnecessary when there are no more bugs found, - # but for now it helps understanding log reports - self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) - self.session_pool.start(self.config['default_servers']) - self.on_header.listen(self._update_remote_height) - - async def stop(self): - if self.running: - self.running = False - self._switch_task.cancel() - self.session_pool.stop() - - @property - def is_connected(self): - return self.client and not self.client.is_closing() - - def rpc(self, list_or_method, args, restricted=True): - session = self.client if restricted else self.session_pool.fastest_session - if session and not session.is_closing(): - return session.send_request(list_or_method, args) - else: - self.session_pool.trigger_nodelay_connect() - raise ConnectionError("Attempting to send rpc request when connection is not available.") - - async def retriable_call(self, function, *args, **kwargs): - async with self._concurrency: - while self.running: - if not self.is_connected: - log.warning("Wallet server unavailable, waiting for it to come back and retry.") - await self.on_connected.first - await self.session_pool.wait_for_fastest_session() - try: - return await function(*args, **kwargs) - except asyncio.TimeoutError: - log.warning("Wallet server call timed out, retrying.") - except ConnectionError: - pass - raise asyncio.CancelledError() # if we got here, we are shutting down - - def _update_remote_height(self, header_args): - self.remote_height = header_args[0]["height"] - - def get_transaction(self, tx_hash, known_height=None): - # use any server if its old, otherwise restrict to who gave us the history - restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get', [hexlify(tx_hash[::-1]).decode()], restricted) - - def get_transaction_and_merkle(self, tx_hash, known_height=None): - # use any server if its old, otherwise restrict to who gave us the history - restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.info', [hexlify(tx_hash[::-1]).decode()], restricted) - - def get_transaction_height(self, tx_hash, known_height=None): - restricted = not known_height or 0 > known_height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_height', [hexlify(tx_hash[::-1]).decode()], restricted) - - def get_merkle(self, tx_hash, height): - restricted = 0 > height > self.remote_height - 10 - return self.rpc('blockchain.transaction.get_merkle', [hexlify(tx_hash[::-1]).decode(), height], restricted) - - def get_headers(self, height, count=10000, b64=False): - restricted = height >= self.remote_height - 100 - return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted) - - # --- Subscribes, history and broadcasts are always aimed towards the master client directly - def get_history(self, address): - return self.rpc('blockchain.address.get_history', [address], True) - - def broadcast(self, raw_transaction): - return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True) - - def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True], True) - - async def subscribe_address(self, address, *addresses): - addresses = list((address, ) + addresses) - try: - return await self.rpc('blockchain.address.subscribe', addresses, True) - except asyncio.TimeoutError: - log.warning( - "timed out subscribing to addresses from %s:%i", - *self.client.server_address_and_port - ) - # abort and cancel, we can't lose a subscription, it will happen again on reconnect - if self.client: - self.client.abort() - raise asyncio.CancelledError() - - def unsubscribe_address(self, address): - return self.rpc('blockchain.address.unsubscribe', [address], True) - - def get_server_features(self): - return self.rpc('server.features', (), restricted=True) - - def get_claims_by_ids(self, claim_ids): - return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids) - - def resolve(self, urls): - return self.rpc('blockchain.claimtrie.resolve', urls) - - def claim_search(self, **kwargs): - return self.rpc('blockchain.claimtrie.search', kwargs) - - -class SessionPool: - - def __init__(self, network: Network, timeout: float): - self.network = network - self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() - self.timeout = timeout - self.new_connection_event = asyncio.Event() - - @property - def online(self): - return any(not session.is_closing() for session in self.sessions) - - @property - def available_sessions(self): - return (session for session in self.sessions if session.available) - - @property - def fastest_session(self): - if not self.online: - return None - return min( - [((session.response_time + session.connection_latency) * (session.pending_amount + 1), session) - for session in self.available_sessions] or [(0, None)], - key=itemgetter(0) - )[1] - - def _get_session_connect_callback(self, session: ClientSession): - loop = asyncio.get_event_loop() - - def callback(): - duplicate_connections = [ - s for s in self.sessions - if s is not session and s.server_address_and_port == session.server_address_and_port - ] - already_connected = None if not duplicate_connections else duplicate_connections[0] - if already_connected: - self.sessions.pop(session).cancel() - session.synchronous_close() - log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour", - session.server[0], already_connected.server[0]) - loop.call_later(3600, self._connect_session, session.server) - return - self.new_connection_event.set() - log.info("connected to %s:%i", *session.server) - - return callback - - def _connect_session(self, server: Tuple[str, int]): - session = None - for s in self.sessions: - if s.server == server: - session = s - break - if not session: - session = ClientSession( - network=self.network, server=server - ) - session._on_connect_cb = self._get_session_connect_callback(session) - task = self.sessions.get(session, None) - if not task or task.done(): - task = asyncio.create_task(session.ensure_session()) - task.add_done_callback(lambda _: self.ensure_connections()) - self.sessions[session] = task - - def start(self, default_servers): - for server in default_servers: - self._connect_session(server) - - def stop(self): - for session, task in self.sessions.items(): - task.cancel() - session.synchronous_close() - self.sessions.clear() - - def ensure_connections(self): - for session in self.sessions: - self._connect_session(session.server) - - def trigger_nodelay_connect(self): - # used when other parts of the system sees we might have internet back - # bypasses the retry interval - for session in self.sessions: - session.trigger_urgent_reconnect.set() - - async def wait_for_fastest_session(self): - while not self.fastest_session: - self.trigger_nodelay_connect() - self.new_connection_event.clear() - await self.new_connection_event.wait() - return self.fastest_session diff --git a/lbry/wallet/sync.py b/lbry/wallet/sync.py new file mode 100644 index 000000000..7f3059303 --- /dev/null +++ b/lbry/wallet/sync.py @@ -0,0 +1,431 @@ +import asyncio +import logging +from io import StringIO +from functools import partial +from operator import itemgetter +from collections import defaultdict +from binascii import hexlify, unhexlify +from typing import List, Optional, DefaultDict, NamedTuple + +import pylru +from lbry.crypto.hash import double_sha256, sha256 + +from lbry.service.api import Client +from lbry.tasks import TaskGroup +from lbry.blockchain.transaction import Transaction +from lbry.blockchain.ledger import Ledger +from lbry.blockchain.block import get_block_filter +from lbry.db import Database +from lbry.event import EventController +from lbry.service.base import Service, Sync + +from .account import Account, AddressManager + + +class TransactionEvent(NamedTuple): + address: str + tx: Transaction + + +class AddressesGeneratedEvent(NamedTuple): + address_manager: AddressManager + addresses: List[str] + + +class TransactionCacheItem: + __slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications' + + def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None): + self.has_tx = asyncio.Event() + self.lock = lock or asyncio.Lock() + self._tx = self.tx = tx + self.pending_verifications = 0 + + @property + def tx(self) -> Optional[Transaction]: + return self._tx + + @tx.setter + def tx(self, tx: Transaction): + self._tx = tx + if tx is not None: + self.has_tx.set() + + +class SPVSync(Sync): + + def __init__(self, service: Service): + super().__init__(service) + return + self.headers = headers + self.network: Network = self.config.get('network') or Network(self) + self.network.on_header.listen(self.receive_header) + self.network.on_status.listen(self.process_status_update) + self.network.on_connected.listen(self.join_network) + + self.accounts = [] + + self.on_address = self.ledger.on_address + + self._on_header_controller = EventController() + self.on_header = self._on_header_controller.stream + self.on_header.listen( + lambda change: log.info( + '%s: added %s header blocks, final height %s', + self.ledger.get_id(), change, self.headers.height + ) + ) + self._download_height = 0 + + self._on_ready_controller = EventController() + self.on_ready = self._on_ready_controller.stream + + self._tx_cache = pylru.lrucache(100000) + self._update_tasks = TaskGroup() + self._other_tasks = TaskGroup() # that we dont need to start + self._header_processing_lock = asyncio.Lock() + self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._known_addresses_out_of_sync = set() + + async def advance(self): + address_array = [ + bytearray(a['address'].encode()) + for a in await self.service.db.get_all_addresses() + ] + block_filters = await self.service.get_block_address_filters() + for block_hash, block_filter in block_filters.items(): + bf = get_block_filter(block_filter) + if bf.MatchAny(address_array): + print(f'match: {block_hash} - {block_filter}') + tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash) + for txid, tx_filter in tx_filters.items(): + tf = get_block_filter(tx_filter) + if tf.MatchAny(address_array): + print(f' match: {txid} - {tx_filter}') + txs = await self.service.search_transactions([txid]) + tx = Transaction(unhexlify(txs[txid])) + await self.service.db.insert_transaction(tx) + + async def get_local_status_and_history(self, address, history=None): + if not history: + address_details = await self.db.get_address(address=address) + history = (address_details['history'] if address_details else '') or '' + parts = history.split(':')[:-1] + return ( + hexlify(sha256(history.encode())).decode() if history else None, + list(zip(parts[0::2], map(int, parts[1::2]))) + ) + + @staticmethod + def get_root_of_merkle_tree(branches, branch_positions, working_branch): + for i, branch in enumerate(branches): + other_branch = unhexlify(branch)[::-1] + other_branch_on_left = bool((branch_positions >> i) & 1) + if other_branch_on_left: + combined = other_branch + working_branch + else: + combined = working_branch + other_branch + working_branch = double_sha256(combined) + return hexlify(working_branch[::-1]) + + async def start(self): + await self.headers.open() + fully_synced = self.on_ready.first + asyncio.create_task(self.network.start()) + await self.network.on_connected.first + async with self._header_processing_lock: + await self._update_tasks.add(self.initial_headers_sync()) + await fully_synced + + async def join_network(self, *_): + log.info("Subscribing and updating accounts.") + await self._update_tasks.add(self.subscribe_accounts()) + await self._update_tasks.done.wait() + self._on_ready_controller.add(True) + + async def stop(self): + self._update_tasks.cancel() + self._other_tasks.cancel() + await self._update_tasks.done.wait() + await self._other_tasks.done.wait() + await self.network.stop() + await self.headers.close() + + @property + def local_height_including_downloaded_height(self): + return max(self.headers.height, self._download_height) + + async def initial_headers_sync(self): + get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) + self.headers.chunk_getter = get_chunk + + async def doit(): + for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)): + async with self._header_processing_lock: + await self.headers.ensure_chunk_at(height) + self._other_tasks.add(doit()) + await self.update_headers() + + async def update_headers(self, height=None, headers=None, subscription_update=False): + rewound = 0 + while True: + + if height is None or height > len(self.headers): + # sometimes header subscription updates are for a header in the future + # which can't be connected, so we do a normal header sync instead + height = len(self.headers) + headers = None + subscription_update = False + + if not headers: + header_response = await self.network.retriable_call(self.network.get_headers, height, 2001) + headers = header_response['hex'] + + if not headers: + # Nothing to do, network thinks we're already at the latest height. + return + + added = await self.headers.connect(height, unhexlify(headers)) + if added > 0: + height += added + self._on_header_controller.add( + BlockHeightEvent(self.headers.height, added)) + + if rewound > 0: + # we started rewinding blocks and apparently found + # a new chain + rewound = 0 + await self.db.rewind_blockchain(height) + + if subscription_update: + # subscription updates are for latest header already + # so we don't need to check if there are newer / more + # on another loop of update_headers(), just return instead + return + + elif added == 0: + # we had headers to connect but none got connected, probably a reorganization + height -= 1 + rewound += 1 + log.warning( + "Blockchain Reorganization: attempting rewind to height %s from starting height %s", + height, height+rewound + ) + + else: + raise IndexError(f"headers.connect() returned negative number ({added})") + + if height < 0: + raise IndexError( + "Blockchain reorganization rewound all the way back to genesis hash. " + "Something is very wrong. Maybe you are on the wrong blockchain?" + ) + + if rewound >= 100: + raise IndexError( + "Blockchain reorganization dropped {} headers. This is highly unusual. " + "Will not continue to attempt reorganizing. Please, delete the ledger " + "synchronization directory inside your wallet directory (folder: '{}') and " + "restart the program to synchronize from scratch." + .format(rewound, self.ledger.get_id()) + ) + + headers = None # ready to download some more headers + + # if we made it this far and this was a subscription_update + # it means something went wrong and now we're doing a more + # robust sync, turn off subscription update shortcut + subscription_update = False + + async def receive_header(self, response): + async with self._header_processing_lock: + header = response[0] + await self.update_headers( + height=header['height'], headers=header['hex'], subscription_update=True + ) + + async def subscribe_accounts(self): + if self.network.is_connected and self.accounts: + log.info("Subscribe to %i accounts", len(self.accounts)) + await asyncio.wait([ + self.subscribe_account(a) for a in self.accounts + ]) + + async def subscribe_account(self, account: Account): + for address_manager in account.address_managers.values(): + await self.subscribe_addresses(address_manager, await address_manager.get_addresses()) + await account.ensure_address_gap() + + async def unsubscribe_account(self, account: Account): + for address in await account.get_addresses(): + await self.network.unsubscribe_address(address) + + async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000): + if self.network.is_connected and addresses: + addresses_remaining = list(addresses) + while addresses_remaining: + batch = addresses_remaining[:batch_size] + results = await self.network.subscribe_address(*batch) + for address, remote_status in zip(batch, results): + self._update_tasks.add(self.update_history(address, remote_status, address_manager)) + addresses_remaining = addresses_remaining[batch_size:] + log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining), + len(addresses), *self.network.client.server_address_and_port) + log.info( + "finished subscribing to %i addresses on %s:%i", len(addresses), + *self.network.client.server_address_and_port + ) + + def process_status_update(self, update): + address, remote_status = update + self._update_tasks.add(self.update_history(address, remote_status)) + + async def update_history(self, address, remote_status, address_manager: AddressManager = None): + async with self._address_update_locks[address]: + self._known_addresses_out_of_sync.discard(address) + + local_status, local_history = await self.get_local_status_and_history(address) + + if local_status == remote_status: + return True + + remote_history = await self.network.retriable_call(self.network.get_history, address) + remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history)) + we_need = set(remote_history) - set(local_history) + if not we_need: + return True + + cache_tasks: List[asyncio.Task[Transaction]] = [] + synced_history = StringIO() + loop = asyncio.get_running_loop() + for i, (txid, remote_height) in enumerate(remote_history): + if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks: + synced_history.write(f'{txid}:{remote_height}:') + else: + check_local = (txid, remote_height) not in we_need + cache_tasks.append(loop.create_task( + self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local) + )) + + synced_txs = [] + for task in cache_tasks: + tx = await task + + check_db_for_txos = [] + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash) + if cache_item is not None: + if cache_item.tx is None: + await cache_item.has_tx.wait() + assert cache_item.tx is not None + txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref + else: + check_db_for_txos.append(txi.txo_ref.hash) + + referenced_txos = {} if not check_db_for_txos else { + txo.id: txo for txo in await self.db.get_txos( + txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True + ) + } + + for txi in tx.inputs: + if txi.txo_ref.txo is not None: + continue + referenced_txo = referenced_txos.get(txi.txo_ref.id) + if referenced_txo is not None: + txi.txo_ref = referenced_txo.ref + + synced_history.write(f'{tx.id}:{tx.height}:') + synced_txs.append(tx) + + await self.db.save_transaction_io_batch( + synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue() + ) + await asyncio.wait([ + self.ledger._on_transaction_controller.add(TransactionEvent(address, tx)) + for tx in synced_txs + ]) + + if address_manager is None: + address_manager = await self.get_address_manager_for_address(address) + + if address_manager is not None: + await address_manager.ensure_address_gap() + + local_status, local_history = \ + await self.get_local_status_and_history(address, synced_history.getvalue()) + if local_status != remote_status: + if local_history == remote_history: + return True + log.warning( + "Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items", + remote_status, len(remote_history), local_status, len(local_history) + ) + log.warning("local: %s", local_history) + log.warning("remote: %s", remote_history) + self._known_addresses_out_of_sync.add(address) + return False + else: + return True + + async def cache_transaction(self, tx_hash, remote_height, check_local=True): + cache_item = self._tx_cache.get(tx_hash) + if cache_item is None: + cache_item = self._tx_cache[tx_hash] = TransactionCacheItem() + elif cache_item.tx is not None and \ + cache_item.tx.height >= remote_height and \ + (cache_item.tx.is_verified or remote_height < 1): + return cache_item.tx # cached tx is already up-to-date + + try: + cache_item.pending_verifications += 1 + return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local) + finally: + cache_item.pending_verifications -= 1 + + async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True): + + async with cache_item.lock: + + tx = cache_item.tx + + if tx is None and check_local: + # check local db + tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash) + + merkle = None + if tx is None: + # fetch from network + _raw, merkle = await self.network.retriable_call( + self.network.get_transaction_and_merkle, tx_hash, remote_height + ) + tx = Transaction(unhexlify(_raw), height=merkle.get('block_height')) + cache_item.tx = tx # make sure it's saved before caching it + await self.maybe_verify_transaction(tx, remote_height, merkle) + return tx + + async def maybe_verify_transaction(self, tx, remote_height, merkle=None): + tx.height = remote_height + cached = self._tx_cache.get(tx.hash) + if not cached: + # cache txs looked up by transaction_show too + cached = TransactionCacheItem() + cached.tx = tx + self._tx_cache[tx.hash] = cached + if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1: + # can't be tx.pending_verifications == 1 because we have to handle the transaction_show case + if not merkle: + merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height) + merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) + header = await self.headers.get(remote_height) + tx.position = merkle['pos'] + tx.is_verified = merkle_root == header['merkle_root'] + + async def get_address_manager_for_address(self, address) -> Optional[AddressManager]: + details = await self.db.get_address(address=address) + for account in self.accounts: + if account.id == details['account']: + return account.address_managers[details['chain']] + return None diff --git a/lbry/wallet/usage_payment.py b/lbry/wallet/usage_payment.py index 67c740260..b1b9447a3 100644 --- a/lbry/wallet/usage_payment.py +++ b/lbry/wallet/usage_payment.py @@ -6,9 +6,9 @@ from lbry.error import ( ServerPaymentInvalidAddressError, ServerPaymentWalletLockedError ) -from lbry.wallet.dewies import lbc_to_dewies -from lbry.wallet.stream import StreamController -from lbry.wallet.transaction import Output, Transaction +from lbry.blockchain.dewies import lbc_to_dewies +from lbry.event import EventController +from lbry.blockchain.transaction import Output, Transaction log = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class WalletServerPayer: self.payment_period = payment_period self.analytics_manager = analytics_manager self.max_fee = max_fee - self._on_payment_controller = StreamController() + self._on_payment_controller = EventController() self.on_payment = self._on_payment_controller.stream self.on_payment.listen(None, on_error=lambda e: logging.warning(e.args[0])) diff --git a/lbry/wallet/wallet.py b/lbry/wallet/wallet.py index 8f3cecc57..5bfd8e457 100644 --- a/lbry/wallet/wallet.py +++ b/lbry/wallet/wallet.py @@ -5,16 +5,28 @@ import json import zlib import typing import logging -from typing import List, Sequence, MutableSequence, Optional +from typing import List, Sequence, MutableSequence, Optional, Iterable from collections import UserDict from hashlib import sha256 from operator import attrgetter +from decimal import Decimal + +from lbry.db import Database +from lbry.blockchain.ledger import Ledger +from lbry.constants import COIN, NULL_HASH32 +from lbry.blockchain.transaction import Transaction, Input, Output +from lbry.blockchain.dewies import dewies_to_lbc from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt +from lbry.crypto.bip32 import PubKey, PrivateKey +from lbry.schema.claim import Claim +from lbry.schema.purchase import Purchase +from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError + from .account import Account +from .coinselection import CoinSelector, OutputEffectiveAmountEstimator if typing.TYPE_CHECKING: - from lbry.wallet.manager import WalletManager - from lbry.wallet.ledger import Ledger + from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager log = logging.getLogger(__name__) @@ -67,8 +79,11 @@ class Wallet: preferences: TimestampedPreferences encryption_password: Optional[str] - def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None, + def __init__(self, ledger: Ledger, db: Database, + name: str = 'Wallet', accounts: MutableSequence[Account] = None, storage: 'WalletStorage' = None, preferences: dict = None) -> None: + self.ledger = ledger + self.db = db self.name = name self.accounts = accounts or [] self.storage = storage or WalletStorage() @@ -79,30 +94,34 @@ class Wallet: def get_id(self): return os.path.basename(self.storage.path) if self.storage.path else self.name - def add_account(self, account: 'Account'): + def generate_account(self, name: str = None, address_generator: dict = None) -> Account: + account = Account.generate(self.ledger, self.db, name, address_generator) self.accounts.append(account) + return account - def generate_account(self, ledger: 'Ledger') -> 'Account': - return Account.generate(ledger, self) + def add_account(self, account_dict) -> Account: + account = Account.from_dict(self.ledger, self.db, account_dict) + self.accounts.append(account) + return account @property - def default_account(self) -> Optional['Account']: + def default_account(self) -> Optional[Account]: for account in self.accounts: return account return None - def get_account_or_default(self, account_id: str) -> Optional['Account']: + def get_account_or_default(self, account_id: str) -> Optional[Account]: if account_id is None: return self.default_account return self.get_account_or_error(account_id) - def get_account_or_error(self, account_id: str) -> 'Account': + def get_account_or_error(self, account_id: str) -> Account: for account in self.accounts: if account.id == account_id: return account raise ValueError(f"Couldn't find account: {account_id}.") - def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']: + def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]: return [ self.get_account_or_error(account_id) for account_id in account_ids @@ -116,24 +135,63 @@ class Wallet: accounts.append(details) return accounts + async def _get_account_and_address_info_for_address(self, address): + match = await self.db.get_address(accounts=self.accounts, address=address) + if match: + for account in self.accounts: + if match['account'] == account.public_key.address: + return account, match + + async def get_private_key_for_address(self, address) -> Optional[PrivateKey]: + match = await self._get_account_and_address_info_for_address(address) + if match: + account, address_info = match + return account.get_private_key(address_info['chain'], address_info['pubkey'].n) + return None + + async def get_public_key_for_address(self, address) -> Optional[PubKey]: + match = await self._get_account_and_address_info_for_address(address) + if match: + _, address_info = match + return address_info['pubkey'] + return None + + async def get_account_for_address(self, address): + match = await self._get_account_and_address_info_for_address(address) + if match: + return match[0] + + async def save_max_gap(self): + gap_changed = False + for account in self.accounts: + if await account.save_max_gap(): + gap_changed = True + if gap_changed: + self.save() + @classmethod - def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet': + def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet': json_dict = storage.read() + if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id(): + raise ValueError( + f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}." + ) wallet = cls( + ledger, db, name=json_dict.get('name', 'Wallet'), preferences=json_dict.get('preferences', {}), storage=storage ) account_dicts: Sequence[dict] = json_dict.get('accounts', []) for account_dict in account_dicts: - ledger = manager.get_or_create_ledger(account_dict['ledger']) - Account.from_dict(ledger, wallet, account_dict) + wallet.add_account(account_dict) return wallet def to_dict(self, encrypt_password: str = None): return { 'version': WalletStorage.LATEST_VERSION, 'name': self.name, + 'ledger': self.ledger.get_id(), 'preferences': self.preferences.data, 'accounts': [a.to_dict(encrypt_password) for a in self.accounts] } @@ -173,15 +231,13 @@ class Wallet: decompressed = zlib.decompress(decrypted) return json.loads(decompressed) - def merge(self, manager: 'WalletManager', - password: str, data: str) -> List['Account']: + def merge(self, password: str, data: str) -> List[Account]: assert not self.is_locked, "Cannot sync apply on a locked wallet." added_accounts = [] decrypted_data = self.unpack(password, data) self.preferences.merge(decrypted_data.get('preferences', {})) for account_dict in decrypted_data['accounts']: - ledger = manager.get_or_create_ledger(account_dict['ledger']) - _, _, pubkey = Account.keys_from_dict(ledger, account_dict) + _, _, pubkey = Account.keys_from_dict(self.ledger, account_dict) account_id = pubkey.address local_match = None for local_account in self.accounts: @@ -191,8 +247,9 @@ class Wallet: if local_match is not None: local_match.merge(account_dict) else: - new_account = Account.from_dict(ledger, self, account_dict) - added_accounts.append(new_account) + added_accounts.append( + self.add_account(account_dict) + ) return added_accounts @property @@ -235,6 +292,203 @@ class Wallet: self.save() return True + async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]): + estimators = [] + for utxo in (await self.db.get_utxos(accounts=funding_accounts))[0]: + estimators.append(OutputEffectiveAmountEstimator(self.ledger, utxo)) + return estimators + + async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]): + txos = await self.get_effective_amount_estimators(funding_accounts) + fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger) + selector = CoinSelector(amount, fee) + spendables = selector.select(txos, self.ledger.coin_selection_strategy) + if spendables: + await self.db.reserve_outputs(s.txo for s in spendables) + return spendables + + async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output], + funding_accounts: Iterable[Account], change_account: Account, + sign: bool = True): + """ Find optimal set of inputs when only outputs are provided; add change + outputs if only inputs are provided or if inputs are greater than outputs. """ + + tx = Transaction() \ + .add_inputs(inputs) \ + .add_outputs(outputs) + + # value of the outputs plus associated fees + cost = ( + tx.get_base_fee(self.ledger) + + tx.get_total_output_sum(self.ledger) + ) + # value of the inputs less the cost to spend those inputs + payment = tx.get_effective_input_sum(self.ledger) + + try: + + for _ in range(5): + + if payment < cost: + deficit = cost - payment + spendables = await self.get_spendable_utxos(deficit, funding_accounts) + if not spendables: + raise InsufficientFundsError() + payment += sum(s.effective_amount for s in spendables) + tx.add_inputs(s.txi for s in spendables) + + cost_of_change = ( + tx.get_base_fee(self.ledger) + + Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger) + ) + if payment > cost: + change = payment - cost + if change > cost_of_change: + change_address = await change_account.change.get_or_create_usable_address() + change_hash160 = change_account.ledger.address_to_hash160(change_address) + change_amount = change - cost_of_change + change_output = Output.pay_pubkey_hash(change_amount, change_hash160) + change_output.is_internal_transfer = True + tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)]) + + if tx._outputs: + break + # this condition and the outer range(5) loop cover an edge case + # whereby a single input is just enough to cover the fee and + # has some change left over, but the change left over is less + # than the cost_of_change: thus the input is completely + # consumed and no output is added, which is an invalid tx. + # to be able to spend this input we must increase the cost + # of the TX and run through the balance algorithm a second time + # adding an extra input and change output, making tx valid. + # we do this 5 times in case the other UTXOs added are also + # less than the fee, after 5 attempts we give up and go home + cost += cost_of_change + 1 + + if sign: + await self.sign(tx) + + except Exception as e: + log.exception('Failed to create transaction:') + await self.db.release_tx(tx) + raise e + + return tx + + async def sign(self, tx): + for i, txi in enumerate(tx._inputs): + assert txi.script is not None + assert txi.txo_ref.txo is not None + txo_script = txi.txo_ref.txo.script + if txo_script.is_pay_pubkey_hash: + address = self.ledger.hash160_to_address(txo_script.values['pubkey_hash']) + private_key = await self.get_private_key_for_address(address) + assert private_key is not None, 'Cannot find private key for signing output.' + serialized = tx._serialize_for_signature(i) + txi.script.values['signature'] = \ + private_key.sign(serialized) + bytes((tx.signature_hash_type(1),)) + txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes + txi.script.generate() + else: + raise NotImplementedError("Don't know how to spend this output.") + tx._reset() + + @classmethod + def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'): + output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address)) + return cls.create([], [output], funding_accounts, change_account) + + def claim_create( + self, name: str, claim: Claim, amount: int, holding_address: str, + funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): + claim_output = Output.pay_claim_name_pubkey_hash( + amount, name, claim, self.ledger.address_to_hash160(holding_address) + ) + if signing_channel is not None: + claim_output.sign(signing_channel, b'placeholder txid:nout') + return self.create_transaction( + [], [claim_output], funding_accounts, change_account, sign=False + ) + + @classmethod + def claim_update( + cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str, + funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None): + updated_claim = Output.pay_update_claim_pubkey_hash( + amount, previous_claim.claim_name, previous_claim.claim_id, + claim, ledger.address_to_hash160(holding_address) + ) + if signing_channel is not None: + updated_claim.sign(signing_channel, b'placeholder txid:nout') + else: + updated_claim.clear_signature() + return cls.create( + [Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False + ) + + @classmethod + def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str, + funding_accounts: List['Account'], change_account: 'Account'): + support_output = Output.pay_support_pubkey_hash( + amount, claim_name, claim_id, ledger.address_to_hash160(holding_address) + ) + return cls.create([], [support_output], funding_accounts, change_account) + + def purchase(self, claim_id: str, amount: int, merchant_address: bytes, + funding_accounts: List['Account'], change_account: 'Account'): + payment = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(merchant_address)) + data = Output.add_purchase_data(Purchase(claim_id)) + return self.create_transaction( + [], [payment, data], funding_accounts, change_account + ) + + async def create_purchase_transaction( + self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager', + override_max_key_fee=False): + fee = txo.claim.stream.fee + fee_amount = exchange.to_dewies(fee.currency, fee.amount) + if not override_max_key_fee and self.ledger.conf.max_key_fee: + max_fee = self.ledger.conf.max_key_fee + max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount'])) + if max_fee_amount and fee_amount > max_fee_amount: + error_fee = f"{dewies_to_lbc(fee_amount)} LBC" + if fee.currency != 'LBC': + error_fee += f" ({fee.amount} {fee.currency})" + error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC" + if max_fee['currency'] != 'LBC': + error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})" + raise KeyFeeAboveMaxAllowedError( + f"Purchase price of {error_fee} exceeds maximum " + f"configured price of {error_max_fee}." + ) + fee_address = fee.address or txo.get_address(self.ledger) + return await self.purchase( + txo.claim_id, fee_amount, fee_address, accounts, accounts[0] + ) + + async def create_channel( + self, name, amount, account, funding_accounts, + claim_address, preview=False, **kwargs): + + claim = Claim() + claim.channel.update(**kwargs) + tx = await self.claim_create( + name, claim, amount, claim_address, funding_accounts, funding_accounts[0] + ) + txo = tx.outputs[0] + txo.generate_channel_private_key() + + await self.sign(tx) + + if not preview: + account.add_channel_private_key(txo.private_key) + self.save() + + return tx + + async def get_channels(self): + return await self.db.get_channels() + class WalletStorage: