diff --git a/lbry/conf.py b/lbry/conf.py index a77719d5d..96b4b4389 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -63,6 +63,18 @@ class Setting(Generic[T]): for location in obj.modify_order: location[self.name] = val + def is_set(self, obj: 'BaseConfig') -> bool: + for location in obj.search_order: + if self.name in location: + return True + return False + + def is_set_to_default(self, obj: 'BaseConfig') -> bool: + for location in obj.search_order: + if self.name in location: + return location[self.name] == self.default + return False + def validate(self, value): raise NotImplementedError() @@ -577,6 +589,9 @@ class CLIConfig(TranscodeConfig): class Config(CLIConfig): + + jurisdiction = String("Limit interactions to wallet server in this jurisdiction.") + # directories data_dir = Path("Directory path to store blobs.", metavar='DIR') download_dir = Path( diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 37a495546..e55ec342d 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -138,7 +138,7 @@ class WalletComponent(Component): 'availability': session.available, } for session in sessions ], - 'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']), + 'known_servers': len(self.wallet_manager.ledger.network.known_hubs), 'available_servers': 1 if is_connected else 0 } diff --git a/lbry/testcase.py b/lbry/testcase.py index f482fe2c6..9b65e2fe8 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -378,6 +378,7 @@ class CommandTestCase(IntegrationTestCase): await daemon.stop() async def add_daemon(self, wallet_node=None, seed=None): + start_wallet_node = False if wallet_node is None: wallet_node = WalletNode( self.wallet_node.manager_class, @@ -385,8 +386,7 @@ class CommandTestCase(IntegrationTestCase): port=self.extra_wallet_node_port ) self.extra_wallet_node_port += 1 - await wallet_node.start(self.conductor.spv_node, seed=seed) - self.extra_wallet_nodes.append(wallet_node) + start_wallet_node = True upload_dir = os.path.join(wallet_node.data_path, 'uploads') os.mkdir(upload_dir) @@ -414,8 +414,13 @@ class CommandTestCase(IntegrationTestCase): ] if self.skip_libtorrent: conf.components_to_skip.append(LIBTORRENT_COMPONENT) - wallet_node.manager.config = conf - wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs + + if start_wallet_node: + await wallet_node.start(self.conductor.spv_node, seed=seed, config=conf) + self.extra_wallet_nodes.append(wallet_node) + else: + wallet_node.manager.config = conf + wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs def wallet_maker(component_manager): wallet_component = WalletComponent(component_manager) diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 44c27e37c..90f78bffb 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -8,7 +8,7 @@ from decimal import Decimal from typing import List, Type, MutableSequence, MutableMapping, Optional from lbry.error import KeyFeeAboveMaxAllowedError -from lbry.conf import Config +from lbry.conf import Config, NOT_SET from .dewies import dewies_to_lbc from .account import Account @@ -184,6 +184,7 @@ class WalletManager: 'auto_connect': True, 'default_servers': config.lbryum_servers, 'known_hubs': config.known_hubs, + 'jurisdiction': config.jurisdiction, 'data_path': config.wallet_dir, 'tx_cache_size': config.transaction_cache_size } @@ -196,6 +197,10 @@ class WalletManager: os.path.join(wallets_directory, 'default_wallet') ) + if Config.lbryum_servers.is_set_to_default(config): + with config.update_config() as c: + c.lbryum_servers = NOT_SET + manager = cls.from_config({ 'ledgers': {ledger_id: ledger_config}, 'wallets': [ @@ -226,10 +231,14 @@ class WalletManager: async def reset(self): self.ledger.config = { 'auto_connect': True, - 'default_servers': self.config.lbryum_servers, + 'explicit_servers': [], + 'default_servers': Config.lbryum_servers.default, 'known_hubs': self.config.known_hubs, + 'jurisdiction': self.config.jurisdiction, 'data_path': self.config.wallet_dir, } + if Config.lbryum_servers.is_set(self.config): + self.ledger.config['explicit_servers'] = self.config.lbryum_servers await self.ledger.stop() await self.ledger.start() diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 32b1d70f9..2d94b3a7e 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -186,7 +186,13 @@ class Network: @property def known_hubs(self): - return self.config.get('known_hubs', KnownHubsList()) + if 'known_hubs' not in self.config: + return KnownHubsList() + return self.config['known_hubs'] + + @property + def jurisdiction(self): + return self.config.get("jurisdiction") def disconnect(self): if self._keepalive_task and not self._keepalive_task.done(): @@ -226,18 +232,23 @@ class Network: log.exception("error looking up dns for spv server %s:%i", server, port) # accumulate the dns results - hubs = self.known_hubs if self.known_hubs else self.config['default_servers'] + if self.config['explicit_servers']: + hubs = self.config['explicit_servers'] + elif self.known_hubs: + hubs = self.known_hubs + else: + hubs = self.config['default_servers'] await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs)) - return hostname_to_ip, ip_to_hostnames + return hubs, hostname_to_ip, ip_to_hostnames async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]: loop = asyncio.get_event_loop() pong_responses = asyncio.Queue() connection = SPVStatusClientProtocol(pong_responses) sent_ping_timestamps = {} - _, ip_to_hostnames = await self.resolve_spv_dns() + hubs, _, ip_to_hostnames = await self.resolve_spv_dns() n = len(ip_to_hostnames) - log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['default_servers'])) + log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['explicit_servers'])) pongs = {} try: await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0)) @@ -246,6 +257,7 @@ class Network: for server in ip_to_hostnames: connection.ping(server) sent_ping_timestamps[server] = perf_counter() + known_hubs = self.known_hubs while len(pongs) < n: (remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start)) latency = ts - start @@ -253,6 +265,9 @@ class Network: '/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2), pong.available, pong.height) + known_hubs.hubs.setdefault((ip_to_hostnames[remote][0], remote[1]), {}).update( + {"country": pong.country_name} + ) if pong.available: pongs[(ip_to_hostnames[remote][0], remote[1])] = pong return pongs @@ -267,13 +282,17 @@ class Network: random_server = random.choice(list(ip_to_hostnames.keys())) host, port = random_server log.warning("trying fallback to randomly selected spv: %s:%i", host, port) + known_hubs.hubs.setdefault((host, port), {}) return {(host, port): None} finally: connection.close() async def connect_to_fastest(self) -> Optional[ClientSession]: fastest_spvs = await self.get_n_fastest_spvs() - for (host, port) in fastest_spvs: + for (host, port), pong in fastest_spvs.items(): + if (pong is not None and self.jurisdiction is not None) and \ + (pong.country_name != self.jurisdiction): + continue client = ClientSession(network=self, server=(host, port)) try: await client.create_connection() @@ -305,7 +324,7 @@ class Network: features = await client.send_request('server.features', []) self.client, self.server_features = client, features log.debug("discover other hubs %s:%i", *client.server) - self._update_hubs(await client.send_request('server.peers.subscribe', [])) + await self._update_hubs(await client.send_request('server.peers.subscribe', [])) log.info("subscribe to headers %s:%i", *client.server) self._update_remote_height((await self.subscribe_headers(),)) self._on_connected_controller.add(True) @@ -375,8 +394,8 @@ class Network: def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def _update_hubs(self, hubs): - if hubs: + async def _update_hubs(self, hubs): + if hubs and hubs != ['']: try: if self.known_hubs.add_hubs(hubs): self.known_hubs.save() diff --git a/lbry/wallet/orchstr8/node.py b/lbry/wallet/orchstr8/node.py index 916edde6e..ed634628c 100644 --- a/lbry/wallet/orchstr8/node.py +++ b/lbry/wallet/orchstr8/node.py @@ -17,6 +17,7 @@ import lbry from lbry.wallet.server.server import Server from lbry.wallet.server.env import Env from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent +from lbry.conf import KnownHubsList, Config log = logging.getLogger(__name__) @@ -107,7 +108,8 @@ class Conductor: class WalletNode: def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger], - verbose: bool = False, port: int = 5280, default_seed: str = None) -> None: + verbose: bool = False, port: int = 5280, default_seed: str = None, + data_path: str = None) -> None: self.manager_class = manager_class self.ledger_class = ledger_class self.verbose = verbose @@ -115,12 +117,12 @@ class WalletNode: self.ledger: Optional[Ledger] = None self.wallet: Optional[Wallet] = None self.account: Optional[Account] = None - self.data_path: Optional[str] = None + self.data_path: str = data_path or tempfile.mkdtemp() self.port = port self.default_seed = default_seed + self.known_hubs = KnownHubsList() - async def start(self, spv_node: 'SPVNode', seed=None, connect=True): - self.data_path = tempfile.mkdtemp() + async def start(self, spv_node: 'SPVNode', seed=None, connect=True, config=None): wallets_dir = os.path.join(self.data_path, 'wallets') os.mkdir(wallets_dir) wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') @@ -130,12 +132,15 @@ class WalletNode: 'ledgers': { self.ledger_class.get_id(): { 'api_port': self.port, - 'default_servers': [(spv_node.hostname, spv_node.port)], - 'data_path': self.data_path + 'explicit_servers': [(spv_node.hostname, spv_node.port)], + 'default_servers': Config.lbryum_servers.default, + 'data_path': self.data_path, + 'known_hubs': config.known_hubs if config else KnownHubsList() } }, 'wallets': [wallet_file_name] }) + self.manager.config = config self.ledger = self.manager.ledgers[self.ledger_class] self.wallet = self.manager.default_wallet if not self.wallet: @@ -172,6 +177,7 @@ class SPVNode: self.udp_port = self.port self.session_timeout = 600 self.rpc_port = '0' # disabled by default + self.stopped = False async def start(self, blockchain_node: 'BlockchainNode', extraconf=None): self.data_path = tempfile.mkdtemp() @@ -201,10 +207,13 @@ class SPVNode: await self.server.start() async def stop(self, cleanup=True): + if self.stopped: + return try: await self.server.db.search_index.delete_index() await self.server.db.search_index.stop() await self.server.stop() + self.stopped = True finally: cleanup and self.cleanup() diff --git a/lbry/wallet/server/udp.py b/lbry/wallet/server/udp.py index 57e9177c1..c5520ac6b 100644 --- a/lbry/wallet/server/udp.py +++ b/lbry/wallet/server/udp.py @@ -37,6 +37,9 @@ class SPVPing(NamedTuple): return decoded +PONG_ENCODING = b'!BBL32s4sH' + + class SPVPong(NamedTuple): protocol_version: int flags: int @@ -46,7 +49,7 @@ class SPVPong(NamedTuple): country: int def encode(self): - return struct.pack(b'!BBL32s4sH', *self) + return struct.pack(PONG_ENCODING, *self) @staticmethod def encode_address(address: str): @@ -67,7 +70,7 @@ class SPVPong(NamedTuple): @classmethod def decode(cls, packet: bytes): - return cls(*struct.unpack(b'!BBl32s4s', packet[:42])) + return cls(*struct.unpack(PONG_ENCODING, packet[:44])) @property def available(self) -> bool: diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index d3f3e131f..67ff86a48 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -118,13 +118,21 @@ class TestESSync(CommandTestCase): class TestHubDiscovery(CommandTestCase): async def test_hub_discovery(self): - final_node = SPVNode(self.conductor.spv_module, node_number=2) - await final_node.start(self.blockchain) - self.addCleanup(final_node.stop) - final_node_host = f"{final_node.hostname}:{final_node.port}" + us_final_node = SPVNode(self.conductor.spv_module, node_number=2) + await us_final_node.start(self.blockchain, extraconf={"COUNTRY": "US"}) + self.addCleanup(us_final_node.stop) + final_node_host = f"{us_final_node.hostname}:{us_final_node.port}" - relay_node = SPVNode(self.conductor.spv_module, node_number=3) - await relay_node.start(self.blockchain, extraconf={"PEER_HUBS": final_node_host}) + kp_final_node = SPVNode(self.conductor.spv_module, node_number=3) + await kp_final_node.start(self.blockchain, extraconf={"COUNTRY": "KP"}) + self.addCleanup(kp_final_node.stop) + kp_final_node_host = f"{kp_final_node.hostname}:{kp_final_node.port}" + + relay_node = SPVNode(self.conductor.spv_module, node_number=4) + await relay_node.start(self.blockchain, extraconf={ + "COUNTRY": "FR", + "PEER_HUBS": ",".join([kp_final_node_host, final_node_host]) + }) relay_node_host = f"{relay_node.hostname}:{relay_node.port}" self.addCleanup(relay_node.stop) @@ -134,23 +142,50 @@ class TestHubDiscovery(CommandTestCase): ('127.0.0.1', 50002) ) - # connect to relay hub which will tell us about the final hub + # connect to relay hub which will tell us about the final hubs self.daemon.jsonrpc_settings_set('lbryum_servers', [relay_node_host]) await self.daemon.jsonrpc_wallet_reconnect() - self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)]) + self.assertEqual( + self.daemon.conf.known_hubs.filter(), { + (relay_node.hostname, relay_node.port): {"country": "FR"}, + (us_final_node.hostname, us_final_node.port): {}, # discovered from relay but not contacted yet + (kp_final_node.hostname, kp_final_node.port): {}, # discovered from relay but not contacted yet + } + ) self.assertEqual( self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', relay_node.port) ) - # use known_hubs to connect to final hub + # use known_hubs to connect to final US hub + self.daemon.jsonrpc_settings_clear('lbryum_servers') + self.daemon.conf.jurisdiction = "US" await self.daemon.jsonrpc_wallet_reconnect() self.assertEqual( - self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', final_node.port) + self.daemon.conf.known_hubs.filter(), { + (relay_node.hostname, relay_node.port): {"country": "FR"}, + (us_final_node.hostname, us_final_node.port): {"country": "US"}, + (kp_final_node.hostname, kp_final_node.port): {"country": "KP"}, + } + ) + self.assertEqual( + self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', us_final_node.port) ) - final_node.server.session_mgr._notify_peer('127.0.0.1:9988') - self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)]) + # connection to KP jurisdiction + self.daemon.conf.jurisdiction = "KP" + await self.daemon.jsonrpc_wallet_reconnect() + self.assertEqual( + self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', kp_final_node.port) + ) + + kp_final_node.server.session_mgr._notify_peer('127.0.0.1:9988') await self.daemon.ledger.network.on_hub.first - self.assertEqual(list(self.daemon.conf.known_hubs), [ - (final_node.hostname, final_node.port), ('127.0.0.1', 9988) - ]) + await asyncio.sleep(0.5) # wait for above event to be processed by other listeners + self.assertEqual( + self.daemon.conf.known_hubs.filter(), { + (relay_node.hostname, relay_node.port): {"country": "FR"}, + (us_final_node.hostname, us_final_node.port): {"country": "US"}, + (kp_final_node.hostname, kp_final_node.port): {"country": "KP"}, + ('127.0.0.1', 9988): {} + } + ) diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index 3650f8844..e241dea5f 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -47,6 +47,26 @@ class ConfigurationTests(unittest.TestCase): c.persisted = {} self.assertEqual(c.test_str, 'the default') + def test_is_set(self): + c = TestConfig() + self.assertEqual(c.test_str, 'the default') + self.assertFalse(TestConfig.test_str.is_set(c)) + c.test_str = 'new value' + self.assertEqual(c.test_str, 'new value') + self.assertTrue(TestConfig.test_str.is_set(c)) + + def test_is_set_to_default(self): + c = TestConfig() + self.assertEqual(TestConfig.test_str.default, 'the default') + self.assertFalse(TestConfig.test_str.is_set(c)) + self.assertFalse(TestConfig.test_str.is_set_to_default(c)) + c.test_str = 'new value' + self.assertTrue(TestConfig.test_str.is_set(c)) + self.assertFalse(TestConfig.test_str.is_set_to_default(c)) + c.test_str = 'the default' + self.assertTrue(TestConfig.test_str.is_set(c)) + self.assertTrue(TestConfig.test_str.is_set_to_default(c)) + def test_arguments(self): parser = argparse.ArgumentParser() TestConfig.contribute_to_argparse(parser)