From 407c570f8b24a5a7890c425e6655c54d6d559d65 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 9 Mar 2022 17:07:16 -0300 Subject: [PATCH] tests: lower timeout, add test with bad and good mixed --- lbry/torrent/tracker.py | 4 ++-- tests/unit/torrent/test_tracker.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index 0095a4fc6..f38cf17fc 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -124,8 +124,8 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol): class TrackerClient: EVENT_CONTROLLER = StreamController() - def __init__(self, node_id, announce_port, servers): - self.client = UDPTrackerClientProtocol() + def __init__(self, node_id, announce_port, servers, timeout=10.0): + self.client = UDPTrackerClientProtocol(timeout=timeout) self.transport = None self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False) self.announce_port = announce_port diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index 1c323fc15..a6afed06a 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -45,7 +45,7 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su class UDPTrackerClientTestCase(AsyncioTestCase): async def asyncSetUp(self): self.servers = {} - self.client = TrackerClient(b"\x00" * 48, 4444, []) + self.client = TrackerClient(b"\x00" * 48, 4444, [], timeout=0.1) await self.client.start() self.addCleanup(self.client.stop) await self.add_server() @@ -91,3 +91,12 @@ class UDPTrackerClientTestCase(AsyncioTestCase): for server in self.servers.values(): self.assertEqual(1, len(server.peers)) self.assertEqual(1, len(server.peers[info_hash])) + + async def test_multiple_with_bad_one(self): + await asyncio.gather(*[self.add_server() for _ in range(10)]) + self.client.servers.append(("127.0.0.2", 7070)) + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + await self.client.get_peer_list(info_hash) + for server in self.servers.values(): + self.assertEqual(1, len(server.peers)) + self.assertEqual(1, len(server.peers[info_hash]))