diff --git a/lbry/conf.py b/lbry/conf.py index 249dce37f..e82e59cbc 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -1,8 +1,8 @@ import os import re import sys -import typing import logging +from typing import List, Tuple, Union, TypeVar, Generic, Optional from argparse import ArgumentParser from contextlib import contextmanager from appdirs import user_data_dir, user_config_dir @@ -15,7 +15,7 @@ log = logging.getLogger(__name__) NOT_SET = type('NOT_SET', (object,), {}) # pylint: disable=invalid-name -T = typing.TypeVar('T') +T = TypeVar('T') CURRENCIES = { 'BTC': {'type': 'crypto'}, @@ -24,11 +24,11 @@ CURRENCIES = { } -class Setting(typing.Generic[T]): +class Setting(Generic[T]): - def __init__(self, doc: str, default: typing.Optional[T] = None, - previous_names: typing.Optional[typing.List[str]] = None, - metavar: typing.Optional[str] = None): + def __init__(self, doc: str, default: Optional[T] = None, + previous_names: Optional[List[str]] = None, + metavar: Optional[str] = None): self.doc = doc self.default = default self.previous_names = previous_names or [] @@ -45,7 +45,7 @@ class Setting(typing.Generic[T]): def no_cli_name(self): return f"--no-{self.name.replace('_', '-')}" - def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> T: + def __get__(self, obj: Optional['BaseConfig'], owner) -> T: if obj is None: return self for location in obj.search_order: @@ -53,7 +53,7 @@ class Setting(typing.Generic[T]): return location[self.name] return self.default - def __set__(self, obj: 'BaseConfig', val: typing.Union[T, NOT_SET]): + def __set__(self, obj: 'BaseConfig', val: Union[T, NOT_SET]): if val == NOT_SET: for location in obj.modify_order: if self.name in location: @@ -87,7 +87,7 @@ class String(Setting[str]): f"Setting '{self.name}' must be a string." # TODO: removes this after pylint starts to understand generics - def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> str: # pylint: disable=useless-super-delegation + def __get__(self, obj: Optional['BaseConfig'], owner) -> str: # pylint: disable=useless-super-delegation return super().__get__(obj, owner) @@ -200,7 +200,7 @@ class MaxKeyFee(Setting[dict]): class StringChoice(String): - def __init__(self, doc: str, valid_values: typing.List[str], default: str, *args, **kwargs): + def __init__(self, doc: str, valid_values: List[str], default: str, *args, **kwargs): super().__init__(doc, default, *args, **kwargs) if not valid_values: raise ValueError("No valid values provided") @@ -275,30 +275,54 @@ class Strings(ListSetting): class KnownHubsList: - def __init__(self, config: 'Config', file_name: str = 'known_hubs.yml'): + def __init__(self, config: 'Config' = None, file_name: str = 'known_hubs.yml'): + self.config = config self.file_name = file_name - self.path = os.path.join(config.wallet_dir, self.file_name) - self.hubs = [] - if self.exists: - self.load() + self.hubs: List[Tuple[str, int]] = [] + + @property + def path(self): + if self.config: + return os.path.join(self.config.wallet_dir, self.file_name) @property def exists(self): return self.path and os.path.exists(self.path) + @property + def serialized(self) -> List[str]: + return [f"{host}:{port}" for host, port in self.hubs] + def load(self): - with open(self.path, 'r') as known_hubs_file: - raw = known_hubs_file.read() - self.hubs = yaml.safe_load(raw) or {} + if self.exists: + with open(self.path, 'r') as known_hubs_file: + raw = known_hubs_file.read() + for hub in yaml.safe_load(raw) or []: + self.append(hub) def save(self): - with open(self.path, 'w') as known_hubs_file: - known_hubs_file.write(yaml.safe_dump(self.hubs, default_flow_style=False)) + if self.path: + with open(self.path, 'w') as known_hubs_file: + known_hubs_file.write(yaml.safe_dump(self.serialized, default_flow_style=False)) def append(self, hub: str): - self.hubs.append(hub) + if hub and ':' in hub: + host, port = hub.split(':') + hub_parts = (host, int(port)) + if hub_parts not in self.hubs: + self.hubs.append(hub_parts) return hub + def extend(self, hubs: List[str]): + for hub in hubs: + self.append(hub) + + def __bool__(self): + return len(self) > 0 + + def __len__(self): + return self.hubs.__len__() + def __iter__(self): return iter(self.hubs) @@ -407,7 +431,7 @@ class ConfigFileAccess: del self.data[key] -TBC = typing.TypeVar('TBC', bound='BaseConfig') +TBC = TypeVar('TBC', bound='BaseConfig') class BaseConfig: @@ -707,7 +731,7 @@ class Config(CLIConfig): return os.path.join(self.data_dir, 'lbrynet.log') -def get_windows_directories() -> typing.Tuple[str, str, str]: +def get_windows_directories() -> Tuple[str, str, str]: from lbry.winpaths import get_path, FOLDERID, UserHandle, \ PathNotFoundException # pylint: disable=import-outside-toplevel @@ -729,14 +753,14 @@ def get_windows_directories() -> typing.Tuple[str, str, str]: return data_dir, lbryum_dir, download_dir -def get_darwin_directories() -> typing.Tuple[str, str, str]: +def get_darwin_directories() -> Tuple[str, str, str]: data_dir = user_data_dir('LBRY') lbryum_dir = os.path.expanduser('~/.lbryum') download_dir = os.path.expanduser('~/Downloads') return data_dir, lbryum_dir, download_dir -def get_linux_directories() -> typing.Tuple[str, str, str]: +def get_linux_directories() -> Tuple[str, str, str]: try: with open(os.path.join(user_config_dir(), 'user-dirs.dirs'), 'r') as xdg: down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()) diff --git a/lbry/testcase.py b/lbry/testcase.py index 5b5d5b6e4..33d265952 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -396,7 +396,7 @@ class CommandTestCase(IntegrationTestCase): conf.use_upnp = False conf.reflect_streams = True conf.blockchain_name = 'lbrycrd_regtest' - conf.lbryum_servers = [('127.0.0.1', 50001)] + conf.lbryum_servers = [(self.conductor.spv_node.hostname, self.conductor.spv_node.port)] conf.reflector_servers = [('127.0.0.1', 5566)] conf.fixed_peers = [('127.0.0.1', 5567)] conf.known_dht_nodes = [] @@ -409,6 +409,7 @@ class CommandTestCase(IntegrationTestCase): if self.skip_libtorrent: conf.components_to_skip.append(LIBTORRENT_COMPONENT) wallet_node.manager.config = conf + wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs def wallet_maker(component_manager): wallet_component = WalletComponent(component_manager) diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index bff251a06..44c27e37c 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -183,6 +183,7 @@ class WalletManager: ledger_config = { 'auto_connect': True, 'default_servers': config.lbryum_servers, + 'known_hubs': config.known_hubs, 'data_path': config.wallet_dir, 'tx_cache_size': config.transaction_cache_size } @@ -226,6 +227,7 @@ class WalletManager: self.ledger.config = { 'auto_connect': True, 'default_servers': self.config.lbryum_servers, + 'known_hubs': self.config.known_hubs, 'data_path': self.config.wallet_dir, } await self.ledger.stop() diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index 7c4a2f28b..07b3f7405 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -14,6 +14,7 @@ from lbry.error import IncompatibleWalletServerError from lbry.wallet.rpc import RPCSession as BaseClientSession, Connector, RPCError, ProtocolError from lbry.wallet.stream import StreamController from lbry.wallet.server.udp import SPVStatusClientProtocol, SPVPong +from lbry.conf import KnownHubsList log = logging.getLogger(__name__) @@ -179,6 +180,10 @@ class Network: def config(self): return self.ledger.config + @property + def known_hubs(self): + return self.config.get('known_hubs', KnownHubsList()) + def disconnect(self): if self._keepalive_task and not self._keepalive_task.done(): self._keepalive_task.cancel() @@ -216,7 +221,8 @@ class Network: log.exception("error looking up dns for spv server %s:%i", server, port) # accumulate the dns results - await asyncio.gather(*(resolve_spv(server, port) for (server, port) in self.config['default_servers'])) + hubs = self.known_hubs if self.known_hubs else self.config['default_servers'] + await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs)) return hostname_to_ip, ip_to_hostnames async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]: @@ -295,6 +301,12 @@ class Network: self.client, self.server_features = client, features log.debug("discover other hubs %s:%i", *client.server) peers = await client.send_request('server.peers.get', []) + if peers: + try: + self.known_hubs.extend(peers) + self.known_hubs.save() + except Exception: + log.exception("could not add hub peers: %s", peers) log.info("subscribe to headers %s:%i", *client.server) self._update_remote_height((await self.subscribe_headers(),)) self._on_connected_controller.add(True) diff --git a/lbry/wallet/server/env.py b/lbry/wallet/server/env.py index e62df648c..8917803ba 100644 --- a/lbry/wallet/server/env.py +++ b/lbry/wallet/server/env.py @@ -77,6 +77,7 @@ class Env: # Peer discovery self.peer_discovery = self.peer_discovery_enum() self.peer_announce = self.boolean('PEER_ANNOUNCE', True) + self.peer_hubs = self.extract_peer_hubs() self.force_proxy = self.boolean('FORCE_PROXY', False) self.tor_proxy_host = self.default('TOR_PROXY_HOST', 'localhost') self.tor_proxy_port = self.integer('TOR_PROXY_PORT', None) @@ -270,5 +271,5 @@ class Env: else: return self.PD_ON - def peer_hubs(self): + def extract_peer_hubs(self): return [hub.strip() for hub in self.default('PEER_HUBS', '').split(',')] diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index c5ecb11f9..15da6f7d2 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -1085,7 +1085,7 @@ class LBRYElectrumX(SessionBase): async def peers_get(self): """Return the server peers as a list of (ip, host, details) tuples.""" - return self.env.peer_hubs() + return self.env.peer_hubs async def peers_subscribe(self): """Return the server peers as a list of (ip, host, details) tuples.""" diff --git a/tests/integration/blockchain/test_wallet_server_sessions.py b/tests/integration/blockchain/test_wallet_server_sessions.py index b531e8934..2b24c0040 100644 --- a/tests/integration/blockchain/test_wallet_server_sessions.py +++ b/tests/integration/blockchain/test_wallet_server_sessions.py @@ -129,16 +129,21 @@ class TestHubDiscovery(CommandTestCase): self.addCleanup(relay_node.stop) self.assertEqual(list(self.daemon.conf.known_hubs), []) - self.daemon.conf.known_hubs.append(relay_node_host) - self.assertEqual( self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', 50002) ) + # connect to relay hub which will tell us about the final hub + self.daemon.jsonrpc_settings_set('lbryum_servers', [relay_node_host]) await self.daemon.jsonrpc_wallet_reconnect() - + self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)]) self.assertEqual( - self.daemon.ledger.network.client.server_address_and_port, - ('127.0.0.1', 50003) + self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', relay_node.port) + ) + + # use known_hubs to connect to final hub + await self.daemon.jsonrpc_wallet_reconnect() + self.assertEqual( + self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', final_node.port) ) diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index 7f58ba3b7..4710219cc 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -260,7 +260,7 @@ class ConfigurationTests(unittest.TestCase): with tempfile.TemporaryDirectory() as temp_dir: c1 = Config(config=os.path.join(temp_dir, 'settings.yml'), wallet_dir=temp_dir) self.assertEqual(list(c1.known_hubs), []) - c1.known_hubs.append('new.hub.io') + c1.known_hubs.append('new.hub.io:99') c1.known_hubs.save() c2 = Config(config=os.path.join(temp_dir, 'settings.yml'), wallet_dir=temp_dir) - self.assertEqual(list(c2.known_hubs), ['new.hub.io']) + self.assertEqual(list(c2.known_hubs), [('new.hub.io', 99)])