This commit is contained in:
Jack Robison 2020-11-17 12:09:35 -05:00
parent 73c40cef60
commit 0caec7e629
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 31 additions and 13 deletions

View file

@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name
}
)
await self.ledger.network.connect_wallets(wallet.id)
await self.ledger.subscribe_account(account)
await self.ledger._subscribe_accounts([account])
elif wallet.accounts:
await self.ledger._subscribe_accounts(wallet.accounts)
wallet.save()
if not skip_on_startup:
with self.conf.update_config() as c:
@ -1331,8 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path)
if not self.ledger.network.is_connected(wallet.id):
await self.ledger.network.connect_wallets(wallet.id)
await self.ledger._subscribe_accounts(wallet.accounts)
return wallet
@requires("wallet")

View file

@ -451,13 +451,17 @@ class Ledger(metaclass=LedgerRegistry):
)
async def subscribe_accounts(self):
wallet_ids = {account.wallet.id for account in self.accounts}
return await self._subscribe_accounts(self.accounts)
async def _subscribe_accounts(self, accounts):
wallet_ids = {account.wallet.id for account in accounts}
await self.network.connect_wallets(*wallet_ids)
# if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts))
await asyncio.wait([
self.subscribe_account(a) for a in self.accounts
])
if accounts:
log.info("Subscribe to %i accounts", len(accounts))
await asyncio.wait([
self.subscribe_account(a) for a in accounts
])
async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values():

View file

@ -114,7 +114,6 @@ class ClientSession(BaseClientSession):
async def ensure_session(self):
# Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0
while True:
try:
@ -201,7 +200,6 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream
self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller,
}
@ -212,7 +210,11 @@ class Network:
return self.clients[wallet.id]
async def close_wallet_session(self, wallet):
await self.clients.pop(wallet.id).close()
session = self.clients.pop(wallet.id)
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
if ensure_connect_task and not ensure_connect_task.done():
ensure_connect_task.cancel()
await session.close()
@property
def config(self):
@ -248,11 +250,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height)
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids):
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
if wallet_ids:
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self):
if self.running:
@ -509,16 +511,20 @@ class SessionPool:
async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session()
connected = asyncio.Event()
if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server)
session._on_connect_cb = connected.set
# session._on_connect_cb = self._get_session_connect_callback(session)
else:
connected.set()
session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None)
if not task or task.done():
task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task
await connected.wait()
def start(self, default_servers):
for server in default_servers:

View file

@ -179,6 +179,12 @@ class ServerPickingTestCase(AsyncioTestCase):
],
'connect_timeout': 3
})
ledger.accounts = []
async def subscribe_accounts():
pass
ledger.subscribe_accounts = subscribe_accounts
network = Network(ledger)
self.addCleanup(network.stop)