added support to config for determining if value is set and implemented hub selection logic

This commit is contained in:
Lex Berezhny 2021-06-14 13:06:31 -04:00
parent 5f0426c840
commit 7d49b046d4
9 changed files with 154 additions and 39 deletions

View file

@ -63,6 +63,18 @@ class Setting(Generic[T]):
for location in obj.modify_order:
location[self.name] = val
def is_set(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return True
return False
def is_set_to_default(self, obj: 'BaseConfig') -> bool:
for location in obj.search_order:
if self.name in location:
return location[self.name] == self.default
return False
def validate(self, value):
raise NotImplementedError()
@ -577,6 +589,9 @@ class CLIConfig(TranscodeConfig):
class Config(CLIConfig):
jurisdiction = String("Limit interactions to wallet server in this jurisdiction.")
# directories
data_dir = Path("Directory path to store blobs.", metavar='DIR')
download_dir = Path(

View file

@ -138,7 +138,7 @@ class WalletComponent(Component):
'availability': session.available,
} for session in sessions
],
'known_servers': len(self.wallet_manager.ledger.network.config['default_servers']),
'known_servers': len(self.wallet_manager.ledger.network.known_hubs),
'available_servers': 1 if is_connected else 0
}

View file

@ -378,6 +378,7 @@ class CommandTestCase(IntegrationTestCase):
await daemon.stop()
async def add_daemon(self, wallet_node=None, seed=None):
start_wallet_node = False
if wallet_node is None:
wallet_node = WalletNode(
self.wallet_node.manager_class,
@ -385,8 +386,7 @@ class CommandTestCase(IntegrationTestCase):
port=self.extra_wallet_node_port
)
self.extra_wallet_node_port += 1
await wallet_node.start(self.conductor.spv_node, seed=seed)
self.extra_wallet_nodes.append(wallet_node)
start_wallet_node = True
upload_dir = os.path.join(wallet_node.data_path, 'uploads')
os.mkdir(upload_dir)
@ -414,8 +414,13 @@ class CommandTestCase(IntegrationTestCase):
]
if self.skip_libtorrent:
conf.components_to_skip.append(LIBTORRENT_COMPONENT)
wallet_node.manager.config = conf
wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs
if start_wallet_node:
await wallet_node.start(self.conductor.spv_node, seed=seed, config=conf)
self.extra_wallet_nodes.append(wallet_node)
else:
wallet_node.manager.config = conf
wallet_node.manager.ledger.config['known_hubs'] = conf.known_hubs
def wallet_maker(component_manager):
wallet_component = WalletComponent(component_manager)

View file

