forked from LBRYCommunity/lbry-sdk
better way to batch announce + handle different intervals for different trackers
This commit is contained in:
parent
d4aca89a48
commit
99fc7178c1
3 changed files with 45 additions and 19 deletions
|
@ -734,10 +734,12 @@ class TrackerAnnouncerComponent(Component):
|
|||
async def announce_forever(self):
|
||||
while True:
|
||||
to_sleep = 60.0
|
||||
to_announce = []
|
||||
for file in self.file_manager.get_filtered():
|
||||
if not file.downloader:
|
||||
continue
|
||||
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
|
||||
to_announce.append(bytes.fromhex(file.sd_hash))
|
||||
await self.tracker_client.announce_many(*to_announce)
|
||||
await asyncio.sleep(to_sleep)
|
||||
|
||||
async def start(self):
|
||||
|
|
|
@ -132,7 +132,6 @@ class TrackerClient:
|
|||
self.servers = servers
|
||||
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
|
||||
self.tasks = {}
|
||||
self.announced = 0
|
||||
|
||||
async def start(self):
|
||||
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
|
@ -149,18 +148,31 @@ class TrackerClient:
|
|||
while self.tasks:
|
||||
self.tasks.popitem()[1].cancel()
|
||||
|
||||
def hash_done(self, info_hash):
|
||||
self.tasks.pop(info_hash, None)
|
||||
if len(self.tasks) == 0 and self.announced > 0:
|
||||
log.info("Tracker finished announcing %d files.", self.announced)
|
||||
self.announced = 0
|
||||
|
||||
def on_hash(self, info_hash, on_announcement=None):
|
||||
if info_hash not in self.tasks:
|
||||
task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
|
||||
task.add_done_callback(lambda *_: self.hash_done(info_hash))
|
||||
task.add_done_callback(lambda *_: self.tasks.pop(info_hash, None))
|
||||
self.tasks[info_hash] = task
|
||||
|
||||
async def announce_many(self, *info_hashes, stopped=False):
|
||||
await asyncio.gather(
|
||||
*[self._announce_many(server, info_hashes, stopped=stopped) for server in self.servers],
|
||||
return_exceptions=True)
|
||||
|
||||
async def _announce_many(self, server, info_hashes, stopped=False):
|
||||
tracker_ip = await resolve_host(*server, 'udp')
|
||||
still_good_info_hashes = {
|
||||
info_hash for (info_hash, (next_announcement, _)) in self.results.get(tracker_ip, {}).items()
|
||||
if time.time() < next_announcement
|
||||
}
|
||||
results = await asyncio.gather(
|
||||
*[self._probe_server(info_hash, tracker_ip, server[1], stopped=stopped)
|
||||
for info_hash in info_hashes if info_hash not in still_good_info_hashes],
|
||||
return_exceptions=True)
|
||||
if results:
|
||||
errors = sum([1 for result in results if result is None or isinstance(result, Exception)])
|
||||
log.info("Tracker: finished announcing %d files to %s:%d, %d errors", len(results), *server, errors)
|
||||
|
||||
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
|
||||
found = []
|
||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||
|
@ -172,18 +184,18 @@ class TrackerClient:
|
|||
|
||||
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
|
||||
result = None
|
||||
if info_hash in self.results:
|
||||
next_announcement, result = self.results[info_hash]
|
||||
self.results.setdefault(tracker_host, {})
|
||||
tracker_host = await resolve_host(tracker_host, tracker_port, 'udp')
|
||||
if info_hash in self.results[tracker_host]:
|
||||
next_announcement, result = self.results[tracker_host][info_hash]
|
||||
if time.time() < next_announcement:
|
||||
return result
|
||||
try:
|
||||
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
|
||||
result = await self.client.announce(
|
||||
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
|
||||
self.results[info_hash] = (time.time() + result.interval, result)
|
||||
self.announced += 1
|
||||
info_hash, self.node_id, self.announce_port, tracker_host, tracker_port, stopped)
|
||||
self.results[tracker_host][info_hash] = (time.time() + result.interval, result)
|
||||
except asyncio.TimeoutError: # todo: this is UDP, timeout is common, we need a better metric for failures
|
||||
self.results[info_hash] = (time.time() + 60.0, result)
|
||||
self.results[tracker_host][info_hash] = (time.time() + 60.0, result)
|
||||
log.debug("Tracker timed out: %s:%d", tracker_host, tracker_port)
|
||||
return None
|
||||
log.debug("Announced: %s found %d peers for %s", tracker_host, len(result.peers), info_hash.hex()[:8])
|
||||
|
|
|
@ -74,6 +74,18 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
self.assertEqual(announcement.peers,
|
||||
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||
|
||||
async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(3)])
|
||||
self.client.servers.append(("no.it.does.not.exist", 7070))
|
||||
self.client.servers.append(("127.0.0.2", 7070))
|
||||
info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)]
|
||||
await self.client.announce_many(*info_hashes)
|
||||
for server in self.servers.values():
|
||||
self.assertDictEqual(
|
||||
server.peers, {
|
||||
info_hash: [encode_peer("127.0.0.1", self.client.announce_port)] for info_hash in info_hashes
|
||||
})
|
||||
|
||||
async def test_announce_using_helper_function(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
queue = asyncio.Queue()
|
||||
|
@ -91,14 +103,14 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
await self.client.get_peer_list(info_hash)
|
||||
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
||||
|
||||
async def test_multiple(self):
|
||||
async def test_multiple_servers(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
for server in self.servers.values():
|
||||
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||
|
||||
async def test_multiple_with_bad_one(self):
|
||||
async def test_multiple_servers_with_bad_one(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
self.client.servers.append(("127.0.0.2", 7070))
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
|
@ -106,7 +118,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
for server in self.servers.values():
|
||||
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||
|
||||
async def test_multiple_with_different_peers_across_helper_function(self):
|
||||
async def test_multiple_servers_with_different_peers_across_helper_function(self):
|
||||
# this is how the downloader uses it
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
|
|
Loading…
Reference in a new issue