diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 1693bb178..38ff07b5a 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -40,11 +40,12 @@ class ClientSession(BaseClientSession): return None return self.transport.get_extra_info('peername') - async def send_timed_server_version_request(self, args=()): + async def send_timed_server_version_request(self, args=(), timeout=None): + timeout = timeout or self.timeout log.debug("send version request to %s:%i", *self.server) start = perf_counter() result = await asyncio.wait_for( - super().send_request('server.version', args), timeout=self.timeout + super().send_request('server.version', args), timeout=timeout ) current_response_time = perf_counter() - start response_sum = (self.response_time or 0) * self._response_samples + current_response_time @@ -52,18 +53,28 @@ class ClientSession(BaseClientSession): self._response_samples += 1 return result - async def send_request(self, method, args=()): + async def send_request(self, method, args=(), timeout=None): + timeout = timeout or self.timeout self.pending_amount += 1 try: if method == 'server.version': - return await self.send_timed_server_version_request(args) + return await self.send_timed_server_version_request(args, timeout) return await asyncio.wait_for( - super().send_request(method, args), timeout=self.timeout + super().send_request(method, args), timeout=timeout ) except RPCError as e: log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s", *self.server, *e.args) raise e + except ConnectionError: + log.warning("connection to %s:%i lost", *self.server) + self.synchronous_close() + raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost") + except asyncio.TimeoutError: + raise + except asyncio.CancelledError: + self.synchronous_close() + raise finally: self.pending_amount -= 1 @@ -83,19 +94,16 @@ class ClientSession(BaseClientSession): except (asyncio.TimeoutError, OSError): await self.close() retry_delay = min(60, retry_delay * 2) - log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) + log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) try: await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) except asyncio.TimeoutError: pass - except asyncio.CancelledError: - self.synchronous_close() - raise finally: self.trigger_urgent_reconnect.clear() - def ensure_server_version(self, required='1.2'): - return self.send_request('server.version', [__version__, required]) + async def ensure_server_version(self, required='1.2', timeout=3): + return await self.send_request('server.version', [__version__, required], timeout) async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) @@ -120,7 +128,6 @@ class ClientSession(BaseClientSession): class BaseNetwork: def __init__(self, ledger): - self.switch_event = asyncio.Event() self.config = ledger.config self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) self.client: Optional[ClientSession] = None @@ -141,24 +148,41 @@ class BaseNetwork: 'blockchain.address.subscribe': self._on_status_controller, } + async def switch_to_fastest(self): + try: + client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30) + except asyncio.TimeoutError: + if self.client: + await self.client.close() + self.client = None + for session in self.session_pool.sessions: + session.synchronous_close() + log.warning("not connected to any wallet servers") + return + if not self.client or client.server_address_and_port != self.client.server_address_and_port: + current_client = self.client + self.client = client + log.info("Switching to SPV wallet server: %s:%d", *self.client.server) + try: + self._update_remote_height((await self.subscribe_headers(),)) + log.info("Subscribed to headers: %s:%d", *self.client.server) + if current_client: + await current_client.close() + log.info("Closed connection to %s:%i", *current_client.server) + except asyncio.TimeoutError: + if self.client: + await self.client.close() + self.client = current_client + return + self._on_connected_controller.add(True) + await asyncio.sleep(30) + async def start(self): self.running = True self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) while self.running: - try: - self.client = await self.session_pool.wait_for_fastest_session() - self._update_remote_height((await self.subscribe_headers(),)) - log.info("Switching to SPV wallet server: %s:%d", *self.client.server) - self._on_connected_controller.add(True) - self.client.on_disconnected.listen(lambda _: self.switch_event.set()) - await self.switch_event.wait() - self.switch_event.clear() - except asyncio.CancelledError: - await self.stop() - raise - except asyncio.TimeoutError: - pass + await self.switch_to_fastest() async def stop(self): self.running = False