lbry-sdk/lbry/wallet/wallet.py

554 lines
21 KiB
Python
Raw Normal View History

2019-03-11 17:04:06 +01:00
import os
import time
2018-05-25 08:03:25 +02:00
import stat
import json
2019-03-11 17:04:06 +01:00
import zlib
import typing
import logging
2020-05-01 15:33:10 +02:00
from typing import List, Sequence, MutableSequence, Optional, Iterable
2019-10-13 01:40:32 +02:00
from collections import UserDict
2019-03-11 17:12:26 +01:00
from hashlib import sha256
2019-03-11 17:04:06 +01:00
from operator import attrgetter
2020-05-01 15:33:10 +02:00
from decimal import Decimal
2020-05-02 05:25:07 +02:00
from lbry.db import Database, SPENDABLE_TYPE_CODES
2020-05-01 15:33:10 +02:00
from lbry.blockchain.ledger import Ledger
from lbry.constants import COIN, NULL_HASH32
from lbry.blockchain.transaction import Transaction, Input, Output
from lbry.blockchain.dewies import dewies_to_lbc
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
2020-05-01 15:33:10 +02:00
from lbry.crypto.bip32 import PubKey, PrivateKey
from lbry.schema.claim import Claim
from lbry.schema.purchase import Purchase
from lbry.error import InsufficientFundsError, KeyFeeAboveMaxAllowedError
2020-01-03 04:18:49 +01:00
from .account import Account
2020-05-01 15:33:10 +02:00
from .coinselection import CoinSelector, OutputEffectiveAmountEstimator
2018-05-25 08:03:25 +02:00
if typing.TYPE_CHECKING:
2020-05-01 15:33:10 +02:00
from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager
2018-05-25 08:03:25 +02:00
log = logging.getLogger(__name__)
ENCRYPT_ON_DISK = 'encrypt-on-disk'
class TimestampedPreferences(UserDict):
def __init__(self, d: dict = None):
super().__init__()
if d is not None:
self.data = d.copy()
def __getitem__(self, key):
return self.data[key]['value']
def __setitem__(self, key, value):
self.data[key] = {
'value': value,
'ts': time.time()
}
def __repr__(self):
return repr(self.to_dict_without_ts())
def to_dict_without_ts(self):
return {
key: value['value'] for key, value in self.data.items()
}
@property
def hash(self):
return sha256(json.dumps(self.data).encode()).digest()
def merge(self, other: dict):
for key, value in other.items():
if key in self.data and value['ts'] < self.data[key]['ts']:
continue
self.data[key] = value
2018-05-25 08:03:25 +02:00
class Wallet:
""" The primary role of Wallet is to encapsulate a collection
of accounts (seed/private keys) and the spending rules / settings
for the coins attached to those accounts. Wallets are represented
by physical files on the filesystem.
"""
preferences: TimestampedPreferences
encryption_password: Optional[str]
2020-05-01 15:33:10 +02:00
def __init__(self, ledger: Ledger, db: Database,
name: str = 'Wallet', accounts: MutableSequence[Account] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
2020-05-01 15:33:10 +02:00
self.ledger = ledger
self.db = db
2018-05-25 08:03:25 +02:00
self.name = name
self.accounts = accounts or []
2018-05-25 08:03:25 +02:00
self.storage = storage or WalletStorage()
self.preferences = TimestampedPreferences(preferences or {})
self.encryption_password = None
2020-01-27 19:17:59 +01:00
self.id = self.get_id()
2018-05-25 08:03:25 +02:00
2020-01-27 19:17:59 +01:00
def get_id(self):
return os.path.basename(self.storage.path) if self.storage.path else self.name
2020-05-01 15:33:10 +02:00
def generate_account(self, name: str = None, address_generator: dict = None) -> Account:
account = Account.generate(self.ledger, self.db, name, address_generator)
2018-05-25 08:03:25 +02:00
self.accounts.append(account)
2020-05-01 15:33:10 +02:00
return account
2020-05-01 15:33:10 +02:00
def add_account(self, account_dict) -> Account:
account = Account.from_dict(self.ledger, self.db, account_dict)
self.accounts.append(account)
return account
2018-05-25 08:03:25 +02:00
@property
2020-05-01 15:33:10 +02:00
def default_account(self) -> Optional[Account]:
for account in self.accounts:
return account
return None
2020-05-01 15:33:10 +02:00
def get_account_or_default(self, account_id: str) -> Optional[Account]:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
2020-05-01 15:33:10 +02:00
def get_account_or_error(self, account_id: str) -> Account:
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
2020-05-01 15:33:10 +02:00
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence[Account]:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
] if account_ids else self.accounts
async def get_detailed_accounts(self, **kwargs):
accounts = []
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
accounts.append(details)
return accounts
2020-05-01 15:33:10 +02:00
async def _get_account_and_address_info_for_address(self, address):
match = await self.db.get_address(accounts=self.accounts, address=address)
if match:
for account in self.accounts:
if match['account'] == account.public_key.address:
return account, match
async def get_private_key_for_address(self, address) -> Optional[PrivateKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
account, address_info = match
return account.get_private_key(address_info['chain'], address_info['pubkey'].n)
return None
async def get_public_key_for_address(self, address) -> Optional[PubKey]:
match = await self._get_account_and_address_info_for_address(address)
if match:
_, address_info = match
return address_info['pubkey']
return None
async def get_account_for_address(self, address):
match = await self._get_account_and_address_info_for_address(address)
if match:
return match[0]
async def save_max_gap(self):
gap_changed = False
for account in self.accounts:
if await account.save_max_gap():
gap_changed = True
if gap_changed:
self.save()
2018-05-25 08:03:25 +02:00
@classmethod
2020-05-01 15:33:10 +02:00
def from_storage(cls, ledger: Ledger, db: Database, storage: 'WalletStorage') -> 'Wallet':
2018-05-25 08:03:25 +02:00
json_dict = storage.read()
2020-05-01 15:33:10 +02:00
if 'ledger' in json_dict and json_dict['ledger'] != ledger.get_id():
raise ValueError(
f"Using ledger {ledger.get_id()} but wallet is {json_dict['ledger']}."
)
wallet = cls(
2020-05-01 15:33:10 +02:00
ledger, db,
2018-05-25 08:03:25 +02:00
name=json_dict.get('name', 'Wallet'),
preferences=json_dict.get('preferences', {}),
2018-05-25 08:03:25 +02:00
storage=storage
)
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts:
2020-05-01 15:33:10 +02:00
wallet.add_account(account_dict)
return wallet
2018-05-25 08:03:25 +02:00
def to_dict(self, encrypt_password: str = None):
2018-05-25 08:03:25 +02:00
return {
'version': WalletStorage.LATEST_VERSION,
2018-05-25 08:03:25 +02:00
'name': self.name,
2020-05-01 15:33:10 +02:00
'ledger': self.ledger.get_id(),
'preferences': self.preferences.data,
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
2018-05-25 08:03:25 +02:00
}
def save(self):
2019-10-18 19:38:51 +02:00
if self.preferences.get(ENCRYPT_ON_DISK, False):
if self.encryption_password is not None:
2019-10-18 19:38:51 +02:00
return self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
elif not self.is_locked:
log.warning(
"Disk encryption requested but no password available for encryption. "
"Saving wallet in an unencrypted state."
)
2019-10-18 19:38:51 +02:00
return self.storage.write(self.to_dict())
2018-05-25 08:03:25 +02:00
2019-03-11 17:04:06 +01:00
@property
2019-03-11 17:30:32 +01:00
def hash(self) -> bytes:
2019-03-11 17:04:06 +01:00
h = sha256()
if self.preferences.get(ENCRYPT_ON_DISK, False):
assert self.encryption_password is not None, \
"Encryption is enabled but no password is available, cannot generate hash."
h.update(self.encryption_password.encode())
h.update(self.preferences.hash)
2019-03-11 17:04:06 +01:00
for account in sorted(self.accounts, key=attrgetter('id')):
h.update(account.hash)
return h.digest()
def pack(self, password):
assert not self.is_locked, "Cannot pack a wallet with locked/encrypted accounts."
2019-03-11 17:04:06 +01:00
new_data = json.dumps(self.to_dict())
new_data_compressed = zlib.compress(new_data.encode())
return better_aes_encrypt(password, new_data_compressed)
@classmethod
def unpack(cls, password, encrypted):
decrypted = better_aes_decrypt(password, encrypted)
decompressed = zlib.decompress(decrypted)
return json.loads(decompressed)
2020-05-01 15:33:10 +02:00
def merge(self, password: str, data: str) -> List[Account]:
assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = []
decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']:
2020-05-01 15:33:10 +02:00
_, _, pubkey = Account.keys_from_dict(self.ledger, account_dict)
account_id = pubkey.address
local_match = None
for local_account in self.accounts:
if account_id == local_account.id:
local_match = local_account
break
if local_match is not None:
local_match.merge(account_dict)
else:
2020-05-01 15:33:10 +02:00
added_accounts.append(
self.add_account(account_dict)
)
return added_accounts
2019-10-14 05:43:06 +02:00
@property
def is_locked(self) -> bool:
for account in self.accounts:
if account.encrypted:
return True
return False
def unlock(self, password):
for account in self.accounts:
if account.encrypted:
if not account.decrypt(password):
return False
self.encryption_password = password
return True
2019-10-14 05:43:06 +02:00
def lock(self):
assert self.encryption_password is not None, "Cannot lock an unencrypted wallet, encrypt first."
2019-10-14 05:43:06 +02:00
for account in self.accounts:
if not account.encrypted:
account.encrypt(self.encryption_password)
return True
2019-10-14 05:43:06 +02:00
@property
def is_encrypted(self) -> bool:
return self.is_locked or self.preferences.get(ENCRYPT_ON_DISK, False)
2019-10-14 05:43:06 +02:00
def decrypt(self):
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
self.preferences[ENCRYPT_ON_DISK] = False
2019-10-14 05:43:06 +02:00
self.save()
return True
2019-10-14 05:43:06 +02:00
def encrypt(self, password):
assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first."
2019-10-18 18:43:28 +02:00
assert password, "Cannot encrypt with blank password."
self.encryption_password = password
self.preferences[ENCRYPT_ON_DISK] = True
2019-10-14 05:43:06 +02:00
self.save()
return True
2019-10-14 05:43:06 +02:00
2020-05-01 15:33:10 +02:00
async def get_effective_amount_estimators(self, funding_accounts: Iterable[Account]):
estimators = []
2020-05-02 05:25:07 +02:00
utxos = await self.db.get_utxos(
accounts=funding_accounts,
txo_type__in=SPENDABLE_TYPE_CODES
)
for utxo in utxos[0]:
2020-05-01 15:33:10 +02:00
estimators.append(OutputEffectiveAmountEstimator(self.ledger, utxo))
return estimators
async def get_spendable_utxos(self, amount: int, funding_accounts: Iterable[Account]):
txos = await self.get_effective_amount_estimators(funding_accounts)
fee = Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
selector = CoinSelector(amount, fee)
spendables = selector.select(txos, self.ledger.coin_selection_strategy)
if spendables:
await self.db.reserve_outputs(s.txo for s in spendables)
return spendables
async def create_transaction(self, inputs: Iterable[Input], outputs: Iterable[Output],
funding_accounts: Iterable[Account], change_account: Account,
sign: bool = True):
""" Find optimal set of inputs when only outputs are provided; add change
outputs if only inputs are provided or if inputs are greater than outputs. """
tx = Transaction() \
.add_inputs(inputs) \
.add_outputs(outputs)
# value of the outputs plus associated fees
cost = (
tx.get_base_fee(self.ledger) +
tx.get_total_output_sum(self.ledger)
)
# value of the inputs less the cost to spend those inputs
payment = tx.get_effective_input_sum(self.ledger)
try:
for _ in range(5):
if payment < cost:
deficit = cost - payment
spendables = await self.get_spendable_utxos(deficit, funding_accounts)
if not spendables:
raise InsufficientFundsError()
payment += sum(s.effective_amount for s in spendables)
tx.add_inputs(s.txi for s in spendables)
cost_of_change = (
tx.get_base_fee(self.ledger) +
Output.pay_pubkey_hash(COIN, NULL_HASH32).get_fee(self.ledger)
)
if payment > cost:
change = payment - cost
if change > cost_of_change:
change_address = await change_account.change.get_or_create_usable_address()
change_hash160 = change_account.ledger.address_to_hash160(change_address)
change_amount = change - cost_of_change
change_output = Output.pay_pubkey_hash(change_amount, change_hash160)
change_output.is_internal_transfer = True
tx.add_outputs([Output.pay_pubkey_hash(change_amount, change_hash160)])
if tx._outputs:
break
# this condition and the outer range(5) loop cover an edge case
# whereby a single input is just enough to cover the fee and
# has some change left over, but the change left over is less
# than the cost_of_change: thus the input is completely
# consumed and no output is added, which is an invalid tx.
# to be able to spend this input we must increase the cost
# of the TX and run through the balance algorithm a second time
# adding an extra input and change output, making tx valid.
# we do this 5 times in case the other UTXOs added are also
# less than the fee, after 5 attempts we give up and go home
cost += cost_of_change + 1
if sign:
await self.sign(tx)
except Exception as e:
log.exception('Failed to create transaction:')
await self.db.release_tx(tx)
raise e
return tx
async def sign(self, tx):
for i, txi in enumerate(tx._inputs):
assert txi.script is not None
assert txi.txo_ref.txo is not None
txo_script = txi.txo_ref.txo.script
if txo_script.is_pay_pubkey_hash:
address = self.ledger.hash160_to_address(txo_script.values['pubkey_hash'])
private_key = await self.get_private_key_for_address(address)
assert private_key is not None, 'Cannot find private key for signing output.'
serialized = tx._serialize_for_signature(i)
txi.script.values['signature'] = \
private_key.sign(serialized) + bytes((tx.signature_hash_type(1),))
txi.script.values['pubkey'] = private_key.public_key.pubkey_bytes
txi.script.generate()
else:
raise NotImplementedError("Don't know how to spend this output.")
tx._reset()
@classmethod
def pay(cls, amount: int, address: bytes, funding_accounts: List['Account'], change_account: 'Account'):
output = Output.pay_pubkey_hash(amount, ledger.address_to_hash160(address))
return cls.create([], [output], funding_accounts, change_account)
def claim_create(
self, name: str, claim: Claim, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
claim_output = Output.pay_claim_name_pubkey_hash(
amount, name, claim, self.ledger.address_to_hash160(holding_address)
)
if signing_channel is not None:
claim_output.sign(signing_channel, b'placeholder txid:nout')
return self.create_transaction(
[], [claim_output], funding_accounts, change_account, sign=False
)
@classmethod
def claim_update(
cls, previous_claim: Output, claim: Claim, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account', signing_channel: Output = None):
updated_claim = Output.pay_update_claim_pubkey_hash(
amount, previous_claim.claim_name, previous_claim.claim_id,
claim, ledger.address_to_hash160(holding_address)
)
if signing_channel is not None:
updated_claim.sign(signing_channel, b'placeholder txid:nout')
else:
updated_claim.clear_signature()
return cls.create(
[Input.spend(previous_claim)], [updated_claim], funding_accounts, change_account, sign=False
)
@classmethod
def support(cls, claim_name: str, claim_id: str, amount: int, holding_address: str,
funding_accounts: List['Account'], change_account: 'Account'):
support_output = Output.pay_support_pubkey_hash(
amount, claim_name, claim_id, ledger.address_to_hash160(holding_address)
)
return cls.create([], [support_output], funding_accounts, change_account)
def purchase(self, claim_id: str, amount: int, merchant_address: bytes,
funding_accounts: List['Account'], change_account: 'Account'):
payment = Output.pay_pubkey_hash(amount, self.ledger.address_to_hash160(merchant_address))
data = Output.add_purchase_data(Purchase(claim_id))
return self.create_transaction(
[], [payment, data], funding_accounts, change_account
)
async def create_purchase_transaction(
self, accounts: List[Account], txo: Output, exchange: 'ExchangeRateManager',
override_max_key_fee=False):
fee = txo.claim.stream.fee
fee_amount = exchange.to_dewies(fee.currency, fee.amount)
if not override_max_key_fee and self.ledger.conf.max_key_fee:
max_fee = self.ledger.conf.max_key_fee
max_fee_amount = exchange.to_dewies(max_fee['currency'], Decimal(max_fee['amount']))
if max_fee_amount and fee_amount > max_fee_amount:
error_fee = f"{dewies_to_lbc(fee_amount)} LBC"
if fee.currency != 'LBC':
error_fee += f" ({fee.amount} {fee.currency})"
error_max_fee = f"{dewies_to_lbc(max_fee_amount)} LBC"
if max_fee['currency'] != 'LBC':
error_max_fee += f" ({max_fee['amount']} {max_fee['currency']})"
raise KeyFeeAboveMaxAllowedError(
f"Purchase price of {error_fee} exceeds maximum "
f"configured price of {error_max_fee}."
)
fee_address = fee.address or txo.get_address(self.ledger)
return await self.purchase(
txo.claim_id, fee_amount, fee_address, accounts, accounts[0]
)
async def create_channel(
self, name, amount, account, funding_accounts,
claim_address, preview=False, **kwargs):
claim = Claim()
claim.channel.update(**kwargs)
tx = await self.claim_create(
name, claim, amount, claim_address, funding_accounts, funding_accounts[0]
)
txo = tx.outputs[0]
txo.generate_channel_private_key()
await self.sign(tx)
if not preview:
account.add_channel_private_key(txo.private_key)
self.save()
return tx
async def get_channels(self):
return await self.db.get_channels()
2018-05-25 08:03:25 +02:00
class WalletStorage:
2018-07-01 23:20:17 +02:00
LATEST_VERSION = 1
2018-05-25 08:03:25 +02:00
def __init__(self, path=None, default=None):
self.path = path
2018-07-12 04:37:15 +02:00
self._default = default or {
'version': self.LATEST_VERSION,
'name': 'My Wallet',
'preferences': {},
2018-07-12 04:37:15 +02:00
'accounts': []
}
2018-05-25 08:03:25 +02:00
def read(self):
if self.path and os.path.exists(self.path):
2018-07-12 04:37:15 +02:00
with open(self.path, 'r') as f:
2018-05-25 08:03:25 +02:00
json_data = f.read()
json_dict = json.loads(json_data)
if json_dict.get('version') == self.LATEST_VERSION and \
set(json_dict) == set(self._default):
return json_dict
else:
return self.upgrade(json_dict)
else:
2018-07-12 04:37:15 +02:00
return self._default.copy()
2018-05-25 08:03:25 +02:00
2018-07-12 04:37:15 +02:00
def upgrade(self, json_dict):
2018-05-25 08:03:25 +02:00
json_dict = json_dict.copy()
version = json_dict.pop('version', -1)
2018-07-12 04:37:15 +02:00
if version == -1:
2018-07-01 23:20:17 +02:00
pass
2018-07-12 04:37:15 +02:00
upgraded = self._default.copy()
2018-05-25 08:03:25 +02:00
upgraded.update(json_dict)
return json_dict
def write(self, json_dict):
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
if self.path is None:
return json_data
temp_path = "{}.tmp.{}".format(self.path, os.getpid())
2018-05-25 08:03:25 +02:00
with open(temp_path, "w") as f:
f.write(json_data)
f.flush()
os.fsync(f.fileno())
if os.path.exists(self.path):
mode = os.stat(self.path).st_mode
else:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, self.path)
except Exception: # pylint: disable=broad-except
2018-05-25 08:03:25 +02:00
os.remove(self.path)
os.rename(temp_path, self.path)
os.chmod(self.path, mode)