import random import typing from typing import List, Dict, Tuple, Type, Optional, Any from twisted.internet import defer from torba.mnemonic import Mnemonic from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.hash import aes_encrypt, aes_decrypt from torba.constants import COIN if typing.TYPE_CHECKING: from torba import baseledger from torba import wallet as basewallet from torba import basetransaction class AddressManager: name: str __slots__ = 'account', 'public_key', 'chain_number' def __init__(self, account, public_key, chain_number): self.account = account self.public_key = public_key self.chain_number = chain_number @classmethod def from_dict(cls, account: 'BaseAccount', d: dict) \ -> Tuple['AddressManager', 'AddressManager']: raise NotImplementedError @classmethod def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> Dict: d: Dict[str, Any] = {'name': cls.name} receiving_dict = receiving.to_dict_instance() if receiving_dict: d['receiving'] = receiving_dict change_dict = change.to_dict_instance() if change_dict: d['change'] = change_dict return d def to_dict_instance(self) -> Optional[dict]: raise NotImplementedError @property def db(self): return self.account.ledger.db def _query_addresses(self, **constraints): return self.db.get_addresses( account=self.account, chain=self.chain_number, **constraints ) def get_private_key(self, index: int) -> PrivateKey: raise NotImplementedError def get_max_gap(self) -> defer.Deferred: raise NotImplementedError def ensure_address_gap(self) -> defer.Deferred: raise NotImplementedError def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: raise NotImplementedError @defer.inlineCallbacks def get_addresses(self, only_usable: bool = False, **constraints) -> defer.Deferred: records = yield self.get_address_records(only_usable=only_usable, **constraints) return [r['address'] for r in records] @defer.inlineCallbacks def get_or_create_usable_address(self) -> defer.Deferred: addresses = yield self.get_addresses(only_usable=True, limit=10) if addresses: return random.choice(addresses) addresses = yield self.ensure_address_gap() return addresses[0] class HierarchicalDeterministic(AddressManager): """ Implements simple version of Bitcoin Hierarchical Deterministic key management. """ name = "deterministic-chain" __slots__ = 'gap', 'maximum_uses_per_address' def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None: super().__init__(account, account.public_key.child(chain), chain) self.gap = gap self.maximum_uses_per_address = maximum_uses_per_address @classmethod def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]: return ( cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})), cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2})) ) def to_dict_instance(self): return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address} def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key.child(self.chain_number).child(index) @defer.inlineCallbacks def generate_keys(self, start: int, end: int) -> defer.Deferred: new_keys = [] for index in range(start, end+1): new_keys.append((index, self.public_key.child(index))) yield self.db.add_keys( self.account, self.chain_number, new_keys ) return [key[1].address for key in new_keys] @defer.inlineCallbacks def get_max_gap(self) -> defer.Deferred: addresses = yield self._query_addresses(order_by="position ASC") max_gap = 0 current_gap = 0 for address in addresses: if address['used_times'] == 0: current_gap += 1 else: max_gap = max(max_gap, current_gap) current_gap = 0 return max_gap @defer.inlineCallbacks def ensure_address_gap(self) -> defer.Deferred: addresses = yield self._query_addresses(limit=self.gap, order_by="position DESC") existing_gap = 0 for address in addresses: if address['used_times'] == 0: existing_gap += 1 else: break if existing_gap == self.gap: return [] start = addresses[0]['position']+1 if addresses else 0 end = start + (self.gap - existing_gap) new_keys = yield self.generate_keys(start, end-1) return new_keys def get_address_records(self, only_usable: bool = False, **constraints): if only_usable: constraints['used_times__lte'] = self.maximum_uses_per_address return self._query_addresses(order_by="used_times ASC, position ASC", **constraints) class SingleKey(AddressManager): """ Single Key address manager always returns the same address for all operations. """ name = "single-address" __slots__ = () @classmethod def from_dict(cls, account: 'BaseAccount', d: dict)\ -> Tuple[AddressManager, AddressManager]: same_address_manager = cls(account, account.public_key, 0) return same_address_manager, same_address_manager def to_dict_instance(self): return None def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key def get_max_gap(self) -> defer.Deferred: return defer.succeed(0) @defer.inlineCallbacks def ensure_address_gap(self) -> defer.Deferred: exists = yield self.get_address_records() if not exists: yield self.db.add_keys( self.account, self.chain_number, [(0, self.public_key)] ) return [self.public_key.address] return [] def get_address_records(self, only_usable: bool = False, **constraints) -> defer.Deferred: return self._query_addresses(**constraints) class BaseAccount: mnemonic_class = Mnemonic private_key_class = PrivateKey public_key_class = PubKey address_generators: Dict[str, Type[AddressManager]] = { SingleKey.name: SingleKey, HierarchicalDeterministic.name: HierarchicalDeterministic, } def __init__(self, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str, seed: str, private_key_string: str, encrypted: bool, private_key: Optional[PrivateKey], public_key: PubKey, address_generator: dict) -> None: self.ledger = ledger self.wallet = wallet self.id = public_key.address self.name = name self.seed = seed self.private_key_string = private_key_string self.password: Optional[str] = None self.encryption_init_vector = None self.encrypted = encrypted self.serialize_encrypted = encrypted self.private_key = private_key self.public_key = public_key generator_name = address_generator.get('name', HierarchicalDeterministic.name) self.address_generator = self.address_generators[generator_name] self.receiving, self.change = self.address_generator.from_dict(self, address_generator) self.address_managers = {self.receiving, self.change} ledger.add_account(self) wallet.add_account(self) @classmethod def generate(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', name: str = None, address_generator: dict = None): return cls.from_dict(ledger, wallet, { 'name': name, 'seed': cls.mnemonic_class().make_seed(), 'address_generator': address_generator or {} }) @classmethod def get_private_key_from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str): return cls.private_key_class.from_seed( ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password) ) @classmethod def from_dict(cls, ledger: 'baseledger.BaseLedger', wallet: 'basewallet.Wallet', d: dict): seed = d.get('seed', '') private_key_string = d.get('private_key', '') private_key = None public_key = None encrypted = d.get('encrypted', False) if not encrypted: if seed: private_key = cls.get_private_key_from_seed(ledger, seed, '') public_key = private_key.public_key elif private_key: private_key = from_extended_key_string(ledger, private_key_string) public_key = private_key.public_key if public_key is None: public_key = from_extended_key_string(ledger, d['public_key']) name = d.get('name') if not name: name = 'Account #{}'.format(public_key.address) return cls( ledger=ledger, wallet=wallet, name=name, seed=seed, private_key_string=private_key_string, encrypted=encrypted, private_key=private_key, public_key=public_key, address_generator=d.get('address_generator', {}) ) def to_dict(self): private_key_string, seed = self.private_key_string, self.seed if not self.encrypted and self.private_key: private_key_string = self.private_key.extended_key_string() if not self.encrypted and self.serialize_encrypted: private_key_string = aes_encrypt(self.password, private_key_string, self.encryption_init_vector) seed = aes_encrypt(self.password, self.seed, self.encryption_init_vector) return { 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': seed, 'encrypted': self.encrypted, 'private_key': private_key_string, 'public_key': self.public_key.extended_key_string(), 'address_generator': self.address_generator.to_dict(self.receiving, self.change) } @defer.inlineCallbacks def get_details(self, show_seed=False, **kwargs): satoshis = yield self.get_balance(**kwargs) details = { 'id': self.id, 'name': self.name, 'coins': round(satoshis/COIN, 2), 'satoshis': satoshis, 'encrypted': self.encrypted, 'public_key': self.public_key.extended_key_string(), 'address_generator': self.address_generator.to_dict(self.receiving, self.change) } if show_seed: details['seed'] = self.seed return details def decrypt(self, password: str) -> None: assert self.encrypted, "Key is not encrypted." self.seed = aes_decrypt(password, self.seed) self.private_key = from_extended_key_string( self.ledger, aes_decrypt(password, self.private_key_string) ) self.password = password self.encrypted = False def encrypt(self, password: str) -> None: assert not self.encrypted, "Key is already encrypted." assert isinstance(self.private_key, PrivateKey) self.seed = aes_encrypt(password, self.seed, self.encryption_init_vector) self.private_key_string = aes_encrypt( password, self.private_key.extended_key_string(), self.encryption_init_vector ) self.private_key = None self.password = None self.encrypted = True @defer.inlineCallbacks def ensure_address_gap(self): addresses = [] for address_manager in self.address_managers: new_addresses = yield address_manager.ensure_address_gap() addresses.extend(new_addresses) return addresses @defer.inlineCallbacks def get_addresses(self, **constraints) -> defer.Deferred: rows = yield self.ledger.db.select_addresses('address', **constraints) return [r[0] for r in rows] def get_address_records(self, **constraints) -> defer.Deferred: return self.ledger.db.get_addresses(account=self, **constraints) def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." address_manager = {0: self.receiving, 1: self.change}[chain] return address_manager.get_private_key(index) def get_balance(self, confirmations: int = 0, **constraints): if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) constraints.update({'height__lte': height, 'height__gt': 0}) return self.ledger.db.get_balance(account=self, **constraints) @defer.inlineCallbacks def get_max_gap(self): change_gap = yield self.change.get_max_gap() receiving_gap = yield self.receiving.get_max_gap() return { 'max_change_gap': change_gap, 'max_receiving_gap': receiving_gap, } def get_utxos(self, **constraints): return self.ledger.db.get_utxos(account=self, **constraints) def get_transactions(self, **constraints) -> List['basetransaction.BaseTransaction']: return self.ledger.db.get_transactions(account=self, **constraints) @defer.inlineCallbacks def fund(self, to_account, amount=None, everything=False, outputs=1, broadcast=False, **constraints): assert self.ledger == to_account.ledger, 'Can only transfer between accounts of the same ledger.' tx_class = self.ledger.transaction_class if everything: utxos = yield self.get_utxos(**constraints) yield self.ledger.reserve_outputs(utxos) tx = yield tx_class.create( inputs=[tx_class.input_class.spend(txo) for txo in utxos], outputs=[], funding_accounts=[self], change_account=to_account ) elif amount > 0: to_address = yield to_account.change.get_or_create_usable_address() to_hash160 = to_account.ledger.address_to_hash160(to_address) tx = yield tx_class.create( inputs=[], outputs=[ tx_class.output_class.pay_pubkey_hash(amount//outputs, to_hash160) for _ in range(outputs) ], funding_accounts=[self], change_account=self ) else: raise ValueError('An amount is required.') if broadcast: yield self.ledger.broadcast(tx) else: yield self.ledger.release_outputs( [txi.txo_ref.txo for txi in tx.inputs] ) return tx