diff --git a/lbry/lbry/dht/node.py b/lbry/lbry/dht/node.py index 2159cb86a..9f439e25d 100644 --- a/lbry/lbry/dht/node.py +++ b/lbry/lbry/dht/node.py @@ -2,6 +2,7 @@ import logging import asyncio import typing import binascii +import socket from lbry.utils import resolve_host from lbry.dht import constants from lbry.dht.peer import make_kademlia_peer @@ -19,7 +20,8 @@ log = logging.getLogger(__name__) class Node: def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int, internal_udp_port: int, peer_port: int, external_ip: str, rpc_timeout: float = constants.rpc_timeout, - split_buckets_under_index: int = constants.split_buckets_under_index): + split_buckets_under_index: int = constants.split_buckets_under_index, + storage: typing.Optional['SQLiteStorage'] = None): self.loop = loop self.internal_udp_port = internal_udp_port self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port, rpc_timeout, @@ -28,6 +30,7 @@ class Node: self.joined = asyncio.Event(loop=self.loop) self._join_task: asyncio.Task = None self._refresh_task: asyncio.Task = None + self._storage = storage async def refresh_node(self, force_once=False): while True: @@ -67,6 +70,8 @@ class Node: to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True] if to_ping: self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0) + if self._storage: + await self._storage.save_kademlia_peers(self.protocol.routing_table.get_peers()) if force_once: break @@ -111,7 +116,7 @@ class Node: self.listening_port = None log.info("Stopped DHT node") - async def start_listening(self, interface: str = '') -> None: + async def start_listening(self, interface: str = '0.0.0.0') -> None: if not self.listening_port: self.listening_port, _ = await self.loop.create_datagram_endpoint( lambda: self.protocol, (interface, self.internal_udp_port) @@ -121,56 +126,53 @@ class Node: else: log.warning("Already bound to port %s", self.listening_port) - async def join_network(self, interface: typing.Optional[str] = '', + async def join_network(self, interface: str = '0.0.0.0', known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + def peers_from_urls(urls: typing.Optional[typing.List[typing.Tuple[bytes, str, int, int]]]): + peer_addresses = [] + for node_id, address, udp_port, tcp_port in urls: + if (node_id, address, udp_port, tcp_port) not in peer_addresses and \ + (address, udp_port) != (self.protocol.external_ip, self.protocol.udp_port): + peer_addresses.append((node_id, address, udp_port, tcp_port)) + return [make_kademlia_peer(*peer_address) for peer_address in peer_addresses] + if not self.listening_port: await self.start_listening(interface) self.protocol.ping_queue.start() self._refresh_task = self.loop.create_task(self.refresh_node()) - # resolve the known node urls - known_node_addresses = [] - url_to_addr = {} - - if known_node_urls: - for host, port in known_node_urls: - address = await resolve_host(host, port, proto='udp') - if (address, port) not in known_node_addresses and\ - (address, port) != (self.protocol.external_ip, self.protocol.udp_port): - known_node_addresses.append((address, port)) - url_to_addr[address] = host - - if known_node_addresses: - peers = [ - make_kademlia_peer(None, address, port) - for (address, port) in known_node_addresses - ] - while True: - if not self.protocol.routing_table.get_peers(): - if self.joined.is_set(): - self.joined.clear() - self.protocol.peer_manager.reset() - self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0) - peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32)) - if self.protocol.routing_table.get_peers(): - self.joined.set() - log.info( - "Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), - self.protocol.routing_table.buckets_with_contacts()) - else: + while True: + if self.protocol.routing_table.get_peers(): + if not self.joined.is_set(): + self.joined.set() + log.info( + "joined dht, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), + self.protocol.routing_table.buckets_with_contacts() + ) + else: + if self.joined.is_set(): + self.joined.clear() + seed_peers = peers_from_urls( + await self._storage.get_persisted_kademlia_peers() + ) if self._storage else [] + if not seed_peers: + try: + seed_peers.extend(peers_from_urls([ + (None, await resolve_host(address, udp_port, 'udp'), udp_port, None) + for address, udp_port in known_node_urls or [] + ])) + except socket.gaierror: + await asyncio.sleep(30, loop=self.loop) continue - await asyncio.sleep(1, loop=self.loop) - log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), - self.protocol.routing_table.buckets_with_contacts()) - self.joined.set() + self.protocol.peer_manager.reset() + self.protocol.ping_queue.enqueue_maybe_ping(*seed_peers, delay=0.0) + await self.peer_search(self.protocol.node_id, shortlist=seed_peers, count=32) - def start(self, interface: str, known_node_urls: typing.List[typing.Tuple[str, int]]): - self._join_task = self.loop.create_task( - self.join_network( - interface=interface, known_node_urls=known_node_urls - ) - ) + await asyncio.sleep(1, loop=self.loop) + + def start(self, interface: str, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls)) def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None, bottom_out_limit: int = constants.bottom_out_limit, diff --git a/lbry/lbry/extras/daemon/Components.py b/lbry/lbry/extras/daemon/Components.py index 78a793381..66d9acd9b 100644 --- a/lbry/lbry/extras/daemon/Components.py +++ b/lbry/lbry/extras/daemon/Components.py @@ -50,7 +50,7 @@ class DatabaseComponent(Component): @staticmethod def get_current_db_revision(): - return 13 + return 14 @property def revision_filename(self): @@ -189,7 +189,7 @@ class BlobComponent(Component): class DHTComponent(Component): component_name = DHT_COMPONENT - depends_on = [UPNP_COMPONENT] + depends_on = [UPNP_COMPONENT, DATABASE_COMPONENT] def __init__(self, component_manager): super().__init__(component_manager) @@ -223,6 +223,7 @@ class DHTComponent(Component): self.external_peer_port = upnp_component.upnp_redirects.get("TCP", self.conf.tcp_port) self.external_udp_port = upnp_component.upnp_redirects.get("UDP", self.conf.udp_port) external_ip = upnp_component.external_ip + storage = self.component_manager.get_component(DATABASE_COMPONENT) if not external_ip: external_ip = await utils.get_external_ip() if not external_ip: @@ -237,11 +238,10 @@ class DHTComponent(Component): external_ip=external_ip, peer_port=self.external_peer_port, rpc_timeout=self.conf.node_rpc_timeout, - split_buckets_under_index=self.conf.split_buckets_under_index - ) - self.dht_node.start( - interface=self.conf.network_interface, known_node_urls=self.conf.known_dht_nodes + split_buckets_under_index=self.conf.split_buckets_under_index, + storage=storage ) + self.dht_node.start(self.conf.network_interface, self.conf.known_dht_nodes) log.info("Started the dht") async def stop(self): diff --git a/lbry/lbry/extras/daemon/migrator/dbmigrator.py b/lbry/lbry/extras/daemon/migrator/dbmigrator.py index 99a1bb2b4..726cc1974 100644 --- a/lbry/lbry/extras/daemon/migrator/dbmigrator.py +++ b/lbry/lbry/extras/daemon/migrator/dbmigrator.py @@ -33,6 +33,8 @@ def migrate_db(conf, start, end): from .migrate11to12 import do_migration elif current == 12: from .migrate12to13 import do_migration + elif current == 13: + from .migrate13to14 import do_migration else: raise Exception(f"DB migration of version {current} to {current+1} is not available") try: diff --git a/lbry/lbry/extras/daemon/migrator/migrate13to14.py b/lbry/lbry/extras/daemon/migrator/migrate13to14.py new file mode 100644 index 000000000..5cbd6d3fa --- /dev/null +++ b/lbry/lbry/extras/daemon/migrator/migrate13to14.py @@ -0,0 +1,21 @@ +import os +import sqlite3 + + +def do_migration(conf): + db_path = os.path.join(conf.data_dir, "lbrynet.sqlite") + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + + cursor.executescript(""" + create table if not exists peer ( + node_id char(96) not null primary key, + address text not null, + udp_port integer not null, + tcp_port integer, + unique (address, udp_port) + ); + """) + + connection.commit() + connection.close() diff --git a/lbry/lbry/extras/daemon/storage.py b/lbry/lbry/extras/daemon/storage.py index b5859d6c1..0c8eeb02f 100644 --- a/lbry/lbry/extras/daemon/storage.py +++ b/lbry/lbry/extras/daemon/storage.py @@ -329,6 +329,14 @@ class SQLiteStorage(SQLiteMixin): timestamp integer, primary key (sd_hash, reflector_address) ); + + create table if not exists peer ( + node_id char(96) not null primary key, + address text not null, + udp_port integer not null, + tcp_port integer, + unique (address, udp_port) + ); """ def __init__(self, conf: Config, path, loop=None, time_getter: typing.Optional[typing.Callable[[], float]] = None): @@ -805,3 +813,17 @@ class SQLiteStorage(SQLiteMixin): "where r.timestamp is null or r.timestamp < ?", int(self.time_getter()) - 86400 ) + + # # # # # # # # # # dht functions # # # # # # # # # # # + async def get_persisted_kademlia_peers(self) -> typing.List[typing.Tuple[bytes, str, int, int]]: + query = 'select node_id, address, udp_port, tcp_port from peer' + return [(binascii.unhexlify(n), a, u, t) for n, a, u, t in await self.db.execute_fetchall(query)] + + async def save_kademlia_peers(self, peers: typing.List['KademliaPeer']): + def _save_kademlia_peers(transaction: sqlite3.Connection): + transaction.execute('delete from peer').fetchall() + transaction.executemany( + 'insert into peer(node_id, address, udp_port, tcp_port) values (?, ?, ?, ?)', + tuple([(binascii.hexlify(p.node_id), p.address, p.udp_port, p.tcp_port) for p in peers]) + ).fetchall() + return await self.db.run(_save_kademlia_peers) diff --git a/lbry/tests/dht_mocks.py b/lbry/tests/dht_mocks.py index cec9b8e80..2e01986c0 100644 --- a/lbry/tests/dht_mocks.py +++ b/lbry/tests/dht_mocks.py @@ -48,8 +48,9 @@ def get_time_accelerator(loop: asyncio.AbstractEventLoop, @contextlib.contextmanager -def mock_network_loop(loop: asyncio.AbstractEventLoop): - dht_network: typing.Dict[typing.Tuple[str, int], 'KademliaProtocol'] = {} +def mock_network_loop(loop: asyncio.AbstractEventLoop, + dht_network: typing.Optional[typing.Dict[typing.Tuple[str, int], 'KademliaProtocol']] = None): + dht_network: typing.Dict[typing.Tuple[str, int], 'KademliaProtocol'] = dht_network if dht_network is not None else {} async def create_datagram_endpoint(proto_lam: typing.Callable[[], 'KademliaProtocol'], from_addr: typing.Tuple[str, int]): diff --git a/lbry/tests/integration/test_dht.py b/lbry/tests/integration/test_dht.py index 25443ee6c..761a3e315 100644 --- a/lbry/tests/integration/test_dht.py +++ b/lbry/tests/integration/test_dht.py @@ -1,6 +1,8 @@ import asyncio from binascii import hexlify +from lbry.extras.daemon.storage import SQLiteStorage +from lbry.conf import Config from lbry.dht import constants from lbry.dht.node import Node from lbry.dht import peer as dht_peer @@ -19,24 +21,32 @@ class DHTIntegrationTest(AsyncioTestCase): self.nodes = [] self.known_node_addresses = [] - async def setup_network(self, size: int, start_port=40000, seed_nodes=1): + async def create_node(self, node_id, port, external_ip='127.0.0.1'): + storage = SQLiteStorage(Config(), ":memory:", self.loop, self.loop.time) + await storage.open() + node = Node(self.loop, PeerManager(self.loop), node_id=node_id, + udp_port=port, internal_udp_port=port, + peer_port=3333, external_ip=external_ip, + storage=storage) + self.addCleanup(node.stop) + node.protocol.rpc_timeout = .5 + node.protocol.ping_queue._default_delay = .5 + return node + + async def setup_network(self, size: int, start_port=40000, seed_nodes=1, external_ip='127.0.0.1'): for i in range(size): node_port = start_port + i - node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i), - udp_port=node_port, internal_udp_port=node_port, - peer_port=3333, external_ip='127.0.0.1') + node_id = constants.generate_id(i) + node = await self.create_node(node_id, node_port) self.nodes.append(node) - self.known_node_addresses.append(('127.0.0.1', node_port)) - await node.start_listening('127.0.0.1') - self.addCleanup(node.stop) + self.known_node_addresses.append((external_ip, node_port)) + for node in self.nodes: - node.protocol.rpc_timeout = .5 - node.protocol.ping_queue._default_delay = .5 - node.start('127.0.0.1', self.known_node_addresses[:seed_nodes]) - await asyncio.gather(*[node.joined.wait() for node in self.nodes]) + node.start(external_ip, self.known_node_addresses[:seed_nodes]) async def test_replace_bad_nodes(self): await self.setup_network(20) + await asyncio.gather(*[node.joined.wait() for node in self.nodes]) self.assertEqual(len(self.nodes), 20) node = self.nodes[0] bad_peers = [] @@ -57,6 +67,7 @@ class DHTIntegrationTest(AsyncioTestCase): async def test_re_join(self): await self.setup_network(20, seed_nodes=10) + await asyncio.gather(*[node.joined.wait() for node in self.nodes]) node = self.nodes[-1] self.assertTrue(node.joined.is_set()) self.assertTrue(node.protocol.routing_table.get_peers()) @@ -84,6 +95,7 @@ class DHTIntegrationTest(AsyncioTestCase): async def test_get_token_on_announce(self): await self.setup_network(2, seed_nodes=2) + await asyncio.gather(*[node.joined.wait() for node in self.nodes]) node1, node2 = self.nodes node1.protocol.peer_manager.clear_token(node2.protocol.node_id) blob_hash = hexlify(constants.generate_id(1337)).decode() @@ -101,6 +113,7 @@ class DHTIntegrationTest(AsyncioTestCase): # imagine that you only got bad peers and refresh will happen in one hour # instead of failing for one hour we should be able to recover by scheduling pings to bad peers we find await self.setup_network(2, seed_nodes=2) + await asyncio.gather(*[node.joined.wait() for node in self.nodes]) node1, node2 = self.nodes node2.stop() # forcefully make it a bad peer but don't remove it from routing table @@ -116,3 +129,37 @@ class DHTIntegrationTest(AsyncioTestCase): await node1.peer_search(node2.protocol.node_id) await asyncio.sleep(.3) # let pending events settle self.assertFalse(node1.protocol.routing_table.get_peers()) + + async def test_peer_persistance(self): + num_nodes = 6 + start_port = 40000 + num_seeds = 2 + external_ip = '127.0.0.1' + + # Start a node + await self.setup_network(num_nodes, start_port=start_port, seed_nodes=num_seeds) + await asyncio.gather(*[node.joined.wait() for node in self.nodes]) + + node1 = self.nodes[-1] + peer_args = [(n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port, n.protocol.peer_port) for n in + self.nodes[:num_seeds]] + peers = [make_kademlia_peer(*args) for args in peer_args] + + # node1 is bootstrapped from the fixed seeds + self.assertCountEqual(peers, node1.protocol.routing_table.get_peers()) + + # Refresh and assert that the peers were persisted + await node1.refresh_node(True) + self.assertEqual(len(peer_args), len(await node1._storage.get_persisted_kademlia_peers())) + node1.stop() + + # Start a fresh node with the same node_id and storage, but no known peers + node2 = await self.create_node(constants.generate_id(num_nodes-1), start_port+num_nodes-1) + node2._storage = node1._storage + node2.start(external_ip, []) + await node2.joined.wait() + + # The peers are restored + self.assertEqual(num_seeds, len(node2.protocol.routing_table.get_peers())) + for bucket1, bucket2 in zip(node1.protocol.routing_table.buckets, node2.protocol.routing_table.buckets): + self.assertEqual((bucket1.range_min, bucket1.range_max), (bucket2.range_min, bucket2.range_max)) diff --git a/lbry/tests/unit/database/test_SQLiteStorage.py b/lbry/tests/unit/database/test_SQLiteStorage.py index e0a11ad84..ab63512e7 100644 --- a/lbry/tests/unit/database/test_SQLiteStorage.py +++ b/lbry/tests/unit/database/test_SQLiteStorage.py @@ -3,6 +3,7 @@ import tempfile import unittest import asyncio import logging +import hashlib from torba.testcase import AsyncioTestCase from lbry.conf import Config from lbry.extras.daemon.storage import SQLiteStorage @@ -10,6 +11,7 @@ from lbry.blob.blob_info import BlobInfo from lbry.blob.blob_manager import BlobManager from lbry.stream.descriptor import StreamDescriptor from tests.test_utils import random_lbry_hash +from lbry.dht.peer import make_kademlia_peer log = logging.getLogger() @@ -247,3 +249,13 @@ class ContentClaimStorageTests(StorageTest): current_claim_info = await self.storage.get_content_claim(stream_hash) # this should still be the previous update self.assertDictEqual(current_claim_info, update_info) + + +class UpdatePeersTest(StorageTest): + async def test_update_get_peers(self): + node_id = hashlib.sha384("1234".encode()).digest() + args = (node_id, '73.186.148.72', 4444, None) + fake_peer = make_kademlia_peer(*args) + await self.storage.save_kademlia_peers([fake_peer]) + peers = await self.storage.get_persisted_kademlia_peers() + self.assertTupleEqual(args, peers[0]) diff --git a/lbry/tests/unit/dht/test_node.py b/lbry/tests/unit/dht/test_node.py index 03902b553..9872666f8 100644 --- a/lbry/tests/unit/dht/test_node.py +++ b/lbry/tests/unit/dht/test_node.py @@ -2,9 +2,11 @@ import asyncio import typing from torba.testcase import AsyncioTestCase from tests import dht_mocks +from lbry.conf import Config from lbry.dht import constants from lbry.dht.node import Node from lbry.dht.peer import PeerManager, make_kademlia_peer +from lbry.extras.daemon.storage import SQLiteStorage class TestNodePingQueueDiscover(AsyncioTestCase): @@ -84,3 +86,75 @@ class TestNodePingQueueDiscover(AsyncioTestCase): # teardown for n in nodes.values(): n.stop() + + +class TestTemporarilyLosingConnection(AsyncioTestCase): + + async def test_losing_connection(self): + async def wait_for(check_ok, insist, timeout=20): + start = loop.time() + while loop.time() - start < timeout: + if check_ok(): + break + await asyncio.sleep(0) + else: + insist() + + loop = self.loop + loop.set_debug(False) + + peer_addresses = [ + ('1.2.3.4', 40000+i) for i in range(10) + ] + node_ids = [constants.generate_id(i) for i in range(10)] + + nodes = [ + Node( + loop, PeerManager(loop), node_id, udp_port, udp_port, 3333, address, + storage=SQLiteStorage(Config(), ":memory:", self.loop, self.loop.time) + ) + for node_id, (address, udp_port) in zip(node_ids, peer_addresses) + ] + dht_network = {peer_addresses[i]: node.protocol for i, node in enumerate(nodes)} + num_seeds = 3 + + with dht_mocks.mock_network_loop(loop, dht_network): + for i, n in enumerate(nodes): + await n._storage.open() + self.addCleanup(n.stop) + n.start(peer_addresses[i][0], peer_addresses[:num_seeds]) + await asyncio.gather(*[n.joined.wait() for n in nodes]) + + node = nodes[-1] + advance = dht_mocks.get_time_accelerator(loop, loop.time()) + await advance(500) + + # Join the network, assert that at least the known peers are in RT + self.assertTrue(node.joined.is_set()) + self.assertTrue(len(node.protocol.routing_table.get_peers()) >= num_seeds) + + # Refresh, so that the peers are persisted + self.assertFalse(len(await node._storage.get_persisted_kademlia_peers()) > num_seeds) + await advance(4000) + self.assertTrue(len(await node._storage.get_persisted_kademlia_peers()) > num_seeds) + + # We lost internet connection - all the peers stop responding + dht_network.pop((node.protocol.external_ip, node.protocol.udp_port)) + + # The peers are cleared on refresh from RT and storage + await advance(4000) + self.assertListEqual([], await node._storage.get_persisted_kademlia_peers()) + await wait_for( + lambda: len(node.protocol.routing_table.get_peers()) == 0, + lambda: self.assertListEqual(node.protocol.routing_table.get_peers(), []) + ) + + # Reconnect + dht_network[(node.protocol.external_ip, node.protocol.udp_port)] = node.protocol + + # Check that node reconnects at least to them + await advance(1000) + await wait_for( + lambda: len(node.protocol.routing_table.get_peers()) >= num_seeds, + lambda: self.assertTrue(len(node.protocol.routing_table.get_peers()) >= num_seeds) + )