evented system for tracker announcements
This commit is contained in:
parent
7acaecaed2
commit
9e9a64d989
4 changed files with 89 additions and 63 deletions
|
@ -28,6 +28,8 @@ from lbry.extras.daemon.storage import SQLiteStorage
|
||||||
from lbry.torrent.torrent_manager import TorrentManager
|
from lbry.torrent.torrent_manager import TorrentManager
|
||||||
from lbry.wallet import WalletManager
|
from lbry.wallet import WalletManager
|
||||||
from lbry.wallet.usage_payment import WalletServerPayer
|
from lbry.wallet.usage_payment import WalletServerPayer
|
||||||
|
from lbry.torrent.tracker import TrackerClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lbry.torrent.session import TorrentSession
|
from lbry.torrent.session import TorrentSession
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -720,6 +722,7 @@ class TrackerAnnouncerComponent(Component):
|
||||||
super().__init__(component_manager)
|
super().__init__(component_manager)
|
||||||
self.file_manager = None
|
self.file_manager = None
|
||||||
self.announce_task = None
|
self.announce_task = None
|
||||||
|
self.tracker_client: typing.Optional[TrackerClient] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self):
|
def component(self):
|
||||||
|
@ -733,12 +736,15 @@ class TrackerAnnouncerComponent(Component):
|
||||||
continue
|
continue
|
||||||
next_announce = file.downloader.next_tracker_announce_time
|
next_announce = file.downloader.next_tracker_announce_time
|
||||||
if next_announce is None or next_announce <= time.time():
|
if next_announce is None or next_announce <= time.time():
|
||||||
await file.downloader.refresh_from_trackers(False)
|
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
|
||||||
else:
|
|
||||||
to_sleep = min(to_sleep, next_announce - time.time())
|
|
||||||
await asyncio.sleep(to_sleep + 1)
|
await asyncio.sleep(to_sleep + 1)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
node = self.component_manager.get_component(DHT_COMPONENT) \
|
||||||
|
if self.component_manager.has_component(DHT_COMPONENT) else None
|
||||||
|
node_id = node.protocol.node_id if node else None
|
||||||
|
self.tracker_client = TrackerClient(node_id, self.conf.tcp_port, self.conf.tracker_servers)
|
||||||
|
await self.tracker_client.start()
|
||||||
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
|
self.file_manager = self.component_manager.get_component(FILE_MANAGER_COMPONENT)
|
||||||
self.announce_task = asyncio.create_task(self.announce_forever())
|
self.announce_task = asyncio.create_task(self.announce_forever())
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,10 @@ from lbry.error import DownloadSDTimeoutError
|
||||||
from lbry.utils import lru_cache_concurrent
|
from lbry.utils import lru_cache_concurrent
|
||||||
from lbry.stream.descriptor import StreamDescriptor
|
from lbry.stream.descriptor import StreamDescriptor
|
||||||
from lbry.blob_exchange.downloader import BlobDownloader
|
from lbry.blob_exchange.downloader import BlobDownloader
|
||||||
from lbry.torrent.tracker import get_peer_list
|
from lbry.torrent.tracker import subscribe_hash
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
|
from lbry.torrent.tracker import AnnounceResponse
|
||||||
from lbry.conf import Config
|
from lbry.conf import Config
|
||||||
from lbry.dht.node import Node
|
from lbry.dht.node import Node
|
||||||
from lbry.blob.blob_manager import BlobManager
|
from lbry.blob.blob_manager import BlobManager
|
||||||
|
@ -67,32 +68,13 @@ class StreamDownloader:
|
||||||
fixed_peers = await get_kademlia_peers_from_hosts(self.config.fixed_peers)
|
fixed_peers = await get_kademlia_peers_from_hosts(self.config.fixed_peers)
|
||||||
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _add_fixed_peers, fixed_peers)
|
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _add_fixed_peers, fixed_peers)
|
||||||
|
|
||||||
async def refresh_from_trackers(self, save_peers=True):
|
async def _process_announcement(self, announcement: 'AnnounceResponse'):
|
||||||
if not self.config.tracker_servers:
|
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
||||||
return
|
peers = await get_kademlia_peers_from_hosts(peers)
|
||||||
node_id = self.node.protocol.node_id if self.node else None
|
log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8])
|
||||||
port = self.config.tcp_port
|
self.next_tracker_announce_time = min(time() + announcement.interval,
|
||||||
for server in self.config.tracker_servers:
|
self.next_tracker_announce_time or 1 << 64)
|
||||||
try:
|
self.peer_queue.put_nowait(peers)
|
||||||
announcement = await get_peer_list(
|
|
||||||
bytes.fromhex(self.sd_hash)[:20], node_id, port, server[0], server[1])
|
|
||||||
log.info("Announced %s to %s", self.sd_hash[:8], server)
|
|
||||||
self.next_tracker_announce_time = max(self.next_tracker_announce_time or 0,
|
|
||||||
time.time() + announcement.interval)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.warning("Tracker timed out: %s", server)
|
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
log.exception("Unexpected error querying tracker %s", server)
|
|
||||||
return
|
|
||||||
if not save_peers:
|
|
||||||
return
|
|
||||||
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
|
||||||
peers = await get_kademlia_peers_from_hosts(peers)
|
|
||||||
log.info("Found %d peers from tracker %s for %s", len(peers), server, self.sd_hash[:8])
|
|
||||||
self.peer_queue.put_nowait(peers)
|
|
||||||
|
|
||||||
async def load_descriptor(self, connection_id: int = 0):
|
async def load_descriptor(self, connection_id: int = 0):
|
||||||
# download or get the sd blob
|
# download or get the sd blob
|
||||||
|
@ -123,7 +105,7 @@ class StreamDownloader:
|
||||||
self.accumulate_task.cancel()
|
self.accumulate_task.cancel()
|
||||||
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
||||||
await self.add_fixed_peers()
|
await self.add_fixed_peers()
|
||||||
asyncio.ensure_future(self.refresh_from_trackers())
|
subscribe_hash(self.sd_hash, self._process_announcement)
|
||||||
# start searching for peers for the sd hash
|
# start searching for peers for the sd hash
|
||||||
self.search_queue.put_nowait(self.sd_hash)
|
self.search_queue.put_nowait(self.sd_hash)
|
||||||
log.info("searching for peers for stream %s", self.sd_hash)
|
log.info("searching for peers for stream %s", self.sd_hash)
|
||||||
|
|
|
@ -4,9 +4,11 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from lbry.utils import resolve_host
|
from lbry.utils import resolve_host, async_timed_cache
|
||||||
|
from lbry.wallet.stream import StreamController
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
CONNECTION_EXPIRES_AFTER_SECONDS = 360
|
||||||
# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html
|
# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html
|
||||||
ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"])
|
ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"])
|
||||||
ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"])
|
ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"])
|
||||||
|
@ -77,17 +79,19 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
||||||
return decode(ConnectResponse,
|
return decode(ConnectResponse,
|
||||||
await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port))
|
await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port))
|
||||||
|
|
||||||
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, connection_id=None, stopped=False):
|
@async_timed_cache(CONNECTION_EXPIRES_AFTER_SECONDS)
|
||||||
if not connection_id:
|
async def ensure_connection_id(self, peer_id, tracker_ip, tracker_port):
|
||||||
reply = await self.connect(tracker_ip, tracker_port)
|
# peer_id is just to ensure cache coherency
|
||||||
connection_id = reply.connection_id
|
return (await self.connect(tracker_ip, tracker_port)).connection_id
|
||||||
|
|
||||||
|
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, stopped=False):
|
||||||
|
connection_id = await self.ensure_connection_id(peer_id, tracker_ip, tracker_port)
|
||||||
# this should make the key deterministic but unique per info hash + peer id
|
# this should make the key deterministic but unique per info hash + peer id
|
||||||
key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port
|
key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port
|
||||||
transaction_id = random.getrandbits(32)
|
transaction_id = random.getrandbits(32)
|
||||||
req = AnnounceRequest(
|
req = AnnounceRequest(
|
||||||
connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port)
|
connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 3 if stopped else 1, 0, key, -1, port)
|
||||||
reply = await self.request(req, tracker_ip, tracker_port)
|
return decode(AnnounceResponse, await self.request(req, tracker_ip, tracker_port))
|
||||||
return decode(AnnounceResponse, reply), connection_id
|
|
||||||
|
|
||||||
async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None):
|
async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None):
|
||||||
if not connection_id:
|
if not connection_id:
|
||||||
|
@ -107,20 +111,52 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
||||||
if data[3] == 3:
|
if data[3] == 3:
|
||||||
return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message))
|
return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message))
|
||||||
return self.data_queue[transaction_id].set_result(data)
|
return self.data_queue[transaction_id].set_result(data)
|
||||||
print("error", data.hex())
|
log.debug("unexpected packet (can be a response for a previously timed out request): %s", data.hex())
|
||||||
|
|
||||||
def connection_lost(self, exc: Exception = None) -> None:
|
def connection_lost(self, exc: Exception = None) -> None:
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
|
||||||
|
|
||||||
async def get_peer_list(info_hash, node_id, port, tracker_ip, tracker_port, stopped=False):
|
class TrackerClient:
|
||||||
node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
EVENT_CONTROLLER = StreamController()
|
||||||
tracker_ip = await resolve_host(tracker_ip, tracker_port, 'udp')
|
def __init__(self, node_id, announce_port, servers):
|
||||||
proto = UDPTrackerClientProtocol()
|
self.client = UDPTrackerClientProtocol()
|
||||||
transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(lambda: proto, local_addr=("0.0.0.0", 0))
|
self.transport = None
|
||||||
try:
|
self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
reply, _ = await proto.announce(info_hash, node_id, port, tracker_ip, tracker_port, stopped=stopped)
|
self.announce_port = announce_port
|
||||||
return reply
|
self.servers = servers
|
||||||
finally:
|
|
||||||
if not transport.is_closing():
|
async def start(self):
|
||||||
transport.close()
|
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
lambda: self.client, local_addr=("0.0.0.0", 0))
|
||||||
|
self.EVENT_CONTROLLER.stream.listen(lambda request: self.on_hash(request[1]) if request[0] == 'search' else None)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.transport is not None:
|
||||||
|
self.transport.close()
|
||||||
|
self.client = None
|
||||||
|
self.transport = None
|
||||||
|
self.EVENT_CONTROLLER.close()
|
||||||
|
|
||||||
|
def on_hash(self, info_hash):
|
||||||
|
asyncio.ensure_future(self.get_peer_list(info_hash))
|
||||||
|
|
||||||
|
async def get_peer_list(self, info_hash, stopped=False):
|
||||||
|
found = []
|
||||||
|
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||||
|
found.append(await done)
|
||||||
|
return found
|
||||||
|
|
||||||
|
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
|
||||||
|
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
|
||||||
|
result = await self.client.announce(
|
||||||
|
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
|
||||||
|
log.info("Announced to tracker. Found %d peers for %s on %s",
|
||||||
|
len(result.peers), info_hash.hex()[:8], tracker_host)
|
||||||
|
self.EVENT_CONTROLLER.add((info_hash, result))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def subscribe_hash(hash, on_data):
|
||||||
|
TrackerClient.EVENT_CONTROLLER.add(('search', bytes.fromhex(hash)))
|
||||||
|
TrackerClient.EVENT_CONTROLLER.stream.listen(lambda request: on_data(request[1]) if request[0] == hash else None)
|
||||||
|
|
|
@ -3,8 +3,8 @@ import random
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
from lbry.testcase import AsyncioTestCase
|
from lbry.testcase import AsyncioTestCase
|
||||||
from lbry.torrent.tracker import UDPTrackerClientProtocol, encode, decode, CompactIPv4Peer, ConnectRequest, \
|
from lbry.torrent.tracker import encode, decode, CompactIPv4Peer, ConnectRequest, \
|
||||||
ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, get_peer_list
|
ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, TrackerClient, subscribe_hash
|
||||||
|
|
||||||
|
|
||||||
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
|
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
|
||||||
|
@ -44,30 +44,32 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su
|
||||||
|
|
||||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
transport, _ = await self.loop.create_datagram_endpoint(UDPTrackerServerProtocol, local_addr=("127.0.0.1", 59900))
|
self.server = UDPTrackerServerProtocol()
|
||||||
self.addCleanup(transport.close)
|
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.server, local_addr=("127.0.0.1", 59900))
|
||||||
self.client = UDPTrackerClientProtocol()
|
|
||||||
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.client, local_addr=("127.0.0.1", 0))
|
|
||||||
self.addCleanup(transport.close)
|
self.addCleanup(transport.close)
|
||||||
|
self.client = TrackerClient(b"\x00" * 48, 4444, [("127.0.0.1", 59900)])
|
||||||
|
await self.client.start()
|
||||||
|
self.addCleanup(self.client.stop)
|
||||||
|
|
||||||
async def test_announce(self):
|
async def test_announce(self):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
announcement = (await self.client.get_peer_list(info_hash))[0]
|
||||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900)
|
|
||||||
self.assertEqual(announcement.seeders, 1)
|
self.assertEqual(announcement.seeders, 1)
|
||||||
self.assertEqual(announcement.peers,
|
self.assertEqual(announcement.peers,
|
||||||
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||||
|
|
||||||
async def test_announce_using_helper_function(self):
|
async def test_announce_using_helper_function(self):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
announcemenet = await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900)
|
queue = asyncio.Queue()
|
||||||
peers = announcemenet.peers
|
subscribe_hash(info_hash, queue.put_nowait)
|
||||||
|
announcement = await queue.get()
|
||||||
|
peers = announcement.peers
|
||||||
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||||
self.assertEqual((await get_peer_list(info_hash, None, 4444, "127.0.0.1", 59900, stopped=True)).peers, [])
|
|
||||||
|
|
||||||
async def test_error(self):
|
async def test_error(self):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
await self.client.get_peer_list(info_hash)
|
||||||
|
self.server.known_conns.clear()
|
||||||
with self.assertRaises(Exception) as err:
|
with self.assertRaises(Exception) as err:
|
||||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900, connection_id=10)
|
await self.client.get_peer_list(info_hash)
|
||||||
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
||||||
|
|
Loading…
Add table
Reference in a new issue