From a667278c99f9ae312757df46a9c237bc6f3110ba Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Wed, 21 Aug 2019 16:16:07 -0400
Subject: [PATCH] switch_to_fastest

---
 torba/torba/client/basenetwork.py | 74 ++++++++++++++++++++-----------
 1 file changed, 49 insertions(+), 25 deletions(-)

diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py
index 1693bb178..38ff07b5a 100644
--- a/torba/torba/client/basenetwork.py
+++ b/torba/torba/client/basenetwork.py
@@ -40,11 +40,12 @@ class ClientSession(BaseClientSession):
             return None
         return self.transport.get_extra_info('peername')
 
-    async def send_timed_server_version_request(self, args=()):
+    async def send_timed_server_version_request(self, args=(), timeout=None):
+        timeout = timeout or self.timeout
         log.debug("send version request to %s:%i", *self.server)
         start = perf_counter()
         result = await asyncio.wait_for(
-            super().send_request('server.version', args), timeout=self.timeout
+            super().send_request('server.version', args), timeout=timeout
         )
         current_response_time = perf_counter() - start
         response_sum = (self.response_time or 0) * self._response_samples + current_response_time
@@ -52,18 +53,28 @@ class ClientSession(BaseClientSession):
         self._response_samples += 1
         return result
 
-    async def send_request(self, method, args=()):
+    async def send_request(self, method, args=(), timeout=None):
+        timeout = timeout or self.timeout
         self.pending_amount += 1
         try:
             if method == 'server.version':
-                return await self.send_timed_server_version_request(args)
+                return await self.send_timed_server_version_request(args, timeout)
             return await asyncio.wait_for(
-                super().send_request(method, args), timeout=self.timeout
+                super().send_request(method, args), timeout=timeout
             )
         except RPCError as e:
             log.warning("Wallet server (%s:%i) returned an error. Code: %s Message: %s",
                         *self.server, *e.args)
             raise e
+        except ConnectionError:
+            log.warning("connection to %s:%i lost", *self.server)
+            self.synchronous_close()
+            raise asyncio.CancelledError(f"connection to {self.server[0]}:{self.server[1]} lost")
+        except asyncio.TimeoutError:
+            raise
+        except asyncio.CancelledError:
+            self.synchronous_close()
+            raise
         finally:
             self.pending_amount -= 1
 
@@ -83,19 +94,16 @@ class ClientSession(BaseClientSession):
             except (asyncio.TimeoutError, OSError):
                 await self.close()
                 retry_delay = min(60, retry_delay * 2)
-                log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
+                log.debug("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server)
             try:
                 await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay)
             except asyncio.TimeoutError:
                 pass
-            except asyncio.CancelledError:
-                self.synchronous_close()
-                raise
             finally:
                 self.trigger_urgent_reconnect.clear()
 
-    def ensure_server_version(self, required='1.2'):
-        return self.send_request('server.version', [__version__, required])
+    async def ensure_server_version(self, required='1.2', timeout=3):
+        return await self.send_request('server.version', [__version__, required], timeout)
 
     async def create_connection(self, timeout=6):
         connector = Connector(lambda: self, *self.server)
@@ -120,7 +128,6 @@ class ClientSession(BaseClientSession):
 class BaseNetwork:
 
     def __init__(self, ledger):
-        self.switch_event = asyncio.Event()
         self.config = ledger.config
         self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6))
         self.client: Optional[ClientSession] = None
@@ -141,24 +148,41 @@ class BaseNetwork:
             'blockchain.address.subscribe': self._on_status_controller,
         }
 
+    async def switch_to_fastest(self):
+        try:
+            client = await asyncio.wait_for(self.session_pool.wait_for_fastest_session(), 30)
+        except asyncio.TimeoutError:
+            if self.client:
+                await self.client.close()
+            self.client = None
+            for session in self.session_pool.sessions:
+                session.synchronous_close()
+            log.warning("not connected to any wallet servers")
+            return
+        if not self.client or client.server_address_and_port != self.client.server_address_and_port:
+            current_client = self.client
+            self.client = client
+            log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
+            try:
+                self._update_remote_height((await self.subscribe_headers(),))
+                log.info("Subscribed to headers: %s:%d", *self.client.server)
+                if current_client:
+                    await current_client.close()
+                    log.info("Closed connection to %s:%i", *current_client.server)
+            except asyncio.TimeoutError:
+                if self.client:
+                    await self.client.close()
+                    self.client = current_client
+                return
+            self._on_connected_controller.add(True)
+        await asyncio.sleep(30)
+
     async def start(self):
         self.running = True
         self.session_pool.start(self.config['default_servers'])
         self.on_header.listen(self._update_remote_height)
         while self.running:
-            try:
-                self.client = await self.session_pool.wait_for_fastest_session()
-                self._update_remote_height((await self.subscribe_headers(),))
-                log.info("Switching to SPV wallet server: %s:%d", *self.client.server)
-                self._on_connected_controller.add(True)
-                self.client.on_disconnected.listen(lambda _: self.switch_event.set())
-                await self.switch_event.wait()
-                self.switch_event.clear()
-            except asyncio.CancelledError:
-                await self.stop()
-                raise
-            except asyncio.TimeoutError:
-                pass
+            await self.switch_to_fastest()
 
     async def stop(self):
         self.running = False