diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 1ad87af4a..c2dd5ba08 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -1303,8 +1303,9 @@ class Daemon(metaclass=JSONRPCServerType): 'name': SingleKey.name if single_key else HierarchicalDeterministic.name } ) - await self.ledger.network.connect_wallets(wallet.id) - await self.ledger.subscribe_account(account) + await self.ledger._subscribe_accounts([account]) + elif wallet.accounts: + await self.ledger._subscribe_accounts(wallet.accounts) wallet.save() if not skip_on_startup: with self.conf.update_config() as c: @@ -1331,8 +1332,9 @@ class Daemon(metaclass=JSONRPCServerType): if not os.path.exists(wallet_path): raise Exception(f"Wallet at path '{wallet_path}' was not found.") wallet = self.wallet_manager.import_wallet(wallet_path) + if not self.ledger.network.is_connected(wallet.id): - await self.ledger.network.connect_wallets(wallet.id) + await self.ledger._subscribe_accounts(wallet.accounts) return wallet @requires("wallet") diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index d2d1ff645..d844f1322 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -451,13 +451,17 @@ class Ledger(metaclass=LedgerRegistry): ) async def subscribe_accounts(self): - wallet_ids = {account.wallet.id for account in self.accounts} + return await self._subscribe_accounts(self.accounts) + + async def _subscribe_accounts(self, accounts): + wallet_ids = {account.wallet.id for account in accounts} await self.network.connect_wallets(*wallet_ids) # if self.network.is_connected and self.accounts: - log.info("Subscribe to %i accounts", len(self.accounts)) - await asyncio.wait([ - self.subscribe_account(a) for a in self.accounts - ]) + if accounts: + log.info("Subscribe to %i accounts", len(accounts)) + await asyncio.wait([ + self.subscribe_account(a) for a in accounts + ]) async def subscribe_account(self, account: Account): for address_manager in account.address_managers.values(): diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 28d7503fc..5483818be 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -114,7 +114,6 @@ class ClientSession(BaseClientSession): async def ensure_session(self): # Handles reconnecting and maintaining a session alive - # TODO: change to 'ping' on newer protocol (above 1.2) retry_delay = default_delay = 1.0 while True: try: @@ -201,7 +200,6 @@ class Network: self._on_header_controller = StreamController(merge_repeated_events=True) self.on_header = self._on_header_controller.stream - self.subscription_controllers = { 'blockchain.headers.subscribe': self._on_header_controller, } @@ -212,7 +210,11 @@ class Network: return self.clients[wallet.id] async def close_wallet_session(self, wallet): - await self.clients.pop(wallet.id).close() + session = self.clients.pop(wallet.id) + ensure_connect_task = self.session_pool.wallet_session_tasks.pop(session) + if ensure_connect_task and not ensure_connect_task.done(): + ensure_connect_task.cancel() + await session.close() @property def config(self): @@ -248,11 +250,11 @@ class Network: self._switch_task.add_done_callback(lambda _: log.info("Wallet client switching task stopped.")) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) - await self.connect_wallets(*{a.wallet.id for a in self.ledger.accounts}) await self.ledger.subscribe_accounts() async def connect_wallets(self, *wallet_ids): - await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids]) + if wallet_ids: + await asyncio.wait([self.session_pool.connect_wallet(wallet_id) for wallet_id in wallet_ids]) async def stop(self): if self.running: @@ -509,16 +511,20 @@ class SessionPool: async def connect_wallet(self, wallet_id: str): fastest = await self.wait_for_fastest_session() + connected = asyncio.Event() if not self.network.is_connected(wallet_id): session = self.network.connect_wallet_client(wallet_id, fastest.server) + session._on_connect_cb = connected.set # session._on_connect_cb = self._get_session_connect_callback(session) else: + connected.set() session = self.network.clients[wallet_id] task = self.wallet_session_tasks.get(session, None) if not task or task.done(): task = asyncio.create_task(session.ensure_session()) # task.add_done_callback(lambda _: self.ensure_connections()) self.wallet_session_tasks[session] = task + await connected.wait() def start(self, default_servers): for server in default_servers: diff --git a/tests/integration/blockchain/test_network.py b/tests/integration/blockchain/test_network.py index f471eaab0..d4f06713f 100644 --- a/tests/integration/blockchain/test_network.py +++ b/tests/integration/blockchain/test_network.py @@ -179,6 +179,12 @@ class ServerPickingTestCase(AsyncioTestCase): ], 'connect_timeout': 3 }) + ledger.accounts = [] + + async def subscribe_accounts(): + pass + + ledger.subscribe_accounts = subscribe_accounts network = Network(ledger) self.addCleanup(network.stop)