wip lbry.wallet

This commit is contained in:
Lex Berezhny 2020-05-01 09:33:10 -04:00
parent fef09c1773
commit 533f31cc89
8 changed files with 800 additions and 746 deletions

View file

@ -1,17 +0,0 @@
__node_daemon__ = 'lbrycrdd'
__node_cli__ = 'lbrycrd-cli'
__node_bin__ = ''
__node_url__ = (
'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.4/lbrycrd-linux-1744.zip'
)
__spvserver__ = 'lbry.wallet.server.coin.LBCRegTest'
from .bip32 import PubKey
from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DISK
from .manager import WalletManager
from .network import Network
from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent
from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic
from .transaction import Transaction, Output, Input
from .script import OutputScript, InputScript
from .header import Headers

View file

@ -2,7 +2,6 @@ import os
import time
import json
import logging
import typing
import asyncio
import random
from functools import partial
@ -16,14 +15,15 @@ import ecdsa
from lbry.error import InvalidPasswordError
from lbry.crypto.crypt import aes_encrypt, aes_decrypt
from .bip32 import PrivateKey, PubKey, from_extended_key_string
from .mnemonic import Mnemonic
from .constants import COIN, CLAIM_TYPES, TXO_TYPES
from .transaction import Transaction, Input, Output
from lbry.crypto.bip32 import PrivateKey, PubKey, from_extended_key_string
from lbry.constants import COIN
from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.ledger import Ledger
from lbry.db import Database
from lbry.db.constants import CLAIM_TYPE_CODES, TXO_TYPES
from .mnemonic import Mnemonic
if typing.TYPE_CHECKING:
from .ledger import Ledger
from .wallet import Wallet
log = logging.getLogger(__name__)
@ -71,12 +71,12 @@ class AddressManager:
def to_dict_instance(self) -> Optional[dict]:
raise NotImplementedError
def _query_addresses(self, **constraints):
return self.account.ledger.db.get_addresses(
accounts=[self.account],
async def _query_addresses(self, **constraints):
return (await self.account.db.get_addresses(
account=self.account,
chain=self.chain_number,
**constraints
)
))[0]
def get_private_key(self, index: int) -> PrivateKey:
raise NotImplementedError
@ -166,22 +166,22 @@ class HierarchicalDeterministic(AddressManager):
start = addresses[0]['pubkey'].n+1 if addresses else 0
end = start + (self.gap - existing_gap)
new_keys = await self._generate_keys(start, end-1)
await self.account.ledger.announce_addresses(self, new_keys)
#await self.account.ledger.announce_addresses(self, new_keys)
return new_keys
async def _generate_keys(self, start: int, end: int) -> List[str]:
if not self.address_generator_lock.locked():
raise RuntimeError('Should not be called outside of address_generator_lock.')
keys = [self.public_key.child(index) for index in range(start, end+1)]
await self.account.ledger.db.add_keys(self.account, self.chain_number, keys)
await self.account.db.add_keys(self.account, self.chain_number, keys)
return [key.address for key in keys]
def get_address_records(self, only_usable: bool = False, **constraints):
async def get_address_records(self, only_usable: bool = False, **constraints):
if only_usable:
constraints['used_times__lt'] = self.maximum_uses_per_address
if 'order_by' not in constraints:
constraints['order_by'] = "used_times asc, n asc"
return self._query_addresses(**constraints)
return await self._query_addresses(**constraints)
class SingleKey(AddressManager):
@ -213,14 +213,14 @@ class SingleKey(AddressManager):
async with self.address_generator_lock:
exists = await self.get_address_records()
if not exists:
await self.account.ledger.db.add_keys(self.account, self.chain_number, [self.public_key])
await self.account.db.add_keys(self.account, self.chain_number, [self.public_key])
new_keys = [self.public_key.address]
await self.account.ledger.announce_addresses(self, new_keys)
#await self.account.ledger.announce_addresses(self, new_keys)
return new_keys
return []
def get_address_records(self, only_usable: bool = False, **constraints):
return self._query_addresses(**constraints)
async def get_address_records(self, only_usable: bool = False, **constraints):
return await self._query_addresses(**constraints)
class Account:
@ -233,12 +233,12 @@ class Account:
HierarchicalDeterministic.name: HierarchicalDeterministic,
}
def __init__(self, ledger: 'Ledger', wallet: 'Wallet', name: str,
def __init__(self, ledger: 'Ledger', db: 'Database', name: str,
seed: str, private_key_string: str, encrypted: bool,
private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict, modified_on: float, channel_keys: dict) -> None:
self.ledger = ledger
self.wallet = wallet
self.db = db
self.id = public_key.address
self.name = name
self.seed = seed
@ -253,8 +253,6 @@ class Account:
self.receiving, self.change = self.address_generator.from_dict(self, address_generator)
self.address_managers = {am.chain_number: am for am in {self.receiving, self.change}}
self.channel_keys = channel_keys
ledger.add_account(self)
wallet.add_account(self)
def get_init_vector(self, key) -> Optional[bytes]:
init_vector = self.init_vectors.get(key, None)
@ -263,9 +261,9 @@ class Account:
return init_vector
@classmethod
def generate(cls, ledger: 'Ledger', wallet: 'Wallet',
def generate(cls, ledger: 'Ledger', db: 'Database',
name: str = None, address_generator: dict = None):
return cls.from_dict(ledger, wallet, {
return cls.from_dict(ledger, db, {
'name': name,
'seed': cls.mnemonic_class().make_seed(),
'address_generator': address_generator or {}
@ -297,14 +295,14 @@ class Account:
return seed, private_key, public_key
@classmethod
def from_dict(cls, ledger: 'Ledger', wallet: 'Wallet', d: dict):
def from_dict(cls, ledger: 'Ledger', db: 'Database', d: dict):
seed, private_key, public_key = cls.keys_from_dict(ledger, d)
name = d.get('name')
if not name:
name = f'Account #{public_key.address}'
return cls(
ledger=ledger,
wallet=wallet,
db=db,
name=name,
seed=seed,
private_key_string=d.get('private_key', ''),
@ -328,7 +326,6 @@ class Account:
if seed:
seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
d = {
'ledger': self.ledger.get_id(),
'name': self.name,
'seed': seed,
'encrypted': bool(self.encrypted or encrypt_password),
@ -365,7 +362,6 @@ class Account:
details = {
'id': self.id,
'name': self.name,
'ledger': self.ledger.get_id(),
'coins': round(satoshis/COIN, 2),
'satoshis': satoshis,
'encrypted': self.encrypted,
@ -436,15 +432,18 @@ class Account:
addresses.extend(new_addresses)
return addresses
async def get_address_records(self, **constraints):
return await self.db.get_addresses(account=self, **constraints)
async def get_addresses(self, **constraints) -> List[str]:
rows = await self.ledger.db.select_addresses([text('account_address.address')], accounts=[self], **constraints)
rows, _ = await self.get_address_records(cols=['account_address.address'], **constraints)
return [r['address'] for r in rows]
def get_address_records(self, **constraints):
return self.ledger.db.get_addresses(accounts=[self], **constraints)
def get_address_count(self, **constraints):
return self.ledger.db.get_address_count(accounts=[self], **constraints)
async def get_valid_receiving_address(self, default_address: str) -> str:
if default_address is None:
return await self.receiving.get_or_create_usable_address()
self.ledger.valid_address_or_error(default_address)
return default_address
def get_private_key(self, chain: int, index: int) -> PrivateKey:
assert not self.encrypted, "Cannot get private key on encrypted wallet account."
@ -459,7 +458,7 @@ class Account:
if confirmations > 0:
height = self.ledger.headers.height - (confirmations-1)
constraints.update({'height__lte': height, 'height__gt': 0})
return self.ledger.db.get_balance(accounts=[self], **constraints)
return self.db.get_balance(account=self, **constraints)
async def get_max_gap(self):
change_gap = await self.change.get_max_gap()
@ -469,24 +468,6 @@ class Account:
'max_receiving_gap': receiving_gap,
}
def get_txos(self, **constraints):
return self.ledger.get_txos(wallet=self.wallet, accounts=[self], **constraints)
def get_txo_count(self, **constraints):
return self.ledger.get_txo_count(wallet=self.wallet, accounts=[self], **constraints)
def get_utxos(self, **constraints):
return self.ledger.get_utxos(wallet=self.wallet, accounts=[self], **constraints)
def get_utxo_count(self, **constraints):
return self.ledger.get_utxo_count(wallet=self.wallet, accounts=[self], **constraints)
def get_transactions(self, **constraints):
return self.ledger.get_transactions(wallet=self.wallet, accounts=[self], **constraints)
def get_transaction_count(self, **constraints):
return self.ledger.get_transaction_count(wallet=self.wallet, accounts=[self], **constraints)
async def fund(self, to_account, amount=None, everything=False,
outputs=1, broadcast=False, **constraints):
assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.'
@ -551,9 +532,9 @@ class Account:
self.wallet.save()
async def save_max_gap(self):
gap_changed = False
if issubclass(self.address_generator, HierarchicalDeterministic):
gap = await self.get_max_gap()
gap_changed = False
new_receiving_gap = max(20, gap['max_receiving_gap'] + 1)
if self.receiving.gap != new_receiving_gap:
self.receiving.gap = new_receiving_gap
@ -562,8 +543,10 @@ class Account:
if self.change.gap != new_change_gap:
self.change.gap = new_change_gap
gap_changed = True
if gap_changed:
self.wallet.save()
return gap_changed
def get_support_summary(self):
return self.db.get_supports_summary(account=self)
async def get_detailed_balance(self, confirmations=0, reserved_subtotals=False):
tips_balance, supports_balance, claims_balance = 0, 0, 0
@ -571,7 +554,7 @@ class Account:
include_claims=True)
total = await get_total_balance()
if reserved_subtotals:
claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPES)
claims_balance = await get_total_balance(txo_type__in=CLAIM_TYPE_CODES)
for txo in await self.get_support_summary():
if confirmations > 0 and not 0 < txo.tx_ref.height <= self.ledger.headers.height - (confirmations - 1):
continue
@ -594,49 +577,3 @@ class Account:
'tips': tips_balance
} if reserved_subtotals else None
}
def get_transaction_history(self, **constraints):
return self.ledger.get_transaction_history(
wallet=self.wallet, accounts=[self], **constraints
)
def get_transaction_history_count(self, **constraints):
return self.ledger.get_transaction_history_count(
wallet=self.wallet, accounts=[self], **constraints
)
def get_claims(self, **constraints):
return self.ledger.get_claims(wallet=self.wallet, accounts=[self], **constraints)
def get_claim_count(self, **constraints):
return self.ledger.get_claim_count(wallet=self.wallet, accounts=[self], **constraints)
def get_streams(self, **constraints):
return self.ledger.get_streams(wallet=self.wallet, accounts=[self], **constraints)
def get_stream_count(self, **constraints):
return self.ledger.get_stream_count(wallet=self.wallet, accounts=[self], **constraints)
def get_channels(self, **constraints):
return self.ledger.get_channels(wallet=self.wallet, accounts=[self], **constraints)
def get_channel_count(self, **constraints):
return self.ledger.get_channel_count(wallet=self.wallet, accounts=[self], **constraints)
def get_collections(self, **constraints):
return self.ledger.get_collections(wallet=self.wallet, accounts=[self], **constraints)
def get_collection_count(self, **constraints):
return self.ledger.get_collection_count(wallet=self.wallet, accounts=[self], **constraints)
def get_supports(self, **constraints):
return self.ledger.get_supports(wallet=self.wallet, accounts=[self], **constraints)
def get_support_count(self, **constraints):
return self.ledger.get_support_count(wallet=self.wallet, accounts=[self], **constraints)
def get_support_summary(self):
return self.ledger.db.get_supports_summary(wallet=self.wallet, accounts=[self])
async def release_all_outputs(self):
await self.ledger.db.release_all_outputs(self)

View file

@ -1,18 +1,32 @@
from random import Random
from typing import List
from lbry.wallet.transaction import OutputEffectiveAmountEstimator
from lbry.blockchain.transaction import Input, Output
MAXIMUM_TRIES = 100000
STRATEGIES = []
COIN_SELECTION_STRATEGIES = []
def strategy(method):
STRATEGIES.append(method.__name__)
COIN_SELECTION_STRATEGIES.append(method.__name__)
return method
class OutputEffectiveAmountEstimator:
__slots__ = 'txo', 'txi', 'fee', 'effective_amount'
def __init__(self, ledger, txo: Output) -> None:
self.txo = txo
self.txi = Input.spend(txo)
self.fee: int = self.txi.get_fee(ledger)
self.effective_amount: int = txo.amount - self.fee
def __lt__(self, other):
return self.effective_amount < other.effective_amount
class CoinSelector:
def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None:

View file

@ -10,16 +10,14 @@ from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.conf import Config
from .dewies import dewies_to_lbc
from .account import Account
from .ledger import Ledger, LedgerRegistry
from .transaction import Transaction, Output
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
from .rpc.jsonrpc import CodeMessageError
if typing.TYPE_CHECKING:
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.blockchain.ledger import Ledger
from lbry.db import Database
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.transaction import Transaction, Output
from .wallet import Wallet, WalletStorage, ENCRYPT_ON_DISK
log = logging.getLogger(__name__)
@ -27,35 +25,42 @@ log = logging.getLogger(__name__)
class WalletManager:
def __init__(self, wallets: MutableSequence[Wallet] = None,
def __init__(self, ledger: Ledger, db: Database,
wallets: MutableSequence[Wallet] = None,
ledgers: MutableMapping[Type[Ledger], Ledger] = None) -> None:
self.ledger = ledger
self.db = db
self.wallets = wallets or []
self.ledgers = ledgers or {}
self.running = False
self.config: Optional[Config] = None
@classmethod
def from_config(cls, config: dict) -> 'WalletManager':
manager = cls()
for ledger_id, ledger_config in config.get('ledgers', {}).items():
manager.get_or_create_ledger(ledger_id, ledger_config)
for wallet_path in config.get('wallets', []):
wallet_storage = WalletStorage(wallet_path)
wallet = Wallet.from_storage(wallet_storage, manager)
manager.wallets.append(wallet)
return manager
async def open(self):
conf = self.ledger.conf
def get_or_create_ledger(self, ledger_id, ledger_config=None):
ledger_class = LedgerRegistry.get_ledger_class(ledger_id)
ledger = self.ledgers.get(ledger_class)
if ledger is None:
ledger = ledger_class(ledger_config or {})
self.ledgers[ledger_class] = ledger
return ledger
wallets_directory = os.path.join(conf.wallet_dir, 'wallets')
if not os.path.exists(wallets_directory):
os.mkdir(wallets_directory)
for wallet_file in conf.wallets:
wallet_path = os.path.join(wallets_directory, wallet_file)
wallet_storage = WalletStorage(wallet_path)
wallet = Wallet.from_storage(self.ledger, self.db, wallet_storage)
self.wallets.append(wallet)
self.ledger.coin_selection_strategy = self.ledger.conf.coin_selection_strategy
default_wallet = self.default_wallet
if default_wallet.default_account is None:
log.info('Wallet at %s is empty, generating a default account.', default_wallet.id)
default_wallet.generate_account()
default_wallet.save()
if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None:
default_wallet.preferences[ENCRYPT_ON_DISK] = True
default_wallet.save()
def import_wallet(self, path):
storage = WalletStorage(path)
wallet = Wallet.from_storage(storage, self)
wallet = Wallet.from_storage(self.ledger, self.db, storage)
self.wallets.append(wallet)
return wallet
@ -104,123 +109,9 @@ class WalletManager:
return 0
return accounts[0].ledger.db.get_balance(wallet=wallet, accounts=accounts)
@property
def ledger(self) -> Ledger:
return self.default_account.ledger
@property
def db(self) -> 'Database':
return self.ledger.db
def check_locked(self):
return self.default_wallet.is_locked
@staticmethod
def migrate_lbryum_to_torba(path):
if not os.path.exists(path):
return None, None
with open(path, 'r') as f:
unmigrated_json = f.read()
unmigrated = json.loads(unmigrated_json)
# TODO: After several public releases of new torba based wallet, we can delete
# this lbryum->torba conversion code and require that users who still
# have old structured wallets install one of the earlier releases that
# still has the below conversion code.
if 'master_public_keys' not in unmigrated:
return None, None
total = unmigrated.get('addr_history')
receiving_addresses, change_addresses = set(), set()
for _, unmigrated_account in unmigrated.get('accounts', {}).items():
receiving_addresses.update(map(unhexlify, unmigrated_account.get('receiving', [])))
change_addresses.update(map(unhexlify, unmigrated_account.get('change', [])))
log.info("Wallet migrator found %s receiving addresses and %s change addresses. %s in total on history.",
len(receiving_addresses), len(change_addresses), len(total))
migrated_json = json.dumps({
'version': 1,
'name': 'My Wallet',
'accounts': [{
'version': 1,
'name': 'Main Account',
'ledger': 'lbc_mainnet',
'encrypted': unmigrated['use_encryption'],
'seed': unmigrated['seed'],
'seed_version': unmigrated['seed_version'],
'private_key': unmigrated['master_private_keys']['x/'],
'public_key': unmigrated['master_public_keys']['x/'],
'certificates': unmigrated.get('claim_certificates', {}),
'address_generator': {
'name': 'deterministic-chain',
'receiving': {'gap': 20, 'maximum_uses_per_address': 1},
'change': {'gap': 6, 'maximum_uses_per_address': 1}
}
}]
}, indent=4, sort_keys=True)
mode = os.stat(path).st_mode
i = 1
backup_path_template = os.path.join(os.path.dirname(path), "old_lbryum_wallet") + "_%i"
while os.path.isfile(backup_path_template % i):
i += 1
os.rename(path, backup_path_template % i)
temp_path = f"{path}.tmp.{os.getpid()}"
with open(temp_path, "w") as f:
f.write(migrated_json)
f.flush()
os.fsync(f.fileno())
os.rename(temp_path, path)
os.chmod(path, mode)
return receiving_addresses, change_addresses
@classmethod
async def from_lbrynet_config(cls, config: Config):
ledger_id = {
'lbrycrd_main': 'lbc_mainnet',
'lbrycrd_testnet': 'lbc_testnet',
'lbrycrd_regtest': 'lbc_regtest'
}[config.blockchain_name]
ledger_config = {
'auto_connect': True,
'default_servers': config.lbryum_servers,
'data_path': config.wallet_dir,
}
wallets_directory = os.path.join(config.wallet_dir, 'wallets')
if not os.path.exists(wallets_directory):
os.mkdir(wallets_directory)
receiving_addresses, change_addresses = cls.migrate_lbryum_to_torba(
os.path.join(wallets_directory, 'default_wallet')
)
manager = cls.from_config({
'ledgers': {ledger_id: ledger_config},
'wallets': [
os.path.join(wallets_directory, wallet_file) for wallet_file in config.wallets
]
})
manager.config = config
ledger = manager.get_or_create_ledger(ledger_id)
ledger.coin_selection_strategy = config.coin_selection_strategy
default_wallet = manager.default_wallet
if default_wallet.default_account is None:
log.info('Wallet at %s is empty, generating a default account.', default_wallet.id)
default_wallet.generate_account(ledger)
default_wallet.save()
if default_wallet.is_locked and default_wallet.preferences.get(ENCRYPT_ON_DISK) is None:
default_wallet.preferences[ENCRYPT_ON_DISK] = True
default_wallet.save()
if receiving_addresses or change_addresses:
if not os.path.exists(ledger.path):
os.mkdir(ledger.path)
await ledger.db.open()
try:
await manager._migrate_addresses(receiving_addresses, change_addresses)
finally:
await ledger.db.close()
return manager
async def reset(self):
self.ledger.config = {
'auto_connect': True,
@ -230,24 +121,6 @@ class WalletManager:
await self.ledger.stop()
await self.ledger.start()
async def _migrate_addresses(self, receiving_addresses: set, change_addresses: set):
async with self.default_account.receiving.address_generator_lock:
migrated_receiving = set(await self.default_account.receiving._generate_keys(0, len(receiving_addresses)))
async with self.default_account.change.address_generator_lock:
migrated_change = set(await self.default_account.change._generate_keys(0, len(change_addresses)))
receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses))
change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses))
if not any(change_addresses.difference(migrated_change)):
log.info("Successfully migrated %s change addresses.", len(change_addresses))
else:
log.warning("Failed to migrate %s change addresses!",
len(set(change_addresses).difference(set(migrated_change))))
if not any(receiving_addresses.difference(migrated_receiving)):
log.info("Successfully migrated %s receiving addresses.", len(receiving_addresses))
else:
log.warning("Failed to migrate %s receiving addresses!",
len(set(receiving_addresses).difference(set(migrated_receiving))))
async def get_best_blockhash(self):
if len(self.ledger.headers) <= 0:
return self.ledger.genesis_hash
@ -272,35 +145,3 @@ class WalletManager:
await self.ledger.maybe_verify_transaction(tx, height, merkle)
return tx
async def create_purchase_transaction(
self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager',
override_max_key_fee=False):
fee = txo.claim.stream.fee
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
if not override_max_key_fee and self.config.max_key_fee:
max_fee = self.config.max_key_fee
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
if max_fee_amount and fee_amount > max_fee_amount:
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
if fee.currency != 'LBC':
error_fee += f" ({fee.amount} {fee.currency})"
error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC"
if max_fee['currency'] != 'LBC':
error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})"
raise KeyFeeAboveMaxAllowedError(
f"Purchase price of {error_fee} exceeds maximum "
f"configured price of {error_max_fee}."
)
fee_address = fee.address or txo.get_address(self.ledger)
return await Transaction.purchase(
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
)
async def broadcast_or_release(self, tx, blocking=False):
try:
await self.ledger.broadcast(tx)
if blocking:
await self.ledger.wait(tx, timeout=None)
except:
await self.ledger.release_tx(tx)
raise

