fix
This commit is contained in:
parent
730a67c8d6
commit
d6cf0e2699
4 changed files with 31 additions and 13 deletions
|
@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||
}
|
||||
)
|
||||
await self.ledger.network.connect_wallets(wallet.id)
|
||||
await self.ledger.subscribe_account(account)
|
||||
await self.ledger._subscribe_accounts([account])
|
||||
elif wallet.accounts:
|
||||
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||
wallet.save()
|
||||
if not skip_on_startup:
|
||||
with self.conf.update_config() as c:
|
||||
|
@ -1331,8 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
if not os.path.exists(wallet_path):
|
||||
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
|
||||
wallet = self.wallet_manager.import_wallet(wallet_path)
|
||||
|
||||
if not self.ledger.network.is_connected(wallet.id):
|
||||
await self.ledger.network.connect_wallets(wallet.id)
|
||||
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||
return wallet
|
||||
|
||||
@requires("wallet")
|
||||
|
|
|
@ -451,13 +451,17 @@ class Ledger(metaclass=LedgerRegistry):
|
|||
)
|
||||
|
||||
async def subscribe_accounts(self):
|
||||
wallet_ids = {account.wallet.id for account in self.accounts}
|
||||
return await self._subscribe_accounts(self.accounts)
|
||||
|
||||
async def _subscribe_accounts(self, accounts):
|
||||
wallet_ids = {account.wallet.id for account in accounts}
|
||||
await self.network.connect_wallets(*wallet_ids)
|
||||
# if self.network.is_connected and self.accounts:
|
||||
log.info("Subscribe to %i accounts", len(self.accounts))
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in self.accounts
|
||||
])
|
||||
if accounts:
|
||||
log.info("Subscribe to %i accounts", len(accounts))
|
||||
await asyncio.wait([
|
||||
self.subscribe_account(a) for a in accounts
|
||||
])
|
||||
|
||||
async def subscribe_account(self, account: Account):
|
||||
for address_manager in account.address_managers.values():
|
||||
|
|
|
@ -114,7 +114,6 @@ class ClientSession(BaseClientSession):
|
|||
|
||||
async def ensure_session(self):
|
||||
# Handles reconnecting and maintaining a session alive
|
||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
||||
retry_delay = default_delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
|
@ -201,7 +200,6 @@ class Network:
|
|||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||
self.on_header = self._on_header_controller.stream
|
||||
|
||||
|
||||
self.subscription_controllers = {
|
||||
'blockchain.headers.subscribe': self._on_header_controller,
|
||||
}
|
||||
|
@ -212,7 +210,11 @@ class Network:
|
|||
return self.clients[wallet.id]
|
||||
|
||||
async def close_wallet_session(self, wallet):
|
||||
await self.clients.pop(wallet.id).close()
|
||||
session = self.clients.pop(wallet.id)
|
||||
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
|
||||
if ensure_connect_task and not ensure_connect_task.done():
|
||||
ensure_connect_task.cancel()
|
||||
await session.close()
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
|
@ -248,11 +250,11 @@ class Network:
|
|||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||
self.session_pool.start(self.config['default_servers'])
|
||||
self.on_header.listen(self._update_remote_height)
|
||||
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
|
||||
await self.ledger.subscribe_accounts()
|
||||
|
||||
async def connect_wallets(self, *wallet_ids):
|
||||
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
|
||||
if wallet_ids:
|
||||
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
|
@ -509,16 +511,20 @@ class SessionPool:
|
|||
|
||||
async def connect_wallet(self, wallet_id: str):
|
||||
fastest = await self.wait_for_fastest_session()
|
||||
connected = asyncio.Event()
|
||||
if not self.network.is_connected(wallet_id):
|
||||
session = self.network.connect_wallet_client(wallet_id, fastest.server)
|
||||
session._on_connect_cb = connected.set
|
||||
# session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
else:
|
||||
connected.set()
|
||||
session = self.network.clients[wallet_id]
|
||||
task = self.wallet_session_tasks.get(session, None)
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
# task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.wallet_session_tasks[session] = task
|
||||
await connected.wait()
|
||||
|
||||
def start(self, default_servers):
|
||||
for server in default_servers:
|
||||
|
|
|
@ -179,6 +179,12 @@ class ServerPickingTestCase(AsyncioTestCase):
|
|||
],
|
||||
'connect_timeout': 3
|
||||
})
|
||||
ledger.accounts = []
|
||||
|
||||
async def subscribe_accounts():
|
||||
pass
|
||||
|
||||
ledger.subscribe_accounts = subscribe_accounts
|
||||
|
||||
network = Network(ledger)
|
||||
self.addCleanup(network.stop)
|
||||
|
|
Loading…
Add table
Reference in a new issue