added support to config for determining if value is set and implemented hub selection logic

This commit is contained in:
Lex Berezhny 2021-06-14 13:06:31 -04:00
parent 5f0426c840
commit 7d49b046d4
9 changed files with 154 additions and 39 deletions

View file

@ -63,6 +63,18 @@ class Setting(Generic[T]):
for location in obj.modify_order: for location in obj.modify_order:
location[self.name] = val location[self.name] = val
def is_set(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return True
return False
def is_set_to_default(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return location[self.name] == self.default
return False
def validate(self, value): def validate(self, value):
raise NotImplementedError() raise NotImplementedError()
@ -577,6 +589,9 @@ class CLIConfig(TranscodeConfig):
class Config(CLIConfig): class Config(CLIConfig):
jurisdiction = String("Limit interactions to wallet server in this jurisdiction.")
# directories # directories
data_dir = Path("Directory path to store blobs.", metavar='DIR') data_dir = Path("Directory path to store blobs.", metavar='DIR')
download_dir = Path( download_dir = Path(

View file

@ -138,7 +138,7 @@ class WalletComponent(Component):
'availability': session.available, 'availability': session.available,
} for session in sessions } for session in sessions
], ],
'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']), 'known_servers': len(self.wallet_manager.ledger.network.known_hubs),
'available_servers': 1 if is_connected else 0 'available_servers': 1 if is_connected else 0
} }

View file

@ -378,6 +378,7 @@ class CommandTestCase(IntegrationTestCase):
await daemon.stop() await daemon.stop()
async def add_daemon(self, wallet_node=None, seed=None): async def add_daemon(self, wallet_node=None, seed=None):
start_wallet_node = False
if wallet_node is None: if wallet_node is None:
wallet_node = WalletNode( wallet_node = WalletNode(
self.wallet_node.manager_class, self.wallet_node.manager_class,
@ -385,8 +386,7 @@ class CommandTestCase(IntegrationTestCase):
port=self.extra_wallet_node_port port=self.extra_wallet_node_port
) )
self.extra_wallet_node_port += 1 self.extra_wallet_node_port += 1
await wallet_node.start(self.conductor.spv_node, seed=seed) start_wallet_node = True
self.extra_wallet_nodes.append(wallet_node)
upload_dir = os.path.join(wallet_node.data_path, 'uploads') upload_dir = os.path.join(wallet_node.data_path, 'uploads')
os.mkdir(upload_dir) os.mkdir(upload_dir)
@ -414,8 +414,13 @@ class CommandTestCase(IntegrationTestCase):
] ]
if self.skip_libtorrent: if self.skip_libtorrent:
conf.components_to_skip.append(LIBTORRENT_COMPONENT) conf.components_to_skip.append(LIBTORRENT_COMPONENT)
wallet_node.manager.config = conf
wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs if start_wallet_node:
await wallet_node.start(self.conductor.spv_node, seed=seed, config=conf)
self.extra_wallet_nodes.append(wallet_node)
else:
wallet_node.manager.config = conf
wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs
def wallet_maker(component_manager): def wallet_maker(component_manager):
wallet_component = WalletComponent(component_manager) wallet_component = WalletComponent(component_manager)

View file

