fix wallet encryption

fixes https://github.com/lbryio/lbry/issues/1623
This commit is contained in:
Jack Robison 2018-11-19 13:51:25 -05:00 committed by Lex Berezhny
parent ee57af8e99
commit e745b6c16e
4 changed files with 34 additions and 17 deletions

View file

@ -369,7 +369,8 @@ class AccountEncryptionTests(AsyncioTestCase):
def test_encrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
account.encryption_init_vector = self.init_vector
account.private_key_encryption_init_vector = self.init_vector
account.seed_encryption_init_vector = self.init_vector
self.assertFalse(account.serialize_encrypted)
self.assertFalse(account.encrypted)
@ -386,6 +387,8 @@ class AccountEncryptionTests(AsyncioTestCase):
account.serialize_encrypted = True
account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
@ -393,20 +396,22 @@ class AccountEncryptionTests(AsyncioTestCase):
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
account.encryption_init_vector = None
self.assertNotEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertNotEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted)
account.serialize_encrypted = False
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
def test_decrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
account.encryption_init_vector = self.init_vector
self.assertTrue(account.encrypted)
self.assertTrue(account.serialize_encrypted)
account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted)

View file

@ -23,9 +23,13 @@ class TestAESEncryptDecrypt(TestCase):
'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE'
'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
)
self.assertEqual(
aes_decrypt(self.password, self.expected),
(self.message, b'f' * 16)
)
def test_encrypt_decrypt(self):
self.assertEqual(
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message)),
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0],
self.message
)

View file

@ -201,7 +201,9 @@ class BaseAccount:
self.seed = seed
self.private_key_string = private_key_string
self.password: Optional[str] = None
self.encryption_init_vector = None
self.private_key_encryption_init_vector: Optional[bytes] = None
self.seed_encryption_init_vector: Optional[bytes] = None
self.encrypted = encrypted
self.serialize_encrypted = encrypted
self.private_key = private_key
@ -264,13 +266,16 @@ class BaseAccount:
if not self.encrypted and self.private_key:
private_key_string = self.private_key.extended_key_string()
if not self.encrypted and self.serialize_encrypted:
private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector)
seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector)
assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
private_key_string = aes_encrypt(
self.password, private_key_string, self.private_key_encryption_init_vector
)
seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector)
return {
'ledger': self.ledger.get_id(),
'name': self.name,
'seed': seed,
'encrypted': self.encrypted,
'encrypted': self.serialize_encrypted,
'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
@ -293,9 +298,10 @@ class BaseAccount:
def decrypt(self, password: str) -> None:
assert self.encrypted, "Key is not encrypted."
self.seed = aes_decrypt(password, self.seed)
self.seed, self.seed_encryption_init_vector = aes_decrypt(password, self.seed)
pk_string, self.private_key_encryption_init_vector = aes_decrypt(password, self.private_key_string)
self.private_key = from_extended_key_string(
self.ledger, aes_decrypt(password, self.private_key_string)
self.ledger, pk_string
)
self.password = password
self.encrypted = False
@ -303,9 +309,10 @@ class BaseAccount:
def encrypt(self, password: str) -> None:
assert not self.encrypted, "Key is already encrypted."
assert isinstance(self.private_key, PrivateKey)
self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector)
self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
self.private_key_string = aes_encrypt(
password, self.private_key.extended_key_string(), self.encryption_init_vector
password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector
)
self.private_key = None
self.password = None

View file

@ -12,6 +12,7 @@ import os
import base64
import hashlib
import hmac
import typing
from binascii import hexlify, unhexlify
from cryptography.hazmat.primitives.ciphers import Cipher, modes
from cryptography.hazmat.primitives.ciphers.algorithms import AES
@ -133,14 +134,14 @@ def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
return base64.b64encode(init_vector + encrypted_data).decode()
def aes_decrypt(secret: str, value: str) -> str:
def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]:
data = base64.b64decode(value.encode())
key = double_sha256(secret.encode())
init_vector, data = data[:16], data[16:]
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
unpadder = PKCS7(AES.block_size).unpadder()
result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
return result.decode()
return result.decode(), init_vector
class Base58Error(Exception):