migrate addresses before starting the migrated wallet

This commit is contained in:
Victor Shyba 2018-10-04 20:50:19 -03:00 committed by Lex Berezhny
parent 3664c25d98
commit 38006f7d29
2 changed files with 43 additions and 5 deletions

View file

@ -341,7 +341,7 @@ class WalletComponent(Component):
log.info("Starting torba wallet")
storage = self.component_manager.get_component(DATABASE_COMPONENT)
lbryschema.BLOCKCHAIN_NAME = conf.settings['blockchain_name']
self.wallet_manager = LbryWalletManager.from_lbrynet_config(conf.settings, storage)
self.wallet_manager = yield LbryWalletManager.from_lbrynet_config(conf.settings, storage)
self.wallet_manager.old_db = storage
yield self.wallet_manager.start()

View file

@ -1,6 +1,8 @@
import os
import json
import logging
from binascii import unhexlify
from datetime import datetime
from typing import Optional
@ -100,7 +102,7 @@ class LbryWalletManager(BaseWalletManager):
@staticmethod
def migrate_lbryum_to_torba(path):
if not os.path.exists(path):
return
return None, None
with open(path, 'r') as f:
unmigrated_json = f.read()
unmigrated = json.loads(unmigrated_json)
@ -109,7 +111,15 @@ class LbryWalletManager(BaseWalletManager):
# have old structured wallets install one of the earlier releases that
# still has the below conversion code.
if 'master_public_keys' not in unmigrated:
return
return None, None
total = unmigrated.get('addr_history')
receiving_addresses, change_addresses = set(), set()
for _, unmigrated_account in unmigrated.get('accounts', {}).items():
receiving_addresses.update(map(unhexlify, unmigrated_account.get('receiving', [])))
change_addresses.update(map(unhexlify, unmigrated_account.get('change', [])))
log.info("Wallet migrator found %s receiving addresses and %s change addresses. %s in total on history.",
len(receiving_addresses), len(change_addresses), len(total))
migrated_json = json.dumps({
'version': 1,
'name': 'My Wallet',
@ -143,8 +153,10 @@ class LbryWalletManager(BaseWalletManager):
os.fsync(f.fileno())
os.rename(temp_path, path)
os.chmod(path, mode)
return receiving_addresses, change_addresses
@classmethod
@defer.inlineCallbacks
def from_lbrynet_config(cls, settings, db):
ledger_id = {
@ -167,7 +179,7 @@ class LbryWalletManager(BaseWalletManager):
wallet_file_path = os.path.join(wallets_directory, 'default_wallet')
cls.migrate_lbryum_to_torba(wallet_file_path)
receiving_addresses, change_addresses = cls.migrate_lbryum_to_torba(wallet_file_path)
manager = cls.from_config({
'ledgers': {ledger_id: ledger_config},
@ -178,7 +190,33 @@ class LbryWalletManager(BaseWalletManager):
log.info('Wallet at %s is empty, generating a default account.', wallet_file_path)
manager.default_wallet.generate_account(ledger)
manager.default_wallet.save()
return manager
if receiving_addresses or change_addresses:
ledger = manager.get_or_create_ledger(ledger_id)
if not os.path.exists(ledger.path):
os.mkdir(ledger.path)
yield ledger.db.open()
try:
yield manager._migrate_addresses(receiving_addresses, change_addresses)
finally:
yield ledger.db.close()
defer.returnValue(manager)
@defer.inlineCallbacks
def _migrate_addresses(self, receiving_addresses: set, change_addresses: set):
migrated_receiving = set((yield self.default_account.receiving.generate_keys(0, len(receiving_addresses))))
migrated_change = set((yield self.default_account.change.generate_keys(0, len(change_addresses))))
receiving_addresses = set(map(self.default_account.ledger.public_key_to_address, receiving_addresses))
change_addresses = set(map(self.default_account.ledger.public_key_to_address, change_addresses))
if not any(change_addresses.difference(migrated_change)):
log.info("Successfully migrated %s change addresses.", len(change_addresses))
else:
log.warning("Failed to migrate %s change addresses!",
len(set(change_addresses).difference(set(migrated_change))))
if not any(receiving_addresses.difference(migrated_receiving)):
log.info("Successfully migrated %s receiving addresses.", len(receiving_addresses))
else:
log.warning("Failed to migrate %s receiving addresses!",
len(set(receiving_addresses).difference(set(migrated_receiving))))
def get_best_blockhash(self):
return self.ledger.headers.hash(self.ledger.headers.height).decode()