add UDP based ping protocol for spv servers
This commit is contained in:
parent
f7a380e9b7
commit
f0d8fb8f1a
5 changed files with 207 additions and 2 deletions
|
@ -192,6 +192,8 @@ def cache_concurrent(async_fn):
|
||||||
async def resolve_host(url: str, port: int, proto: str) -> str:
|
async def resolve_host(url: str, port: int, proto: str) -> str:
|
||||||
if proto not in ['udp', 'tcp']:
|
if proto not in ['udp', 'tcp']:
|
||||||
raise Exception("invalid protocol")
|
raise Exception("invalid protocol")
|
||||||
|
if url.lower() == 'localhost':
|
||||||
|
return '127.0.0.1'
|
||||||
try:
|
try:
|
||||||
if ipaddress.ip_address(url):
|
if ipaddress.ip_address(url):
|
||||||
return url
|
return url
|
||||||
|
|
|
@ -11,6 +11,7 @@ from lbry.wallet.server.daemon import DaemonError
|
||||||
from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN
|
from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN
|
||||||
from lbry.wallet.server.util import chunks, class_logger
|
from lbry.wallet.server.util import chunks, class_logger
|
||||||
from lbry.wallet.server.leveldb import FlushData
|
from lbry.wallet.server.leveldb import FlushData
|
||||||
|
from lbry.wallet.server.udp import StatusServer
|
||||||
|
|
||||||
|
|
||||||
class Prefetcher:
|
class Prefetcher:
|
||||||
|
@ -185,6 +186,7 @@ class BlockProcessor:
|
||||||
|
|
||||||
self.search_cache = {}
|
self.search_cache = {}
|
||||||
self.history_cache = {}
|
self.history_cache = {}
|
||||||
|
self.status_server = StatusServer()
|
||||||
|
|
||||||
async def run_in_thread_with_lock(self, func, *args):
|
async def run_in_thread_with_lock(self, func, *args):
|
||||||
# Run in a thread to prevent blocking. Shielded so that
|
# Run in a thread to prevent blocking. Shielded so that
|
||||||
|
@ -221,6 +223,7 @@ class BlockProcessor:
|
||||||
processed_time = time.perf_counter() - start
|
processed_time = time.perf_counter() - start
|
||||||
self.block_count_metric.set(self.height)
|
self.block_count_metric.set(self.height)
|
||||||
self.block_update_time_metric.observe(processed_time)
|
self.block_update_time_metric.observe(processed_time)
|
||||||
|
self.status_server.set_height(self.db.fs_height, self.db.db_tip)
|
||||||
if not self.db.first_sync:
|
if not self.db.first_sync:
|
||||||
s = '' if len(blocks) == 1 else 's'
|
s = '' if len(blocks) == 1 else 's'
|
||||||
self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time))
|
self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time))
|
||||||
|
@ -682,9 +685,11 @@ class BlockProcessor:
|
||||||
disk before exiting, as otherwise a significant amount of work
|
disk before exiting, as otherwise a significant amount of work
|
||||||
could be lost.
|
could be lost.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._caught_up_event = caught_up_event
|
self._caught_up_event = caught_up_event
|
||||||
try:
|
try:
|
||||||
await self._first_open_dbs()
|
await self._first_open_dbs()
|
||||||
|
self.status_server.set_height(self.db.fs_height, self.db.db_tip)
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
self.prefetcher.main_loop(self.height),
|
self.prefetcher.main_loop(self.height),
|
||||||
self._process_prefetched_blocks()
|
self._process_prefetched_blocks()
|
||||||
|
@ -695,6 +700,7 @@ class BlockProcessor:
|
||||||
self.logger.exception("Block processing failed!")
|
self.logger.exception("Block processing failed!")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
self.status_server.stop()
|
||||||
# Shut down block processing
|
# Shut down block processing
|
||||||
self.logger.info('flushing to DB for a clean shutdown...')
|
self.logger.info('flushing to DB for a clean shutdown...')
|
||||||
await self.flush(True)
|
await self.flush(True)
|
||||||
|
@ -714,7 +720,6 @@ class BlockProcessor:
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.total = 0
|
self.total = 0
|
||||||
|
|
|
@ -111,7 +111,10 @@ class Server:
|
||||||
return _flag.wait()
|
return _flag.wait()
|
||||||
|
|
||||||
await self.start_prometheus()
|
await self.start_prometheus()
|
||||||
|
await self.bp.status_server.start(0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1]
|
||||||
|
, self.env.host, self.env.tcp_port)
|
||||||
await _start_cancellable(self.bp.fetch_and_process_blocks)
|
await _start_cancellable(self.bp.fetch_and_process_blocks)
|
||||||
|
|
||||||
await self.db.populate_header_merkle_cache()
|
await self.db.populate_header_merkle_cache()
|
||||||
await _start_cancellable(self.mempool.keep_synchronized)
|
await _start_cancellable(self.mempool.keep_synchronized)
|
||||||
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
await _start_cancellable(self.session_mgr.serve, self.notifications)
|
||||||
|
|
|
@ -247,11 +247,12 @@ class SessionManager:
|
||||||
async def _manage_servers(self):
|
async def _manage_servers(self):
|
||||||
paused = False
|
paused = False
|
||||||
max_sessions = self.env.max_sessions
|
max_sessions = self.env.max_sessions
|
||||||
low_watermark = max_sessions * 19 // 20
|
low_watermark = int(max_sessions * 0.95)
|
||||||
while True:
|
while True:
|
||||||
await self.session_event.wait()
|
await self.session_event.wait()
|
||||||
self.session_event.clear()
|
self.session_event.clear()
|
||||||
if not paused and len(self.sessions) >= max_sessions:
|
if not paused and len(self.sessions) >= max_sessions:
|
||||||
|
self.bp.status_server.set_unavailable()
|
||||||
self.logger.info(f'maximum sessions {max_sessions:,d} '
|
self.logger.info(f'maximum sessions {max_sessions:,d} '
|
||||||
f'reached, stopping new connections until '
|
f'reached, stopping new connections until '
|
||||||
f'count drops to {low_watermark:,d}')
|
f'count drops to {low_watermark:,d}')
|
||||||
|
@ -260,6 +261,7 @@ class SessionManager:
|
||||||
# Start listening for incoming connections if paused and
|
# Start listening for incoming connections if paused and
|
||||||
# session count has fallen
|
# session count has fallen
|
||||||
if paused and len(self.sessions) <= low_watermark:
|
if paused and len(self.sessions) <= low_watermark:
|
||||||
|
self.bp.status_server.set_available()
|
||||||
self.logger.info('resuming listening for incoming connections')
|
self.logger.info('resuming listening for incoming connections')
|
||||||
await self._start_external_servers()
|
await self._start_external_servers()
|
||||||
paused = False
|
paused = False
|
||||||
|
@ -572,6 +574,7 @@ class SessionManager:
|
||||||
await self.start_other()
|
await self.start_other()
|
||||||
await self._start_external_servers()
|
await self._start_external_servers()
|
||||||
server_listening_event.set()
|
server_listening_event.set()
|
||||||
|
self.bp.status_server.set_available()
|
||||||
# Peer discovery should start after the external servers
|
# Peer discovery should start after the external servers
|
||||||
# because we connect to ourself
|
# because we connect to ourself
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
|
|
192
lbry/wallet/server/udp.py
Normal file
192
lbry/wallet/server/udp.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
import asyncio
|
||||||
|
import struct
|
||||||
|
from time import perf_counter
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, NamedTuple
|
||||||
|
from lbry.utils import LRUCache
|
||||||
|
# from prometheus_client import Counter
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
_MAGIC = 1446058291 # genesis blocktime (which is actually wrong)
|
||||||
|
# ping_count_metric = Counter("ping_count", "Number of pings received", namespace='wallet_server_status')
|
||||||
|
_PAD_BYTES = b'\x00' * 64
|
||||||
|
|
||||||
|
|
||||||
|
class SPVPing(NamedTuple):
|
||||||
|
magic: int
|
||||||
|
protocol_version: int
|
||||||
|
pad_bytes: bytes
|
||||||
|
|
||||||
|
def encode(self):
|
||||||
|
return struct.pack(b'!lB64s', *self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(protocol_version=1) -> bytes:
|
||||||
|
return SPVPing(_MAGIC, protocol_version, _PAD_BYTES).encode()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode(cls, packet: bytes):
|
||||||
|
decoded = cls(*struct.unpack(b'!lB64s', packet[:69]))
|
||||||
|
if decoded.magic != _MAGIC:
|
||||||
|
raise ValueError("invalid magic bytes")
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
class SPVPong(NamedTuple):
|
||||||
|
protocol_version: int
|
||||||
|
flags: int
|
||||||
|
height: int
|
||||||
|
tip: bytes
|
||||||
|
source_address_raw: bytes
|
||||||
|
|
||||||
|
def encode(self):
|
||||||
|
return struct.pack(b'!BBl32s4s', *self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(height: int, tip: bytes, flags: int, protocol_version: int = 1) -> bytes:
|
||||||
|
# note: drops the last 4 bytes so the result can be cached and have addresses added to it as needed
|
||||||
|
return SPVPong(protocol_version, flags, height, tip, b'\x00\x00\x00\x00').encode()[:38]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode(cls, packet: bytes):
|
||||||
|
return cls(*struct.unpack(b'!BBl32s4s', packet[:42]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
return (self.flags & 0b00000001) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ip_address(self) -> str:
|
||||||
|
return ".".join(map(str, self.source_address_raw))
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"SPVPong(external_ip={self.ip_address}, version={self.protocol_version}, " \
|
||||||
|
f"available={'True' if self.flags & 1 > 0 else 'False'}," \
|
||||||
|
f" height={self.height}, tip={self.tip[::-1].hex()})"
|
||||||
|
|
||||||
|
|
||||||
|
class SPVServerStatusProtocol(asyncio.DatagramProtocol):
|
||||||
|
PROTOCOL_VERSION = 1
|
||||||
|
|
||||||
|
def __init__(self, height: int, tip: bytes, throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10):
|
||||||
|
super().__init__()
|
||||||
|
self.transport: Optional[asyncio.transports.DatagramTransport] = None
|
||||||
|
self._height = height
|
||||||
|
self._tip = tip
|
||||||
|
self._flags = 0
|
||||||
|
self._cached_response = None
|
||||||
|
self.update_cached_response()
|
||||||
|
self._throttle = LRUCache(throttle_cache_size)
|
||||||
|
self._should_log = LRUCache(throttle_cache_size)
|
||||||
|
self._min_delay = 1 / throttle_reqs_per_sec
|
||||||
|
|
||||||
|
def update_cached_response(self):
|
||||||
|
self._cached_response = SPVPong.make(self._height, self._tip, self._flags, self.PROTOCOL_VERSION)
|
||||||
|
|
||||||
|
def set_unavailable(self):
|
||||||
|
self._flags &= 0b11111110
|
||||||
|
self.update_cached_response()
|
||||||
|
|
||||||
|
def set_available(self):
|
||||||
|
self._flags |= 0b00000001
|
||||||
|
self.update_cached_response()
|
||||||
|
|
||||||
|
def set_height(self, height: int, tip: bytes):
|
||||||
|
self._height, self._tip = height, tip
|
||||||
|
self.update_cached_response()
|
||||||
|
|
||||||
|
def should_throttle(self, host: str):
|
||||||
|
now = perf_counter()
|
||||||
|
last_requested = self._throttle.get(host, default=0)
|
||||||
|
self._throttle[host] = now
|
||||||
|
if now - last_requested < self._min_delay:
|
||||||
|
log_cnt = self._should_log.get(host, default=0) + 1
|
||||||
|
if log_cnt % 100 == 0:
|
||||||
|
log.warning("throttle spv status to %s", host)
|
||||||
|
self._should_log[host] = log_cnt
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def make_pong(self, host):
|
||||||
|
return self._cached_response + bytes(int(b) for b in host.split("."))
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
|
||||||
|
if self.should_throttle(addr[0]):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
SPVPing.decode(data)
|
||||||
|
except (ValueError, struct.error, AttributeError, TypeError):
|
||||||
|
# log.exception("derp")
|
||||||
|
return
|
||||||
|
self.transport.sendto(self.make_pong(addr[0]), addr)
|
||||||
|
# ping_count_metric.inc()
|
||||||
|
|
||||||
|
def connection_made(self, transport) -> None:
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||||
|
self.transport = None
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
|
||||||
|
class StatusServer:
|
||||||
|
def __init__(self):
|
||||||
|
self._protocol: Optional[SPVServerStatusProtocol] = None
|
||||||
|
|
||||||
|
async def start(self, height: int, tip: bytes, interface: str, port: int):
|
||||||
|
if self._protocol:
|
||||||
|
return
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
self._protocol = SPVServerStatusProtocol(height, tip)
|
||||||
|
interface = interface if interface.lower() != 'localhost' else '127.0.0.1'
|
||||||
|
await loop.create_datagram_endpoint(lambda: self._protocol, (interface, port))
|
||||||
|
log.info("started udp status server on %s:%i", interface, port)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._protocol:
|
||||||
|
self._protocol.close()
|
||||||
|
self._protocol = None
|
||||||
|
|
||||||
|
def set_unavailable(self):
|
||||||
|
self._protocol.set_unavailable()
|
||||||
|
|
||||||
|
def set_available(self):
|
||||||
|
self._protocol.set_available()
|
||||||
|
|
||||||
|
def set_height(self, height: int, tip: bytes):
|
||||||
|
self._protocol.set_height(height, tip)
|
||||||
|
|
||||||
|
|
||||||
|
class SPVStatusClientProtocol(asyncio.DatagramProtocol):
|
||||||
|
PROTOCOL_VERSION = 1
|
||||||
|
|
||||||
|
def __init__(self, responses: asyncio.Queue):
|
||||||
|
super().__init__()
|
||||||
|
self.transport: Optional[asyncio.transports.DatagramTransport] = None
|
||||||
|
self.responses = responses
|
||||||
|
self._ping_packet = SPVPing.make(self.PROTOCOL_VERSION)
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
|
||||||
|
try:
|
||||||
|
self.responses.put_nowait(((addr, perf_counter()), SPVPong.decode(data)))
|
||||||
|
except (ValueError, struct.error, AttributeError, TypeError, RuntimeError):
|
||||||
|
return
|
||||||
|
|
||||||
|
def connection_made(self, transport) -> None:
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||||
|
self.transport = None
|
||||||
|
log.info("closed udp spv server selection client")
|
||||||
|
|
||||||
|
def ping(self, server: Tuple[str, int]):
|
||||||
|
self.transport.sendto(self._ping_packet, server)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# log.info("close udp client")
|
||||||
|
if self.transport:
|
||||||
|
self.transport.close()
|
Loading…
Reference in a new issue