From e745b6c16e7b33f28b69ebf869c80ffa6b605984 Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Mon, 19 Nov 2018 13:51:25 -0500
Subject: [PATCH] fix wallet encryption

fixes https://github.com/lbryio/lbry/issues/1623
---
 tests/client_tests/unit/test_account.py | 17 +++++++++++------
 tests/client_tests/unit/test_hash.py    |  6 +++++-
 torba/client/baseaccount.py             | 23 +++++++++++++++--------
 torba/client/hash.py                    |  5 +++--
 4 files changed, 34 insertions(+), 17 deletions(-)

diff --git a/tests/client_tests/unit/test_account.py b/tests/client_tests/unit/test_account.py
index 29d9690b8..8c50cea1c 100644
--- a/tests/client_tests/unit/test_account.py
+++ b/tests/client_tests/unit/test_account.py
@@ -369,7 +369,8 @@ class AccountEncryptionTests(AsyncioTestCase):
 
     def test_encrypt_wallet(self):
         account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
-        account.encryption_init_vector = self.init_vector
+        account.private_key_encryption_init_vector = self.init_vector
+        account.seed_encryption_init_vector = self.init_vector
 
         self.assertFalse(account.serialize_encrypted)
         self.assertFalse(account.encrypted)
@@ -386,6 +387,8 @@ class AccountEncryptionTests(AsyncioTestCase):
 
         account.serialize_encrypted = True
         account.decrypt(self.password)
+        self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
+        self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
 
         self.assertEqual(account.seed, self.unencrypted_account['seed'])
         self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
@@ -393,20 +396,22 @@ class AccountEncryptionTests(AsyncioTestCase):
         self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
         self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
 
-        account.encryption_init_vector = None
-        self.assertNotEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
-        self.assertNotEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
-
         self.assertFalse(account.encrypted)
         self.assertTrue(account.serialize_encrypted)
 
+        account.serialize_encrypted = False
+        self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
+        self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
+
     def test_decrypt_wallet(self):
         account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
-        account.encryption_init_vector = self.init_vector
 
         self.assertTrue(account.encrypted)
         self.assertTrue(account.serialize_encrypted)
         account.decrypt(self.password)
+        self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
+        self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
+
         self.assertFalse(account.encrypted)
         self.assertTrue(account.serialize_encrypted)
 
diff --git a/tests/client_tests/unit/test_hash.py b/tests/client_tests/unit/test_hash.py
index 7c2fba63e..4b3a8de5c 100644
--- a/tests/client_tests/unit/test_hash.py
+++ b/tests/client_tests/unit/test_hash.py
@@ -23,9 +23,13 @@ class TestAESEncryptDecrypt(TestCase):
            'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE'
            'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
         )
+        self.assertEqual(
+            aes_decrypt(self.password, self.expected),
+            (self.message, b'f' * 16)
+        )
 
     def test_encrypt_decrypt(self):
         self.assertEqual(
-            aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message)),
+            aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0],
             self.message
         )
diff --git a/torba/client/baseaccount.py b/torba/client/baseaccount.py
index 066d3bc9d..9e9cf8d05 100644
--- a/torba/client/baseaccount.py
+++ b/torba/client/baseaccount.py
@@ -201,7 +201,9 @@ class BaseAccount:
         self.seed = seed
         self.private_key_string = private_key_string
         self.password: Optional[str] = None
-        self.encryption_init_vector = None
+        self.private_key_encryption_init_vector: Optional[bytes] = None
+        self.seed_encryption_init_vector: Optional[bytes] = None
+
         self.encrypted = encrypted
         self.serialize_encrypted = encrypted
         self.private_key = private_key
@@ -264,13 +266,16 @@ class BaseAccount:
         if not self.encrypted and self.private_key:
             private_key_string = self.private_key.extended_key_string()
         if not self.encrypted and self.serialize_encrypted:
-            private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector)
-            seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector)
+            assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
+            private_key_string = aes_encrypt(
+                self.password, private_key_string, self.private_key_encryption_init_vector
+            )
+            seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector)
         return {
             'ledger': self.ledger.get_id(),
             'name': self.name,
             'seed': seed,
-            'encrypted': self.encrypted,
+            'encrypted': self.serialize_encrypted,
             'private_key': private_key_string,
             'public_key': self.public_key.extended_key_string(),
             'address_generator': self.address_generator.to_dict(self.receiving, self.change)
@@ -293,9 +298,10 @@ class BaseAccount:
 
     def decrypt(self, password: str) -> None:
         assert self.encrypted, "Key is not encrypted."
-        self.seed = aes_decrypt(password, self.seed)
+        self.seed, self.seed_encryption_init_vector = aes_decrypt(password, self.seed)
+        pk_string, self.private_key_encryption_init_vector = aes_decrypt(password, self.private_key_string)
         self.private_key = from_extended_key_string(
-            self.ledger, aes_decrypt(password, self.private_key_string)
+            self.ledger, pk_string
         )
         self.password = password
         self.encrypted = False
@@ -303,9 +309,10 @@ class BaseAccount:
     def encrypt(self, password: str) -> None:
         assert not self.encrypted, "Key is already encrypted."
         assert isinstance(self.private_key, PrivateKey)
-        self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector)
+
+        self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
         self.private_key_string = aes_encrypt(
-            password, self.private_key.extended_key_string(), self.encryption_init_vector
+            password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector
         )
         self.private_key = None
         self.password = None
diff --git a/torba/client/hash.py b/torba/client/hash.py
index bee4ba8c7..dcd61a080 100644
--- a/torba/client/hash.py
+++ b/torba/client/hash.py
@@ -12,6 +12,7 @@ import os
 import base64
 import hashlib
 import hmac
+import typing
 from binascii import hexlify, unhexlify
 from cryptography.hazmat.primitives.ciphers import Cipher, modes
 from cryptography.hazmat.primitives.ciphers.algorithms import AES
@@ -133,14 +134,14 @@ def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
     return base64.b64encode(init_vector + encrypted_data).decode()
 
 
-def aes_decrypt(secret: str, value: str) -> str:
+def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]:
     data = base64.b64decode(value.encode())
     key = double_sha256(secret.encode())
     init_vector, data = data[:16], data[16:]
     decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
     unpadder = PKCS7(AES.block_size).unpadder()
     result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
-    return result.decode()
+    return result.decode(), init_vector
 
 
 class Base58Error(Exception):