diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 4e26c68e5..f937a12d5 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -3,77 +3,21 @@ from twisted.trial import unittest from twisted.internet import defer from torba.coin.bitcoinsegwit import MainNetLedger +from torba.baseaccount import KeyChain, SingleKey -class TestKeyChain(unittest.TestCase): - - def setUp(self): - self.ledger = MainNetLedger({'db': MainNetLedger.database_class(':memory:')}) - return self.ledger.db.start() +class TestKeyChainAccount(unittest.TestCase): @defer.inlineCallbacks - def test_address_gap_algorithm(self): - account = self.ledger.account_class.generate(self.ledger, u"torba") - - # save records out of order to make sure we're really testing ORDER BY - # and not coincidentally getting records in the correct order - yield account.receiving.generate_keys(4, 7) - yield account.receiving.generate_keys(0, 3) - yield account.receiving.generate_keys(8, 11) - keys = yield account.receiving.get_addresses(None, True) - self.assertEqual( - [key['position'] for key in keys], - [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] - ) - - # we have 12, but default gap is 20 - new_keys = yield account.receiving.ensure_address_gap() - self.assertEqual(len(new_keys), 8) - keys = yield account.receiving.get_addresses(None, True) - self.assertEqual( - [key['position'] for key in keys], - [19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] - ) - - # case #1: no new addresses needed - empty = yield account.receiving.ensure_address_gap() - self.assertEqual(len(empty), 0) - - # case #2: only one new addressed needed - keys = yield account.receiving.get_addresses(None, True) - yield self.ledger.db.set_address_history(keys[19]['address'], 'a:1:') - new_keys = yield account.receiving.ensure_address_gap() - self.assertEqual(len(new_keys), 1) - - # case #3: 20 addresses needed - keys = yield account.receiving.get_addresses(None, True) - yield self.ledger.db.set_address_history(keys[0]['address'], 'a:1:') - new_keys = yield account.receiving.ensure_address_gap() - self.assertEqual(len(new_keys), 20) - - @defer.inlineCallbacks - def test_create_usable_address(self): - account = self.ledger.account_class.generate(self.ledger, u"torba") - - keys = yield account.receiving.get_addresses(None, True) - self.assertEqual(len(keys), 0) - - address = yield account.receiving.get_or_create_usable_address() - self.assertIsNotNone(address) - - keys = yield account.receiving.get_addresses(None, True) - self.assertEqual(len(keys), 20) - - -class TestAccount(unittest.TestCase): - def setUp(self): self.ledger = MainNetLedger({'db': MainNetLedger.database_class(':memory:')}) - return self.ledger.db.start() + yield self.ledger.db.start() + self.account = self.ledger.account_class.generate(self.ledger, u"torba") @defer.inlineCallbacks def test_generate_account(self): - account = self.ledger.account_class.generate(self.ledger, u"torba") + account = self.account + self.assertEqual(account.ledger, self.ledger) self.assertIsNotNone(account.seed) self.assertEqual(account.public_key.ledger, self.ledger) @@ -91,13 +35,68 @@ class TestAccount(unittest.TestCase): addresses = yield account.change.get_addresses() self.assertEqual(len(addresses), 6) + addresses = yield account.get_addresses() + self.assertEqual(len(addresses), 26) + + @defer.inlineCallbacks + def test_ensure_address_gap(self): + account = self.account + + self.assertIsInstance(account.receiving, KeyChain) + + yield account.receiving.generate_keys(4, 7) + yield account.receiving.generate_keys(0, 3) + yield account.receiving.generate_keys(8, 11) + records = yield account.receiving.get_address_records() + self.assertEqual( + [r['position'] for r in records], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ) + + # we have 12, but default gap is 20 + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 8) + records = yield account.receiving.get_address_records() + self.assertEqual( + [r['position'] for r in records], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + ) + + # case #1: no new addresses needed + empty = yield account.receiving.ensure_address_gap() + self.assertEqual(len(empty), 0) + + # case #2: only one new addressed needed + records = yield account.receiving.get_address_records() + yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 1) + + # case #3: 20 addresses needed + yield self.ledger.db.set_address_history(new_keys[0], 'a:1:') + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 20) + + @defer.inlineCallbacks + def test_get_or_create_usable_address(self): + account = self.account + + keys = yield account.receiving.get_addresses() + self.assertEqual(len(keys), 0) + + address = yield account.receiving.get_or_create_usable_address() + self.assertIsNotNone(address) + + keys = yield account.receiving.get_addresses() + self.assertEqual(len(keys), 20) + @defer.inlineCallbacks def test_generate_account_from_seed(self): account = self.ledger.account_class.from_seed( self.ledger, u"carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" u"sent", - u"torba" + u"torba", receiving_gap=3, change_gap=2 ) self.assertEqual( account.private_key.extended_key_string(), @@ -112,7 +111,6 @@ class TestAccount(unittest.TestCase): address = yield account.receiving.ensure_address_gap() self.assertEqual(address[0], b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP') - self.maxDiff = None private_key = yield self.ledger.get_private_key_for_address(b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP') self.assertEqual( private_key.extended_key_string(), @@ -131,6 +129,7 @@ class TestAccount(unittest.TestCase): @defer.inlineCallbacks def test_load_and_save_account(self): account_data = { + 'name': 'My Account', 'seed': "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" "h absent", @@ -141,10 +140,11 @@ class TestAccount(unittest.TestCase): 'public_key': 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'iW44g14WF52fYC5J483wqQ5ZP', - 'receiving_gap': 10, - 'receiving_maximum_use_per_address': 2, - 'change_gap': 10, - 'change_maximum_use_per_address': 2 + 'is_hd': True, + 'receiving_gap': 5, + 'receiving_maximum_uses_per_address': 2, + 'change_gap': 5, + 'change_maximum_uses_per_address': 2 } account = self.ledger.account_class.from_dict(self.ledger, account_data) @@ -152,9 +152,159 @@ class TestAccount(unittest.TestCase): yield account.ensure_address_gap() addresses = yield account.receiving.get_addresses() - self.assertEqual(len(addresses), 10) + self.assertEqual(len(addresses), 5) addresses = yield account.change.get_addresses() - self.assertEqual(len(addresses), 10) + self.assertEqual(len(addresses), 5) + + self.maxDiff = None + account_data['ledger'] = 'btc_mainnet' + self.assertDictEqual(account_data, account.to_dict()) + + +class TestSingleKeyAccount(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.ledger = MainNetLedger({'db': MainNetLedger.database_class(':memory:')}) + yield self.ledger.db.start() + self.account = self.ledger.account_class.generate(self.ledger, u"torba", is_hd=False) + + @defer.inlineCallbacks + def test_generate_account(self): + account = self.account + + self.assertEqual(account.ledger, self.ledger) + self.assertIsNotNone(account.seed) + self.assertEqual(account.public_key.ledger, self.ledger) + self.assertEqual(account.private_key.public_key, account.public_key) + + addresses = yield account.receiving.get_addresses() + self.assertEqual(len(addresses), 0) + addresses = yield account.change.get_addresses() + self.assertEqual(len(addresses), 0) + + yield account.ensure_address_gap() + + addresses = yield account.receiving.get_addresses() + self.assertEqual(len(addresses), 1) + self.assertEqual(addresses[0], account.public_key.address) + addresses = yield account.change.get_addresses() + self.assertEqual(len(addresses), 1) + self.assertEqual(addresses[0], account.public_key.address) + + addresses = yield account.get_addresses() + self.assertEqual(len(addresses), 1) + self.assertEqual(addresses[0], account.public_key.address) + + @defer.inlineCallbacks + def test_ensure_address_gap(self): + account = self.account + + self.assertIsInstance(account.receiving, SingleKey) + addresses = yield account.receiving.get_addresses() + self.assertEqual(addresses, []) + + # we have 12, but default gap is 20 + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 1) + self.assertEqual(new_keys[0], account.public_key.address) + records = yield account.receiving.get_address_records() + self.assertEqual(records, [{ + 'position': 0, 'address': account.public_key.address, 'used_times': 0 + }]) + + # case #1: no new addresses needed + empty = yield account.receiving.ensure_address_gap() + self.assertEqual(len(empty), 0) + + # case #2: after use, still no new address needed + records = yield account.receiving.get_address_records() + yield self.ledger.db.set_address_history(records[0]['address'], 'a:1:') + empty = yield account.receiving.ensure_address_gap() + self.assertEqual(len(empty), 0) + + @defer.inlineCallbacks + def test_get_or_create_usable_address(self): + account = self.account + + addresses = yield account.receiving.get_addresses() + self.assertEqual(len(addresses), 0) + + address1 = yield account.receiving.get_or_create_usable_address() + self.assertIsNotNone(address1) + + yield self.ledger.db.set_address_history(address1, 'a:1:b:2:c:3:') + records = yield account.receiving.get_address_records() + self.assertEqual(records[0]['used_times'], 3) + + address2 = yield account.receiving.get_or_create_usable_address() + self.assertEqual(address1, address2) + + keys = yield account.receiving.get_addresses() + self.assertEqual(len(keys), 1) + + @defer.inlineCallbacks + def test_generate_account_from_seed(self): + account = self.ledger.account_class.from_seed( + self.ledger, + u"carbon smart garage balance margin twelve chest sword toast envelope bottom stomach ab" + u"sent", + u"torba", + is_hd=False + ) + self.assertEqual( + account.private_key.extended_key_string(), + b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' + b'6yz3jMbycrLrRMpeAJxR8qDg8' + ) + self.assertEqual( + account.public_key.extended_key_string(), + b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' + b'iW44g14WF52fYC5J483wqQ5ZP' + ) + address = yield account.receiving.ensure_address_gap() + self.assertEqual(address[0], account.public_key.address) + + private_key = yield self.ledger.get_private_key_for_address(address[0]) + self.assertEqual( + private_key.extended_key_string(), + b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' + b'6yz3jMbycrLrRMpeAJxR8qDg8' + ) + + invalid_key = yield self.ledger.get_private_key_for_address(b'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') + self.assertIsNone(invalid_key) + + self.assertEqual( + hexlify(private_key.wif()), + b'1c2423f3dc6087d9683f73a684935abc0ccd8bc26370588f56653128c6a6f0bf7c01' + ) + + @defer.inlineCallbacks + def test_load_and_save_account(self): + account_data = { + 'name': 'My Account', + 'seed': + "carbon smart garage balance margin twelve chest sword toast envelope bottom stomac" + "h absent", + 'encrypted': False, + 'private_key': + 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' + '6yz3jMbycrLrRMpeAJxR8qDg8', + 'public_key': + 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' + 'iW44g14WF52fYC5J483wqQ5ZP', + 'is_hd': False + } + + account = self.ledger.account_class.from_dict(self.ledger, account_data) + + yield account.ensure_address_gap() + + addresses = yield account.receiving.get_addresses() + self.assertEqual(len(addresses), 1) + addresses = yield account.change.get_addresses() + self.assertEqual(len(addresses), 1) self.maxDiff = None account_data['ledger'] = 'btc_mainnet' diff --git a/torba/baseaccount.py b/torba/baseaccount.py index f8d46dd03..ec29a184a 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -7,30 +7,60 @@ from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.hash import double_sha256, aes_encrypt, aes_decrypt -class KeyChain: +class KeyManager(object): - def __init__(self, account, parent_key, chain_number, gap, maximum_use_per_address): - # type: ('BaseAccount', PubKey, int, int, int) -> None + __slots__ = 'account', 'public_key', 'chain_number' + + def __init__(self, account, public_key, chain_number): self.account = account - self.db = account.ledger.db - self.main_key = parent_key.child(chain_number) + self.public_key = public_key self.chain_number = chain_number - self.gap = gap - self.maximum_use_per_address = maximum_use_per_address - def get_addresses(self, limit=None, details=False): - return self.db.get_addresses(self.account, self.chain_number, limit, details) + @property + def db(self): + return self.account.ledger.db - def get_usable_addresses(self, limit=None): - return self.db.get_usable_addresses( - self.account, self.chain_number, self.maximum_use_per_address, limit + def _query_addresses(self, limit=None, max_used_times=None, order_by=None): + return self.db.get_addresses( + self.account, self.chain_number, limit, max_used_times, order_by ) + def ensure_address_gap(self): # type: () -> defer.Deferred + raise NotImplementedError + + def get_address_records(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred + raise NotImplementedError + + @defer.inlineCallbacks + def get_addresses(self, limit=None, only_usable=False): # type: (int, bool) -> defer.Deferred + records = yield self.get_address_records(limit=limit, only_usable=only_usable) + defer.returnValue([r['address'] for r in records]) + + @defer.inlineCallbacks + def get_or_create_usable_address(self): # type: () -> defer.Deferred + addresses = yield self.get_addresses(limit=1, only_usable=True) + if addresses: + defer.returnValue(addresses[0]) + addresses = yield self.ensure_address_gap() + defer.returnValue(addresses[0]) + + +class KeyChain(KeyManager): + """ Implements simple version of Bitcoin Hierarchical Deterministic key management. """ + + __slots__ = 'gap', 'maximum_uses_per_address' + + def __init__(self, account, root_public_key, chain_number, gap, maximum_uses_per_address): + # type: ('BaseAccount', PubKey, int, int, int) -> None + super(KeyChain, self).__init__(account, root_public_key.child(chain_number), chain_number) + self.gap = gap + self.maximum_uses_per_address = maximum_uses_per_address + @defer.inlineCallbacks def generate_keys(self, start, end): new_keys = [] for index in range(start, end+1): - new_keys.append((index, self.main_key.child(index))) + new_keys.append((index, self.public_key.child(index))) yield self.db.add_keys( self.account, self.chain_number, new_keys ) @@ -38,7 +68,7 @@ class KeyChain: @defer.inlineCallbacks def ensure_address_gap(self): - addresses = yield self.get_addresses(self.gap, True) + addresses = yield self._query_addresses(self.gap, None, "position DESC") existing_gap = 0 for address in addresses: @@ -55,13 +85,34 @@ class KeyChain: new_keys = yield self.generate_keys(start, end-1) defer.returnValue(new_keys) + def get_address_records(self, limit=None, only_usable=False): + return self._query_addresses( + limit, self.maximum_uses_per_address if only_usable else None, + "used_times ASC, position ASC" + ) + + +class SingleKey(KeyManager): + """ Single Key manager always returns the same address for all operations. """ + + __slots__ = () + + def __init__(self, account, root_public_key, chain_number): + # type: ('BaseAccount', PubKey) -> None + super(SingleKey, self).__init__(account, root_public_key, chain_number) + @defer.inlineCallbacks - def get_or_create_usable_address(self): - addresses = yield self.get_usable_addresses(1) - if addresses: - defer.returnValue(addresses[0]) - addresses = yield self.ensure_address_gap() - defer.returnValue(addresses[0]) + def ensure_address_gap(self): + exists = yield self.get_address_records() + if not exists: + yield self.db.add_keys( + self.account, self.chain_number, [(0, self.public_key)] + ) + defer.returnValue([self.public_key.address]) + defer.returnValue([]) + + def get_address_records(self, **kwargs): + return self._query_addresses() class BaseAccount(object): @@ -70,34 +121,43 @@ class BaseAccount(object): private_key_class = PrivateKey public_key_class = PubKey - def __init__(self, ledger, seed, encrypted, private_key, + def __init__(self, ledger, name, seed, encrypted, is_hd, private_key, public_key, receiving_gap=20, change_gap=6, - receiving_maximum_use_per_address=2, change_maximum_use_per_address=2): - # type: (torba.baseledger.BaseLedger, str, bool, PrivateKey, PubKey, int, int, int, int) -> None + receiving_maximum_uses_per_address=2, change_maximum_uses_per_address=2): + # type: (torba.baseledger.BaseLedger, str, str, bool, bool, PrivateKey, PubKey, int, int, int, int) -> None self.ledger = ledger + self.name = name self.seed = seed self.encrypted = encrypted self.private_key = private_key self.public_key = public_key - self.receiving, self.change = self.keychains = ( - KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_use_per_address), - KeyChain(self, public_key, 1, change_gap, change_maximum_use_per_address) - ) + if is_hd: + receiving, change = self.keychains = ( + KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_uses_per_address), + KeyChain(self, public_key, 1, change_gap, change_maximum_uses_per_address) + ) + else: + self.keychains = SingleKey(self, public_key, 0), + receiving = change = self.keychains[0] + self.receiving = receiving # type: KeyManager + self.change = change # type: KeyManager ledger.add_account(self) @classmethod - def generate(cls, ledger, password): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount + def generate(cls, ledger, password, **kwargs): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount seed = cls.mnemonic_class().make_seed() - return cls.from_seed(ledger, seed, password) + return cls.from_seed(ledger, seed, password, **kwargs) @classmethod - def from_seed(cls, ledger, seed, password): + def from_seed(cls, ledger, seed, password, is_hd=True, **kwargs): # type: (torba.baseledger.BaseLedger, str, str) -> BaseAccount private_key = cls.get_private_key_from_seed(ledger, seed, password) return cls( - ledger=ledger, seed=seed, encrypted=False, + ledger=ledger, name='Account #{}'.format(private_key.public_key.address), + seed=seed, encrypted=False, is_hd=is_hd, private_key=private_key, - public_key=private_key.public_key + public_key=private_key.public_key, + **kwargs ) @classmethod @@ -109,38 +169,60 @@ class BaseAccount(object): @classmethod def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount - if not d['encrypted']: + if not d['encrypted'] and d['private_key']: private_key = from_extended_key_string(ledger, d['private_key']) public_key = private_key.public_key else: private_key = d['private_key'] public_key = from_extended_key_string(ledger, d['public_key']) - return cls( + + kwargs = dict( ledger=ledger, + name=d['name'], seed=d['seed'], encrypted=d['encrypted'], private_key=private_key, public_key=public_key, - receiving_gap=d['receiving_gap'], - change_gap=d['change_gap'], - receiving_maximum_use_per_address=d['receiving_maximum_use_per_address'], - change_maximum_use_per_address=d['change_maximum_use_per_address'] + is_hd=False ) + if d['is_hd']: + kwargs.update(dict( + receiving_gap=d['receiving_gap'], + change_gap=d['change_gap'], + receiving_maximum_uses_per_address=d['receiving_maximum_uses_per_address'], + change_maximum_uses_per_address=d['change_maximum_uses_per_address'], + is_hd=True + )) + + return cls(**kwargs) + def to_dict(self): - return { + private_key = self.private_key + if not self.encrypted and self.private_key: + private_key = self.private_key.extended_key_string().decode() + + d = { 'ledger': self.ledger.get_id(), + 'name': self.name, 'seed': self.seed, 'encrypted': self.encrypted, - 'private_key': self.private_key if self.encrypted else - self.private_key.extended_key_string().decode(), + 'private_key': private_key, 'public_key': self.public_key.extended_key_string().decode(), - 'receiving_gap': self.receiving.gap, - 'change_gap': self.change.gap, - 'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address, - 'change_maximum_use_per_address': self.change.maximum_use_per_address + 'is_hd': False } + if isinstance(self.receiving, KeyChain) and isinstance(self.change, KeyChain): + d.update({ + 'receiving_gap': self.receiving.gap, + 'change_gap': self.change.gap, + 'receiving_maximum_uses_per_address': self.receiving.maximum_uses_per_address, + 'change_maximum_uses_per_address': self.change.maximum_uses_per_address, + 'is_hd': True + }) + + return d + def decrypt(self, password): assert self.encrypted, "Key is not encrypted." secret = double_sha256(password) @@ -163,18 +245,29 @@ class BaseAccount(object): addresses.extend(new_addresses) defer.returnValue(addresses) - def get_addresses(self, limit=None, details=False): - return self.ledger.db.get_addresses(self, None, limit, details) + @defer.inlineCallbacks + def get_addresses(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred + records = yield self.get_address_records(limit, max_used_times) + defer.returnValue([r['address'] for r in records]) - def get_unused_addresses(self): - return self.ledger.db.get_unused_addresses(self, None) + def get_address_records(self, limit=None, max_used_times=None): # type: (int, int) -> defer.Deferred + return self.ledger.db.get_addresses(self, None, limit, max_used_times) def get_private_key(self, chain, index): assert not self.encrypted, "Cannot get private key on encrypted wallet account." - return self.private_key.child(chain).child(index) + if isinstance(self.receiving, SingleKey): + return self.private_key + else: + return self.private_key.child(chain).child(index) - def get_balance(self, **constraints): - return self.ledger.db.get_balance_for_account(self, **constraints) + def get_balance(self, confirmations, **constraints): + if confirmations == 0: + return self.ledger.db.get_balance_for_account(self, **constraints) + else: + height = self.ledger.headers.height - (confirmations-1) + return self.ledger.db.get_balance_for_account( + self, height__lte=height, height__not=-1, **constraints + ) def get_unspent_outputs(self, **constraints): return self.ledger.db.get_utxos_for_account(self, **constraints) diff --git a/torba/baseledger.py b/torba/baseledger.py index 339b111ef..fba46c48b 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -238,7 +238,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): # need to update anyways. Continue to get history and create more addresses until # all missing addresses are created and history for them is fully restored. yield account.ensure_address_gap() - addresses = yield account.get_unused_addresses() + addresses = yield account.get_addresses(max_used_times=0) while addresses: yield defer.DeferredList([ self.update_history(a) for a in addresses