make client server updatable from conf
This commit is contained in:
parent
a7cea4082e
commit
1169a02c8b
3 changed files with 12 additions and 10 deletions
|
@ -746,7 +746,7 @@ class TrackerAnnouncerComponent(Component):
|
|||
node = self.component_manager.get_component(DHT_COMPONENT) \
|
||||
if self.component_manager.has_component(DHT_COMPONENT) else None
|
||||
node_id = node.protocol.node_id if node else None
|
||||
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers)
|
||||
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, lambda: self.conf.tracker_servers)
|
||||
await self.tracker_client.start()
|
||||
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
|
||||
self.announce_task = asyncio.create_task(self.announce_forever())
|
||||
|
|
|
@ -135,12 +135,12 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
|||
class TrackerClient:
|
||||
EVENT_CONTROLLER = StreamController()
|
||||
|
||||
def __init__(self, node_id, announce_port, servers, timeout=10.0):
|
||||
def __init__(self, node_id, announce_port, get_servers, timeout=10.0):
|
||||
self.client = UDPTrackerClientProtocol(timeout=timeout)
|
||||
self.transport = None
|
||||
self.peer_id = make_peer_id(node_id.hex() if node_id else None)
|
||||
self.announce_port = announce_port
|
||||
self.servers = servers
|
||||
self._get_servers = get_servers
|
||||
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
|
||||
self.tasks = {}
|
||||
|
||||
|
@ -167,7 +167,7 @@ class TrackerClient:
|
|||
|
||||
async def announce_many(self, *info_hashes, stopped=False):
|
||||
await asyncio.gather(
|
||||
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers],
|
||||
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self._get_servers()],
|
||||
return_exceptions=True)
|
||||
|
||||
async def _announce_many(self, server, info_hashes, stopped=False):
|
||||
|
@ -186,7 +186,8 @@ class TrackerClient:
|
|||
|
||||
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
|
||||
found = []
|
||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||
servers = self._get_servers()
|
||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in servers]):
|
||||
result = await done
|
||||
if result is not None:
|
||||
await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))
|
||||
|
|
|
@ -51,8 +51,9 @@ def encode_peer(ip_address: str, port: int):
|
|||
|
||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.client_servers_list = []
|
||||
self.servers = {}
|
||||
self.client = TrackerClient(b"\x00" * 48, 4444, [], timeout=0.1)
|
||||
self.client = TrackerClient(b"\x00" * 48, 4444, lambda: self.client_servers_list, timeout=0.1)
|
||||
await self.client.start()
|
||||
self.addCleanup(self.client.stop)
|
||||
await self.add_server()
|
||||
|
@ -65,7 +66,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port))
|
||||
self.addCleanup(transport.close)
|
||||
if add_to_client:
|
||||
self.client.servers.append(("127.0.0.1", port))
|
||||
self.client_servers_list.append(("127.0.0.1", port))
|
||||
|
||||
async def test_announce(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
|
@ -76,8 +77,8 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
|
||||
async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(3)])
|
||||
self.client.servers.append(("no.it.does.not.exist", 7070))
|
||||
self.client.servers.append(("127.0.0.2", 7070))
|
||||
self.client_servers_list.append(("no.it.does.not.exist", 7070))
|
||||
self.client_servers_list.append(("127.0.0.2", 7070))
|
||||
info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)]
|
||||
await self.client.announce_many(*info_hashes)
|
||||
for server in self.servers.values():
|
||||
|
@ -112,7 +113,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
|
||||
async def test_multiple_servers_with_bad_one(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
self.client.servers.append(("127.0.0.2", 7070))
|
||||
self.client_servers_list.append(("127.0.0.2", 7070))
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
for server in self.servers.values():
|
||||
|
|
Loading…
Reference in a new issue