diff --git a/lbry/lbry/dht/node.py b/lbry/lbry/dht/node.py index 2159cb86a..f9c44c72e 100644 --- a/lbry/lbry/dht/node.py +++ b/lbry/lbry/dht/node.py @@ -19,7 +19,8 @@ log = logging.getLogger(__name__) class Node: def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int, internal_udp_port: int, peer_port: int, external_ip: str, rpc_timeout: float = constants.rpc_timeout, - split_buckets_under_index: int = constants.split_buckets_under_index): + split_buckets_under_index: int = constants.split_buckets_under_index, + storage: typing.Optional['SQLiteStorage'] = None): self.loop = loop self.internal_udp_port = internal_udp_port self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port, rpc_timeout, @@ -28,6 +29,7 @@ class Node: self.joined = asyncio.Event(loop=self.loop) self._join_task: asyncio.Task = None self._refresh_task: asyncio.Task = None + self._storage = storage async def refresh_node(self, force_once=False): while True: @@ -50,6 +52,9 @@ class Node: node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i)) node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i)) + if self._storage: + await self._storage.update_peers(self.protocol.routing_table.get_peers()) + if self.protocol.routing_table.get_peers(): # if we have node ids to look up, perform the iterative search until we have k results while node_ids: @@ -111,7 +116,7 @@ class Node: self.listening_port = None log.info("Stopped DHT node") - async def start_listening(self, interface: str = '') -> None: + async def start_listening(self, interface: str = '0.0.0.0') -> None: if not self.listening_port: self.listening_port, _ = await self.loop.create_datagram_endpoint( lambda: self.protocol, (interface, self.internal_udp_port) @@ -121,56 +126,55 @@ class Node: else: log.warning("Already bound to port %s", self.listening_port) - async def join_network(self, interface: typing.Optional[str] = '', + async def join_network(self, interface: str = '0.0.0.0', known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + def peers_from_urls(urls: typing.Optional[typing.List[typing.Tuple[bytes, str, int, int]]]): + peer_addresses = [] + for node_id, address, udp_port, tcp_port in urls: + if (node_id, address, udp_port, tcp_port) not in peer_addresses and \ + (address, udp_port) != (self.protocol.external_ip, self.protocol.udp_port): + peer_addresses.append((node_id, address, udp_port, tcp_port)) + return [make_kademlia_peer(*peer_address) for peer_address in peer_addresses] + + def set_joined(): + self.joined.set() + log.info( + "joined dht, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), + self.protocol.routing_table.buckets_with_contacts() + ) + if not self.listening_port: await self.start_listening(interface) self.protocol.ping_queue.start() self._refresh_task = self.loop.create_task(self.refresh_node()) - # resolve the known node urls - known_node_addresses = [] - url_to_addr = {} + restored_peers = peers_from_urls(await self._storage.get_peers()) if self._storage else [] - if known_node_urls: - for host, port in known_node_urls: - address = await resolve_host(host, port, proto='udp') - if (address, port) not in known_node_addresses and\ - (address, port) != (self.protocol.external_ip, self.protocol.udp_port): - known_node_addresses.append((address, port)) - url_to_addr[address] = host + fixed_peers = peers_from_urls([ + (None, await resolve_host(address, udp_port, 'udp'), udp_port, None) + for address, udp_port in known_node_urls or [] + ]) - if known_node_addresses: - peers = [ - make_kademlia_peer(None, address, port) - for (address, port) in known_node_addresses - ] - while True: - if not self.protocol.routing_table.get_peers(): - if self.joined.is_set(): - self.joined.clear() - self.protocol.peer_manager.reset() - self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0) - peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32)) - if self.protocol.routing_table.get_peers(): - self.joined.set() - log.info( - "Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), - self.protocol.routing_table.buckets_with_contacts()) - else: - continue - await asyncio.sleep(1, loop=self.loop) + seed_peers = restored_peers or fixed_peers + fallback = False + while seed_peers: + if self.protocol.routing_table.get_peers(): + if not self.joined.is_set(): + set_joined() + else: + if self.joined.is_set(): + self.joined.clear() + seed_peers = fixed_peers if fallback else seed_peers + self.protocol.peer_manager.reset() + self.protocol.ping_queue.enqueue_maybe_ping(*seed_peers, delay=0.0) + seed_peers.extend(await self.peer_search(self.protocol.node_id, shortlist=seed_peers, count=32)) + fallback = not self.protocol.routing_table.get_peers() + await asyncio.sleep(1, loop=self.loop) - log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), - self.protocol.routing_table.buckets_with_contacts()) - self.joined.set() + set_joined() - def start(self, interface: str, known_node_urls: typing.List[typing.Tuple[str, int]]): - self._join_task = self.loop.create_task( - self.join_network( - interface=interface, known_node_urls=known_node_urls - ) - ) + def start(self, interface: str, known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + self._join_task = self.loop.create_task(self.join_network(interface, known_node_urls)) def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None, bottom_out_limit: int = constants.bottom_out_limit, diff --git a/lbry/lbry/extras/daemon/Components.py b/lbry/lbry/extras/daemon/Components.py index 78a793381..98b1d5250 100644 --- a/lbry/lbry/extras/daemon/Components.py +++ b/lbry/lbry/extras/daemon/Components.py @@ -189,7 +189,7 @@ class BlobComponent(Component): class DHTComponent(Component): component_name = DHT_COMPONENT - depends_on = [UPNP_COMPONENT] + depends_on = [UPNP_COMPONENT, DATABASE_COMPONENT] def __init__(self, component_manager): super().__init__(component_manager) @@ -223,6 +223,7 @@ class DHTComponent(Component): self.external_peer_port = upnp_component.upnp_redirects.get("TCP", self.conf.tcp_port) self.external_udp_port = upnp_component.upnp_redirects.get("UDP", self.conf.udp_port) external_ip = upnp_component.external_ip + storage = self.component_manager.get_component(DATABASE_COMPONENT) if not external_ip: external_ip = await utils.get_external_ip() if not external_ip: @@ -237,11 +238,10 @@ class DHTComponent(Component): external_ip=external_ip, peer_port=self.external_peer_port, rpc_timeout=self.conf.node_rpc_timeout, - split_buckets_under_index=self.conf.split_buckets_under_index - ) - self.dht_node.start( - interface=self.conf.network_interface, known_node_urls=self.conf.known_dht_nodes + split_buckets_under_index=self.conf.split_buckets_under_index, + storage=storage ) + self.dht_node.start(self.conf.network_interface, self.conf.known_dht_nodes) log.info("Started the dht") async def stop(self): diff --git a/lbry/lbry/extras/daemon/storage.py b/lbry/lbry/extras/daemon/storage.py index b5859d6c1..2dfee2ce3 100644 --- a/lbry/lbry/extras/daemon/storage.py +++ b/lbry/lbry/extras/daemon/storage.py @@ -329,6 +329,14 @@ class SQLiteStorage(SQLiteMixin): timestamp integer, primary key (sd_hash, reflector_address) ); + + create table if not exists peer ( + address text not null, + udp_port integer not null, + tcp_port integer, + node_id char(96) unique not null, + primary key (address, udp_port) + ); """ def __init__(self, conf: Config, path, loop=None, time_getter: typing.Optional[typing.Callable[[], float]] = None): @@ -805,3 +813,20 @@ class SQLiteStorage(SQLiteMixin): "where r.timestamp is null or r.timestamp < ?", int(self.time_getter()) - 86400 ) + + # # # # # # # # # # dht functions # # # # # # # # # # # + async def get_peers(self): + query = 'select node_id, address, udp_port, tcp_port from peer' + return [(binascii.unhexlify(n), a, u, t) for n, a, u, t in await self.db.execute_fetchall(query)] + + async def update_peers(self, peers: typing.List['KademliaPeer']): + def _update_peers(transaction: sqlite3.Connection): + transaction.execute('delete from peer').fetchall() + transaction.executemany( + 'insert into peer(node_id, address, udp_port, tcp_port) values (?, ?, ?, ?)', ( + tuple( + [(binascii.hexlify(p.node_id), p.address, p.udp_port, p.tcp_port) for p in peers] + ) + ) + ).fetchall() + return await self.db.run(_update_peers) diff --git a/lbry/tests/integration/test_dht.py b/lbry/tests/integration/test_dht.py index 25443ee6c..080562451 100644 --- a/lbry/tests/integration/test_dht.py +++ b/lbry/tests/integration/test_dht.py @@ -1,6 +1,8 @@ import asyncio from binascii import hexlify +from lbry.extras.daemon.storage import SQLiteStorage +from lbry.conf import Config from lbry.dht import constants from lbry.dht.node import Node from lbry.dht import peer as dht_peer @@ -19,20 +21,29 @@ class DHTIntegrationTest(AsyncioTestCase): self.nodes = [] self.known_node_addresses = [] - async def setup_network(self, size: int, start_port=40000, seed_nodes=1): + async def create_node(self, node_id, port, external_ip='127.0.0.1'): + storage = SQLiteStorage(Config(), ":memory:", self.loop, self.loop.time) + await storage.open() + node = Node(self.loop, PeerManager(self.loop), node_id=node_id, + udp_port=port, internal_udp_port=port, + peer_port=3333, external_ip=external_ip, + storage=storage) + self.addCleanup(node.stop) + node.protocol.rpc_timeout = .5 + node.protocol.ping_queue._default_delay = .5 + node._peer_search_timeout = .5 + return node + + async def setup_network(self, size: int, start_port=40000, seed_nodes=1, external_ip='127.0.0.1'): for i in range(size): node_port = start_port + i - node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i), - udp_port=node_port, internal_udp_port=node_port, - peer_port=3333, external_ip='127.0.0.1') + node_id = constants.generate_id(i) + node = await self.create_node(node_id, node_port) self.nodes.append(node) - self.known_node_addresses.append(('127.0.0.1', node_port)) - await node.start_listening('127.0.0.1') - self.addCleanup(node.stop) + self.known_node_addresses.append((external_ip, node_port)) + for node in self.nodes: - node.protocol.rpc_timeout = .5 - node.protocol.ping_queue._default_delay = .5 - node.start('127.0.0.1', self.known_node_addresses[:seed_nodes]) + node.start(external_ip, self.known_node_addresses[:seed_nodes]) await asyncio.gather(*[node.joined.wait() for node in self.nodes]) async def test_replace_bad_nodes(self): @@ -116,3 +127,88 @@ class DHTIntegrationTest(AsyncioTestCase): await node1.peer_search(node2.protocol.node_id) await asyncio.sleep(.3) # let pending events settle self.assertFalse(node1.protocol.routing_table.get_peers()) + + async def test_peer_persistance(self): + num_peers = 5 + start_port = 40000 + external_ip = '127.0.0.1' + + # Start a node + node1 = await self.create_node(constants.generate_id(num_peers), start_port+num_peers) + node1.start(external_ip) + + # Add peers + peer_args = [(n.protocol.nodeid, n.protocol.external_ip, n.protocol.udp_port) for n in self.nodes] + peers = [make_kademlia_peer(*args) for args in peer_args] + for peer in peers: + await node1.protocol._add_peer(peer) + + await asyncio.sleep(.3) + self.assertTrue(node1.joined.is_set()) + self.assertCountEqual(peers, node1.protocol.routing_table.get_peers()) + + # Refresh and assert that the peers were persisted + await node1.refresh_node(True) + self.assertCountEqual(peer_args, await node1._storage.get_peers()) + node1.stop() + + # Start a fresh node with the same node_id and storage + node2 = await self.create_node(constants.generate_id(num_peers), start_port+num_peers+1) + node2._storage = node1._storage + node2.start(external_ip) + + # The peers are restored + await asyncio.sleep(.3) + self.assertTrue(node2.joined.is_set()) + self.assertCountEqual(peers, node2.protocol.routing_table.get_peers()) + for bucket1, bucket2 in zip(node1.protocol.routing_table.buckets, node2.protocol.routing_table.buckets): + self.assertEqual((bucket1.range_min, bucket1.range_max), (bucket2.range_min, bucket2.range_max)) + + async def test_switch_to_known_seeds(self): + num_peers = 10 + start_port = 40000 + external_ip = '127.0.0.1' + + await self.setup_network(num_peers, seed_nodes=num_peers // 2, start_port=start_port) + peer_args = [ + (n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port) for n in self.nodes + ] + known_peers = [make_kademlia_peer(*args) for args in peer_args[:num_peers // 2]] + known_nodes = self.nodes[:num_peers // 2] + persisted_peers = [make_kademlia_peer(*args) for args in peer_args[num_peers // 2:]] + persisted_nodes = self.nodes[num_peers // 2:] + + # Create node with the persisted nodes in storage + node = await self.create_node(constants.generate_id(num_peers), start_port+num_peers) + await node._storage.update_peers(persisted_peers) + + # Stop known peers so they stop replying and won't be added + for n in known_nodes: + n.stop() + + node.start(external_ip, self.known_node_addresses[:num_peers // 2]) + await node.joined.wait() + self.assertTrue(node.joined.is_set()) + + # Only persisted ones are added to the routing table + self.assertCountEqual(persisted_peers, node.protocol.routing_table.get_peers()) + + # Start the known ones, stop the persisted + for n1, n2 in zip(known_nodes, persisted_nodes): + n1.start(external_ip) + n2.stop() + asyncio.gather(*[n.joined.wait() for n in known_nodes]) + await asyncio.sleep(3) + self.assertTrue(all(known.joined.is_set() for known in known_nodes)) + self.assertTrue(all(not persisted.joined.is_set() for persisted in persisted_nodes)) + + # Remove persisted from node's routing table, set them as bad + for peer in persisted_peers: + node.protocol.routing_table.remove_peer(peer) + node.protocol.peer_manager.report_failure(peer.address, peer.udp_port) + self.assertFalse(node.protocol.routing_table.get_peers()) + + # The known_peers replace the persisted ones + await node.joined.wait() + await asyncio.sleep(3) + self.assertCountEqual(known_peers, node.protocol.routing_table.get_peers()) diff --git a/lbry/tests/unit/database/test_SQLiteStorage.py b/lbry/tests/unit/database/test_SQLiteStorage.py index e0a11ad84..bd5dfd67f 100644 --- a/lbry/tests/unit/database/test_SQLiteStorage.py +++ b/lbry/tests/unit/database/test_SQLiteStorage.py @@ -3,6 +3,7 @@ import tempfile import unittest import asyncio import logging +import hashlib from torba.testcase import AsyncioTestCase from lbry.conf import Config from lbry.extras.daemon.storage import SQLiteStorage @@ -10,6 +11,7 @@ from lbry.blob.blob_info import BlobInfo from lbry.blob.blob_manager import BlobManager from lbry.stream.descriptor import StreamDescriptor from tests.test_utils import random_lbry_hash +from lbry.dht.peer import make_kademlia_peer log = logging.getLogger() @@ -247,3 +249,13 @@ class ContentClaimStorageTests(StorageTest): current_claim_info = await self.storage.get_content_claim(stream_hash) # this should still be the previous update self.assertDictEqual(current_claim_info, update_info) + + +class UpdatePeersTest(StorageTest): + async def test_update_get_peers(self): + node_id = hashlib.sha384("1234".encode()).digest() + args = (node_id, '73.186.148.72', 4444, None) + fake_peer = make_kademlia_peer(*args) + await self.storage.update_peers([fake_peer]) + peers = await self.storage.get_peers() + self.assertTupleEqual(args, peers[0])