forked from LBRYCommunity/lbry-sdk
handle multiple results from multiple trackers
This commit is contained in:
parent
2918d8c7b4
commit
d4aca89a48
2 changed files with 37 additions and 18 deletions
|
@ -138,7 +138,7 @@ class TrackerClient:
|
||||||
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
lambda: self.client, local_addr=("0.0.0.0", 0))
|
lambda: self.client, local_addr=("0.0.0.0", 0))
|
||||||
self.EVENT_CONTROLLER.stream.listen(
|
self.EVENT_CONTROLLER.stream.listen(
|
||||||
lambda request: self.on_hash(request[1]) if request[0] == 'search' else None)
|
lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.transport is not None:
|
if self.transport is not None:
|
||||||
|
@ -155,18 +155,18 @@ class TrackerClient:
|
||||||
log.info("Tracker finished announcing %d files.", self.announced)
|
log.info("Tracker finished announcing %d files.", self.announced)
|
||||||
self.announced = 0
|
self.announced = 0
|
||||||
|
|
||||||
def on_hash(self, info_hash):
|
def on_hash(self, info_hash, on_announcement=None):
|
||||||
if info_hash not in self.tasks:
|
if info_hash not in self.tasks:
|
||||||
task = asyncio.create_task(self.get_peer_list(info_hash))
|
task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement))
|
||||||
task.add_done_callback(lambda *_: self.hash_done(info_hash))
|
task.add_done_callback(lambda *_: self.hash_done(info_hash))
|
||||||
self.tasks[info_hash] = task
|
self.tasks[info_hash] = task
|
||||||
|
|
||||||
async def get_peer_list(self, info_hash, stopped=False):
|
async def get_peer_list(self, info_hash, stopped=False, on_announcement=None):
|
||||||
found = []
|
found = []
|
||||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||||
result = await done
|
result = await done
|
||||||
if result is not None:
|
if result is not None:
|
||||||
self.EVENT_CONTROLLER.add((info_hash, result))
|
await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else []))
|
||||||
found.append(result)
|
found.append(result)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
@ -191,6 +191,4 @@ class TrackerClient:
|
||||||
|
|
||||||
|
|
||||||
def subscribe_hash(info_hash: bytes, on_data):
|
def subscribe_hash(info_hash: bytes, on_data):
|
||||||
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash))
|
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data))
|
||||||
TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == info_hash).add_done_callback(
|
|
||||||
lambda request: on_data(request.result()[1]))
|
|
||||||
|
|
|
@ -16,6 +16,10 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
||||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
|
def add_peer(self, info_hash, ip_address: str, port: int):
|
||||||
|
self.peers.setdefault(info_hash, [])
|
||||||
|
self.peers[info_hash].append(encode_peer(ip_address, port))
|
||||||
|
|
||||||
def datagram_received(self, data: bytes, address: (str, int)) -> None:
|
def datagram_received(self, data: bytes, address: (str, int)) -> None:
|
||||||
if len(data) < 16:
|
if len(data) < 16:
|
||||||
return
|
return
|
||||||
|
@ -30,18 +34,21 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
||||||
if req.connection_id not in self.known_conns:
|
if req.connection_id not in self.known_conns:
|
||||||
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
|
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
|
||||||
else:
|
else:
|
||||||
self.peers.setdefault(req.info_hash, [])
|
compact_address = encode_peer(address[0], req.port)
|
||||||
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray())
|
|
||||||
compact_address = compact_ip + req.port.to_bytes(2, "big", signed=False)
|
|
||||||
if req.event != 3:
|
if req.event != 3:
|
||||||
self.peers[req.info_hash].append(compact_address)
|
self.add_peer(req.info_hash, address[0], req.port)
|
||||||
elif compact_address in self.peers[req.info_hash]:
|
elif compact_address in self.peers.get(req.info_hash, []):
|
||||||
self.peers[req.info_hash].remove(compact_address)
|
self.peers[req.info_hash].remove(compact_address)
|
||||||
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
|
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
|
||||||
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
|
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
|
||||||
return self.transport.sendto(resp, address)
|
return self.transport.sendto(resp, address)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_peer(ip_address: str, port: int):
|
||||||
|
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
|
||||||
|
return compact_ip + port.to_bytes(2, "big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.servers = {}
|
self.servers = {}
|
||||||
|
@ -70,7 +77,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
async def test_announce_using_helper_function(self):
|
async def test_announce_using_helper_function(self):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
subscribe_hash(info_hash, queue.put_nowait)
|
subscribe_hash(info_hash, queue.put)
|
||||||
announcement = await queue.get()
|
announcement = await queue.get()
|
||||||
peers = announcement.peers
|
peers = announcement.peers
|
||||||
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||||
|
@ -89,8 +96,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
await self.client.get_peer_list(info_hash)
|
await self.client.get_peer_list(info_hash)
|
||||||
for server in self.servers.values():
|
for server in self.servers.values():
|
||||||
self.assertEqual(1, len(server.peers))
|
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||||
self.assertEqual(1, len(server.peers[info_hash]))
|
|
||||||
|
|
||||||
async def test_multiple_with_bad_one(self):
|
async def test_multiple_with_bad_one(self):
|
||||||
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||||
|
@ -98,5 +104,20 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
await self.client.get_peer_list(info_hash)
|
await self.client.get_peer_list(info_hash)
|
||||||
for server in self.servers.values():
|
for server in self.servers.values():
|
||||||
self.assertEqual(1, len(server.peers))
|
self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]})
|
||||||
self.assertEqual(1, len(server.peers[info_hash]))
|
|
||||||
|
async def test_multiple_with_different_peers_across_helper_function(self):
|
||||||
|
# this is how the downloader uses it
|
||||||
|
await asyncio.gather(*[self.add_server() for _ in range(10)])
|
||||||
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
|
fake_peers = []
|
||||||
|
for server in self.servers.values():
|
||||||
|
for _ in range(10):
|
||||||
|
peer = (f"127.0.0.{random.randint(1, 255)}", random.randint(2000, 65500))
|
||||||
|
fake_peers.append(peer)
|
||||||
|
server.add_peer(info_hash, *peer)
|
||||||
|
response = []
|
||||||
|
subscribe_hash(info_hash, response.append)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.gather(*self.client.tasks.values())
|
||||||
|
self.assertEqual(11, len(response))
|
||||||
|
|
Loading…
Reference in a new issue