From f0d8fb8f1a7c68f231b0d8f49c5ce5530b411b52 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Thu, 21 Jan 2021 16:08:33 -0500 Subject: [PATCH] add UDP based ping protocol for spv servers --- lbry/utils.py | 2 + lbry/wallet/server/block_processor.py | 7 +- lbry/wallet/server/server.py | 3 + lbry/wallet/server/session.py | 5 +- lbry/wallet/server/udp.py | 192 ++++++++++++++++++++++++++ 5 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 lbry/wallet/server/udp.py diff --git a/lbry/utils.py b/lbry/utils.py index 7b2b72886..456eb0811 100644 --- a/lbry/utils.py +++ b/lbry/utils.py @@ -192,6 +192,8 @@ def cache_concurrent(async_fn): async def resolve_host(url: str, port: int, proto: str) -> str: if proto not in ['udp', 'tcp']: raise Exception("invalid protocol") + if url.lower() == 'localhost': + return '127.0.0.1' try: if ipaddress.ip_address(url): return url diff --git a/lbry/wallet/server/block_processor.py b/lbry/wallet/server/block_processor.py index df40d7ea6..caaa62a29 100644 --- a/lbry/wallet/server/block_processor.py +++ b/lbry/wallet/server/block_processor.py @@ -11,6 +11,7 @@ from lbry.wallet.server.daemon import DaemonError from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN from lbry.wallet.server.util import chunks, class_logger from lbry.wallet.server.leveldb import FlushData +from lbry.wallet.server.udp import StatusServer class Prefetcher: @@ -185,6 +186,7 @@ class BlockProcessor: self.search_cache = {} self.history_cache = {} + self.status_server = StatusServer() async def run_in_thread_with_lock(self, func, *args): # Run in a thread to prevent blocking. Shielded so that @@ -221,6 +223,7 @@ class BlockProcessor: processed_time = time.perf_counter() - start self.block_count_metric.set(self.height) self.block_update_time_metric.observe(processed_time) + self.status_server.set_height(self.db.fs_height, self.db.db_tip) if not self.db.first_sync: s = '' if len(blocks) == 1 else 's' self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time)) @@ -682,9 +685,11 @@ class BlockProcessor: disk before exiting, as otherwise a significant amount of work could be lost. """ + self._caught_up_event = caught_up_event try: await self._first_open_dbs() + self.status_server.set_height(self.db.fs_height, self.db.db_tip) await asyncio.wait([ self.prefetcher.main_loop(self.height), self._process_prefetched_blocks() @@ -695,6 +700,7 @@ class BlockProcessor: self.logger.exception("Block processing failed!") raise finally: + self.status_server.stop() # Shut down block processing self.logger.info('flushing to DB for a clean shutdown...') await self.flush(True) @@ -714,7 +720,6 @@ class BlockProcessor: class Timer: - def __init__(self, name): self.name = name self.total = 0 diff --git a/lbry/wallet/server/server.py b/lbry/wallet/server/server.py index dad07ac0c..6e997c645 100644 --- a/lbry/wallet/server/server.py +++ b/lbry/wallet/server/server.py @@ -111,7 +111,10 @@ class Server: return _flag.wait() await self.start_prometheus() + await self.bp.status_server.start(0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1] + , self.env.host, self.env.tcp_port) await _start_cancellable(self.bp.fetch_and_process_blocks) + await self.db.populate_header_merkle_cache() await _start_cancellable(self.mempool.keep_synchronized) await _start_cancellable(self.session_mgr.serve, self.notifications) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index ef41b869b..41a535e1b 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -247,11 +247,12 @@ class SessionManager: async def _manage_servers(self): paused = False max_sessions = self.env.max_sessions - low_watermark = max_sessions * 19 // 20 + low_watermark = int(max_sessions * 0.95) while True: await self.session_event.wait() self.session_event.clear() if not paused and len(self.sessions) >= max_sessions: + self.bp.status_server.set_unavailable() self.logger.info(f'maximum sessions {max_sessions:,d} ' f'reached, stopping new connections until ' f'count drops to {low_watermark:,d}') @@ -260,6 +261,7 @@ class SessionManager: # Start listening for incoming connections if paused and # session count has fallen if paused and len(self.sessions) <= low_watermark: + self.bp.status_server.set_available() self.logger.info('resuming listening for incoming connections') await self._start_external_servers() paused = False @@ -572,6 +574,7 @@ class SessionManager: await self.start_other() await self._start_external_servers() server_listening_event.set() + self.bp.status_server.set_available() # Peer discovery should start after the external servers # because we connect to ourself await asyncio.wait([ diff --git a/lbry/wallet/server/udp.py b/lbry/wallet/server/udp.py new file mode 100644 index 000000000..eefda85a0 --- /dev/null +++ b/lbry/wallet/server/udp.py @@ -0,0 +1,192 @@ +import asyncio +import struct +from time import perf_counter +import logging +from typing import Optional, Tuple, NamedTuple +from lbry.utils import LRUCache +# from prometheus_client import Counter + + +log = logging.getLogger(__name__) +_MAGIC = 1446058291 # genesis blocktime (which is actually wrong) +# ping_count_metric = Counter("ping_count", "Number of pings received", namespace='wallet_server_status') +_PAD_BYTES = b'\x00' * 64 + + +class SPVPing(NamedTuple): + magic: int + protocol_version: int + pad_bytes: bytes + + def encode(self): + return struct.pack(b'!lB64s', *self) + + @staticmethod + def make(protocol_version=1) -> bytes: + return SPVPing(_MAGIC, protocol_version, _PAD_BYTES).encode() + + @classmethod + def decode(cls, packet: bytes): + decoded = cls(*struct.unpack(b'!lB64s', packet[:69])) + if decoded.magic != _MAGIC: + raise ValueError("invalid magic bytes") + return decoded + + +class SPVPong(NamedTuple): + protocol_version: int + flags: int + height: int + tip: bytes + source_address_raw: bytes + + def encode(self): + return struct.pack(b'!BBl32s4s', *self) + + @staticmethod + def make(height: int, tip: bytes, flags: int, protocol_version: int = 1) -> bytes: + # note: drops the last 4 bytes so the result can be cached and have addresses added to it as needed + return SPVPong(protocol_version, flags, height, tip, b'\x00\x00\x00\x00').encode()[:38] + + @classmethod + def decode(cls, packet: bytes): + return cls(*struct.unpack(b'!BBl32s4s', packet[:42])) + + @property + def available(self) -> bool: + return (self.flags & 0b00000001) > 0 + + @property + def ip_address(self) -> str: + return ".".join(map(str, self.source_address_raw)) + + def __repr__(self) -> str: + return f"SPVPong(external_ip={self.ip_address}, version={self.protocol_version}, " \ + f"available={'True' if self.flags & 1 > 0 else 'False'}," \ + f" height={self.height}, tip={self.tip[::-1].hex()})" + + +class SPVServerStatusProtocol(asyncio.DatagramProtocol): + PROTOCOL_VERSION = 1 + + def __init__(self, height: int, tip: bytes, throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10): + super().__init__() + self.transport: Optional[asyncio.transports.DatagramTransport] = None + self._height = height + self._tip = tip + self._flags = 0 + self._cached_response = None + self.update_cached_response() + self._throttle = LRUCache(throttle_cache_size) + self._should_log = LRUCache(throttle_cache_size) + self._min_delay = 1 / throttle_reqs_per_sec + + def update_cached_response(self): + self._cached_response = SPVPong.make(self._height, self._tip, self._flags, self.PROTOCOL_VERSION) + + def set_unavailable(self): + self._flags &= 0b11111110 + self.update_cached_response() + + def set_available(self): + self._flags |= 0b00000001 + self.update_cached_response() + + def set_height(self, height: int, tip: bytes): + self._height, self._tip = height, tip + self.update_cached_response() + + def should_throttle(self, host: str): + now = perf_counter() + last_requested = self._throttle.get(host, default=0) + self._throttle[host] = now + if now - last_requested < self._min_delay: + log_cnt = self._should_log.get(host, default=0) + 1 + if log_cnt % 100 == 0: + log.warning("throttle spv status to %s", host) + self._should_log[host] = log_cnt + return True + return False + + def make_pong(self, host): + return self._cached_response + bytes(int(b) for b in host.split(".")) + + def datagram_received(self, data: bytes, addr: Tuple[str, int]): + if self.should_throttle(addr[0]): + return + try: + SPVPing.decode(data) + except (ValueError, struct.error, AttributeError, TypeError): + # log.exception("derp") + return + self.transport.sendto(self.make_pong(addr[0]), addr) + # ping_count_metric.inc() + + def connection_made(self, transport) -> None: + self.transport = transport + + def connection_lost(self, exc: Optional[Exception]) -> None: + self.transport = None + + def close(self): + if self.transport: + self.transport.close() + + +class StatusServer: + def __init__(self): + self._protocol: Optional[SPVServerStatusProtocol] = None + + async def start(self, height: int, tip: bytes, interface: str, port: int): + if self._protocol: + return + loop = asyncio.get_event_loop() + self._protocol = SPVServerStatusProtocol(height, tip) + interface = interface if interface.lower() != 'localhost' else '127.0.0.1' + await loop.create_datagram_endpoint(lambda: self._protocol, (interface, port)) + log.info("started udp status server on %s:%i", interface, port) + + def stop(self): + if self._protocol: + self._protocol.close() + self._protocol = None + + def set_unavailable(self): + self._protocol.set_unavailable() + + def set_available(self): + self._protocol.set_available() + + def set_height(self, height: int, tip: bytes): + self._protocol.set_height(height, tip) + + +class SPVStatusClientProtocol(asyncio.DatagramProtocol): + PROTOCOL_VERSION = 1 + + def __init__(self, responses: asyncio.Queue): + super().__init__() + self.transport: Optional[asyncio.transports.DatagramTransport] = None + self.responses = responses + self._ping_packet = SPVPing.make(self.PROTOCOL_VERSION) + + def datagram_received(self, data: bytes, addr: Tuple[str, int]): + try: + self.responses.put_nowait(((addr, perf_counter()), SPVPong.decode(data))) + except (ValueError, struct.error, AttributeError, TypeError, RuntimeError): + return + + def connection_made(self, transport) -> None: + self.transport = transport + + def connection_lost(self, exc: Optional[Exception]) -> None: + self.transport = None + log.info("closed udp spv server selection client") + + def ping(self, server: Tuple[str, int]): + self.transport.sendto(self._ping_packet, server) + + def close(self): + # log.info("close udp client") + if self.transport: + self.transport.close()