fix wallet encryption

fixes https://github.com/lbryio/lbry/issues/1623
This commit is contained in:
Jack Robison 2018-11-19 13:51:25 -05:00 committed by Lex Berezhny
parent ee57af8e99
commit e745b6c16e
4 changed files with 34 additions and 17 deletions

View file

@ -369,7 +369,8 @@ class AccountEncryptionTests(AsyncioTestCase):
def test_encrypt_wallet(self): def test_encrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
account.encryption_init_vector = self.init_vector account.private_key_encryption_init_vector = self.init_vector
account.seed_encryption_init_vector = self.init_vector
self.assertFalse(account.serialize_encrypted) self.assertFalse(account.serialize_encrypted)
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
@ -386,6 +387,8 @@ class AccountEncryptionTests(AsyncioTestCase):
account.serialize_encrypted = True account.serialize_encrypted = True
account.decrypt(self.password) account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed, self.unencrypted_account['seed']) self.assertEqual(account.seed, self.unencrypted_account['seed'])
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key']) self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
@ -393,20 +396,22 @@ class AccountEncryptionTests(AsyncioTestCase):
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
account.encryption_init_vector = None
self.assertNotEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertNotEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted) self.assertTrue(account.serialize_encrypted)
account.serialize_encrypted = False
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
def test_decrypt_wallet(self): def test_decrypt_wallet(self):
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account) account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
account.encryption_init_vector = self.init_vector
self.assertTrue(account.encrypted) self.assertTrue(account.encrypted)
self.assertTrue(account.serialize_encrypted) self.assertTrue(account.serialize_encrypted)
account.decrypt(self.password) account.decrypt(self.password)
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertTrue(account.serialize_encrypted) self.assertTrue(account.serialize_encrypted)

View file

@ -23,9 +23,13 @@ class TestAESEncryptDecrypt(TestCase):
'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE' 'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE'
'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2' 'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
) )
self.assertEqual(
aes_decrypt(self.password, self.expected),
(self.message, b'f' * 16)
)
def test_encrypt_decrypt(self): def test_encrypt_decrypt(self):
self.assertEqual( self.assertEqual(
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message)), aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0],
self.message self.message
) )

View file

@ -201,7 +201,9 @@ class BaseAccount:
self.seed = seed self.seed = seed
self.private_key_string = private_key_string self.private_key_string = private_key_string
self.password: Optional[str] = None self.password: Optional[str] = None
self.encryption_init_vector = None self.private_key_encryption_init_vector: Optional[bytes] = None
self.seed_encryption_init_vector: Optional[bytes] = None
self.encrypted = encrypted self.encrypted = encrypted
self.serialize_encrypted = encrypted self.serialize_encrypted = encrypted
self.private_key = private_key self.private_key = private_key
@ -264,13 +266,16 @@ class BaseAccount:
if not self.encrypted and self.private_key: if not self.encrypted and self.private_key:
private_key_string = self.private_key.extended_key_string() private_key_string = self.private_key.extended_key_string()
if not self.encrypted and self.serialize_encrypted: if not self.encrypted and self.serialize_encrypted:
private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector) assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector) private_key_string = aes_encrypt(
self.password, private_key_string, self.private_key_encryption_init_vector
)
seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector)
return { return {
'ledger': self.ledger.get_id(), 'ledger': self.ledger.get_id(),
'name': self.name, 'name': self.name,
'seed': seed, 'seed': seed,
'encrypted': self.encrypted, 'encrypted': self.serialize_encrypted,
'private_key': private_key_string, 'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(), 'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change) 'address_generator': self.address_generator.to_dict(self.receiving, self.change)
@ -293,9 +298,10 @@ class BaseAccount:
def decrypt(self, password: str) -> None: def decrypt(self, password: str) -> None:
assert self.encrypted, "Key is not encrypted." assert self.encrypted, "Key is not encrypted."
self.seed = aes_decrypt(password, self.seed) self.seed, self.seed_encryption_init_vector = aes_decrypt(password, self.seed)
pk_string, self.private_key_encryption_init_vector = aes_decrypt(password, self.private_key_string)
self.private_key = from_extended_key_string( self.private_key = from_extended_key_string(
self.ledger, aes_decrypt(password, self.private_key_string) self.ledger, pk_string
) )
self.password = password self.password = password
self.encrypted = False self.encrypted = False
@ -303,9 +309,10 @@ class BaseAccount:
def encrypt(self, password: str) -> None: def encrypt(self, password: str) -> None:
assert not self.encrypted, "Key is already encrypted." assert not self.encrypted, "Key is already encrypted."
assert isinstance(self.private_key, PrivateKey) assert isinstance(self.private_key, PrivateKey)
self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector)
self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
self.private_key_string = aes_encrypt( self.private_key_string = aes_encrypt(
password, self.private_key.extended_key_string(), self.encryption_init_vector password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector
) )
self.private_key = None self.private_key = None
self.password = None self.password = None

View file

@ -12,6 +12,7 @@ import os
import base64 import base64
import hashlib import hashlib
import hmac import hmac
import typing
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers import Cipher, modes
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
@ -133,14 +134,14 @@ def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
return base64.b64encode(init_vector + encrypted_data).decode() return base64.b64encode(init_vector + encrypted_data).decode()
def aes_decrypt(secret: str, value: str) -> str: def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]:
data = base64.b64decode(value.encode()) data = base64.b64decode(value.encode())
key = double_sha256(secret.encode()) key = double_sha256(secret.encode())
init_vector, data = data[:16], data[16:] init_vector, data = data[:16], data[16:]
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor() decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
unpadder = PKCS7(AES.block_size).unpadder() unpadder = PKCS7(AES.block_size).unpadder()
result = unpadder.update(decryptor.update(data)) + unpadder.finalize() result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
return result.decode() return result.decode(), init_vector
class Base58Error(Exception): class Base58Error(Exception):