diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index c8318408f..5f2fffa21 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -10,7 +10,7 @@ from .wallet import Wallet, WalletStorage, TimestampedPreferences, ENCRYPT_ON_DI from .manager import WalletManager from .network import Network from .ledger import Ledger, RegTestLedger, TestNetLedger, BlockHeightEvent -from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic +from .account import Account, AddressManager, SingleKey, HierarchicalDeterministic, DeterministicChannelKeyManager from .transaction import Transaction, Output, Input from .script import OutputScript, InputScript from .database import SQLiteMixin, Database diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index e577bbbcb..98abbddef 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -40,21 +40,30 @@ class DeterministicChannelKeyManager: self.account = account self.public_key = account.public_key.child(2) self.private_key = account.private_key.child(2) if account.private_key else None + self.last_known = 0 + self.cache = {} + + def maybe_generate_deterministic_key_for_channel(self, txo): + next_key = self.private_key.child(self.last_known) + if txo.claim.channel.public_key_bytes == next_key.public_key.pubkey_bytes: + self.cache[next_key.address()] = next_key + self.last_known += 1 + + async def ensure_cache_primed(self): + await self.generate_next_key() async def generate_next_key(self): db = self.account.ledger.db - i = 0 while True: - next_key = self.private_key.child(i) - if not await db.is_channel_key_used(self.account.wallet, next_key.address()): + next_key = self.private_key.child(self.last_known) + key_address = next_key.address() + self.cache[key_address] = next_key + if not await db.is_channel_key_used(self.account.wallet, key_address): return next_key - i += 1 + self.last_known += 1 def get_private_key_from_pubkey_hash(self, pubkey_hash): - for i in range(100): - next_key = self.private_key.child(i) - if next_key.address() == pubkey_hash: - return next_key + return self.cache.get(pubkey_hash) class AddressManager: diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 9583c22a7..87bace8b5 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -550,6 +550,7 @@ class Ledger(metaclass=LedgerRegistry): ) remote_history_txids = {txid for txid, _ in remote_history} async for tx in self.request_synced_transactions(to_request, remote_history_txids, address): + self.maybe_has_channel_key(tx) pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:" if len(pending_synced_history) % 100 == 0: log.info("Syncing address %s: %d/%d", address, len(pending_synced_history), len(to_request)) @@ -617,6 +618,12 @@ class Ledger(metaclass=LedgerRegistry): tx.is_verified = merkle_root == header['merkle_root'] return tx + def maybe_has_channel_key(self, tx): + for txo in tx._outputs: + if txo.can_decode_claim and txo.claim.is_channel: + for account in self.accounts: + account.deterministic_channel_keys.maybe_generate_deterministic_key_for_channel(txo) + async def request_transactions(self, to_request: Tuple[Tuple[str, int], ...], cached=False): batches = [[]] remote_heights = {} diff --git a/tests/integration/blockchain/test_account_commands.py b/tests/integration/blockchain/test_account_commands.py index 7cf5e81ce..5e7abec04 100644 --- a/tests/integration/blockchain/test_account_commands.py +++ b/tests/integration/blockchain/test_account_commands.py @@ -177,10 +177,36 @@ class AccountManagement(CommandTestCase): async def test_deterministic_channel_keys(self): seed = self.account.seed - channel1 = await self.channel_create('@foo1') - channel2 = await self.channel_create('@foo2') + + # create two channels and make sure they have different keys + channel1a = await self.channel_create('@foo1') + channel2a = await self.channel_create('@foo2') self.assertNotEqual( - channel1['outputs'][0]['value']['public_key'], - channel2['outputs'][0]['value']['public_key'], + channel1a['outputs'][0]['value']['public_key'], + channel2a['outputs'][0]['value']['public_key'], ) + + # start another daemon from the same seed self.daemon2 = await self.add_daemon(seed=seed) + channel2b, channel1b = (await self.daemon2.jsonrpc_channel_list())['items'] + + # both daemons end up with the same channel signing keys automagically + self.assertTrue(channel1b.has_private_key) + self.assertEqual( + channel1a['outputs'][0]['value']['public_key_id'], + channel1b.private_key.public_key.address + ) + self.assertTrue(channel2b.has_private_key) + self.assertEqual( + channel2a['outputs'][0]['value']['public_key_id'], + channel2b.private_key.public_key.address + ) + + # create third channel while both daemons running, second daemon should pick it up + channel3a = await self.channel_create('@foo3') + channel3b, = (await self.daemon2.jsonrpc_channel_list(name='@foo3'))['items'] + self.assertTrue(channel3b.has_private_key) + self.assertEqual( + channel3a['outputs'][0]['value']['public_key_id'], + channel3b.private_key.public_key.address + ) diff --git a/tests/unit/wallet/test_account.py b/tests/unit/wallet/test_account.py index 894762c68..2a4c30f7b 100644 --- a/tests/unit/wallet/test_account.py +++ b/tests/unit/wallet/test_account.py @@ -1,7 +1,11 @@ import asyncio from binascii import hexlify from lbry.testcase import AsyncioTestCase -from lbry.wallet import Wallet, Ledger, Database, Headers, Account, SingleKey, HierarchicalDeterministic +from lbry.wallet import ( + Wallet, Ledger, Database, Headers, + Account, SingleKey, HierarchicalDeterministic, + DeterministicChannelKeyManager +) class TestAccount(AsyncioTestCase):