From e5be1ff157ff359bd9e521262788865fe50a449d Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 16 Oct 2019 01:18:39 -0400 Subject: [PATCH] bug fixes and encrypt-on-disk preference --- lbry/lbry/extras/daemon/Daemon.py | 25 +-- lbry/lbry/wallet/account.py | 6 +- .../tests/integration/test_wallet_commands.py | 142 +++++++++++++++--- torba/tests/client_tests/unit/test_account.py | 34 ++--- torba/tests/client_tests/unit/test_wallet.py | 6 + torba/torba/client/baseaccount.py | 48 +++--- torba/torba/client/wallet.py | 55 +++++-- 7 files changed, 217 insertions(+), 99 deletions(-) diff --git a/lbry/lbry/extras/daemon/Daemon.py b/lbry/lbry/extras/daemon/Daemon.py index 40eb38555..ca1c9fb26 100644 --- a/lbry/lbry/extras/daemon/Daemon.py +++ b/lbry/lbry/extras/daemon/Daemon.py @@ -1197,8 +1197,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: (bool) true if wallet is unlocked, otherwise false """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - return wallet.unlock(password) + return self.wallet_manager.get_wallet_or_default(wallet_id).unlock(password) @requires(WALLET_COMPONENT) def jsonrpc_wallet_lock(self, wallet_id=None): @@ -1214,8 +1213,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: (bool) true if wallet is locked, otherwise false """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - return wallet.lock() + return self.wallet_manager.get_wallet_or_default(wallet_id).lock() @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) def jsonrpc_wallet_decrypt(self, wallet_id=None): @@ -1231,8 +1229,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: (bool) true if wallet is decrypted, otherwise false """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - return wallet.decrypt() + return self.wallet_manager.get_wallet_or_default(wallet_id).decrypt() @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) def jsonrpc_wallet_encrypt(self, new_password, wallet_id=None): @@ -1250,8 +1247,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: (bool) true if wallet is decrypted, otherwise false """ - wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - return wallet.encrypt(new_password) + return self.wallet_manager.get_wallet_or_default(wallet_id).encrypt(new_password) @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) async def jsonrpc_wallet_send( @@ -1604,7 +1600,7 @@ class Daemon(metaclass=JSONRPCServerType): Wallet synchronization. """ - @requires("wallet") + @requires("wallet", conditions=[WALLET_IS_UNLOCKED]) def jsonrpc_sync_hash(self, wallet_id=None): """ Deterministic hash of the wallet. @@ -1621,21 +1617,18 @@ class Daemon(metaclass=JSONRPCServerType): wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return hexlify(wallet.hash).decode() - @requires("wallet") + @requires("wallet", conditions=[WALLET_IS_UNLOCKED]) async def jsonrpc_sync_apply(self, password, data=None, encrypt_password=None, wallet_id=None, blocking=False): """ Apply incoming synchronization data, if provided, and then produce a sync hash and an encrypted wallet. Usage: - sync_apply [--data=] [--encrypt-password=] - [--wallet_id=] [--blocking] + sync_apply [--data=] [--wallet_id=] [--blocking] Options: --password= : (str) password to decrypt incoming and encrypt outgoing data --data= : (str) incoming sync data, if any - --encrypt-password= : (str) password to encrypt outgoing data if different - from the decrypt password, used during password changes --wallet_id= : (str) wallet being sync'ed --blocking : (bool) wait until any new accounts have sync'ed @@ -1644,7 +1637,6 @@ class Daemon(metaclass=JSONRPCServerType): """ wallet = self.wallet_manager.get_wallet_or_default(wallet_id) - assert not wallet.is_locked, "Cannot sync apply on a locked wallet." if data is not None: added_accounts = wallet.merge(self.wallet_manager, password, data) if added_accounts and self.ledger.network.is_connected: @@ -1656,8 +1648,7 @@ class Daemon(metaclass=JSONRPCServerType): for new_account in added_accounts: asyncio.create_task(self.ledger.subscribe_account(new_account)) wallet.save() - wallet.unlock(password) - encrypted = wallet.pack(encrypt_password or password) + encrypted = wallet.pack(password) return { 'hash': self.jsonrpc_sync_hash(wallet_id), 'data': encrypted.decode() diff --git a/lbry/lbry/wallet/account.py b/lbry/lbry/wallet/account.py index c3e9b08dc..126d09f85 100644 --- a/lbry/lbry/wallet/account.py +++ b/lbry/lbry/wallet/account.py @@ -30,7 +30,7 @@ class Account(BaseAccount): @property def hash(self) -> bytes: - h = sha256(json.dumps(self.to_dict(False)).encode()) + h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode()) for cert in sorted(self.channel_keys.keys()): h.update(cert.encode()) return h.digest() @@ -119,8 +119,8 @@ class Account(BaseAccount): account.channel_keys = d.get('certificates', {}) return account - def to_dict(self, include_channel_keys=True): - d = super().to_dict() + def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True): + d = super().to_dict(encrypt_password) if include_channel_keys: d['certificates'] = self.channel_keys return d diff --git a/lbry/tests/integration/test_wallet_commands.py b/lbry/tests/integration/test_wallet_commands.py index 161eef8bc..7a3d0062f 100644 --- a/lbry/tests/integration/test_wallet_commands.py +++ b/lbry/tests/integration/test_wallet_commands.py @@ -1,20 +1,33 @@ import asyncio +import json +from lbry import error from lbry.testcase import CommandTestCase -from binascii import unhexlify +from torba.client.wallet import ENCRYPT_ON_DISK -class WalletSynchronization(CommandTestCase): - SEED = "carbon smart garage balance margin twelve chest sword toast envelope bottom stomach absent" +class WalletEncryptionAndSynchronization(CommandTestCase): - async def test_sync(self): - daemon = self.daemon - daemon2 = await self.add_daemon( + SEED = ( + "carbon smart garage balance margin twelve chest " + "sword toast envelope bottom stomach absent" + ) + + async def asyncSetUp(self): + await super().asyncSetUp() + self.daemon2 = await self.add_daemon( seed="chest sword toast envelope bottom stomach absent " "carbon smart garage balance margin twelve" ) - address = (await daemon2.wallet_manager.default_account.receiving.get_addresses(limit=1, only_usable=True))[0] + address = (await self.daemon2.wallet_manager.default_account.receiving.get_addresses(limit=1, only_usable=True))[0] sendtxid = await self.blockchain.send_to_address(address, 1) - await self.confirm_tx(sendtxid, daemon2.ledger) + await self.confirm_tx(sendtxid, self.daemon2.ledger) + + def assertWalletEncrypted(self, wallet_path, encrypted): + wallet = json.load(open(wallet_path)) + self.assertEqual(wallet['accounts'][0]['private_key'][1:4] != 'prv', encrypted) + + async def test_sync(self): + daemon, daemon2 = self.daemon, self.daemon2 # Preferences self.assertFalse(daemon.jsonrpc_preference_get()) @@ -30,16 +43,12 @@ class WalletSynchronization(CommandTestCase): self.assertDictEqual(daemon.jsonrpc_preference_get(), { "one": "1", "conflict": "1", "fruit": ["peach", "apricot"] }) - self.assertDictEqual(daemon2.jsonrpc_preference_get(), {"two": "2", "conflict": "2"}) + self.assertDictEqual(daemon2.jsonrpc_preference_get(), { + "two": "2", "conflict": "2" + }) self.assertEqual(len((await daemon.jsonrpc_account_list())['lbc_regtest']), 1) - daemon2.jsonrpc_wallet_encrypt('password') - daemon2.jsonrpc_wallet_lock() - with self.assertRaises(AssertionError): - await daemon2.jsonrpc_sync_apply('password') - - daemon2.jsonrpc_wallet_unlock('password') data = await daemon2.jsonrpc_sync_apply('password') await daemon.jsonrpc_sync_apply('password', data=data['data'], blocking=True) @@ -52,9 +61,7 @@ class WalletSynchronization(CommandTestCase): # Channel Certificate channel = await daemon2.jsonrpc_channel_create('@foo', '0.1') - await daemon2.ledger.wait(channel) - await self.generate(1) - await daemon2.ledger.wait(channel) + await self.confirm_tx(channel.id, self.daemon2.ledger) # both daemons will have the channel but only one has the cert so far self.assertEqual(len(await daemon.jsonrpc_channel_list()), 1) @@ -70,3 +77,102 @@ class WalletSynchronization(CommandTestCase): daemon2.wallet_manager.default_account.channel_keys, daemon.wallet_manager.default_wallet.accounts[1].channel_keys ) + + async def test_encryption_and_locking(self): + daemon = self.daemon + wallet = daemon.wallet_manager.default_wallet + + self.assertEqual( + daemon.jsonrpc_wallet_status(), + {'is_locked': False, 'is_encrypted': False} + ) + self.assertIsNone(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK)) + self.assertWalletEncrypted(wallet.storage.path, False) + + # can't lock an unencrypted account + with self.assertRaisesRegex(AssertionError, "Cannot lock an unencrypted wallet, encrypt first."): + daemon.jsonrpc_wallet_lock() + # safe to call unlock and decrypt, they are no-ops at this point + daemon.jsonrpc_wallet_unlock('password') # already unlocked + daemon.jsonrpc_wallet_decrypt() # already not encrypted + + daemon.jsonrpc_wallet_encrypt('password') + + self.assertEqual( + daemon.jsonrpc_wallet_status(), + {'is_locked': False, 'is_encrypted': True} + ) + self.assertEqual( + daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), + {'encrypt-on-disk': True} + ) + self.assertWalletEncrypted(wallet.storage.path, True) + + daemon.jsonrpc_wallet_lock() + + self.assertEqual( + daemon.jsonrpc_wallet_status(), + {'is_locked': True, 'is_encrypted': True} + ) + + with self.assertRaises(error.ComponentStartConditionNotMet): + await daemon.jsonrpc_channel_create('@foo', '1.0') + + daemon.jsonrpc_wallet_unlock('password') + await daemon.jsonrpc_channel_create('@foo', '1.0') + + daemon.jsonrpc_wallet_decrypt() + self.assertEqual( + daemon.jsonrpc_wallet_status(), + {'is_locked': False, 'is_encrypted': False} + ) + self.assertEqual( + daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), + {'encrypt-on-disk': False} + ) + self.assertWalletEncrypted(wallet.storage.path, False) + + async def test_sync_with_encryption_and_password_change(self): + daemon, daemon2 = self.daemon, self.daemon2 + wallet, wallet2 = daemon.wallet_manager.default_wallet, daemon2.wallet_manager.default_wallet + + daemon.jsonrpc_wallet_encrypt('password') + + self.assertEqual(daemon.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True}) + self.assertEqual(daemon2.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': False}) + self.assertEqual(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True}) + self.assertIsNone(daemon2.jsonrpc_preference_get(ENCRYPT_ON_DISK)) + self.assertWalletEncrypted(wallet.storage.path, True) + self.assertWalletEncrypted(wallet2.storage.path, False) + + data = await daemon2.jsonrpc_sync_apply('password2') + with self.assertRaises(ValueError): # wrong password + await daemon.jsonrpc_sync_apply('password', data=data['data'], blocking=True) + await daemon.jsonrpc_sync_apply('password2', data=data['data'], blocking=True) + + # encryption did not change from before sync_apply + self.assertEqual(daemon.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True}) + self.assertEqual(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True}) + self.assertWalletEncrypted(wallet.storage.path, True) + + # old password is still used + daemon.jsonrpc_wallet_lock() + self.assertFalse(daemon.jsonrpc_wallet_unlock('password2')) + self.assertTrue(daemon.jsonrpc_wallet_unlock('password')) + + # encrypt using new password + daemon.jsonrpc_wallet_encrypt('password2') + daemon.jsonrpc_wallet_lock() + self.assertFalse(daemon.jsonrpc_wallet_unlock('password')) + self.assertTrue(daemon.jsonrpc_wallet_unlock('password2')) + + data = await daemon.jsonrpc_sync_apply('password2') + await daemon2.jsonrpc_sync_apply('password2', data=data['data'], blocking=True) + + # wallet2 is now encrypted using new password + self.assertEqual(daemon2.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True}) + self.assertEqual(daemon2.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True}) + self.assertWalletEncrypted(wallet.storage.path, True) + + daemon2.jsonrpc_wallet_lock() + self.assertTrue(daemon2.jsonrpc_wallet_unlock('password2')) diff --git a/torba/tests/client_tests/unit/test_account.py b/torba/tests/client_tests/unit/test_account.py index a9785ac2d..3afdfcaae 100644 --- a/torba/tests/client_tests/unit/test_account.py +++ b/torba/tests/client_tests/unit/test_account.py @@ -435,14 +435,14 @@ class AccountEncryptionTests(AsyncioTestCase): def test_encrypt_wallet(self): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account) - account.private_key_encryption_init_vector = self.init_vector - account.seed_encryption_init_vector = self.init_vector + account.init_vectors = { + 'seed': self.init_vector, + 'private_key': self.init_vector + } - self.assertFalse(account.serialize_encrypted) self.assertFalse(account.encrypted) self.assertIsNotNone(account.private_key) account.encrypt(self.password) - self.assertFalse(account.serialize_encrypted) self.assertTrue(account.encrypted) self.assertEqual(account.seed, self.encrypted_account['seed']) self.assertEqual(account.private_key_string, self.encrypted_account['private_key']) @@ -451,42 +451,32 @@ class AccountEncryptionTests(AsyncioTestCase): self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) - account.serialize_encrypted = True account.decrypt(self.password) - self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) - self.assertEqual(account.seed_encryption_init_vector, self.init_vector) + self.assertEqual(account.init_vectors['private_key'], self.init_vector) + self.assertEqual(account.init_vectors['seed'], self.init_vector) self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) - self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) - self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) + self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed']) + self.assertEqual(account.to_dict(encrypt_password=self.password)['private_key'], self.encrypted_account['private_key']) self.assertFalse(account.encrypted) - self.assertTrue(account.serialize_encrypted) - - account.serialize_encrypted = False - self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed']) - self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key']) def test_decrypt_wallet(self): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account) self.assertTrue(account.encrypted) - self.assertTrue(account.serialize_encrypted) account.decrypt(self.password) - self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) - self.assertEqual(account.seed_encryption_init_vector, self.init_vector) + self.assertEqual(account.init_vectors['private_key'], self.init_vector) + self.assertEqual(account.init_vectors['seed'], self.init_vector) self.assertFalse(account.encrypted) - self.assertTrue(account.serialize_encrypted) self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) - self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) - self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) - - account.serialize_encrypted = False + self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed']) + self.assertEqual(account.to_dict(encrypt_password=self.password)['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed']) self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key']) diff --git a/torba/tests/client_tests/unit/test_wallet.py b/torba/tests/client_tests/unit/test_wallet.py index a9424e2d4..7bf88c5ff 100644 --- a/torba/tests/client_tests/unit/test_wallet.py +++ b/torba/tests/client_tests/unit/test_wallet.py @@ -116,6 +116,12 @@ class TestWalletCreation(AsyncioTestCase): class TestTimestampedPreferences(TestCase): + def test_init(self): + p = TimestampedPreferences() + p['one'] = 1 + p2 = TimestampedPreferences(p.data) + self.assertEqual(p2['one'], 1) + def test_hash(self): p = TimestampedPreferences() self.assertEqual( diff --git a/torba/torba/client/baseaccount.py b/torba/torba/client/baseaccount.py index 21431c662..693a65d10 100644 --- a/torba/torba/client/baseaccount.py +++ b/torba/torba/client/baseaccount.py @@ -1,3 +1,4 @@ +import os import json import time import asyncio @@ -221,12 +222,8 @@ class BaseAccount: self.seed = seed self.modified_on = modified_on self.private_key_string = private_key_string - self.password: Optional[str] = None - self.private_key_encryption_init_vector: Optional[bytes] = None - self.seed_encryption_init_vector: Optional[bytes] = None - + self.init_vectors: Dict[str, bytes] = {} self.encrypted = encrypted - self.serialize_encrypted = encrypted self.private_key = private_key self.public_key = public_key generator_name = address_generator.get('name', HierarchicalDeterministic.name) @@ -236,6 +233,12 @@ class BaseAccount: ledger.add_account(self) wallet.add_account(self) + def get_init_vector(self, key) -> Optional[bytes]: + init_vector = self.init_vectors.get(key, None) + if init_vector is None: + init_vector = self.init_vectors[key] = os.urandom(16) + return init_vector + @classmethod def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str = None, address_generator: dict = None): @@ -289,21 +292,20 @@ class BaseAccount: modified_on=d.get('modified_on', time.time()) ) - def to_dict(self): + def to_dict(self, encrypt_password: str = None): private_key_string, seed = self.private_key_string, self.seed if not self.encrypted and self.private_key: private_key_string = self.private_key.extended_key_string() - if not self.encrypted and self.serialize_encrypted: - assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector] + if not self.encrypted and encrypt_password: private_key_string = aes_encrypt( - self.password, private_key_string, self.private_key_encryption_init_vector + encrypt_password, private_key_string, self.get_init_vector('private_key') ) - seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector) + seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed')) return { 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': seed, - 'encrypted': self.serialize_encrypted, + 'encrypted': bool(self.encrypted or encrypt_password), 'private_key': private_key_string, 'public_key': self.public_key.extended_key_string(), 'address_generator': self.address_generator.to_dict(self.receiving, self.change), @@ -322,6 +324,7 @@ class BaseAccount: @property def hash(self) -> bytes: + assert not self.encrypted, "Cannot hash an encrypted account." return sha256(json.dumps(self.to_dict()).encode()) async def get_details(self, show_seed=False, **kwargs): @@ -339,42 +342,41 @@ class BaseAccount: details['seed'] = self.seed return details - def decrypt(self, password: str) -> None: + def decrypt(self, password: str) -> bool: assert self.encrypted, "Key is not encrypted." try: seed, seed_iv = aes_decrypt(password, self.seed) pk_string, pk_iv = aes_decrypt(password, self.private_key_string) except ValueError: # failed to remove padding, password is wrong - return + return False try: Mnemonic().mnemonic_decode(seed) except IndexError: # failed to decode the seed, this either means it decrypted and is invalid # or that we hit an edge case where an incorrect password gave valid padding - return + return False try: private_key = from_extended_key_string( self.ledger, pk_string ) except (TypeError, ValueError): - return + return False self.seed = seed - self.seed_encryption_init_vector = seed_iv self.private_key = private_key - self.private_key_encryption_init_vector = pk_iv - self.password = password + self.init_vectors['seed'] = seed_iv + self.init_vectors['private_key'] = pk_iv self.encrypted = False + return True - def encrypt(self, password: str) -> None: + def encrypt(self, password: str) -> bool: assert not self.encrypted, "Key is already encrypted." assert isinstance(self.private_key, PrivateKey) - - self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector) + self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed')) self.private_key_string = aes_encrypt( - password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector + password, self.private_key.extended_key_string(), self.get_init_vector('private_key') ) self.private_key = None - self.password = None self.encrypted = True + return True async def ensure_address_gap(self): addresses = [] diff --git a/torba/torba/client/wallet.py b/torba/torba/client/wallet.py index d00026ab4..dc94d34d2 100644 --- a/torba/torba/client/wallet.py +++ b/torba/torba/client/wallet.py @@ -4,6 +4,7 @@ import stat import json import zlib import typing +import logging from typing import List, Sequence, MutableSequence, Optional from collections import UserDict from hashlib import sha256 @@ -14,8 +15,18 @@ if typing.TYPE_CHECKING: from torba.client import basemanager, baseaccount, baseledger +log = logging.getLogger(__name__) + +ENCRYPT_ON_DISK = 'encrypt-on-disk' + + class TimestampedPreferences(UserDict): + def __init__(self, d: dict = None): + super().__init__() + if d is not None: + self.data = d.copy() + def __getitem__(self, key): return self.data[key]['value'] @@ -52,6 +63,7 @@ class Wallet: """ preferences: TimestampedPreferences + encryption_password: Optional[str] def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None, storage: 'WalletStorage' = None, preferences: dict = None) -> None: @@ -59,6 +71,7 @@ class Wallet: self.accounts = accounts or [] self.storage = storage or WalletStorage() self.preferences = TimestampedPreferences(preferences or {}) + self.encryption_password = None @property def id(self): @@ -119,15 +132,24 @@ class Wallet: ledger.account_class.from_dict(ledger, wallet, account_dict) return wallet - def to_dict(self): + def to_dict(self, encrypt_password: str = None): return { 'version': WalletStorage.LATEST_VERSION, 'name': self.name, 'preferences': self.preferences.data, - 'accounts': [a.to_dict() for a in self.accounts] + 'accounts': [a.to_dict(encrypt_password) for a in self.accounts] } def save(self): + if self.preferences.get(ENCRYPT_ON_DISK, False): + if self.encryption_password: + self.storage.write(self.to_dict(encrypt_password=self.encryption_password)) + return + else: + log.warning( + "Disk encryption requested but no password available for encryption. " + "Saving wallet in an unencrypted state." + ) self.storage.write(self.to_dict()) @property @@ -139,6 +161,7 @@ class Wallet: return h.digest() def pack(self, password): + assert not self.is_locked, "Cannot pack a wallet with locked/encrypted accounts." new_data = json.dumps(self.to_dict()) new_data_compressed = zlib.compress(new_data.encode()) return better_aes_encrypt(password, new_data_compressed) @@ -151,9 +174,12 @@ class Wallet: def merge(self, manager: 'basemanager.BaseWalletManager', password: str, data: str) -> List['baseaccount.BaseAccount']: + assert not self.is_locked, "Cannot sync apply on a locked wallet." added_accounts = [] decrypted_data = self.unpack(password, data) self.preferences.merge(decrypted_data.get('preferences', {})) + if not self.encryption_password and self.preferences.get(ENCRYPT_ON_DISK, False): + self.encryption_password = password for account_dict in decrypted_data['accounts']: ledger = manager.get_or_create_ledger(account_dict['ledger']) _, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict) @@ -178,38 +204,35 @@ class Wallet: return False def unlock(self, password): + self.encryption_password = password for account in self.accounts: if account.encrypted: - account.decrypt(password) + if not account.decrypt(password): + return False return True def lock(self): + assert self.encryption_password is not None, "Cannot lock an unencrypted wallet, encrypt first." for account in self.accounts: if not account.encrypted: - assert account.password is not None, "account was never encrypted" - account.encrypt(account.password) + account.encrypt(self.encryption_password) return True @property def is_encrypted(self) -> bool: - for account in self.accounts: - if account.serialize_encrypted: - return True - return False + return self.is_locked or self.preferences.get(ENCRYPT_ON_DISK, False) def decrypt(self): - for account in self.accounts: - account.serialize_encrypted = False + assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first." + self.preferences[ENCRYPT_ON_DISK] = False self.save() return True def encrypt(self, password): - for account in self.accounts: - if not account.encrypted: - account.encrypt(password) - account.serialize_encrypted = True + assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first." + self.encryption_password = password + self.preferences[ENCRYPT_ON_DISK] = True self.save() - self.unlock(password) return True