forked from LBRYCommunity/lbry-sdk
add torrent udp tracker client, server and tests
This commit is contained in:
parent
d0e715feb9
commit
4a0bf8a702
3 changed files with 168 additions and 0 deletions
141
lbry/torrent/tracker.py
Normal file
141
lbry/torrent/tracker.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import random
|
||||
import struct
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"])
|
||||
ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"])
|
||||
AnnounceRequest = namedtuple("AnnounceRequest",
|
||||
["connection_id", "action", "transaction_id", "info_hash", "peer_id", "downloaded", "left",
|
||||
"uploaded", "event", "ip_addr", "key", "num_want", "port"])
|
||||
AnnounceResponse = namedtuple("AnnounceResponse",
|
||||
["action", "transaction_id", "interval", "leechers", "seeders", "peers"])
|
||||
CompactIPv4Peer = namedtuple("CompactPeer", ["address", "port"])
|
||||
ScrapeRequest = namedtuple("ScrapeRequest", ["connection_id", "action", "transaction_id", "infohashes"])
|
||||
ScrapeResponse = namedtuple("ScrapeResponse", ["action", "transaction_id", "items"])
|
||||
ScrapeResponseItem = namedtuple("ScrapeResponseItem", ["seeders", "completed", "leechers"])
|
||||
ErrorResponse = namedtuple("ErrorResponse", ["action", "transaction_id", "message"])
|
||||
STRUCTS = {
|
||||
ConnectRequest: struct.Struct(">QII"),
|
||||
ConnectResponse: struct.Struct(">IIQ"),
|
||||
AnnounceRequest: struct.Struct(">QII20s20sQQQIIIiH"),
|
||||
AnnounceResponse: struct.Struct(">IIIII"),
|
||||
CompactIPv4Peer: struct.Struct(">IH"),
|
||||
ScrapeRequest: struct.Struct(">QII"),
|
||||
ScrapeResponse: struct.Struct(">II"),
|
||||
ScrapeResponseItem: struct.Struct(">III"),
|
||||
ErrorResponse: struct.Struct(">II")
|
||||
}
|
||||
|
||||
|
||||
def decode(cls, data, offset=0):
|
||||
decoder = STRUCTS[cls]
|
||||
if cls == AnnounceResponse:
|
||||
return AnnounceResponse(*decoder.unpack_from(data, offset),
|
||||
peers=[decode(CompactIPv4Peer, data, index) for index in range(20, len(data), 6)])
|
||||
elif cls == ScrapeResponse:
|
||||
return ScrapeResponse(*decoder.unpack_from(data, offset),
|
||||
items=[decode(ScrapeResponseItem, data, index) for index in range(8, len(data), 12)])
|
||||
elif cls == ErrorResponse:
|
||||
return ErrorResponse(*decoder.unpack_from(data, offset), data[decoder.size:])
|
||||
return cls(*decoder.unpack_from(data, offset))
|
||||
|
||||
|
||||
def encode(obj):
|
||||
if isinstance(obj, ScrapeRequest):
|
||||
return STRUCTS[ScrapeRequest].pack(*obj[:-1]) + b''.join(obj.infohashes)
|
||||
elif isinstance(obj, ErrorResponse):
|
||||
return STRUCTS[ErrorResponse].pack(*obj[:-1]) + obj.message
|
||||
elif isinstance(obj, AnnounceResponse):
|
||||
return STRUCTS[AnnounceResponse].pack(*obj[:-1]) + b''.join([encode(peer) for peer in obj.peers])
|
||||
return STRUCTS[type(obj)].pack(*obj)
|
||||
|
||||
|
||||
class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
self.data_queue = {}
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||
self.transport = transport
|
||||
|
||||
async def request(self, obj, tracker_ip, tracker_port):
|
||||
self.data_queue[obj.transaction_id] = asyncio.get_running_loop().create_future()
|
||||
self.transport.sendto(encode(obj), (tracker_ip, tracker_port))
|
||||
try:
|
||||
return await asyncio.wait_for(self.data_queue[obj.transaction_id], 3.0)
|
||||
finally:
|
||||
self.data_queue.pop(obj.transaction_id, None)
|
||||
|
||||
async def connect(self, tracker_ip, tracker_port):
|
||||
transaction_id = random.getrandbits(32)
|
||||
return decode(ConnectResponse,
|
||||
await self.request(ConnectRequest(0x41727101980, 0, transaction_id), tracker_ip, tracker_port))
|
||||
|
||||
async def announce(self, info_hash, peer_id, port, tracker_ip, tracker_port, connection_id=None):
|
||||
if not connection_id:
|
||||
reply = await self.connect(tracker_ip, tracker_port)
|
||||
connection_id = reply.connection_id
|
||||
# this should make the key deterministic but unique per info hash + peer id
|
||||
key = int.from_bytes(info_hash[:4], "big") ^ int.from_bytes(peer_id[:4], "big") ^ port
|
||||
transaction_id = random.getrandbits(32)
|
||||
req = AnnounceRequest(connection_id, 1, transaction_id, info_hash, peer_id, 0, 0, 0, 1, 0, key, -1, port)
|
||||
reply = await self.request(req, tracker_ip, tracker_port)
|
||||
return decode(AnnounceResponse, reply), connection_id
|
||||
|
||||
async def scrape(self, infohashes, tracker_ip, tracker_port, connection_id=None):
|
||||
if not connection_id:
|
||||
reply = await self.connect(tracker_ip, tracker_port)
|
||||
connection_id = reply.connection_id
|
||||
transaction_id = random.getrandbits(32)
|
||||
reply = await self.request(
|
||||
ScrapeRequest(connection_id, 2, transaction_id, infohashes), tracker_ip, tracker_port)
|
||||
return decode(ScrapeResponse, reply), connection_id
|
||||
|
||||
def datagram_received(self, data: bytes, addr: (str, int)) -> None:
|
||||
if len(data) < 8:
|
||||
return
|
||||
transaction_id = int.from_bytes(data[4:8], byteorder="big", signed=False)
|
||||
if transaction_id in self.data_queue:
|
||||
if not self.data_queue[transaction_id].done():
|
||||
if data[3] == 3:
|
||||
return self.data_queue[transaction_id].set_exception(Exception(decode(ErrorResponse, data).message))
|
||||
return self.data_queue[transaction_id].set_result(data)
|
||||
print("error", data.hex())
|
||||
|
||||
def connection_lost(self, exc: Exception = None) -> None:
|
||||
self.transport = None
|
||||
|
||||
|
||||
class UDPTrackerServerProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
self.known_conns = set()
|
||||
self.peers = {}
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||
self.transport = transport
|
||||
|
||||
def datagram_received(self, data: bytes, address: (str, int)) -> None:
|
||||
if len(data) < 16:
|
||||
return
|
||||
action = int.from_bytes(data[8:12], "big", signed=False)
|
||||
if action == 0:
|
||||
req = decode(ConnectRequest, data)
|
||||
connection_id = random.getrandbits(32)
|
||||
self.known_conns.add(connection_id)
|
||||
return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), address)
|
||||
elif action == 1:
|
||||
req = decode(AnnounceRequest, data)
|
||||
if req.connection_id not in self.known_conns:
|
||||
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
|
||||
else:
|
||||
self.peers.setdefault(req.info_hash, [])
|
||||
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray())
|
||||
self.peers[req.info_hash].append(compact_ip + req.port.to_bytes(2, "big", signed=False))
|
||||
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
|
||||
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
|
||||
return self.transport.sendto(resp, address)
|
0
tests/unit/torrent/__init__.py
Normal file
0
tests/unit/torrent/__init__.py
Normal file
27
tests/unit/torrent/test_tracker.py
Normal file
27
tests/unit/torrent/test_tracker.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import random
|
||||
from lbry.testcase import AsyncioTestCase
|
||||
from lbry.torrent.tracker import UDPTrackerClientProtocol, UDPTrackerServerProtocol, CompactIPv4Peer
|
||||
|
||||
|
||||
class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
transport, _ = await self.loop.create_datagram_endpoint(UDPTrackerServerProtocol, local_addr=("127.0.0.1", 59900))
|
||||
self.addCleanup(transport.close)
|
||||
self.client = UDPTrackerClientProtocol()
|
||||
transport, _ = await self.loop.create_datagram_endpoint(lambda: self.client, local_addr=("127.0.0.1", 0))
|
||||
self.addCleanup(transport.close)
|
||||
|
||||
async def test_announce(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900)
|
||||
self.assertEqual(announcement.seeders, 1)
|
||||
self.assertEqual(announcement.peers,
|
||||
[CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)])
|
||||
|
||||
async def test_error(self):
|
||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
peer_id = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||
with self.assertRaises(Exception) as err:
|
||||
announcement, _ = await self.client.announce(info_hash, peer_id, 4444, "127.0.0.1", 59900, connection_id=10)
|
||||
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
Loading…
Reference in a new issue