lbry-sdk/lbry/wallet/wallet.py

317 lines
11 KiB
Python
Raw Permalink Normal View History

2019-03-11 17:04:06 +01:00
import os
import time
2018-05-25 08:03:25 +02:00
import stat
import json
2019-03-11 17:04:06 +01:00
import zlib
import typing
import logging
from typing import List, Sequence, MutableSequence, Optional
2019-10-13 01:40:32 +02:00
from collections import UserDict
2019-03-11 17:12:26 +01:00
from hashlib import sha256
2019-03-11 17:04:06 +01:00
from operator import attrgetter
from lbry.crypto.crypt import better_aes_encrypt, better_aes_decrypt
from lbry.error import InvalidPasswordError
2020-01-03 04:18:49 +01:00
from .account import Account
2018-05-25 08:03:25 +02:00
if typing.TYPE_CHECKING:
2020-01-03 04:18:49 +01:00
from lbry.wallet.manager import WalletManager
from lbry.wallet.ledger import Ledger
2018-05-25 08:03:25 +02:00
log = logging.getLogger(__name__)
ENCRYPT_ON_DISK = 'encrypt-on-disk'
class TimestampedPreferences(UserDict):
def __init__(self, d: dict = None):
super().__init__()
if d is not None:
self.data = d.copy()
def __getitem__(self, key):
return self.data[key]['value']
def __setitem__(self, key, value):
self.data[key] = {
'value': value,
2020-07-22 01:14:57 +02:00
'ts': int(time.time())
}
def __repr__(self):
return repr(self.to_dict_without_ts())
def to_dict_without_ts(self):
return {
key: value['value'] for key, value in self.data.items()
}
@property
def hash(self):
return sha256(json.dumps(self.data).encode()).digest()
def merge(self, other: dict):
for key, value in other.items():
if key in self.data and value['ts'] < self.data[key]['ts']:
continue
self.data[key] = value
2018-05-25 08:03:25 +02:00
class Wallet:
""" The primary role of Wallet is to encapsulate a collection
of accounts (seed/private keys) and the spending rules / settings
for the coins attached to those accounts. Wallets are represented
by physical files on the filesystem.
"""
preferences: TimestampedPreferences
encryption_password: Optional[str]
2020-01-03 04:18:49 +01:00
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['Account'] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None:
2018-05-25 08:03:25 +02:00
self.name = name
self.accounts = accounts or []
2018-05-25 08:03:25 +02:00
self.storage = storage or WalletStorage()
self.preferences = TimestampedPreferences(preferences or {})
self.encryption_password = None
2020-01-27 19:17:59 +01:00
self.id = self.get_id()
2018-05-25 08:03:25 +02:00
2020-01-27 19:17:59 +01:00
def get_id(self):
return os.path.basename(self.storage.path) if self.storage.path else self.name
2020-01-03 04:18:49 +01:00
def add_account(self, account: 'Account'):
2018-05-25 08:03:25 +02:00
self.accounts.append(account)
2020-01-03 04:18:49 +01:00
def generate_account(self, ledger: 'Ledger') -> 'Account':
return Account.generate(ledger, self)
2018-05-25 08:03:25 +02:00
@property
2020-01-03 04:18:49 +01:00
def default_account(self) -> Optional['Account']:
for account in self.accounts:
return account
return None
2020-01-03 04:18:49 +01:00
def get_account_or_default(self, account_id: str) -> Optional['Account']:
if account_id is None:
return self.default_account
return self.get_account_or_error(account_id)
2020-01-03 04:18:49 +01:00
def get_account_or_error(self, account_id: str) -> 'Account':
for account in self.accounts:
if account.id == account_id:
return account
raise ValueError(f"Couldn't find account: {account_id}.")
2020-01-03 04:18:49 +01:00
def get_accounts_or_all(self, account_ids: List[str]) -> Sequence['Account']:
return [
self.get_account_or_error(account_id)
for account_id in account_ids
] if account_ids else self.accounts
async def get_detailed_accounts(self, **kwargs):
accounts = []
for i, account in enumerate(self.accounts):
details = await account.get_details(**kwargs)
details['is_default'] = i == 0
accounts.append(details)
return accounts
2018-05-25 08:03:25 +02:00
@classmethod
2020-01-03 04:18:49 +01:00
def from_storage(cls, storage: 'WalletStorage', manager: 'WalletManager') -> 'Wallet':
2018-05-25 08:03:25 +02:00
json_dict = storage.read()
wallet = cls(
2018-05-25 08:03:25 +02:00
name=json_dict.get('name', 'Wallet'),
preferences=json_dict.get('preferences', {}),
2018-05-25 08:03:25 +02:00
storage=storage
)
account_dicts: Sequence[dict] = json_dict.get('accounts', [])
for account_dict in account_dicts:
ledger = manager.get_or_create_ledger(account_dict['ledger'])
2020-01-03 04:18:49 +01:00
Account.from_dict(ledger, wallet, account_dict)
return wallet
2018-05-25 08:03:25 +02:00
def to_dict(self, encrypt_password: str = None):
2018-05-25 08:03:25 +02:00
return {
'version': WalletStorage.LATEST_VERSION,
2018-05-25 08:03:25 +02:00
'name': self.name,
'preferences': self.preferences.data,
'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
2018-05-25 08:03:25 +02:00
}
2022-10-18 05:59:26 +02:00
def to_json(self):
assert not self.is_locked, "Cannot serialize a wallet with locked/encrypted accounts."
return json.dumps(self.to_dict())
2018-05-25 08:03:25 +02:00
def save(self):
2019-10-18 19:38:51 +02:00
if self.preferences.get(ENCRYPT_ON_DISK, False):
if self.encryption_password is not None:
2019-10-18 19:38:51 +02:00
return self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
elif not self.is_locked:
log.warning(
"Disk encryption requested but no password available for encryption. "
"Resetting encryption preferences and saving wallet in an unencrypted state."
)
self.preferences[ENCRYPT_ON_DISK] = False
2019-10-18 19:38:51 +02:00
return self.storage.write(self.to_dict())
2018-05-25 08:03:25 +02:00
2019-03-11 17:04:06 +01:00
@property
2019-03-11 17:30:32 +01:00
def hash(self) -> bytes:
2019-03-11 17:04:06 +01:00
h = sha256()
if self.is_encrypted:
assert self.encryption_password is not None, \
"Encryption is enabled but no password is available, cannot generate hash."
h.update(self.encryption_password.encode())
h.update(self.preferences.hash)
2019-03-11 17:04:06 +01:00
for account in sorted(self.accounts, key=attrgetter('id')):
h.update(account.hash)
return h.digest()
2022-10-18 05:59:26 +02:00
def pack(self, password):
assert not self.is_locked, "Cannot pack a wallet with locked/encrypted accounts."
2022-10-18 05:59:26 +02:00
new_data_compressed = zlib.compress(self.to_json().encode())
2019-03-11 17:04:06 +01:00
return better_aes_encrypt(password, new_data_compressed)
@classmethod
2022-10-18 05:59:26 +02:00
def unpack(cls, password, encrypted):
decrypted = better_aes_decrypt(password, encrypted)
try:
decompressed = zlib.decompress(decrypted)
except zlib.error as e:
if "incorrect header check" in e.args[0].lower():
raise InvalidPasswordError()
if "unknown compression method" in e.args[0].lower():
raise InvalidPasswordError()
raise
2019-03-11 17:04:06 +01:00
return json.loads(decompressed)
2020-01-03 04:18:49 +01:00
def merge(self, manager: 'WalletManager',
password: str, data: str) -> (List['Account'], List['Account']):
assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts, merged_accounts = [], []
if password is None:
2022-10-18 05:59:26 +02:00
decrypted_data = json.loads(data)
else:
decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']:
ledger = manager.get_or_create_ledger(account_dict['ledger'])
2020-01-03 04:18:49 +01:00
_, _, pubkey = Account.keys_from_dict(ledger, account_dict)
account_id = pubkey.address
local_match = None
for local_account in self.accounts:
if account_id == local_account.id:
local_match = local_account
break
if local_match is not None:
local_match.merge(account_dict)
merged_accounts.append(local_match)
else:
2020-01-03 04:18:49 +01:00
new_account = Account.from_dict(ledger, self, account_dict)
added_accounts.append(new_account)
return added_accounts, merged_accounts
2019-10-14 05:43:06 +02:00
@property
def is_locked(self) -> bool:
for account in self.accounts:
if account.encrypted:
return True
return False
2022-03-14 01:42:34 +01:00
async def unlock(self, password):
2019-10-14 05:43:06 +02:00
for account in self.accounts:
if account.encrypted:
if not account.decrypt(password):
return False
2022-03-14 01:42:34 +01:00
await account.deterministic_channel_keys.ensure_cache_primed()
self.encryption_password = password
return True
2019-10-14 05:43:06 +02:00
def lock(self):
assert self.encryption_password is not None, "Cannot lock an unencrypted wallet, encrypt first."
2019-10-14 05:43:06 +02:00
for account in self.accounts:
if not account.encrypted:
account.encrypt(self.encryption_password)
return True
2019-10-14 05:43:06 +02:00
@property
def is_encrypted(self) -> bool:
# either its locked or it was unlocked using a password.
# if its set to encrypt on preferences but isnt encrypted and no password was given so far,
# then its not encrypted
return self.is_locked or (
self.preferences.get(ENCRYPT_ON_DISK, False) and self.encryption_password is not None)
2019-10-14 05:43:06 +02:00
def decrypt(self):
assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
self.preferences[ENCRYPT_ON_DISK] = False
2019-10-14 05:43:06 +02:00
self.save()
return True
2019-10-14 05:43:06 +02:00
def encrypt(self, password):
assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first."
2019-10-18 18:43:28 +02:00
assert password, "Cannot encrypt with blank password."
self.encryption_password = password
self.preferences[ENCRYPT_ON_DISK] = True
2019-10-14 05:43:06 +02:00
self.save()
return True
2019-10-14 05:43:06 +02:00
2018-05-25 08:03:25 +02:00
class WalletStorage:
2018-07-01 23:20:17 +02:00
LATEST_VERSION = 1
2018-05-25 08:03:25 +02:00
def __init__(self, path=None, default=None):
self.path = path
2018-07-12 04:37:15 +02:00
self._default = default or {
'version': self.LATEST_VERSION,
'name': 'My Wallet',
'preferences': {},
2018-07-12 04:37:15 +02:00
'accounts': []
}
2018-05-25 08:03:25 +02:00
def read(self):
if self.path and os.path.exists(self.path):
2018-07-12 04:37:15 +02:00
with open(self.path, 'r') as f:
2018-05-25 08:03:25 +02:00
json_data = f.read()
json_dict = json.loads(json_data)
if json_dict.get('version') == self.LATEST_VERSION and \
set(json_dict) == set(self._default):
return json_dict
else:
return self.upgrade(json_dict)
else:
2018-07-12 04:37:15 +02:00
return self._default.copy()
2018-05-25 08:03:25 +02:00
2018-07-12 04:37:15 +02:00
def upgrade(self, json_dict):
2018-05-25 08:03:25 +02:00
json_dict = json_dict.copy()
version = json_dict.pop('version', -1)
2018-07-12 04:37:15 +02:00
if version == -1:
2018-07-01 23:20:17 +02:00
pass
2018-07-12 04:37:15 +02:00
upgraded = self._default.copy()
2018-05-25 08:03:25 +02:00
upgraded.update(json_dict)
return json_dict
def write(self, json_dict):
json_data = json.dumps(json_dict, indent=4, sort_keys=True)
if self.path is None:
return json_data
temp_path = "{}.tmp.{}".format(self.path, os.getpid())
2018-05-25 08:03:25 +02:00
with open(temp_path, "w") as f:
f.write(json_data)
f.flush()
os.fsync(f.fileno())
if os.path.exists(self.path):
mode = os.stat(self.path).st_mode
else:
mode = stat.S_IREAD | stat.S_IWRITE
try:
os.rename(temp_path, self.path)
except Exception: # pylint: disable=broad-except
2018-05-25 08:03:25 +02:00
os.remove(self.path)
os.rename(temp_path, self.path)
os.chmod(self.path, mode)