From 1bef56f030ce78abad4c532d6c1df999ec7c3934 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 20 Aug 2019 14:55:24 -0400 Subject: [PATCH] disconnect from duplicate servers (same ip, different dns) --- torba/torba/client/basenetwork.py | 50 ++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index b6190b973..b7293ec37 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -239,14 +239,45 @@ class SessionPool: key=itemgetter(0) )[1] + def _get_session_connect_callback(self, session: ClientSession): + loop = asyncio.get_event_loop() + + def callback(): + duplicate_connections = [ + s for s in self.sessions + if s is not session and s.server_address_and_port == session.server_address_and_port + ] + already_connected = None if not duplicate_connections else duplicate_connections[0] + if already_connected: + self.sessions.pop(session).cancel() + session.synchronous_close() + log.info("wallet server %s resolves to the same server as %s, rechecking in an hour", + session.server[0], already_connected.server[0]) + loop.call_later(3600, self._connect_session, session.server) + return + self.new_connection_event.set() + log.info("connected to %s:%i", *session.server) + + return callback + + def _connect_session(self, server: Tuple[str, int]): + session = None + for s in self.sessions: + if s.server == server: + session = s + break + if not session: + session = ClientSession( + network=self.network, server=server + ) + session._on_connect_cb = self._get_session_connect_callback(session) + if session not in self.sessions or not self.sessions[session] or self.sessions[session].done(): + self.sessions[session] = asyncio.create_task(session.ensure_session()) + self.sessions[session].add_done_callback(lambda _: self.ensure_connections()) + def start(self, default_servers): - callback = self.new_connection_event.set - self.sessions = { - ClientSession( - network=self.network, server=server, on_connect_callback=callback - ): None for server in default_servers - } - self.ensure_connections() + for server in default_servers: + self._connect_session(server) def stop(self): for task in self.sessions.values(): @@ -255,10 +286,7 @@ class SessionPool: def ensure_connections(self): for session, task in list(self.sessions.items()): - if not task or task.done(): - task = asyncio.create_task(session.ensure_session()) - task.add_done_callback(lambda _: self.ensure_connections()) - self.sessions[session] = task + self._connect_session(session.server) def trigger_nodelay_connect(self): # used when other parts of the system sees we might have internet back