better way to batch announce + handle different intervals for different trackers

This commit is contained in:
Victor Shyba 2022-03-09 19:59:30 -03:00
parent d4aca89a48
commit 99fc7178c1
3 changed files with 45 additions and 19 deletions

View file

@ -734,10 +734,12 @@ class TrackerAnnouncerComponent(Component):
async def announce_forever(self):
while True:
to_sleep = 60.0
to_announce = []
for file in self.file_manager.get_filtered():
if not file.downloader:
continue
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
to_announce.append(bytes.fromhex(file.sd_hash))
await self.tracker_client.announce_many(*to_announce)
await asyncio.sleep(to_sleep)
async def start(self):

View file

@ -132,7 +132,6 @@ class TrackerClient:
self.servers = servers
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
self.tasks = {}
self.announced = 0
async def start(self):
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
@ -149,18 +148,31 @@ class TrackerClient:
while self.tasks:
self.tasks.popitem()[1].cancel()
def hash_done(self, info_hash):
self.tasks.pop(info_hash, None)
if len(self.tasks) == 0 and self.announced > 0:
log.info("Tracker finished announcing %d files.", self.announced)
self.announced = 0
def on_hash(self, info_hash, on_announcement=None):
if info_hash not in self.tasks:
task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
task.add_done_callback(lambda *_: self.hash_done(info_hash))
task.add_done_callback(lambda *_: self.tasks.pop(info_hash, None))
self.tasks[info_hash] = task
async def announce_many(self, *info_hashes, stopped=False):
await asyncio.gather(
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers],
return_exceptions=True)
async def _announce_many(self, server, info_hashes, stopped=False):
tracker_ip = await resolve_host(*server, 'udp')
still_good_info_hashes = {
info_hash for (info_hash, (next_announcement, _)) in self.results.get(tracker_ip, {}).items()
if time.time() < next_announcement
}
results = await asyncio.gather(
*[self._probe_server(info_hash, tracker_ip, server[1], stopped=stopped)
for info_hash in info_hashes if info_hash not in still_good_info_hashes],
return_exceptions=True)
if results:
errors = sum([1 for result in results if result is None or isinstance(result, Exception)])
log.info("Tracker: finished announcing %d files to %s:%d, %d errors", len(results), *server, errors)
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
found = []
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
@ -172,18 +184,18 @@ class TrackerClient:
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
result = None
if info_hash in self.results:
next_announcement, result = self.results[info_hash]
self.results.setdefault(tracker_host, {})
tracker_host = await resolve_host(tracker_host, tracker_port, 'udp')
if info_hash in self.results[tracker_host]:
next_announcement, result = self.results[tracker_host][info_hash]
if time.time() < next_announcement:
return result
try:
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
result = await self.client.announce(
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
self.results[info_hash] = (time.time() + result.interval, result)
self.announced += 1
info_hash, self.node_id, self.announce_port, tracker_host, tracker_port, stopped)
self.results[tracker_host][info_hash] = (time.time() + result.interval, result)
except asyncio.TimeoutError: # todo: this is UDP, timeout is common, we need a better metric for failures
self.results[info_hash] = (time.time() + 60.0, result)
self.results[tracker_host][info_hash] = (time.time() + 60.0, result)
log.debug("Tracker timed out: %s:%d", tracker_host, tracker_port)
return None
log.debug("Announced: %s found %d peers for %s", tracker_host, len(result.peers), info_hash.hex()[:8])

View file

@ -74,6 +74,18 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
self.assertEqual(announcement.peers,
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self):
await asyncio.gather(*[self.add_server() for _ in range(3)])
self.client.servers.append(("no.it.does.not.exist", 7070))
self.client.servers.append(("127.0.0.2", 7070))
info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)]
await self.client.announce_many(*info_hashes)
for server in self.servers.values():
self.assertDictEqual(
server.peers, {
info_hash: [encode_peer("127.0.0.1", self.client.announce_port)] for info_hash in info_hashes
})
async def test_announce_using_helper_function(self):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
queue = asyncio.Queue()
@ -91,14 +103,14 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
await self.client.get_peer_list(info_hash)
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
async def test_multiple(self):
async def test_multiple_servers(self):
await asyncio.gather(*[self.add_server() for _ in range(10)])
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash)
for server in self.servers.values():
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
async def test_multiple_with_bad_one(self):
async def test_multiple_servers_with_bad_one(self):
await asyncio.gather(*[self.add_server() for _ in range(10)])
self.client.servers.append(("127.0.0.2", 7070))
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
@ -106,7 +118,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
for server in self.servers.values():
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
async def test_multiple_with_different_peers_across_helper_function(self):
async def test_multiple_servers_with_different_peers_across_helper_function(self):
# this is how the downloader uses it
await asyncio.gather(*[self.add_server() for _ in range(10)])
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)