diff --git a/lbry/schema/attrs.py b/lbry/schema/attrs.py index 69f81f451..4b0973ac3 100644 --- a/lbry/schema/attrs.py +++ b/lbry/schema/attrs.py @@ -32,6 +32,17 @@ def calculate_sha384_file_hash(file_path): return sha384.digest() +def country_int_to_str(country: int) -> str: + r = LocationMessage.Country.Name(country) + return r[1:] if r.startswith('R') else r + + +def country_str_to_int(country: str) -> int: + if len(country) == 3: + country = 'R' + country + return LocationMessage.Country.Value(country) + + class Dimmensional(Metadata): __slots__ = () @@ -423,14 +434,11 @@ class Language(Metadata): @property def region(self) -> str: if self.message.region: - r = LocationMessage.Country.Name(self.message.region) - return r[1:] if r.startswith('R') else r + return country_int_to_str(self.message.region) @region.setter def region(self, region: str): - if len(region) == 3: - region = 'R'+region - self.message.region = LocationMessage.Country.Value(region) + self.message.region = country_str_to_int(region) class LanguageList(BaseMessageList[Language]): diff --git a/lbry/wallet/server/env.py b/lbry/wallet/server/env.py index 8917803ba..8739800bb 100644 --- a/lbry/wallet/server/env.py +++ b/lbry/wallet/server/env.py @@ -74,6 +74,7 @@ class Env: self.anon_logs = self.boolean('ANON_LOGS', False) self.log_sessions = self.integer('LOG_SESSIONS', 3600) self.allow_lan_udp = self.boolean('ALLOW_LAN_UDP', False) + self.country = self.default('COUNTRY', 'US') # Peer discovery self.peer_discovery = self.peer_discovery_enum() self.peer_announce = self.boolean('PEER_ANNOUNCE', True) diff --git a/lbry/wallet/server/server.py b/lbry/wallet/server/server.py index 956f06ed2..cbec5c93b 100644 --- a/lbry/wallet/server/server.py +++ b/lbry/wallet/server/server.py @@ -114,7 +114,7 @@ class Server: await self.start_prometheus() if self.env.udp_port: await self.bp.status_server.start( - 0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1], + 0, bytes.fromhex(self.bp.coin.GENESIS_HASH)[::-1], self.env.country, self.env.host, self.env.udp_port, self.env.allow_lan_udp ) await _start_cancellable(self.bp.fetch_and_process_blocks) diff --git a/lbry/wallet/server/udp.py b/lbry/wallet/server/udp.py index 1357beb04..57e9177c1 100644 --- a/lbry/wallet/server/udp.py +++ b/lbry/wallet/server/udp.py @@ -4,6 +4,7 @@ from time import perf_counter import logging from typing import Optional, Tuple, NamedTuple from lbry.utils import LRUCache, is_valid_public_ipv4 +from lbry.schema.attrs import country_str_to_int, country_int_to_str # from prometheus_client import Counter @@ -13,6 +14,9 @@ _MAGIC = 1446058291 # genesis blocktime (which is actually wrong) _PAD_BYTES = b'\x00' * 64 +PROTOCOL_VERSION = 1 + + class SPVPing(NamedTuple): magic: int protocol_version: int @@ -22,8 +26,8 @@ class SPVPing(NamedTuple): return struct.pack(b'!lB64s', *self) @staticmethod - def make(protocol_version=1) -> bytes: - return SPVPing(_MAGIC, protocol_version, _PAD_BYTES).encode() + def make() -> bytes: + return SPVPing(_MAGIC, PROTOCOL_VERSION, _PAD_BYTES).encode() @classmethod def decode(cls, packet: bytes): @@ -39,14 +43,27 @@ class SPVPong(NamedTuple): height: int tip: bytes source_address_raw: bytes + country: int def encode(self): - return struct.pack(b'!BBl32s4s', *self) + return struct.pack(b'!BBL32s4sH', *self) @staticmethod - def make(height: int, tip: bytes, flags: int, protocol_version: int = 1) -> bytes: - # note: drops the last 4 bytes so the result can be cached and have addresses added to it as needed - return SPVPong(protocol_version, flags, height, tip, b'\x00\x00\x00\x00').encode()[:38] + def encode_address(address: str): + return bytes(int(b) for b in address.split(".")) + + @classmethod + def make(cls, flags: int, height: int, tip: bytes, source_address: str, country: str) -> bytes: + return SPVPong( + PROTOCOL_VERSION, flags, height, tip, + cls.encode_address(source_address), + country_str_to_int(country) + ).encode() + + @classmethod + def make_sans_source_address(cls, flags: int, height: int, tip: bytes, country: str) -> Tuple[bytes, bytes]: + pong = cls.make(flags, height, tip, '0.0.0.0', country) + return pong[:38], pong[42:] @classmethod def decode(cls, packet: bytes): @@ -60,23 +77,30 @@ class SPVPong(NamedTuple): def ip_address(self) -> str: return ".".join(map(str, self.source_address_raw)) + @property + def country_name(self): + return country_int_to_str(self.country) + def __repr__(self) -> str: return f"SPVPong(external_ip={self.ip_address}, version={self.protocol_version}, " \ f"available={'True' if self.flags & 1 > 0 else 'False'}," \ - f" height={self.height}, tip={self.tip[::-1].hex()})" + f" height={self.height}, tip={self.tip[::-1].hex()}, country={self.country_name})" class SPVServerStatusProtocol(asyncio.DatagramProtocol): - PROTOCOL_VERSION = 1 - def __init__(self, height: int, tip: bytes, throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10, - allow_localhost: bool = False, allow_lan: bool = False): + def __init__( + self, height: int, tip: bytes, country: str, + throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10, + allow_localhost: bool = False, allow_lan: bool = False + ): super().__init__() self.transport: Optional[asyncio.transports.DatagramTransport] = None self._height = height self._tip = tip self._flags = 0 - self._cached_response = None + self._country = country + self._left_cache = self._right_cache = None self.update_cached_response() self._throttle = LRUCache(throttle_cache_size) self._should_log = LRUCache(throttle_cache_size) @@ -85,7 +109,9 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol): self._allow_lan = allow_lan def update_cached_response(self): - self._cached_response = SPVPong.make(self._height, self._tip, self._flags, self.PROTOCOL_VERSION) + self._left_cache, self._right_cache = SPVPong.make_sans_source_address( + self._flags, max(0, self._height), self._tip, self._country + ) def set_unavailable(self): self._flags &= 0b11111110 @@ -112,7 +138,7 @@ class SPVServerStatusProtocol(asyncio.DatagramProtocol): return False def make_pong(self, host): - return self._cached_response + bytes(int(b) for b in host.split(".")) + return self._left_cache + SPVPong.encode_address(host) + self._right_cache def datagram_received(self, data: bytes, addr: Tuple[str, int]): if self.should_throttle(addr[0]): @@ -144,13 +170,13 @@ class StatusServer: def __init__(self): self._protocol: Optional[SPVServerStatusProtocol] = None - async def start(self, height: int, tip: bytes, interface: str, port: int, allow_lan: bool = False): + async def start(self, height: int, tip: bytes, country: str, interface: str, port: int, allow_lan: bool = False): if self.is_running: return loop = asyncio.get_event_loop() interface = interface if interface.lower() != 'localhost' else '127.0.0.1' self._protocol = SPVServerStatusProtocol( - height, tip, allow_localhost=interface == '127.0.0.1', allow_lan=allow_lan + height, tip, country, allow_localhost=interface == '127.0.0.1', allow_lan=allow_lan ) await loop.create_datagram_endpoint(lambda: self._protocol, (interface, port)) log.info("started udp status server on %s:%i", interface, port) @@ -178,13 +204,12 @@ class StatusServer: class SPVStatusClientProtocol(asyncio.DatagramProtocol): - PROTOCOL_VERSION = 1 def __init__(self, responses: asyncio.Queue): super().__init__() self.transport: Optional[asyncio.transports.DatagramTransport] = None self.responses = responses - self._ping_packet = SPVPing.make(self.PROTOCOL_VERSION) + self._ping_packet = SPVPing.make() def datagram_received(self, data: bytes, addr: Tuple[str, int]): try: