diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 329f4ba12..b3be409e0 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -746,7 +746,7 @@ class TrackerAnnouncerComponent(Component): node = self.component_manager.get_component(DHT_COMPONENT) \ if self.component_manager.has_component(DHT_COMPONENT) else None node_id = node.protocol.node_id if node else None - self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers) + self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, lambda: self.conf.tracker_servers) await self.tracker_client.start() self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT) self.announce_task = asyncio.create_task(self.announce_forever()) diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index 8ae778771..27010c7b8 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -135,12 +135,12 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol): class TrackerClient: EVENT_CONTROLLER = StreamController() - def __init__(self, node_id, announce_port, servers, timeout=10.0): + def __init__(self, node_id, announce_port, get_servers, timeout=10.0): self.client = UDPTrackerClientProtocol(timeout=timeout) self.transport = None self.peer_id = make_peer_id(node_id.hex() if node_id else None) self.announce_port = announce_port - self.servers = servers + self._get_servers = get_servers self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires self.tasks = {} @@ -167,7 +167,7 @@ class TrackerClient: async def announce_many(self, *info_hashes, stopped=False): await asyncio.gather( - *[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers], + *[self._announce_many(server, info_hashes, stopped=stopped) for server in self._get_servers()], return_exceptions=True) async def _announce_many(self, server, info_hashes, stopped=False): @@ -186,7 +186,8 @@ class TrackerClient: async def get_peer_list(self, info_hash, stopped=False, on_announcement=None): found = [] - for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): + servers = self._get_servers() + for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in servers]): result = await done if result is not None: await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else [])) diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index 8712a479e..4bb733361 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -51,8 +51,9 @@ def encode_peer(ip_address: str, port: int): class UDPTrackerClientTestCase(AsyncioTestCase): async def asyncSetUp(self): + self.client_servers_list = [] self.servers = {} - self.client = TrackerClient(b"\x00" * 48, 4444, [], timeout=0.1) + self.client = TrackerClient(b"\x00" * 48, 4444, lambda: self.client_servers_list, timeout=0.1) await self.client.start() self.addCleanup(self.client.stop) await self.add_server() @@ -65,7 +66,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase): transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port)) self.addCleanup(transport.close) if add_to_client: - self.client.servers.append(("127.0.0.1", port)) + self.client_servers_list.append(("127.0.0.1", port)) async def test_announce(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) @@ -76,8 +77,8 @@ class UDPTrackerClientTestCase(AsyncioTestCase): async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self): await asyncio.gather(*[self.add_server() for _ in range(3)]) - self.client.servers.append(("no.it.does.not.exist", 7070)) - self.client.servers.append(("127.0.0.2", 7070)) + self.client_servers_list.append(("no.it.does.not.exist", 7070)) + self.client_servers_list.append(("127.0.0.2", 7070)) info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)] await self.client.announce_many(*info_hashes) for server in self.servers.values(): @@ -112,7 +113,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase): async def test_multiple_servers_with_bad_one(self): await asyncio.gather(*[self.add_server() for _ in range(10)]) - self.client.servers.append(("127.0.0.2", 7070)) + self.client_servers_list.append(("127.0.0.2", 7070)) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) await self.client.get_peer_list(info_hash) for server in self.servers.values():