@ -8,7 +8,7 @@ from decimal import Decimal
from typing import List, Type, MutableSequence, MutableMapping, Optional
from lbry.error import KeyFeeAboveMaxAllowedError
from lbry.conf import Config
from lbry.conf import Config, NOT_SET
from .dewies import dewies_to_lbc
from .account import Account
@ -184,6 +184,7 @@ class WalletManager:
'auto_connect': True,
'default_servers': config.lbryum_servers,
'known_hubs': config.known_hubs,
'jurisdiction': config.jurisdiction,
'data_path': config.wallet_dir,
'tx_cache_size': config.transaction_cache_size
}
@ -196,6 +197,10 @@ class WalletManager:
os.path.join(wallets_directory, 'default_wallet')
)
if Config.lbryum_servers.is_set_to_default(config):
with config.update_config() as c:
c.lbryum_servers = NOT_SET
manager = cls.from_config({
'ledgers': {ledger_id: ledger_config},
'wallets': [
@ -226,10 +231,14 @@ class WalletManager:
async def reset(self):
self.ledger.config = {
'auto_connect': True,
'default_servers': self.config.lbryum_servers,
'explicit_servers': [],
'default_servers': Config.lbryum_servers.default,
'known_hubs': self.config.known_hubs,
'jurisdiction': self.config.jurisdiction,
'data_path': self.config.wallet_dir,
}
if Config.lbryum_servers.is_set(self.config):
self.ledger.config['explicit_servers'] = self.config.lbryum_servers
await self.ledger.stop()
await self.ledger.start()

View file

@ -186,7 +186,13 @@ class Network:
@property
def known_hubs(self):
return self.config.get('known_hubs', KnownHubsList())
if 'known_hubs' not in self.config:
return KnownHubsList()
return self.config['known_hubs']
@property
def jurisdiction(self):
return self.config.get("jurisdiction")
def disconnect(self):
if self._keepalive_task and not self._keepalive_task.done():
@ -226,18 +232,23 @@ class Network:
log.exception("error looking up dns for spv server %s:%i", server, port)
# accumulate the dns results
hubs = self.known_hubs if self.known_hubs else self.config['default_servers']
if self.config['explicit_servers']:
hubs = self.config['explicit_servers']
elif self.known_hubs:
hubs = self.known_hubs
else:
hubs = self.config['default_servers']
await asyncio.gather(*(resolve_spv(server, port) for (server, port) in hubs))
return hostname_to_ip, ip_to_hostnames
return hubs, hostname_to_ip, ip_to_hostnames
async def get_n_fastest_spvs(self, timeout=3.0) -> Dict[Tuple[str, int], Optional[SPVPong]]:
loop = asyncio.get_event_loop()
pong_responses = asyncio.Queue()
connection = SPVStatusClientProtocol(pong_responses)
sent_ping_timestamps = {}
_, ip_to_hostnames = await self.resolve_spv_dns()
hubs, _, ip_to_hostnames = await self.resolve_spv_dns()
n = len(ip_to_hostnames)
log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['default_servers']))
log.info("%i possible spv servers to try (%i urls in config)", n, len(self.config['explicit_servers']))
pongs = {}
try:
await loop.create_datagram_endpoint(lambda: connection, ('0.0.0.0', 0))
@ -246,6 +257,7 @@ class Network:
for server in ip_to_hostnames:
connection.ping(server)
sent_ping_timestamps[server] = perf_counter()
known_hubs = self.known_hubs
while len(pongs) < n:
(remote, ts), pong = await asyncio.wait_for(pong_responses.get(), timeout - (perf_counter() - start))
latency = ts - start
@ -253,6 +265,9 @@ class Network:
'/'.join(ip_to_hostnames[remote]), remote[1], round(latency * 1000, 2),
pong.available, pong.height)
known_hubs.hubs.setdefault((ip_to_hostnames[remote][0], remote[1]), {}).update(
{"country": pong.country_name}
)
if pong.available:
pongs[(ip_to_hostnames[remote][0], remote[1])] = pong
return pongs
@ -267,13 +282,17 @@ class Network:
random_server = random.choice(list(ip_to_hostnames.keys()))
host, port = random_server
log.warning("trying fallback to randomly selected spv: %s:%i", host, port)
known_hubs.hubs.setdefault((host, port), {})
return {(host, port): None}
finally:
connection.close()
async def connect_to_fastest(self) -> Optional[ClientSession]:
fastest_spvs = await self.get_n_fastest_spvs()
for (host, port) in fastest_spvs:
for (host, port), pong in fastest_spvs.items():
if (pong is not None and self.jurisdiction is not None) and \
(pong.country_name != self.jurisdiction):
continue
client = ClientSession(network=self, server=(host, port))
try:
await client.create_connection()
@ -305,7 +324,7 @@ class Network:
features = await client.send_request('server.features', [])
self.client, self.server_features = client, features
log.debug("discover other hubs %s:%i", *client.server)
self._update_hubs(await client.send_request('server.peers.subscribe', []))
await self._update_hubs(await client.send_request('server.peers.subscribe', []))
log.info("subscribe to headers %s:%i", *client.server)
self._update_remote_height((await self.subscribe_headers(),))
self._on_connected_controller.add(True)
@ -375,8 +394,8 @@ class Network:
def _update_remote_height(self, header_args):
self.remote_height = header_args[0]["height"]
def _update_hubs(self, hubs):
if hubs:
async def _update_hubs(self, hubs):
if hubs and hubs != ['']:
try:
if self.known_hubs.add_hubs(hubs):
self.known_hubs.save()

View file

