diff --git a/tests/client_tests/unit/test_account.py b/tests/client_tests/unit/test_account.py index 29d9690b8..8c50cea1c 100644 --- a/tests/client_tests/unit/test_account.py +++ b/tests/client_tests/unit/test_account.py @@ -369,7 +369,8 @@ class AccountEncryptionTests(AsyncioTestCase): def test_encrypt_wallet(self): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account) - account.encryption_init_vector = self.init_vector + account.private_key_encryption_init_vector = self.init_vector + account.seed_encryption_init_vector = self.init_vector self.assertFalse(account.serialize_encrypted) self.assertFalse(account.encrypted) @@ -386,6 +387,8 @@ class AccountEncryptionTests(AsyncioTestCase): account.serialize_encrypted = True account.decrypt(self.password) + self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) + self.assertEqual(account.seed_encryption_init_vector, self.init_vector) self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) @@ -393,20 +396,22 @@ class AccountEncryptionTests(AsyncioTestCase): self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) - account.encryption_init_vector = None - self.assertNotEqual(account.to_dict()['seed'], self.encrypted_account['seed']) - self.assertNotEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) - self.assertFalse(account.encrypted) self.assertTrue(account.serialize_encrypted) + account.serialize_encrypted = False + self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed']) + self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key']) + def test_decrypt_wallet(self): account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account) - account.encryption_init_vector = self.init_vector self.assertTrue(account.encrypted) self.assertTrue(account.serialize_encrypted) account.decrypt(self.password) + self.assertEqual(account.private_key_encryption_init_vector, self.init_vector) + self.assertEqual(account.seed_encryption_init_vector, self.init_vector) + self.assertFalse(account.encrypted) self.assertTrue(account.serialize_encrypted) diff --git a/tests/client_tests/unit/test_hash.py b/tests/client_tests/unit/test_hash.py index 7c2fba63e..4b3a8de5c 100644 --- a/tests/client_tests/unit/test_hash.py +++ b/tests/client_tests/unit/test_hash.py @@ -23,9 +23,13 @@ class TestAESEncryptDecrypt(TestCase): 'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE' 'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2' ) + self.assertEqual( + aes_decrypt(self.password, self.expected), + (self.message, b'f' * 16) + ) def test_encrypt_decrypt(self): self.assertEqual( - aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message)), + aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0], self.message ) diff --git a/torba/client/baseaccount.py b/torba/client/baseaccount.py index 066d3bc9d..9e9cf8d05 100644 --- a/torba/client/baseaccount.py +++ b/torba/client/baseaccount.py @@ -201,7 +201,9 @@ class BaseAccount: self.seed = seed self.private_key_string = private_key_string self.password: Optional[str] = None - self.encryption_init_vector = None + self.private_key_encryption_init_vector: Optional[bytes] = None + self.seed_encryption_init_vector: Optional[bytes] = None + self.encrypted = encrypted self.serialize_encrypted = encrypted self.private_key = private_key @@ -264,13 +266,16 @@ class BaseAccount: if not self.encrypted and self.private_key: private_key_string = self.private_key.extended_key_string() if not self.encrypted and self.serialize_encrypted: - private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector) - seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector) + assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector] + private_key_string = aes_encrypt( + self.password, private_key_string, self.private_key_encryption_init_vector + ) + seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector) return { 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': seed, - 'encrypted': self.encrypted, + 'encrypted': self.serialize_encrypted, 'private_key': private_key_string, 'public_key': self.public_key.extended_key_string(), 'address_generator': self.address_generator.to_dict(self.receiving, self.change) @@ -293,9 +298,10 @@ class BaseAccount: def decrypt(self, password: str) -> None: assert self.encrypted, "Key is not encrypted." - self.seed = aes_decrypt(password, self.seed) + self.seed, self.seed_encryption_init_vector = aes_decrypt(password, self.seed) + pk_string, self.private_key_encryption_init_vector = aes_decrypt(password, self.private_key_string) self.private_key = from_extended_key_string( - self.ledger, aes_decrypt(password, self.private_key_string) + self.ledger, pk_string ) self.password = password self.encrypted = False @@ -303,9 +309,10 @@ class BaseAccount: def encrypt(self, password: str) -> None: assert not self.encrypted, "Key is already encrypted." assert isinstance(self.private_key, PrivateKey) - self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector) + + self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector) self.private_key_string = aes_encrypt( - password, self.private_key.extended_key_string(), self.encryption_init_vector + password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector ) self.private_key = None self.password = None diff --git a/torba/client/hash.py b/torba/client/hash.py index bee4ba8c7..dcd61a080 100644 --- a/torba/client/hash.py +++ b/torba/client/hash.py @@ -12,6 +12,7 @@ import os import base64 import hashlib import hmac +import typing from binascii import hexlify, unhexlify from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers.algorithms import AES @@ -133,14 +134,14 @@ def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str: return base64.b64encode(init_vector + encrypted_data).decode() -def aes_decrypt(secret: str, value: str) -> str: +def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]: data = base64.b64decode(value.encode()) key = double_sha256(secret.encode()) init_vector, data = data[:16], data[16:] decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor() unpadder = PKCS7(AES.block_size).unpadder() result = unpadder.update(decryptor.update(data)) + unpadder.finalize() - return result.decode() + return result.decode(), init_vector class Base58Error(Exception):