diff --git a/lbry/conf.py b/lbry/conf.py index 15fe5f8b6..814e9d744 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -681,6 +681,10 @@ class Config(CLIConfig): ('cdn.reflector.lbry.com', 5567) ]) + tracker_servers = Servers("BitTorrent-compatible (BEP15) UDP trackers for helping P2P discovery", [ + ('tracker.lbry.com', 1337) + ]) + lbryum_servers = Servers("SPV wallet servers", [ ('spv11.lbry.com', 50001), ('spv12.lbry.com', 50001), diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 03bef1534..e061c4363 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -27,6 +27,8 @@ from lbry.extras.daemon.storage import SQLiteStorage from lbry.torrent.torrent_manager import TorrentManager from lbry.wallet import WalletManager from lbry.wallet.usage_payment import WalletServerPayer +from lbry.torrent.tracker import TrackerClient + try: from lbry.torrent.session import TorrentSession except ImportError: @@ -48,6 +50,7 @@ BACKGROUND_DOWNLOADER_COMPONENT = "background_downloader" PEER_PROTOCOL_SERVER_COMPONENT = "peer_protocol_server" UPNP_COMPONENT = "upnp" EXCHANGE_RATE_MANAGER_COMPONENT = "exchange_rate_manager" +TRACKER_ANNOUNCER_COMPONENT = "tracker_announcer_component" LIBTORRENT_COMPONENT = "libtorrent_component" @@ -708,3 +711,49 @@ class ExchangeRateManagerComponent(Component): async def stop(self): self.exchange_rate_manager.stop() + + +class TrackerAnnouncerComponent(Component): + component_name = TRACKER_ANNOUNCER_COMPONENT + depends_on = [FILE_MANAGER_COMPONENT] + + def __init__(self, component_manager): + super().__init__(component_manager) + self.file_manager = None + self.announce_task = None + self.tracker_client: typing.Optional[TrackerClient] = None + + @property + def component(self): + return self.tracker_client + + @property + def running(self): + return self._running and self.announce_task and not self.announce_task.done() + + async def announce_forever(self): + while True: + sleep_seconds = 60.0 + announce_sd_hashes = [] + for file in self.file_manager.get_filtered(): + if not file.downloader: + continue + announce_sd_hashes.append(bytes.fromhex(file.sd_hash)) + await self.tracker_client.announce_many(*announce_sd_hashes) + await asyncio.sleep(sleep_seconds) + + async def start(self): + node = self.component_manager.get_component(DHT_COMPONENT) \ + if self.component_manager.has_component(DHT_COMPONENT) else None + node_id = node.protocol.node_id if node else None + self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, lambda: self.conf.tracker_servers) + await self.tracker_client.start() + self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT) + self.announce_task = asyncio.create_task(self.announce_forever()) + + async def stop(self): + self.file_manager = None + if self.announce_task and not self.announce_task.done(): + self.announce_task.cancel() + self.announce_task = None + self.tracker_client.stop() diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 6881889bc..c9983f756 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -44,7 +44,7 @@ from lbry.error import ( from lbry.extras import system_info from lbry.extras.daemon import analytics from lbry.extras.daemon.components import WALLET_COMPONENT, DATABASE_COMPONENT, DHT_COMPONENT, BLOB_COMPONENT -from lbry.extras.daemon.components import FILE_MANAGER_COMPONENT, DISK_SPACE_COMPONENT +from lbry.extras.daemon.components import FILE_MANAGER_COMPONENT, DISK_SPACE_COMPONENT, TRACKER_ANNOUNCER_COMPONENT from lbry.extras.daemon.components import EXCHANGE_RATE_MANAGER_COMPONENT, UPNP_COMPONENT from lbry.extras.daemon.componentmanager import RequiredCondition from lbry.extras.daemon.componentmanager import ComponentManager @@ -4949,7 +4949,6 @@ class Daemon(metaclass=JSONRPCServerType): DHT / Blob Exchange peer commands. """ - @requires(DHT_COMPONENT) async def jsonrpc_peer_list(self, blob_hash, page=None, page_size=None): """ Get peers for blob hash @@ -4971,21 +4970,29 @@ class Daemon(metaclass=JSONRPCServerType): if not is_valid_blobhash(blob_hash): # TODO: use error from lbry.error raise Exception("invalid blob hash") - peers = [] peer_q = asyncio.Queue(loop=self.component_manager.loop) - await self.dht_node._peers_for_value_producer(blob_hash, peer_q) + if self.component_manager.has_component(TRACKER_ANNOUNCER_COMPONENT): + tracker = self.component_manager.get_component(TRACKER_ANNOUNCER_COMPONENT) + tracker_peers = await tracker.get_kademlia_peer_list(bytes.fromhex(blob_hash)) + log.info("Found %d peers for %s from trackers.", len(tracker_peers), blob_hash[:8]) + peer_q.put_nowait(tracker_peers) + elif not self.component_manager.has_component(DHT_COMPONENT): + raise Exception("Peer list needs, at least, either a DHT component or a Tracker component for discovery.") + peers = [] + if self.component_manager.has_component(DHT_COMPONENT): + await self.dht_node._peers_for_value_producer(blob_hash, peer_q) while not peer_q.empty(): peers.extend(peer_q.get_nowait()) - results = [ - { - "node_id": hexlify(peer.node_id).decode(), + results = { + (peer.address, peer.tcp_port): { + "node_id": hexlify(peer.node_id).decode() if peer.node_id else None, "address": peer.address, "udp_port": peer.udp_port, "tcp_port": peer.tcp_port, } for peer in peers - ] - return paginate_list(results, page, page_size) + } + return paginate_list(list(results.values()), page, page_size) @requires(DATABASE_COMPONENT) async def jsonrpc_blob_announce(self, blob_hash=None, stream_hash=None, sd_hash=None): diff --git a/lbry/file/source.py b/lbry/file/source.py index b661eb594..0cded2f6c 100644 --- a/lbry/file/source.py +++ b/lbry/file/source.py @@ -45,6 +45,7 @@ class ManagedDownloadSource: self.purchase_receipt = None self._added_on = added_on self.analytics_manager = analytics_manager + self.downloader = None self.saving = asyncio.Event(loop=self.loop) self.finished_writing = asyncio.Event(loop=self.loop) diff --git a/lbry/stream/downloader.py b/lbry/stream/downloader.py index 0ef627248..1f78979b7 100644 --- a/lbry/stream/downloader.py +++ b/lbry/stream/downloader.py @@ -8,6 +8,8 @@ from lbry.error import DownloadSDTimeoutError from lbry.utils import lru_cache_concurrent from lbry.stream.descriptor import StreamDescriptor from lbry.blob_exchange.downloader import BlobDownloader +from lbry.torrent.tracker import enqueue_tracker_search + if typing.TYPE_CHECKING: from lbry.conf import Config from lbry.dht.node import Node @@ -91,6 +93,7 @@ class StreamDownloader: self.accumulate_task.cancel() _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) await self.add_fixed_peers() + enqueue_tracker_search(bytes.fromhex(self.sd_hash), self.peer_queue) # start searching for peers for the sd hash self.search_queue.put_nowait(self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash) diff --git a/lbry/stream/managed_stream.py b/lbry/stream/managed_stream.py index 2a85da66e..5ceaff2ad 100644 --- a/lbry/stream/managed_stream.py +++ b/lbry/stream/managed_stream.py @@ -16,10 +16,8 @@ from lbry.file.source import ManagedDownloadSource if typing.TYPE_CHECKING: from lbry.conf import Config - from lbry.schema.claim import Claim from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_info import BlobInfo - from lbry.dht.node import Node from lbry.extras.daemon.analytics import AnalyticsManager from lbry.wallet.transaction import Transaction diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py new file mode 100644 index 000000000..82daa87f5 --- /dev/null +++ b/lbry/torrent/tracker.py @@ -0,0 +1,285 @@ +import random +import socket +import string +import struct +import asyncio +import logging +import time +import ipaddress +from collections import namedtuple +from functools import reduce +from typing import Optional + +from lbry.dht.node import get_kademlia_peers_from_hosts +from lbry.utils import resolve_host, async_timed_cache, cache_concurrent +from lbry.wallet.stream import StreamController +from lbry import version + +log = logging.getLogger(__name__) +CONNECTION_EXPIRES_AFTER_SECONDS = 50 +PREFIX = 'LB' # todo: PR BEP20 to add ourselves +DEFAULT_TIMEOUT_SECONDS = 10.0 +DEFAULT_CONCURRENCY_LIMIT = 100 +# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html +ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"]) +ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"]) +AnnounceRequest = namedtuple("AnnounceRequest", + ["connection_id", "action", "transaction_id", "info_hash", "peer_id", "downloaded", "left", + "uploaded", "event", "ip_addr", "key", "num_want", "port"]) +AnnounceResponse = namedtuple("AnnounceResponse", + ["action", "transaction_id", "interval", "leechers", "seeders", "peers"]) +CompactIPv4Peer = namedtuple("CompactPeer", ["address", "port"]) +ScrapeRequest = namedtuple("ScrapeRequest", ["connection_id", "action", "transaction_id", "infohashes"]) +ScrapeResponse = namedtuple("ScrapeResponse", ["action", "transaction_id", "items"]) +ScrapeResponseItem = namedtuple("ScrapeResponseItem", ["seeders", "completed", "leechers"]) +ErrorResponse = namedtuple("ErrorResponse", ["action", "transaction_id", "message"]) +structs = { + ConnectRequest: struct.Struct(">QII"), + ConnectResponse: struct.Struct(">IIQ"), + AnnounceRequest: struct.Struct(">QII20s20sQQQIIIiH"), + AnnounceResponse: struct.Struct(">IIIII"), + CompactIPv4Peer: struct.Struct(">IH"), + ScrapeRequest: struct.Struct(">QII"), + ScrapeResponse: struct.Struct(">II"), + ScrapeResponseItem: struct.Struct(">III"), + ErrorResponse: struct.Struct(">II") +} + + +def decode(cls, data, offset=0): + decoder = structs[cls] + if cls is AnnounceResponse: + return AnnounceResponse(*decoder.unpack_from(data, offset), + peers=[decode(CompactIPv4Peer, data, index) for index in range(20, len(data), 6)]) + elif cls is ScrapeResponse: + return ScrapeResponse(*decoder.unpack_from(data, offset), + items=[decode(ScrapeResponseItem, data, index) for index in range(8, len(data), 12)]) + elif cls is ErrorResponse: + return ErrorResponse(*decoder.unpack_from(data, offset), data[decoder.size:]) + return cls(*decoder.unpack_from(data, offset)) + + +def encode(obj): + if isinstance(obj, ScrapeRequest): + return structs[ScrapeRequest].pack(*obj[:-1]) + b''.join(obj.infohashes) + elif isinstance(obj, ErrorResponse): + return structs[ErrorResponse].pack(*obj[:-1]) + obj.message + elif isinstance(obj, AnnounceResponse): + return structs[AnnounceResponse].pack(*obj[:-1]) + b''.join([encode(peer) for peer in obj.peers]) + return structs[type(obj)].pack(*obj) + + +def make_peer_id(random_part: Optional[str] = None) -> bytes: + # see https://wiki.theory.org/BitTorrentSpecification#peer_id and https://www.bittorrent.org/beps/bep_0020.html + # not to confuse with node id; peer id identifies uniquely the software, version and instance + random_part = random_part or ''.join(random.choice(string.ascii_letters) for _ in range(20)) + return f"{PREFIX}-{'-'.join(map(str, version))}-{random_part}"[:20].encode() + + +class UDPTrackerClientProtocol(asyncio.DatagramProtocol): + def __init__(self, timeout: float = DEFAULT_TIMEOUT_SECONDS): + self.transport = None + self.data_queue = {} + self.timeout = timeout + self.semaphore = asyncio.Semaphore(DEFAULT_CONCURRENCY_LIMIT) + + def connection_made(self, transport: asyncio.DatagramTransport) -> None: + self.transport = transport + + async def request(self, obj, tracker_ip, tracker_port): + self.data_queue[obj.transaction_id] = asyncio.get_running_loop().create_future() + try: + async with self.semaphore: + self.transport.sendto(encode(obj), (tracker_ip, tracker_port)) + return await asyncio.wait_for(self.data_queue[obj.transaction_id], self.timeout) + finally: + self.data_queue.pop(obj.transaction_id, None) + + async def connect(self, tracker_ip, tracker_port): + transaction_id = random.getrandbits(32) + return decode(ConnectResponse, + await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port)) + + @cache_concurrent + @async_timed_cache(CONNECTION_EXPIRES_AFTER_SECONDS) + async def ensure_connection_id(self, peer_id, tracker_ip, tracker_port): + # peer_id is just to ensure cache coherency + return (await self.connect(tracker_ip, tracker_port)).connection_id + + async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, stopped=False): + connection_id = await self.ensure_connection_id(peer_id, tracker_ip, tracker_port) + # this should make the key deterministic but unique per info hash + peer id + key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port + transaction_id = random.getrandbits(32) + req = AnnounceRequest( + connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port) + return decode(AnnounceResponse, await self.request(req, tracker_ip, tracker_port)) + + async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None): + connection_id = await self.ensure_connection_id(None, tracker_ip, tracker_port) + transaction_id = random.getrandbits(32) + reply = await self.request( + ScrapeRequest(connection_id, 2, transaction_id, infohashes), tracker_ip, tracker_port) + return decode(ScrapeResponse, reply), connection_id + + def datagram_received(self, data: bytes, addr: (str, int)) -> None: + if len(data) < 8: + return + transaction_id = int.from_bytes(data[4:8], byteorder="big", signed=False) + if transaction_id in self.data_queue: + if not self.data_queue[transaction_id].done(): + if data[3] == 3: + return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message)) + return self.data_queue[transaction_id].set_result(data) + log.debug("unexpected packet (can be a response for a previously timed out request): %s", data.hex()) + + def connection_lost(self, exc: Exception = None) -> None: + self.transport = None + + +class TrackerClient: + event_controller = StreamController() + + def __init__(self, node_id, announce_port, get_servers, timeout=10.0): + self.client = UDPTrackerClientProtocol(timeout=timeout) + self.transport = None + self.peer_id = make_peer_id(node_id.hex() if node_id else None) + self.announce_port = announce_port + self._get_servers = get_servers + self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires + self.tasks = {} + + async def start(self): + self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( + lambda: self.client, local_addr=("0.0.0.0", 0)) + self.event_controller.stream.listen( + lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None) + + def stop(self): + while self.tasks: + self.tasks.popitem()[1].cancel() + if self.transport is not None: + self.transport.close() + self.client = None + self.transport = None + self.event_controller.close() + + def on_hash(self, info_hash, on_announcement=None): + if info_hash not in self.tasks: + task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement)) + task.add_done_callback(lambda *_: self.tasks.pop(info_hash, None)) + self.tasks[info_hash] = task + + async def announce_many(self, *info_hashes, stopped=False): + await asyncio.gather( + *[self._announce_many(server, info_hashes, stopped=stopped) for server in self._get_servers()], + return_exceptions=True) + + async def _announce_many(self, server, info_hashes, stopped=False): + tracker_ip = await resolve_host(*server, 'udp') + still_good_info_hashes = { + info_hash for (info_hash, (next_announcement, _)) in self.results.get(tracker_ip, {}).items() + if time.time() < next_announcement + } + results = await asyncio.gather( + *[self._probe_server(info_hash, tracker_ip, server[1], stopped=stopped) + for info_hash in info_hashes if info_hash not in still_good_info_hashes], + return_exceptions=True) + if results: + errors = sum([1 for result in results if result is None or isinstance(result, Exception)]) + log.info("Tracker: finished announcing %d files to %s:%d, %d errors", len(results), *server, errors) + + async def get_peer_list(self, info_hash, stopped=False, on_announcement=None, no_port=False): + found = [] + probes = [self._probe_server(info_hash, *server, stopped, no_port) for server in self._get_servers()] + for done in asyncio.as_completed(probes): + result = await done + if result is not None: + await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else [])) + found.append(result) + return found + + async def get_kademlia_peer_list(self, info_hash): + responses = await self.get_peer_list(info_hash, no_port=True) + return await announcement_to_kademlia_peers(*responses) + + async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False, no_port=False): + result = None + try: + tracker_host = await resolve_host(tracker_host, tracker_port, 'udp') + except socket.error: + log.warning("DNS failure while resolving tracker host: %s, skipping.", tracker_host) + return + self.results.setdefault(tracker_host, {}) + if info_hash in self.results[tracker_host]: + next_announcement, result = self.results[tracker_host][info_hash] + if time.time() < next_announcement: + return result + try: + result = await self.client.announce( + info_hash, self.peer_id, 0 if no_port else self.announce_port, tracker_host, tracker_port, stopped) + self.results[tracker_host][info_hash] = (time.time() + result.interval, result) + except asyncio.TimeoutError: # todo: this is UDP, timeout is common, we need a better metric for failures + self.results[tracker_host][info_hash] = (time.time() + 60.0, result) + log.debug("Tracker timed out: %s:%d", tracker_host, tracker_port) + return None + log.debug("Announced: %s found %d peers for %s", tracker_host, len(result.peers), info_hash.hex()[:8]) + return result + + +def enqueue_tracker_search(info_hash: bytes, peer_q: asyncio.Queue): + async def on_announcement(announcement: AnnounceResponse): + peers = await announcement_to_kademlia_peers(announcement) + log.info("Found %d peers from tracker for %s", len(peers), info_hash.hex()[:8]) + peer_q.put_nowait(peers) + TrackerClient.event_controller.add(('search', info_hash, on_announcement)) + + +def announcement_to_kademlia_peers(*announcements: AnnounceResponse): + peers = [ + (str(ipaddress.ip_address(peer.address)), peer.port) + for announcement in announcements for peer in announcement.peers if peer.port > 1024 # no privileged or 0 + ] + return get_kademlia_peers_from_hosts(peers) + + +class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production + def __init__(self): + self.transport = None + self.known_conns = set() + self.peers = {} + + def connection_made(self, transport: asyncio.DatagramTransport) -> None: + self.transport = transport + + def add_peer(self, info_hash, ip_address: str, port: int): + self.peers.setdefault(info_hash, []) + self.peers[info_hash].append(encode_peer(ip_address, port)) + + def datagram_received(self, data: bytes, addr: (str, int)) -> None: + if len(data) < 16: + return + action = int.from_bytes(data[8:12], "big", signed=False) + if action == 0: + req = decode(ConnectRequest, data) + connection_id = random.getrandbits(32) + self.known_conns.add(connection_id) + return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), addr) + elif action == 1: + req = decode(AnnounceRequest, data) + if req.connection_id not in self.known_conns: + resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00')) + else: + compact_address = encode_peer(addr[0], req.port) + if req.event != 3: + self.add_peer(req.info_hash, addr[0], req.port) + elif compact_address in self.peers.get(req.info_hash, []): + self.peers[req.info_hash].remove(compact_address) + peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]] + resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers)) + return self.transport.sendto(resp, addr) + + +def encode_peer(ip_address: str, port: int): + compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray()) + return compact_ip + port.to_bytes(2, "big", signed=False) diff --git a/lbry/utils.py b/lbry/utils.py index dc3a6c06e..7a92ccc6a 100644 --- a/lbry/utils.py +++ b/lbry/utils.py @@ -131,21 +131,6 @@ def json_dumps_pretty(obj, **kwargs): return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs) -def cancel_task(task: typing.Optional[asyncio.Task]): - if task and not task.done(): - task.cancel() - - -def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): - for task in tasks: - cancel_task(task) - - -def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): - while tasks: - cancel_task(tasks.pop()) - - def async_timed_cache(duration: int): def wrapper(func): cache: typing.Dict[typing.Tuple, diff --git a/tests/integration/datanetwork/test_file_commands.py b/tests/integration/datanetwork/test_file_commands.py index 08cf070c8..ffde6acc9 100644 --- a/tests/integration/datanetwork/test_file_commands.py +++ b/tests/integration/datanetwork/test_file_commands.py @@ -10,6 +10,7 @@ from lbry.stream.descriptor import StreamDescriptor from lbry.testcase import CommandTestCase from lbry.extras.daemon.components import TorrentSession, BACKGROUND_DOWNLOADER_COMPONENT from lbry.wallet import Transaction +from lbry.torrent.tracker import UDPTrackerServerProtocol class FileCommands(CommandTestCase): @@ -102,6 +103,32 @@ class FileCommands(CommandTestCase): await self.daemon.jsonrpc_get('lbry://foo') self.assertItemCount(await self.daemon.jsonrpc_file_list(), 1) + async def test_tracker_discovery(self): + port = 50990 + server = UDPTrackerServerProtocol() + transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port)) + self.addCleanup(transport.close) + self.daemon.conf.fixed_peers = [] + self.daemon.conf.tracker_servers = [("127.0.0.1", port)] + tx = await self.stream_create('foo', '0.01') + sd_hash = tx['outputs'][0]['value']['source']['sd_hash'] + self.assertNotIn(bytes.fromhex(sd_hash)[:20], server.peers) + server.add_peer(bytes.fromhex(sd_hash)[:20], "127.0.0.1", 5567) + self.assertEqual(1, len(server.peers[bytes.fromhex(sd_hash)[:20]])) + self.assertTrue(await self.daemon.jsonrpc_file_delete(delete_all=True)) + stream = await self.daemon.jsonrpc_get('foo', save_file=True) + await self.wait_files_to_complete() + self.assertEqual(0, stream.blobs_remaining) + self.assertEqual(2, len(server.peers[bytes.fromhex(sd_hash)[:20]])) + self.assertEqual([{'address': '127.0.0.1', + 'node_id': None, + 'tcp_port': 5567, + 'udp_port': None}, + {'address': '127.0.0.1', + 'node_id': None, + 'tcp_port': 4444, + 'udp_port': None}], (await self.daemon.jsonrpc_peer_list(sd_hash))['items']) + async def test_announces(self): # announces on publish self.assertEqual(await self.daemon.storage.get_blobs_to_announce(), []) diff --git a/tests/integration/other/test_cli.py b/tests/integration/other/test_cli.py index ec28bbee3..924504a7e 100644 --- a/tests/integration/other/test_cli.py +++ b/tests/integration/other/test_cli.py @@ -10,7 +10,7 @@ from lbry.extras.daemon.components import ( DATABASE_COMPONENT, DISK_SPACE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT, - LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT + LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT, TRACKER_ANNOUNCER_COMPONENT ) from lbry.extras.daemon.daemon import Daemon @@ -26,7 +26,7 @@ class CLIIntegrationTest(AsyncioTestCase): DATABASE_COMPONENT, DISK_SPACE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT, - LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT + LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT, TRACKER_ANNOUNCER_COMPONENT ) Daemon.component_attributes = {} self.daemon = Daemon(conf) diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index a50a89e54..5943c3874 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -54,7 +54,8 @@ class BlobExchangeTestBase(AsyncioTestCase): download_dir=self.client_dir, wallet=self.client_wallet_dir, save_files=True, - fixed_peers=[] + fixed_peers=[], + tracker_servers=[] ) self.client_config.transaction_cache_size = 10000 self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) diff --git a/tests/unit/components/test_component_manager.py b/tests/unit/components/test_component_manager.py index a237c1ac8..269591d8e 100644 --- a/tests/unit/components/test_component_manager.py +++ b/tests/unit/components/test_component_manager.py @@ -34,6 +34,7 @@ class TestComponentManager(AsyncioTestCase): ], [ components.BackgroundDownloaderComponent, + components.TrackerAnnouncerComponent ] ] self.component_manager = ComponentManager(Config()) @@ -150,6 +151,9 @@ class FakeDelayedFileManager(FakeComponent): async def start(self): await asyncio.sleep(1) + def get_filtered(self): + return [] + class TestComponentManagerProperStart(AdvanceTimeTestCase): diff --git a/tests/unit/lbrynet_daemon/test_allowed_origin.py b/tests/unit/lbrynet_daemon/test_allowed_origin.py index b70fe6d3f..d7e31ea4d 100644 --- a/tests/unit/lbrynet_daemon/test_allowed_origin.py +++ b/tests/unit/lbrynet_daemon/test_allowed_origin.py @@ -11,7 +11,7 @@ from lbry.extras.daemon.components import ( DATABASE_COMPONENT, DISK_SPACE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT, - LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT + LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT, TRACKER_ANNOUNCER_COMPONENT ) from lbry.extras.daemon.daemon import Daemon @@ -72,7 +72,7 @@ class TestAccessHeaders(AsyncioTestCase): DATABASE_COMPONENT, DISK_SPACE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT, - LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT + LIBTORRENT_COMPONENT, BACKGROUND_DOWNLOADER_COMPONENT, TRACKER_ANNOUNCER_COMPONENT ) Daemon.component_attributes = {} self.daemon = Daemon(conf) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py index 3542c60e4..953eb4c83 100644 --- a/tests/unit/stream/test_managed_stream.py +++ b/tests/unit/stream/test_managed_stream.py @@ -13,13 +13,6 @@ from lbry.stream.descriptor import StreamDescriptor from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase -def get_mock_node(loop): - mock_node = mock.Mock(spec=Node) - mock_node.joined = asyncio.Event(loop=loop) - mock_node.joined.set() - return mock_node - - class TestManagedStream(BlobExchangeTestBase): async def create_stream(self, blob_count: int = 10, file_name='test_file'): self.stream_bytes = b'' diff --git a/tests/unit/torrent/__init__.py b/tests/unit/torrent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py new file mode 100644 index 000000000..32e4846a1 --- /dev/null +++ b/tests/unit/torrent/test_tracker.py @@ -0,0 +1,92 @@ +import asyncio +import random + +from lbry.testcase import AsyncioTestCase +from lbry.dht.peer import KademliaPeer +from lbry.torrent.tracker import CompactIPv4Peer, TrackerClient, enqueue_tracker_search, UDPTrackerServerProtocol, encode_peer + + +class UDPTrackerClientTestCase(AsyncioTestCase): + async def asyncSetUp(self): + self.client_servers_list = [] + self.servers = {} + self.client = TrackerClient(b"\x00" * 48, 4444, lambda: self.client_servers_list, timeout=1) + await self.client.start() + self.addCleanup(self.client.stop) + await self.add_server() + + async def add_server(self, port=None, add_to_client=True): + port = port or len(self.servers) + 59990 + assert port not in self.servers + server = UDPTrackerServerProtocol() + self.servers[port] = server + transport, _ = await self.loop.create_datagram_endpoint(lambda: server, local_addr=("127.0.0.1", port)) + self.addCleanup(transport.close) + if add_to_client: + self.client_servers_list.append(("127.0.0.1", port)) + + async def test_announce(self): + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + announcement = (await self.client.get_peer_list(info_hash))[0] + self.assertEqual(announcement.seeders, 1) + self.assertEqual(announcement.peers, + [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) + + async def test_announce_many_info_hashes_to_many_servers_with_bad_one_and_dns_error(self): + await asyncio.gather(*[self.add_server() for _ in range(3)]) + self.client_servers_list.append(("no.it.does.not.exist", 7070)) + self.client_servers_list.append(("127.0.0.2", 7070)) + info_hashes = [random.getrandbits(160).to_bytes(20, "big", signed=False) for _ in range(5)] + await self.client.announce_many(*info_hashes) + for server in self.servers.values(): + self.assertDictEqual( + server.peers, { + info_hash: [encode_peer("127.0.0.1", self.client.announce_port)] for info_hash in info_hashes + }) + + async def test_announce_using_helper_function(self): + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + queue = asyncio.Queue() + enqueue_tracker_search(info_hash, queue) + peers = await queue.get() + self.assertEqual(peers, [KademliaPeer('127.0.0.1', None, None, 4444, allow_localhost=True)]) + + async def test_error(self): + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + await self.client.get_peer_list(info_hash) + list(self.servers.values())[0].known_conns.clear() + self.client.results.clear() + with self.assertRaises(Exception) as err: + await self.client.get_peer_list(info_hash) + self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00') + + async def test_multiple_servers(self): + await asyncio.gather(*[self.add_server() for _ in range(10)]) + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + await self.client.get_peer_list(info_hash) + for server in self.servers.values(): + self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) + + async def test_multiple_servers_with_bad_one(self): + await asyncio.gather(*[self.add_server() for _ in range(10)]) + self.client_servers_list.append(("127.0.0.2", 7070)) + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + await self.client.get_peer_list(info_hash) + for server in self.servers.values(): + self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) + + async def test_multiple_servers_with_different_peers_across_helper_function(self): + # this is how the downloader uses it + await asyncio.gather(*[self.add_server() for _ in range(10)]) + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + fake_peers = [] + for server in self.servers.values(): + for _ in range(10): + peer = (f"127.0.0.{random.randint(1, 255)}", random.randint(2000, 65500)) + fake_peers.append(peer) + server.add_peer(info_hash, *peer) + peer_q = asyncio.Queue() + enqueue_tracker_search(info_hash, peer_q) + await asyncio.sleep(0) + await asyncio.gather(*self.client.tasks.values()) + self.assertEqual(11, peer_q.qsize())