@ -17,6 +17,7 @@ import lbry
from lbry.wallet.server.server import Server
from lbry.wallet.server.env import Env
from lbry.wallet import Wallet, Ledger, RegTestLedger, WalletManager, Account, BlockHeightEvent
from lbry.conf import KnownHubsList, Config
log = logging.getLogger(__name__)
@ -107,7 +108,8 @@ class Conductor:
class WalletNode:
def __init__(self, manager_class: Type[WalletManager], ledger_class: Type[Ledger],
verbose: bool = False, port: int = 5280, default_seed: str = None) -> None:
verbose: bool = False, port: int = 5280, default_seed: str = None,
data_path: str = None) -> None:
self.manager_class = manager_class
self.ledger_class = ledger_class
self.verbose = verbose
@ -115,12 +117,12 @@ class WalletNode:
self.ledger: Optional[Ledger] = None
self.wallet: Optional[Wallet] = None
self.account: Optional[Account] = None
self.data_path: Optional[str] = None
self.data_path: str = data_path or tempfile.mkdtemp()
self.port = port
self.default_seed = default_seed
self.known_hubs = KnownHubsList()
async def start(self, spv_node: 'SPVNode', seed=None, connect=True):
self.data_path = tempfile.mkdtemp()
async def start(self, spv_node: 'SPVNode', seed=None, connect=True, config=None):
wallets_dir = os.path.join(self.data_path, 'wallets')
os.mkdir(wallets_dir)
wallet_file_name = os.path.join(wallets_dir, 'my_wallet.json')
@ -130,12 +132,15 @@ class WalletNode:
'ledgers': {
self.ledger_class.get_id(): {
'api_port': self.port,
'default_servers': [(spv_node.hostname, spv_node.port)],
'data_path': self.data_path
'explicit_servers': [(spv_node.hostname, spv_node.port)],
'default_servers': Config.lbryum_servers.default,
'data_path': self.data_path,
'known_hubs': config.known_hubs if config else KnownHubsList()
}
},
'wallets': [wallet_file_name]
})
self.manager.config = config
self.ledger = self.manager.ledgers[self.ledger_class]
self.wallet = self.manager.default_wallet
if not self.wallet:
@ -172,6 +177,7 @@ class SPVNode:
self.udp_port = self.port
self.session_timeout = 600
self.rpc_port = '0' # disabled by default
self.stopped = False
async def start(self, blockchain_node: 'BlockchainNode', extraconf=None):
self.data_path = tempfile.mkdtemp()
@ -201,10 +207,13 @@ class SPVNode:
await self.server.start()
async def stop(self, cleanup=True):
if self.stopped:
return
try:
await self.server.db.search_index.delete_index()
await self.server.db.search_index.stop()
await self.server.stop()
self.stopped = True
finally:
cleanup and self.cleanup()

View file

@ -37,6 +37,9 @@ class SPVPing(NamedTuple):
return decoded
PONG_ENCODING = b'!BBL32s4sH'
class SPVPong(NamedTuple):
protocol_version: int
flags: int
@ -46,7 +49,7 @@ class SPVPong(NamedTuple):
country: int
def encode(self):
return struct.pack(b'!BBL32s4sH', *self)
return struct.pack(PONG_ENCODING, *self)
@staticmethod
def encode_address(address: str):
@ -67,7 +70,7 @@ class SPVPong(NamedTuple):
@classmethod
def decode(cls, packet: bytes):
return cls(*struct.unpack(b'!BBl32s4s', packet[:42]))
return cls(*struct.unpack(PONG_ENCODING, packet[:44]))
@property
def available(self) -> bool:

View file

