From 9e9a64d989eebcb9394ae1d491c10982f0523af1 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Mon, 7 Mar 2022 23:35:12 -0300 Subject: [PATCH] evented system for tracker announcements --- lbry/extras/daemon/components.py | 12 +++-- lbry/stream/downloader.py | 38 ++++----------- lbry/torrent/tracker.py | 74 ++++++++++++++++++++++-------- tests/unit/torrent/test_tracker.py | 28 +++++------ 4 files changed, 89 insertions(+), 63 deletions(-) diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 8535008d4..f290bedf2 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -28,6 +28,8 @@ from lbry.extras.daemon.storage import SQLiteStorage from lbry.torrent.torrent_manager import TorrentManager from lbry.wallet import WalletManager from lbry.wallet.usage_payment import WalletServerPayer +from lbry.torrent.tracker import TrackerClient + try: from lbry.torrent.session import TorrentSession except ImportError: @@ -720,6 +722,7 @@ class TrackerAnnouncerComponent(Component): super().__init__(component_manager) self.file_manager = None self.announce_task = None + self.tracker_client: typing.Optional[TrackerClient] = None @property def component(self): @@ -733,12 +736,15 @@ class TrackerAnnouncerComponent(Component): continue next_announce = file.downloader.next_tracker_announce_time if next_announce is None or next_announce <= time.time(): - await file.downloader.refresh_from_trackers(False) - else: - to_sleep = min(to_sleep, next_announce - time.time()) + self.tracker_client.on_hash(bytes.fromhex(file.sd_hash)) await asyncio.sleep(to_sleep + 1) async def start(self): + node = self.component_manager.get_component(DHT_COMPONENT) \ + if self.component_manager.has_component(DHT_COMPONENT) else None + node_id = node.protocol.node_id if node else None + self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers) + await self.tracker_client.start() self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT) self.announce_task = asyncio.create_task(self.announce_forever()) diff --git a/lbry/stream/downloader.py b/lbry/stream/downloader.py index 608b33849..76f0bcdd2 100644 --- a/lbry/stream/downloader.py +++ b/lbry/stream/downloader.py @@ -10,9 +10,10 @@ from lbry.error import DownloadSDTimeoutError from lbry.utils import lru_cache_concurrent from lbry.stream.descriptor import StreamDescriptor from lbry.blob_exchange.downloader import BlobDownloader -from lbry.torrent.tracker import get_peer_list +from lbry.torrent.tracker import subscribe_hash if typing.TYPE_CHECKING: + from lbry.torrent.tracker import AnnounceResponse from lbry.conf import Config from lbry.dht.node import Node from lbry.blob.blob_manager import BlobManager @@ -67,32 +68,13 @@ class StreamDownloader: fixed_peers = await get_kademlia_peers_from_hosts(self.config.fixed_peers) self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _add_fixed_peers, fixed_peers) - async def refresh_from_trackers(self, save_peers=True): - if not self.config.tracker_servers: - return - node_id = self.node.protocol.node_id if self.node else None - port = self.config.tcp_port - for server in self.config.tracker_servers: - try: - announcement = await get_peer_list( - bytes.fromhex(self.sd_hash)[:20], node_id, port, server[0], server[1]) - log.info("Announced %s to %s", self.sd_hash[:8], server) - self.next_tracker_announce_time = max(self.next_tracker_announce_time or 0, - time.time() + announcement.interval) - except asyncio.CancelledError: - raise - except asyncio.TimeoutError: - log.warning("Tracker timed out: %s", server) - return - except Exception: - log.exception("Unexpected error querying tracker %s", server) - return - if not save_peers: - return - peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers] - peers = await get_kademlia_peers_from_hosts(peers) - log.info("Found %d peers from tracker %s for %s", len(peers), server, self.sd_hash[:8]) - self.peer_queue.put_nowait(peers) + async def _process_announcement(self, announcement: 'AnnounceResponse'): + peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers] + peers = await get_kademlia_peers_from_hosts(peers) + log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8]) + self.next_tracker_announce_time = min(time() + announcement.interval, + self.next_tracker_announce_time or 1 << 64) + self.peer_queue.put_nowait(peers) async def load_descriptor(self, connection_id: int = 0): # download or get the sd blob @@ -123,7 +105,7 @@ class StreamDownloader: self.accumulate_task.cancel() _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) await self.add_fixed_peers() - asyncio.ensure_future(self.refresh_from_trackers()) + subscribe_hash(self.sd_hash, self._process_announcement) # start searching for peers for the sd hash self.search_queue.put_nowait(self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash) diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index c46f63bbc..33cf330a4 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -4,9 +4,11 @@ import asyncio import logging from collections import namedtuple -from lbry.utils import resolve_host +from lbry.utils import resolve_host, async_timed_cache +from lbry.wallet.stream import StreamController log = logging.getLogger(__name__) +CONNECTION_EXPIRES_AFTER_SECONDS = 360 # see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"]) ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"]) @@ -77,17 +79,19 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol): return decode(ConnectResponse, await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port)) - async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, connection_id=None, stopped=False): - if not connection_id: - reply = await self.connect(tracker_ip, tracker_port) - connection_id = reply.connection_id + @async_timed_cache(CONNECTION_EXPIRES_AFTER_SECONDS) + async def ensure_connection_id(self, peer_id, tracker_ip, tracker_port): + # peer_id is just to ensure cache coherency + return (await self.connect(tracker_ip, tracker_port)).connection_id + + async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, stopped=False): + connection_id = await self.ensure_connection_id(peer_id, tracker_ip, tracker_port) # this should make the key deterministic but unique per info hash + peer id key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port transaction_id = random.getrandbits(32) req = AnnounceRequest( connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port) - reply = await self.request(req, tracker_ip, tracker_port) - return decode(AnnounceResponse, reply), connection_id + return decode(AnnounceResponse, await self.request(req, tracker_ip, tracker_port)) async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None): if not connection_id: @@ -107,20 +111,52 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol): if data[3] == 3: return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message)) return self.data_queue[transaction_id].set_result(data) - print("error", data.hex()) + log.debug("unexpected packet (can be a response for a previously timed out request): %s", data.hex()) def connection_lost(self, exc: Exception = None) -> None: self.transport = None -async def get_peer_list(info_hash, node_id, port, tracker_ip, tracker_port, stopped=False): - node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False) - tracker_ip = await resolve_host(tracker_ip, tracker_port, 'udp') - proto = UDPTrackerClientProtocol() - transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(lambda: proto, local_addr=("0.0.0.0", 0)) - try: - reply, _ = await proto.announce(info_hash, node_id, port, tracker_ip, tracker_port, stopped=stopped) - return reply - finally: - if not transport.is_closing(): - transport.close() +class TrackerClient: + EVENT_CONTROLLER = StreamController() + def __init__(self, node_id, announce_port, servers): + self.client = UDPTrackerClientProtocol() + self.transport = None + self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False) + self.announce_port = announce_port + self.servers = servers + + async def start(self): + self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( + lambda: self.client, local_addr=("0.0.0.0", 0)) + self.EVENT_CONTROLLER.stream.listen(lambda request: self.on_hash(request[1]) if request[0] == 'search' else None) + + def stop(self): + if self.transport is not None: + self.transport.close() + self.client = None + self.transport = None + self.EVENT_CONTROLLER.close() + + def on_hash(self, info_hash): + asyncio.ensure_future(self.get_peer_list(info_hash)) + + async def get_peer_list(self, info_hash, stopped=False): + found = [] + for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): + found.append(await done) + return found + + async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False): + tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp') + result = await self.client.announce( + info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped) + log.info("Announced to tracker. Found %d peers for %s on %s", + len(result.peers), info_hash.hex()[:8], tracker_host) + self.EVENT_CONTROLLER.add((info_hash, result)) + return result + + +def subscribe_hash(hash, on_data): + TrackerClient.EVENT_CONTROLLER.add(('search', bytes.fromhex(hash))) + TrackerClient.EVENT_CONTROLLER.stream.listen(lambda request: on_data(request[1]) if request[0] == hash else None) diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index a03493118..7d981cf8a 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -3,8 +3,8 @@ import random from functools import reduce from lbry.testcase import AsyncioTestCase -from lbry.torrent.tracker import UDPTrackerClientProtocol, encode, decode, CompactIPv4Peer, ConnectRequest, \ - ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, get_peer_list +from lbry.torrent.tracker import encode, decode, CompactIPv4Peer, ConnectRequest, \ + ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, TrackerClient, subscribe_hash class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production @@ -44,30 +44,32 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su class UDPTrackerClientTestCase(AsyncioTestCase): async def asyncSetUp(self): - transport, _ = await self.loop.create_datagram_endpoint(UDPTrackerServerProtocol, local_addr=("127.0.0.1", 59900)) - self.addCleanup(transport.close) - self.client = UDPTrackerClientProtocol() - transport, _ = await self.loop.create_datagram_endpoint(lambda: self.client, local_addr=("127.0.0.1", 0)) + self.server = UDPTrackerServerProtocol() + transport, _ = await self.loop.create_datagram_endpoint(lambda: self.server, local_addr=("127.0.0.1", 59900)) self.addCleanup(transport.close) + self.client = TrackerClient(b"\x00" * 48, 4444, [("127.0.0.1", 59900)]) + await self.client.start() + self.addCleanup(self.client.stop) async def test_announce(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) - peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False) - announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900) + announcement = (await self.client.get_peer_list(info_hash))[0] self.assertEqual(announcement.seeders, 1) self.assertEqual(announcement.peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) async def test_announce_using_helper_function(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) - announcemenet = await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900) - peers = announcemenet.peers + queue = asyncio.Queue() + subscribe_hash(info_hash, queue.put_nowait) + announcement = await queue.get() + peers = announcement.peers self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) - self.assertEqual((await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900, stopped=True)).peers, []) async def test_error(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) - peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False) + await self.client.get_peer_list(info_hash) + self.server.known_conns.clear() with self.assertRaises(Exception) as err: - announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900, connection_id=10) + await self.client.get_peer_list(info_hash) self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')