handle multiple results from multiple trackers

This commit is contained in:
Victor Shyba 2022-03-09 17:47:23 -03:00
parent 2918d8c7b4
commit d4aca89a48
2 changed files with 37 additions and 18 deletions

View file

@ -138,7 +138,7 @@ class TrackerClient:
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: self.client, local_addr=("0.0.0.0", 0)) lambda: self.client, local_addr=("0.0.0.0", 0))
self.EVENT_CONTROLLER.stream.listen( self.EVENT_CONTROLLER.stream.listen(
lambda request: self.on_hash(request[1]) if request[0] == 'search' else None) lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None)
def stop(self): def stop(self):
if self.transport is not None: if self.transport is not None:
@ -155,18 +155,18 @@ class TrackerClient:
log.info("Tracker finished announcing %d files.", self.announced) log.info("Tracker finished announcing %d files.", self.announced)
self.announced = 0 self.announced = 0
def on_hash(self, info_hash): def on_hash(self, info_hash, on_announcement=None):
if info_hash not in self.tasks: if info_hash not in self.tasks:
task = asyncio.create_task(self.get_peer_list(info_hash)) task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
task.add_done_callback(lambda *_: self.hash_done(info_hash)) task.add_done_callback(lambda *_: self.hash_done(info_hash))
self.tasks[info_hash] = task self.tasks[info_hash] = task
async def get_peer_list(self, info_hash, stopped=False): async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
found = [] found = []
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
result = await done result = await done
if result is not None: if result is not None:
self.EVENT_CONTROLLER.add((info_hash, result)) await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))
found.append(result) found.append(result)
return found return found
@ -191,6 +191,4 @@ class TrackerClient:
def subscribe_hash(info_hash: bytes, on_data): def subscribe_hash(info_hash: bytes, on_data):
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash)) TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data))
TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == info_hash).add_done_callback(
lambda request: on_data(request.result()[1]))

View file

@ -16,6 +16,10 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
def connection_made(self, transport: asyncio.DatagramTransport) -> None: def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.transport = transport self.transport = transport
def add_peer(self, info_hash, ip_address: str, port: int):
self.peers.setdefault(info_hash, [])
self.peers[info_hash].append(encode_peer(ip_address, port))
def datagram_received(self, data: bytes, address: (str, int)) -> None: def datagram_received(self, data: bytes, address: (str, int)) -> None:
if len(data) < 16: if len(data) < 16:
return return
@ -30,18 +34,21 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
if req.connection_id not in self.known_conns: if req.connection_id not in self.known_conns:
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00')) resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
else: else:
self.peers.setdefault(req.info_hash, []) compact_address = encode_peer(address[0], req.port)
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray())
compact_address = compact_ip + req.port.to_bytes(2, "big", signed=False)
if req.event != 3: if req.event != 3:
self.peers[req.info_hash].append(compact_address) self.add_peer(req.info_hash, address[0], req.port)
elif compact_address in self.peers[req.info_hash]: elif compact_address in self.peers.get(req.info_hash, []):
self.peers[req.info_hash].remove(compact_address) self.peers[req.info_hash].remove(compact_address)
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]] peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers)) resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
return self.transport.sendto(resp, address) return self.transport.sendto(resp, address)
def encode_peer(ip_address: str, port: int):
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
return compact_ip + port.to_bytes(2, "big", signed=False)
class UDPTrackerClientTestCase(AsyncioTestCase): class UDPTrackerClientTestCase(AsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.servers = {} self.servers = {}
@ -70,7 +77,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
async def test_announce_using_helper_function(self): async def test_announce_using_helper_function(self):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
queue = asyncio.Queue() queue = asyncio.Queue()
subscribe_hash(info_hash, queue.put_nowait) subscribe_hash(info_hash, queue.put)
announcement = await queue.get() announcement = await queue.get()
peers = announcement.peers peers = announcement.peers
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
@ -89,8 +96,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
for server in self.servers.values(): for server in self.servers.values():
self.assertEqual(1, len(server.peers)) self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
self.assertEqual(1, len(server.peers[info_hash]))
async def test_multiple_with_bad_one(self): async def test_multiple_with_bad_one(self):
await asyncio.gather(*[self.add_server() for _ in range(10)]) await asyncio.gather(*[self.add_server() for _ in range(10)])
@ -98,5 +104,20 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
for server in self.servers.values(): for server in self.servers.values():
self.assertEqual(1, len(server.peers)) self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
self.assertEqual(1, len(server.peers[info_hash]))
async def test_multiple_with_different_peers_across_helper_function(self):
# this is how the downloader uses it
await asyncio.gather(*[self.add_server() for _ in range(10)])
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
fake_peers = []
for server in self.servers.values():
for _ in range(10):
peer = (f"127.0.0.{random.randint(1, 255)}", random.randint(2000, 65500))
fake_peers.append(peer)
server.add_peer(info_hash, *peer)
response = []
subscribe_hash(info_hash, response.append)
await asyncio.sleep(0)
await asyncio.gather(*self.client.tasks.values())
self.assertEqual(11, len(response))