move server implementation to tracker module

This commit is contained in:
Victor Shyba 2022-03-12 02:17:37 -03:00
parent 2e85e29ef1
commit c276053301
2 changed files with 44 additions and 45 deletions

View file

@ -6,6 +6,7 @@ import asyncio
import logging import logging
import time import time
from collections import namedtuple from collections import namedtuple
from functools import reduce
from typing import Optional from typing import Optional
from lbry.utils import resolve_host, async_timed_cache, cache_concurrent from lbry.utils import resolve_host, async_timed_cache, cache_concurrent
@ -220,3 +221,45 @@ class TrackerClient:
def subscribe_hash(info_hash: bytes, on_data): def subscribe_hash(info_hash: bytes, on_data):
TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data)) TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data))
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
def __init__(self):
self.transport = None
self.known_conns = set()
self.peers = {}
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.transport = transport
def add_peer(self, info_hash, ip_address: str, port: int):
self.peers.setdefault(info_hash, [])
self.peers[info_hash].append(encode_peer(ip_address, port))
def datagram_received(self, data: bytes, address: (str, int)) -> None:
if len(data) < 16:
return
action = int.from_bytes(data[8:12], "big", signed=False)
if action == 0:
req = decode(ConnectRequest, data)
connection_id = random.getrandbits(32)
self.known_conns.add(connection_id)
return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), address)
elif action == 1:
req = decode(AnnounceRequest, data)
if req.connection_id not in self.known_conns:
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
else:
compact_address = encode_peer(address[0], req.port)
if req.event != 3:
self.add_peer(req.info_hash, address[0], req.port)
elif compact_address in self.peers.get(req.info_hash, []):
self.peers[req.info_hash].remove(compact_address)
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
return self.transport.sendto(resp, address)
def encode_peer(ip_address: str, port: int):
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
return compact_ip + port.to_bytes(2, "big", signed=False)

View file

@ -1,52 +1,8 @@
import asyncio import asyncio
import random import random
from functools import reduce
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.torrent.tracker import encode, decode, CompactIPv4Peer, ConnectRequest, \ from lbry.torrent.tracker import CompactIPv4Peer, TrackerClient, subscribe_hash, UDPTrackerServerProtocol, encode_peer
ConnectResponse, AnnounceRequest, ErrorResponse, AnnounceResponse, TrackerClient, subscribe_hash
class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not suitable for production
def __init__(self):
self.transport = None
self.known_conns = set()
self.peers = {}
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
self.transport = transport
def add_peer(self, info_hash, ip_address: str, port: int):
self.peers.setdefault(info_hash, [])
self.peers[info_hash].append(encode_peer(ip_address, port))
def datagram_received(self, data: bytes, address: (str, int)) -> None:
if len(data) < 16:
return
action = int.from_bytes(data[8:12], "big", signed=False)
if action == 0:
req = decode(ConnectRequest, data)
connection_id = random.getrandbits(32)
self.known_conns.add(connection_id)
return self.transport.sendto(encode(ConnectResponse(0, req.transaction_id, connection_id)), address)
elif action == 1:
req = decode(AnnounceRequest, data)
if req.connection_id not in self.known_conns:
resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00'))
else:
compact_address = encode_peer(address[0], req.port)
if req.event != 3:
self.add_peer(req.info_hash, address[0], req.port)
elif compact_address in self.peers.get(req.info_hash, []):
self.peers[req.info_hash].remove(compact_address)
peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]]
resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers))
return self.transport.sendto(resp, address)
def encode_peer(ip_address: str, port: int):
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray())
return compact_ip + port.to_bytes(2, "big", signed=False)
class UDPTrackerClientTestCase(AsyncioTestCase): class UDPTrackerClientTestCase(AsyncioTestCase):