View file

@ -1,406 +0,0 @@
import logging
import asyncio
import json
from time import perf_counter
from operator import itemgetter
from typing import Dict, Optional, Tuple
from binascii import hexlify
from lbry import __version__
from lbry.error import IncompatibleWalletServerError
from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError
from lbry.wallet.stream import StreamController
log = logging.getLogger(__name__)
class ClientSession(BaseClientSession):
def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs):
self.network = network
self.server = server
super().__init__(*args, **kwargs)
self._on_disconnect_controller = StreamController()
self.on_disconnected = self._on_disconnect_controller.stream
self.framer.max_size = self.max_errors = 1 << 32
self.timeout = timeout
self.max_seconds_idle = timeout * 2
self.response_time: Optional[float] = None
self.connection_latency: Optional[float] = None
self._response_samples = 0
self.pending_amount = 0
self._on_connect_cb = on_connect_callback or (lambda: None)
self.trigger_urgent_reconnect = asyncio.Event()
@property
def available(self):
return not self.is_closing() and self.response_time is not None
@property
def server_address_and_port(self) -> Optional[Tuple[str, int]]:
if not self.transport:
return None
return self.transport.get_extra_info('peername')
async def send_timed_server_version_request(self, args=(), timeout=None):
timeout = timeout or self.timeout
log.debug("send version request to %s:%i", *self.server)
start = perf_counter()
result = await asyncio.wait_for(
super().send_request('server.version', args), timeout=timeout
)
current_response_time = perf_counter() - start
response_sum = (self.response_time or 0) * self._response_samples + current_response_time
self.response_time = response_sum / (self._response_samples + 1)
self._response_samples += 1
return result
async def send_request(self, method, args=()):
self.pending_amount += 1
log.debug("send %s%s to %s:%i", method, tuple(args), *self.server)
try:
if method == 'server.version':
return await self.send_timed_server_version_request(args, self.timeout)
request = asyncio.ensure_future(super().send_request(method, args))
while not request.done():
done, pending = await asyncio.wait([request], timeout=self.timeout)
if pending:
log.debug("Time since last packet: %s", perf_counter() - self.last_packet_received)
if (perf_counter() - self.last_packet_received) < self.timeout:
continue
log.info("timeout sending %s to %s:%i", method, *self.server)
raise asyncio.TimeoutError
if done:
try:
return request.result()
except ConnectionResetError:
log.error(
"wallet server (%s) reset connection upon our %s request, json of %i args is %i bytes",
self.server[0], method, len(args), len(json.dumps(args))
)
raise
except (RPCError, ProtocolError) as e:
log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
*self.server, *e.args)
raise e
except ConnectionError:
log.warning("connection to %s:%i lost", *self.server)
self.synchronous_close()
raise
except asyncio.CancelledError:
log.info("cancelled sending %s to %s:%i", method, *self.server)
self.synchronous_close()
raise
finally:
self.pending_amount -= 1
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
if self.is_closing():
await self.create_connection(self.timeout)
await self.ensure_server_version()
self._on_connect_cb()
if (perf_counter() - self.last_send) > self.max_seconds_idle or self.response_time is None:
await self.ensure_server_version()
retry_delay = default_delay
except RPCError as e:
await self.close()
log.debug("Server error, ignoring for 1h: %s:%d -- %s", *self.server, e.message)
retry_delay = 60 * 60
except IncompatibleWalletServerError:
await self.close()
retry_delay = 60 * 60
log.debug("Wallet server has an incompatible version, retrying in 1h: %s:%d", *self.server)
except (asyncio.TimeoutError, OSError):
await self.close()
retry_delay = min(60, retry_delay * 2)
log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
try:
await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
except asyncio.TimeoutError:
pass
finally:
self.trigger_urgent_reconnect.clear()
async def ensure_server_version(self, required=None, timeout=3):
required = required or self.network.PROTOCOL_VERSION
response = await asyncio.wait_for(
self.send_request('server.version', [__version__, required]), timeout=timeout
)
if tuple(int(piece) for piece in response[0].split(".")) < self.network.MINIMUM_REQUIRED:
raise IncompatibleWalletServerError(*self.server)
return response
async def create_connection(self, timeout=6):
connector = Connector(lambda: self, *self.server)
start = perf_counter()
await asyncio.wait_for(connector.create_connection(), timeout=timeout)
self.connection_latency = perf_counter() - start
async def handle_request(self, request):
controller = self.network.subscription_controllers[request.method]
controller.add(request.args)
def connection_lost(self, exc):
log.debug("Connection lost: %s:%d", *self.server)
super().connection_lost(exc)
self.response_time = None
self.connection_latency = None
self._response_samples = 0
self.pending_amount = 0
self._on_disconnect_controller.add(True)
class Network:
PROTOCOL_VERSION = __version__
MINIMUM_REQUIRED = (0, 65, 0)
def __init__(self, ledger):
self.ledger = ledger
self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
self.client: Optional[ClientSession] = None
self.server_features = None
self._switch_task: Optional[asyncio.Task] = None
self.running = False
self.remote_height: int = 0
self._concurrency = asyncio.Semaphore(16)
self._on_connected_controller = StreamController()
self.on_connected = self._on_connected_controller.stream
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self._on_status_controller = StreamController(merge_repeated_events=True)
self.on_status = self._on_status_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
'blockchain.address.subscribe': self._on_status_controller,
}
@property
def config(self):
return self.ledger.config
async def switch_forever(self):
while self.running:
if self.is_connected:
await self.client.on_disconnected.first
self.server_features = None
self.client = None
continue
self.client = await self.session_pool.wait_for_fastest_session()
log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
try:
self.server_features = await self.get_server_features()
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
log.info("Subscribed to headers: %s:%d", *self.client.server)
except (asyncio.TimeoutError, ConnectionError):
log.info("Switching to %s:%d timed out, closing and retrying.", *self.client.server)
self.client.synchronous_close()
self.server_features = None
self.client = None
async def start(self):
self.running = True
self._switch_task = asyncio.ensure_future(self.switch_forever())
# this may become unnecessary when there are no more bugs found,
# but for now it helps understanding log reports
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
async def stop(self):
if self.running:
self.running = False
self._switch_task.cancel()
self.session_pool.stop()
@property
def is_connected(self):
return self.client and not self.client.is_closing()
def rpc(self, list_or_method, args, restricted=True):
session = self.client if restricted else self.session_pool.fastest_session
if session and not session.is_closing():
return session.send_request(list_or_method, args)
else:
self.session_pool.trigger_nodelay_connect()
raise ConnectionError("Attempting to send rpc request when connection is not available.")
async def retriable_call(self, function, *args, **kwargs):
async with self._concurrency:
while self.running:
if not self.is_connected:
log.warning("Wallet server unavailable, waiting for it to come back and retry.")
await self.on_connected.first
await self.session_pool.wait_for_fastest_session()
try:
return await function(*args, **kwargs)
except asyncio.TimeoutError:
log.warning("Wallet server call timed out, retrying.")
except ConnectionError:
pass
raise asyncio.CancelledError() # if we got here, we are shutting down
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def get_transaction(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.info', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_transaction_height(self, tx_hash, known_height=None):
restricted = not known_height or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_height', [hexlify(tx_hash[::-1]).decode()], restricted)
def get_merkle(self, tx_hash, height):
restricted = 0 > height > self.remote_height - 10
return self.rpc('blockchain.transaction.get_merkle', [hexlify(tx_hash[::-1]).decode(), height], restricted)
def get_headers(self, height, count=10000, b64=False):
restricted = height >= self.remote_height - 100
return self.rpc('blockchain.block.headers', [height, count, 0, b64], restricted)
# --- Subscribes, history and broadcasts are always aimed towards the master client directly
def get_history(self, address):
return self.rpc('blockchain.address.get_history', [address], True)
def broadcast(self, raw_transaction):
return self.rpc('blockchain.transaction.broadcast', [raw_transaction], True)
def subscribe_headers(self):
return self.rpc('blockchain.headers.subscribe', [True], True)
async def subscribe_address(self, address, *addresses):
addresses = list((address, ) + addresses)
try:
return await self.rpc('blockchain.address.subscribe', addresses, True)
except asyncio.TimeoutError:
log.warning(
"timed out subscribing to addresses from %s:%i",
*self.client.server_address_and_port
)
# abort and cancel, we can't lose a subscription, it will happen again on reconnect
if self.client:
self.client.abort()
raise asyncio.CancelledError()
def unsubscribe_address(self, address):
return self.rpc('blockchain.address.unsubscribe', [address], True)
def get_server_features(self):
return self.rpc('server.features', (), restricted=True)
def get_claims_by_ids(self, claim_ids):
return self.rpc('blockchain.claimtrie.getclaimsbyids', claim_ids)
def resolve(self, urls):
return self.rpc('blockchain.claimtrie.resolve', urls)
def claim_search(self, **kwargs):
return self.rpc('blockchain.claimtrie.search', kwargs)
class SessionPool:
def __init__(self, network: Network, timeout: float):
self.network = network
self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict()
self.timeout = timeout
self.new_connection_event = asyncio.Event()
@property
def online(self):
return any(not session.is_closing() for session in self.sessions)
@property
def available_sessions(self):
return (session for session in self.sessions if session.available)
@property
def fastest_session(self):
if not self.online:
return None
return min(
[((session.response_time + session.connection_latency) * (session.pending_amount + 1), session)
for session in self.available_sessions] or [(0, None)],
key=itemgetter(0)
)[1]
def _get_session_connect_callback(self, session: ClientSession):
loop = asyncio.get_event_loop()
def callback():
duplicate_connections = [
s for s in self.sessions
if s is not session and s.server_address_and_port == session.server_address_and_port
]
already_connected = None if not duplicate_connections else duplicate_connections[0]
if already_connected:
self.sessions.pop(session).cancel()
session.synchronous_close()
log.debug("wallet server %s resolves to the same server as %s, rechecking in an hour",
session.server[0], already_connected.server[0])
loop.call_later(3600, self._connect_session, session.server)
return
self.new_connection_event.set()
log.info("connected to %s:%i", *session.server)
return callback
def _connect_session(self, server: Tuple[str, int]):
session = None
for s in self.sessions:
if s.server == server:
session = s
break
if not session:
session = ClientSession(
network=self.network, server=server
)
session._on_connect_cb = self._get_session_connect_callback(session)
task = self.sessions.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
task.add_done_callback(lambda _: self.ensure_connections())
self.sessions[session] = task
def start(self, default_servers):
for server in default_servers:
self._connect_session(server)
def stop(self):
for session, task in self.sessions.items():
task.cancel()
session.synchronous_close()
self.sessions.clear()
def ensure_connections(self):
for session in self.sessions:
self._connect_session(session.server)
def trigger_nodelay_connect(self):
# used when other parts of the system sees we might have internet back
# bypasses the retry interval
for session in self.sessions:
session.trigger_urgent_reconnect.set()
async def wait_for_fastest_session(self):
while not self.fastest_session:
self.trigger_nodelay_connect()
self.new_connection_event.clear()
await self.new_connection_event.wait()
return self.fastest_session

431
lbry/wallet/sync.py Normal file
View file

@ -0,0 +1,431 @@
import asyncio
import logging
from io import StringIO
from functools import partial
from operator import itemgetter
from collections import defaultdict
from binascii import hexlify, unhexlify
from typing import List, Optional, DefaultDict, NamedTuple
import pylru
from lbry.crypto.hash import double_sha256, sha256
from lbry.service.api import Client
from lbry.tasks import TaskGroup
from lbry.blockchain.transaction import Transaction
from lbry.blockchain.ledger import Ledger
from lbry.blockchain.block import get_block_filter
from lbry.db import Database
from lbry.event import EventController
from lbry.service.base import Service, Sync
from .account import Account, AddressManager
class TransactionEvent(NamedTuple):
address: str
tx: Transaction
class AddressesGeneratedEvent(NamedTuple):
address_manager: AddressManager
addresses: List[str]
class TransactionCacheItem:
__slots__ = '_tx', 'lock', 'has_tx', 'pending_verifications'
def __init__(self, tx: Optional[Transaction] = None, lock: Optional[asyncio.Lock] = None):
self.has_tx = asyncio.Event()
self.lock = lock or asyncio.Lock()
self._tx = self.tx = tx
self.pending_verifications = 0
@property
def tx(self) -> Optional[Transaction]:
return self._tx
@tx.setter
def tx(self, tx: Transaction):
self._tx = tx
if tx is not None:
self.has_tx.set()
class SPVSync(Sync):
def __init__(self, service: Service):
super().__init__(service)
return
self.headers = headers
self.network: Network = self.config.get('network') or Network(self)
self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status_update)
self.network.on_connected.listen(self.join_network)
self.accounts = []
self.on_address = self.ledger.on_address
self._on_header_controller = EventController()
self.on_header = self._on_header_controller.stream
self.on_header.listen(
lambda change: log.info(
'%s: added %s header blocks, final height %s',
self.ledger.get_id(), change, self.headers.height
)
)
self._download_height = 0
self._on_ready_controller = EventController()
self.on_ready = self._on_ready_controller.stream
self._tx_cache = pylru.lrucache(100000)
self._update_tasks = TaskGroup()
self._other_tasks = TaskGroup() # that we dont need to start
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._known_addresses_out_of_sync = set()
async def advance(self):
address_array = [
bytearray(a['address'].encode())
for a in await self.service.db.get_all_addresses()
]
block_filters = await self.service.get_block_address_filters()
for block_hash, block_filter in block_filters.items():
bf = get_block_filter(block_filter)
if bf.MatchAny(address_array):
print(f'match: {block_hash} - {block_filter}')
tx_filters = await self.service.get_transaction_address_filters(block_hash=block_hash)
for txid, tx_filter in tx_filters.items():
tf = get_block_filter(tx_filter)
if tf.MatchAny(address_array):
print(f' match: {txid} - {tx_filter}')
txs = await self.service.search_transactions([txid])
tx = Transaction(unhexlify(txs[txid]))
await self.service.db.insert_transaction(tx)
async def get_local_status_and_history(self, address, history=None):
if not history:
address_details = await self.db.get_address(address=address)
history = (address_details['history'] if address_details else '') or ''
parts = history.split(':')[:-1]
return (
hexlify(sha256(history.encode())).decode() if history else None,
list(zip(parts[0::2], map(int, parts[1::2])))
)
@staticmethod
def get_root_of_merkle_tree(branches, branch_positions, working_branch):
for i, branch in enumerate(branches):
other_branch = unhexlify(branch)[::-1]
other_branch_on_left = bool((branch_positions >> i) & 1)
if other_branch_on_left:
combined = other_branch + working_branch
else:
combined = working_branch + other_branch
working_branch = double_sha256(combined)
return hexlify(working_branch[::-1])
async def start(self):
await self.headers.open()
fully_synced = self.on_ready.first
asyncio.create_task(self.network.start())
await self.network.on_connected.first
async with self._header_processing_lock:
await self._update_tasks.add(self.initial_headers_sync())
await fully_synced
async def join_network(self, *_):
log.info("Subscribing and updating accounts.")
await self._update_tasks.add(self.subscribe_accounts())
await self._update_tasks.done.wait()
self._on_ready_controller.add(True)
async def stop(self):
self._update_tasks.cancel()
self._other_tasks.cancel()
await self._update_tasks.done.wait()
await self._other_tasks.done.wait()
await self.network.stop()
await self.headers.close()
@property
def local_height_including_downloaded_height(self):
return max(self.headers.height, self._download_height)
async def initial_headers_sync(self):
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True)
self.headers.chunk_getter = get_chunk
async def doit():
for height in reversed(sorted(self.headers.known_missing_checkpointed_chunks)):
async with self._header_processing_lock:
await self.headers.ensure_chunk_at(height)
self._other_tasks.add(doit())
await self.update_headers()
async def update_headers(self, height=None, headers=None, subscription_update=False):
rewound = 0
while True:
if height is None or height > len(self.headers):
# sometimes header subscription updates are for a header in the future
# which can't be connected, so we do a normal header sync instead
height = len(self.headers)
headers = None
subscription_update = False
if not headers:
header_response = await self.network.retriable_call(self.network.get_headers, height, 2001)
headers = header_response['hex']
if not headers:
# Nothing to do, network thinks we're already at the latest height.
return
added = await self.headers.connect(height, unhexlify(headers))
if added > 0:
height += added
self._on_header_controller.add(
BlockHeightEvent(self.headers.height, added))
if rewound > 0:
# we started rewinding blocks and apparently found
# a new chain
rewound = 0
await self.db.rewind_blockchain(height)
if subscription_update:
# subscription updates are for latest header already
# so we don't need to check if there are newer / more
# on another loop of update_headers(), just return instead
return
elif added == 0:
# we had headers to connect but none got connected, probably a reorganization
height -= 1
rewound += 1
log.warning(
"Blockchain Reorganization: attempting rewind to height %s from starting height %s",
height, height+rewound
)
else:
raise IndexError(f"headers.connect() returned negative number ({added})")
if height < 0:
raise IndexError(
"Blockchain reorganization rewound all the way back to genesis hash. "
"Something is very wrong. Maybe you are on the wrong blockchain?"
)
if rewound >= 100:
raise IndexError(
"Blockchain reorganization dropped {} headers. This is highly unusual. "
"Will not continue to attempt reorganizing. Please, delete the ledger "
"synchronization directory inside your wallet directory (folder: '{}') and "
"restart the program to synchronize from scratch."
.format(rewound, self.ledger.get_id())
)
headers = None # ready to download some more headers
# if we made it this far and this was a subscription_update
# it means something went wrong and now we're doing a more
# robust sync, turn off subscription update shortcut
subscription_update = False
async def receive_header(self, response):
async with self._header_processing_lock:
header = response[0]
await self.update_headers(
height=header['height'], headers=header['hex'], subscription_update=True
)
async def subscribe_accounts(self):
if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values():
await self.subscribe_addresses(address_manager, await address_manager.get_addresses())
await account.ensure_address_gap()
async def unsubscribe_account(self, account: Account):
for address in await account.get_addresses():
await self.network.unsubscribe_address(address)
async def subscribe_addresses(self, address_manager: AddressManager, addresses: List[str], batch_size: int = 1000):
if self.network.is_connected and addresses:
addresses_remaining = list(addresses)
while addresses_remaining:
batch = addresses_remaining[:batch_size]
results = await self.network.subscribe_address(*batch)
for address, remote_status in zip(batch, results):
self._update_tasks.add(self.update_history(address, remote_status, address_manager))
addresses_remaining = addresses_remaining[batch_size:]
log.info("subscribed to %i/%i addresses on %s:%i", len(addresses) - len(addresses_remaining),
len(addresses), *self.network.client.server_address_and_port)
log.info(
"finished subscribing to %i addresses on %s:%i", len(addresses),
*self.network.client.server_address_and_port
)
def process_status_update(self, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))
async def update_history(self, address, remote_status, address_manager: AddressManager = None):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)
local_status, local_history = await self.get_local_status_and_history(address)
if local_status == remote_status:
return True
remote_history = await self.network.retriable_call(self.network.get_history, address)
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
return True
cache_tasks: List[asyncio.Task[Transaction]] = []
synced_history = StringIO()
loop = asyncio.get_running_loop()
for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:')
else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(loop.create_task(
self.cache_transaction(unhexlify(txid)[::-1], remote_height, check_local=check_local)
))
synced_txs = []
for task in cache_tasks:
tx = await task
check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.hash)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.hash)
referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txo_hash__in=check_db_for_txos, order_by='txo.txo_hash', no_tx=True
)
}
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
synced_history.write(f'{tx.id}:{tx.height}:')
synced_txs.append(tx)
await self.db.save_transaction_io_batch(
synced_txs, address, self.ledger.address_to_hash160(address), synced_history.getvalue()
)
await asyncio.wait([
self.ledger._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])
if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)
if address_manager is not None:
await address_manager.ensure_address_gap()
local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
if local_status != remote_status:
if local_history == remote_history:
return True
log.warning(
"Wallet is out of sync after syncing. Remote: %s with %d items, local: %s with %d items",
remote_status, len(remote_history), local_status, len(local_history)
)
log.warning("local: %s", local_history)
log.warning("remote: %s", remote_history)
self._known_addresses_out_of_sync.add(address)
return False
else:
return True
async def cache_transaction(self, tx_hash, remote_height, check_local=True):
cache_item = self._tx_cache.get(tx_hash)
if cache_item is None:
cache_item = self._tx_cache[tx_hash] = TransactionCacheItem()
elif cache_item.tx is not None and \
cache_item.tx.height >= remote_height and \
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date
try:
cache_item.pending_verifications += 1
return await self._update_cache_item(cache_item, tx_hash, remote_height, check_local)
finally:
cache_item.pending_verifications -= 1
async def _update_cache_item(self, cache_item, tx_hash, remote_height, check_local=True):
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(tx_hash=tx_hash)
merkle = None
if tx is None:
# fetch from network
_raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, tx_hash, remote_height
)
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height, merkle)
return tx
async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height
cached = self._tx_cache.get(tx.hash)
if not cached:
# cache txs looked up by transaction_show too
cached = TransactionCacheItem()
cached.tx = tx
self._tx_cache[tx.hash] = cached
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle:
merkle = await self.network.retriable_call(self.network.get_merkle, tx.hash, remote_height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address)
for account in self.accounts:
if account.id == details['account']:
return account.address_managers[details['chain']]
return None

