diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index bc6726449..47266fef6 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -3,10 +3,10 @@ from twisted.trial import unittest from twisted.internet import defer from torba.coin.bitcoinsegwit import MainNetLedger -from torba.baseaccount import KeyChain, SingleKey +from torba.baseaccount import HierarchicalDeterministic, SingleKey -class TestKeyChainAccount(unittest.TestCase): +class TestHierarchicalDeterministicAccount(unittest.TestCase): @defer.inlineCallbacks def setUp(self): @@ -42,7 +42,7 @@ class TestKeyChainAccount(unittest.TestCase): def test_ensure_address_gap(self): account = self.account - self.assertIsInstance(account.receiving, KeyChain) + self.assertIsInstance(account.receiving, HierarchicalDeterministic) yield account.receiving.generate_keys(4, 7) yield account.receiving.generate_keys(0, 3) @@ -95,7 +95,7 @@ class TestKeyChainAccount(unittest.TestCase): account = self.ledger.account_class.from_seed( self.ledger, "carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" - "sent", "torba", receiving_gap=3, change_gap=2 + "sent", "torba", {'name': 'deterministic-chain', 'receiving_gap': 3, 'change_gap': 2} ) self.assertEqual( account.private_key.extended_key_string(), @@ -139,11 +139,11 @@ class TestKeyChainAccount(unittest.TestCase): 'public_key': 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'iW44g14WF52fYC5J483wqQ5ZP', - 'is_hd': True, - 'receiving_gap': 5, - 'receiving_maximum_uses_per_address': 2, - 'change_gap': 5, - 'change_maximum_uses_per_address': 2 + 'address_generator': { + 'name': 'deterministic-chain', + 'receiving': {'gap': 5, 'maximum_uses_per_address': 2}, + 'change': {'gap': 5, 'maximum_uses_per_address': 2} + } } account = self.ledger.account_class.from_dict(self.ledger, account_data) @@ -166,7 +166,7 @@ class TestSingleKeyAccount(unittest.TestCase): def setUp(self): self.ledger = MainNetLedger({'db': MainNetLedger.database_class(':memory:')}) yield self.ledger.db.start() - self.account = self.ledger.account_class.generate(self.ledger, u"torba", is_hd=False) + self.account = self.ledger.account_class.generate(self.ledger, u"torba", {'name': 'single-address'}) @defer.inlineCallbacks def test_generate_account(self): @@ -247,7 +247,7 @@ class TestSingleKeyAccount(unittest.TestCase): account = self.ledger.account_class.from_seed( self.ledger, "carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" - "sent", "torba", is_hd=False + "sent", "torba", {'name': 'single-address'} ) self.assertEqual( account.private_key.extended_key_string(), @@ -291,7 +291,7 @@ class TestSingleKeyAccount(unittest.TestCase): 'public_key': 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'iW44g14WF52fYC5J483wqQ5ZP', - 'is_hd': False + 'address_generator': {'name': 'single-address'} } account = self.ledger.account_class.from_dict(self.ledger, account_data) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 5966cf374..47af97eb7 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -148,7 +148,7 @@ class TestTransactionSigning(unittest.TestCase): account = self.ledger.account_class.from_seed( self.ledger, u"carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" - u"sent", u"torba" + u"sent", u"torba", {} ) yield account.ensure_address_gap() diff --git a/tests/unit/test_wallet.py b/tests/unit/test_wallet.py index b52ee4ebb..31a9cb8ab 100644 --- a/tests/unit/test_wallet.py +++ b/tests/unit/test_wallet.py @@ -44,11 +44,11 @@ class TestWalletCreation(unittest.TestCase): 'public_key': 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'iW44g14WF52fYC5J483wqQ5ZP', - 'is_hd': True, - 'receiving_gap': 10, - 'receiving_maximum_uses_per_address': 2, - 'change_gap': 10, - 'change_maximum_uses_per_address': 2, + 'address_generator': { + 'name': 'deterministic-chain', + 'receiving': {'gap': 17, 'maximum_uses_per_address': 3}, + 'change': {'gap': 10, 'maximum_uses_per_address': 3} + } } ] } diff --git a/torba/baseaccount.py b/torba/baseaccount.py index 532b364da..eb2835bbd 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -1,5 +1,5 @@ import typing -from typing import Sequence +from typing import Tuple, Type from twisted.internet import defer from torba.mnemonic import Mnemonic @@ -10,7 +10,9 @@ if typing.TYPE_CHECKING: from torba import baseledger -class KeyManager: +class AddressManager: + + name: str __slots__ = 'account', 'public_key', 'chain_number' @@ -19,6 +21,15 @@ class KeyManager: self.public_key = public_key self.chain_number = chain_number + @classmethod + def from_dict(cls, account: 'BaseAccount', d: dict) \ + -> Tuple['AddressManager', 'AddressManager']: + raise NotImplementedError + + @classmethod + def to_dict(cls, receiving: 'AddressManager', change: 'AddressManager') -> dict: + return {'name': cls.name} + @property def db(self): return self.account.ledger.db @@ -28,6 +39,9 @@ class KeyManager: self.account, self.chain_number, limit, max_used_times, order_by ) + def get_private_key(self, index: int) -> PrivateKey: + raise NotImplementedError + def get_max_gap(self) -> defer.Deferred: raise NotImplementedError @@ -51,17 +65,38 @@ class KeyManager: defer.returnValue(addresses[0]) -class KeyChain(KeyManager): +class HierarchicalDeterministic(AddressManager): """ Implements simple version of Bitcoin Hierarchical Deterministic key management. """ + name = "deterministic-chain" + __slots__ = 'gap', 'maximum_uses_per_address' - def __init__(self, account: 'BaseAccount', root_public_key: PubKey, - chain_number: int, gap: int, maximum_uses_per_address: int) -> None: - super().__init__(account, root_public_key.child(chain_number), chain_number) + def __init__(self, account: 'BaseAccount', chain: int, gap: int, maximum_uses_per_address: int) -> None: + super().__init__(account, account.public_key.child(chain), chain) self.gap = gap self.maximum_uses_per_address = maximum_uses_per_address + @classmethod + def from_dict(cls, account: 'BaseAccount', d: dict) -> Tuple[AddressManager, AddressManager]: + return ( + cls(account, 0, **d.get('receiving', {'gap': 20, 'maximum_uses_per_address': 2})), + cls(account, 1, **d.get('change', {'gap': 6, 'maximum_uses_per_address': 2})) + ) + + @classmethod + def to_dict(cls, receiving: 'HierarchicalDeterministic', change: 'HierarchicalDeterministic') -> dict: + d = super().to_dict(receiving, change) + d['receiving'] = receiving.to_dict_instance() + d['change'] = change.to_dict_instance() + return d + + def to_dict_instance(self): + return {'gap': self.gap, 'maximum_uses_per_address': self.maximum_uses_per_address} + + def get_private_key(self, index: int) -> PrivateKey: + return self.account.private_key.child(self.chain_number).child(index) + @defer.inlineCallbacks def generate_keys(self, start: int, end: int) -> defer.Deferred: new_keys = [] @@ -111,11 +146,22 @@ class KeyChain(KeyManager): ) -class SingleKey(KeyManager): - """ Single Key manager always returns the same address for all operations. """ +class SingleKey(AddressManager): + """ Single Key address manager always returns the same address for all operations. """ + + name = "single-address" __slots__ = () + @classmethod + def from_dict(cls, account: 'BaseAccount', d: dict)\ + -> Tuple[AddressManager, AddressManager]: + same_address_manager = cls(account, account.public_key, 0) + return same_address_manager, same_address_manager + + def get_private_key(self, index: int) -> PrivateKey: + return self.account.private_key + def get_max_gap(self) -> defer.Deferred: return defer.succeed(0) @@ -138,10 +184,13 @@ class BaseAccount: mnemonic_class = Mnemonic private_key_class = PrivateKey public_key_class = PubKey + address_generators = { + SingleKey.name: SingleKey, + HierarchicalDeterministic.name: HierarchicalDeterministic, + } - def __init__(self, ledger: 'baseledger.BaseLedger', name: str, seed: str, encrypted: bool, is_hd: bool, - private_key: PrivateKey, public_key: PubKey, receiving_gap: int = 20, change_gap: int = 6, - receiving_maximum_uses_per_address: int = 2, change_maximum_uses_per_address: int = 2 + def __init__(self, ledger: 'baseledger.BaseLedger', name: str, seed: str, encrypted: bool, + private_key: PrivateKey, public_key: PubKey, address_generator: dict ) -> None: self.ledger = ledger self.name = name @@ -149,34 +198,26 @@ class BaseAccount: self.encrypted = encrypted self.private_key = private_key self.public_key = public_key - if is_hd: - self.receiving: KeyManager = KeyChain( - self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address - ) - self.change: KeyManager = KeyChain( - self, public_key, 1, change_gap, change_maximum_uses_per_address - ) - self.keychains: Sequence[KeyManager] = (self.receiving, self.change) - else: - self.change = self.receiving = SingleKey(self, public_key, 0) - self.keychains = (self.receiving,) + generator_name = address_generator.get('name', HierarchicalDeterministic.name) + self.address_generator: Type[AddressManager] = self.address_generators[generator_name] + self.receiving, self.change = self.address_generator.from_dict(self, address_generator) + self.address_managers = {self.receiving, self.change} ledger.add_account(self) @classmethod - def generate(cls, ledger: 'baseledger.BaseLedger', password: str, **kwargs): + def generate(cls, ledger: 'baseledger.BaseLedger', password: str, address_generator: dict = None): seed = cls.mnemonic_class().make_seed() - return cls.from_seed(ledger, seed, password, **kwargs) + return cls.from_seed(ledger, seed, password, address_generator or {}) @classmethod - def from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str, - is_hd: bool = True, **kwargs): + def from_seed(cls, ledger: 'baseledger.BaseLedger', seed: str, password: str, address_generator: dict): private_key = cls.get_private_key_from_seed(ledger, seed, password) return cls( ledger=ledger, name='Account #{}'.format(private_key.public_key.address), - seed=seed, encrypted=False, is_hd=is_hd, + seed=seed, encrypted=False, private_key=private_key, public_key=private_key.public_key, - **kwargs + address_generator=address_generator ) @classmethod @@ -193,54 +234,30 @@ class BaseAccount: else: private_key = d['private_key'] public_key = from_extended_key_string(ledger, d['public_key']) - - kwargs = dict( + return cls( ledger=ledger, name=d['name'], seed=d['seed'], encrypted=d['encrypted'], private_key=private_key, public_key=public_key, - is_hd=False + address_generator=d['address_generator'] ) - if d['is_hd']: - kwargs.update(dict( - receiving_gap=d['receiving_gap'], - change_gap=d['change_gap'], - receiving_maximum_uses_per_address=d['receiving_maximum_uses_per_address'], - change_maximum_uses_per_address=d['change_maximum_uses_per_address'], - is_hd=True - )) - - return cls(**kwargs) - def to_dict(self): private_key = self.private_key if not self.encrypted and self.private_key: private_key = self.private_key.extended_key_string() - - d = { + return { 'ledger': self.ledger.get_id(), 'name': self.name, 'seed': self.seed, 'encrypted': self.encrypted, 'private_key': private_key, 'public_key': self.public_key.extended_key_string(), - 'is_hd': False + 'address_generator': self.address_generator.to_dict(self.receiving, self.change) } - if isinstance(self.receiving, KeyChain) and isinstance(self.change, KeyChain): - d.update({ - 'receiving_gap': self.receiving.gap, - 'change_gap': self.change.gap, - 'receiving_maximum_uses_per_address': self.receiving.maximum_uses_per_address, - 'change_maximum_uses_per_address': self.change.maximum_uses_per_address, - 'is_hd': True - }) - - return d - def decrypt(self, password): assert self.encrypted, "Key is not encrypted." secret = double_sha256(password) @@ -258,8 +275,8 @@ class BaseAccount: @defer.inlineCallbacks def ensure_address_gap(self): addresses = [] - for keychain in self.keychains: - new_addresses = yield keychain.ensure_address_gap() + for address_manager in self.address_managers: + new_addresses = yield address_manager.ensure_address_gap() addresses.extend(new_addresses) defer.returnValue(addresses) @@ -273,9 +290,8 @@ class BaseAccount: def get_private_key(self, chain: int, index: int) -> PrivateKey: assert not self.encrypted, "Cannot get private key on encrypted wallet account." - if isinstance(self.receiving, SingleKey): - return self.private_key - return self.private_key.child(chain).child(index) + address_manager = {0: self.receiving, 1: self.change}[chain] + return address_manager.get_private_key(index) def get_balance(self, confirmations: int = 6, **constraints): if confirmations > 0: diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 0f326dfa2..e7dc1b4a7 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -328,9 +328,8 @@ class BaseTransaction: self.locktime = stream.read_uint32() @classmethod - def ensure_all_have_same_ledger( - cls, funding_accounts: Iterable[BaseAccount], change_account: BaseAccount = None)\ - -> 'baseledger.BaseLedger': + def ensure_all_have_same_ledger(cls, funding_accounts: Iterable[BaseAccount], + change_account: BaseAccount = None) -> 'baseledger.BaseLedger': ledger = None for account in funding_accounts: if ledger is None: diff --git a/torba/wallet.py b/torba/wallet.py index ca6d48fab..4e5afbd8d 100644 --- a/torba/wallet.py +++ b/torba/wallet.py @@ -24,7 +24,7 @@ class Wallet: self.storage = storage or WalletStorage() def generate_account(self, ledger: 'baseledger.BaseLedger') -> 'baseaccount.BaseAccount': - account = ledger.account_class.generate(ledger, u'torba') + account = ledger.account_class.generate(ledger, u'torba', {}) self.accounts.append(account) return account