@ -8,7 +8,7 @@ from decimal import Decimal
from typing import List, Type, MutableSequence, MutableMapping, Optional from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.error import KeyFeeAboveMaxAllowedError from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.conf import Config from lbry.conf import Config, NOT_SET
from .dewies import dewies_to_lbc from .dewies import dewies_to_lbc
from .account import Account from .account import Account
@ -184,6 +184,7 @@ class WalletManager:
'auto_connect': True, 'auto_connect': True,
'default_servers': config.lbryum_servers, 'default_servers': config.lbryum_servers,
'known_hubs': config.known_hubs, 'known_hubs': config.known_hubs,
'jurisdiction': config.jurisdiction,
'data_path': config.wallet_dir, 'data_path': config.wallet_dir,
'tx_cache_size': config.transaction_cache_size 'tx_cache_size': config.transaction_cache_size
} }
@ -196,6 +197,10 @@ class WalletManager:
os.path.join(wallets_directory, 'default_wallet') os.path.join(wallets_directory, 'default_wallet')
) )
if Config.lbryum_servers.is_set_to_default(config):
with config.update_config() as c:
c.lbryum_servers = NOT_SET
manager = cls.from_config({ manager = cls.from_config({
'ledgers': {ledger_id: ledger_config}, 'ledgers': {ledger_id: ledger_config},
'wallets': [ 'wallets': [
@ -226,10 +231,14 @@ class WalletManager:
async def reset(self): async def reset(self):
self.ledger.config = { self.ledger.config = {
'auto_connect': True, 'auto_connect': True,
'default_servers': self.config.lbryum_servers, 'explicit_servers': [],
'default_servers': Config.lbryum_servers.default,
'known_hubs': self.config.known_hubs, 'known_hubs': self.config.known_hubs,
'jurisdiction': self.config.jurisdiction,
'data_path': self.config.wallet_dir, 'data_path': self.config.wallet_dir,
} }
if Config.lbryum_servers.is_set(self.config):
self.ledger.config['explicit_servers'] = self.config.lbryum_servers
await self.ledger.stop() await self.ledger.stop()
await self.ledger.start() await self.ledger.start()

View file

@ -186,7 +186,13 @@ class Network:
@property @property
def known_hubs(self): def known_hubs(self):
return self.config.get('known_hubs', KnownHubsList()) if 'known_hubs' not in self.config:
return KnownHubsList()
return self.config['known_hubs']
@property
def jurisdiction(self):
return self.config.get("jurisdiction")
def disconnect(self): def disconnect(self):
if self._keepalive_task and not self._keepalive_task.done(): if self._keepalive_task and not self._keepalive_task.done():
@ -226,18 +232,23 @@ class Network:
log.exception("error looking up dns for spv server %s:%i", server, port) log.exception("error looking up dns for spv server %s:%i", server, port)
# accumulate the dns results # accumulate the dns results
hubs = self.known_hubs if self.known_hubs else self.config['default_servers'] if self.config['explicit_servers']:
hubs = self.config['explicit_servers']
elif self.known_hubs:
hubs = self.known_hubs
else:
hubs = self.config['default_servers']
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs)) await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs))
return hostname_to_ip, ip_to_hostnames return hubs, hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]: async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue() pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses) connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {} sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns() hubs, _, ip_to_hostnames = await self.resolve_spv_dns()
n = len(ip_to_hostnames) n = len(ip_to_hostnames)
log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['default_servers'])) log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['explicit_servers']))
pongs = {} pongs = {}
try: try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0)) await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
@ -246,6 +257,7 @@ class Network:
for server in ip_to_hostnames: for server in ip_to_hostnames:
connection.ping(server) connection.ping(server)
sent_ping_timestamps[server] = perf_counter() sent_ping_timestamps[server] = perf_counter()
known_hubs = self.known_hubs
while len(pongs) < n: while len(pongs) < n:
(remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start)) (remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start))
latency = ts - start latency = ts - start
@ -253,6 +265,9 @@ class Network:
'/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2), '/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2),
pong.available, pong.height) pong.available, pong.height)
known_hubs.hubs.setdefault((ip_to_hostnames[remote][0], remote[1]), {}).update(
{"country": pong.country_name}
)
if pong.available: if pong.available:
pongs[(ip_to_hostnames[remote][0], remote[1])] = pong pongs[(ip_to_hostnames[remote][0], remote[1])] = pong
return pongs return pongs
@ -267,13 +282,17 @@ class Network:
random_server = random.choice(list(ip_to_hostnames.keys())) random_server = random.choice(list(ip_to_hostnames.keys()))
host, port = random_server host, port = random_server
log.warning("trying fallback to randomly selected spv: %s:%i", host, port) log.warning("trying fallback to randomly selected spv: %s:%i", host, port)
known_hubs.hubs.setdefault((host, port), {})
return {(host, port): None} return {(host, port): None}
finally: finally:
connection.close() connection.close()
async def connect_to_fastest(self) -> Optional[ClientSession]: async def connect_to_fastest(self) -> Optional[ClientSession]:
fastest_spvs = await self.get_n_fastest_spvs() fastest_spvs = await self.get_n_fastest_spvs()
for (host, port) in fastest_spvs: for (host, port), pong in fastest_spvs.items():
if (pong is not None and self.jurisdiction is not None) and \
(pong.country_name != self.jurisdiction):
continue
client = ClientSession(network=self, server=(host, port)) client = ClientSession(network=self, server=(host, port))
try: try:
await client.create_connection() await client.create_connection()
@ -305,7 +324,7 @@ class Network:
features = await client.send_request('server.features', []) features = await client.send_request('server.features', [])
self.client, self.server_features = client, features self.client, self.server_features = client, features
log.debug("discover other hubs %s:%i", *client.server) log.debug("discover other hubs %s:%i", *client.server)
self._update_hubs(await client.send_request('server.peers.subscribe', [])) await self._update_hubs(await client.send_request('server.peers.subscribe', []))
log.info("subscribe to headers %s:%i", *client.server) log.info("subscribe to headers %s:%i", *client.server)
self._update_remote_height((await self.subscribe_headers(),)) self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True) self._on_connected_controller.add(True)
@ -375,8 +394,8 @@ class Network:
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"] self.remote_height = header_args[0]["height"]
def _update_hubs(self, hubs): async def _update_hubs(self, hubs):
if hubs: if hubs and hubs != ['']:
try: try:
if self.known_hubs.add_hubs(hubs): if self.known_hubs.add_hubs(hubs):
self.known_hubs.save() self.known_hubs.save()

