diff --git a/torba/client/baseaccount.py b/torba/client/baseaccount.py index 6b0c60cf6..391973650 100644 --- a/torba/client/baseaccount.py +++ b/torba/client/baseaccount.py @@ -58,6 +58,9 @@ class AddressManager: def get_private_key(self, index: int) -> PrivateKey: raise NotImplementedError + def get_public_key(self, index: int) -> PubKey: + raise NotImplementedError + async def get_max_gap(self): raise NotImplementedError @@ -108,6 +111,9 @@ class HierarchicalDeterministic(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key.child(self.chain_number).child(index) + def get_public_key(self, index: int) -> PubKey: + return self.account.public_key.child(self.chain_number).child(index) + async def get_max_gap(self) -> int: addresses = await self._query_addresses(order_by="position ASC") max_gap = 0 @@ -174,6 +180,9 @@ class SingleKey(AddressManager): def get_private_key(self, index: int) -> PrivateKey: return self.account.private_key + def get_public_key(self, index: int) -> PubKey: + return self.account.public_key + async def get_max_gap(self) -> int: return 0 @@ -390,6 +399,9 @@ class BaseAccount: assert not self.encrypted, "Cannot get private key on encrypted wallet account." return self.address_managers[chain].get_private_key(index) + def get_public_key(self, chain: int, index: int) -> PubKey: + return self.address_managers[chain].get_public_key(index) + def get_balance(self, confirmations: int = 0, **constraints): if confirmations > 0: height = self.ledger.headers.height - (confirmations-1) diff --git a/torba/client/baseledger.py b/torba/client/baseledger.py index 740f91377..2fb831588 100644 --- a/torba/client/baseledger.py +++ b/torba/client/baseledger.py @@ -173,12 +173,29 @@ class BaseLedger(metaclass=LedgerRegistry): def add_account(self, account: baseaccount.BaseAccount): self.accounts.append(account) - async def get_private_key_for_address(self, address): + async def _get_account_and_address_info_for_address(self, address): match = await self.db.get_address(address=address) if match: for account in self.accounts: if match['account'] == account.public_key.address: - return account.get_private_key(match['chain'], match['position']) + return account, match + + async def get_private_key_for_address(self, address): + match = await self._get_account_and_address_info_for_address(address) + if match: + account, address_info = match + return account.get_private_key(address_info['chain'], address_info['position']) + + async def get_public_key_for_address(self, address): + match = await self._get_account_and_address_info_for_address(address) + if match: + account, address_info = match + return account.get_public_key(address_info['chain'], address_info['position']) + + async def get_account_for_address(self, address): + match = await self._get_account_and_address_info_for_address(address) + if match: + return match[0] async def get_effective_amount_estimators(self, funding_accounts: Iterable[baseaccount.BaseAccount]): estimators = []