migrate key addresses on changed accounts after sync apply

This commit is contained in:
Lex Berezhny 2022-07-25 10:44:28 -04:00
parent 352e45b6b7
commit 656e299100
3 changed files with 8 additions and 6 deletions

View file

@ -9,6 +9,7 @@ import inspect
import typing import typing
import random import random
import tracemalloc import tracemalloc
from itertools import chain
from decimal import Decimal from decimal import Decimal
from urllib.parse import urlencode, quote from urllib.parse import urlencode, quote
from typing import Callable, Optional, List from typing import Callable, Optional, List
@ -1986,8 +1987,8 @@ class Daemon(metaclass=JSONRPCServerType):
wallet = self.wallet_manager.get_wallet_or_default(wallet_id) wallet = self.wallet_manager.get_wallet_or_default(wallet_id)
wallet_changed = False wallet_changed = False
if data is not None: if data is not None:
added_accounts = wallet.merge(self.wallet_manager, password, data) added_accounts, merged_accounts = wallet.merge(self.wallet_manager, password, data)
for new_account in added_accounts: for new_account in chain(added_accounts, merged_accounts):
await new_account.maybe_migrate_certificates() await new_account.maybe_migrate_certificates()
if added_accounts and self.ledger.network.is_connected: if added_accounts and self.ledger.network.is_connected:
if blocking: if blocking:

View file

@ -175,9 +175,9 @@ class Wallet:
return json.loads(decompressed) return json.loads(decompressed)
def merge(self, manager: 'WalletManager', def merge(self, manager: 'WalletManager',
password: str, data: str) -> List['Account']: password: str, data: str) -> (List['Account'], List['Account']):
assert not self.is_locked, "Cannot sync apply on a locked wallet." assert not self.is_locked, "Cannot sync apply on a locked wallet."
added_accounts = [] added_accounts, merged_accounts = [], []
decrypted_data = self.unpack(password, data) decrypted_data = self.unpack(password, data)
self.preferences.merge(decrypted_data.get('preferences', {})) self.preferences.merge(decrypted_data.get('preferences', {}))
for account_dict in decrypted_data['accounts']: for account_dict in decrypted_data['accounts']:
@ -191,10 +191,11 @@ class Wallet:
break break
if local_match is not None: if local_match is not None:
local_match.merge(account_dict) local_match.merge(account_dict)
merged_accounts.append(local_match)
else: else:
new_account = Account.from_dict(ledger, self, account_dict) new_account = Account.from_dict(ledger, self, account_dict)
added_accounts.append(new_account) added_accounts.append(new_account)
return added_accounts return added_accounts, merged_accounts
@property @property
def is_locked(self) -> bool: def is_locked(self) -> bool:

View file

@ -209,7 +209,7 @@ class TestWalletCreation(AsyncioTestCase):
self.assertEqual(len(wallet1.accounts), 1) self.assertEqual(len(wallet1.accounts), 1)
self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1}) self.assertEqual(wallet1.preferences, {'one': 1, 'conflict': 1})
added = wallet1.merge(self.manager, 'password', wallet2.pack('password')) added, _ = wallet1.merge(self.manager, 'password', wallet2.pack('password'))
self.assertEqual(added[0].id, wallet2.default_account.id) self.assertEqual(added[0].id, wallet2.default_account.id)
self.assertEqual(len(wallet1.accounts), 2) self.assertEqual(len(wallet1.accounts), 2)
self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id) self.assertEqual(wallet1.accounts[1].id, wallet2.default_account.id)