fix
This commit is contained in:
parent
730a67c8d6
commit
d6cf0e2699
4 changed files with 31 additions and 13 deletions
|
@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await self.ledger.network.connect_wallets(wallet.id)
|
await self.ledger._subscribe_accounts([account])
|
||||||
await self.ledger.subscribe_account(account)
|
elif wallet.accounts:
|
||||||
|
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||||
wallet.save()
|
wallet.save()
|
||||||
if not skip_on_startup:
|
if not skip_on_startup:
|
||||||
with self.conf.update_config() as c:
|
with self.conf.update_config() as c:
|
||||||
|
@ -1331,8 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
if not os.path.exists(wallet_path):
|
if not os.path.exists(wallet_path):
|
||||||
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
|
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
|
||||||
wallet = self.wallet_manager.import_wallet(wallet_path)
|
wallet = self.wallet_manager.import_wallet(wallet_path)
|
||||||
|
|
||||||
if not self.ledger.network.is_connected(wallet.id):
|
if not self.ledger.network.is_connected(wallet.id):
|
||||||
await self.ledger.network.connect_wallets(wallet.id)
|
await self.ledger._subscribe_accounts(wallet.accounts)
|
||||||
return wallet
|
return wallet
|
||||||
|
|
||||||
@requires("wallet")
|
@requires("wallet")
|
||||||
|
|
|
@ -451,13 +451,17 @@ class Ledger(metaclass=LedgerRegistry):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_accounts(self):
|
async def subscribe_accounts(self):
|
||||||
wallet_ids = {account.wallet.id for account in self.accounts}
|
return await self._subscribe_accounts(self.accounts)
|
||||||
|
|
||||||
|
async def _subscribe_accounts(self, accounts):
|
||||||
|
wallet_ids = {account.wallet.id for account in accounts}
|
||||||
await self.network.connect_wallets(*wallet_ids)
|
await self.network.connect_wallets(*wallet_ids)
|
||||||
# if self.network.is_connected and self.accounts:
|
# if self.network.is_connected and self.accounts:
|
||||||
log.info("Subscribe to %i accounts", len(self.accounts))
|
if accounts:
|
||||||
await asyncio.wait([
|
log.info("Subscribe to %i accounts", len(accounts))
|
||||||
self.subscribe_account(a) for a in self.accounts
|
await asyncio.wait([
|
||||||
])
|
self.subscribe_account(a) for a in accounts
|
||||||
|
])
|
||||||
|
|
||||||
async def subscribe_account(self, account: Account):
|
async def subscribe_account(self, account: Account):
|
||||||
for address_manager in account.address_managers.values():
|
for address_manager in account.address_managers.values():
|
||||||
|
|
|
@ -114,7 +114,6 @@ class ClientSession(BaseClientSession):
|
||||||
|
|
||||||
async def ensure_session(self):
|
async def ensure_session(self):
|
||||||
# Handles reconnecting and maintaining a session alive
|
# Handles reconnecting and maintaining a session alive
|
||||||
# TODO: change to 'ping' on newer protocol (above 1.2)
|
|
||||||
retry_delay = default_delay = 1.0
|
retry_delay = default_delay = 1.0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -201,7 +200,6 @@ class Network:
|
||||||
self._on_header_controller = StreamController(merge_repeated_events=True)
|
self._on_header_controller = StreamController(merge_repeated_events=True)
|
||||||
self.on_header = self._on_header_controller.stream
|
self.on_header = self._on_header_controller.stream
|
||||||
|
|
||||||
|
|
||||||
self.subscription_controllers = {
|
self.subscription_controllers = {
|
||||||
'blockchain.headers.subscribe': self._on_header_controller,
|
'blockchain.headers.subscribe': self._on_header_controller,
|
||||||
}
|
}
|
||||||
|
@ -212,7 +210,11 @@ class Network:
|
||||||
return self.clients[wallet.id]
|
return self.clients[wallet.id]
|
||||||
|
|
||||||
async def close_wallet_session(self, wallet):
|
async def close_wallet_session(self, wallet):
|
||||||
await self.clients.pop(wallet.id).close()
|
session = self.clients.pop(wallet.id)
|
||||||
|
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
|
||||||
|
if ensure_connect_task and not ensure_connect_task.done():
|
||||||
|
ensure_connect_task.cancel()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
|
@ -248,11 +250,11 @@ class Network:
|
||||||
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
|
||||||
self.session_pool.start(self.config['default_servers'])
|
self.session_pool.start(self.config['default_servers'])
|
||||||
self.on_header.listen(self._update_remote_height)
|
self.on_header.listen(self._update_remote_height)
|
||||||
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
|
|
||||||
await self.ledger.subscribe_accounts()
|
await self.ledger.subscribe_accounts()
|
||||||
|
|
||||||
async def connect_wallets(self, *wallet_ids):
|
async def connect_wallets(self, *wallet_ids):
|
||||||
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
|
if wallet_ids:
|
||||||
|
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
if self.running:
|
if self.running:
|
||||||
|
@ -509,16 +511,20 @@ class SessionPool:
|
||||||
|
|
||||||
async def connect_wallet(self, wallet_id: str):
|
async def connect_wallet(self, wallet_id: str):
|
||||||
fastest = await self.wait_for_fastest_session()
|
fastest = await self.wait_for_fastest_session()
|
||||||
|
connected = asyncio.Event()
|
||||||
if not self.network.is_connected(wallet_id):
|
if not self.network.is_connected(wallet_id):
|
||||||
session = self.network.connect_wallet_client(wallet_id, fastest.server)
|
session = self.network.connect_wallet_client(wallet_id, fastest.server)
|
||||||
|
session._on_connect_cb = connected.set
|
||||||
# session._on_connect_cb = self._get_session_connect_callback(session)
|
# session._on_connect_cb = self._get_session_connect_callback(session)
|
||||||
else:
|
else:
|
||||||
|
connected.set()
|
||||||
session = self.network.clients[wallet_id]
|
session = self.network.clients[wallet_id]
|
||||||
task = self.wallet_session_tasks.get(session, None)
|
task = self.wallet_session_tasks.get(session, None)
|
||||||
if not task or task.done():
|
if not task or task.done():
|
||||||
task = asyncio.create_task(session.ensure_session())
|
task = asyncio.create_task(session.ensure_session())
|
||||||
# task.add_done_callback(lambda _: self.ensure_connections())
|
# task.add_done_callback(lambda _: self.ensure_connections())
|
||||||
self.wallet_session_tasks[session] = task
|
self.wallet_session_tasks[session] = task
|
||||||
|
await connected.wait()
|
||||||
|
|
||||||
def start(self, default_servers):
|
def start(self, default_servers):
|
||||||
for server in default_servers:
|
for server in default_servers:
|
||||||
|
|
|
@ -179,6 +179,12 @@ class ServerPickingTestCase(AsyncioTestCase):
|
||||||
],
|
],
|
||||||
'connect_timeout': 3
|
'connect_timeout': 3
|
||||||
})
|
})
|
||||||
|
ledger.accounts = []
|
||||||
|
|
||||||
|
async def subscribe_accounts():
|
||||||
|
pass
|
||||||
|
|
||||||
|
ledger.subscribe_accounts = subscribe_accounts
|
||||||
|
|
||||||
network = Network(ledger)
|
network = Network(ledger)
|
||||||
self.addCleanup(network.stop)
|
self.addCleanup(network.stop)
|
||||||
|
|
Loading…
Add table
Reference in a new issue