simplified and cleaned up typing for encryp / decrypt

This commit is contained in:
Lex Berezhny 2018-09-24 23:12:46 -04:00
parent c45c792657
commit 58d2c04b9f
2 changed files with 24 additions and 22 deletions

View file

@ -371,11 +371,13 @@ class AccountEncryptionTests(unittest.TestCase):
self.assertFalse(account.serialize_encrypted) self.assertFalse(account.serialize_encrypted)
self.assertFalse(account.encrypted) self.assertFalse(account.encrypted)
self.assertIsNotNone(account.private_key)
account.encrypt(self.password) account.encrypt(self.password)
self.assertFalse(account.serialize_encrypted) self.assertFalse(account.serialize_encrypted)
self.assertTrue(account.encrypted) self.assertTrue(account.encrypted)
self.assertEqual(account.seed, self.encrypted_account['seed']) self.assertEqual(account.seed, self.encrypted_account['seed'])
self.assertEqual(account.private_key, self.encrypted_account['private_key']) self.assertEqual(account.private_key_string, self.encrypted_account['private_key'])
self.assertIsNone(account.private_key)
self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed']) self.assertEqual(account.to_dict()['seed'], self.encrypted_account['seed'])
self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key']) self.assertEqual(account.to_dict()['private_key'], self.encrypted_account['private_key'])

View file

@ -198,19 +198,20 @@ class BaseAccount:
} }
def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str, def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str,
seed: str, encrypted: bool, private_key: PrivateKey, public_key: PubKey, seed: str, private_key_string: str, encrypted: bool,
address_generator: dict, password: str = None) -> None: private_key: Optional[PrivateKey], public_key: PubKey,
address_generator: dict) -> None:
self.ledger = ledger self.ledger = ledger
self.wallet = wallet self.wallet = wallet
self.id = public_key.address self.id = public_key.address
self.name = name self.name = name
self.seed = seed self.seed = seed
self.password = password self.private_key_string = private_key_string
self.password: Optional[str] = None
self.encryption_init_vector = None self.encryption_init_vector = None
self.encrypted = encrypted self.encrypted = encrypted
self.serialize_encrypted = encrypted self.serialize_encrypted = encrypted
self.private_key = private_key
self.private_key: Union[PrivateKey, str] = private_key
self.public_key = public_key self.public_key = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name) generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name] self.address_generator = self.address_generators[generator_name]
@ -237,7 +238,8 @@ class BaseAccount:
@classmethod @classmethod
def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict): def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict):
seed = d.get('seed', '') seed = d.get('seed', '')
private_key = d.get('private_key', '') private_key_string = d.get('private_key', '')
private_key = None
public_key = None public_key = None
encrypted = d.get('encrypted', False) encrypted = d.get('encrypted', False)
if not encrypted: if not encrypted:
@ -245,7 +247,7 @@ class BaseAccount:
private_key = cls.get_private_key_from_seed(ledger, seed, '') private_key = cls.get_private_key_from_seed(ledger, seed, '')
public_key = private_key.public_key public_key = private_key.public_key
elif private_key: elif private_key:
private_key = from_extended_key_string(ledger, private_key) private_key = from_extended_key_string(ledger, private_key_string)
public_key = private_key.public_key public_key = private_key.public_key
if public_key is None: if public_key is None:
public_key = from_extended_key_string(ledger, d['public_key']) public_key = from_extended_key_string(ledger, d['public_key'])
@ -257,6 +259,7 @@ class BaseAccount:
wallet=wallet, wallet=wallet,
name=name, name=name,
seed=seed, seed=seed,
private_key_string=private_key_string,
encrypted=encrypted, encrypted=encrypted,
private_key=private_key, private_key=private_key,
public_key=public_key, public_key=public_key,
@ -264,19 +267,18 @@ class BaseAccount:
) )
def to_dict(self): def to_dict(self):
private_key, seed = self.private_key, self.seed private_key_string, seed = self.private_key_string, self.seed
if not self.encrypted and self.private_key: if not self.encrypted and self.private_key:
private_key = self.private_key.extended_key_string() private_key_string = self.private_key.extended_key_string()
if not self.encrypted and self.serialize_encrypted: if not self.encrypted and self.serialize_encrypted:
private_key = aes_encrypt(self.password, private_key, init_vector=self.encryption_init_vector) private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector)
seed = aes_encrypt(self.password, self.seed, init_vector=self.encryption_init_vector) seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector)
return { return {
'ledger': self.ledger.get_id(), 'ledger': self.ledger.get_id(),
'name': self.name, 'name': self.name,
'seed': seed, 'seed': seed,
'encrypted': self.encrypted, 'encrypted': self.encrypted,
'private_key': private_key, 'private_key': private_key_string,
'public_key': self.public_key.extended_key_string(), 'public_key': self.public_key.extended_key_string(),
'address_generator': self.address_generator.to_dict(self.receiving, self.change) 'address_generator': self.address_generator.to_dict(self.receiving, self.change)
} }
@ -300,10 +302,8 @@ class BaseAccount:
def decrypt(self, password: str) -> None: def decrypt(self, password: str) -> None:
assert self.encrypted, "Key is not encrypted." assert self.encrypted, "Key is not encrypted."
self.seed = aes_decrypt(password, self.seed) self.seed = aes_decrypt(password, self.seed)
p_k: Union[PrivateKey, str] = self.private_key
assert isinstance(p_k, str)
self.private_key = from_extended_key_string( self.private_key = from_extended_key_string(
self.ledger, aes_decrypt(password, str(p_k)) self.ledger, aes_decrypt(password, self.private_key_string)
) )
self.password = password self.password = password
self.encrypted = False self.encrypted = False
@ -311,11 +311,11 @@ class BaseAccount:
def encrypt(self, password: str) -> None: def encrypt(self, password: str) -> None:
assert not self.encrypted, "Key is already encrypted." assert not self.encrypted, "Key is already encrypted."
assert isinstance(self.private_key, PrivateKey) assert isinstance(self.private_key, PrivateKey)
self.seed = aes_encrypt(password, self.seed, init_vector=self.encryption_init_vector) self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector)
p_k: PrivateKey = self.private_key # this is because the type is changing from PrivateKey <-> str self.private_key_string = aes_encrypt(
extended: str = p_k.extended_key_string() password, self.private_key.extended_key_string(), self.encryption_init_vector
self.private_key = aes_encrypt(password, extended, )
init_vector=self.encryption_init_vector) self.private_key = None
self.password = None self.password = None
self.encrypted = True self.encrypted = True