bug fixes and encrypt-on-disk preference

This commit is contained in:
Lex Berezhny 2019-10-16 01:18:39 -04:00
parent 398fe55083
commit e5be1ff157
7 changed files with 217 additions and 99 deletions

View file

@ -1197,8 +1197,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: Returns:
(bool) true if wallet is unlocked, otherwise false (bool) true if wallet is unlocked, otherwise false
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return self.wallet_manager.get_wallet_or_default(wallet_id).unlock(password)
return wallet.unlock(password)
@requires(WALLET_COMPONENT) @requires(WALLET_COMPONENT)
def jsonrpc_wallet_lock(self, wallet_id=None): def jsonrpc_wallet_lock(self, wallet_id=None):
@ -1214,8 +1213,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: Returns:
(bool) true if wallet is locked, otherwise false (bool) true if wallet is locked, otherwise false
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return self.wallet_manager.get_wallet_or_default(wallet_id).lock()
return wallet.lock()
@requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED])
def jsonrpc_wallet_decrypt(self, wallet_id=None): def jsonrpc_wallet_decrypt(self, wallet_id=None):
@ -1231,8 +1229,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: Returns:
(bool) true if wallet is decrypted, otherwise false (bool) true if wallet is decrypted, otherwise false
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return self.wallet_manager.get_wallet_or_default(wallet_id).decrypt()
return wallet.decrypt()
@requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED])
def jsonrpc_wallet_encrypt(self, new_password, wallet_id=None): def jsonrpc_wallet_encrypt(self, new_password, wallet_id=None):
@ -1250,8 +1247,7 @@ class Daemon(metaclass=JSONRPCServerType):
Returns: Returns:
(bool) true if wallet is decrypted, otherwise false (bool) true if wallet is decrypted, otherwise false
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) return self.wallet_manager.get_wallet_or_default(wallet_id).encrypt(new_password)
return wallet.encrypt(new_password)
@requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED]) @requires(WALLET_COMPONENT, conditions=[WALLET_IS_UNLOCKED])
async def jsonrpc_wallet_send( async def jsonrpc_wallet_send(
@ -1604,7 +1600,7 @@ class Daemon(metaclass=JSONRPCServerType):
Wallet synchronization. Wallet synchronization.
""" """
@requires("wallet") @requires("wallet", conditions=[WALLET_IS_UNLOCKED])
def jsonrpc_sync_hash(self, wallet_id=None): def jsonrpc_sync_hash(self, wallet_id=None):
""" """
Deterministic hash of the wallet. Deterministic hash of the wallet.
@ -1621,21 +1617,18 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
return hexlify(wallet.hash).decode() return hexlify(wallet.hash).decode()
@requires("wallet") @requires("wallet", conditions=[WALLET_IS_UNLOCKED])
async def jsonrpc_sync_apply(self, password, data=None, encrypt_password=None, wallet_id=None, blocking=False): async def jsonrpc_sync_apply(self, password, data=None, encrypt_password=None, wallet_id=None, blocking=False):
""" """
Apply incoming synchronization data, if provided, and then produce a sync hash and Apply incoming synchronization data, if provided, and then produce a sync hash and
an encrypted wallet. an encrypted wallet.
Usage: Usage:
sync_apply <password> [--data=<data>] [--encrypt-password=<encrypt_password>] sync_apply <password> [--data=<data>] [--wallet_id=<wallet_id>] [--blocking]
[--wallet_id=<wallet_id>] [--blocking]
Options: Options:
--password=<password> : (str) password to decrypt incoming and encrypt outgoing data --password=<password> : (str) password to decrypt incoming and encrypt outgoing data
--data=<data> : (str) incoming sync data, if any --data=<data> : (str) incoming sync data, if any
--encrypt-password=<encrypt_password> : (str) password to encrypt outgoing data if different
from the decrypt password, used during password changes
--wallet_id=<wallet_id> : (str) wallet being sync'ed --wallet_id=<wallet_id> : (str) wallet being sync'ed
--blocking : (bool) wait until any new accounts have sync'ed --blocking : (bool) wait until any new accounts have sync'ed
@ -1644,7 +1637,6 @@ class Daemon(metaclass=JSONRPCServerType):
""" """
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
assert not wallet.is_locked, "Cannot sync apply on a locked wallet."
if data is not None: if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data) added_accounts = wallet.merge(self.wallet_manager, password, data)
if added_accounts and self.ledger.network.is_connected: if added_accounts and self.ledger.network.is_connected:
@ -1656,8 +1648,7 @@ class Daemon(metaclass=JSONRPCServerType):
for new_account in added_accounts: for new_account in added_accounts:
asyncio.create_task(self.ledger.subscribe_account(new_account)) asyncio.create_task(self.ledger.subscribe_account(new_account))
wallet.save() wallet.save()
wallet.unlock(password) encrypted = wallet.pack(password)
encrypted = wallet.pack(encrypt_password or password)
return { return {
'hash': self.jsonrpc_sync_hash(wallet_id), 'hash': self.jsonrpc_sync_hash(wallet_id),
'data': encrypted.decode() 'data': encrypted.decode()

View file

@ -30,7 +30,7 @@ class Account(BaseAccount):
@property @property
def hash(self) -> bytes: def hash(self) -> bytes:
h = sha256(json.dumps(self.to_dict(False)).encode()) h = sha256(json.dumps(self.to_dict(include_channel_keys=False)).encode())
for cert in sorted(self.channel_keys.keys()): for cert in sorted(self.channel_keys.keys()):
h.update(cert.encode()) h.update(cert.encode())
return h.digest() return h.digest()
@ -119,8 +119,8 @@ class Account(BaseAccount):
account.channel_keys = d.get('certificates', {}) account.channel_keys = d.get('certificates', {})
return account return account
def to_dict(self, include_channel_keys=True): def to_dict(self, encrypt_password: str = None, include_channel_keys: bool = True):
d = super().to_dict() d = super().to_dict(encrypt_password)
if include_channel_keys: if include_channel_keys:
d['certificates'] = self.channel_keys d['certificates'] = self.channel_keys
return d return d

View file

@ -1,20 +1,33 @@
import asyncio import asyncio
import json
from lbry import error
from lbry.testcase import CommandTestCase from lbry.testcase import CommandTestCase
from binascii import unhexlify from torba.client.wallet import ENCRYPT_ON_DISK
class WalletSynchronization(CommandTestCase): class WalletEncryptionAndSynchronization(CommandTestCase):
SEED = "carbon smart garage balance margin twelve chest sword toast envelope bottom stomach absent"
async def test_sync(self): SEED = (
daemon = self.daemon "carbon smart garage balance margin twelve chest "
daemon2 = await self.add_daemon( "sword toast envelope bottom stomach absent"
)
async def asyncSetUp(self):
await super().asyncSetUp()
self.daemon2 = await self.add_daemon(
seed="chest sword toast envelope bottom stomach absent " seed="chest sword toast envelope bottom stomach absent "
"carbon smart garage balance margin twelve" "carbon smart garage balance margin twelve"
) )
address = (await daemon2.wallet_manager.default_account.receiving.get_addresses(limit=1, only_usable=True))[0] address = (await self.daemon2.wallet_manager.default_account.receiving.get_addresses(limit=1, only_usable=True))[0]
sendtxid = await self.blockchain.send_to_address(address, 1) sendtxid = await self.blockchain.send_to_address(address, 1)
await self.confirm_tx(sendtxid, daemon2.ledger) await self.confirm_tx(sendtxid, self.daemon2.ledger)
def assertWalletEncrypted(self, wallet_path, encrypted):
wallet = json.load(open(wallet_path))
self.assertEqual(wallet['accounts'][0]['private_key'][1:4] != 'prv', encrypted)
async def test_sync(self):
daemon, daemon2 = self.daemon, self.daemon2
# Preferences # Preferences
self.assertFalse(daemon.jsonrpc_preference_get()) self.assertFalse(daemon.jsonrpc_preference_get())
@ -30,16 +43,12 @@ class WalletSynchronization(CommandTestCase):
self.assertDictEqual(daemon.jsonrpc_preference_get(), { self.assertDictEqual(daemon.jsonrpc_preference_get(), {
"one": "1", "conflict": "1", "fruit": ["peach", "apricot"] "one": "1", "conflict": "1", "fruit": ["peach", "apricot"]
}) })
self.assertDictEqual(daemon2.jsonrpc_preference_get(), {"two": "2", "conflict": "2"}) self.assertDictEqual(daemon2.jsonrpc_preference_get(), {
"two": "2", "conflict": "2"
})
self.assertEqual(len((await daemon.jsonrpc_account_list())['lbc_regtest']), 1) self.assertEqual(len((await daemon.jsonrpc_account_list())['lbc_regtest']), 1)
daemon2.jsonrpc_wallet_encrypt('password')
daemon2.jsonrpc_wallet_lock()
with self.assertRaises(AssertionError):
await daemon2.jsonrpc_sync_apply('password')
daemon2.jsonrpc_wallet_unlock('password')
data = await daemon2.jsonrpc_sync_apply('password') data = await daemon2.jsonrpc_sync_apply('password')
await daemon.jsonrpc_sync_apply('password', data=data['data'], blocking=True) await daemon.jsonrpc_sync_apply('password', data=data['data'], blocking=True)
@ -52,9 +61,7 @@ class WalletSynchronization(CommandTestCase):
# Channel Certificate # Channel Certificate
channel = await daemon2.jsonrpc_channel_create('@foo', '0.1') channel = await daemon2.jsonrpc_channel_create('@foo', '0.1')
await daemon2.ledger.wait(channel) await self.confirm_tx(channel.id, self.daemon2.ledger)
await self.generate(1)
await daemon2.ledger.wait(channel)
# both daemons will have the channel but only one has the cert so far # both daemons will have the channel but only one has the cert so far
self.assertEqual(len(await daemon.jsonrpc_channel_list()), 1) self.assertEqual(len(await daemon.jsonrpc_channel_list()), 1)
@ -70,3 +77,102 @@ class WalletSynchronization(CommandTestCase):
daemon2.wallet_manager.default_account.channel_keys, daemon2.wallet_manager.default_account.channel_keys,
daemon.wallet_manager.default_wallet.accounts[1].channel_keys daemon.wallet_manager.default_wallet.accounts[1].channel_keys
) )
async def test_encryption_and_locking(self):
daemon = self.daemon
wallet = daemon.wallet_manager.default_wallet
self.assertEqual(
daemon.jsonrpc_wallet_status(),
{'is_locked': False, 'is_encrypted': False}
)
self.assertIsNone(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK))
self.assertWalletEncrypted(wallet.storage.path, False)
# can't lock an unencrypted account
with self.assertRaisesRegex(AssertionError, "Cannot lock an unencrypted wallet, encrypt first."):
daemon.jsonrpc_wallet_lock()
# safe to call unlock and decrypt, they are no-ops at this point
daemon.jsonrpc_wallet_unlock('password') # already unlocked
daemon.jsonrpc_wallet_decrypt() # already not encrypted
daemon.jsonrpc_wallet_encrypt('password')
self.assertEqual(
daemon.jsonrpc_wallet_status(),
{'is_locked': False, 'is_encrypted': True}
)
self.assertEqual(
daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK),
{'encrypt-on-disk': True}
)
self.assertWalletEncrypted(wallet.storage.path, True)
daemon.jsonrpc_wallet_lock()
self.assertEqual(
daemon.jsonrpc_wallet_status(),
{'is_locked': True, 'is_encrypted': True}
)
with self.assertRaises(error.ComponentStartConditionNotMet):
await daemon.jsonrpc_channel_create('@foo', '1.0')
daemon.jsonrpc_wallet_unlock('password')
await daemon.jsonrpc_channel_create('@foo', '1.0')
daemon.jsonrpc_wallet_decrypt()
self.assertEqual(
daemon.jsonrpc_wallet_status(),
{'is_locked': False, 'is_encrypted': False}
)
self.assertEqual(
daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK),
{'encrypt-on-disk': False}
)
self.assertWalletEncrypted(wallet.storage.path, False)
async def test_sync_with_encryption_and_password_change(self):
daemon, daemon2 = self.daemon, self.daemon2
wallet, wallet2 = daemon.wallet_manager.default_wallet, daemon2.wallet_manager.default_wallet
daemon.jsonrpc_wallet_encrypt('password')
self.assertEqual(daemon.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True})
self.assertEqual(daemon2.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': False})
self.assertEqual(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True})
self.assertIsNone(daemon2.jsonrpc_preference_get(ENCRYPT_ON_DISK))
self.assertWalletEncrypted(wallet.storage.path, True)
self.assertWalletEncrypted(wallet2.storage.path, False)
data = await daemon2.jsonrpc_sync_apply('password2')
with self.assertRaises(ValueError): # wrong password
await daemon.jsonrpc_sync_apply('password', data=data['data'], blocking=True)
await daemon.jsonrpc_sync_apply('password2', data=data['data'], blocking=True)
# encryption did not change from before sync_apply
self.assertEqual(daemon.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True})
self.assertEqual(daemon.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True})
self.assertWalletEncrypted(wallet.storage.path, True)
# old password is still used
daemon.jsonrpc_wallet_lock()
self.assertFalse(daemon.jsonrpc_wallet_unlock('password2'))
self.assertTrue(daemon.jsonrpc_wallet_unlock('password'))
# encrypt using new password
daemon.jsonrpc_wallet_encrypt('password2')
daemon.jsonrpc_wallet_lock()
self.assertFalse(daemon.jsonrpc_wallet_unlock('password'))
self.assertTrue(daemon.jsonrpc_wallet_unlock('password2'))
data = await daemon.jsonrpc_sync_apply('password2')
await daemon2.jsonrpc_sync_apply('password2', data=data['data'], blocking=True)
# wallet2 is now encrypted using new password
self.assertEqual(daemon2.jsonrpc_wallet_status(), {'is_locked': False, 'is_encrypted': True})
self.assertEqual(daemon2.jsonrpc_preference_get(ENCRYPT_ON_DISK), {'encrypt-on-disk': True})
self.assertWalletEncrypted(wallet.storage.path, True)
daemon2.jsonrpc_wallet_lock()
self.assertTrue(daemon2.jsonrpc_wallet_unlock('password2'))

