handle multiple results from multiple trackers
This commit is contained in:
parent
2918d8c7b4
commit
d4aca89a48
2 changed files with 37 additions and 18 deletions
|
@ -138,7 +138,7 @@ class TrackerClient:
|
|||
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
lambda: self.client, local_addr=("0.0.0.0", 0))
|
||||
self.EVENT_CONTROLLER.stream.listen(
|
||||
lambda request: self.on_hash(request[1]) if request[0] == 'search' else None)
|
||||
lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None)
|
||||
|
||||
def stop(self):
|
||||
if self.transport is not None:
|
||||
|
@ -155,18 +155,18 @@ class TrackerClient:
|
|||
log.info("Tracker finished announcing %d files.", self.announced)
|
||||
self.announced = 0
|
||||
|
||||
def on_hash(self, info_hash):
|
||||
def on_hash(self, info_hash, on_announcement=None):
|
||||
if info_hash not in self.tasks:
|
||||
task = asyncio.create_task(self.get_peer_list(info_hash))
|
||||
task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
|
||||
task.add_done_callback(lambda *_: self.hash_done(info_hash))
|
||||
self.tasks[info_hash] = task
|
||||
|
||||
async def get_peer_list(self, info_hash, stopped=False):
|
||||
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
|
||||
found = []
|
||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||
result = await done
|
||||
if result is not None:
|
||||
self.EVENT_CONTROLLER.add((info_hash, result))
|
||||
await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))
|
||||
found.append(result)
|
||||
return found
|
||||
|
||||
|
@ -191,6 +191,4 @@ class TrackerClient:
|
|||
|
||||
|
||||
def subscribe_hash(info_hash: bytes, on_data):
|
||||
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash))
|
||||
TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == info_hash).add_done_callback(
|
||||
lambda request: on_data(request.result()[1]))
|
||||
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data))
|
||||
|
|
|
@ -16,6 +16,10 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
|||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||
self.transport = transport
|
||||
|
||||
def add_peer(self, info_hash, ip_address: str, port: int):
|
||||
self.peers.setdefault(info_hash, [])
|
||||
self.peers[info_hash].append(encode_peer(ip_address, port))
|
||||
|
||||
def datagram_received(self, data: bytes, address: (str, int)) -> None:
|
||||
if len(data) < 16:
|
||||
return
|
||||
|
@ -30,18 +34,21 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
|||
if req.connection_id not in self.known_conns:
|
||||
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
|
||||
else:
|
||||
self.peers.setdefault(req.info_hash, [])
|
||||
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray())
|
||||
compact_address = compact_ip + req.port.to_bytes(2, "big", signed=False)
|
||||
compact_address = encode_peer(address[0], req.port)
|
||||
if req.event != 3:
|
||||
self.peers[req.info_hash].append(compact_address)
|
||||
elif compact_address in self.peers[req.info_hash]:
|
||||
self.add_peer(req.info_hash, address[0], req.port)
|
||||
elif compact_address in self.peers.get(req.info_hash, []):
|
||||
self.peers[req.info_hash].remove(compact_address)
|
||||
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
|
||||
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
|
||||
return self.transport.sendto(resp, address)
|
||||
|
||||
|
||||
def encode_peer(ip_address: str, port: int):
|
||||
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
|
||||
return compact_ip + port.to_bytes(2, "big", signed=False)
|
||||
|
||||
|
||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.servers = {}
|
||||
|
@ -70,7 +77,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
async def test_announce_using_helper_function(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
queue = asyncio.Queue()
|
||||
subscribe_hash(info_hash, queue.put_nowait)
|
||||
subscribe_hash(info_hash, queue.put)
|
||||
announcement = await queue.get()
|
||||
peers = announcement.peers
|
||||
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||
|
@ -89,8 +96,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
for server in self.servers.values():
|
||||
self.assertEqual(1, len(server.peers))
|
||||
self.assertEqual(1, len(server.peers[info_hash]))
|
||||
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||
|
||||
async def test_multiple_with_bad_one(self):
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
|
@ -98,5 +104,20 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
|||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
for server in self.servers.values():
|
||||
self.assertEqual(1, len(server.peers))
|
||||
self.assertEqual(1, len(server.peers[info_hash]))
|
||||
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||
|
||||
async def test_multiple_with_different_peers_across_helper_function(self):
|
||||
# this is how the downloader uses it
|
||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
fake_peers = []
|
||||
for server in self.servers.values():
|
||||
for _ in range(10):
|
||||
peer = (f"127.0.0.{random.randint(1, 255)}", random.randint(2000, 65500))
|
||||
fake_peers.append(peer)
|
||||
server.add_peer(info_hash, *peer)
|
||||
response = []
|
||||
subscribe_hash(info_hash, response.append)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.gather(*self.client.tasks.values())
|
||||
self.assertEqual(11, len(response))
|
||||
|
|
Loading…
Reference in a new issue