View file

@ -6,9 +6,9 @@ from lbry.error import (
ServerPaymentInvalidAddressError,
ServerPaymentWalletLockedError
)
from lbry.wallet.dewies import lbc_to_dewies
from lbry.wallet.stream import StreamController
from lbry.wallet.transaction import Output, Transaction
from lbry.blockchain.dewies import lbc_to_dewies
from lbry.event import EventController
from lbry.blockchain.transaction import Output, Transaction
log = logging.getLogger(__name__)
@ -22,7 +22,7 @@ class WalletServerPayer:
self.payment_period = payment_period
self.analytics_manager = analytics_manager
self.max_fee = max_fee
self._on_payment_controller = StreamController()
self._on_payment_controller = EventController()
self.on_payment = self._on_payment_controller.stream
self.on_payment.listen(None, on_error=lambda e: logging.warning(e.args[0]))

View file

@ -5,16 +5,28 @@ import json
import zlib
import typing
import logging
from typing import List, Sequence, MutableSequence, Optional
from typing import List, Sequence, MutableSequence, Optional, Iterable
from collections import UserDict
from hashlib import sha256
from operator import attrgetter
from decimal import Decimal
from lbry.db import Database
from lbry.blockchain.ledger import Ledger
from lbry.constants import COIN, NULL_HASH32
from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
from lbry.crypto.bip32 import PubKey, PrivateKey
from lbry.schema.claim import Claim
from lbry.schema.purchase import Purchase
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
from .account import Account
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
if typing.TYPE_CHECKING:
from lbry.wallet.manager import WalletManager
from lbry.wallet.ledger import Ledger
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
log = logging.getLogger(__name__)
@ -67,8 +79,11 @@ class Wallet:
preferences: TimestampedPreferences
encryption_password: Optional[str]
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None,
def __init__(self, ledger: Ledger, db: Database,
name: str = 'Wallet', accounts: MutableSequence[Account] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
self.ledger = ledger
self.db = db
self.name = name
self.accounts = accounts or []
self.storage = storage or WalletStorage()
@ -79,30 +94,34 @@ class Wallet:
def get_id(self):
return os.path.basename(self.storage.path) if self.storage.path else self.name
def add_account(self, account: 'Account'):
def generate_account(self, name: str = None, address_generator: dict = None) -> Account:
account = Account.generate(self.ledger, self.db, name, address_generator)
self.accounts.append(account)
return account
def generate_account(self, ledger: 'Ledger') -> 'Account':
return Account.generate(ledger, self)
def add_account(self, account_dict) -> Account:
account = Account.from_dict(self.ledger, self.db, account_dict)
self.accounts.append(account)
return account
@property
def default_account(self) -> Optional['Account']:
def default_account(self) -> Optional[Account]:
for account in self.accounts:
return account
return None
def get_account_or_default(self, account_id: str) -> Optional['Account']:
def get_account_or_default(self, account_id: str) -> Optional[Account]:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
def get_account_or_error(self, account_id: str) -> 'Account':
def get_account_or_error(self, account_id: str) -> Account:
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']:
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
@ -116,24 +135,63 @@ class Wallet:
accounts.append(details)
return accounts
async def _get_account_and_address_info_for_address(self, address):
match = await self.db.get_address(accounts=self.accounts, address=address)
if match:
for account in self.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, address):
match = await self._get_account_and_address_info_for_address(address)
if match:
return match[0]
async def save_max_gap(self):
gap_changed = False
for account in self.accounts:
if await account.save_max_gap():
gap_changed = True
if gap_changed:
self.save()
@classmethod
def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet':
def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet':
json_dict = storage.read()
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
raise ValueError(
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
)
wallet = cls(
ledger, db,
name=json_dict.get('name', 'Wallet'),
preferences=json_dict.get('preferences', {}),
storage=storage
)
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts:
ledger = manager.get_or_create_ledger(account_dict['ledger'])
Account.from_dict(ledger, wallet, account_dict)
wallet.add_account(account_dict)
return wallet
def to_dict(self, encrypt_password: str = None):
return {
'version': WalletStorage.LATEST_VERSION,
'name': self.name,
'ledger': self.ledger.get_id(),
'preferences': self.preferences.data,
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
}
@ -173,15 +231,13 @@ class Wallet:
decompressed = zlib.decompress(decrypted)
return json.loads(decompressed)
def merge(self, manager: 'WalletManager',
password: str, data: str) -> List['Account']:
def merge(self, password: str, data: str) -> List[Account]:
assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = []
decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']:
ledger = manager.get_or_create_ledger(account_dict['ledger'])
_, _, pubkey = Account.keys_from_dict(ledger, account_dict)
_, _, pubkey = Account.keys_from_dict(self.ledger, account_dict)
account_id = pubkey.address
local_match = None
for local_account in self.accounts:
@ -191,8 +247,9 @@ class Wallet:
if local_match is not None:
local_match.merge(account_dict)
else:
new_account = Account.from_dict(ledger, self, account_dict)
added_accounts.append(new_account)
added_accounts.append(
self.add_account(account_dict)
)
return added_accounts
@property
@ -235,6 +292,203 @@ class Wallet:
self.save()
return True
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = []
for utxo in (await self.db.get_utxos(accounts=funding_accounts))[0]:
estimators.append(OutputEffectiveAmountEstimator(self.ledger, utxo))
return estimators
async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]):
txos = await self.get_effective_amount_estimators(funding_accounts)
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
selector = CoinSelector(amount, fee)
spendables = selector.select(txos, self.ledger.coin_selection_strategy)
if spendables:
await self.db.reserve_outputs(s.txo for s in spendables)
return spendables
async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output],
funding_accounts: Iterable[Account], change_account: Account,
sign: bool = True):
""" Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """
tx = Transaction() \
.add_inputs(inputs) \
.add_outputs(outputs)
# value of the outputs plus associated fees
cost = (
tx.get_base_fee(self.ledger) +
tx.get_total_output_sum(self.ledger)
)
# value of the inputs less the cost to spend those inputs
payment = tx.get_effective_input_sum(self.ledger)
try:
for _ in range(5):
if payment < cost:
deficit = cost - payment
spendables = await self.get_spendable_utxos(deficit, funding_accounts)
if not spendables:
raise InsufficientFundsError()
payment += sum(s.effective_amount for s in spendables)
tx.add_inputs(s.txi for s in spendables)
cost_of_change = (
tx.get_base_fee(self.ledger) +
Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
)
if payment > cost:
change = payment - cost
if change > cost_of_change:
change_address = await change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change
change_output = Output.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_internal_transfer = True
tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)])
if tx._outputs:
break
# this condition and the outer range(5) loop cover an edge case
# whereby a single input is just enough to cover the fee and
# has some change left over, but the change left over is less
# than the cost_of_change: thus the input is completely
# consumed and no output is added, which is an invalid tx.
# to be able to spend this input we must increase the cost
# of the TX and run through the balance algorithm a second time
# adding an extra input and change output, making tx valid.
# we do this 5 times in case the other UTXOs added are also
# less than the fee, after 5 attempts we give up and go home
cost += cost_of_change + 1
if sign:
await self.sign(tx)
except Exception as e:
log.exception('Failed to create transaction:')
await self.db.release_tx(tx)
raise e
return tx
async def sign(self, tx):
for i, txi in enumerate(tx._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = self.ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = await self.get_private_key_for_address(address)
assert private_key is not None, 'Cannot find private key for signing output.'
serialized = tx._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(serialized) + bytes((tx.signature_hash_type(1),))
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate()
else:
raise NotImplementedError("Don't know how to spend this output.")
tx._reset()
@classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'):
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
return cls.create([], [output], funding_accounts, change_account)
def claim_create(
self, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, self.ledger.address_to_hash160(holding_address)
)
if signing_channel is not None:
claim_output.sign(signing_channel, b'placeholder txid:nout')
return self.create_transaction(
[], [claim_output], funding_accounts, change_account, sign=False
)
@classmethod
def claim_update(
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id,
claim, ledger.address_to_hash160(holding_address)
)
if signing_channel is not None:
updated_claim.sign(signing_channel, b'placeholder txid:nout')
else:
updated_claim.clear_signature()
return cls.create(
[Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False
)
@classmethod
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account'):
support_output = Output.pay_support_pubkey_hash(
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
)
return cls.create([], [support_output], funding_accounts, change_account)
def purchase(self, claim_id: str, amount: int, merchant_address: bytes,
funding_accounts: List['Account'], change_account: 'Account'):
payment = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(merchant_address))
data = Output.add_purchase_data(Purchase(claim_id))
return self.create_transaction(
[], [payment, data], funding_accounts, change_account
)
async def create_purchase_transaction(
self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager',
override_max_key_fee=False):
fee = txo.claim.stream.fee
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
if not override_max_key_fee and self.ledger.conf.max_key_fee:
max_fee = self.ledger.conf.max_key_fee
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
if max_fee_amount and fee_amount > max_fee_amount:
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
if fee.currency != 'LBC':
error_fee += f" ({fee.amount} {fee.currency})"
error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC"
if max_fee['currency'] != 'LBC':
error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})"
raise KeyFeeAboveMaxAllowedError(
f"Purchase price of {error_fee} exceeds maximum "
f"configured price of {error_max_fee}."
)
fee_address = fee.address or txo.get_address(self.ledger)
return await self.purchase(
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
)
async def create_channel(
self, name, amount, account, funding_accounts,
claim_address, preview=False, **kwargs):
claim = Claim()
claim.channel.update(**kwargs)
tx = await self.claim_create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
)
txo = tx.outputs[0]
txo.generate_channel_private_key()
await self.sign(tx)
if not preview:
account.add_channel_private_key(txo.private_key)
self.save()
return tx
async def get_channels(self):
return await self.db.get_channels()
class WalletStorage: