From 58d2c04b9ffadad1fcd0b0707e8a1d559c89f5df Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Mon, 24 Sep 2018 23:12:46 -0400 Subject: [PATCH] simplified and cleaned up typing for encryp / decrypt --- tests/unit/test_account.py | 4 +++- torba/baseaccount.py | 42 +++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 7bbd4afcf..f23f29c0d 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -371,11 +371,13 @@ class AccountEncryptionTests(unittest.TestCase): self.assertFalse(account.serialize_encrypted) self.assertFalse(account.encrypted) + self.assertIsNotNone(account.private_key) account.encrypt(self.password) self.assertFalse(account.serialize_encrypted) self.assertTrue(account.encrypted) self.assertEqual(account.seed, self.encrypted_account['seed']) - self.assertEqual(account.private_key, self.encrypted_account['private_key']) + self.assertEqual(account.private_key_string, self.encrypted_account['private_key']) + self.assertIsNone(account.private_key) self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) diff --git a/torba/baseaccount.py b/torba/baseaccount.py index eb649e411..0fa266cc1 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -198,19 +198,20 @@ class BaseAccount: } def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str, - seed: str, encrypted: bool, private_key: PrivateKey, public_key: PubKey, - address_generator: dict, password: str = None) -> None: + seed: str, private_key_string: str, encrypted: bool, + private_key: Optional[PrivateKey], public_key: PubKey, + address_generator: dict) -> None: self.ledger = ledger self.wallet = wallet self.id = public_key.address self.name = name self.seed = seed - self.password = password + self.private_key_string = private_key_string + self.password: Optional[str] = None self.encryption_init_vector = None self.encrypted = encrypted self.serialize_encrypted = encrypted - - self.private_key: Union[PrivateKey, str] = private_key + self.private_key = private_key self.public_key = public_key generator_name = address_generator.get('name', HierarchicalDeterministic.name) self.address_generator = self.address_generators[generator_name] @@ -237,7 +238,8 @@ class BaseAccount: @classmethod def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict): seed = d.get('seed', '') - private_key = d.get('private_key', '') + private_key_string = d.get('private_key', '') + private_key = None public_key = None encrypted = d.get('encrypted', False) if not encrypted: @@ -245,7 +247,7 @@ class BaseAccount: private_key = cls.get_private_key_from_seed(ledger, seed, '') public_key = private_key.public_key elif private_key: - private_key = from_extended_key_string(ledger, private_key) + private_key = from_extended_key_string(ledger, private_key_string) public_key = private_key.public_key if public_key is None: public_key = from_extended_key_string(ledger, d['public_key']) @@ -257,6 +259,7 @@ class BaseAccount: wallet=wallet, name=name, seed=seed, + private_key_string=private_key_string, encrypted=encrypted, private_key=private_key, public_key=public_key, @@ -264,19 +267,18 @@ class BaseAccount: ) def to_dict(self): - private_key, seed = self.private_key, self.seed + private_key_string, seed = self.private_key_string, self.seed if not self.encrypted and self.private_key: - private_key = self.private_key.extended_key_string() + private_key_string = self.private_key.extended_key_string() if not self.encrypted and self.serialize_encrypted: - private_key = aes_encrypt(self.password, private_key, init_vector=self.encryption_init_vector) - seed = aes_encrypt(self.password, self.seed, init_vector=self.encryption_init_vector) - + private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector) + seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector) return { 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': seed, 'encrypted': self.encrypted, - 'private_key': private_key, + 'private_key': private_key_string, 'public_key': self.public_key.extended_key_string(), 'address_generator': self.address_generator.to_dict(self.receiving, self.change) } @@ -300,10 +302,8 @@ class BaseAccount: def decrypt(self, password: str) -> None: assert self.encrypted, "Key is not encrypted." self.seed = aes_decrypt(password, self.seed) - p_k: Union[PrivateKey, str] = self.private_key - assert isinstance(p_k, str) self.private_key = from_extended_key_string( - self.ledger, aes_decrypt(password, str(p_k)) + self.ledger, aes_decrypt(password, self.private_key_string) ) self.password = password self.encrypted = False @@ -311,11 +311,11 @@ class BaseAccount: def encrypt(self, password: str) -> None: assert not self.encrypted, "Key is already encrypted." assert isinstance(self.private_key, PrivateKey) - self.seed = aes_encrypt(password, self.seed, init_vector=self.encryption_init_vector) - p_k: PrivateKey = self.private_key # this is because the type is changing from PrivateKey <-> str - extended: str = p_k.extended_key_string() - self.private_key = aes_encrypt(password, extended, - init_vector=self.encryption_init_vector) + self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector) + self.private_key_string = aes_encrypt( + password, self.private_key.extended_key_string(), self.encryption_init_vector + ) + self.private_key = None self.password = None self.encrypted = True