tests: add support for multiple trackers

This commit is contained in:
Victor Shyba 2022-03-09 16:53:45 -03:00
parent 0e4f1eae5b
commit cc4a578578

View file

@ -44,12 +44,20 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
class UDPTrackerClientTestCase(AsyncioTestCase): class UDPTrackerClientTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.server = UDPTrackerServerProtocol() self.servers = {}
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.server, local_addr=("127.0.0.1", 59900)) self.client = TrackerClient(b"\x00" * 48, 4444, [])
self.addCleanup(transport.close)
self.client = TrackerClient(b"\x00" * 48, 4444, [("127.0.0.1", 59900)])
await self.client.start() await self.client.start()
self.addCleanup(self.client.stop) self.addCleanup(self.client.stop)
await self.add_server()
async def add_server(self, port=None, add_to_client=True):
port = port or len(self.servers) + 59990
server = UDPTrackerServerProtocol()
transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port))
self.addCleanup(transport.close)
self.servers[port] = server
if add_to_client:
self.client.servers.append(("127.0.0.1", port))
async def test_announce(self): async def test_announce(self):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
@ -69,7 +77,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
async def test_error(self): async def test_error(self):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
self.server.known_conns.clear() list(self.servers.values())[0].known_conns.clear()
self.client.results.clear() self.client.results.clear()
with self.assertRaises(Exception) as err: with self.assertRaises(Exception) as err:
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)