@ -118,13 +118,21 @@ class TestESSync(CommandTestCase):
class TestHubDiscovery(CommandTestCase):
async def test_hub_discovery(self):
final_node = SPVNode(self.conductor.spv_module, node_number=2)
await final_node.start(self.blockchain)
self.addCleanup(final_node.stop)
final_node_host = f"{final_node.hostname}:{final_node.port}"
us_final_node = SPVNode(self.conductor.spv_module, node_number=2)
await us_final_node.start(self.blockchain, extraconf={"COUNTRY": "US"})
self.addCleanup(us_final_node.stop)
final_node_host = f"{us_final_node.hostname}:{us_final_node.port}"
relay_node = SPVNode(self.conductor.spv_module, node_number=3)
await relay_node.start(self.blockchain, extraconf={"PEER_HUBS": final_node_host})
kp_final_node = SPVNode(self.conductor.spv_module, node_number=3)
await kp_final_node.start(self.blockchain, extraconf={"COUNTRY": "KP"})
self.addCleanup(kp_final_node.stop)
kp_final_node_host = f"{kp_final_node.hostname}:{kp_final_node.port}"
relay_node = SPVNode(self.conductor.spv_module, node_number=4)
await relay_node.start(self.blockchain, extraconf={
"COUNTRY": "FR",
"PEER_HUBS": ",".join([kp_final_node_host, final_node_host])
})
relay_node_host = f"{relay_node.hostname}:{relay_node.port}"
self.addCleanup(relay_node.stop)
@ -134,23 +142,50 @@ class TestHubDiscovery(CommandTestCase):
('127.0.0.1', 50002)
)
# connect to relay hub which will tell us about the final hub
# connect to relay hub which will tell us about the final hubs
self.daemon.jsonrpc_settings_set('lbryum_servers', [relay_node_host])
await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)])
self.assertEqual(
self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {}, # discovered from relay but not contacted yet
(kp_final_node.hostname, kp_final_node.port): {}, # discovered from relay but not contacted yet
}
)
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', relay_node.port)
)
# use known_hubs to connect to final hub
# use known_hubs to connect to final US hub
self.daemon.jsonrpc_settings_clear('lbryum_servers')
self.daemon.conf.jurisdiction = "US"
await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', final_node.port)
self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {"country": "US"},
(kp_final_node.hostname, kp_final_node.port): {"country": "KP"},
}
)
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', us_final_node.port)
)
final_node.server.session_mgr._notify_peer('127.0.0.1:9988')
self.assertEqual(list(self.daemon.conf.known_hubs), [(final_node.hostname, final_node.port)])
# connection to KP jurisdiction
self.daemon.conf.jurisdiction = "KP"
await self.daemon.jsonrpc_wallet_reconnect()
self.assertEqual(
self.daemon.ledger.network.client.server_address_and_port, ('127.0.0.1', kp_final_node.port)
)
kp_final_node.server.session_mgr._notify_peer('127.0.0.1:9988')
await self.daemon.ledger.network.on_hub.first
self.assertEqual(list(self.daemon.conf.known_hubs), [
(final_node.hostname, final_node.port), ('127.0.0.1', 9988)
])
await asyncio.sleep(0.5) # wait for above event to be processed by other listeners
self.assertEqual(
self.daemon.conf.known_hubs.filter(), {
(relay_node.hostname, relay_node.port): {"country": "FR"},
(us_final_node.hostname, us_final_node.port): {"country": "US"},
(kp_final_node.hostname, kp_final_node.port): {"country": "KP"},
('127.0.0.1', 9988): {}
}
)

View file

@ -47,6 +47,26 @@ class ConfigurationTests(unittest.TestCase):
c.persisted = {}
self.assertEqual(c.test_str, 'the default')
def test_is_set(self):
c = TestConfig()
self.assertEqual(c.test_str, 'the default')
self.assertFalse(TestConfig.test_str.is_set(c))
c.test_str = 'new value'
self.assertEqual(c.test_str, 'new value')
self.assertTrue(TestConfig.test_str.is_set(c))
def test_is_set_to_default(self):
c = TestConfig()
self.assertEqual(TestConfig.test_str.default, 'the default')
self.assertFalse(TestConfig.test_str.is_set(c))
self.assertFalse(TestConfig.test_str.is_set_to_default(c))
c.test_str = 'new value'
self.assertTrue(TestConfig.test_str.is_set(c))
self.assertFalse(TestConfig.test_str.is_set_to_default(c))
c.test_str = 'the default'
self.assertTrue(TestConfig.test_str.is_set(c))
self.assertTrue(TestConfig.test_str.is_set_to_default(c))
def test_arguments(self):
parser = argparse.ArgumentParser()
TestConfig.contribute_to_argparse(parser)