forked from LBRYCommunity/lbry-sdk
evented system for tracker announcements
This commit is contained in:
parent
7acaecaed2
commit
9e9a64d989
4 changed files with 89 additions and 63 deletions
|
@ -28,6 +28,8 @@ from lbry.extras.daemon.storage import SQLiteStorage
|
|||
from lbry.torrent.torrent_manager import TorrentManager
|
||||
from lbry.wallet import WalletManager
|
||||
from lbry.wallet.usage_payment import WalletServerPayer
|
||||
from lbry.torrent.tracker import TrackerClient
|
||||
|
||||
try:
|
||||
from lbry.torrent.session import TorrentSession
|
||||
except ImportError:
|
||||
|
@ -720,6 +722,7 @@ class TrackerAnnouncerComponent(Component):
|
|||
super().__init__(component_manager)
|
||||
self.file_manager = None
|
||||
self.announce_task = None
|
||||
self.tracker_client: typing.Optional[TrackerClient] = None
|
||||
|
||||
@property
|
||||
def component(self):
|
||||
|
@ -733,12 +736,15 @@ class TrackerAnnouncerComponent(Component):
|
|||
continue
|
||||
next_announce = file.downloader.next_tracker_announce_time
|
||||
if next_announce is None or next_announce <= time.time():
|
||||
await file.downloader.refresh_from_trackers(False)
|
||||
else:
|
||||
to_sleep = min(to_sleep, next_announce - time.time())
|
||||
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
|
||||
await asyncio.sleep(to_sleep + 1)
|
||||
|
||||
async def start(self):
|
||||
node = self.component_manager.get_component(DHT_COMPONENT) \
|
||||
if self.component_manager.has_component(DHT_COMPONENT) else None
|
||||
node_id = node.protocol.node_id if node else None
|
||||
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers)
|
||||
await self.tracker_client.start()
|
||||
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
|
||||
self.announce_task = asyncio.create_task(self.announce_forever())
|
||||
|
||||
|
|
|
@ -10,9 +10,10 @@ from lbry.error import DownloadSDTimeoutError
|
|||
from lbry.utils import lru_cache_concurrent
|
||||
from lbry.stream.descriptor import StreamDescriptor
|
||||
from lbry.blob_exchange.downloader import BlobDownloader
|
||||
from lbry.torrent.tracker import get_peer_list
|
||||
from lbry.torrent.tracker import subscribe_hash
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbry.torrent.tracker import AnnounceResponse
|
||||
from lbry.conf import Config
|
||||
from lbry.dht.node import Node
|
||||
from lbry.blob.blob_manager import BlobManager
|
||||
|
@ -67,32 +68,13 @@ class StreamDownloader:
|
|||
fixed_peers = await get_kademlia_peers_from_hosts(self.config.fixed_peers)
|
||||
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _add_fixed_peers, fixed_peers)
|
||||
|
||||
async def refresh_from_trackers(self, save_peers=True):
|
||||
if not self.config.tracker_servers:
|
||||
return
|
||||
node_id = self.node.protocol.node_id if self.node else None
|
||||
port = self.config.tcp_port
|
||||
for server in self.config.tracker_servers:
|
||||
try:
|
||||
announcement = await get_peer_list(
|
||||
bytes.fromhex(self.sd_hash)[:20], node_id, port, server[0], server[1])
|
||||
log.info("Announced %s to %s", self.sd_hash[:8], server)
|
||||
self.next_tracker_announce_time = max(self.next_tracker_announce_time or 0,
|
||||
time.time() + announcement.interval)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
log.warning("Tracker timed out: %s", server)
|
||||
return
|
||||
except Exception:
|
||||
log.exception("Unexpected error querying tracker %s", server)
|
||||
return
|
||||
if not save_peers:
|
||||
return
|
||||
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
||||
peers = await get_kademlia_peers_from_hosts(peers)
|
||||
log.info("Found %d peers from tracker %s for %s", len(peers), server, self.sd_hash[:8])
|
||||
self.peer_queue.put_nowait(peers)
|
||||
async def _process_announcement(self, announcement: 'AnnounceResponse'):
|
||||
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
||||
peers = await get_kademlia_peers_from_hosts(peers)
|
||||
log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8])
|
||||
self.next_tracker_announce_time = min(time() + announcement.interval,
|
||||
self.next_tracker_announce_time or 1 << 64)
|
||||
self.peer_queue.put_nowait(peers)
|
||||
|
||||
async def load_descriptor(self, connection_id: int = 0):
|
||||
# download or get the sd blob
|
||||
|
@ -123,7 +105,7 @@ class StreamDownloader:
|
|||
self.accumulate_task.cancel()
|
||||
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
||||
await self.add_fixed_peers()
|
||||
asyncio.ensure_future(self.refresh_from_trackers())
|
||||
subscribe_hash(self.sd_hash, self._process_announcement)
|
||||
# start searching for peers for the sd hash
|
||||
self.search_queue.put_nowait(self.sd_hash)
|
||||
log.info("searching for peers for stream %s", self.sd_hash)
|
||||
|
|
|
@ -4,9 +4,11 @@ import asyncio
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from lbry.utils import resolve_host
|
||||
from lbry.utils import resolve_host, async_timed_cache
|
||||
from lbry.wallet.stream import StreamController
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CONNECTION_EXPIRES_AFTER_SECONDS = 360
|
||||
# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html
|
||||
ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"])
|
||||
ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"])
|
||||
|
@ -77,17 +79,19 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
|||
return decode(ConnectResponse,
|
||||
await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port))
|
||||
|
||||
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, connection_id=None, stopped=False):
|
||||
if not connection_id:
|
||||
reply = await self.connect(tracker_ip, tracker_port)
|
||||
connection_id = reply.connection_id
|
||||
@async_timed_cache(CONNECTION_EXPIRES_AFTER_SECONDS)
|
||||
async def ensure_connection_id(self, peer_id, tracker_ip, tracker_port):
|
||||
# peer_id is just to ensure cache coherency
|
||||
return (await self.connect(tracker_ip, tracker_port)).connection_id
|
||||
|
||||
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, stopped=False):
|
||||
connection_id = await self.ensure_connection_id(peer_id, tracker_ip, tracker_port)
|
||||
# this should make the key deterministic but unique per info hash + peer id
|
||||
key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port
|
||||
transaction_id = random.getrandbits(32)
|
||||
req = AnnounceRequest(
|
||||
connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port)
|
||||
reply = await self.request(req, tracker_ip, tracker_port)
|
||||
return decode(AnnounceResponse, reply), connection_id
|
||||
return decode(AnnounceResponse, await self.request(req, tracker_ip, tracker_port))
|
||||
|
||||
async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None):
|
||||
if not connection_id:
|
||||
|
@ -107,20 +111,52 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
|||
if data[3] == 3:
|
||||
return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message))
|
||||
return self.data_queue[transaction_id].set_result(data)
|
||||
print("error", data.hex())
|
||||
log.debug("unexpected packet (can be a response for a previously timed out request): %s", data.hex())
|
||||
|
||||
def connection_lost(self, exc: Exception = None) -> None:
|
||||
self.transport = None
|
||||
|
||||
|
||||
async def get_peer_list(info_hash, node_id, port, tracker_ip, tracker_port, stopped=False):
|
||||
node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
tracker_ip = await resolve_host(tracker_ip, tracker_port, 'udp')
|
||||
proto = UDPTrackerClientProtocol()
|
||||
transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(lambda: proto, local_addr=("0.0.0.0", 0))
|
||||
try:
|
||||
reply, _ = await proto.announce(info_hash, node_id, port, tracker_ip, tracker_port, stopped=stopped)
|
||||
return reply
|
||||
finally:
|
||||
if not transport.is_closing():
|
||||
transport.close()
|
||||
class TrackerClient:
|
||||
EVENT_CONTROLLER = StreamController()
|
||||
def __init__(self, node_id, announce_port, servers):
|
||||
self.client = UDPTrackerClientProtocol()
|
||||
self.transport = None
|
||||
self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
self.announce_port = announce_port
|
||||
self.servers = servers
|
||||
|
||||
async def start(self):
|
||||
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
lambda: self.client, local_addr=("0.0.0.0", 0))
|
||||
self.EVENT_CONTROLLER.stream.listen(lambda request: self.on_hash(request[1]) if request[0] == 'search' else None)
|
||||
|
||||
def stop(self):
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
self.client = None
|
||||
self.transport = None
|
||||
self.EVENT_CONTROLLER.close()
|
||||
|
||||
def on_hash(self, info_hash):
|
||||
asyncio.ensure_future(self.get_peer_list(info_hash))
|
||||
|
||||
async def get_peer_list(self, info_hash, stopped=False):
|
||||
found = []
|
||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||
found.append(await done)
|
||||
return found
|
||||
|
||||
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
|
||||
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
|
||||
result = await self.client.announce(
|
||||
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
|
||||
log.info("Announced to tracker. Found %d peers for %s on %s",
|
||||
len(result.peers), info_hash.hex()[:8], tracker_host)
|
||||
self.EVENT_CONTROLLER.add((info_hash, result))
|
||||
return result
|
||||
|
||||
|
||||
def subscribe_hash(hash, on_data):
|
||||
TrackerClient.EVENT_CONTROLLER.add(('search', bytes.fromhex(hash)))
|
||||
TrackerClient.EVENT_CONTROLLER.stream.listen(lambda request: on_data(request[1]) if request[0] == hash else None)
|
||||
|
|
|
@ -3,8 +3,8 @@ import random
|
|||
from functools import reduce
|
||||
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.torrent.tracker import UDPTrackerClientProtocol, encode, decode, CompactIPv4Peer, ConnectRequest, \
|
||||
ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, get_peer_list
|
||||
from lbry.torrent.tracker import encode, decode, CompactIPv4Peer, ConnectRequest, \
|
||||
ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, TrackerClient, subscribe_hash
|
||||
|
||||
|
||||
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
|
||||
|
@ -44,30 +44,32 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
|||
|
||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
transport, _ = await self.loop.create_datagram_endpoint(UDPTrackerServerProtocol, local_addr=("127.0.0.1", 59900))
|
||||
self.addCleanup(transport.close)
|
||||
self.client = UDPTrackerClientProtocol()
|
||||
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.client, local_addr=("127.0.0.1", 0))
|
||||
self.server = UDPTrackerServerProtocol()
|
||||
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.server, local_addr=("127.0.0.1", 59900))
|
||||
self.addCleanup(transport.close)
|
||||
self.client = TrackerClient(b"\x00" * 48, 4444, [("127.0.0.1", 59900)])
|
||||
await self.client.start()
|
||||
self.addCleanup(self.client.stop)
|
||||
|
||||
async def test_announce(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900)
|
||||
announcement = (await self.client.get_peer_list(info_hash))[0]
|
||||
self.assertEqual(announcement.seeders, 1)
|
||||
self.assertEqual(announcement.peers,
|
||||
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||
|
||||
async def test_announce_using_helper_function(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
announcemenet = await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900)
|
||||
peers = announcemenet.peers
|
||||
queue = asyncio.Queue()
|
||||
subscribe_hash(info_hash, queue.put_nowait)
|
||||
announcement = await queue.get()
|
||||
peers = announcement.peers
|
||||
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||
self.assertEqual((await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900, stopped=True)).peers, [])
|
||||
|
||||
async def test_error(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
self.server.known_conns.clear()
|
||||
with self.assertRaises(Exception) as err:
|
||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900, connection_id=10)
|
||||
await self.client.get_peer_list(info_hash)
|
||||
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
||||
|
|
Loading…
Reference in a new issue