From 1bef56f030ce78abad4c532d6c1df999ec7c3934 Mon Sep 17 00:00:00 2001
From: Jack Robison <jackrobison@lbry.io>
Date: Tue, 20 Aug 2019 14:55:24 -0400
Subject: [PATCH] disconnect from duplicate servers (same ip, different dns)

---
 torba/torba/client/basenetwork.py | 50 ++++++++++++++++++++++++-------
 1 file changed, 39 insertions(+), 11 deletions(-)

diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py
index b6190b973..b7293ec37 100644
--- a/torba/torba/client/basenetwork.py
+++ b/torba/torba/client/basenetwork.py
@@ -239,14 +239,45 @@ class SessionPool:
             key=itemgetter(0)
         )[1]
 
+    def _get_session_connect_callback(self, session: ClientSession):
+        loop = asyncio.get_event_loop()
+
+        def callback():
+            duplicate_connections = [
+                s for s in self.sessions
+                if s is not session and s.server_address_and_port == session.server_address_and_port
+            ]
+            already_connected = None if not duplicate_connections else duplicate_connections[0]
+            if already_connected:
+                self.sessions.pop(session).cancel()
+                session.synchronous_close()
+                log.info("wallet server %s resolves to the same server as %s, rechecking in an hour",
+                          session.server[0], already_connected.server[0])
+                loop.call_later(3600, self._connect_session, session.server)
+                return
+            self.new_connection_event.set()
+            log.info("connected to %s:%i", *session.server)
+
+        return callback
+
+    def _connect_session(self, server: Tuple[str, int]):
+        session = None
+        for s in self.sessions:
+            if s.server == server:
+                session = s
+                break
+        if not session:
+            session = ClientSession(
+                network=self.network, server=server
+            )
+            session._on_connect_cb = self._get_session_connect_callback(session)
+        if session not in self.sessions or not self.sessions[session] or self.sessions[session].done():
+            self.sessions[session] = asyncio.create_task(session.ensure_session())
+            self.sessions[session].add_done_callback(lambda _: self.ensure_connections())
+
     def start(self, default_servers):
-        callback = self.new_connection_event.set
-        self.sessions = {
-            ClientSession(
-                network=self.network, server=server, on_connect_callback=callback
-            ): None for server in default_servers
-        }
-        self.ensure_connections()
+        for server in default_servers:
+            self._connect_session(server)
 
     def stop(self):
         for task in self.sessions.values():
@@ -255,10 +286,7 @@ class SessionPool:
 
     def ensure_connections(self):
         for session, task in list(self.sessions.items()):
-            if not task or task.done():
-                task = asyncio.create_task(session.ensure_session())
-                task.add_done_callback(lambda _: self.ensure_connections())
-                self.sessions[session] = task
+            self._connect_session(session.server)
 
     def trigger_nodelay_connect(self):
         # used when other parts of the system sees we might have internet back