disconnect from duplicate servers (same ip, different dns)
This commit is contained in:
parent
9045737504
commit
1bef56f030
1 changed files with 39 additions and 11 deletions
|
@ -239,14 +239,45 @@ class SessionPool:
|
|||
key=itemgetter(0)
|
||||
)[1]
|
||||
|
||||
def _get_session_connect_callback(self, session: ClientSession):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def callback():
|
||||
duplicate_connections = [
|
||||
s for s in self.sessions
|
||||
if s is not session and s.server_address_and_port == session.server_address_and_port
|
||||
]
|
||||
already_connected = None if not duplicate_connections else duplicate_connections[0]
|
||||
if already_connected:
|
||||
self.sessions.pop(session).cancel()
|
||||
session.synchronous_close()
|
||||
log.info("wallet server %s resolves to the same server as %s, rechecking in an hour",
|
||||
session.server[0], already_connected.server[0])
|
||||
loop.call_later(3600, self._connect_session, session.server)
|
||||
return
|
||||
self.new_connection_event.set()
|
||||
log.info("connected to %s:%i", *session.server)
|
||||
|
||||
return callback
|
||||
|
||||
def _connect_session(self, server: Tuple[str, int]):
|
||||
session = None
|
||||
for s in self.sessions:
|
||||
if s.server == server:
|
||||
session = s
|
||||
break
|
||||
if not session:
|
||||
session = ClientSession(
|
||||
network=self.network, server=server
|
||||
)
|
||||
session._on_connect_cb = self._get_session_connect_callback(session)
|
||||
if session not in self.sessions or not self.sessions[session] or self.sessions[session].done():
|
||||
self.sessions[session] = asyncio.create_task(session.ensure_session())
|
||||
self.sessions[session].add_done_callback(lambda _: self.ensure_connections())
|
||||
|
||||
def start(self, default_servers):
|
||||
callback = self.new_connection_event.set
|
||||
self.sessions = {
|
||||
ClientSession(
|
||||
network=self.network, server=server, on_connect_callback=callback
|
||||
): None for server in default_servers
|
||||
}
|
||||
self.ensure_connections()
|
||||
for server in default_servers:
|
||||
self._connect_session(server)
|
||||
|
||||
def stop(self):
|
||||
for task in self.sessions.values():
|
||||
|
@ -255,10 +286,7 @@ class SessionPool:
|
|||
|
||||
def ensure_connections(self):
|
||||
for session, task in list(self.sessions.items()):
|
||||
if not task or task.done():
|
||||
task = asyncio.create_task(session.ensure_session())
|
||||
task.add_done_callback(lambda _: self.ensure_connections())
|
||||
self.sessions[session] = task
|
||||
self._connect_session(session.server)
|
||||
|
||||
def trigger_nodelay_connect(self):
|
||||
# used when other parts of the system sees we might have internet back
|
||||
|
|
Loading…
Reference in a new issue