diff --git a/tests/client_tests/unit/test_hash.py b/tests/client_tests/unit/test_hash.py index 4b3a8de5c..19ad79249 100644 --- a/tests/client_tests/unit/test_hash.py +++ b/tests/client_tests/unit/test_hash.py @@ -1,5 +1,5 @@ from unittest import TestCase, mock -from torba.client.hash import aes_decrypt, aes_encrypt +from torba.client.hash import aes_decrypt, aes_encrypt, better_aes_decrypt, better_aes_encrypt class TestAESEncryptDecrypt(TestCase): @@ -33,3 +33,10 @@ class TestAESEncryptDecrypt(TestCase): aes_decrypt('bubblegum', aes_encrypt('bubblegum', self.message))[0], self.message ) + + def test_better_encrypt_decrypt(self): + self.assertEqual( + b'valuable value', + better_aes_decrypt( + 'super secret', + better_aes_encrypt('super secret', b'valuable value'))) diff --git a/torba/client/hash.py b/torba/client/hash.py index 9b0e94f73..cb2a17696 100644 --- a/torba/client/hash.py +++ b/torba/client/hash.py @@ -14,6 +14,7 @@ import hashlib import hmac import typing from binascii import hexlify, unhexlify +from cryptography.hazmat.primitives.kdf.scrypt import Scrypt from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.padding import PKCS7 @@ -146,23 +147,29 @@ def aes_decrypt(secret: str, value: str) -> typing.Tuple[str, bytes]: def better_aes_encrypt(secret: str, value: bytes) -> bytes: init_vector = os.urandom(16) - key = double_sha256(secret.encode()) + key = scrypt(secret.encode(), salt=init_vector) encryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).encryptor() padder = PKCS7(AES.block_size).padder() padded_data = padder.update(value) + padder.finalize() encrypted_data = encryptor.update(padded_data) + encryptor.finalize() - return base64.b64encode(init_vector + encrypted_data) + return base64.b64encode(b's:8192:16:1:' + init_vector + encrypted_data) def better_aes_decrypt(secret: str, value: bytes) -> bytes: data = base64.b64decode(value) - key = double_sha256(secret.encode()) + type, n, r, p, data = data.split(b':', maxsplit=4) init_vector, data = data[:16], data[16:] + key = scrypt(secret.encode(), salt=init_vector, n=int(n), r=int(r), p=int(p)) decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor() unpadder = PKCS7(AES.block_size).unpadder() return unpadder.update(decryptor.update(data)) + unpadder.finalize() +def scrypt(passphrase, salt, n=1<<13, r=16, p=1): + kdf = Scrypt(salt, length=32, n=n, r=r, p=p, backend=default_backend()) + return kdf.derive(passphrase) + + class Base58Error(Exception): """ Exception used for Base58 errors. """