diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 754c4fe0f..329f4ba12 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -734,10 +734,12 @@ class TrackerAnnouncerComponent(Component): async def announce_forever(self): while True: to_sleep = 60.0 + to_announce = [] for file in self.file_manager.get_filtered(): if not file.downloader: continue - self.tracker_client.on_hash(bytes.fromhex(file.sd_hash)) + to_announce.append(bytes.fromhex(file.sd_hash)) + await self.tracker_client.announce_many(*to_announce) await asyncio.sleep(to_sleep) async def start(self): diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index e23a5ba84..07c4ab785 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -132,7 +132,6 @@ class TrackerClient: self.servers = servers self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires self.tasks = {} - self.announced = 0 async def start(self): self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( @@ -149,18 +148,31 @@ class TrackerClient: while self.tasks: self.tasks.popitem()[1].cancel() - def hash_done(self, info_hash): - self.tasks.pop(info_hash, None) - if len(self.tasks) == 0 and self.announced > 0: - log.info("Tracker finished announcing %d files.", self.announced) - self.announced = 0 - def on_hash(self, info_hash, on_announcement=None): if info_hash not in self.tasks: task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement)) - task.add_done_callback(lambda *_: self.hash_done(info_hash)) + task.add_done_callback(lambda *_: self.tasks.pop(info_hash, None)) self.tasks[info_hash] = task + async def announce_many(self, *info_hashes, stopped=False): + await asyncio.gather( + *[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers], + return_exceptions=True) + + async def _announce_many(self, server, info_hashes, stopped=False): + tracker_ip = await resolve_host(*server, 'udp') + still_good_info_hashes = { + info_hash for (info_hash, (next_announcement, _)) in self.results.get(tracker_ip, {}).items() + if time.time() < next_announcement + } + results = await asyncio.gather( + *[self._probe_server(info_hash, tracker_ip, server[1], stopped=stopped) + for info_hash in info_hashes if info_hash not in still_good_info_hashes], + return_exceptions=True) + if results: + errors = sum([1 for result in results if result is None or isinstance(result, Exception)]) + log.info("Tracker: finished announcing %d files to %s:%d, %d errors", len(results), *server, errors) + async def get_peer_list(self, info_hash, stopped=False, on_announcement=None): found = [] for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): @@ -172,18 +184,18 @@ class TrackerClient: async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False): result = None - if info_hash in self.results: - next_announcement, result = self.results[info_hash] + self.results.setdefault(tracker_host, {}) + tracker_host = await resolve_host(tracker_host, tracker_port, 'udp') + if info_hash in self.results[tracker_host]: + next_announcement, result = self.results[tracker_host][info_hash] if time.time() < next_announcement: return result try: - tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp') result = await self.client.announce( - info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped) - self.results[info_hash] = (time.time() + result.interval, result) - self.announced += 1 + info_hash, self.node_id, self.announce_port, tracker_host, tracker_port, stopped) + self.results[tracker_host][info_hash] = (time.time() + result.interval, result) except asyncio.TimeoutError: # todo: this is UDP, timeout is common, we need a better metric for failures - self.results[info_hash] = (time.time() + 60.0, result) + self.results[tracker_host][info_hash] = (time.time() + 60.0, result) log.debug("Tracker timed out: %s:%d", tracker_host, tracker_port) return None log.debug("Announced: %s found %d peers for %s", tracker_host, len(result.peers), info_hash.hex()[:8]) diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index e4b3c2f1c..8712a479e 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -74,6 +74,18 @@ class UDPTrackerClientTestCase(AsyncioTestCase): self.assertEqual(announcement.peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) + async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self): + await asyncio.gather(*[self.add_server() for _ in range(3)]) + self.client.servers.append(("no.it.does.not.exist", 7070)) + self.client.servers.append(("127.0.0.2", 7070)) + info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)] + await self.client.announce_many(*info_hashes) + for server in self.servers.values(): + self.assertDictEqual( + server.peers, { + info_hash: [encode_peer("127.0.0.1", self.client.announce_port)] for info_hash in info_hashes + }) + async def test_announce_using_helper_function(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) queue = asyncio.Queue() @@ -91,14 +103,14 @@ class UDPTrackerClientTestCase(AsyncioTestCase): await self.client.get_peer_list(info_hash) self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00') - async def test_multiple(self): + async def test_multiple_servers(self): await asyncio.gather(*[self.add_server() for _ in range(10)]) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) await self.client.get_peer_list(info_hash) for server in self.servers.values(): self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) - async def test_multiple_with_bad_one(self): + async def test_multiple_servers_with_bad_one(self): await asyncio.gather(*[self.add_server() for _ in range(10)]) self.client.servers.append(("127.0.0.2", 7070)) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) @@ -106,7 +118,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase): for server in self.servers.values(): self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) - async def test_multiple_with_different_peers_across_helper_function(self): + async def test_multiple_servers_with_different_peers_across_helper_function(self): # this is how the downloader uses it await asyncio.gather(*[self.add_server() for _ in range(10)]) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)