review feedback and pylint

This commit is contained in:
Jack Robison 2018-09-21 14:49:16 -04:00 committed by Lex Berezhny
parent aa47f10602
commit c0e0b4b745
2 changed files with 35 additions and 28 deletions

View file

@ -1,10 +1,10 @@
import typing import typing
from typing import Dict, Tuple, Type, Optional, Any from typing import Dict, Tuple, Type, Optional, Any, Union
from twisted.internet import defer from twisted.internet import defer
from torba.mnemonic import Mnemonic from torba.mnemonic import Mnemonic
from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.bip32 import PrivateKey, PubKey, from_extended_key_string
from torba.hash import double_sha256, aes_encrypt, aes_decrypt from torba.hash import aes_encrypt, aes_decrypt
from torba.constants import COIN from torba.constants import COIN
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -199,16 +199,18 @@ class BaseAccount:
def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str, def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str,
seed: str, encrypted: bool, private_key: PrivateKey, public_key: PubKey, seed: str, encrypted: bool, private_key: PrivateKey, public_key: PubKey,
address_generator: dict) -> None: address_generator: dict, password: str = None) -> None:
self.ledger = ledger self.ledger = ledger
self.wallet = wallet self.wallet = wallet
self.id = public_key.address self.id = public_key.address
self.name = name self.name = name
self.seed = seed self.seed = seed
self.password = None self.password = password
self.encryption_init_vector = None
self.encrypted = encrypted self.encrypted = encrypted
self.encrypted_on_disk = encrypted self.serialize_encrypted = encrypted
self.private_key = private_key
self.private_key: Union[PrivateKey, str] = private_key
self.public_key = public_key self.public_key = public_key
generator_name = address_generator.get('name', HierarchicalDeterministic.name) generator_name = address_generator.get('name', HierarchicalDeterministic.name)
self.address_generator = self.address_generators[generator_name] self.address_generator = self.address_generators[generator_name]
@ -265,9 +267,9 @@ class BaseAccount:
private_key, seed = self.private_key, self.seed private_key, seed = self.private_key, self.seed
if not self.encrypted and self.private_key: if not self.encrypted and self.private_key:
private_key = self.private_key.extended_key_string() private_key = self.private_key.extended_key_string()
if not self.encrypted and self.encrypted_on_disk: if not self.encrypted and self.serialize_encrypted:
private_key = aes_encrypt(self.password, private_key) private_key = aes_encrypt(self.password, private_key, init_vector=self.encryption_init_vector)
seed = aes_encrypt(self.password, self.seed) seed = aes_encrypt(self.password, self.seed, init_vector=self.encryption_init_vector)
return { return {
'ledger': self.ledger.get_id(), 'ledger': self.ledger.get_id(),
@ -295,20 +297,25 @@ class BaseAccount:
details['seed'] = self.seed details['seed'] = self.seed
return details return details
def decrypt(self, password): def decrypt(self, password: str) -> None:
assert self.encrypted, "Key is not encrypted." assert self.encrypted, "Key is not encrypted."
self.seed = aes_decrypt(password, self.seed.encode()).decode() self.seed = aes_decrypt(password, self.seed)
p_k: Union[PrivateKey, str] = self.private_key
assert isinstance(p_k, str)
self.private_key = from_extended_key_string( self.private_key = from_extended_key_string(
self.ledger, aes_decrypt(password, self.private_key.encode()).decode() self.ledger, aes_decrypt(password, str(p_k))
) )
self.password = password self.password = password
self.encrypted = False self.encrypted = False
def encrypt(self, password): def encrypt(self, password: str) -> None:
assert not self.encrypted, "Key is already encrypted." assert not self.encrypted, "Key is already encrypted."
self.seed = aes_encrypt(password, self.seed.encode()).decode() assert isinstance(self.private_key, PrivateKey)
private_key: PrivateKey = self.private_key self.seed = aes_encrypt(password, self.seed, init_vector=self.encryption_init_vector)
self.private_key = aes_encrypt(password, private_key.extended_key_string().encode()).decode() p_k: PrivateKey = self.private_key # this is because the type is changing from PrivateKey <-> str
extended: str = p_k.extended_key_string()
self.private_key = aes_encrypt(password, extended,
init_vector=self.encryption_init_vector)
self.password = None self.password = None
self.encrypted = True self.encrypted = True

View file

@ -106,27 +106,27 @@ def hex_str_to_hash(x):
return reversed(unhexlify(x)) return reversed(unhexlify(x))
def aes_encrypt(secret, value, iv=None): def aes_encrypt(secret: str, value: str, init_vector: bytes = None) -> str:
if iv: if init_vector is not None:
assert len(iv) == 16 assert len(init_vector) == 16
else: else:
iv = os.urandom(16) init_vector = os.urandom(16)
key = double_sha256(secret) key = double_sha256(secret.encode())
encryptor = Cipher(AES(key), modes.CBC(iv), default_backend()).encryptor() encryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).encryptor()
padder = PKCS7(AES.block_size).padder() padder = PKCS7(AES.block_size).padder()
padded_data = padder.update(value) + padder.finalize() padded_data = padder.update(value.encode()) + padder.finalize()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize() encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return base64.b64encode(iv + encrypted_data) return base64.b64encode(init_vector + encrypted_data).decode()
def aes_decrypt(secret, value): def aes_decrypt(secret: str, value: str) -> str:
data = base64.b64decode(value) data = base64.b64decode(value.encode())
key = double_sha256(secret) key = double_sha256(secret.encode())
init_vector, data = data[:16], data[16:] init_vector, data = data[:16], data[16:]
decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor() decryptor = Cipher(AES(key), modes.CBC(init_vector), default_backend()).decryptor()
unpadder = PKCS7(AES.block_size).unpadder() unpadder = PKCS7(AES.block_size).unpadder()
result = unpadder.update(decryptor.update(data)) + unpadder.finalize() result = unpadder.update(decryptor.update(data)) + unpadder.finalize()
return result return result.decode()
class Base58Error(Exception): class Base58Error(Exception):