diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index fb1590f6a..14fd3c136 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -3,9 +3,9 @@ import struct import asyncio import logging from collections import namedtuple -from functools import reduce log = logging.getLogger(__name__) +# see: http://bittorrent.org/beps/bep_0015.html and http://xbtt.sourceforge.net/udp_tracker_protocol.html ConnectRequest = namedtuple("ConnectRequest", ["connection_id", "action", "transaction_id"]) ConnectResponse = namedtuple("ConnectResponse", ["action", "transaction_id", "connection_id"]) AnnounceRequest = namedtuple("AnnounceRequest", @@ -108,34 +108,3 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol): def connection_lost(self, exc: Exception = None) -> None: self.transport = None - - -class UDPTrackerServerProtocol(asyncio.DatagramProtocol): - def __init__(self): - self.transport = None - self.known_conns = set() - self.peers = {} - - def connection_made(self, transport: asyncio.DatagramTransport) -> None: - self.transport = transport - - def datagram_received(self, data: bytes, address: (str, int)) -> None: - if len(data) < 16: - return - action = int.from_bytes(data[8:12], "big", signed=False) - if action == 0: - req = decode(ConnectRequest, data) - connection_id = random.getrandbits(32) - self.known_conns.add(connection_id) - return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), address) - elif action == 1: - req = decode(AnnounceRequest, data) - if req.connection_id not in self.known_conns: - resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00')) - else: - self.peers.setdefault(req.info_hash, []) - compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray()) - self.peers[req.info_hash].append(compact_ip + req.port.to_bytes(2, "big", signed=False)) - peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]] - resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers)) - return self.transport.sendto(resp, address) diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index 3d56d0269..9f1ebf106 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -1,6 +1,41 @@ +import asyncio import random +from functools import reduce + from lbry.testcase import AsyncioTestCase -from lbry.torrent.tracker import UDPTrackerClientProtocol, UDPTrackerServerProtocol, CompactIPv4Peer +from lbry.torrent.tracker import UDPTrackerClientProtocol, encode, decode, CompactIPv4Peer, ConnectRequest, \ + ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse + + +class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production + def __init__(self): + self.transport = None + self.known_conns = set() + self.peers = {} + + def connection_made(self, transport: asyncio.DatagramTransport) -> None: + self.transport = transport + + def datagram_received(self, data: bytes, address: (str, int)) -> None: + if len(data) < 16: + return + action = int.from_bytes(data[8:12], "big", signed=False) + if action == 0: + req = decode(ConnectRequest, data) + connection_id = random.getrandbits(32) + self.known_conns.add(connection_id) + return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), address) + elif action == 1: + req = decode(AnnounceRequest, data) + if req.connection_id not in self.known_conns: + resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00')) + else: + self.peers.setdefault(req.info_hash, []) + compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray()) + self.peers[req.info_hash].append(compact_ip + req.port.to_bytes(2, "big", signed=False)) + peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]] + resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers)) + return self.transport.sendto(resp, address) class UDPTrackerClientTestCase(AsyncioTestCase):