This commit is contained in:
Jack Robison 2020-11-17 12:09:35 -05:00
parent 730a67c8d6
commit d6cf0e2699
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 31 additions and 13 deletions

View file

@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType):
'name': SingleKey.name if single_key else HierarchicalDeterministic.name 'name': SingleKey.name if single_key else HierarchicalDeterministic.name
} }
) )
await self.ledger.network.connect_wallets(wallet.id) await self.ledger._subscribe_accounts([account])
await self.ledger.subscribe_account(account) elif wallet.accounts:
await self.ledger._subscribe_accounts(wallet.accounts)
wallet.save() wallet.save()
if not skip_on_startup: if not skip_on_startup:
with self.conf.update_config() as c: with self.conf.update_config() as c:
@ -1331,8 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType):
if not os.path.exists(wallet_path): if not os.path.exists(wallet_path):
raise Exception(f"Wallet at path '{wallet_path}' was not found.") raise Exception(f"Wallet at path '{wallet_path}' was not found.")
wallet = self.wallet_manager.import_wallet(wallet_path) wallet = self.wallet_manager.import_wallet(wallet_path)
if not self.ledger.network.is_connected(wallet.id): if not self.ledger.network.is_connected(wallet.id):
await self.ledger.network.connect_wallets(wallet.id) await self.ledger._subscribe_accounts(wallet.accounts)
return wallet return wallet
@requires("wallet") @requires("wallet")

View file

@ -451,13 +451,17 @@ class Ledger(metaclass=LedgerRegistry):
) )
async def subscribe_accounts(self): async def subscribe_accounts(self):
wallet_ids = {account.wallet.id for account in self.accounts} return await self._subscribe_accounts(self.accounts)
async def _subscribe_accounts(self, accounts):
wallet_ids = {account.wallet.id for account in accounts}
await self.network.connect_wallets(*wallet_ids) await self.network.connect_wallets(*wallet_ids)
# if self.network.is_connected and self.accounts: # if self.network.is_connected and self.accounts:
log.info("Subscribe to %i accounts", len(self.accounts)) if accounts:
await asyncio.wait([ log.info("Subscribe to %i accounts", len(accounts))
self.subscribe_account(a) for a in self.accounts await asyncio.wait([
]) self.subscribe_account(a) for a in accounts
])
async def subscribe_account(self, account: Account): async def subscribe_account(self, account: Account):
for address_manager in account.address_managers.values(): for address_manager in account.address_managers.values():

View file

@ -114,7 +114,6 @@ class ClientSession(BaseClientSession):
async def ensure_session(self): async def ensure_session(self):
# Handles reconnecting and maintaining a session alive # Handles reconnecting and maintaining a session alive
# TODO: change to 'ping' on newer protocol (above 1.2)
retry_delay = default_delay = 1.0 retry_delay = default_delay = 1.0
while True: while True:
try: try:
@ -201,7 +200,6 @@ class Network:
self._on_header_controller = StreamController(merge_repeated_events=True) self._on_header_controller = StreamController(merge_repeated_events=True)
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self.subscription_controllers = { self.subscription_controllers = {
'blockchain.headers.subscribe': self._on_header_controller, 'blockchain.headers.subscribe': self._on_header_controller,
} }
@ -212,7 +210,11 @@ class Network:
return self.clients[wallet.id] return self.clients[wallet.id]
async def close_wallet_session(self, wallet): async def close_wallet_session(self, wallet):
await self.clients.pop(wallet.id).close() session = self.clients.pop(wallet.id)
ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session)
if ensure_connect_task and not ensure_connect_task.done():
ensure_connect_task.cancel()
await session.close()
@property @property
def config(self): def config(self):
@ -248,11 +250,11 @@ class Network:
self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped."))
self.session_pool.start(self.config['default_servers']) self.session_pool.start(self.config['default_servers'])
self.on_header.listen(self._update_remote_height) self.on_header.listen(self._update_remote_height)
await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts})
await self.ledger.subscribe_accounts() await self.ledger.subscribe_accounts()
async def connect_wallets(self, *wallet_ids): async def connect_wallets(self, *wallet_ids):
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids]) if wallet_ids:
await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids])
async def stop(self): async def stop(self):
if self.running: if self.running:
@ -509,16 +511,20 @@ class SessionPool:
async def connect_wallet(self, wallet_id: str): async def connect_wallet(self, wallet_id: str):
fastest = await self.wait_for_fastest_session() fastest = await self.wait_for_fastest_session()
connected = asyncio.Event()
if not self.network.is_connected(wallet_id): if not self.network.is_connected(wallet_id):
session = self.network.connect_wallet_client(wallet_id, fastest.server) session = self.network.connect_wallet_client(wallet_id, fastest.server)
session._on_connect_cb = connected.set
# session._on_connect_cb = self._get_session_connect_callback(session) # session._on_connect_cb = self._get_session_connect_callback(session)
else: else:
connected.set()
session = self.network.clients[wallet_id] session = self.network.clients[wallet_id]
task = self.wallet_session_tasks.get(session, None) task = self.wallet_session_tasks.get(session, None)
if not task or task.done(): if not task or task.done():
task = asyncio.create_task(session.ensure_session()) task = asyncio.create_task(session.ensure_session())
# task.add_done_callback(lambda _: self.ensure_connections()) # task.add_done_callback(lambda _: self.ensure_connections())
self.wallet_session_tasks[session] = task self.wallet_session_tasks[session] = task
await connected.wait()
def start(self, default_servers): def start(self, default_servers):
for server in default_servers: for server in default_servers:

View file

@ -179,6 +179,12 @@ class ServerPickingTestCase(AsyncioTestCase):
], ],
'connect_timeout': 3 'connect_timeout': 3
}) })
ledger.accounts = []
async def subscribe_accounts():
pass
ledger.subscribe_accounts = subscribe_accounts
network = Network(ledger) network = Network(ledger)
self.addCleanup(network.stop) self.addCleanup(network.stop)