View file

@ -435,14 +435,14 @@ class AccountEncryptionTests(AsyncioTestCase):
def test_encrypt_wallet(self): def test_encrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
account.private_key_encryption_init_vector = self.init_vector account.init_vectors = {
account.seed_encryption_init_vector = self.init_vector 'seed': self.init_vector,
'private_key': self.init_vector
}
self.assertFalse(account.serialize_encrypted)
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertIsNotNone(account.private_key) self.assertIsNotNone(account.private_key)
account.encrypt(self.password) account.encrypt(self.password)
self.assertFalse(account.serialize_encrypted)
self.assertTrue(account.encrypted) self.assertTrue(account.encrypted)
self.assertEqual(account.seed, self.encrypted_account['seed']) self.assertEqual(account.seed, self.encrypted_account['seed'])
self.assertEqual(account.private_key_string, self.encrypted_account['private_key']) self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
@ -451,42 +451,32 @@ class AccountEncryptionTests(AsyncioTestCase):
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
account.serialize_encrypted = True
account.decrypt(self.password) account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) self.assertEqual(account.init_vectors['private_key'], self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector) self.assertEqual(account.init_vectors['seed'], self.init_vector)
self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict(encrypt_password=self.password)['private_key'], self.encrypted_account['private_key'])
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted)
account.serialize_encrypted = False
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
def test_decrypt_wallet(self): def test_decrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
self.assertTrue(account.encrypted) self.assertTrue(account.encrypted)
self.assertTrue(account.serialize_encrypted)
account.decrypt(self.password) account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) self.assertEqual(account.init_vectors['private_key'], self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector) self.assertEqual(account.init_vectors['seed'], self.init_vector)
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted)
self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict(encrypt_password=self.password)['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict(encrypt_password=self.password)['private_key'], self.encrypted_account['private_key'])
account.serialize_encrypted = False
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed']) self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key']) self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])

