forked from LBRYCommunity/lbry-sdk
cache results, save interval on tracker
This commit is contained in:
parent
43e50f7f04
commit
2d9c5742c7
4 changed files with 25 additions and 17 deletions
|
@ -3,7 +3,6 @@ import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import binascii
|
import binascii
|
||||||
import time
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import base58
|
import base58
|
||||||
|
@ -730,14 +729,12 @@ class TrackerAnnouncerComponent(Component):
|
||||||
|
|
||||||
async def announce_forever(self):
|
async def announce_forever(self):
|
||||||
while True:
|
while True:
|
||||||
to_sleep = 60 * 1
|
to_sleep = 6
|
||||||
for file in self.file_manager.get_filtered():
|
for file in self.file_manager.get_filtered():
|
||||||
if not file.downloader:
|
if not file.downloader:
|
||||||
continue
|
continue
|
||||||
next_announce = file.downloader.next_tracker_announce_time
|
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
|
||||||
if next_announce is None or next_announce <= time.time():
|
await asyncio.sleep(to_sleep)
|
||||||
self.tracker_client.on_hash(bytes.fromhex(file.sd_hash))
|
|
||||||
await asyncio.sleep(to_sleep + 1)
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
node = self.component_manager.get_component(DHT_COMPONENT) \
|
node = self.component_manager.get_component(DHT_COMPONENT) \
|
||||||
|
|
|
@ -41,7 +41,6 @@ class StreamDownloader:
|
||||||
self.added_fixed_peers = False
|
self.added_fixed_peers = False
|
||||||
self.time_to_descriptor: typing.Optional[float] = None
|
self.time_to_descriptor: typing.Optional[float] = None
|
||||||
self.time_to_first_bytes: typing.Optional[float] = None
|
self.time_to_first_bytes: typing.Optional[float] = None
|
||||||
self.next_tracker_announce_time = None
|
|
||||||
|
|
||||||
async def cached_read_blob(blob_info: 'BlobInfo') -> bytes:
|
async def cached_read_blob(blob_info: 'BlobInfo') -> bytes:
|
||||||
return await self.read_blob(blob_info, 2)
|
return await self.read_blob(blob_info, 2)
|
||||||
|
@ -72,8 +71,6 @@ class StreamDownloader:
|
||||||
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
peers = [(str(ipaddress.ip_address(peer.address)), peer.port) for peer in announcement.peers]
|
||||||
peers = await get_kademlia_peers_from_hosts(peers)
|
peers = await get_kademlia_peers_from_hosts(peers)
|
||||||
log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8])
|
log.info("Found %d peers from tracker for %s", len(peers), self.sd_hash[:8])
|
||||||
self.next_tracker_announce_time = min(time() + announcement.interval,
|
|
||||||
self.next_tracker_announce_time or 1 << 64)
|
|
||||||
self.peer_queue.put_nowait(peers)
|
self.peer_queue.put_nowait(peers)
|
||||||
|
|
||||||
async def load_descriptor(self, connection_id: int = 0):
|
async def load_descriptor(self, connection_id: int = 0):
|
||||||
|
@ -105,7 +102,8 @@ class StreamDownloader:
|
||||||
self.accumulate_task.cancel()
|
self.accumulate_task.cancel()
|
||||||
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
||||||
await self.add_fixed_peers()
|
await self.add_fixed_peers()
|
||||||
subscribe_hash(self.sd_hash, self._process_announcement)
|
subscribe_hash(
|
||||||
|
bytes.fromhex(self.sd_hash), lambda result: asyncio.ensure_future(self._process_announcement(result)))
|
||||||
# start searching for peers for the sd hash
|
# start searching for peers for the sd hash
|
||||||
self.search_queue.put_nowait(self.sd_hash)
|
self.search_queue.put_nowait(self.sd_hash)
|
||||||
log.info("searching for peers for stream %s", self.sd_hash)
|
log.info("searching for peers for stream %s", self.sd_hash)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import random
|
||||||
import struct
|
import struct
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from lbry.utils import resolve_host, async_timed_cache
|
from lbry.utils import resolve_host, async_timed_cache
|
||||||
|
@ -120,12 +121,14 @@ class UDPTrackerClientProtocol(asyncio.DatagramProtocol):
|
||||||
|
|
||||||
class TrackerClient:
|
class TrackerClient:
|
||||||
EVENT_CONTROLLER = StreamController()
|
EVENT_CONTROLLER = StreamController()
|
||||||
|
|
||||||
def __init__(self, node_id, announce_port, servers):
|
def __init__(self, node_id, announce_port, servers):
|
||||||
self.client = UDPTrackerClientProtocol()
|
self.client = UDPTrackerClientProtocol()
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
self.node_id = node_id or random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
self.announce_port = announce_port
|
self.announce_port = announce_port
|
||||||
self.servers = servers
|
self.servers = servers
|
||||||
|
self.results = {} # we can't probe the server before the interval, so we keep the result here until it expires
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
@ -145,24 +148,33 @@ class TrackerClient:
|
||||||
async def get_peer_list(self, info_hash, stopped=False):
|
async def get_peer_list(self, info_hash, stopped=False):
|
||||||
found = []
|
found = []
|
||||||
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]):
|
||||||
found.extend(await done)
|
result = await done
|
||||||
|
if result is not None:
|
||||||
|
self.EVENT_CONTROLLER.add((info_hash, result))
|
||||||
|
found.append(result)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
|
async def _probe_server(self, info_hash, tracker_host, tracker_port, stopped=False):
|
||||||
|
result = None
|
||||||
|
if info_hash in self.results:
|
||||||
|
next_announcement, result = self.results[info_hash]
|
||||||
|
if time.time() < next_announcement:
|
||||||
|
return result
|
||||||
try:
|
try:
|
||||||
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
|
tracker_ip = await resolve_host(tracker_host, tracker_port, 'udp')
|
||||||
result = await self.client.announce(
|
result = await self.client.announce(
|
||||||
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
|
info_hash, self.node_id, self.announce_port, tracker_ip, tracker_port, stopped)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.info("Tracker timed out: %s:%d", tracker_host, tracker_port)
|
log.info("Tracker timed out: %s:%d", tracker_host, tracker_port)
|
||||||
return []
|
return None
|
||||||
|
finally:
|
||||||
|
self.results[info_hash] = (time.time() + (result.interval if result else 60.0), result)
|
||||||
log.info("Announced to tracker. Found %d peers for %s on %s",
|
log.info("Announced to tracker. Found %d peers for %s on %s",
|
||||||
len(result.peers), info_hash.hex()[:8], tracker_host)
|
len(result.peers), info_hash.hex()[:8], tracker_host)
|
||||||
self.EVENT_CONTROLLER.add((info_hash, result))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def subscribe_hash(hash, on_data):
|
def subscribe_hash(hash: bytes, on_data):
|
||||||
TrackerClient.EVENT_CONTROLLER.add(('search', bytes.fromhex(hash)))
|
TrackerClient.EVENT_CONTROLLER.add(('search', hash))
|
||||||
TrackerClient.EVENT_CONTROLLER.stream.listen(
|
TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == hash).add_done_callback(
|
||||||
lambda request: on_data(request[1]) if request[0].hex() == hash else None)
|
lambda request: on_data(request.result()[1]))
|
||||||
|
|
|
@ -70,6 +70,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase):
|
||||||
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False)
|
||||||
await self.client.get_peer_list(info_hash)
|
await self.client.get_peer_list(info_hash)
|
||||||
self.server.known_conns.clear()
|
self.server.known_conns.clear()
|
||||||
|
self.client.results.clear()
|
||||||
with self.assertRaises(Exception) as err:
|
with self.assertRaises(Exception) as err:
|
||||||
await self.client.get_peer_list(info_hash)
|
await self.client.get_peer_list(info_hash)
|
||||||
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
self.assertEqual(err.exception.args[0], b'Connection ID missmatch.\x00')
|
||||||
|
|
Loading…
Add table
Reference in a new issue