View file

@ -17,6 +17,7 @@ import lbry
from lbry.wallet.server.server import Server from lbry.wallet.server.server import Server
from lbry.wallet.server.env import Env from lbry.wallet.server.env import Env
from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent
from lbry.conf import KnownHubsList, Config
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -107,7 +108,8 @@ class Conductor:
class WalletNode: class WalletNode:
def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger], def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger],
verbose: bool = False, port: int = 5280, default_seed: str = None) -> None: verbose: bool = False, port: int = 5280, default_seed: str = None,
data_path: str = None) -> None:
self.manager_class = manager_class self.manager_class = manager_class
self.ledger_class = ledger_class self.ledger_class = ledger_class
self.verbose = verbose self.verbose = verbose
@ -115,12 +117,12 @@ class WalletNode:
self.ledger: Optional[Ledger] = None self.ledger: Optional[Ledger] = None
self.wallet: Optional[Wallet] = None self.wallet: Optional[Wallet] = None
self.account: Optional[Account] = None self.account: Optional[Account] = None
self.data_path: Optional[str] = None self.data_path: str = data_path or tempfile.mkdtemp()
self.port = port self.port = port
self.default_seed = default_seed self.default_seed = default_seed
self.known_hubs = KnownHubsList()
async def start(self, spv_node: 'SPVNode', seed=None, connect=True): async def start(self, spv_node: 'SPVNode', seed=None, connect=True, config=None):
self.data_path = tempfile.mkdtemp()
wallets_dir = os.path.join(self.data_path, 'wallets') wallets_dir = os.path.join(self.data_path, 'wallets')
os.mkdir(wallets_dir) os.mkdir(wallets_dir)
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json') wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
@ -130,12 +132,15 @@ class WalletNode:
'ledgers': { 'ledgers': {
self.ledger_class.get_id(): { self.ledger_class.get_id(): {
'api_port': self.port, 'api_port': self.port,
'default_servers': [(spv_node.hostname, spv_node.port)], 'explicit_servers': [(spv_node.hostname, spv_node.port)],
'data_path': self.data_path 'default_servers': Config.lbryum_servers.default,
'data_path': self.data_path,
'known_hubs': config.known_hubs if config else KnownHubsList()
} }
}, },
'wallets': [wallet_file_name] 'wallets': [wallet_file_name]
}) })
self.manager.config = config
self.ledger = self.manager.ledgers[self.ledger_class] self.ledger = self.manager.ledgers[self.ledger_class]
self.wallet = self.manager.default_wallet self.wallet = self.manager.default_wallet
if not self.wallet: if not self.wallet:
@ -172,6 +177,7 @@ class SPVNode:
self.udp_port = self.port self.udp_port = self.port
self.session_timeout = 600 self.session_timeout = 600
self.rpc_port = '0' # disabled by default self.rpc_port = '0' # disabled by default
self.stopped = False
async def start(self, blockchain_node: 'BlockchainNode', extraconf=None): async def start(self, blockchain_node: 'BlockchainNode', extraconf=None):
self.data_path = tempfile.mkdtemp() self.data_path = tempfile.mkdtemp()
@ -201,10 +207,13 @@ class SPVNode:
await self.server.start() await self.server.start()
async def stop(self, cleanup=True): async def stop(self, cleanup=True):
if self.stopped:
return
try: try:
await self.server.db.search_index.delete_index() await self.server.db.search_index.delete_index()
await self.server.db.search_index.stop() await self.server.db.search_index.stop()
await self.server.stop() await self.server.stop()
self.stopped = True
finally: finally:
cleanup and self.cleanup() cleanup and self.cleanup()