View file

@ -116,6 +116,12 @@ class TestWalletCreation(AsyncioTestCase):
class TestTimestampedPreferences(TestCase): class TestTimestampedPreferences(TestCase):
def test_init(self):
p = TimestampedPreferences()
p['one'] = 1
p2 = TimestampedPreferences(p.data)
self.assertEqual(p2['one'], 1)
def test_hash(self): def test_hash(self):
p = TimestampedPreferences() p = TimestampedPreferences()
self.assertEqual( self.assertEqual(

View file

@ -1,3 +1,4 @@
import os
import json import json
import time import time
import asyncio import asyncio
@ -221,12 +222,8 @@ class BaseAccount:
self.seed = seed self.seed = seed
self.modified_on = modified_on self.modified_on = modified_on
self.private_key_string = private_key_string self.private_key_string = private_key_string
self.password: Optional[str] = None self.init_vectors: Dict[str, bytes] = {}
self.private_key_encryption_init_vector: Optional[bytes] = None
self.seed_encryption_init_vector: Optional[bytes] = None
self.encrypted = encrypted self.encrypted = encrypted
self.serialize_encrypted = encrypted
self.private_key = private_key self.private_key = private_key
self.public_key = public_key self.public_key = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name) generator_name = address_generator.get('name', HierarchicalDeterministic.name)
@ -236,6 +233,12 @@ class BaseAccount:
ledger.add_account(self) ledger.add_account(self)
wallet.add_account(self) wallet.add_account(self)
def get_init_vector(self, key) -> Optional[bytes]:
init_vector = self.init_vectors.get(key, None)
if init_vector is None:
init_vector = self.init_vectors[key] = os.urandom(16)
return init_vector
@classmethod @classmethod
def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet',
name: str = None, address_generator: dict = None): name: str = None, address_generator: dict = None):
@ -289,21 +292,20 @@ class BaseAccount:
modified_on=d.get('modified_on', time.time()) modified_on=d.get('modified_on', time.time())
) )
def to_dict(self): def to_dict(self, encrypt_password: str = None):
private_key_string, seed = self.private_key_string, self.seed private_key_string, seed = self.private_key_string, self.seed
if not self.encrypted and self.private_key: if not self.encrypted and self.private_key:
private_key_string = self.private_key.extended_key_string() private_key_string = self.private_key.extended_key_string()
if not self.encrypted and self.serialize_encrypted: if not self.encrypted and encrypt_password:
assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
private_key_string = aes_encrypt( private_key_string = aes_encrypt(
self.password, private_key_string, self.private_key_encryption_init_vector encrypt_password, private_key_string, self.get_init_vector('private_key')
) )
seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector) seed = aes_encrypt(encrypt_password, self.seed, self.get_init_vector('seed'))
return { return {
'ledger': self.ledger.get_id(), 'ledger': self.ledger.get_id(),
'name': self.name, 'name': self.name,
'seed': seed, 'seed': seed,
'encrypted': self.serialize_encrypted, 'encrypted': bool(self.encrypted or encrypt_password),
'private_key': private_key_string, 'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(), 'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change), 'address_generator': self.address_generator.to_dict(self.receiving, self.change),
@ -322,6 +324,7 @@ class BaseAccount:
@property @property
def hash(self) -> bytes: def hash(self) -> bytes:
assert not self.encrypted, "Cannot hash an encrypted account."
return sha256(json.dumps(self.to_dict()).encode()) return sha256(json.dumps(self.to_dict()).encode())
async def get_details(self, show_seed=False, **kwargs): async def get_details(self, show_seed=False, **kwargs):
@ -339,42 +342,41 @@ class BaseAccount:
details['seed'] = self.seed details['seed'] = self.seed
return details return details
def decrypt(self, password: str) -> None: def decrypt(self, password: str) -> bool:
assert self.encrypted, "Key is not encrypted." assert self.encrypted, "Key is not encrypted."
try: try:
seed, seed_iv = aes_decrypt(password, self.seed) seed, seed_iv = aes_decrypt(password, self.seed)
pk_string, pk_iv = aes_decrypt(password, self.private_key_string) pk_string, pk_iv = aes_decrypt(password, self.private_key_string)
except ValueError: # failed to remove padding, password is wrong except ValueError: # failed to remove padding, password is wrong
return return False
try: try:
Mnemonic().mnemonic_decode(seed) Mnemonic().mnemonic_decode(seed)
except IndexError: # failed to decode the seed, this either means it decrypted and is invalid except IndexError: # failed to decode the seed, this either means it decrypted and is invalid
# or that we hit an edge case where an incorrect password gave valid padding # or that we hit an edge case where an incorrect password gave valid padding
return return False
try: try:
private_key = from_extended_key_string( private_key = from_extended_key_string(
self.ledger, pk_string self.ledger, pk_string
) )
except (TypeError, ValueError): except (TypeError, ValueError):
return return False
self.seed = seed self.seed = seed
self.seed_encryption_init_vector = seed_iv
self.private_key = private_key self.private_key = private_key
self.private_key_encryption_init_vector = pk_iv self.init_vectors['seed'] = seed_iv
self.password = password self.init_vectors['private_key'] = pk_iv
self.encrypted = False self.encrypted = False
return True
def encrypt(self, password: str) -> None: def encrypt(self, password: str) -> bool:
assert not self.encrypted, "Key is already encrypted." assert not self.encrypted, "Key is already encrypted."
assert isinstance(self.private_key, PrivateKey) assert isinstance(self.private_key, PrivateKey)
self.seed = aes_encrypt(password, self.seed, self.get_init_vector('seed'))
self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
self.private_key_string = aes_encrypt( self.private_key_string = aes_encrypt(
password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector password, self.private_key.extended_key_string(), self.get_init_vector('private_key')
) )
self.private_key = None self.private_key = None
self.password = None
self.encrypted = True self.encrypted = True
return True
async def ensure_address_gap(self): async def ensure_address_gap(self):
addresses = [] addresses = []

View file

@ -4,6 +4,7 @@ import stat
import json import json
import zlib import zlib
import typing import typing
import logging
from typing import List, Sequence, MutableSequence, Optional from typing import List, Sequence, MutableSequence, Optional
from collections import UserDict from collections import UserDict
from hashlib import sha256 from hashlib import sha256
@ -14,8 +15,18 @@ if typing.TYPE_CHECKING:
from torba.client import basemanager, baseaccount, baseledger from torba.client import basemanager, baseaccount, baseledger
log = logging.getLogger(__name__)
ENCRYPT_ON_DISK = 'encrypt-on-disk'
class TimestampedPreferences(UserDict): class TimestampedPreferences(UserDict):
def __init__(self, d: dict = None):
super().__init__()
if d is not None:
self.data = d.copy()
def __getitem__(self, key): def __getitem__(self, key):
return self.data[key]['value'] return self.data[key]['value']
@ -52,6 +63,7 @@ class Wallet:
""" """
preferences: TimestampedPreferences preferences: TimestampedPreferences
encryption_password: Optional[str]
def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None, def __init__(self, name: str = 'Wallet', accounts: MutableSequence['baseaccount.BaseAccount'] = None,
storage: 'WalletStorage' = None, preferences: dict = None) -> None: storage: 'WalletStorage' = None, preferences: dict = None) -> None:
@ -59,6 +71,7 @@ class Wallet:
self.accounts = accounts or [] self.accounts = accounts or []
self.storage = storage or WalletStorage() self.storage = storage or WalletStorage()
self.preferences = TimestampedPreferences(preferences or {}) self.preferences = TimestampedPreferences(preferences or {})
self.encryption_password = None
@property @property
def id(self): def id(self):
@ -119,15 +132,24 @@ class Wallet:
ledger.account_class.from_dict(ledger, wallet, account_dict) ledger.account_class.from_dict(ledger, wallet, account_dict)
return wallet return wallet
def to_dict(self): def to_dict(self, encrypt_password: str = None):
return { return {
'version': WalletStorage.LATEST_VERSION, 'version': WalletStorage.LATEST_VERSION,
'name': self.name, 'name': self.name,
'preferences': self.preferences.data, 'preferences': self.preferences.data,
'accounts': [a.to_dict() for a in self.accounts] 'accounts': [a.to_dict(encrypt_password) for a in self.accounts]
} }
def save(self): def save(self):
if self.preferences.get(ENCRYPT_ON_DISK, False):
if self.encryption_password:
self.storage.write(self.to_dict(encrypt_password=self.encryption_password))
return
else:
log.warning(
"Disk encryption requested but no password available for encryption. "
"Saving wallet in an unencrypted state."
)
self.storage.write(self.to_dict()) self.storage.write(self.to_dict())
@property @property
@ -139,6 +161,7 @@ class Wallet:
return h.digest() return h.digest()
def pack(self, password): def pack(self, password):
assert not self.is_locked, "Cannot pack a wallet with locked/encrypted accounts."
new_data = json.dumps(self.to_dict()) new_data = json.dumps(self.to_dict())
new_data_compressed = zlib.compress(new_data.encode()) new_data_compressed = zlib.compress(new_data.encode())
return better_aes_encrypt(password, new_data_compressed) return better_aes_encrypt(password, new_data_compressed)
@ -151,9 +174,12 @@ class Wallet:
def merge(self, manager: 'basemanager.BaseWalletManager', def merge(self, manager: 'basemanager.BaseWalletManager',
password: str, data: str) -> List['baseaccount.BaseAccount']: password: str, data: str) -> List['baseaccount.BaseAccount']:
assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = [] added_accounts = []
decrypted_data = self.unpack(password, data) decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {})) self.preferences.merge(decrypted_data.get('preferences', {}))
if not self.encryption_password and self.preferences.get(ENCRYPT_ON_DISK, False):
self.encryption_password = password
for account_dict in decrypted_data['accounts']: for account_dict in decrypted_data['accounts']:
ledger = manager.get_or_create_ledger(account_dict['ledger']) ledger = manager.get_or_create_ledger(account_dict['ledger'])
_, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict) _, _, pubkey = ledger.account_class.keys_from_dict(ledger, account_dict)
@ -178,38 +204,35 @@ class Wallet:
return False return False
def unlock(self, password): def unlock(self, password):
self.encryption_password = password
for account in self.accounts: for account in self.accounts:
if account.encrypted: if account.encrypted:
account.decrypt(password) if not account.decrypt(password):
return False
return True return True
def lock(self): def lock(self):
assert self.encryption_password is not None, "Cannot lock an unencrypted wallet, encrypt first."
for account in self.accounts: for account in self.accounts:
if not account.encrypted: if not account.encrypted:
assert account.password is not None, "account was never encrypted" account.encrypt(self.encryption_password)
account.encrypt(account.password)
return True return True
@property @property
def is_encrypted(self) -> bool: def is_encrypted(self) -> bool:
for account in self.accounts: return self.is_locked or self.preferences.get(ENCRYPT_ON_DISK, False)
if account.serialize_encrypted:
return True
return False
def decrypt(self): def decrypt(self):
for account in self.accounts: assert not self.is_locked, "Cannot decrypt a locked wallet, unlock first."
account.serialize_encrypted = False self.preferences[ENCRYPT_ON_DISK] = False
self.save() self.save()
return True return True
def encrypt(self, password): def encrypt(self, password):
for account in self.accounts: assert not self.is_locked, "Cannot re-encrypt a locked wallet, unlock first."
if not account.encrypted: self.encryption_password = password
account.encrypt(password) self.preferences[ENCRYPT_ON_DISK] = True
account.serialize_encrypted = True
self.save() self.save()
self.unlock(password)
return True return True