cache results, save interval on tracker

This commit is contained in:
Victor Shyba 2022-03-08 00:58:18 -03:00
parent 43e50f7f04
commit 2d9c5742c7
4 changed files with 25 additions and 17 deletions

View file

@ -3,7 +3,6 @@ import os
import asyncio import asyncio
import logging import logging
import binascii import binascii
import time
import typing import typing
import base58 import base58
@ -730,14 +729,12 @@ class TrackerAnnouncerComponent(Component):
async def announce_forever(self): async def announce_forever(self):
while True: while True:
to_sleep = 60 * 1 to_sleep = 6
for file in self.file_manager.get_filtered(): for file in self.file_manager.get_filtered():
if not file.downloader: if not file.downloader:
continue continue
next_announce = file.downloader.next_tracker_announce_time self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
if next_announce is None or next_announce <= time.time(): await asyncio.sleep(to_sleep)
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
await asyncio.sleep(to_sleep + 1)
async def start(self): async def start(self):
node = self.component_manager.get_component(DHT_COMPONENT) \ node = self.component_manager.get_component(DHT_COMPONENT) \

View file

@ -41,7 +41,6 @@ class StreamDownloader:
self.added_fixed_peers = False self.added_fixed_peers = False
self.time_to_descriptor: typing.Optional[float] = None self.time_to_descriptor: typing.Optional[float] = None
self.time_to_first_bytes: typing.Optional[float] = None self.time_to_first_bytes: typing.Optional[float] = None
self.next_tracker_announce_time = None
async def cached_read_blob(blob_info: 'BlobInfo') -> bytes: async def cached_read_blob(blob_info: 'BlobInfo') -> bytes:
return await self.read_blob(blob_info, 2) return await self.read_blob(blob_info, 2)
@ -72,8 +71,6 @@ class StreamDownloader:
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers] peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
peers = await get_kademlia_peers_from_hosts(peers) peers = await get_kademlia_peers_from_hosts(peers)
log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8]) log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8])
self.next_tracker_announce_time = min(time() + announcement.interval,
self.next_tracker_announce_time or 1 << 64)
self.peer_queue.put_nowait(peers) self.peer_queue.put_nowait(peers)
async def load_descriptor(self, connection_id: int = 0): async def load_descriptor(self, connection_id: int = 0):
@ -105,7 +102,8 @@ class StreamDownloader:
self.accumulate_task.cancel() self.accumulate_task.cancel()
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
await self.add_fixed_peers() await self.add_fixed_peers()
subscribe_hash(self.sd_hash, self._process_announcement) subscribe_hash(
bytes.fromhex(self.sd_hash), lambda result: asyncio.ensure_future(self._process_announcement(result)))
# start searching for peers for the sd hash # start searching for peers for the sd hash
self.search_queue.put_nowait(self.sd_hash) self.search_queue.put_nowait(self.sd_hash)
log.info("searching for peers for stream %s", self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash)

View file

@ -2,6 +2,7 @@ import random
import struct import struct
import asyncio import asyncio
import logging import logging
import time
from collections import namedtuple from collections import namedtuple
from lbry.utils import resolve_host, async_timed_cache from lbry.utils import resolve_host, async_timed_cache
@ -120,12 +121,14 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
class TrackerClient: class TrackerClient:
EVENT_CONTROLLER = StreamController() EVENT_CONTROLLER = StreamController()
def __init__(self, node_id, announce_port, servers): def __init__(self, node_id, announce_port, servers):
self.client = UDPTrackerClientProtocol() self.client = UDPTrackerClientProtocol()
self.transport = None self.transport = None
self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False) self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
self.announce_port = announce_port self.announce_port = announce_port
self.servers = servers self.servers = servers
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
async def start(self): async def start(self):
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
@ -145,24 +148,33 @@ class TrackerClient:
async def get_peer_list(self, info_hash, stopped=False): async def get_peer_list(self, info_hash, stopped=False):
found = [] found = []
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
found.extend(await done) result = await done
if result is not None:
self.EVENT_CONTROLLER.add((info_hash, result))
found.append(result)
return found return found
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False): async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
result = None
if info_hash in self.results:
next_announcement, result = self.results[info_hash]
if time.time() < next_announcement:
return result
try: try:
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp') tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
result = await self.client.announce( result = await self.client.announce(
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped) info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.info("Tracker timed out: %s:%d", tracker_host, tracker_port) log.info("Tracker timed out: %s:%d", tracker_host, tracker_port)
return [] return None
finally:
self.results[info_hash] = (time.time() + (result.interval if result else 60.0), result)
log.info("Announced to tracker. Found %d peers for %s on %s", log.info("Announced to tracker. Found %d peers for %s on %s",
len(result.peers), info_hash.hex()[:8], tracker_host) len(result.peers), info_hash.hex()[:8], tracker_host)
self.EVENT_CONTROLLER.add((info_hash, result))
return result return result
def subscribe_hash(hash, on_data): def subscribe_hash(hash: bytes, on_data):
TrackerClient.EVENT_CONTROLLER.add(('search', bytes.fromhex(hash))) TrackerClient.EVENT_CONTROLLER.add(('search', hash))
TrackerClient.EVENT_CONTROLLER.stream.listen( TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == hash).add_done_callback(
lambda request: on_data(request[1]) if request[0].hex() == hash else None) lambda request: on_data(request.result()[1]))

View file

@ -70,6 +70,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
self.server.known_conns.clear() self.server.known_conns.clear()
self.client.results.clear()
with self.assertRaises(Exception) as err: with self.assertRaises(Exception) as err:
await self.client.get_peer_list(info_hash) await self.client.get_peer_list(info_hash)
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00') self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')