make client server updatable from conf

This commit is contained in:
Victor Shyba 2022-03-12 02:15:45 -03:00
parent a7cea4082e
commit 1169a02c8b
3 changed files with 12 additions and 10 deletions

View file

@ -746,7 +746,7 @@ class TrackerAnnouncerComponent(Component):
node = self.component_manager.get_component(DHT_COMPONENT) \ node = self.component_manager.get_component(DHT_COMPONENT) \
if self.component_manager.has_component(DHT_COMPONENT) else None if self.component_manager.has_component(DHT_COMPONENT) else None
node_id = node.protocol.node_id if node else None node_id = node.protocol.node_id if node else None
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers) self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, lambda: self.conf.tracker_servers)
await self.tracker_client.start() await self.tracker_client.start()
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT) self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
self.announce_task = asyncio.create_task(self.announce_forever()) self.announce_task = asyncio.create_task(self.announce_forever())

View file

@ -135,12 +135,12 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
class TrackerClient: class TrackerClient:
EVENT_CONTROLLER = StreamController() EVENT_CONTROLLER = StreamController()
def __init__(self, node_id, announce_port, servers, timeout=10.0): def __init__(self, node_id, announce_port, get_servers, timeout=10.0):
self.client = UDPTrackerClientProtocol(timeout=timeout) self.client = UDPTrackerClientProtocol(timeout=timeout)
self.transport = None self.transport = None
self.peer_id = make_peer_id(node_id.hex() if node_id else None) self.peer_id = make_peer_id(node_id.hex() if node_id else None)
self.announce_port = announce_port self.announce_port = announce_port
self.servers = servers self._get_servers = get_servers
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
self.tasks = {} self.tasks = {}
@ -167,7 +167,7 @@ class TrackerClient:
async def announce_many(self, *info_hashes, stopped=False): async def announce_many(self, *info_hashes, stopped=False):
await asyncio.gather( await asyncio.gather(
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers], *[self._announce_many(server, info_hashes, stopped=stopped) for server in self._get_servers()],
return_exceptions=True) return_exceptions=True)
async def _announce_many(self, server, info_hashes, stopped=False): async def _announce_many(self, server, info_hashes, stopped=False):
@ -186,7 +186,8 @@ class TrackerClient:
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None): async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
found = [] found = []
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): servers = self._get_servers()
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in servers]):
result = await done result = await done
if result is not None: if result is not None:
await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else [])) await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))

View file

@ -51,8 +51,9 @@ def encode_peer(ip_address: str, port: int):
class UDPTrackerClientTestCase(AsyncioTestCase): class UDPTrackerClientTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.client_servers_list = []
self.servers = {} self.servers = {}
self.client = TrackerClient(b"\x00" * 48, 4444, [], timeout=0.1) self.client = TrackerClient(b"\x00" * 48, 4444, lambda: self.client_servers_list, timeout=0.1)
await self.client.start() await self.client.start()
self.addCleanup(self.client.stop) self.addCleanup(self.client.stop)
await self.add_server() await self.add_server()
@ -65,7 +66,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port)) transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port))
self.addCleanup(transport.close) self.addCleanup(transport.close)
if add_to_client: if add_to_client:
self.client.servers.append(("127.0.0.1", port)) self.client_servers_list.append(("127.0.0.1", port))
async def test_announce(self): async def test_announce(self):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
@ -76,8 +77,8 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self): async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self):
await asyncio.gather(*[self.add_server() for _ in range(3)]) await asyncio.gather(*[self.add_server() for _ in range(3)])
self.client.servers.append(("no.it.does.not.exist", 7070)) self.client_servers_list.append(("no.it.does.not.exist", 7070))
self.client.servers.append(("127.0.0.2", 7070)) self.client_servers_list.append(("127.0.0.2", 7070))
info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)] info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)]
await self.client.announce_many(*info_hashes) await self.client.announce_many(*info_hashes)
for server in self.servers.values(): for server in self.servers.values():
@ -112,7 +113,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
async def test_multiple_servers_with_bad_one(self): async def test_multiple_servers_with_bad_one(self):
await asyncio.gather(*[self.add_server() for _ in range(10)]) await asyncio.gather(*[self.add_server() for _ in range(10)])
self.client.servers.append(("127.0.0.2", 7070)) self.client_servers_list.append(("127.0.0.2", 7070))
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
for server in self.servers.values(): for server in self.servers.values():