View file

@ -37,6 +37,9 @@ class SPVPing(NamedTuple):
return decoded return decoded
PONG_ENCODING = b'!BBL32s4sH'
class SPVPong(NamedTuple): class SPVPong(NamedTuple):
protocol_version: int protocol_version: int
flags: int flags: int
@ -46,7 +49,7 @@ class SPVPong(NamedTuple):
country: int country: int
def encode(self): def encode(self):
return struct.pack(b'!BBL32s4sH', *self) return struct.pack(PONG_ENCODING, *self)
@staticmethod @staticmethod
def encode_address(address: str): def encode_address(address: str):
@ -67,7 +70,7 @@ class SPVPong(NamedTuple):
@classmethod @classmethod
def decode(cls, packet: bytes): def decode(cls, packet: bytes):
return cls(*struct.unpack(b'!BBl32s4s', packet[:42])) return cls(*struct.unpack(PONG_ENCODING, packet[:44]))
@property @property
def available(self) -> bool: def available(self) -> bool:

View file

@ -118,13 +118,21 @@ class TestESSync(CommandTestCase):
class TestHubDiscovery(CommandTestCase): class TestHubDiscovery(CommandTestCase):
async def test_hub_discovery(self): async def test_hub_discovery(self):
final_node = SPVNode(self.conductor.spv_module, node_number=2) us_final_node = SPVNode(self.conductor.spv_module, node_number=2)
await final_node.start(self.blockchain) await us_final_node.start(self.blockchain, extraconf={"COUNTRY": "US"})
self.addCleanup(final_node.stop) self.addCleanup(us_final_node.stop)
final_node_host = f"{final_node.hostname}:{final_node.port}" final_node_host = f"{us_final_node.hostname}:{us_final_node.port}"
relay_node = SPVNode(self.conductor.spv_module, node_number=3) kp_final_node = SPVNode(self.conductor.spv_module, node_number=3)
await relay_node.start(self.blockchain, extraconf={"PEER_HUBS": final_node_host}) await kp_final_node.start(self.blockchain, extraconf={"COUNTRY": "KP"})
self.addCleanup(kp_final_node.stop)
kp_final_node_host = f"{kp_final_node.hostname}:{kp_final_node.port}"
relay_node = SPVNode(self.conductor.spv_module, node_number=4)
await relay_node.start(self.blockchain, extraconf={
"COUNTRY": "FR",
"PEER_HUBS": ",".join([kp_final_node_host, final_node_host])
})
relay_node_host = f"{relay_node.hostname}:{relay_node.port}" relay_node_host = f"{relay_node.hostname}:{relay_node.port}"
self.addCleanup(relay_node.stop) self.addCleanup(relay_node.stop)
@ -134,23 +142,50 @@ class TestHubDiscovery(CommandTestCase):
('127.0.0.1', 50002) ('127.0.0.1', 50002)
) )
# connect to relay hub which will tell us about the final hub # connect to relay hub which will tell us about the final hubs
self.daemon.jsonrpc_settings_set('lbryum_servers', [relay_node_host]) self.daemon.jsonrpc_settings_set('lbryum_servers', [relay_node_host])
await self.daemon.jsonrpc_wallet_reconnect() await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)]) self.assertEqual(
self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {}, # discovered from relay but not contacted yet
(kp_final_node.hostname, kp_final_node.port): {}, # discovered from relay but not contacted yet
}
)
self.assertEqual( self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', relay_node.port) self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', relay_node.port)
) )
# use known_hubs to connect to final hub # use known_hubs to connect to final US hub
self.daemon.jsonrpc_settings_clear('lbryum_servers')
self.daemon.conf.jurisdiction = "US"
await self.daemon.jsonrpc_wallet_reconnect() await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual( self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', final_node.port) self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {"country": "US"},
(kp_final_node.hostname, kp_final_node.port): {"country": "KP"},
}
)
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', us_final_node.port)
) )
final_node.server.session_mgr._notify_peer('127.0.0.1:9988') # connection to KP jurisdiction
self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)]) self.daemon.conf.jurisdiction = "KP"
await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', kp_final_node.port)
)
kp_final_node.server.session_mgr._notify_peer('127.0.0.1:9988')
await self.daemon.ledger.network.on_hub.first await self.daemon.ledger.network.on_hub.first
self.assertEqual(list(self.daemon.conf.known_hubs), [ await asyncio.sleep(0.5) # wait for above event to be processed by other listeners
(final_node.hostname, final_node.port), ('127.0.0.1', 9988) self.assertEqual(
]) self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {"country": "US"},
(kp_final_node.hostname, kp_final_node.port): {"country": "KP"},
('127.0.0.1', 9988): {}
}
)

View file

@ -47,6 +47,26 @@ class ConfigurationTests(unittest.TestCase):
c.persisted = {} c.persisted = {}
self.assertEqual(c.test_str, 'the default') self.assertEqual(c.test_str, 'the default')
def test_is_set(self):
c = TestConfig()
self.assertEqual(c.test_str, 'the default')
self.assertFalse(TestConfig.test_str.is_set(c))
c.test_str = 'new value'
self.assertEqual(c.test_str, 'new value')
self.assertTrue(TestConfig.test_str.is_set(c))
def test_is_set_to_default(self):
c = TestConfig()
self.assertEqual(TestConfig.test_str.default, 'the default')
self.assertFalse(TestConfig.test_str.is_set(c))
self.assertFalse(TestConfig.test_str.is_set_to_default(c))
c.test_str = 'new value'
self.assertTrue(TestConfig.test_str.is_set(c))
self.assertFalse(TestConfig.test_str.is_set_to_default(c))
c.test_str = 'the default'
self.assertTrue(TestConfig.test_str.is_set(c))
self.assertTrue(TestConfig.test_str.is_set_to_default(c))
def test_arguments(self): def test_arguments(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
TestConfig.contribute_to_argparse(parser) TestConfig.contribute_to_argparse(parser)