parent
ee57af8e99
commit
e745b6c16e
4 changed files with 34 additions and 17 deletions
|
@ -369,7 +369,8 @@ class AccountEncryptionTests(AsyncioTestCase):
|
|||
|
||||
def test_encrypt_wallet(self):
|
||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.unencrypted_account)
|
||||
account.encryption_init_vector = self.init_vector
|
||||
account.private_key_encryption_init_vector = self.init_vector
|
||||
account.seed_encryption_init_vector = self.init_vector
|
||||
|
||||
self.assertFalse(account.serialize_encrypted)
|
||||
self.assertFalse(account.encrypted)
|
||||
|
@ -386,6 +387,8 @@ class AccountEncryptionTests(AsyncioTestCase):
|
|||
|
||||
account.serialize_encrypted = True
|
||||
account.decrypt(self.password)
|
||||
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
|
||||
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
|
||||
|
||||
self.assertEqual(account.seed, self.unencrypted_account['seed'])
|
||||
self.assertEqual(account.private_key.extended_key_string(), self.unencrypted_account['private_key'])
|
||||
|
@ -393,20 +396,22 @@ class AccountEncryptionTests(AsyncioTestCase):
|
|||
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
|
||||
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
|
||||
|
||||
account.encryption_init_vector = None
|
||||
self.assertNotEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
|
||||
self.assertNotEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])
|
||||
|
||||
self.assertFalse(account.encrypted)
|
||||
self.assertTrue(account.serialize_encrypted)
|
||||
|
||||
account.serialize_encrypted = False
|
||||
self.assertEqual(account.to_dict()['seed'], self.unencrypted_account['seed'])
|
||||
self.assertEqual(account.to_dict()['private_key'], self.unencrypted_account['private_key'])
|
||||
|
||||
def test_decrypt_wallet(self):
|
||||
account = self.ledger.account_class.from_dict(self.ledger, Wallet(), self.encrypted_account)
|
||||
account.encryption_init_vector = self.init_vector
|
||||
|
||||
self.assertTrue(account.encrypted)
|
||||
self.assertTrue(account.serialize_encrypted)
|
||||
account.decrypt(self.password)
|
||||
self.assertEqual(account.private_key_encryption_init_vector, self.init_vector)
|
||||
self.assertEqual(account.seed_encryption_init_vector, self.init_vector)
|
||||
|
||||
self.assertFalse(account.encrypted)
|
||||
self.assertTrue(account.serialize_encrypted)
|
||||
|
||||
|
|
|
@ -23,9 +23,13 @@ class TestAESEncryptDecrypt(TestCase):
|
|||
'ZmZmZmZmZmZmZmZmZmZmZjlrKptoKD+MFwDxcg3XtCD9qz8UWhEhq/TVJT5+Mtp2a8sE'
|
||||
'CaO6WQj7fYsWGu2Hvbc0qYqxdN0HeTsiO+cZRo3eJISgr3F+rXFYi5oSBlD2'
|
||||
)
|
||||
self.assertEqual(
|
||||
aes_decrypt(self.password, self.expected),
|
||||
(self.message, b'f' * 16)
|
||||
)
|
||||
|
||||
def test_encrypt_decrypt(self):
|
||||
self.assertEqual(
|
||||
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message)),
|
||||
aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0],
|
||||
self.message
|
||||
)
|
||||
|
|
|
@ -201,7 +201,9 @@ class BaseAccount:
|
|||
self.seed = seed
|
||||
self.private_key_string = private_key_string
|
||||
self.password: Optional[str] = None
|
||||
self.encryption_init_vector = None
|
||||
self.private_key_encryption_init_vector: Optional[bytes] = None
|
||||
self.seed_encryption_init_vector: Optional[bytes] = None
|
||||
|
||||
self.encrypted = encrypted
|
||||
self.serialize_encrypted = encrypted
|
||||
self.private_key = private_key
|
||||
|
@ -264,13 +266,16 @@ class BaseAccount:
|
|||
if not self.encrypted and self.private_key:
|
||||
private_key_string = self.private_key.extended_key_string()
|
||||
if not self.encrypted and self.serialize_encrypted:
|
||||
private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector)
|
||||
seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector)
|
||||
assert None not in [self.seed_encryption_init_vector, self.private_key_encryption_init_vector]
|
||||
private_key_string = aes_encrypt(
|
||||
self.password, private_key_string, self.private_key_encryption_init_vector
|
||||
)
|
||||
seed = aes_encrypt(self.password, self.seed, self.seed_encryption_init_vector)
|
||||
return {
|
||||
'ledger': self.ledger.get_id(),
|
||||
'name': self.name,
|
||||
'seed': seed,
|
||||
'encrypted': self.encrypted,
|
||||
'encrypted': self.serialize_encrypted,
|
||||
'private_key': private_key_string,
|
||||
'public_key': self.public_key.extended_key_string(),
|
||||
'address_generator': self.address_generator.to_dict(self.receiving, self.change)
|
||||
|
@ -293,9 +298,10 @@ class BaseAccount:
|
|||
|
||||
def decrypt(self, password: str) -> None:
|
||||
assert self.encrypted, "Key is not encrypted."
|
||||
self.seed = aes_decrypt(password, self.seed)
|
||||
self.seed, self.seed_encryption_init_vector = aes_decrypt(password, self.seed)
|
||||
pk_string, self.private_key_encryption_init_vector = aes_decrypt(password, self.private_key_string)
|
||||
self.private_key = from_extended_key_string(
|
||||
self.ledger, aes_decrypt(password, self.private_key_string)
|
||||
self.ledger, pk_string
|
||||
)
|
||||
self.password = password
|
||||
self.encrypted = False
|
||||
|
@ -303,9 +309,10 @@ class BaseAccount:
|
|||
def encrypt(self, password: str) -> None:
|
||||
assert not self.encrypted, "Key is already encrypted."
|
||||
assert isinstance(self.private_key, PrivateKey)
|
||||
self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector)
|
||||
|
||||
self.seed = aes_encrypt(password, self.seed, self.seed_encryption_init_vector)
|
||||
self.private_key_string = aes_encrypt(
|
||||
password, self.private_key.extended_key_string(), self.encryption_init_vector
|
||||
password, self.private_key.extended_key_string(), self.private_key_encryption_init_vector
|
||||
)
|
||||
self.private_key = None
|
||||
self.password = None
|
||||
|
|
|
@ -12,6 +12,7 @@ import os
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import typing
|
||||
from binascii import hexlify, unhexlify
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, modes
|
||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
|
@ -133,14 +134,14 @@ def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
|
|||
return base64.b64encode(init_vector + encrypted_data).decode()
|
||||
|
||||
|
||||
def aes_decrypt(secret: str, value: str) -> str:
|
||||
def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]:
|
||||
data = base64.b64decode(value.encode())
|
||||
key = double_sha256(secret.encode())
|
||||
init_vector, data = data[:16], data[16:]
|
||||
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
|
||||
unpadder = PKCS7(AES.block_size).unpadder()
|
||||
result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
|
||||
return result.decode()
|
||||
return result.decode(), init_vector
|
||||
|
||||
|
||||
class Base58Error(Exception):
|
||||
|
|
Loading…
Reference in a new issue