diff --git a/lbrynet/cryptoutils.py b/lbrynet/cryptoutils.py index 6f71c9a69..3c90ee598 100644 --- a/lbrynet/cryptoutils.py +++ b/lbrynet/cryptoutils.py @@ -1,4 +1,8 @@ import hashlib +from cryptography.hazmat.backends import default_backend + + +backend = default_backend() def get_lbry_hash_obj(): diff --git a/lbrynet/dht/AUTHORS b/lbrynet/dht/AUTHORS deleted file mode 100644 index 9a8063643..000000000 --- a/lbrynet/dht/AUTHORS +++ /dev/null @@ -1,7 +0,0 @@ -Francois Aucamp - -Thanks goes to the following people for providing patches/suggestions/tests: - -Neil Kleynhans -Haiyang Ma -Bryan McAlister diff --git a/lbrynet/dht/COPYING b/lbrynet/dht/COPYING deleted file mode 100644 index cca7fc278..000000000 --- a/lbrynet/dht/COPYING +++ /dev/null @@ -1,165 +0,0 @@ - GNU LESSER GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - - This version of the GNU Lesser General Public License incorporates -the terms and conditions of version 3 of the GNU General Public -License, supplemented by the additional permissions listed below. - - 0. Additional Definitions. - - As used herein, "this License" refers to version 3 of the GNU Lesser -General Public License, and the "GNU GPL" refers to version 3 of the GNU -General Public License. - - "The Library" refers to a covered work governed by this License, -other than an Application or a Combined Work as defined below. - - An "Application" is any work that makes use of an interface provided -by the Library, but which is not otherwise based on the Library. -Defining a subclass of a class defined by the Library is deemed a mode -of using an interface provided by the Library. - - A "Combined Work" is a work produced by combining or linking an -Application with the Library. The particular version of the Library -with which the Combined Work was made is also called the "Linked -Version". - - The "Minimal Corresponding Source" for a Combined Work means the -Corresponding Source for the Combined Work, excluding any source code -for portions of the Combined Work that, considered in isolation, are -based on the Application, and not on the Linked Version. - - The "Corresponding Application Code" for a Combined Work means the -object code and/or source code for the Application, including any data -and utility programs needed for reproducing the Combined Work from the -Application, but excluding the System Libraries of the Combined Work. - - 1. Exception to Section 3 of the GNU GPL. - - You may convey a covered work under sections 3 and 4 of this License -without being bound by section 3 of the GNU GPL. - - 2. Conveying Modified Versions. - - If you modify a copy of the Library, and, in your modifications, a -facility refers to a function or data to be supplied by an Application -that uses the facility (other than as an argument passed when the -facility is invoked), then you may convey a copy of the modified -version: - - a) under this License, provided that you make a good faith effort to - ensure that, in the event an Application does not supply the - function or data, the facility still operates, and performs - whatever part of its purpose remains meaningful, or - - b) under the GNU GPL, with none of the additional permissions of - this License applicable to that copy. - - 3. Object Code Incorporating Material from Library Header Files. - - The object code form of an Application may incorporate material from -a header file that is part of the Library. You may convey such object -code under terms of your choice, provided that, if the incorporated -material is not limited to numerical parameters, data structure -layouts and accessors, or small macros, inline functions and templates -(ten or fewer lines in length), you do both of the following: - - a) Give prominent notice with each copy of the object code that the - Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the object code with a copy of the GNU GPL and this license - document. - - 4. Combined Works. - - You may convey a Combined Work under terms of your choice that, -taken together, effectively do not restrict modification of the -portions of the Library contained in the Combined Work and reverse -engineering for debugging such modifications, if you also do each of -the following: - - a) Give prominent notice with each copy of the Combined Work that - the Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the Combined Work with a copy of the GNU GPL and this license - document. - - c) For a Combined Work that displays copyright notices during - execution, include the copyright notice for the Library among - these notices, as well as a reference directing the user to the - copies of the GNU GPL and this license document. - - d) Do one of the following: - - 0) Convey the Minimal Corresponding Source under the terms of this - License, and the Corresponding Application Code in a form - suitable for, and under terms that permit, the user to - recombine or relink the Application with a modified version of - the Linked Version to produce a modified Combined Work, in the - manner specified by section 6 of the GNU GPL for conveying - Corresponding Source. - - 1) Use a suitable shared library mechanism for linking with the - Library. A suitable mechanism is one that (a) uses at run time - a copy of the Library already present on the user's computer - system, and (b) will operate properly with a modified version - of the Library that is interface-compatible with the Linked - Version. - - e) Provide Installation Information, but only if you would otherwise - be required to provide such information under section 6 of the - GNU GPL, and only to the extent that such information is - necessary to install and execute a modified version of the - Combined Work produced by recombining or relinking the - Application with a modified version of the Linked Version. (If - you use option 4d0, the Installation Information must accompany - the Minimal Corresponding Source and Corresponding Application - Code. If you use option 4d1, you must provide the Installation - Information in the manner specified by section 6 of the GNU GPL - for conveying Corresponding Source.) - - 5. Combined Libraries. - - You may place library facilities that are a work based on the -Library side by side in a single library together with other library -facilities that are not Applications and are not covered by this -License, and convey such a combined library under terms of your -choice, if you do both of the following: - - a) Accompany the combined library with a copy of the same work based - on the Library, uncombined with any other library facilities, - conveyed under the terms of this License. - - b) Give prominent notice with the combined library that part of it - is a work based on the Library, and explaining where to find the - accompanying uncombined form of the same work. - - 6. Revised Versions of the GNU Lesser General Public License. - - The Free Software Foundation may publish revised and/or new versions -of the GNU Lesser General Public License from time to time. Such new -versions will be similar in spirit to the present version, but may -differ in detail to address new problems or concerns. - - Each version is given a distinguishing version number. If the -Library as you received it specifies that a certain numbered version -of the GNU Lesser General Public License "or any later version" -applies to it, you have the option of following the terms and -conditions either of that published version or of any later version -published by the Free Software Foundation. If the Library as you -received it does not specify a version number of the GNU Lesser -General Public License, you may choose any version of the GNU Lesser -General Public License ever published by the Free Software Foundation. - - If the Library as you received it specifies that a proxy can decide -whether future versions of the GNU Lesser General Public License shall -apply, that proxy's public statement of acceptance of any version is -permanent authorization for you to choose that version for the -Library. diff --git a/lbrynet/dht/blob_announcer.py b/lbrynet/dht/blob_announcer.py new file mode 100644 index 000000000..2a209e65e --- /dev/null +++ b/lbrynet/dht/blob_announcer.py @@ -0,0 +1,65 @@ +import asyncio +import typing +import logging +if typing.TYPE_CHECKING: + from lbrynet.dht.node import Node + from lbrynet.extras.daemon.storage import SQLiteStorage + +log = logging.getLogger(__name__) + + +class BlobAnnouncer: + def __init__(self, loop: asyncio.BaseEventLoop, node: 'Node', storage: 'SQLiteStorage'): + self.loop = loop + self.node = node + self.storage = storage + self.pending_call: asyncio.Handle = None + self.announce_task: asyncio.Task = None + self.running = False + self.announce_queue: typing.List[str] = [] + + async def _announce(self, batch_size: typing.Optional[int] = 10): + if not self.node.joined.is_set(): + await self.node.joined.wait() + blob_hashes = await self.storage.get_blobs_to_announce() + if blob_hashes: + self.announce_queue.extend(blob_hashes) + log.info("%i blobs to announce", len(blob_hashes)) + batch = [] + while len(self.announce_queue): + cnt = 0 + announced = [] + while self.announce_queue and cnt < batch_size: + blob_hash = self.announce_queue.pop() + announced.append(blob_hash) + batch.append(self.node.announce_blob(blob_hash)) + cnt += 1 + to_await = [] + while batch: + to_await.append(batch.pop()) + if to_await: + await asyncio.gather(*tuple(to_await), loop=self.loop) + await self.storage.update_last_announced_blobs(announced, self.loop.time()) + log.info("announced %i blobs", len(announced)) + if self.running: + self.pending_call = self.loop.call_later(60, self.announce, batch_size) + + def announce(self, batch_size: typing.Optional[int] = 10): + self.announce_task = self.loop.create_task(self._announce(batch_size)) + + def start(self, batch_size: typing.Optional[int] = 10): + if self.running: + raise Exception("already running") + self.running = True + self.announce(batch_size) + + def stop(self): + self.running = False + if self.pending_call: + if not self.pending_call.cancelled(): + self.pending_call.cancel() + self.pending_call = None + if self.announce_task: + if not (self.announce_task.done() or self.announce_task.cancelled()): + self.announce_task.cancel() + self.announce_task = None diff --git a/lbrynet/dht/call_later_manager.py b/lbrynet/dht/call_later_manager.py deleted file mode 100644 index eba08450e..000000000 --- a/lbrynet/dht/call_later_manager.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging - -log = logging.getLogger() - -MIN_DELAY = 0.0 -MAX_DELAY = 0.01 -DELAY_INCREMENT = 0.0001 -QUEUE_SIZE_THRESHOLD = 100 - - -class CallLaterManager: - def __init__(self, callLater): - """ - :param callLater: (IReactorTime.callLater) - """ - - self._callLater = callLater - self._pendingCallLaters = [] - self._delay = MIN_DELAY - - def get_min_delay(self): - self._pendingCallLaters = [cl for cl in self._pendingCallLaters if cl.active()] - queue_size = len(self._pendingCallLaters) - if queue_size > QUEUE_SIZE_THRESHOLD: - self._delay = min((self._delay + DELAY_INCREMENT), MAX_DELAY) - else: - self._delay = max((self._delay - 2.0 * DELAY_INCREMENT), MIN_DELAY) - return self._delay - - def _cancel(self, call_later): - """ - :param call_later: DelayedCall - :return: (callable) canceller function - """ - - def cancel(reason=None): - """ - :param reason: reason for cancellation, this is returned after cancelling the DelayedCall - :return: reason - """ - - if call_later.active(): - call_later.cancel() - if call_later in self._pendingCallLaters: - self._pendingCallLaters.remove(call_later) - return reason - return cancel - - def stop(self): - """ - Cancel any callLaters that are still running - """ - - from twisted.internet import defer - while self._pendingCallLaters: - canceller = self._cancel(self._pendingCallLaters[0]) - try: - canceller() - except (defer.CancelledError, defer.AlreadyCalledError, ValueError): - pass - - def call_later(self, when, what, *args, **kwargs): - """ - Schedule a call later and get a canceller callback function - - :param when: (float) delay in seconds - :param what: (callable) - :param args: (*tuple) args to be passed to the callable - :param kwargs: (**dict) kwargs to be passed to the callable - - :return: (tuple) twisted.internet.base.DelayedCall object, canceller function - """ - - call_later = self._callLater(when, what, *args, **kwargs) - canceller = self._cancel(call_later) - self._pendingCallLaters.append(call_later) - return call_later, canceller - - def call_soon(self, what, *args, **kwargs): - delay = self.get_min_delay() - return self.call_later(delay, what, *args, **kwargs) diff --git a/lbrynet/dht/constants.py b/lbrynet/dht/constants.py index 28b17e74d..0dd3bd746 100644 --- a/lbrynet/dht/constants.py +++ b/lbrynet/dht/constants.py @@ -1,61 +1,40 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive -# -# The docstrings in this module contain epytext markup; API documentation -# may be created by processing this file with epydoc: http://epydoc.sf.net +import hashlib +import os -""" This module defines the charaterizing constants of the Kademlia network - -C{checkRefreshInterval} and C{udpDatagramMaxSize} are implementation-specific -constants, and do not affect general Kademlia operation. -""" - -######### KADEMLIA CONSTANTS ########### - -#: Small number Representing the degree of parallelism in network calls -alpha = 3 - -#: Maximum number of contacts stored in a bucket; this should be an even number +hash_class = hashlib.sha384 +hash_length = hash_class().digest_size +hash_bits = hash_length * 8 +alpha = 5 k = 8 - -#: Maximum number of contacts stored in the replacement cache -replacementCacheSize = 8 - -#: Timeout for network operations (in seconds) -rpcTimeout = 5 - -# number of rpc attempts to make before a timeout results in the node being removed as a contact -rpcAttempts = 5 -# time window to count failures (in seconds) -rpcAttemptsPruningTimeWindow = 600 - -# Delay between iterations of iterative node lookups (for loose parallelism) (in seconds) -iterativeLookupDelay = rpcTimeout / 2 - -#: If a k-bucket has not been used for this amount of time, refresh it (in seconds) -refreshTimeout = 3600 # 1 hour -#: The interval at which nodes replicate (republish/refresh) data they are holding -replicateInterval = refreshTimeout -# The time it takes for data to expire in the network; the original publisher of the data -# will also republish the data at this time if it is still valid -dataExpireTimeout = 86400 # 24 hours - -tokenSecretChangeInterval = 300 # 5 minutes - -######## IMPLEMENTATION-SPECIFIC CONSTANTS ########### - -#: The interval for the node to check whether any buckets need refreshing -checkRefreshInterval = refreshTimeout / 5 - -#: Max size of a single UDP datagram, in bytes. If a message is larger than this, it will -#: be spread across several UDP packets. -udpDatagramMaxSize = 8192 # 8 KB - -key_bits = 384 - +replacement_cache_size = 8 +rpc_timeout = 5.0 +rpc_attempts = 5 +rpc_attempts_pruning_window = 600 +iterative_lookup_delay = rpc_timeout / 2.0 +refresh_interval = 3600 # 1 hour +replicate_interval = refresh_interval +data_expiration = 86400 # 24 hours +token_secret_refresh_interval = 300 # 5 minutes +check_refresh_interval = refresh_interval / 5 +max_datagram_size = 8192 # 8 KB rpc_id_length = 20 +protocol_version = 1 +bottom_out_limit = 3 +msg_size_limit = max_datagram_size - 26 -protocolVersion = 1 + +def digest(data: bytes) -> bytes: + h = hash_class() + h.update(data) + return h.digest() + + +def generate_id(num=None) -> bytes: + if num is not None: + return digest(str(num).encode()) + else: + return digest(os.urandom(32)) + + +def generate_rpc_id(num=None) -> bytes: + return generate_id(num)[:rpc_id_length] diff --git a/lbrynet/dht/contact.py b/lbrynet/dht/contact.py deleted file mode 100644 index 99c84532b..000000000 --- a/lbrynet/dht/contact.py +++ /dev/null @@ -1,189 +0,0 @@ -import ipaddress -from binascii import hexlify -from functools import reduce -from lbrynet.dht import constants - - -def is_valid_ipv4(address): - try: - ip = ipaddress.ip_address(address) - return ip.version == 4 - except ipaddress.AddressValueError: - return False - - -class _Contact: - """ Encapsulation for remote contact - - This class contains information on a single remote contact, and also - provides a direct RPC API to the remote node which it represents - """ - - def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm): - if id is not None: - if not len(id) == constants.key_bits // 8: - raise ValueError("invalid node id: {}".format(hexlify(id).decode())) - if not 0 <= udpPort <= 65536: - raise ValueError("invalid port") - if not is_valid_ipv4(ipAddress): - raise ValueError("invalid ip address") - self._contactManager = contactManager - self._id = id - self.address = ipAddress - self.port = udpPort - self._networkProtocol = networkProtocol - self.commTime = firstComm - self.getTime = self._contactManager._get_time - self.lastReplied = None - self.lastRequested = None - self.protocolVersion = 0 - self._token = (None, 0) # token, timestamp - - def update_token(self, token): - self._token = token, self.getTime() if token else 0 - - @property - def token(self): - # expire the token 1 minute early to be safe - return self._token[0] if self._token[1] + 240 > self.getTime() else None - - @property - def lastInteracted(self): - return max(self.lastRequested or 0, self.lastReplied or 0, self.lastFailed or 0) - - @property - def id(self): - return self._id - - def log_id(self, short=True): - if not self.id: - return "not initialized" - id_hex = hexlify(self.id) - return id_hex if not short else id_hex[:8] - - @property - def failedRPCs(self): - return len(self.failures) - - @property - def lastFailed(self): - return (self.failures or [None])[-1] - - @property - def failures(self): - return self._contactManager._rpc_failures.get((self.address, self.port), []) - - @property - def contact_is_good(self): - """ - :return: False if contact is bad, None if contact is unknown, or True if contact is good - """ - failures = self.failures - now = self.getTime() - delay = constants.checkRefreshInterval - - if failures: - if self.lastReplied and len(failures) >= 2 and self.lastReplied < failures[-2]: - return False - elif self.lastReplied and len(failures) >= 2 and self.lastReplied > failures[-2]: - pass # handled below - elif len(failures) >= 2: - return False - - if self.lastReplied and self.lastReplied > now - delay: - return True - if self.lastReplied and self.lastRequested and self.lastRequested > now - delay: - return True - return None - - def __eq__(self, other): - if not isinstance(other, _Contact): - raise TypeError("invalid type to compare with Contact: %s" % str(type(other))) - return (self.id, self.address, self.port) == (other.id, other.address, other.port) - - def __hash__(self): - return hash((self.id, self.address, self.port)) - - def compact_ip(self): - compact_ip = reduce( - lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray()) - return compact_ip - - def set_id(self, id): - if not self._id: - self._id = id - - def update_last_replied(self): - self.lastReplied = int(self.getTime()) - - def update_last_requested(self): - self.lastRequested = int(self.getTime()) - - def update_last_failed(self): - failures = self._contactManager._rpc_failures.get((self.address, self.port), []) - failures.append(self.getTime()) - self._contactManager._rpc_failures[(self.address, self.port)] = failures - - def update_protocol_version(self, version): - self.protocolVersion = version - - def __str__(self): - return '<%s.%s object; IP address: %s, UDP port: %d>' % ( - self.__module__, self.__class__.__name__, self.address, self.port) - - def __getattr__(self, name): - """ This override allows the host node to call a method of the remote - node (i.e. this contact) as if it was a local function. - - For instance, if C{remoteNode} is a instance of C{Contact}, the - following will result in C{remoteNode}'s C{test()} method to be - called with argument C{123}:: - remoteNode.test(123) - - Such a RPC method call will return a Deferred, which will callback - when the contact responds with the result (or an error occurs). - This happens via this contact's C{_networkProtocol} object (i.e. the - host Node's C{_protocol} object). - """ - - if name not in ['ping', 'findValue', 'findNode', 'store']: - raise AttributeError("unknown command: %s" % name) - - def _sendRPC(*args, **kwargs): - return self._networkProtocol.sendRPC(self, name.encode(), args) - - return _sendRPC - - -class ContactManager: - def __init__(self, get_time=None): - if not get_time: - from twisted.internet import reactor - get_time = reactor.seconds - self._get_time = get_time - self._contacts = {} - self._rpc_failures = {} - - def get_contact(self, id, address, port): - for contact in self._contacts.values(): - if contact.id == id and contact.address == address and contact.port == port: - return contact - - def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0): - contact = self.get_contact(id, ipAddress, udpPort) - if contact: - return contact - contact = _Contact(self, id, ipAddress, udpPort, networkProtocol, firstComm or self._get_time()) - self._contacts[(id, ipAddress, udpPort)] = contact - return contact - - def is_ignored(self, origin_tuple): - failed_rpc_count = len(self._prune_failures(origin_tuple)) - return failed_rpc_count > constants.rpcAttempts - - def _prune_failures(self, origin_tuple): - # Prunes recorded failures to the last time window of attempts - pruning_limit = self._get_time() - constants.rpcAttemptsPruningTimeWindow - pruned = list(filter(lambda t: t >= pruning_limit, self._rpc_failures.get(origin_tuple, []))) - self._rpc_failures[origin_tuple] = pruned - return pruned diff --git a/lbrynet/dht/datastore.py b/lbrynet/dht/datastore.py deleted file mode 100644 index 19fa0f84a..000000000 --- a/lbrynet/dht/datastore.py +++ /dev/null @@ -1,66 +0,0 @@ -from collections import UserDict -from lbrynet.dht import constants - - -class DictDataStore(UserDict): - """ A datastore using an in-memory Python dictionary """ - #implements(IDataStore) - - def __init__(self, getTime=None): - # Dictionary format: - # { : (, , , ) } - super().__init__() - if not getTime: - from twisted.internet import reactor - getTime = reactor.seconds - self._getTime = getTime - self.completed_blobs = set() - - def filter_bad_and_expired_peers(self, key): - """ - Returns only non-expired and unknown/good peers - """ - return filter( - lambda peer: - self._getTime() - peer[3] < constants.dataExpireTimeout and peer[0].contact_is_good is not False, - self[key] - ) - - def filter_expired_peers(self, key): - """ - Returns only non-expired peers - """ - return filter(lambda peer: self._getTime() - peer[3] < constants.dataExpireTimeout, self[key]) - - def removeExpiredPeers(self): - expired_keys = [] - for key in self.keys(): - unexpired_peers = list(self.filter_expired_peers(key)) - if not unexpired_peers: - expired_keys.append(key) - else: - self[key] = unexpired_peers - for key in expired_keys: - del self[key] - - def hasPeersForBlob(self, key): - return bool(key in self and len(tuple(self.filter_bad_and_expired_peers(key)))) - - def addPeerToBlob(self, contact, key, compact_address, lastPublished, originallyPublished, originalPublisherID): - if key in self: - if compact_address not in map(lambda store_tuple: store_tuple[1], self[key]): - self[key].append( - (contact, compact_address, lastPublished, originallyPublished, originalPublisherID) - ) - else: - self[key] = [(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)] - - def getPeersForBlob(self, key): - return [] if key not in self else [val[1] for val in self.filter_bad_and_expired_peers(key)] - - def getStoringContacts(self): - contacts = set() - for key in self: - for values in self[key]: - contacts.add(values[0]) - return list(contacts) diff --git a/lbrynet/dht/error.py b/lbrynet/dht/error.py index 44b5a194d..662b98b6c 100644 --- a/lbrynet/dht/error.py +++ b/lbrynet/dht/error.py @@ -1,43 +1,23 @@ -import binascii -#import exceptions - -# this is a dict of {"exceptions.": exception class} items used to raise -# remote built-in exceptions locally -BUILTIN_EXCEPTIONS = { -# "exceptions.%s" % e: getattr(exceptions, e) for e in dir(exceptions) if not e.startswith("_") -} +class BaseKademliaException(Exception): + pass -class DecodeError(Exception): +class DecodeError(BaseKademliaException): """ Should be raised by an C{Encoding} implementation if decode operation fails """ -class BucketFull(Exception): +class BucketFull(BaseKademliaException): """ Raised when the bucket is full """ -class UnknownRemoteException(Exception): +class RemoteException(BaseKademliaException): pass -class TimeoutError(Exception): - """ Raised when a RPC times out """ - - def __init__(self, remote_contact_id): - # remote_contact_id is a binary blob so we need to convert it - # into something more readable - if remote_contact_id: - msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id)) - else: - msg = 'Timeout connecting to uninitialized node' - super().__init__(msg) - self.remote_contact_id = remote_contact_id - - -class TransportNotConnected(Exception): +class TransportNotConnected(BaseKademliaException): pass diff --git a/lbrynet/dht/interface.py b/lbrynet/dht/interface.py deleted file mode 100644 index 6353dbbfd..000000000 --- a/lbrynet/dht/interface.py +++ /dev/null @@ -1,116 +0,0 @@ -from zope.interface import Interface - - -class IDataStore(Interface): - """ Interface for classes implementing physical storage (for data - published via the "STORE" RPC) for the Kademlia DHT - - @note: This provides an interface for a dict-like object - """ - - def keys(self): - """ Return a list of the keys in this data store """ - - def removeExpiredPeers(self): - pass - - def hasPeersForBlob(self, key): - pass - - def addPeerToBlob(self, key, value, lastPublished, originallyPublished, originalPublisherID): - pass - - def getPeersForBlob(self, key): - pass - - def removePeer(self, key): - pass - - -class IRoutingTable(Interface): - """ Interface for RPC message translators/formatters - - Classes inheriting from this should provide a suitable routing table for - a parent Node object (i.e. the local entity in the Kademlia network) - """ - - def __init__(self, parentNodeID): - """ - @param parentNodeID: The n-bit node ID of the node to which this - routing table belongs - @type parentNodeID: str - """ - - def addContact(self, contact): - """ Add the given contact to the correct k-bucket; if it already - exists, its status will be updated - - @param contact: The contact to add to this node's k-buckets - @type contact: kademlia.contact.Contact - """ - - def findCloseNodes(self, key, count, _rpcNodeID=None): - """ Finds a number of known nodes closest to the node/value with the - specified key. - - @param key: the n-bit key (i.e. the node or value ID) to search for - @type key: str - @param count: the amount of contacts to return - @type count: int - @param _rpcNodeID: Used during RPC, this is be the sender's Node ID - Whatever ID is passed in the parameter will get - excluded from the list of returned contacts. - @type _rpcNodeID: str - - @return: A list of node contacts (C{kademlia.contact.Contact instances}) - closest to the specified key. - This method will return C{k} (or C{count}, if specified) - contacts if at all possible; it will only return fewer if the - node is returning all of the contacts that it knows of. - @rtype: list - """ - - def getContact(self, contactID): - """ Returns the (known) contact with the specified node ID - - @raise ValueError: No contact with the specified contact ID is known - by this node - """ - - def getRefreshList(self, startIndex=0, force=False): - """ Finds all k-buckets that need refreshing, starting at the - k-bucket with the specified index, and returns IDs to be searched for - in order to refresh those k-buckets - - @param startIndex: The index of the bucket to start refreshing at; - this bucket and those further away from it will - be refreshed. For example, when joining the - network, this node will set this to the index of - the bucket after the one containing it's closest - neighbour. - @type startIndex: index - @param force: If this is C{True}, all buckets (in the specified range) - will be refreshed, regardless of the time they were last - accessed. - @type force: bool - - @return: A list of node ID's that the parent node should search for - in order to refresh the routing Table - @rtype: list - """ - - def removeContact(self, contactID): - """ Remove the contact with the specified node ID from the routing - table - - @param contactID: The node ID of the contact to remove - @type contactID: str - """ - - def touchKBucket(self, key): - """ Update the "last accessed" timestamp of the k-bucket which covers - the range containing the specified key in the key/ID space - - @param key: A key in the range of the target k-bucket - @type key: str - """ diff --git a/lbrynet/dht/iterativefind.py b/lbrynet/dht/iterativefind.py deleted file mode 100644 index 076bd1a3d..000000000 --- a/lbrynet/dht/iterativefind.py +++ /dev/null @@ -1,230 +0,0 @@ -import logging -from twisted.internet import defer -from lbrynet.dht.distance import Distance -from lbrynet.dht.error import TimeoutError -from lbrynet.dht import constants - -log = logging.getLogger(__name__) - - -def get_contact(contact_list, node_id, address, port): - for contact in contact_list: - if contact.id == node_id and contact.address == address and contact.port == port: - return contact - raise IndexError(node_id) - - -def expand_peer(compact_peer_info): - host = "{}.{}.{}.{}".format(*compact_peer_info[:4]) - port = int.from_bytes(compact_peer_info[4:6], 'big') - peer_node_id = compact_peer_info[6:] - return (peer_node_id, host, port) - - -class _IterativeFind: - # TODO: use polymorphism to search for a value or node - # instead of using a find_value flag - def __init__(self, node, shortlist, key, rpc, exclude=None): - self.exclude = set(exclude or []) - self.node = node - self.finished_deferred = defer.Deferred() - # all distance operations in this class only care about the distance - # to self.key, so this makes it easier to calculate those - self.distance = Distance(key) - # The closest known and active node yet found - self.closest_node = None if not shortlist else shortlist[0] - self.prev_closest_node = None - # Shortlist of contact objects (the k closest known contacts to the key from the routing table) - self.shortlist = shortlist - # The search key - self.key = key - # The rpc method name (findValue or findNode) - self.rpc = rpc - # List of active queries; len() indicates number of active probes - self.active_probes = [] - # List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply - self.already_contacted = [] - # A list of found and known-to-be-active remote nodes (Contact objects) - self.active_contacts = [] - # Ensure only one searchIteration call is running at a time - self._search_iteration_semaphore = defer.DeferredSemaphore(1) - self._iteration_count = 0 - self.find_value_result = {} - self.pending_iteration_calls = [] - - @property - def is_find_node_request(self): - return self.rpc == "findNode" - - @property - def is_find_value_request(self): - return self.rpc == "findValue" - - def is_closer(self, contact): - if not self.closest_node: - return True - return self.distance.is_closer(contact.id, self.closest_node.id) - - def getContactTriples(self, result): - if self.is_find_value_request: - contact_triples = result[b'contacts'] - else: - contact_triples = result - for contact_tup in contact_triples: - if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3: - raise ValueError("invalid contact triple") - contact_tup[1] = contact_tup[1].decode() # ips are strings - return contact_triples - - def sortByDistance(self, contact_list): - """Sort the list of contacts in order by distance from key""" - contact_list.sort(key=lambda c: self.distance(c.id)) - - def extendShortlist(self, contact, result): - # The "raw response" tuple contains the response message and the originating address info - originAddress = (contact.address, contact.port) - if self.finished_deferred.called: - return contact.id - if self.node.contact_manager.is_ignored(originAddress): - raise ValueError("contact is ignored") - if contact.id == self.node.node_id: - return contact.id - - if contact not in self.active_contacts: - self.active_contacts.append(contact) - if contact not in self.shortlist: - self.shortlist.append(contact) - - # Now grow extend the (unverified) shortlist with the returned contacts - # TODO: some validation on the result (for guarding against attacks) - # If we are looking for a value, first see if this result is the value - # we are looking for before treating it as a list of contact triples - if self.is_find_value_request and self.key in result: - # We have found the value - for peer in result[self.key]: - node_id, host, port = expand_peer(peer) - if (host, port) not in self.exclude: - self.find_value_result.setdefault(self.key, []).append((node_id, host, port)) - if self.find_value_result: - self.finished_deferred.callback(self.find_value_result) - else: - if self.is_find_value_request: - # We are looking for a value, and the remote node didn't have it - # - mark it as the closest "empty" node, if it is - # TODO: store to this peer after finding the value as per the kademlia spec - if b'closestNodeNoValue' in self.find_value_result: - if self.is_closer(contact): - self.find_value_result[b'closestNodeNoValue'] = contact - else: - self.find_value_result[b'closestNodeNoValue'] = contact - contactTriples = self.getContactTriples(result) - for contactTriple in contactTriples: - if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted): - continue - elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])): - continue - else: - found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1], - contactTriple[2], self.node._protocol) - if found_contact not in self.shortlist: - self.shortlist.append(found_contact) - - if not self.finished_deferred.called and self.should_stop(): - self.sortByDistance(self.active_contacts) - self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))]) - - return contact.id - - @defer.inlineCallbacks - def probeContact(self, contact): - fn = getattr(contact, self.rpc) - try: - response = yield fn(self.key) - result = self.extendShortlist(contact, response) - defer.returnValue(result) - except (TimeoutError, defer.CancelledError, ValueError, IndexError): - defer.returnValue(contact.id) - - def should_stop(self): - if self.is_find_value_request: - # search stops when it finds a value, let it run - return False - if self.prev_closest_node and self.closest_node and self.distance.is_closer(self.prev_closest_node.id, - self.closest_node.id): - # we're getting further away - return True - if len(self.active_contacts) >= constants.k: - # we have enough results - return True - return False - - # Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts - def _searchIteration(self): - # Sort the discovered active nodes from closest to furthest - if len(self.active_contacts): - self.sortByDistance(self.active_contacts) - self.prev_closest_node = self.closest_node - self.closest_node = self.active_contacts[0] - - # Sort the current shortList before contacting other nodes - self.sortByDistance(self.shortlist) - probes = [] - already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted} - to_remove = [] - for contact in self.shortlist: - if self.node.contact_manager.is_ignored((contact.address, contact.port)): - to_remove.append(contact) # a contact became bad during iteration - continue - if (contact.address, contact.port) not in already_contacted_addresses: - self.already_contacted.append(contact) - to_remove.append(contact) - probe = self.probeContact(contact) - probes.append(probe) - self.active_probes.append(probe) - if len(probes) == constants.alpha: - break - for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully - self.shortlist.remove(contact) - - # run the probes - if probes: - # Schedule the next iteration if there are any active - # calls (Kademlia uses loose parallelism) - self.searchIteration() - - d = defer.DeferredList(probes, consumeErrors=True) - - def _remove_probes(results): - for probe in probes: - self.active_probes.remove(probe) - return results - - d.addCallback(_remove_probes) - - elif not self.finished_deferred.called and not self.active_probes or self.should_stop(): - # If no probes were sent, there will not be any improvement, so we're done - if self.is_find_value_request: - self.finished_deferred.callback(self.find_value_result) - else: - self.sortByDistance(self.active_contacts) - self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))]) - elif not self.finished_deferred.called: - # Force the next iteration - self.searchIteration() - - def searchIteration(self, delay=constants.iterativeLookupDelay): - def _cancel_pending_iterations(result): - while self.pending_iteration_calls: - canceller = self.pending_iteration_calls.pop() - canceller() - return result - self.finished_deferred.addBoth(_cancel_pending_iterations) - self._iteration_count += 1 - call, cancel = self.node.reactor_callLater(delay, self._search_iteration_semaphore.run, self._searchIteration) - self.pending_iteration_calls.append(cancel) - - -def iterativeFind(node, shortlist, key, rpc, exclude=None): - helper = _IterativeFind(node, shortlist, key, rpc, exclude) - helper.searchIteration(0) - return helper.finished_deferred diff --git a/lbrynet/dht/kbucket.py b/lbrynet/dht/kbucket.py deleted file mode 100644 index fa3f19d2e..000000000 --- a/lbrynet/dht/kbucket.py +++ /dev/null @@ -1,147 +0,0 @@ -import logging - -from lbrynet.dht import constants -from lbrynet.dht.distance import Distance -from lbrynet.dht.error import BucketFull - -log = logging.getLogger(__name__) - - -class KBucket: - """ Description - later - """ - - def __init__(self, rangeMin, rangeMax, node_id): - """ - @param rangeMin: The lower boundary for the range in the n-bit ID - space covered by this k-bucket - @param rangeMax: The upper boundary for the range in the ID space - covered by this k-bucket - """ - self.lastAccessed = 0 - self.rangeMin = rangeMin - self.rangeMax = rangeMax - self._contacts = list() - self._node_id = node_id - - def addContact(self, contact): - """ Add contact to _contact list in the right order. This will move the - contact to the end of the k-bucket if it is already present. - - @raise kademlia.kbucket.BucketFull: Raised when the bucket is full and - the contact isn't in the bucket - already - - @param contact: The contact to add - @type contact: dht.contact._Contact - """ - if contact in self._contacts: - # Move the existing contact to the end of the list - # - using the new contact to allow add-on data - # (e.g. optimization-specific stuff) to pe updated as well - self._contacts.remove(contact) - self._contacts.append(contact) - elif len(self._contacts) < constants.k: - self._contacts.append(contact) - else: - raise BucketFull("No space in bucket to insert contact") - - def getContact(self, contactID): - """Get the contact specified node ID - - @raise IndexError: raised if the contact is not in the bucket - - @param contactID: the node id of the contact to retrieve - @type contactID: str - - @rtype: dht.contact._Contact - """ - for contact in self._contacts: - if contact.id == contactID: - return contact - raise IndexError(contactID) - - def getContacts(self, count=-1, excludeContact=None, sort_distance_to=None): - """ Returns a list containing up to the first count number of contacts - - @param count: The amount of contacts to return (if 0 or less, return - all contacts) - @type count: int - @param excludeContact: A node id to exclude; if this contact is in - the list of returned values, it will be - discarded before returning. If a C{str} is - passed as this argument, it must be the - contact's ID. - @type excludeContact: str - - @param sort_distance_to: Sort distance to the id, defaulting to the parent node id. If False don't - sort the contacts - - @raise IndexError: If the number of requested contacts is too large - - @return: Return up to the first count number of contacts in a list - If no contacts are present an empty is returned - @rtype: list - """ - contacts = [contact for contact in self._contacts if contact.id != excludeContact] - - # Return all contacts in bucket - if count <= 0: - count = len(contacts) - - # Get current contact number - currentLen = len(contacts) - - # If count greater than k - return only k contacts - if count > constants.k: - count = constants.k - - if not currentLen: - return contacts - - if sort_distance_to is False: - pass - else: - sort_distance_to = sort_distance_to or self._node_id - contacts.sort(key=lambda c: Distance(sort_distance_to)(c.id)) - - return contacts[:min(currentLen, count)] - - def getBadOrUnknownContacts(self): - contacts = self.getContacts(sort_distance_to=False) - results = [contact for contact in contacts if contact.contact_is_good is False] - results.extend(contact for contact in contacts if contact.contact_is_good is None) - return results - - def removeContact(self, contact): - """ Remove the contact from the bucket - - @param contact: The contact to remove - @type contact: dht.contact._Contact - - @raise ValueError: The specified contact is not in this bucket - """ - self._contacts.remove(contact) - - def keyInRange(self, key): - """ Tests whether the specified key (i.e. node ID) is in the range - of the n-bit ID space covered by this k-bucket (in otherwords, it - returns whether or not the specified key should be placed in this - k-bucket) - - @param key: The key to test - @type key: str or int - - @return: C{True} if the key is in this k-bucket's range, or C{False} - if not. - @rtype: bool - """ - if isinstance(key, bytes): - key = int.from_bytes(key, 'big') - return self.rangeMin <= key < self.rangeMax - - def __len__(self): - return len(self._contacts) - - def __contains__(self, item): - return item in self._contacts diff --git a/lbrynet/dht/msgformat.py b/lbrynet/dht/msgformat.py deleted file mode 100644 index 5e4c59d79..000000000 --- a/lbrynet/dht/msgformat.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive -# -# The docstrings in this module contain epytext markup; API documentation -# may be created by processing this file with epydoc: http://epydoc.sf.net - -from lbrynet.dht import msgtypes - - -class MessageTranslator: - """ Interface for RPC message translators/formatters - - Classes inheriting from this should provide a translation services between - the classes used internally by this Kademlia implementation and the actual - data that is transmitted between nodes. - """ - - def fromPrimitive(self, msgPrimitive): - """ Create an RPC Message from a message's string representation - - @param msgPrimitive: The unencoded primitive representation of a message - @type msgPrimitive: str, int, list or dict - - @return: The translated message object - @rtype: entangled.kademlia.msgtypes.Message - """ - - def toPrimitive(self, message): - """ Create a string representation of a message - - @param message: The message object - @type message: msgtypes.Message - - @return: The message's primitive representation in a particular - messaging format - @rtype: str, int, list or dict - """ - - -class DefaultFormat(MessageTranslator): - """ The default on-the-wire message format for this library """ - typeRequest, typeResponse, typeError = range(3) - headerType, headerMsgID, headerNodeID, headerPayload, headerArgs = range(5) # TODO: make str - - @staticmethod - def get(primitive, key): - try: - return primitive[key] - except KeyError: - return primitive[str(key)] # TODO: switch to int() - - def fromPrimitive(self, msgPrimitive): - msgType = self.get(msgPrimitive, self.headerType) - if msgType == self.typeRequest: - msg = msgtypes.RequestMessage(self.get(msgPrimitive, self.headerNodeID), - self.get(msgPrimitive, self.headerPayload), - self.get(msgPrimitive, self.headerArgs), - self.get(msgPrimitive, self.headerMsgID)) - elif msgType == self.typeResponse: - msg = msgtypes.ResponseMessage(self.get(msgPrimitive, self.headerMsgID), - self.get(msgPrimitive, self.headerNodeID), - self.get(msgPrimitive, self.headerPayload)) - elif msgType == self.typeError: - msg = msgtypes.ErrorMessage(self.get(msgPrimitive, self.headerMsgID), - self.get(msgPrimitive, self.headerNodeID), - self.get(msgPrimitive, self.headerPayload), - self.get(msgPrimitive, self.headerArgs)) - else: - # Unknown message, no payload - msg = msgtypes.Message(msgPrimitive[self.headerMsgID], msgPrimitive[self.headerNodeID]) - return msg - - def toPrimitive(self, message): - msg = {self.headerMsgID: message.id, - self.headerNodeID: message.nodeID} - if isinstance(message, msgtypes.RequestMessage): - msg[self.headerType] = self.typeRequest - msg[self.headerPayload] = message.request - msg[self.headerArgs] = message.args - elif isinstance(message, msgtypes.ErrorMessage): - msg[self.headerType] = self.typeError - msg[self.headerPayload] = message.exceptionType - msg[self.headerArgs] = message.response - elif isinstance(message, msgtypes.ResponseMessage): - msg[self.headerType] = self.typeResponse - msg[self.headerPayload] = message.response - return msg diff --git a/lbrynet/dht/msgtypes.py b/lbrynet/dht/msgtypes.py deleted file mode 100644 index 907c1a8e7..000000000 --- a/lbrynet/dht/msgtypes.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive -# -# The docstrings in this module contain epytext markup; API documentation -# may be created by processing this file with epydoc: http://epydoc.sf.net - -from lbrynet.utils import generate_id -from lbrynet.dht import constants - - -class Message: - """ Base class for messages - all "unknown" messages use this class """ - - def __init__(self, rpcID, nodeID): - if len(rpcID) != constants.rpc_id_length: - raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID)) - if len(nodeID) != constants.key_bits // 8: - raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID)) - self.id = rpcID - self.nodeID = nodeID - - -class RequestMessage(Message): - """ Message containing an RPC request """ - - def __init__(self, nodeID, method, methodArgs, rpcID=None): - if rpcID is None: - rpcID = generate_id()[:constants.rpc_id_length] - super().__init__(rpcID, nodeID) - self.request = method - self.args = methodArgs - - -class ResponseMessage(Message): - """ Message containing the result from a successful RPC request """ - - def __init__(self, rpcID, nodeID, response): - super().__init__(rpcID, nodeID) - self.response = response - - -class ErrorMessage(ResponseMessage): - """ Message containing the error from an unsuccessful RPC request """ - - def __init__(self, rpcID, nodeID, exceptionType, errorMessage): - super().__init__(rpcID, nodeID, errorMessage) - if isinstance(exceptionType, type): - exceptionType = (f'{exceptionType.__module__}.{exceptionType.__name__}').encode() - self.exceptionType = exceptionType diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 8fc590cb5..c8524617d 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -1,649 +1,226 @@ -import binascii -import hashlib import logging -from functools import reduce +import asyncio +import typing +import socket +import binascii +import contextlib +from lbrynet.dht import constants +from lbrynet.dht.error import RemoteException +from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction +from lbrynet.dht.protocol.distance import Distance +from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder +from lbrynet.dht.protocol.protocol import KademliaProtocol +from lbrynet.dht.peer import KademliaPeer -from twisted.internet import defer, error, task - -from lbrynet.utils import generate_id, DeferredDict -from lbrynet.dht.call_later_manager import CallLaterManager -from lbrynet.dht.error import TimeoutError -from lbrynet.dht import constants, routingtable, datastore, protocol -from lbrynet.dht.contact import ContactManager -from lbrynet.dht.iterativefind import iterativeFind +if typing.TYPE_CHECKING: + from lbrynet.dht.peer import PeerManager log = logging.getLogger(__name__) -def rpcmethod(func): - """ Decorator to expose Node methods as remote procedure calls +class Node: + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int, + internal_udp_port: int, peer_port: int, external_ip: str): + self.loop = loop + self.internal_udp_port = internal_udp_port + self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port) + self.listening_port: asyncio.DatagramTransport = None + self.joined = asyncio.Event(loop=self.loop) + self._join_task: asyncio.Task = None + self._refresh_task: asyncio.Task = None - Apply this decorator to methods in the Node class (or a subclass) in order - to make them remotely callable via the DHT's RPC mechanism. - """ - func.rpcmethod = True - return func + async def refresh_node(self): + while True: + # remove peers with expired blob announcements from the datastore + self.protocol.data_store.removed_expired_peers() + total_peers: typing.List['KademliaPeer'] = [] + # add all peers in the routing table + total_peers.extend(self.protocol.routing_table.get_peers()) + # add all the peers who have announed blobs to us + total_peers.extend(self.protocol.data_store.get_storing_contacts()) -class MockKademliaHelper: - def __init__(self, clock=None, callLater=None, resolve=None, listenUDP=None): - if not listenUDP or not resolve or not callLater or not clock: - from twisted.internet import reactor - listenUDP = listenUDP or reactor.listenUDP - resolve = resolve or reactor.resolve - callLater = callLater or reactor.callLater - clock = clock or reactor + # get ids falling in the midpoint of each bucket that hasn't been recently updated + node_ids = self.protocol.routing_table.get_refresh_list(0, True) + # if we have 3 or fewer populated buckets get two random ids in the range of each to try and + # populate/split the buckets further + buckets_with_contacts = self.protocol.routing_table.buckets_with_contacts() + if buckets_with_contacts <= 3: + for i in range(buckets_with_contacts): + node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i)) + node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i)) - self.clock = clock - self.contact_manager = ContactManager(self.clock.seconds) - self.reactor_listenUDP = listenUDP - self.reactor_resolve = resolve - self.call_later_manager = CallLaterManager(callLater) - self.reactor_callLater = self.call_later_manager.call_later - self.reactor_callSoon = self.call_later_manager.call_soon - - self._listeningPort = None # object implementing Twisted - # IListeningPort This will contain a deferred created when - # joining the network, to enable publishing/retrieving - # information from the DHT as soon as the node is part of the - # network (add callbacks to this deferred if scheduling such - # operations before the node has finished joining the network) - - def get_looping_call(self, fn, *args, **kwargs): - lc = task.LoopingCall(fn, *args, **kwargs) - lc.clock = self.clock - return lc - - def safe_stop_looping_call(self, lc): - if lc and lc.running: - return lc.stop() - return defer.succeed(None) - - def safe_start_looping_call(self, lc, t): - if lc and not lc.running: - lc.start(t) - - -class Node(MockKademliaHelper): - """ Local node in the Kademlia network - - This class represents a single local node in a Kademlia network; in other - words, this class encapsulates an Entangled-using application's "presence" - in a Kademlia network. - - In Entangled, all interactions with the Kademlia network by a client - application is performed via this class (or a subclass). - """ - - def __init__(self, node_id=None, udpPort=4000, dataStore=None, - routingTableClass=None, networkProtocol=None, - externalIP=None, peerPort=3333, listenUDP=None, - callLater=None, resolve=None, clock=None, - interface='', externalUDPPort=None): - """ - @param dataStore: The data store to use. This must be class inheriting - from the C{DataStore} interface (or providing the - same API). How the data store manages its data - internally is up to the implementation of that data - store. - @type dataStore: entangled.kademlia.datastore.DataStore - @param routingTable: The routing table class to use. Since there exists - some ambiguity as to how the routing table should be - implemented in Kademlia, a different routing table - may be used, as long as the appropriate API is - exposed. This should be a class, not an object, - in order to allow the Node to pass an - auto-generated node ID to the routingtable object - upon instantiation (if necessary). - @type routingTable: entangled.kademlia.routingtable.RoutingTable - @param networkProtocol: The network protocol to use. This can be - overridden from the default to (for example) - change the format of the physical RPC messages - being transmitted. - @type networkProtocol: entangled.kademlia.protocol.KademliaProtocol - @param externalIP: the IP at which this node can be contacted - @param peerPort: the port at which this node announces it has a blob for - """ - - super().__init__(clock, callLater, resolve, listenUDP) - self.node_id = node_id or self._generateID() - self.port = udpPort - self._listen_interface = interface - self._change_token_lc = self.get_looping_call(self.change_token) - self._refresh_node_lc = self.get_looping_call(self._refreshNode) - self._refresh_contacts_lc = self.get_looping_call(self._refreshContacts) - - # Create k-buckets (for storing contacts) - if routingTableClass is None: - self._routingTable = routingtable.TreeRoutingTable(self.node_id, self.clock.seconds) - else: - self._routingTable = routingTableClass(self.node_id, self.clock.seconds) - - self._protocol = networkProtocol or protocol.KademliaProtocol(self) - self.token_secret = self._generateID() - self.old_token_secret = None - self.externalIP = externalIP - self.peerPort = peerPort - self.externalUDPPort = externalUDPPort or self.port - self._dataStore = dataStore or datastore.DictDataStore(self.clock.seconds) - self._join_deferred = None - - #def __del__(self): - # log.warning("unclean shutdown of the dht node") - # if hasattr(self, "_listeningPort") and self._listeningPort is not None: - # self._listeningPort.stopListening() - - def __str__(self): - return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % ( - self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port) - - @defer.inlineCallbacks - def stop(self): - # stop LoopingCalls: - yield self.safe_stop_looping_call(self._refresh_node_lc) - yield self.safe_stop_looping_call(self._change_token_lc) - yield self.safe_stop_looping_call(self._refresh_contacts_lc) - if self._listeningPort is not None: - yield self._listeningPort.stopListening() - self._listeningPort = None - - def start_listening(self): - if not self._listeningPort: - try: - self._listeningPort = self.reactor_listenUDP(self.port, self._protocol, - interface=self._listen_interface) - except error.CannotListenError as e: - import traceback - log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc()) - raise ValueError("%s lbrynet may already be running." % str(e)) - else: - log.warning("Already bound to port %s", self._listeningPort) - - @defer.inlineCallbacks - def joinNetwork(self, known_node_addresses=(('jack.lbry.tech', 4455), )): - """ - Attempt to join the dht, retry every 30 seconds if unsuccessful - :param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes - """ - - self._join_deferred = defer.Deferred() - known_node_resolution = {} - - @defer.inlineCallbacks - def _resolve_seeds(): - result = {} - for host, port in known_node_addresses: - node_address = yield self.reactor_resolve(host) - result[(host, port)] = node_address - defer.returnValue(result) - - if not known_node_resolution: - known_node_resolution = yield _resolve_seeds() - # we are one of the seed nodes, don't add ourselves - if (self.externalIP, self.port) in known_node_resolution.values(): - del known_node_resolution[(self.externalIP, self.port)] - known_node_addresses.remove((self.externalIP, self.port)) - - def _ping_contacts(contacts): - d = DeferredDict({contact: contact.ping() for contact in contacts}, consumeErrors=True) - d.addErrback(lambda err: err.trap(TimeoutError)) - return d - - @defer.inlineCallbacks - def _initialize_routing(): - bootstrap_contacts = [] - contact_addresses = {(c.address, c.port): c for c in self.contacts} - for (host, port), ip_address in known_node_resolution.items(): - if (host, port) not in contact_addresses: - # Create temporary contact information for the list of addresses of known nodes - # The contact node id will be set with the responding node id when we initialize it to None - contact = self.contact_manager.make_contact(None, ip_address, port, self._protocol) - bootstrap_contacts.append(contact) - else: - for contact in self.contacts: - if contact.address == ip_address and contact.port == port: - if not contact.id: - bootstrap_contacts.append(contact) - break - if not bootstrap_contacts: - log.warning("no bootstrap contacts to ping") - ping_result = yield _ping_contacts(bootstrap_contacts) - shortlist = list(ping_result.keys()) - if not shortlist: - log.warning("failed to ping %i bootstrap contacts", len(bootstrap_contacts)) - defer.returnValue(None) + if self.protocol.routing_table.get_peers(): + # if we have node ids to look up, perform the iterative search until we have k results + while node_ids: + peers = await self.peer_search(node_ids.pop()) + total_peers.extend(peers) else: - # find the closest peers to us - closest = yield self._iterativeFind(self.node_id, shortlist if not self.contacts else None) - yield _ping_contacts(closest) - # query random hashes in our bucket key ranges to fill or split them - random_ids_in_range = self._routingTable.getRefreshList() - while random_ids_in_range: - yield self.iterativeFindNode(random_ids_in_range.pop()) - defer.returnValue(None) + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(constants.refresh_interval // 4, fut.set_result, None) + await fut + continue - @defer.inlineCallbacks - def _iterative_join(joined_d=None, last_buckets_with_contacts=None): - log.info("Attempting to join the DHT network, %i contacts known so far", len(self.contacts)) - joined_d = joined_d or defer.Deferred() - yield _initialize_routing() - buckets_with_contacts = self.bucketsWithContacts() - if last_buckets_with_contacts and last_buckets_with_contacts == buckets_with_contacts: - if not joined_d.called: - joined_d.callback(True) - elif buckets_with_contacts < 4: - self.reactor_callLater(0, _iterative_join, joined_d, buckets_with_contacts) - elif not joined_d.called: - joined_d.callback(None) - yield joined_d - if not self._join_deferred.called: - self._join_deferred.callback(True) - defer.returnValue(None) + # ping the set of peers; upon success/failure the routing able and last replied/failed time will be updated + to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True] + if to_ping: + await self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0) - yield _iterative_join() + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(constants.refresh_interval, fut.set_result, None) + await fut - @defer.inlineCallbacks - def start(self, known_node_addresses=None, block_on_join=False): - """ Causes the Node to attempt to join the DHT network by contacting the - known DHT nodes. This can be called multiple times if the previous attempt - has failed or if the Node has lost all the contacts. + async def announce_blob(self, blob_hash: str) -> typing.List[bytes]: + announced_to_node_ids = [] + while not announced_to_node_ids: + hash_value = binascii.unhexlify(blob_hash.encode()) + assert len(hash_value) == constants.hash_length + peers = await self.peer_search(hash_value) - @param known_node_addresses: A sequence of tuples containing IP address - information for existing nodes on the - Kademlia network, in the format: - C{(, (udp port>)} - @type known_node_addresses: list - """ + if not self.protocol.external_ip: + raise Exception("Cannot determine external IP") + log.info("Store to %i peers", len(peers)) + log.info(peers) + for peer in peers: + log.info("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port) + stored_to_tup = await asyncio.gather( + *(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop + ) + announced_to_node_ids.extend([node_id for node_id, contacted in stored_to_tup if contacted]) + log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8], + len(announced_to_node_ids), len(peers)) + return announced_to_node_ids - self.start_listening() - yield self._protocol._listening - # TODO: Refresh all k-buckets further away than this node's closest neighbour - d = self.joinNetwork(known_node_addresses or []) - d.addCallback(lambda _: self.start_looping_calls()) - d.addCallback(lambda _: log.info("Joined the dht")) - if block_on_join: - yield d + def stop(self) -> None: + if self.joined.is_set(): + self.joined.clear() + if self._join_task: + self._join_task.cancel() + if self._refresh_task and not (self._refresh_task.done() or self._refresh_task.cancelled()): + self._refresh_task.cancel() + if self.protocol and self.protocol.ping_queue.running: + self.protocol.ping_queue.stop() + if self.listening_port is not None: + self.listening_port.close() + self._join_task = None + self.listening_port = None + log.info("Stopped DHT node") - def start_looping_calls(self): - self.safe_start_looping_call(self._change_token_lc, constants.tokenSecretChangeInterval) - # Start refreshing k-buckets periodically, if necessary - self.safe_start_looping_call(self._refresh_node_lc, constants.checkRefreshInterval) - self.safe_start_looping_call(self._refresh_contacts_lc, 60) + async def start_listening(self, interface: str = '') -> None: + if not self.listening_port: + self.listening_port, _ = await self.loop.create_datagram_endpoint( + lambda: self.protocol, (interface, self.internal_udp_port) + ) + log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port) + else: + log.warning("Already bound to port %s", self.listening_port) - @property - def contacts(self): - def _inner(): - for i in range(len(self._routingTable._buckets)): - for contact in self._routingTable._buckets[i]._contacts: - yield contact - return list(_inner()) + async def join_network(self, interface: typing.Optional[str] = '', + known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + known_node_addresses: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): + if not self.listening_port: + await self.start_listening(interface) + self.protocol.ping_queue.start() + self._refresh_task = self.loop.create_task(self.refresh_node()) - def hasContacts(self): - for bucket in self._routingTable._buckets: - if bucket._contacts: - return True - return False + known_node_addresses = known_node_addresses or [] + if known_node_urls: + for host, port in known_node_urls: + info = await self.loop.getaddrinfo( + host, 'https', + proto=socket.IPPROTO_TCP, + ) + if (info[0][4][0], port) not in known_node_addresses: + known_node_addresses.append((info[0][4][0], port)) + futs = [] + for address, port in known_node_addresses: + peer = self.protocol.get_rpc_peer(KademliaPeer(self.loop, address, udp_port=port)) + futs.append(peer.ping()) + if futs: + await asyncio.wait(futs, loop=self.loop) - def bucketsWithContacts(self): - return self._routingTable.bucketsWithContacts() + async with self.peer_search_junction(self.protocol.node_id, max_results=16) as junction: + async for peers in junction: + for peer in peers: + try: + await self.protocol.get_rpc_peer(peer).ping() + except (asyncio.TimeoutError, RemoteException): + pass + self.joined.set() + log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), + self.protocol.routing_table.buckets_with_contacts()) - @defer.inlineCallbacks - def storeToContact(self, blob_hash, contact): + def start(self, interface: str, known_node_urls: typing.List[typing.Tuple[str, int]]): + self._join_task = self.loop.create_task( + self.join_network( + interface=interface, known_node_urls=known_node_urls + ) + ) + + def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None, + bottom_out_limit: int = constants.bottom_out_limit, + max_results: int = constants.k) -> IterativeNodeFinder: + + return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol, + key, bottom_out_limit, max_results, None, shortlist) + + def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None, + bottom_out_limit: int = 40, + max_results: int = -1) -> IterativeValueFinder: + + return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol, + key, bottom_out_limit, max_results, None, shortlist) + + @contextlib.asynccontextmanager + async def stream_peer_search_junction(self, hash_queue: asyncio.Queue, bottom_out_limit=20, + max_results=-1) -> AsyncGeneratorJunction: + peer_generator = AsyncGeneratorJunction(self.loop) + + async def _add_hashes_from_queue(): + while True: + try: + blob_hash = await hash_queue.get() + except asyncio.CancelledError: + break + peer_generator.add_generator( + self.get_iterative_value_finder( + binascii.unhexlify(blob_hash.encode()), bottom_out_limit=bottom_out_limit, + max_results=max_results + ) + ) + add_hashes_task = self.loop.create_task(_add_hashes_from_queue()) try: - if not contact.token: - yield contact.findValue(blob_hash) - res = yield contact.store(blob_hash, contact.token, self.peerPort, self.node_id, 0) - if res != b"OK": - raise ValueError(res) - log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address) - return True - except protocol.TimeoutError: - log.debug("Timeout while storing blob_hash %s at %s", - binascii.hexlify(blob_hash), contact.log_id()) - except ValueError as err: - log.error("Unexpected response: %s" % err) - except Exception as err: - if 'Invalid token' in str(err): - contact.update_token(None) - log.error("Unexpected error while storing blob_hash %s at %s: %s", - binascii.hexlify(blob_hash), contact, err) - return False + async with peer_generator as junction: + yield junction + await peer_generator.finished.wait() + except asyncio.CancelledError: + if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()): + add_hashes_task.cancel() + raise + finally: + if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()): + add_hashes_task.cancel() - @defer.inlineCallbacks - def announceHaveBlob(self, blob_hash): - contacts = yield self.iterativeFindNode(blob_hash) + def peer_search_junction(self, node_id: bytes, max_results=constants.k*2, + bottom_out_limit=20) -> AsyncGeneratorJunction: + peer_generator = AsyncGeneratorJunction(self.loop) + peer_generator.add_generator( + self.get_iterative_node_finder( + node_id, bottom_out_limit=bottom_out_limit, max_results=max_results + ) + ) + return peer_generator - if not self.externalIP: - raise Exception("Cannot determine external IP: %s" % self.externalIP) - stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts}) - contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]] - log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash), - len(contacted_node_ids), len(contacts)) - defer.returnValue(contacted_node_ids) - - def change_token(self): - self.old_token_secret = self.token_secret - self.token_secret = self._generateID() - - def make_token(self, compact_ip): - h = hashlib.new('sha384') - h.update(self.token_secret + compact_ip) - return h.digest() - - def verify_token(self, token, compact_ip): - h = hashlib.new('sha384') - h.update(self.token_secret + compact_ip) - if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token? - h = hashlib.new('sha384') - h.update(self.old_token_secret + compact_ip) - if not token == h.digest(): - return False - return True - - def iterativeFindNode(self, key): - """ The basic Kademlia node lookup operation - - Call this to find a remote node in the P2P overlay network. - - @param key: the n-bit key (i.e. the node or value ID) to search for - @type key: str - - @return: This immediately returns a deferred object, which will return - a list of k "closest" contacts (C{kademlia.contact.Contact} - objects) to the specified key as soon as the operation is - finished. - @rtype: twisted.internet.defer.Deferred - """ - return self._iterativeFind(key) - - @defer.inlineCallbacks - def iterativeFindValue(self, key, exclude=None): - """ The Kademlia search operation (deterministic) - - Call this to retrieve data from the DHT. - - @param key: the n-bit key (i.e. the value ID) to search for - @type key: str - - @return: This immediately returns a deferred object, which will return - either one of two things: - - If the value was found, it will return a Python - dictionary containing the searched-for key (the C{key} - parameter passed to this method), and its associated - value, in the format: - C{key: data_value} - - If the value was not found, it will return a list of k - "closest" contacts (C{kademlia.contact.Contact} objects) - to the specified key - @rtype: twisted.internet.defer.Deferred - """ - - if len(key) != constants.key_bits // 8: - raise ValueError("invalid key length!") - - # Execute the search - find_result = yield self._iterativeFind(key, rpc='findValue', exclude=exclude) - if isinstance(find_result, dict): - # We have found the value; now see who was the closest contact without it... - # ...and store the key/value pair - pass - else: - # The value wasn't found, but a list of contacts was returned - # Now, see if we have the value (it might seem wasteful to search on the network - # first, but it ensures that all values are properly propagated through the - # network - if self._dataStore.hasPeersForBlob(key): - # Ok, we have the value locally, so use that - # Send this value to the closest node without it - peers = self._dataStore.getPeersForBlob(key) - find_result = {key: peers} - else: - pass - - defer.returnValue(list(set(find_result.get(key, []) if find_result else []))) - # TODO: get this working - # if 'closestNodeNoValue' in find_result: - # closest_node_without_value = find_result['closestNodeNoValue'] - # try: - # response, address = yield closest_node_without_value.findValue(key, rawResponse=True) - # yield closest_node_without_value.store(key, response.response['token'], self.peerPort) - # except TimeoutError: - # pass - - def addContact(self, contact): - """ Add/update the given contact; simple wrapper for the same method - in this object's RoutingTable object - - @param contact: The contact to add to this node's k-buckets - @type contact: kademlia.contact.Contact - """ - return self._routingTable.addContact(contact) - - def removeContact(self, contact): - """ Remove the contact with the specified node ID from this node's - table of known nodes. This is a simple wrapper for the same method - in this object's RoutingTable object - - @param contact: The Contact object to remove - @type contact: _Contact - """ - self._routingTable.removeContact(contact) - - def findContact(self, contactID): - """ Find a entangled.kademlia.contact.Contact object for the specified - cotact ID - - @param contactID: The contact ID of the required Contact object - @type contactID: str - - @return: Contact object of remote node with the specified node ID, - or None if the contact was not found - @rtype: twisted.internet.defer.Deferred - """ - try: - df = defer.succeed(self._routingTable.getContact(contactID)) - except (ValueError, IndexError): - df = self.iterativeFindNode(contactID) - df.addCallback(lambda nodes: ([node for node in nodes if node.id == contactID] or (None,))[0]) - return df - - @rpcmethod - def ping(self): - """ Used to verify contact between two Kademlia nodes - - @rtype: str - """ - return b'pong' - - @rpcmethod - def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age): - """ Store the received data in this node's local datastore - - @param blob_hash: The hash of the data - @type blob_hash: str - - @param token: The token we previously returned when this contact sent us a findValue - @type token: str - - @param port: The TCP port the contact is listening on for requests for this blob (the peerPort) - @type port: int - - @param originalPublisherID: The node ID of the node that is the publisher of the data - @type originalPublisherID: str - - @param age: The relative age of the data (time in seconds since it was - originally published). Note that the original publish time - isn't actually given, to compensate for clock skew between - different nodes. - @type age: int - - @rtype: str - """ - - if originalPublisherID is None: - originalPublisherID = rpc_contact.id - compact_ip = rpc_contact.compact_ip() - if self.clock.seconds() - self._protocol.started_listening_time < constants.tokenSecretChangeInterval: - pass - elif not self.verify_token(token, compact_ip): - raise ValueError("Invalid token") - if 0 <= port <= 65536: - compact_port = port.to_bytes(2, 'big') - else: - raise TypeError(f'Invalid port: {port}') - compact_address = compact_ip + compact_port + rpc_contact.id - now = int(self.clock.seconds()) - originallyPublished = now - age - self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished, - originalPublisherID) - return b'OK' - - @rpcmethod - def findNode(self, rpc_contact, key): - """ Finds a number of known nodes closest to the node/value with the - specified key. - - @param key: the n-bit key (i.e. the node or value ID) to search for - @type key: str - - @return: A list of contact triples closest to the specified key. - This method will return C{k} (or C{count}, if specified) - contacts if at all possible; it will only return fewer if the - node is returning all of the contacts that it knows of. - @rtype: list - """ - if len(key) != constants.key_bits // 8: - raise ValueError("invalid contact id length: %i" % len(key)) - - contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id) - contact_triples = [] - for contact in contacts: - contact_triples.append((contact.id, contact.address, contact.port)) - return contact_triples - - @rpcmethod - def findValue(self, rpc_contact, key): - """ Return the value associated with the specified key if present in - this node's data, otherwise execute FIND_NODE for the key - - @param key: The hashtable key of the data to return - @type key: str - - @return: A dictionary containing the requested key/value pair, - or a list of contact triples closest to the requested key. - @rtype: dict or list - """ - - if len(key) != constants.key_bits // 8: - raise ValueError("invalid blob hash length: %i" % len(key)) - - response = { - b'token': self.make_token(rpc_contact.compact_ip()), - } - - if self._protocol._protocolVersion: - response[b'protocolVersion'] = self._protocol._protocolVersion - - # get peers we have stored for this blob - has_other_peers = self._dataStore.hasPeersForBlob(key) - peers = [] - if has_other_peers: - peers.extend(self._dataStore.getPeersForBlob(key)) - - # if we don't have k storing peers to return and we have this hash locally, include our contact information - if len(peers) < constants.k and key in self._dataStore.completed_blobs: - compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray()) - compact_port = self.peerPort.to_bytes(2, 'big') - compact_address = compact_ip + compact_port + self.node_id - peers.append(compact_address) - - if peers: - response[key] = peers - else: - response[b'contacts'] = self.findNode(rpc_contact, key) - return response - - def _generateID(self): - """ Generates an n-bit pseudo-random identifier - - @return: A globally unique n-bit pseudo-random identifier - @rtype: str - """ - return generate_id() - - # from lbrynet.p2p.utils import profile_deferred - # @profile_deferred() - @defer.inlineCallbacks - def _iterativeFind(self, key, startupShortlist=None, rpc='findNode', exclude=None): - """ The basic Kademlia iterative lookup operation (for nodes/values) - - This builds a list of k "closest" contacts through iterative use of - the "FIND_NODE" RPC, or if C{findValue} is set to C{True}, using the - "FIND_VALUE" RPC, in which case the value (if found) may be returned - instead of a list of contacts - - @param key: the n-bit key (i.e. the node or value ID) to search for - @type key: str - @param startupShortlist: A list of contacts to use as the starting - shortlist for this search; this is normally - only used when the node joins the network - @type startupShortlist: list - @param rpc: The name of the RPC to issue to remote nodes during the - Kademlia lookup operation (e.g. this sets whether this - algorithm should search for a data value (if - rpc='findValue') or not. It can thus be used to perform - other operations that piggy-back on the basic Kademlia - lookup operation (Entangled's "delete" RPC, for instance). - @type rpc: str - - @return: If C{findValue} is C{True}, the algorithm will stop as soon - as a data value for C{key} is found, and return a dictionary - containing the key and the found value. Otherwise, it will - return a list of the k closest nodes to the specified key - @rtype: twisted.internet.defer.Deferred - """ - - if len(key) != constants.key_bits // 8: - raise ValueError("invalid key length: %i" % len(key)) - - if startupShortlist is None: - shortlist = self._routingTable.findCloseNodes(key) - # if key != self.node_id: - # # Update the "last accessed" timestamp for the appropriate k-bucket - # self._routingTable.touchKBucket(key) - if len(shortlist) == 0: - log.warning("This node doesn't know any other nodes") - # This node doesn't know of any other nodes - fakeDf = defer.Deferred() - fakeDf.callback([]) - result = yield fakeDf - defer.returnValue(result) - else: - # This is used during the bootstrap process - shortlist = startupShortlist - - result = yield iterativeFind(self, shortlist, key, rpc, exclude=exclude) - defer.returnValue(result) - - @defer.inlineCallbacks - def _refreshNode(self): - """ Periodically called to perform k-bucket refreshes and data - replication/republishing as necessary """ - yield self._refreshRoutingTable() - self._dataStore.removeExpiredPeers() - self._refreshStoringPeers() - defer.returnValue(None) - - def _refreshContacts(self): - self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts, delay=0) - - def _refreshStoringPeers(self): - self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts(), delay=0) - - @defer.inlineCallbacks - def _refreshRoutingTable(self): - nodeIDs = self._routingTable.getRefreshList(0, False) - while nodeIDs: - searchID = nodeIDs.pop() - yield self.iterativeFindNode(searchID) - defer.returnValue(None) + async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2, + bottom_out_limit=20) -> typing.List['KademliaPeer']: + accumulated: typing.List['KademliaPeer'] = [] + async with self.peer_search_junction(self.protocol.node_id, max_results=max_results, + bottom_out_limit=bottom_out_limit) as junction: + async for peers in junction: + log.info("peer search: %s", peers) + accumulated.extend(peers) + log.info("junction done") + log.info("context done") + distance = Distance(node_id) + accumulated.sort(key=lambda peer: distance(peer.node_id)) + return accumulated[:count] diff --git a/lbrynet/dht/peer.py b/lbrynet/dht/peer.py new file mode 100644 index 000000000..319ca9b5a --- /dev/null +++ b/lbrynet/dht/peer.py @@ -0,0 +1,203 @@ +import typing +import asyncio +import logging +import ipaddress +from binascii import hexlify +from lbrynet.dht import constants +from lbrynet.dht.serialization.datagram import make_compact_address, make_compact_ip, decode_compact_address + +log = logging.getLogger(__name__) + + +def is_valid_ipv4(address): + try: + ip = ipaddress.ip_address(address) + return ip.version == 4 + except ipaddress.AddressValueError: + return False + + +class PeerManager: + def __init__(self, loop: asyncio.BaseEventLoop): + self._loop = loop + self._rpc_failures: typing.Dict[ + typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]] + ] = {} + self._last_replied: typing.Dict[typing.Tuple[str, int], float] = {} + self._last_sent: typing.Dict[typing.Tuple[str, int], float] = {} + self._last_requested: typing.Dict[typing.Tuple[str, int], float] = {} + self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {} + self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {} + self._node_tokens: typing.Dict[bytes, (float, bytes)] = {} + self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer'] + self._lock = asyncio.Lock(loop=loop) + + async def report_failure(self, address: str, udp_port: int): + now = self._loop.time() + async with self._lock: + _, previous = self._rpc_failures.pop((address, udp_port), (None, None)) + self._rpc_failures[(address, udp_port)] = (previous, now) + + async def report_last_sent(self, address: str, udp_port: int): + now = self._loop.time() + async with self._lock: + self._last_sent[(address, udp_port)] = now + + async def report_last_replied(self, address: str, udp_port: int): + now = self._loop.time() + async with self._lock: + self._last_replied[(address, udp_port)] = now + + async def report_last_requested(self, address: str, udp_port: int): + now = self._loop.time() + async with self._lock: + self._last_requested[(address, udp_port)] = now + + async def clear_token(self, node_id: bytes): + async with self._lock: + self._node_tokens.pop(node_id, None) + + async def update_token(self, node_id: bytes, token: bytes): + now = self._loop.time() + async with self._lock: + self._node_tokens[node_id] = (now, token) + + def get_node_token(self, node_id: bytes) -> typing.Optional[bytes]: + ts, token = self._node_tokens.get(node_id, (None, None)) + if ts and ts > self._loop.time() - constants.token_secret_refresh_interval: + return token + + def get_last_replied(self, address: str, udp_port: int) -> typing.Optional[float]: + return self._last_replied.get((address, udp_port)) + + def get_node_id(self, address: str, udp_port: int) -> typing.Optional[bytes]: + return self._node_id_mapping.get((address, udp_port)) + + def get_node_address(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]: + return self._node_id_reverse_mapping.get(node_id) + + async def get_node_address_and_port(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]: + async with self._lock: + addr_tuple = self._node_id_reverse_mapping.get(node_id) + if addr_tuple and addr_tuple in self._node_id_mapping: + return addr_tuple + + async def update_contact_triple(self, node_id: bytes, address: str, udp_port: int): + """ + Update the mapping of node_id -> address tuple and that of address tuple -> node_id + This is to handle peers changing addresses and ids while assuring that the we only ever have + one node id / address tuple mapped to each other + """ + async with self._lock: + if (address, udp_port) in self._node_id_mapping: + self._node_id_reverse_mapping.pop(self._node_id_mapping.pop((address, udp_port))) + if node_id in self._node_id_reverse_mapping: + self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id)) + self._node_id_mapping[(address, udp_port)] = node_id + self._node_id_reverse_mapping[node_id] = (address, udp_port) + + def get_kademlia_peer(self, node_id: bytes, address: str, udp_port: int) -> 'KademliaPeer': + return KademliaPeer(self._loop, address, node_id, udp_port) + + async def prune(self): + now = self._loop.time() + async with self._lock: + to_pop = [] + for (address, udp_port), (_, last_failure) in self._rpc_failures.items(): + if last_failure and last_failure < now - constants.rpc_attempts_pruning_window: + to_pop.append((address, udp_port)) + while to_pop: + del self._rpc_failures[to_pop.pop()] + to_pop = [] + for node_id, (age, token) in self._node_tokens.items(): + if age < now - constants.token_secret_refresh_interval: + to_pop.append(node_id) + while to_pop: + del self._node_tokens[to_pop.pop()] + + def contact_triple_is_good(self, node_id: bytes, address: str, udp_port: int): + """ + :return: False if peer is bad, None if peer is unknown, or True if peer is good + """ + + delay = self._loop.time() - constants.check_refresh_interval + + if node_id not in self._node_id_reverse_mapping or (address, udp_port) not in self._node_id_mapping: + return + addr_tup = (address, udp_port) + if self._node_id_reverse_mapping[node_id] != addr_tup or self._node_id_mapping[addr_tup] != node_id: + return + previous_failure, most_recent_failure = self._rpc_failures.get((address, udp_port), (None, None)) + last_requested = self._last_requested.get((address, udp_port)) + last_replied = self._last_replied.get((address, udp_port)) + if most_recent_failure and last_replied: + if delay < last_replied > most_recent_failure: + return True + elif last_replied > most_recent_failure: + return + return False + elif previous_failure and most_recent_failure and most_recent_failure > delay: + return False + elif last_replied and last_replied > delay: + return True + elif last_requested and last_requested > delay: + return None + return + + def peer_is_good(self, peer: 'KademliaPeer'): + return self.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) + + def decode_tcp_peer_from_compact_address(self, compact_address: bytes) -> 'KademliaPeer': + node_id, address, tcp_port = decode_compact_address(compact_address) + return KademliaPeer(self._loop, address, node_id, tcp_port=tcp_port) + + +class KademliaPeer: + def __init__(self, loop: asyncio.BaseEventLoop, address: str, node_id: typing.Optional[bytes] = None, + udp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None): + if node_id is not None: + if not len(node_id) == constants.hash_length: + raise ValueError("invalid node_id: {}".format(hexlify(node_id).decode())) + if udp_port is not None and not 0 <= udp_port <= 65536: + raise ValueError("invalid udp port") + if tcp_port and not 0 <= tcp_port <= 65536: + raise ValueError("invalid tcp port") + if not is_valid_ipv4(address): + raise ValueError("invalid ip address") + self.loop = loop + self._node_id = node_id + self.address = address + self.udp_port = udp_port + self.tcp_port = tcp_port + self.protocol_version = 1 + + def update_tcp_port(self, tcp_port: int): + self.tcp_port = tcp_port + + def update_udp_port(self, udp_port: int): + self.udp_port = udp_port + + def set_id(self, node_id): + if not self._node_id: + self._node_id = node_id + + @property + def node_id(self) -> bytes: + return self._node_id + + def compact_address_udp(self) -> bytearray: + return make_compact_address(self.node_id, self.address, self.udp_port) + + def compact_address_tcp(self) -> bytearray: + return make_compact_address(self.node_id, self.address, self.tcp_port) + + def compact_ip(self): + return make_compact_ip(self.address) + + def __eq__(self, other): + if not isinstance(other, KademliaPeer): + raise TypeError("invalid type to compare with Peer: %s" % str(type(other))) + return (self.node_id, self.address, self.udp_port) == (other.node_id, other.address, other.udp_port) + + def __hash__(self): + return hash((self.node_id, self.address, self.udp_port)) diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py deleted file mode 100644 index 0a00b01f1..000000000 --- a/lbrynet/dht/protocol.py +++ /dev/null @@ -1,493 +0,0 @@ -import logging -import errno -from binascii import hexlify - -from twisted.internet import protocol, defer -from lbrynet.dht import constants, encoding, msgformat, msgtypes -from lbrynet.dht.error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected - -log = logging.getLogger(__name__) - - -class PingQueue: - """ - Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the - routing table after having been given enough time for a pinhole to expire. - """ - - def __init__(self, node): - self._node = node - self._enqueued_contacts = {} - self._pending_contacts = {} - self._process_lc = node.get_looping_call(self._process) - - def enqueue_maybe_ping(self, *contacts, **kwargs): - delay = kwargs.get('delay', constants.checkRefreshInterval) - no_op = (defer.succeed(None), lambda: None) - for contact in contacts: - if delay and contact not in self._enqueued_contacts: - self._pending_contacts.setdefault(contact, self._node.clock.seconds() + delay) - else: - self._enqueued_contacts.setdefault(contact, no_op) - - @defer.inlineCallbacks - def _ping(self, contact): - if contact.contact_is_good: - return - try: - yield contact.ping() - except TimeoutError: - pass - except Exception as err: - log.warning("unexpected error: %s", err) - finally: - if contact in self._enqueued_contacts: - del self._enqueued_contacts[contact] - - def _process(self): - # move contacts that are scheduled to join the queue - if self._pending_contacts: - now = self._node.clock.seconds() - for contact in [contact for contact, schedule in self._pending_contacts.items() if schedule <= now]: - del self._pending_contacts[contact] - self._enqueued_contacts.setdefault(contact, (defer.succeed(None), lambda: None)) - # spread pings across 60 seconds to avoid flood and/or false negatives - step = 60.0/float(len(self._enqueued_contacts)) if self._enqueued_contacts else 0 - for index, (contact, (call, _)) in enumerate(self._enqueued_contacts.items()): - if call.called and not contact.contact_is_good: - self._enqueued_contacts[contact] = self._node.reactor_callLater(index*step, self._ping, contact) - - def start(self): - return self._node.safe_start_looping_call(self._process_lc, 60) - - def stop(self): - map(None, (cancel() for _, (call, cancel) in self._enqueued_contacts.items() if not call.called)) - return self._node.safe_stop_looping_call(self._process_lc) - - -class KademliaProtocol(protocol.DatagramProtocol): - """ Implements all low-level network-related functions of a Kademlia node """ - - msgSizeLimit = constants.udpDatagramMaxSize - 26 - - def __init__(self, node): - self._node = node - self._translator = msgformat.DefaultFormat() - self._sentMessages = {} - self._partialMessages = {} - self._partialMessagesProgress = {} - self._listening = defer.Deferred(None) - self._ping_queue = PingQueue(self._node) - self._protocolVersion = constants.protocolVersion - self.started_listening_time = 0 - - def _migrate_incoming_rpc_args(self, contact, method, *args): - if method == b'store' and contact.protocolVersion == 0: - if isinstance(args[1], dict): - blob_hash = args[0] - token = args[1].pop(b'token', None) - port = args[1].pop(b'port', -1) - originalPublisherID = args[1].pop(b'lbryid', None) - age = 0 - return (blob_hash, token, port, originalPublisherID, age), {} - return args, {} - - def _migrate_outgoing_rpc_args(self, contact, method, *args): - """ - This will reformat protocol version 0 arguments for the store function and will add the - protocol version keyword argument to calls to contacts who will accept it - """ - if contact.protocolVersion == 0: - if method == b'store': - blob_hash, token, port, originalPublisherID, age = args - args = ( - blob_hash, { - b'token': token, - b'port': port, - b'lbryid': originalPublisherID - }, originalPublisherID, False - ) - return args - return args - if args and isinstance(args[-1], dict): - args[-1][b'protocolVersion'] = self._protocolVersion - return args - return args + ({b'protocolVersion': self._protocolVersion},) - - @defer.inlineCallbacks - def sendRPC(self, contact, method, args): - for _ in range(constants.rpcAttempts): - try: - response = yield self._sendRPC(contact, method, args) - return response - except TimeoutError: - if contact.contact_is_good: - log.debug("RETRY %s ON %s", method, contact) - continue - else: - raise - - def _sendRPC(self, contact, method, args): - """ - Sends an RPC to the specified contact - - @param contact: The contact (remote node) to send the RPC to - @type contact: kademlia.contacts.Contact - @param method: The name of remote method to invoke - @type method: str - @param args: A list of (non-keyword) arguments to pass to the remote - method, in the correct order - @type args: tuple - - @return: This immediately returns a deferred object, which will return - the result of the RPC call, or raise the relevant exception - if the remote node raised one. If C{rawResponse} is set to - C{True}, however, it will always return the actual response - message (which may be a C{ResponseMessage} or an - C{ErrorMessage}). - @rtype: twisted.internet.defer.Deferred - """ - msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method, - *args)) - msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = encoding.bencode(msgPrimitive) - - if args: - log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method, - hexlify(args[0]), contact.address, contact.port) - else: - log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method, - contact.address, contact.port) - - df = defer.Deferred() - - def _remove_contact(failure): # remove the contact from the routing table and track the failure - contact.update_last_failed() - try: - if not contact.contact_is_good: - self._node.removeContact(contact) - except (ValueError, IndexError): - pass - return failure - - def _update_contact(result): # refresh the contact in the routing table - contact.update_last_replied() - if method == b'findValue': - if b'token' in result: - contact.update_token(result[b'token']) - if b'protocolVersion' not in result: - contact.update_protocol_version(0) - else: - contact.update_protocol_version(result.pop(b'protocolVersion')) - d = self._node.addContact(contact) - d.addCallback(lambda _: result) - return d - - df.addCallbacks(_update_contact, _remove_contact) - - # Set the RPC timeout timer - timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id) - - # Transmit the data - self._send(encodedMsg, msg.id, (contact.address, contact.port)) - self._sentMessages[msg.id] = (contact, df, timeoutCall, cancelTimeout, method, args) - - df.addErrback(cancelTimeout) - return df - - def startProtocol(self): - log.info("DHT listening on UDP %i (ext port %i)", self._node.port, self._node.externalUDPPort) - if self._listening.called: - self._listening = defer.Deferred() - self._listening.callback(True) - self.started_listening_time = self._node.clock.seconds() - return self._ping_queue.start() - - def datagramReceived(self, datagram, address): - """ Handles and parses incoming RPC messages (and responses) - - @note: This is automatically called by Twisted when the protocol - receives a UDP datagram - """ - if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00': - totalPackets = (datagram[1] << 8) | datagram[2] - msgID = datagram[5:25] - seqNumber = (datagram[3] << 8) | datagram[4] - if msgID not in self._partialMessages: - self._partialMessages[msgID] = {} - self._partialMessages[msgID][seqNumber] = datagram[26:] - if len(self._partialMessages[msgID]) == totalPackets: - keys = self._partialMessages[msgID].keys() - keys.sort() - data = b'' - for key in keys: - data += self._partialMessages[msgID][key] - datagram = data - del self._partialMessages[msgID] - else: - return - try: - msgPrimitive = encoding.bdecode(datagram) - message = self._translator.fromPrimitive(msgPrimitive) - except (encoding.DecodeError, ValueError) as err: - # We received some rubbish here - log.warning("Error decoding datagram %s from %s:%i - %s", hexlify(datagram), - address[0], address[1], err) - return - except (IndexError, KeyError): - log.warning("Couldn't decode dht datagram from %s", address) - return - - if isinstance(message, msgtypes.RequestMessage): - # This is an RPC method request - remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self) - remoteContact.update_last_requested() - # only add a requesting contact to the routing table if it has replied to one of our requests - if remoteContact.contact_is_good is True: - df = self._node.addContact(remoteContact) - else: - df = defer.succeed(None) - df.addCallback(lambda _: self._handleRPC(remoteContact, message.id, message.request, message.args)) - # if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it - # will be added to our routing table if successful - if remoteContact.contact_is_good is None and remoteContact.lastReplied is None: - df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact)) - elif isinstance(message, msgtypes.ErrorMessage): - # The RPC request raised a remote exception; raise it locally - if message.exceptionType in BUILTIN_EXCEPTIONS: - exception_type = BUILTIN_EXCEPTIONS[message.exceptionType] - else: - exception_type = UnknownRemoteException - remoteException = exception_type(message.response) - log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0], - address[1], remoteException) - if message.id in self._sentMessages: - # Cancel timeout timer for this RPC - remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5] - timeoutCanceller() - del self._sentMessages[message.id] - - # reject replies coming from a different address than what we sent our request to - if (remoteContact.address, remoteContact.port) != address: - log.warning("Sent request to node %s at %s:%i, got reply from %s:%i", - remoteContact.log_id(), remoteContact.address, - remoteContact.port, address[0], address[1]) - df.errback(TimeoutError(remoteContact.id)) - return - - # this error is returned by nodes that can be contacted but have an old - # and broken version of the ping command, if they return it the node can - # be contacted, so we'll treat it as a successful ping - old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'" - if isinstance(remoteException, TypeError) and \ - remoteException.message == old_ping_error: - log.debug("old pong error") - df.callback('pong') - else: - df.errback(remoteException) - elif isinstance(message, msgtypes.ResponseMessage): - # Find the message that triggered this response - if message.id in self._sentMessages: - # Cancel timeout timer for this RPC - remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5] - timeoutCanceller() - del self._sentMessages[message.id] - log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port, - method, remoteContact.address, remoteContact.port) - - # When joining the network we made Contact objects for the seed nodes with node ids set to None - # Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set. - # These replies have be distinguished from those where the node id in the datagram does not match - # the node id of the node we sent a message to (these messages are treated as an error) - if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap - log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port, - remoteContact.log_id(False), hexlify(message.nodeID)) - df.errback(TimeoutError(remoteContact.id)) - return - elif not remoteContact.id: - remoteContact.set_id(message.nodeID) - - # We got a result from the RPC - df.callback(message.response) - else: - # If the original message isn't found, it must have timed out - # TODO: we should probably do something with this... - pass - - def _send(self, data, rpcID, address): - """ Transmit the specified data over UDP, breaking it up into several - packets if necessary - - If the data is spread over multiple UDP datagrams, the packets have the - following structure:: - | | | | | |||||||||||| 0x00 | - |Transmission|Total number|Sequence number| RPC ID |Header end| - | type ID | of packets |of this packet | | indicator| - | (1 byte) | (2 bytes) | (2 bytes) |(20 bytes)| (1 byte) | - | | | | | |||||||||||| | - - @note: The header used for breaking up large data segments will - possibly be moved out of the KademliaProtocol class in the - future, into something similar to a message translator/encoder - class (see C{kademlia.msgformat} and C{kademlia.encoding}). - """ - - if len(data) > self.msgSizeLimit: - # We have to spread the data over multiple UDP datagrams, - # and provide sequencing information - # - # 1st byte is transmission type id, bytes 2 & 3 are the - # total number of packets in this transmission, bytes 4 & - # 5 are the sequence number for this specific packet - totalPackets = len(data) // self.msgSizeLimit - if len(data) % self.msgSizeLimit > 0: - totalPackets += 1 - encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff) - seqNumber = 0 - startPos = 0 - while seqNumber < totalPackets: - packetData = data[startPos:startPos + self.msgSizeLimit] - encSeqNumber = chr(seqNumber >> 8) + chr(seqNumber & 0xff) - txData = f'\x00{encTotalPackets}{encSeqNumber}{rpcID}\x00{packetData}' - self._scheduleSendNext(txData, address) - - startPos += self.msgSizeLimit - seqNumber += 1 - else: - self._scheduleSendNext(data, address) - - def _scheduleSendNext(self, txData, address): - """Schedule the sending of the next UDP packet """ - delayed_call, _ = self._node.reactor_callSoon(self._write, txData, address) - - def _write(self, txData, address): - if self.transport: - try: - self.transport.write(txData, address) - except OSError as err: - if err.errno == errno.EWOULDBLOCK: - # i'm scared this may swallow important errors, but i get a million of these - # on Linux and it doesn't seem to affect anything -grin - log.warning("Can't send data to dht: EWOULDBLOCK") - elif err.errno == errno.ENETUNREACH: - # this should probably try to retransmit when the network connection is back - log.error("Network is unreachable") - else: - log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)", - len(txData), address[0], address[1], err.message, err.errno) - raise err - else: - raise TransportNotConnected() - - def _sendResponse(self, contact, rpcID, response): - """ Send a RPC response to the specified contact - """ - msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response) - msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = encoding.bencode(msgPrimitive) - self._send(encodedMsg, rpcID, (contact.address, contact.port)) - - def _sendError(self, contact, rpcID, exceptionType, exceptionMessage): - """ Send an RPC error message to the specified contact - """ - exceptionMessage = exceptionMessage.encode() - msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage) - msgPrimitive = self._translator.toPrimitive(msg) - encodedMsg = encoding.bencode(msgPrimitive) - self._send(encodedMsg, rpcID, (contact.address, contact.port)) - - def _handleRPC(self, senderContact, rpcID, method, args): - """ Executes a local function in response to an RPC request """ - - # Set up the deferred callchain - def handleError(f): - self._sendError(senderContact, rpcID, f.type, f.getErrorMessage()) - - def handleResult(result): - self._sendResponse(senderContact, rpcID, result) - - df = defer.Deferred() - df.addCallback(handleResult) - df.addErrback(handleError) - - # Execute the RPC - func = getattr(self._node, method.decode(), None) - if callable(func) and hasattr(func, "rpcmethod"): - # Call the exposed Node method and return the result to the deferred callback chain - # if args: - # log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method, - # args[0].encode('hex'), senderContact.address, senderContact.port) - # else: - log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method, - senderContact.address, senderContact.port) - if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting - senderContact.update_protocol_version(int(args[-1].pop(b'protocolVersion'))) - a, kw = tuple(args[:-1]), args[-1] - else: - senderContact.update_protocol_version(0) - a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args) - try: - if method != b'ping': - result = func(senderContact, *a) - else: - result = func() - except Exception as e: - log.error("error handling request for %s:%i %s", senderContact.address, senderContact.port, method) - df.errback(e) - else: - df.callback(result) - else: - # No such exposed method - df.errback(AttributeError('Invalid method: %s' % method)) - return df - - def _msgTimeout(self, messageID): - """ Called when an RPC request message times out """ - # Find the message that timed out - if messageID not in self._sentMessages: - # This should never be reached - log.error("deferred timed out, but is not present in sent messages list!") - return - remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID] - if messageID in self._partialMessages: - # We are still receiving this message - self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args) - return - del self._sentMessages[messageID] - # The message's destination node is now considered to be dead; - # raise an (asynchronous) TimeoutError exception and update the host node - df.errback(TimeoutError(remoteContact.id)) - - def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args): - # See if any progress has been made; if not, kill the message - if self._hasProgressBeenMade(messageID): - # Reset the RPC timeout timer - timeoutCanceller() - timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID) - self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args) - else: - # No progress has been made - if messageID in self._partialMessagesProgress: - del self._partialMessagesProgress[messageID] - if messageID in self._partialMessages: - del self._partialMessages[messageID] - df.errback(TimeoutError(remoteContact.id)) - - def _hasProgressBeenMade(self, messageID): - return ( - messageID in self._partialMessagesProgress and - ( - len(self._partialMessagesProgress[messageID]) != - len(self._partialMessages[messageID]) - ) - ) - - def stopProtocol(self): - """ Called when the transport is disconnected. - - Will only be called once, after all ports are disconnected. - """ - log.info('Stopping DHT') - self._ping_queue.stop() - self._node.call_later_manager.stop() - log.info('DHT stopped') diff --git a/lbrynet/dht/protocol/__init__.py b/lbrynet/dht/protocol/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbrynet/dht/protocol/async_generator_junction.py b/lbrynet/dht/protocol/async_generator_junction.py new file mode 100644 index 000000000..3993f4b2b --- /dev/null +++ b/lbrynet/dht/protocol/async_generator_junction.py @@ -0,0 +1,105 @@ +import asyncio +import typing +import logging +import traceback +if typing.TYPE_CHECKING: + from types import AsyncGeneratorType + +log = logging.getLogger(__name__) + + +def cancel_task(task: typing.Optional[asyncio.Task]): + if task and not (task.done() or task.cancelled()): + task.cancel() + + +def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): + for task in tasks: + cancel_task(task) + + +def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]): + while tasks: + cancel_task(tasks.pop()) + + +class AsyncGeneratorJunction: + """ + A helper to interleave the results from multiple async generators into one + async generator. + """ + + def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None): + self.loop = loop + self.result_queue = queue or asyncio.Queue(loop=loop) + self.tasks: typing.List[asyncio.Task] = [] + self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {} + self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop) + self.can_iterate = asyncio.Event(loop=self.loop) + self.finished = asyncio.Event(loop=self.loop) + + @property + def running(self): + return any(self.running_iterators.values()) + + async def wait_for_generators(self): + async def iterate(iterator: typing.AsyncGenerator): + try: + async for item in iterator: + self.result_queue.put_nowait(item) + finally: + self.running_iterators[iterator] = False + + while True: + async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get() + self.running_iterators[async_gen] = True + self.tasks.append(self.loop.create_task(iterate(async_gen))) + if not self.can_iterate.is_set(): + self.can_iterate.set() + + def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']): + """ + Add an async generator. This can be called during an iteration of the generator junction. + """ + self.generator_queue.put_nowait(async_gen) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.can_iterate.is_set(): + await self.can_iterate.wait() + if not self.running: + raise StopAsyncIteration() + try: + return await self.result_queue.get() + finally: + self.awaiting = None + + def aclose(self): + async def _aclose(): + for iterator in list(self.running_iterators.keys()): + result = iterator.aclose() + if asyncio.iscoroutine(result): + await result + self.running_iterators[iterator] = False + drain_tasks(self.tasks) + raise StopAsyncIteration() + if not self.finished.is_set(): + self.finished.set() + return self.loop.create_task(_aclose()) + + async def __aenter__(self): + self.tasks.append(self.loop.create_task(self.wait_for_generators())) + return self + + async def __aexit__(self, exc_type, exc, tb): + try: + await self.aclose() + except StopAsyncIteration: + pass + finally: + if exc_type: + if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration): + err = traceback.format_exception(exc_type, exc, tb) + log.error(err) diff --git a/lbrynet/dht/protocol/data_store.py b/lbrynet/dht/protocol/data_store.py new file mode 100644 index 000000000..f422c21d5 --- /dev/null +++ b/lbrynet/dht/protocol/data_store.py @@ -0,0 +1,76 @@ +import asyncio +import typing + +from lbrynet.dht import constants +if typing.TYPE_CHECKING: + from lbrynet.dht.peer import KademliaPeer, PeerManager + + +class DictDataStore: + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'): + # Dictionary format: + # { : [, , , ] } + self._data_store: typing.Dict[bytes, + typing.List[typing.Tuple['KademliaPeer', bytes, float, float, bytes]]] = {} + self._get_time = loop.time + self._peer_manager = peer_manager + self.completed_blobs: typing.Set[str] = set() + + def filter_bad_and_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']: + """ + Returns only non-expired and unknown/good peers + """ + peers = [] + for peer in map(lambda p: p[0], + filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, + self._data_store[key])): + if self._peer_manager.peer_is_good(peer) is not False: + peers.append(peer) + return peers + + def filter_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']: + """ + Returns only non-expired peers + """ + return list( + map( + lambda p: p[0], + filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, self._data_store[key]) + ) + ) + + def removed_expired_peers(self): + expired_keys = [] + for key in self._data_store.keys(): + unexpired_peers = self.filter_expired_peers(key) + if not unexpired_peers: + expired_keys.append(key) + else: + self._data_store[key] = [x for x in self._data_store[key] if x[0] in unexpired_peers] + for key in expired_keys: + del self._data_store[key] + + def has_peers_for_blob(self, key: bytes) -> bool: + return key in self._data_store and len(self.filter_bad_and_expired_peers(key)) > 0 + + def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes, compact_address: bytes, last_published: int, + originally_published: int, original_publisher_id: bytes) -> None: + if key in self._data_store: + if compact_address not in map(lambda store_tuple: store_tuple[1], self._data_store[key]): + self._data_store[key].append( + (contact, compact_address, last_published, originally_published, original_publisher_id) + ) + else: + self._data_store[key] = [(contact, compact_address, last_published, originally_published, + original_publisher_id)] + + def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']: + return [] if key not in self._data_store else [peer for peer in self.filter_bad_and_expired_peers(key)] + + def get_storing_contacts(self) -> typing.List['KademliaPeer']: + peers = set() + for key in self._data_store: + for values in self._data_store[key]: + if values[0] not in peers: + peers.add(values[0]) + return list(peers) diff --git a/lbrynet/dht/distance.py b/lbrynet/dht/protocol/distance.py similarity index 66% rename from lbrynet/dht/distance.py rename to lbrynet/dht/protocol/distance.py index 2c1099535..516dfabc4 100644 --- a/lbrynet/dht/distance.py +++ b/lbrynet/dht/protocol/distance.py @@ -8,20 +8,16 @@ class Distance: we pre-calculate the value of that point. """ - def __init__(self, key): - if len(key) != constants.key_bits // 8: + def __init__(self, key: bytes): + if len(key) != constants.hash_length: raise ValueError("invalid key length: %i" % len(key)) self.key = key self.val_key_one = int.from_bytes(key, 'big') - def __call__(self, key_two): + def __call__(self, key_two: bytes) -> int: val_key_two = int.from_bytes(key_two, 'big') return self.val_key_one ^ val_key_two - def is_closer(self, a, b): + def is_closer(self, a: bytes, b: bytes) -> bool: """Returns true is `a` is closer to `key` than `b` is""" return self(a) < self(b) - - def to_contact(self, contact): - """A convenience function for calculating the distance to a contact""" - return self(contact.id) diff --git a/lbrynet/dht/protocol/iterative_find.py b/lbrynet/dht/protocol/iterative_find.py new file mode 100644 index 000000000..8b899013b --- /dev/null +++ b/lbrynet/dht/protocol/iterative_find.py @@ -0,0 +1,381 @@ +import asyncio +import typing +import logging +from lbrynet.utils import drain_tasks +from lbrynet.dht import constants +from lbrynet.dht.error import RemoteException +from lbrynet.dht.protocol.distance import Distance + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from lbrynet.dht.protocol.routing_table import TreeRoutingTable + from lbrynet.dht.protocol.protocol import KademliaProtocol + from lbrynet.dht.peer import PeerManager, KademliaPeer + +log = logging.getLogger(__name__) + + +class FindResponse: + @property + def found(self) -> bool: + raise NotImplementedError() + + def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]: + raise NotImplementedError() + + +class FindNodeResponse(FindResponse): + def __init__(self, key: bytes, close_triples: typing.List[typing.Tuple[bytes, str, int]]): + self.key = key + self.close_triples = close_triples + + @property + def found(self) -> bool: + return self.key in [triple[0] for triple in self.close_triples] + + def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]: + return self.close_triples + + +class FindValueResponse(FindResponse): + def __init__(self, key: bytes, result_dict: typing.Dict): + self.key = key + self.token = result_dict[b'token'] + self.close_triples: typing.List[typing.Tuple[bytes, bytes, int]] = result_dict.get(b'contacts', []) + self.found_compact_addresses = result_dict.get(key, []) + + @property + def found(self) -> bool: + return len(self.found_compact_addresses) > 0 + + def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]: + return [(node_id, address.decode(), port) for node_id, address, port in self.close_triples] + + +def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes, + shortlist: typing.Optional[typing.List['KademliaPeer']]) -> typing.List['KademliaPeer']: + """ + If not provided, initialize the shortlist of peers to probe to the (up to) k closest peers in the routing table + + :param routing_table: a TreeRoutingTable + :param key: a 48 byte hash + :param shortlist: optional manually provided shortlist, this is done during bootstrapping when there are no + peers in the routing table. During bootstrap the shortlist is set to be the seed nodes. + """ + if len(key) != constants.hash_length: + raise ValueError("invalid key length: %i" % len(key)) + if not shortlist: + shortlist = routing_table.find_close_peers(key) + distance = Distance(key) + shortlist.sort(key=lambda peer: distance(peer.node_id), reverse=True) + return shortlist + + +class IterativeFinder: + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', + routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, + bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k, + exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + shortlist: typing.Optional[typing.List['KademliaPeer']] = None): + if len(key) != constants.hash_length: + raise ValueError("invalid key length: %i" % len(key)) + self.loop = loop + self.peer_manager = peer_manager + self.routing_table = routing_table + self.protocol = protocol + + self.key = key + self.bottom_out_limit = bottom_out_limit + self.max_results = max_results + self.exclude = exclude or [] + + self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist) + self.active: typing.List['KademliaPeer'] = [] + self.contacted: typing.List[typing.Tuple[str, int]] = [] + self.distance = Distance(key) + + self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0] + self.prev_closest_peer: typing.Optional['KademliaPeer'] = None + + self.iteration_queue = asyncio.Queue(loop=self.loop) + + self.running_probes: typing.List[asyncio.Task] = [] + self.lock = asyncio.Lock(loop=self.loop) + self.iteration_count = 0 + self.bottom_out_count = 0 + self.running = False + self.tasks: typing.List[asyncio.Task] = [] + self.delayed_calls: typing.List[asyncio.Handle] = [] + self.finished = asyncio.Event(loop=self.loop) + + async def send_probe(self, peer: 'KademliaPeer') -> FindResponse: + """ + Send the rpc request to the peer and return an object with the FindResponse interface + """ + raise NotImplementedError() + + def check_result_ready(self, response: FindResponse): + """ + Called with a lock after adding peers from an rpc result to the shortlist. + This method is responsible for putting a result for the generator into the Queue + """ + raise NotImplementedError() + + def get_initial_result(self) -> typing.List['KademliaPeer']: + """ + Get an initial or cached result to be put into the Queue. Used for findValue requests where the blob + has peers in the local data store of blobs announced to us + """ + return [] + + def _is_closer(self, peer: 'KademliaPeer') -> bool: + if not self.closest_peer: + return True + return self.distance.is_closer(peer.node_id, self.closest_peer.node_id) + + def _update_closest(self): + self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True) + if self.closest_peer and self.closest_peer is not self.shortlist[-1]: + if self._is_closer(self.shortlist[-1]): + self.prev_closest_peer = self.closest_peer + self.closest_peer = self.shortlist[-1] + + async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse): + async with self.lock: + if peer not in self.shortlist: + self.shortlist.append(peer) + if peer not in self.active: + self.active.append(peer) + for contact_triple in response.get_close_triples(): + addr_tuple = (contact_triple[1], contact_triple[2]) + if addr_tuple not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple) + found_peer = self.peer_manager.get_kademlia_peer( + contact_triple[0], contact_triple[1], contact_triple[2] + ) + if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False: + self.shortlist.append(found_peer) + self._update_closest() + self.check_result_ready(response) + + async def _send_probe(self, peer: 'KademliaPeer'): + try: + response = await self.send_probe(peer) + except asyncio.CancelledError: + return + except asyncio.TimeoutError: + if peer in self.active: + self.active.remove(peer) + return + except ValueError as err: + log.warning(str(err)) + if peer in self.active: + self.active.remove(peer) + return + except RemoteException: + return + return await self._handle_probe_result(peer, response) + + async def _search_round(self): + """ + Send up to constants.alpha (5) probes to the closest peers in the shortlist + """ + + added = 0 + async with self.lock: + self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True) + while self.running and len(self.shortlist) and added < constants.alpha: + peer = self.shortlist.pop() + origin_address = (peer.address, peer.udp_port) + if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False: + continue + if peer.node_id == self.protocol.node_id: + continue + if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port): + continue + if (peer.address, peer.udp_port) not in self.contacted: + self.contacted.append((peer.address, peer.udp_port)) + + t = self.loop.create_task(self._send_probe(peer)) + + def callback(_): + if t and t in self.running_probes: + self.running_probes.remove(t) + if not self.running_probes and self.shortlist: + self.tasks.append(self.loop.create_task(self._search_task(0.0))) + + t.add_done_callback(callback) + self.running_probes.append(t) + added += 1 + + async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay): + try: + if self.running: + await self._search_round() + if self.running: + self.delayed_calls.append(self.loop.call_later(delay, self._search)) + except (asyncio.CancelledError, StopAsyncIteration): + if self.running: + drain_tasks(self.running_probes) + self.running = False + + def _search(self): + self.tasks.append(self.loop.create_task(self._search_task())) + + def search(self): + if self.running: + raise Exception("already running") + self.running = True + self._search() + + async def next_queue_or_finished(self) -> typing.List['KademliaPeer']: + peers = self.loop.create_task(self.iteration_queue.get()) + finished = self.loop.create_task(self.finished.wait()) + err = None + try: + await asyncio.wait([peers, finished], loop=self.loop, return_when='FIRST_COMPLETED') + if peers.done(): + return peers.result() + raise StopAsyncIteration() + except asyncio.CancelledError as error: + err = error + finally: + if not finished.done() and not finished.cancelled(): + finished.cancel() + if not peers.done() and not peers.cancelled(): + peers.cancel() + if err: + raise err + + def __aiter__(self): + self.search() + return self + + async def __anext__(self) -> typing.List['KademliaPeer']: + try: + if self.iteration_count == 0: + initial_results = self.get_initial_result() + if initial_results: + self.iteration_queue.put_nowait(initial_results) + result = await self.next_queue_or_finished() + self.iteration_count += 1 + return result + except (asyncio.CancelledError, StopAsyncIteration): + await self.aclose() + raise + + def aclose(self): + self.running = False + + async def _aclose(): + async with self.lock: + self.running = False + if not self.finished.is_set(): + self.finished.set() + drain_tasks(self.tasks) + drain_tasks(self.running_probes) + while self.delayed_calls: + timer = self.delayed_calls.pop() + if timer: + timer.cancel() + + return asyncio.ensure_future(_aclose(), loop=self.loop) + + +class IterativeNodeFinder(IterativeFinder): + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', + routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, + bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k, + exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + shortlist: typing.Optional[typing.List['KademliaPeer']] = None): + super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude, + shortlist) + self.yielded_peers: typing.Set['KademliaPeer'] = set() + + async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse: + response = await self.protocol.get_rpc_peer(peer).find_node(self.key) + return FindNodeResponse(self.key, response) + + def put_result(self, from_list: typing.List['KademliaPeer']): + not_yet_yielded = [peer for peer in from_list if peer not in self.yielded_peers] + not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id)) + to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))] + if to_yield: + for peer in to_yield: + self.yielded_peers.add(peer) + self.iteration_queue.put_nowait(to_yield) + + def check_result_ready(self, response: FindNodeResponse): + found = response.found and self.key != self.protocol.node_id + + if found: + log.info("found") + self.put_result(self.shortlist) + if not self.finished.is_set(): + self.finished.set() + return + if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer): + # log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted), + # self.bottom_out_count, self.iteration_count) + self.bottom_out_count = 0 + elif self.prev_closest_peer and self.closest_peer: + self.bottom_out_count += 1 + log.info("bottom out %i %i %i %i", len(self.active), len(self.contacted), len(self.shortlist), + self.bottom_out_count) + if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit: + log.info("limit hit") + self.put_result(self.active) + if not self.finished.is_set(): + self.finished.set() + return + if self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results: + log.info("max results") + self.put_result(self.active) + if not self.finished.is_set(): + self.finished.set() + return + + +class IterativeValueFinder(IterativeFinder): + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', + routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes, + bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k, + exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None, + shortlist: typing.Optional[typing.List['KademliaPeer']] = None): + super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude, + shortlist) + self.blob_peers: typing.Set['KademliaPeer'] = set() + + async def send_probe(self, peer: 'KademliaPeer') -> FindValueResponse: + response = await self.protocol.get_rpc_peer(peer).find_value(self.key) + return FindValueResponse(self.key, response) + + def check_result_ready(self, response: FindValueResponse): + if response.found: + blob_peers = [self.peer_manager.decode_tcp_peer_from_compact_address(compact_addr) + for compact_addr in response.found_compact_addresses] + to_yield = [] + self.bottom_out_count = 0 + for blob_peer in blob_peers: + if blob_peer not in self.blob_peers: + self.blob_peers.add(blob_peer) + to_yield.append(blob_peer) + if to_yield: + # log.info("found %i new peers for blob", len(to_yield)) + self.iteration_queue.put_nowait(to_yield) + # if self.max_results and len(self.blob_peers) >= self.max_results: + # log.info("enough blob peers found") + # if not self.finished.is_set(): + # self.finished.set() + return + if self.prev_closest_peer and self.closest_peer: + self.bottom_out_count += 1 + if self.bottom_out_count >= self.bottom_out_limit: + log.info("blob peer search bottomed out") + if not self.finished.is_set(): + self.finished.set() + return + + def get_initial_result(self) -> typing.List['KademliaPeer']: + if self.protocol.data_store.has_peers_for_blob(self.key): + return self.protocol.data_store.get_peers_for_blob(self.key) + return [] diff --git a/lbrynet/dht/protocol/protocol.py b/lbrynet/dht/protocol/protocol.py new file mode 100644 index 000000000..205c96371 --- /dev/null +++ b/lbrynet/dht/protocol/protocol.py @@ -0,0 +1,634 @@ +import logging +import socket +import functools +import hashlib +import asyncio +import typing +import binascii +from asyncio.protocols import DatagramProtocol +from asyncio.transports import DatagramTransport + +from lbrynet.dht import constants +from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram +from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE +from lbrynet.dht.error import RemoteException, TransportNotConnected +from lbrynet.dht.protocol.routing_table import TreeRoutingTable +from lbrynet.dht.protocol.data_store import DictDataStore + +if typing.TYPE_CHECKING: + from lbrynet.dht.peer import PeerManager, KademliaPeer + +log = logging.getLogger(__name__) + + +old_protocol_errors = { + "findNode() takes exactly 2 arguments (5 given)": "0.19.1", + "findValue() takes exactly 2 arguments (5 given)": "0.19.1" +} + + +class KademliaRPC: + def __init__(self, protocol: 'KademliaProtocol', loop: asyncio.BaseEventLoop, peer_port: int = 3333): + self.protocol = protocol + self.loop = loop + self.peer_port = peer_port + self.old_token_secret: bytes = None + self.token_secret = constants.generate_id() + + def compact_address(self): + compact_ip = functools.reduce(lambda buff, x: buff + bytearray([int(x)]), + self.protocol.external_ip.split('.'), bytearray()) + compact_port = self.peer_port.to_bytes(2, 'big') + return compact_ip + compact_port + self.protocol.node_id + + @staticmethod + def ping(): + return b'pong' + + def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int, + original_publisher_id: bytes, age: int) -> bytes: + if original_publisher_id is None: + original_publisher_id = rpc_contact.node_id + rpc_contact.update_tcp_port(port) + if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval: + pass + elif not self.verify_token(token, rpc_contact.compact_ip()): + raise ValueError("Invalid token") + now = int(self.loop.time()) + originally_published = now - age + self.protocol.data_store.add_peer_to_blob( + rpc_contact, blob_hash, rpc_contact.compact_address_tcp(), now, originally_published, original_publisher_id + ) + return b'OK' + + def find_node(self, rpc_contact: 'KademliaPeer', key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]: + if len(key) != constants.hash_length: + raise ValueError("invalid contact node_id length: %i" % len(key)) + + contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id) + contact_triples = [] + for contact in contacts: + contact_triples.append((contact.node_id, contact.address, contact.udp_port)) + return contact_triples + + def find_value(self, rpc_contact: 'KademliaPeer', key: bytes): + if len(key) != constants.hash_length: + raise ValueError("invalid blob_exchange hash length: %i" % len(key)) + + response = { + b'token': self.make_token(rpc_contact.compact_ip()), + } + + if self.protocol.protocol_version: + response[b'protocolVersion'] = self.protocol.protocol_version + + # get peers we have stored for this blob_exchange + has_other_peers = self.protocol.data_store.has_peers_for_blob(key) + peers = [] + if has_other_peers: + peers.extend(self.protocol.data_store.get_peers_for_blob(key)) + + # if we don't have k storing peers to return and we have this hash locally, include our contact information + if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs: + peers.append(self.compact_address()) + if peers: + response[key] = peers + else: + response[b'contacts'] = self.find_node(rpc_contact, key) + return response + + def refresh_token(self): # TODO: this needs to be called periodically + self.old_token_secret = self.token_secret + self.token_secret = constants.generate_id() + + def make_token(self, compact_ip): + h = hashlib.new('sha384') + h.update(self.token_secret + compact_ip) + return h.digest() + + def verify_token(self, token, compact_ip): + h = hashlib.new('sha384') + h.update(self.token_secret + compact_ip) + if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token? + h = hashlib.new('sha384') + h.update(self.old_token_secret + compact_ip) + if not token == h.digest(): + return False + return True + + +class RemoteKademliaRPC: + """ + Encapsulates RPC calls to remote Peers + """ + + def __init__(self, loop: asyncio.BaseEventLoop, peer_tracker: 'PeerManager', protocol: 'KademliaProtocol', + peer: 'KademliaPeer'): + self.loop = loop + self.peer_tracker = peer_tracker + self.protocol = protocol + self.peer = peer + + async def ping(self) -> bytes: + """ + :return: b'pong' + """ + response = await self.protocol.send_request( + self.peer, RequestDatagram.make_ping(self.protocol.node_id) + ) + return response.response + + async def store(self, blob_hash: bytes) -> bytes: + """ + :param blob_hash: blob hash as bytes + :return: b'OK' + """ + if len(blob_hash) != constants.hash_bits // 8: + raise ValueError(f"invalid length of blob hash: {len(blob_hash)}") + if not self.protocol.peer_port or not 0 < self.protocol.peer_port < 65535: + raise ValueError(f"invalid tcp port: {self.protocol.peer_port}") + token = self.peer_tracker.get_node_token(self.peer.node_id) + if not token: + find_value_resp = await self.find_value(blob_hash) + token = find_value_resp[b'token'] + response = await self.protocol.send_request( + self.peer, RequestDatagram.make_store(self.protocol.node_id, blob_hash, token, self.protocol.peer_port) + ) + return response.response + + async def find_node(self, key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]: + """ + :return: [(node_id, address, udp_port), ...] + """ + if len(key) != constants.hash_bits // 8: + raise ValueError(f"invalid length of find node key: {len(key)}") + response = await self.protocol.send_request( + self.peer, RequestDatagram.make_find_node(self.protocol.node_id, key) + ) + return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response] + + async def find_value(self, key: bytes) -> typing.Union[typing.Dict]: + """ + :return: { + b'token': , + b'contacts': [(node_id, address, udp_port), ...] + : [= d] + for k in scheduled: + del self._pending_contacts[k] + if k not in self._enqueued_contacts: + self._enqueued_contacts.append(k) + while self._enqueued_contacts: + peer = self._enqueued_contacts.pop() + tasks.append(self._loop.create_task(_ping(peer))) + if tasks: + await asyncio.wait(tasks, loop=self._loop) + + f = self._loop.create_future() + self._loop.call_later(1.0, lambda: None if f.done() else f.set_result(None)) + await f + + def start(self): + assert not self._running + self._running = True + if not self._process_task: + self._process_task = self._loop.create_task(self._process()) + + def stop(self): + assert self._running + self._running = False + if self._process_task: + self._process_task.cancel() + self._process_task = None + if self._next_task: + self._next_task.cancel() + self._next_task = None + if self._next_timer: + self._next_timer.cancel() + self._next_timer = None + + +class KademliaProtocol(DatagramProtocol): + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, external_ip: str, + udp_port: int, peer_port: int, rpc_timeout: float = 5.0): + self.peer_manager = peer_manager + self.loop = loop + self.node_id = node_id + self.external_ip = external_ip + self.udp_port = udp_port + self.peer_port = peer_port + self.is_seed_node = False + self.partial_messages: typing.Dict[bytes, typing.Dict[bytes, bytes]] = {} + self.sent_messages: typing.Dict[bytes, typing.Tuple['KademliaPeer', asyncio.Future, RequestDatagram]] = {} + self.protocol_version = constants.protocol_version + self.started_listening_time = 0 + self.transport: DatagramTransport = None + self.old_token_secret = constants.generate_id() + self.token_secret = constants.generate_id() + self.routing_table = TreeRoutingTable(self.loop, self.peer_manager, self.node_id) + self.data_store = DictDataStore(self.loop, self.peer_manager) + self.ping_queue = PingQueue(self.loop, self) + self.node_rpc = KademliaRPC(self, self.loop, self.peer_port) + self.lock = asyncio.Lock(loop=self.loop) + self.rpc_timeout = rpc_timeout + self._split_lock = asyncio.Lock(loop=self.loop) + + def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC: + return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer) + + def stop(self): + if self.transport: + self.disconnect() + + def disconnect(self): + self.transport.close() + + def connection_made(self, transport: DatagramTransport): + self.transport = transport + + def connection_lost(self, exc): + self.stop() + + @staticmethod + def _migrate_incoming_rpc_args(peer: 'KademliaPeer', method: bytes, *args) -> typing.Tuple[typing.Tuple, + typing.Dict]: + if method == b'store' and peer.protocol_version == 0: + if isinstance(args[1], dict): + blob_hash = args[0] + token = args[1].pop(b'token', None) + port = args[1].pop(b'port', -1) + original_publisher_id = args[1].pop(b'lbryid', None) + age = 0 + return (blob_hash, token, port, original_publisher_id, age), {} + return args, {} + + async def _add_peer(self, peer: 'KademliaPeer'): + bucket_index = self.routing_table.kbucket_index(peer.node_id) + if self.routing_table.buckets[bucket_index].add_peer(peer): + return True + # The bucket is full; see if it can be split (by checking if its range includes the host node's node_id) + if self.routing_table.should_split(bucket_index, peer.node_id): + self.routing_table.split_bucket(bucket_index) + # Retry the insertion attempt + result = await self._add_peer(peer) + self.routing_table.join_buckets() + return result + else: + # We can't split the k-bucket + # + # The 13 page kademlia paper specifies that the least recently contacted node in the bucket + # shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful + # the new contact is ignored and not added to the bucket (sections 2.2 and 2.4). + # + # A reasonable extension to this is BEP 0005, which extends the above: + # + # Not all nodes that we learn about are equal. Some are "good" and some are not. + # Many nodes using the DHT are able to send queries and receive responses, + # but are not able to respond to queries from other nodes. It is important that + # each node's routing table must contain only known good nodes. A good node is + # a node has responded to one of our queries within the last 15 minutes. A node + # is also good if it has ever responded to one of our queries and has sent us a + # query within the last 15 minutes. After 15 minutes of inactivity, a node becomes + # questionable. Nodes become bad when they fail to respond to multiple queries + # in a row. Nodes that we know are good are given priority over nodes with unknown status. + # + # When there are bad or questionable nodes in the bucket, the least recent is selected for + # potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent) + # contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact + # is ignored if the pinged node replies. + + not_good_contacts = self.routing_table.buckets[bucket_index].get_bad_or_unknown_peers() + not_recently_replied = [] + for p in not_good_contacts: + last_replied = self.peer_manager.get_last_replied(p.address, p.udp_port) + if not last_replied or last_replied + 60 < self.loop.time(): + not_recently_replied.append(p) + if not_recently_replied: + to_replace = not_recently_replied[0] + else: + to_replace = self.routing_table.buckets[bucket_index].peers[0] + last_replied = self.peer_manager.get_last_replied(to_replace.address, to_replace.udp_port) + if last_replied and last_replied + 60 > self.loop.time(): + return False + log.debug("pinging %s:%s", to_replace.address, to_replace.udp_port) + try: + to_replace_rpc = self.get_rpc_peer(to_replace) + await to_replace_rpc.ping() + return False + except asyncio.TimeoutError: + log.debug("Replacing dead contact in bucket %i: %s:%i with %s:%i ", bucket_index, + to_replace.address, to_replace.udp_port, peer.address, peer.udp_port) + if to_replace in self.routing_table.buckets[bucket_index]: + self.routing_table.buckets[bucket_index].remove_peer(to_replace) + return await self._add_peer(peer) + + async def add_peer(self, peer: 'KademliaPeer') -> bool: + if peer.node_id == self.node_id: + return False + async with self._split_lock: + return await self._add_peer(peer) + + async def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram): + assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(), + binascii.hexlify(self.node_id)[:8].decode()) + method = message.method + if method not in [b'ping', b'store', b'findNode', b'findValue']: + raise AttributeError('Invalid method: %s' % message.method.decode()) + if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]: + # args don't need reformatting + a, kw = tuple(message.args[:-1]), message.args[-1] + else: + a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args) + log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(), + sender_contact.address, sender_contact.udp_port) + + if method == b'ping': + result = self.node_rpc.ping() + elif method == b'store': + blob_hash, token, port, original_publisher_id, age = a + result = self.node_rpc.store(sender_contact, blob_hash, token, port, original_publisher_id, age) + elif method == b'findNode': + key, = a + result = self.node_rpc.find_node(sender_contact, key) + else: + assert method == b'findValue' + key, = a + result = self.node_rpc.find_value(sender_contact, key) + + await self.send_response( + sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result), + ) + + async def handle_request_datagram(self, address, request_datagram: RequestDatagram): + # This is an RPC method request + await self.peer_manager.report_last_requested(address[0], address[1]) + await self.peer_manager.update_contact_triple(request_datagram.node_id, address[0], address[1]) + # only add a requesting contact to the routing table if it has replied to one of our requests + peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1]) + try: + await self._handle_rpc(peer, request_datagram) + # if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it + # will be added to our routing table if successful + is_good = self.peer_manager.peer_is_good(peer) + if is_good is None: + await self.ping_queue.enqueue_maybe_ping(peer) + elif is_good is True: + await self.add_peer(peer) + + except Exception as err: + log.warning("error raised handling %s request from %s:%i - %s(%s)", + request_datagram.method, peer.address, peer.udp_port, str(type(err)), + str(err)) + await self.send_error( + peer, + ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(), + str(err).encode()) + ) + + async def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram): + # Find the message that triggered this response + if response_datagram.rpc_id in self.sent_messages: + peer, df, request = self.sent_messages[response_datagram.rpc_id] + if peer.address != address[0]: + df.set_exception(RemoteException( + f"response from {address[0]}:{address[1]}, " + f"expected {peer.address}:{peer.udp_port}") + ) + return + peer.set_id(response_datagram.node_id) + # We got a result from the RPC + if peer.node_id == self.node_id: + df.set_exception(RemoteException("node has our node id")) + return + elif response_datagram.node_id == self.node_id: + df.set_exception(RemoteException("incoming message is from our node id")) + return + await self.peer_manager.report_last_replied(address[0], address[1]) + await self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1]) + if not df.cancelled(): + df.set_result(response_datagram) + await self.add_peer(peer) + else: + log.warning("%s:%i replied, but after we cancelled the request attempt", + peer.address, peer.udp_port) + else: + # If the original message isn't found, it must have timed out + # TODO: we should probably do something with this... + pass + + def handle_error_datagram(self, address, error_datagram: ErrorDatagram): + # The RPC request raised a remote exception; raise it locally + remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})") + if error_datagram.rpc_id in self.sent_messages: + peer, df, request = self.sent_messages.pop(error_datagram.rpc_id) + + error_msg = f"" \ + f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \ + f"Args: {request.args}\n" \ + f"Raised: {str(remote_exception)}" + if error_datagram.response not in old_protocol_errors: + log.warning(error_msg) + else: + log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)", + peer.address, peer.udp_port, old_protocol_errors[error_datagram.response]) + + # reject replies coming from a different address than what we sent our request to + if (peer.address, peer.udp_port) != address: + log.error("node id mismatch in reply") + remote_exception = TimeoutError(peer.node_id) + df.set_exception(remote_exception) + return + else: + if error_datagram.response not in old_protocol_errors: + msg = f"Received error from {address[0]}:{address[1]}, but it isn't in response to a " \ + f"pending request: {str(remote_exception)}" + log.warning(msg) + else: + log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)", + address[0], address[1], old_protocol_errors[error_datagram.response]) + + def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None: + try: + message = decode_datagram(datagram) + except (ValueError, TypeError): + self.loop.create_task(self.peer_manager.report_failure(address[0], address[1])) + log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode()) + return + + if isinstance(message, RequestDatagram): + self.loop.create_task(self.handle_request_datagram(address, message)) + elif isinstance(message, ErrorDatagram): + self.handle_error_datagram(address, message) + else: + assert isinstance(message, ResponseDatagram), "sanity" + self.loop.create_task(self.handle_response_datagram(address, message)) + + async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram: + await self._send(peer, request) + response_fut = self.sent_messages[request.rpc_id][1] + try: + response = await asyncio.wait_for(response_fut, self.rpc_timeout) + await self.peer_manager.report_last_replied(peer.address, peer.udp_port) + return response + except (asyncio.TimeoutError, RemoteException): + await self.peer_manager.report_failure(peer.address, peer.udp_port) + if self.peer_manager.peer_is_good(peer) is False: + self.routing_table.remove_peer(peer) + raise + + async def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram): + await self._send(peer, response) + + async def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram): + await self._send(peer, error) + + async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, + ErrorDatagram]): + if not self.transport: + raise TransportNotConnected() + + data = message.bencode() + if len(data) > constants.msg_size_limit: + log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit) + raise ValueError() + if isinstance(message, (RequestDatagram, ResponseDatagram)): + assert message.node_id == self.node_id, message + if isinstance(message, RequestDatagram): + assert self.node_id != peer.node_id + + def pop_from_sent_messages(_): + if message.rpc_id in self.sent_messages: + self.sent_messages.pop(message.rpc_id) + + async with self.lock: + if isinstance(message, RequestDatagram): + response_fut = self.loop.create_future() + response_fut.add_done_callback(pop_from_sent_messages) + self.sent_messages[message.rpc_id] = (peer, response_fut, message) + try: + self.transport.sendto(data, (peer.address, peer.udp_port)) + except OSError as err: + # TODO: handle ENETUNREACH + if err.errno == socket.EWOULDBLOCK: + # i'm scared this may swallow important errors, but i get a million of these + # on Linux and it doesn't seem to affect anything -grin + log.warning("Can't send data to dht: EWOULDBLOCK") + else: + log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)", + len(data), peer.address, peer.udp_port, str(err), err.errno) + if isinstance(message, RequestDatagram): + self.sent_messages[message.rpc_id][1].set_exception(err) + else: + raise err + if isinstance(message, RequestDatagram): + await self.peer_manager.report_last_sent(peer.address, peer.udp_port) + elif isinstance(message, ErrorDatagram): + await self.peer_manager.report_failure(peer.address, peer.udp_port) + + def change_token(self): + self.old_token_secret = self.token_secret + self.token_secret = constants.generate_id() + + def make_token(self, compact_ip): + return constants.digest(self.token_secret + compact_ip) + + def verify_token(self, token, compact_ip): + h = constants.hash_class() + h.update(self.token_secret + compact_ip) + if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token? + h = constants.hash_class() + h.update(self.old_token_secret + compact_ip) + if not token == h.digest(): + return False + return True + + async def store_to_peer(self, hash_value: bytes, peer: 'KademliaPeer') -> typing.Tuple[bytes, bool]: + try: + res = await self.get_rpc_peer(peer).store(hash_value) + if res != b"OK": + raise ValueError(res) + log.info("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer) + return peer.node_id, True + except asyncio.TimeoutError: + log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer) + except ValueError as err: + log.error("Unexpected response: %s" % err) + except Exception as err: + if 'Invalid token' in str(err): + await self.peer_manager.clear_token(peer.node_id) + else: + log.exception("Unexpected error while storing blob_hash") + return peer.node_id, False + + def _write(self, data: bytes, address: typing.Tuple[str, int]): + if self.transport: + try: + self.transport.sendto(data, address) + except OSError as err: + if err.errno == socket.EWOULDBLOCK: + # i'm scared this may swallow important errors, but i get a million of these + # on Linux and it doesn't seem to affect anything -grin + log.warning("Can't send data to dht: EWOULDBLOCK") + # elif err.errno == socket.ENETUNREACH: + # # this should probably try to retransmit when the network connection is back + # log.error("Network is unreachable") + else: + log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)", + len(data), address[0], address[1], str(err), err.errno) + raise err + else: + raise TransportNotConnected() diff --git a/lbrynet/dht/protocol/routing_table.py b/lbrynet/dht/protocol/routing_table.py new file mode 100644 index 000000000..d57d2561b --- /dev/null +++ b/lbrynet/dht/protocol/routing_table.py @@ -0,0 +1,305 @@ +import asyncio +import random +import logging +import typing +import itertools + +from lbrynet.dht import constants +from lbrynet.dht.protocol.distance import Distance +if typing.TYPE_CHECKING: + from lbrynet.dht.peer import KademliaPeer, PeerManager + +log = logging.getLogger(__name__) + + +class KBucket: + """ Description - later + """ + + def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, node_id: bytes): + """ + @param range_min: The lower boundary for the range in the n-bit ID + space covered by this k-bucket + @param range_max: The upper boundary for the range in the ID space + covered by this k-bucket + """ + self._peer_manager = peer_manager + self.last_accessed = 0 + self.range_min = range_min + self.range_max = range_max + self.peers: typing.List['KademliaPeer'] = [] + self._node_id = node_id + + def add_peer(self, peer: 'KademliaPeer') -> bool: + """ Add contact to _contact list in the right order. This will move the + contact to the end of the k-bucket if it is already present. + + @raise kademlia.kbucket.BucketFull: Raised when the bucket is full and + the contact isn't in the bucket + already + + @param peer: The contact to add + @type peer: dht.contact._Contact + """ + if peer in self.peers: + # Move the existing contact to the end of the list + # - using the new contact to allow add-on data + # (e.g. optimization-specific stuff) to pe updated as well + self.peers.remove(peer) + self.peers.append(peer) + return True + elif len(self.peers) < constants.k: + self.peers.append(peer) + return True + else: + return False + # raise BucketFull("No space in bucket to insert contact") + + def get_peer(self, node_id: bytes) -> 'KademliaPeer': + for peer in self.peers: + if peer.node_id == node_id: + return peer + raise IndexError(node_id) + + def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']: + """ Returns a list containing up to the first count number of contacts + + @param count: The amount of contacts to return (if 0 or less, return + all contacts) + @type count: int + @param exclude_contact: A node node_id to exclude; if this contact is in + the list of returned values, it will be + discarded before returning. If a C{str} is + passed as this argument, it must be the + contact's ID. + @type exclude_contact: str + + @param sort_distance_to: Sort distance to the node_id, defaulting to the parent node node_id. If False don't + sort the contacts + + @raise IndexError: If the number of requested contacts is too large + + @return: Return up to the first count number of contacts in a list + If no contacts are present an empty is returned + @rtype: list + """ + peers = [peer for peer in self.peers if peer.node_id != exclude_contact] + + # Return all contacts in bucket + if count <= 0: + count = len(peers) + + # Get current contact number + current_len = len(peers) + + # If count greater than k - return only k contacts + if count > constants.k: + count = constants.k + + if not current_len: + return peers + + if sort_distance_to is False: + pass + else: + sort_distance_to = sort_distance_to or self._node_id + peers.sort(key=lambda c: Distance(sort_distance_to)(c.node_id)) + + return peers[:min(current_len, count)] + + def get_bad_or_unknown_peers(self) -> typing.List['KademliaPeer']: + peer = self.get_peers(sort_distance_to=False) + return [ + peer for peer in peer + if self._peer_manager.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) is not True + ] + + def remove_peer(self, peer: 'KademliaPeer') -> None: + self.peers.remove(peer) + + def key_in_range(self, key: bytes) -> bool: + """ Tests whether the specified key (i.e. node ID) is in the range + of the n-bit ID space covered by this k-bucket (in otherwords, it + returns whether or not the specified key should be placed in this + k-bucket) + + @param key: The key to test + @type key: str or int + + @return: C{True} if the key is in this k-bucket's range, or C{False} + if not. + @rtype: bool + """ + return self.range_min <= int.from_bytes(key, 'big') < self.range_max + + def __len__(self) -> int: + return len(self.peers) + + def __contains__(self, item) -> bool: + return item in self.peers + + +class TreeRoutingTable: + """ This class implements a routing table used by a Node class. + + The Kademlia routing table is a binary tree whose leaves are k-buckets, + where each k-bucket contains nodes with some common prefix of their IDs. + This prefix is the k-bucket's position in the binary tree; it therefore + covers some range of ID values, and together all of the k-buckets cover + the entire n-bit ID (or key) space (with no overlap). + + @note: In this implementation, nodes in the tree (the k-buckets) are + added dynamically, as needed; this technique is described in the 13-page + version of the Kademlia paper, in section 2.4. It does, however, use the + ping RPC-based k-bucket eviction algorithm described in section 2.2 of + that paper. + """ + + def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes): + self._loop = loop + self._peer_manager = peer_manager + self._parent_node_id = parent_node_id + self.buckets: typing.List[KBucket] = [ + KBucket( + self._peer_manager, range_min=0, range_max=2 ** constants.hash_bits, node_id=self._parent_node_id + ) + ] + + def get_peers(self) -> typing.List['KademliaPeer']: + return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets))) + + def should_split(self, bucket_index: int, to_add: bytes) -> bool: + # https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456 + if self.buckets[bucket_index].key_in_range(self._parent_node_id): + return True + contacts = self.get_peers() + distance = Distance(self._parent_node_id) + contacts.sort(key=lambda c: distance(c.node_id)) + kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k - 1] + return distance(to_add) < distance(kth_contact.node_id) + + def find_close_peers(self, key: bytes, count: typing.Optional[int] = None, + sender_node_id: typing.Optional[bytes] = None) -> typing.List['KademliaPeer']: + exclude = [self._parent_node_id] + if sender_node_id: + exclude.append(sender_node_id) + if key in exclude: + exclude.remove(key) + count = count or constants.k + distance = Distance(key) + contacts = self.get_peers() + contacts = [c for c in contacts if c.node_id not in exclude] + if contacts: + contacts.sort(key=lambda c: distance(c.node_id)) + return contacts[:min(count, len(contacts))] + return [] + + def get_peer(self, contact_id: bytes) -> 'KademliaPeer': + """ + @raise IndexError: No contact with the specified contact ID is known + by this node + """ + return self.buckets[self.kbucket_index(contact_id)].get_peer(contact_id) + + def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]: + bucket_index = start_index + refresh_ids = [] + now = int(self._loop.time()) + for bucket in self.buckets[start_index:]: + if force or now - bucket.last_accessed >= constants.refresh_interval: + to_search = self.midpoint_id_in_bucket_range(bucket_index) + refresh_ids.append(to_search) + bucket_index += 1 + return refresh_ids + + def remove_peer(self, peer: 'KademliaPeer') -> None: + if not peer.node_id: + return + bucket_index = self.kbucket_index(peer.node_id) + try: + self.buckets[bucket_index].remove_peer(peer) + except ValueError: + return + + def touch_kbucket(self, key: bytes) -> None: + self.touch_kbucket_by_index(self.kbucket_index(key)) + + def touch_kbucket_by_index(self, bucket_index: int): + self.buckets[bucket_index].last_accessed = int(self._loop.time()) + + def kbucket_index(self, key: bytes) -> int: + i = 0 + for bucket in self.buckets: + if bucket.key_in_range(key): + return i + else: + i += 1 + return i + + def random_id_in_bucket_range(self, bucket_index: int) -> bytes: + random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max)) + return random_id.to_bytes(constants.hash_length, 'big') + + def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes: + half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2) + return int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big') + + def split_bucket(self, old_bucket_index: int) -> None: + """ Splits the specified k-bucket into two new buckets which together + cover the same range in the key/ID space + + @param old_bucket_index: The index of k-bucket to split (in this table's + list of k-buckets) + @type old_bucket_index: int + """ + # Resize the range of the current (old) k-bucket + old_bucket = self.buckets[old_bucket_index] + split_point = old_bucket.range_max - (old_bucket.range_max - old_bucket.range_min) // 2 + # Create a new k-bucket to cover the range split off from the old bucket + new_bucket = KBucket(self._peer_manager, split_point, old_bucket.range_max, self._parent_node_id) + old_bucket.range_max = split_point + # Now, add the new bucket into the routing table tree + self.buckets.insert(old_bucket_index + 1, new_bucket) + # Finally, copy all nodes that belong to the new k-bucket into it... + for contact in old_bucket.peers: + if new_bucket.key_in_range(contact.node_id): + new_bucket.add_peer(contact) + # ...and remove them from the old bucket + for contact in new_bucket.peers: + old_bucket.remove_peer(contact) + + def join_buckets(self): + to_pop = [i for i, bucket in enumerate(self.buckets) if not len(bucket)] + if not to_pop: + return + log.info("join buckets %i", len(to_pop)) + bucket_index_to_pop = to_pop[0] + assert len(self.buckets[bucket_index_to_pop]) == 0 + can_go_lower = bucket_index_to_pop - 1 >= 0 + can_go_higher = bucket_index_to_pop + 1 < len(self.buckets) + assert can_go_higher or can_go_lower + bucket = self.buckets[bucket_index_to_pop] + if can_go_lower and can_go_higher: + midpoint = ((bucket.range_max - bucket.range_min) // 2) + bucket.range_min + self.buckets[bucket_index_to_pop - 1].range_max = midpoint - 1 + self.buckets[bucket_index_to_pop + 1].range_min = midpoint + elif can_go_lower: + self.buckets[bucket_index_to_pop - 1].range_max = bucket.range_max + elif can_go_higher: + self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min + self.buckets.remove(bucket) + return self.join_buckets() + + def contact_in_routing_table(self, address_tuple: typing.Tuple[str, int]) -> bool: + for bucket in self.buckets: + for contact in bucket.get_peers(sort_distance_to=False): + if address_tuple[0] == contact.address and address_tuple[1] == contact.udp_port: + return True + return False + + def buckets_with_contacts(self) -> int: + count = 0 + for bucket in self.buckets: + if len(bucket): + count += 1 + return count diff --git a/lbrynet/dht/routingtable.py b/lbrynet/dht/routingtable.py deleted file mode 100644 index c99003d33..000000000 --- a/lbrynet/dht/routingtable.py +++ /dev/null @@ -1,320 +0,0 @@ -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive -# -# The docstrings in this module contain epytext markup; API documentation -# may be created by processing this file with epydoc: http://epydoc.sf.net - -import random -import logging -from twisted.internet import defer - -from lbrynet.dht import constants, kbucket -from lbrynet.dht.error import TimeoutError -from lbrynet.dht.distance import Distance - -log = logging.getLogger(__name__) - - -class TreeRoutingTable: - """ This class implements a routing table used by a Node class. - - The Kademlia routing table is a binary tree whFose leaves are k-buckets, - where each k-bucket contains nodes with some common prefix of their IDs. - This prefix is the k-bucket's position in the binary tree; it therefore - covers some range of ID values, and together all of the k-buckets cover - the entire n-bit ID (or key) space (with no overlap). - - @note: In this implementation, nodes in the tree (the k-buckets) are - added dynamically, as needed; this technique is described in the 13-page - version of the Kademlia paper, in section 2.4. It does, however, use the - C{PING} RPC-based k-bucket eviction algorithm described in section 2.2 of - that paper. - """ - #implements(IRoutingTable) - - def __init__(self, parentNodeID, getTime=None): - """ - @param parentNodeID: The n-bit node ID of the node to which this - routing table belongs - @type parentNodeID: str - """ - # Create the initial (single) k-bucket covering the range of the entire n-bit ID space - self._parentNodeID = parentNodeID - self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits, node_id=self._parentNodeID)] - if not getTime: - from twisted.internet import reactor - getTime = reactor.seconds - self._getTime = getTime - self._ongoing_replacements = set() - - def get_contacts(self): - contacts = [] - for i in range(len(self._buckets)): - for contact in self._buckets[i]._contacts: - contacts.append(contact) - return contacts - - def _shouldSplit(self, bucketIndex, toAdd): - # https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456 - if self._buckets[bucketIndex].keyInRange(self._parentNodeID): - return True - contacts = self.get_contacts() - distance = Distance(self._parentNodeID) - contacts.sort(key=lambda c: distance(c.id)) - kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k-1] - return distance(toAdd) < distance(kth_contact.id) - - def addContact(self, contact): - """ Add the given contact to the correct k-bucket; if it already - exists, its status will be updated - - @param contact: The contact to add to this node's k-buckets - @type contact: kademlia.contact.Contact - - @rtype: defer.Deferred - """ - - if contact.id == self._parentNodeID: - return defer.succeed(None) - bucketIndex = self._kbucketIndex(contact.id) - try: - self._buckets[bucketIndex].addContact(contact) - except kbucket.BucketFull: - # The bucket is full; see if it can be split (by checking if its range includes the host node's id) - if self._shouldSplit(bucketIndex, contact.id): - self._splitBucket(bucketIndex) - # Retry the insertion attempt - return self.addContact(contact) - else: - # We can't split the k-bucket - # - # The 13 page kademlia paper specifies that the least recently contacted node in the bucket - # shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful - # the new contact is ignored and not added to the bucket (sections 2.2 and 2.4). - # - # A reasonable extension to this is BEP 0005, which extends the above: - # - # Not all nodes that we learn about are equal. Some are "good" and some are not. - # Many nodes using the DHT are able to send queries and receive responses, - # but are not able to respond to queries from other nodes. It is important that - # each node's routing table must contain only known good nodes. A good node is - # a node has responded to one of our queries within the last 15 minutes. A node - # is also good if it has ever responded to one of our queries and has sent us a - # query within the last 15 minutes. After 15 minutes of inactivity, a node becomes - # questionable. Nodes become bad when they fail to respond to multiple queries - # in a row. Nodes that we know are good are given priority over nodes with unknown status. - # - # When there are bad or questionable nodes in the bucket, the least recent is selected for - # potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent) - # contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact - # is ignored if the pinged node replies. - - def replaceContact(failure, deadContact): - """ - Callback for the deferred PING RPC to see if the node to be replaced in the k-bucket is still - responding - - @type failure: twisted.python.failure.Failure - """ - failure.trap(TimeoutError) - log.debug("Replacing dead contact in bucket %i: %s:%i (%s) with %s:%i (%s)", bucketIndex, - deadContact.address, deadContact.port, deadContact.log_id(), contact.address, - contact.port, contact.log_id()) - try: - self._buckets[bucketIndex].removeContact(deadContact) - except ValueError: - # The contact has already been removed (probably due to a timeout) - pass - return self.addContact(contact) - - not_good_contacts = self._buckets[bucketIndex].getBadOrUnknownContacts() - if not_good_contacts: - to_replace = not_good_contacts[0] - else: - to_replace = self._buckets[bucketIndex]._contacts[0] - if to_replace not in self._ongoing_replacements: - log.debug("pinging %s:%s", to_replace.address, to_replace.port) - self._ongoing_replacements.add(to_replace) - df = to_replace.ping() - df.addErrback(replaceContact, to_replace) - df.addBoth(lambda _: self._ongoing_replacements.remove(to_replace)) - else: - df = defer.succeed(None) - return df - else: - self.touchKBucketByIndex(bucketIndex) - return defer.succeed(None) - - def findCloseNodes(self, key, count=None, sender_node_id=None): - """ Finds a number of known nodes closest to the node/value with the - specified key. - - @param key: the n-bit key (i.e. the node or value ID) to search for - @type key: str - @param count: the amount of contacts to return, default of k (8) - @type count: int - @param sender_node_id: Used during RPC, this is be the sender's Node ID - Whatever ID is passed in the parameter will get - excluded from the list of returned contacts. - @type sender_node_id: str - - @return: A list of node contacts (C{kademlia.contact.Contact instances}) - closest to the specified key. - This method will return C{k} (or C{count}, if specified) - contacts if at all possible; it will only return fewer if the - node is returning all of the contacts that it knows of. - @rtype: list - """ - exclude = [self._parentNodeID] - if sender_node_id: - exclude.append(sender_node_id) - if key in exclude: - exclude.remove(key) - count = count or constants.k - distance = Distance(key) - contacts = self.get_contacts() - contacts = [c for c in contacts if c.id not in exclude] - contacts.sort(key=lambda c: distance(c.id)) - return contacts[:min(count, len(contacts))] - - def getContact(self, contactID): - """ Returns the (known) contact with the specified node ID - - @raise ValueError: No contact with the specified contact ID is known - by this node - """ - bucketIndex = self._kbucketIndex(contactID) - return self._buckets[bucketIndex].getContact(contactID) - - def getRefreshList(self, startIndex=0, force=False): - """ Finds all k-buckets that need refreshing, starting at the - k-bucket with the specified index, and returns IDs to be searched for - in order to refresh those k-buckets - - @param startIndex: The index of the bucket to start refreshing at; - this bucket and those further away from it will - be refreshed. For example, when joining the - network, this node will set this to the index of - the bucket after the one containing it's closest - neighbour. - @type startIndex: index - @param force: If this is C{True}, all buckets (in the specified range) - will be refreshed, regardless of the time they were last - accessed. - @type force: bool - - @return: A list of node ID's that the parent node should search for - in order to refresh the routing Table - @rtype: list - """ - bucketIndex = startIndex - refreshIDs = [] - now = int(self._getTime()) - for bucket in self._buckets[startIndex:]: - if force or now - bucket.lastAccessed >= constants.refreshTimeout: - searchID = self.midpoint_id_in_bucket_range(bucketIndex) - refreshIDs.append(searchID) - bucketIndex += 1 - return refreshIDs - - def removeContact(self, contact): - """ - Remove the contact from the routing table - - @param contact: The contact to remove - @type contact: dht.contact._Contact - """ - bucketIndex = self._kbucketIndex(contact.id) - try: - self._buckets[bucketIndex].removeContact(contact) - except ValueError: - return - - def touchKBucket(self, key): - """ Update the "last accessed" timestamp of the k-bucket which covers - the range containing the specified key in the key/ID space - - @param key: A key in the range of the target k-bucket - @type key: str - """ - self.touchKBucketByIndex(self._kbucketIndex(key)) - - def touchKBucketByIndex(self, bucketIndex): - self._buckets[bucketIndex].lastAccessed = int(self._getTime()) - - def _kbucketIndex(self, key): - """ Calculate the index of the k-bucket which is responsible for the - specified key (or ID) - - @param key: The key for which to find the appropriate k-bucket index - @type key: str - - @return: The index of the k-bucket responsible for the specified key - @rtype: int - """ - i = 0 - for bucket in self._buckets: - if bucket.keyInRange(key): - return i - else: - i += 1 - return i - - def random_id_in_bucket_range(self, bucketIndex): - """ Returns a random ID in the specified k-bucket's range - - @param bucketIndex: The index of the k-bucket to use - @type bucketIndex: int - """ - - random_id = int(random.randrange(self._buckets[bucketIndex].rangeMin, self._buckets[bucketIndex].rangeMax)) - return random_id.to_bytes(constants.key_bits // 8, 'big') - - def midpoint_id_in_bucket_range(self, bucketIndex): - """ Returns the middle ID in the specified k-bucket's range - - @param bucketIndex: The index of the k-bucket to use - @type bucketIndex: int - """ - - half = int((self._buckets[bucketIndex].rangeMax - self._buckets[bucketIndex].rangeMin) // 2) - return int(self._buckets[bucketIndex].rangeMin + half).to_bytes(constants.key_bits // 8, 'big') - - def _splitBucket(self, oldBucketIndex): - """ Splits the specified k-bucket into two new buckets which together - cover the same range in the key/ID space - - @param oldBucketIndex: The index of k-bucket to split (in this table's - list of k-buckets) - @type oldBucketIndex: int - """ - # Resize the range of the current (old) k-bucket - oldBucket = self._buckets[oldBucketIndex] - splitPoint = oldBucket.rangeMax - (oldBucket.rangeMax - oldBucket.rangeMin) / 2 - # Create a new k-bucket to cover the range split off from the old bucket - newBucket = kbucket.KBucket(splitPoint, oldBucket.rangeMax, self._parentNodeID) - oldBucket.rangeMax = splitPoint - # Now, add the new bucket into the routing table tree - self._buckets.insert(oldBucketIndex + 1, newBucket) - # Finally, copy all nodes that belong to the new k-bucket into it... - for contact in oldBucket._contacts: - if newBucket.keyInRange(contact.id): - newBucket.addContact(contact) - # ...and remove them from the old bucket - for contact in newBucket._contacts: - oldBucket.removeContact(contact) - - def contactInRoutingTable(self, address_tuple): - for bucket in self._buckets: - for contact in bucket.getContacts(sort_distance_to=False): - if address_tuple[0] == contact.address and address_tuple[1] == contact.port: - return True - return False - - def bucketsWithContacts(self): - count = 0 - for bucket in self._buckets: - if len(bucket): - count += 1 - return count diff --git a/lbrynet/dht/serialization/__init__.py b/lbrynet/dht/serialization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbrynet/dht/encoding.py b/lbrynet/dht/serialization/bencoding.py similarity index 54% rename from lbrynet/dht/encoding.py rename to lbrynet/dht/serialization/bencoding.py index 4eb4d0764..5efdf969d 100644 --- a/lbrynet/dht/encoding.py +++ b/lbrynet/dht/serialization/bencoding.py @@ -1,8 +1,8 @@ +import typing from lbrynet.dht.error import DecodeError -def bencode(data): - """ Encoder implementation of the Bencode algorithm (Bittorrent). """ +def _bencode(data: typing.Union[int, bytes, bytearray, str, list, tuple, dict]) -> bytes: if isinstance(data, int): return b'i%de' % data elif isinstance(data, (bytes, bytearray)): @@ -12,31 +12,20 @@ def bencode(data): elif isinstance(data, (list, tuple)): encoded_list_items = b'' for item in data: - encoded_list_items += bencode(item) + encoded_list_items += _bencode(item) return b'l%se' % encoded_list_items elif isinstance(data, dict): encoded_dict_items = b'' keys = data.keys() for key in sorted(keys): - encoded_dict_items += bencode(key) - encoded_dict_items += bencode(data[key]) + encoded_dict_items += _bencode(key) + encoded_dict_items += _bencode(data[key]) return b'd%se' % encoded_dict_items else: - raise TypeError("Cannot bencode '%s' object" % type(data)) + raise TypeError(f"Cannot bencode {type(data)}") -def bdecode(data): - """ Decoder implementation of the Bencode algorithm. """ - assert type(data) == bytes # fixme: _maybe_ remove this after porting - if len(data) == 0: - raise DecodeError('Cannot decode empty string') - try: - return _decode_recursive(data)[0] - except ValueError as e: - raise DecodeError(str(e)) - - -def _decode_recursive(data, start_index=0): +def _bdecode(data: bytes, start_index: int = 0) -> typing.Tuple[typing.Union[int, bytes, list, tuple, dict], int]: if data[start_index] == ord('i'): end_pos = data[start_index:].find(b'e') + start_index return int(data[start_index + 1:end_pos]), end_pos + 1 @@ -44,32 +33,44 @@ def _decode_recursive(data, start_index=0): start_index += 1 decoded_list = [] while data[start_index] != ord('e'): - list_data, start_index = _decode_recursive(data, start_index) + list_data, start_index = _bdecode(data, start_index) decoded_list.append(list_data) return decoded_list, start_index + 1 elif data[start_index] == ord('d'): start_index += 1 decoded_dict = {} while data[start_index] != ord('e'): - key, start_index = _decode_recursive(data, start_index) - value, start_index = _decode_recursive(data, start_index) + key, start_index = _bdecode(data, start_index) + value, start_index = _bdecode(data, start_index) decoded_dict[key] = value return decoded_dict, start_index - elif data[start_index] == ord('f'): - # This (float data type) is a non-standard extension to the original Bencode algorithm - end_pos = data[start_index:].find(b'e') + start_index - return float(data[start_index + 1:end_pos]), end_pos + 1 - elif data[start_index] == ord('n'): - # This (None/NULL data type) is a non-standard extension - # to the original Bencode algorithm - return None, start_index + 1 else: split_pos = data[start_index:].find(b':') + start_index try: length = int(data[start_index:split_pos]) - except ValueError: - raise DecodeError() + except (ValueError, TypeError) as err: + raise DecodeError(err) start_index = split_pos + 1 end_pos = start_index + length b = data[start_index:end_pos] return b, end_pos + + +def bencode(data: typing.Dict) -> bytes: + if not isinstance(data, dict): + raise TypeError() + return _bencode(data) + + +def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict: + assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}") + + if len(data) == 0: + raise DecodeError('Cannot decode empty string') + try: + result = _bdecode(data)[0] + if not allow_non_dict_return and not isinstance(result, dict): + raise ValueError(f'expected dict, got {type(result)}') + return result + except (ValueError, TypeError) as err: + raise DecodeError(err) diff --git a/lbrynet/dht/serialization/datagram.py b/lbrynet/dht/serialization/datagram.py new file mode 100644 index 000000000..81aa1b132 --- /dev/null +++ b/lbrynet/dht/serialization/datagram.py @@ -0,0 +1,181 @@ +import typing +from functools import reduce +from lbrynet.dht import constants +from lbrynet.dht.serialization.bencoding import bencode, bdecode + +REQUEST_TYPE = 0 +RESPONSE_TYPE = 1 +ERROR_TYPE = 2 + + +class KademliaDatagramBase: + """ + field names are used to unwrap/wrap the argument names to index integers that replace them in a datagram + all packets have an argument dictionary when bdecoded starting with {0: , 1: , 2: , ...} + these correspond to the packet_type, rpc_id, and node_id args + """ + + fields = [ + 'packet_type', + 'rpc_id', + 'node_id' + ] + + expected_packet_type = -1 + + def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes): + self.packet_type = packet_type + if self.expected_packet_type != packet_type: + raise ValueError(f"invalid packet type: {packet_type}, expected {self.expected_packet_type}") + if len(rpc_id) != constants.rpc_id_length: + raise ValueError(f"invalid rpc node_id: {len(rpc_id)} bytes (expected 20)") + if not len(node_id) == constants.hash_length: + raise ValueError(f"invalid node node_id: {len(node_id)} bytes (expected 48)") + self.rpc_id = rpc_id + self.node_id = node_id + + def bencode(self) -> bytes: + return bencode({ + i: getattr(self, k) for i, k in enumerate(self.fields) + }) + + +class RequestDatagram(KademliaDatagramBase): + fields = [ + 'packet_type', + 'rpc_id', + 'node_id', + 'method', + 'args' + ] + + expected_packet_type = REQUEST_TYPE + + def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, method: bytes, + args: typing.Optional[typing.List] = None): + super().__init__(packet_type, rpc_id, node_id) + self.method = method + self.args = args or [] + if not self.args: + self.args.append({}) + if isinstance(self.args[-1], dict): + self.args[-1][b'protocolVersion'] = 1 + else: + self.args.append({b'protocolVersion': 1}) + + @classmethod + def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': + if rpc_id and len(rpc_id) != constants.rpc_id_length: + raise ValueError("invalid rpc id length") + elif not rpc_id: + rpc_id = constants.generate_id()[:constants.rpc_id_length] + if len(from_node_id) != constants.hash_bits // 8: + raise ValueError("invalid node id") + return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping') + + @classmethod + def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int, + rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': + if rpc_id and len(rpc_id) != constants.rpc_id_length: + raise ValueError("invalid rpc id length") + if not rpc_id: + rpc_id = constants.generate_id()[:constants.rpc_id_length] + if len(from_node_id) != constants.hash_bits // 8: + raise ValueError("invalid node id") + store_args = [blob_hash, token, port, from_node_id, 0] + return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args) + + @classmethod + def make_find_node(cls, from_node_id: bytes, key: bytes, + rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': + if rpc_id and len(rpc_id) != constants.rpc_id_length: + raise ValueError("invalid rpc id length") + if not rpc_id: + rpc_id = constants.generate_id()[:constants.rpc_id_length] + if len(from_node_id) != constants.hash_bits // 8: + raise ValueError("invalid node id") + return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key]) + + @classmethod + def make_find_value(cls, from_node_id: bytes, key: bytes, + rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram': + if rpc_id and len(rpc_id) != constants.rpc_id_length: + raise ValueError("invalid rpc id length") + if not rpc_id: + rpc_id = constants.generate_id()[:constants.rpc_id_length] + if len(from_node_id) != constants.hash_bits // 8: + raise ValueError("invalid node id") + return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key]) + + +class ResponseDatagram(KademliaDatagramBase): + fields = [ + 'packet_type', + 'rpc_id', + 'node_id', + 'response' + ] + + expected_packet_type = RESPONSE_TYPE + + def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, response): + super().__init__(packet_type, rpc_id, node_id) + self.response = response + + +class ErrorDatagram(KademliaDatagramBase): + fields = [ + 'packet_type', + 'rpc_id', + 'node_id', + 'exception_type', + 'response', + ] + + expected_packet_type = ERROR_TYPE + + def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, exception_type: bytes, response: bytes): + super().__init__(packet_type, rpc_id, node_id) + self.exception_type = exception_type.decode() + self.response = response.decode() + + +def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]: + msg_types = { + REQUEST_TYPE: RequestDatagram, + RESPONSE_TYPE: ResponseDatagram, + ERROR_TYPE: ErrorDatagram + } + + primitive: typing.Dict = bdecode(datagram) + if not isinstance(primitive, dict): + raise ValueError("invalid datagram type") + if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object + datagram_type = primitive[0] # pylint: disable=unsubscriptable-object + else: + raise ValueError("invalid datagram type") + datagram_class = msg_types[datagram_type] + return datagram_class(**{ + k: primitive[i] # pylint: disable=unsubscriptable-object + for i, k in enumerate(datagram_class.fields) + if i in primitive # pylint: disable=unsupported-membership-test + } + ) + + +def make_compact_ip(address: str): + return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray()) + + +def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray: + compact_ip = make_compact_ip(address) + if not 0 <= port <= 65536: + raise ValueError(f'Invalid port: {port}') + return compact_ip + port.to_bytes(2, 'big') + node_id + + +def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, int]: + address = "{}.{}.{}.{}".format(*compact_address[:4]) + port = int.from_bytes(compact_address[4:6], 'big') + node_id = compact_address[6:] + return node_id, address, port diff --git a/lbrynet/extras/daemon/HashAnnouncer.py b/lbrynet/extras/daemon/HashAnnouncer.py deleted file mode 100644 index e1352aa78..000000000 --- a/lbrynet/extras/daemon/HashAnnouncer.py +++ /dev/null @@ -1,92 +0,0 @@ -import binascii -import logging - -from twisted.internet import defer, task -from lbrynet.extras.compat import f2d -from lbrynet import utils -from lbrynet.conf import Config - -log = logging.getLogger(__name__) - - -class DHTHashAnnouncer: - def __init__(self, conf: Config, dht_node, storage, concurrent_announcers=None): - self.conf = conf - self.dht_node = dht_node - self.storage = storage - self.clock = dht_node.clock - self.peer_port = dht_node.peerPort - self.hash_queue = [] - if concurrent_announcers is None: - self.concurrent_announcers = conf.concurrent_announcers - else: - self.concurrent_announcers = concurrent_announcers - self._manage_lc = None - if self.concurrent_announcers: - self._manage_lc = task.LoopingCall(self.manage) - self._manage_lc.clock = self.clock - self.sem = defer.DeferredSemaphore(self.concurrent_announcers or conf.concurrent_announcers or 1) - - def start(self): - if self._manage_lc: - self._manage_lc.start(30) - - def stop(self): - if self._manage_lc and self._manage_lc.running: - self._manage_lc.stop() - - @defer.inlineCallbacks - def do_store(self, blob_hash): - storing_node_ids = yield self.dht_node.announceHaveBlob(binascii.unhexlify(blob_hash)) - now = self.clock.seconds() - if storing_node_ids: - result = (now, storing_node_ids) - yield f2d(self.storage.update_last_announced_blob(blob_hash, now)) - log.debug("Stored %s to %i peers", blob_hash[:16], len(storing_node_ids)) - else: - result = (None, []) - self.hash_queue.remove(blob_hash) - defer.returnValue(result) - - def hash_queue_size(self): - return len(self.hash_queue) - - def _show_announce_progress(self, size, start): - queue_size = len(self.hash_queue) - average_blobs_per_second = float(size - queue_size) / (self.clock.seconds() - start) - log.info("Announced %i/%i blobs, %f blobs per second", size - queue_size, size, average_blobs_per_second) - - @defer.inlineCallbacks - def immediate_announce(self, blob_hashes): - self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue) - log.info("Announcing %i blobs", len(self.hash_queue)) - start = self.clock.seconds() - progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start) - progress_lc.clock = self.clock - progress_lc.start(60, now=False) - results = yield utils.DeferredDict( - {blob_hash: self.sem.run(self.do_store, blob_hash) for blob_hash in blob_hashes} - ) - now = self.clock.seconds() - - progress_lc.stop() - - announced_to = [blob_hash for blob_hash in results if results[blob_hash][0]] - if len(announced_to) != len(results): - log.debug("Failed to announce %i blobs", len(results) - len(announced_to)) - if announced_to: - log.info('Took %s seconds to announce %i of %i attempted hashes (%f hashes per second)', - now - start, len(announced_to), len(blob_hashes), - int(float(len(blob_hashes)) / float(now - start))) - defer.returnValue(results) - - @defer.inlineCallbacks - def manage(self): - if not self.dht_node.contacts: - log.info("Not ready to start announcing hashes") - return - need_reannouncement = yield f2d(self.storage.get_blobs_to_announce()) - if need_reannouncement: - yield self.immediate_announce(need_reannouncement) - else: - log.debug("Nothing to announce") diff --git a/lbrynet/extras/daemon/PeerFinder.py b/lbrynet/extras/daemon/PeerFinder.py deleted file mode 100644 index e0997807e..000000000 --- a/lbrynet/extras/daemon/PeerFinder.py +++ /dev/null @@ -1,70 +0,0 @@ -import binascii -import logging - -from twisted.internet import defer - -log = logging.getLogger(__name__) - - -class DummyPeerFinder: - """This class finds peers which have announced to the DHT that they have certain blobs""" - - def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True): - return defer.succeed([]) - - -class DHTPeerFinder(DummyPeerFinder): - """This class finds peers which have announced to the DHT that they have certain blobs""" - #implements(IPeerFinder) - - def __init__(self, component_manager): - """ - component_manager - an instance of ComponentManager - """ - self.component_manager = component_manager - self.peer_manager = component_manager.peer_manager - self.peers = {} - self._ongoing_searchs = {} - - def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True): - """ - Find peers for blob in the DHT - blob_hash (str): blob hash to look for - timeout (int): seconds to timeout after - filter_self (bool): if True, and if a peer for a blob is itself, filter it - from the result - - Returns: - list of peers for the blob - """ - if "dht" in self.component_manager.skip_components: - return defer.succeed([]) - if not self.component_manager.all_components_running("dht"): - return defer.succeed([]) - dht_node = self.component_manager.get_component("dht") - - self.peers.setdefault(blob_hash, {(dht_node.externalIP, dht_node.peerPort,)}) - if not blob_hash in self._ongoing_searchs or self._ongoing_searchs[blob_hash].called: - self._ongoing_searchs[blob_hash] = self._execute_peer_search(dht_node, blob_hash, timeout) - - def _filter_self(blob_hash): - my_host, my_port = dht_node.externalIP, dht_node.peerPort - return {(host, port) for host, port in self.peers[blob_hash] if (host, port) != (my_host, my_port)} - - peers = set(_filter_self(blob_hash) if filter_self else self.peers[blob_hash]) - return defer.succeed([self.peer_manager.get_peer(*peer) for peer in peers]) - - @defer.inlineCallbacks - def _execute_peer_search(self, dht_node, blob_hash, timeout): - bin_hash = binascii.unhexlify(blob_hash) - finished_deferred = dht_node.iterativeFindValue(bin_hash, exclude=self.peers[blob_hash]) - timeout = timeout or self.component_manager.conf.peer_search_timeout - if timeout: - finished_deferred.addTimeout(timeout, dht_node.clock) - try: - peer_list = yield finished_deferred - self.peers[blob_hash].update({(host, port) for _, host, port in peer_list}) - except defer.TimeoutError: - log.debug("DHT timed out while looking peers for blob %s after %s seconds", blob_hash, timeout) - finally: - del self._ongoing_searchs[blob_hash] diff --git a/lbrynet/extras/daemon/PeerManager.py b/lbrynet/extras/daemon/PeerManager.py deleted file mode 100644 index 4cb1425ba..000000000 --- a/lbrynet/extras/daemon/PeerManager.py +++ /dev/null @@ -1,14 +0,0 @@ -from lbrynet.p2p.Peer import Peer - - -class PeerManager: - def __init__(self): - self.peers = [] - - def get_peer(self, host, port): - for peer in self.peers: - if peer.host == host and peer.port == port: - return peer - peer = Peer(host, port) - self.peers.append(peer) - return peer diff --git a/lbrynet/p2p/Peer.py b/lbrynet/p2p/Peer.py deleted file mode 100644 index bd52e0da3..000000000 --- a/lbrynet/p2p/Peer.py +++ /dev/null @@ -1,46 +0,0 @@ -import datetime -from collections import defaultdict -from lbrynet import utils - -# Do not create this object except through PeerManager -class Peer: - def __init__(self, host, port): - self.host = host - self.port = port - # If a peer is reported down, we wait till this time till reattempting connection - self.attempt_connection_at = None - # Number of times this Peer has been reported to be down, resets to 0 when it is up - self.down_count = 0 - # Number of successful connections (with full protocol completion) to this peer - self.success_count = 0 - self.score = 0 - self.stats = defaultdict(float) # {string stat_type, float count} - - def is_available(self): - if self.attempt_connection_at is None or utils.today() > self.attempt_connection_at: - return True - return False - - def report_up(self): - self.down_count = 0 - self.attempt_connection_at = None - - def report_success(self): - self.success_count += 1 - - def report_down(self): - self.down_count += 1 - timeout_time = datetime.timedelta(seconds=60 * self.down_count) - self.attempt_connection_at = utils.today() + timeout_time - - def update_score(self, score_change): - self.score += score_change - - def update_stats(self, stat_type, count): - self.stats[stat_type] += count - - def __str__(self): - return f'{self.host}:{self.port}' - - def __repr__(self): - return f'Peer({self.host!r}, {self.port!r})' diff --git a/tests/dht_mocks.py b/tests/dht_mocks.py new file mode 100644 index 000000000..46dc82668 --- /dev/null +++ b/tests/dht_mocks.py @@ -0,0 +1,83 @@ +import typing +import contextlib +import socket +import mock +import functools +import asyncio +if typing.TYPE_CHECKING: + from lbrynet.dht.protocol.protocol import KademliaProtocol + + +def get_time_accelerator(loop: asyncio.BaseEventLoop, + now: typing.Optional[float] = None) -> typing.Callable[[float], typing.Awaitable[None]]: + """ + Returns an async advance() function + + This provides a way to advance() the BaseEventLoop.time for the scheduled TimerHandles + made by call_later, call_at, and call_soon. + """ + + _time = now or loop.time() + loop.time = functools.wraps(loop.time)(lambda: _time) + + async def accelerate_time(seconds: float) -> None: + nonlocal _time + if seconds < 0: + raise ValueError('Cannot go back in time ({} seconds)'.format(seconds)) + _time += seconds + await past_events() + await asyncio.sleep(0) + + async def past_events() -> None: + while loop._scheduled: + timer: asyncio.TimerHandle = loop._scheduled[0] + if timer not in loop._ready and timer._when <= _time: + loop._scheduled.remove(timer) + loop._ready.append(timer) + if timer._when > _time: + break + await asyncio.sleep(0) + + async def accelerator(seconds: float): + steps = seconds * 10.0 + + for _ in range(max(int(steps), 1)): + await accelerate_time(0.1) + + return accelerator + + +@contextlib.contextmanager +def mock_network_loop(loop: asyncio.BaseEventLoop): + dht_network: typing.Dict[typing.Tuple[str, int], 'KademliaProtocol'] = {} + + async def create_datagram_endpoint(proto_lam: typing.Callable[[], 'KademliaProtocol'], + from_addr: typing.Tuple[str, int]): + def sendto(data, to_addr): + rx = dht_network.get(to_addr) + if rx and rx.external_ip: + # print(f"{from_addr[0]}:{from_addr[1]} -{len(data)} bytes-> {rx.external_ip}:{rx.udp_port}") + return rx.datagram_received(data, from_addr) + + protocol = proto_lam() + transport = asyncio.DatagramTransport(extra={'socket': mock_sock}) + transport.close = lambda: mock_sock.close() + mock_sock.sendto = sendto + transport.sendto = mock_sock.sendto + protocol.connection_made(transport) + dht_network[from_addr] = protocol + return transport, protocol + + with mock.patch('socket.socket') as mock_socket: + mock_sock = mock.Mock(spec=socket.socket) + mock_sock.setsockopt = lambda *_: None + mock_sock.bind = lambda *_: None + mock_sock.setblocking = lambda *_: None + mock_sock.getsockname = lambda: "0.0.0.0" + mock_sock.getpeername = lambda: "" + mock_sock.close = lambda: None + mock_sock.type = socket.SOCK_DGRAM + mock_sock.fileno = lambda: 7 + mock_socket.return_value = mock_sock + loop.create_datagram_endpoint = create_datagram_endpoint + yield diff --git a/tests/unit/dht/protocol/__init__.py b/tests/unit/dht/protocol/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/dht/protocol/test_protocol.py b/tests/unit/dht/protocol/test_protocol.py new file mode 100644 index 000000000..5a98e8837 --- /dev/null +++ b/tests/unit/dht/protocol/test_protocol.py @@ -0,0 +1,91 @@ +import asyncio +from torba.testcase import AsyncioTestCase +from tests import dht_mocks +from lbrynet.dht import constants +from lbrynet.dht.protocol.protocol import KademliaProtocol +from lbrynet.dht.peer import PeerManager + + +class TestProtocol(AsyncioTestCase): + async def test_ping(self): + loop = asyncio.get_event_loop() + with dht_mocks.mock_network_loop(loop): + node_id1 = constants.generate_id() + peer1 = KademliaProtocol( + loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333 + ) + peer2 = KademliaProtocol( + loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333 + ) + await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444)) + await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444)) + + peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444) + result = await peer2.get_rpc_peer(peer).ping() + self.assertEqual(result, b'pong') + peer1.stop() + peer2.stop() + peer1.disconnect() + peer2.disconnect() + + async def test_update_token(self): + loop = asyncio.get_event_loop() + with dht_mocks.mock_network_loop(loop): + node_id1 = constants.generate_id() + peer1 = KademliaProtocol( + loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333 + ) + peer2 = KademliaProtocol( + loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333 + ) + await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444)) + await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444)) + + peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444) + self.assertEqual(None, peer2.peer_manager.get_node_token(peer.node_id)) + await peer2.get_rpc_peer(peer).find_value(b'1' * 48) + self.assertNotEqual(None, peer2.peer_manager.get_node_token(peer.node_id)) + peer1.stop() + peer2.stop() + peer1.disconnect() + peer2.disconnect() + + async def test_store_to_peer(self): + loop = asyncio.get_event_loop() + with dht_mocks.mock_network_loop(loop): + node_id1 = constants.generate_id() + peer1 = KademliaProtocol( + loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333 + ) + peer2 = KademliaProtocol( + loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333 + ) + await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444)) + await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444)) + + peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444) + peer2_from_peer1 = peer1.peer_manager.get_kademlia_peer( + peer2.node_id, peer2.external_ip, udp_port=peer2.udp_port + ) + peer3 = peer1.peer_manager.get_kademlia_peer( + constants.generate_id(), '1.2.3.6', udp_port=4444 + ) + store_result = await peer2.store_to_peer(b'2' * 48, peer) + self.assertEqual(store_result[0], peer.node_id) + self.assertEqual(True, store_result[1]) + self.assertEqual(True, peer1.data_store.has_peers_for_blob(b'2' * 48)) + self.assertEqual(False, peer1.data_store.has_peers_for_blob(b'3' * 48)) + self.assertListEqual([peer2_from_peer1], peer1.data_store.get_storing_contacts()) + + find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48) + self.assertSetEqual( + {b'2' * 48, b'token', b'protocolVersion'}, set(find_value_response.keys()) + ) + self.assertEqual(1, len(find_value_response[b'2' * 48])) + self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1) + # self.assertEqual(peer2_from_peer1.tcp_port, 3333) + + peer1.stop() + peer2.stop() + peer1.disconnect() + peer2.disconnect() diff --git a/tests/unit/dht/routing/__init__.py b/tests/unit/dht/routing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/dht/routing/test_kbucket.py b/tests/unit/dht/routing/test_kbucket.py new file mode 100644 index 000000000..88644132c --- /dev/null +++ b/tests/unit/dht/routing/test_kbucket.py @@ -0,0 +1,115 @@ +import struct +import asyncio +from lbrynet.utils import generate_id +from lbrynet.dht.protocol.routing_table import KBucket +from lbrynet.dht.peer import PeerManager +from lbrynet.dht import constants +from torba.testcase import AsyncioTestCase + + +def address_generator(address=(10, 42, 42, 1)): + def increment(addr): + value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1 + new_addr = [] + for i in range(4): + new_addr.append(value % 256) + value >>= 8 + return tuple(new_addr[::-1]) + + while True: + yield "{}.{}.{}.{}".format(*address) + address = increment(address) + + +class TestKBucket(AsyncioTestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + self.address_generator = address_generator() + self.peer_manager = PeerManager(self.loop) + self.kbucket = KBucket(self.peer_manager, 0, 2**constants.hash_bits, generate_id()) + + def test_add_peer(self): + # Test if contacts can be added to empty list + # Add k contacts to bucket + for i in range(constants.k): + peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444) + self.assertTrue(self.kbucket.add_peer(peer)) + self.assertEqual(peer, self.kbucket.peers[i]) + + # Test if contact is not added to full list + peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444) + self.assertFalse(self.kbucket.add_peer(peer)) + + # Test if an existing contact is updated correctly if added again + existing_peer = self.kbucket.peers[0] + self.assertTrue(self.kbucket.add_peer(existing_peer)) + self.assertEqual(existing_peer, self.kbucket.peers[-1]) + + # def testGetContacts(self): + # # try and get 2 contacts from empty list + # result = self.kbucket.getContacts(2) + # self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" % + # (len(result))) + # + # # Add k-2 contacts + # node_ids = [] + # if constants.k >= 2: + # for i in range(constants.k-2): + # node_ids.append(generate_id()) + # tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0, + # None) + # self.kbucket.addContact(tmpContact) + # else: + # # add k contacts + # for i in range(constants.k): + # node_ids.append(generate_id()) + # tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0, + # None) + # self.kbucket.addContact(tmpContact) + # + # # try to get too many contacts + # # requested count greater than bucket size; should return at most k contacts + # contacts = self.kbucket.getContacts(constants.k+3) + # self.assertTrue(len(contacts) <= constants.k, + # 'Returned list should not have more than k entries!') + # + # # verify returned contacts in list + # for node_id, i in zip(node_ids, range(constants.k-2)): + # self.assertFalse(self.kbucket._contacts[i].id != node_id, + # "Contact in position %s not same as added contact" % (str(i))) + # + # # try to get too many contacts + # # requested count one greater than number of contacts + # if constants.k >= 2: + # result = self.kbucket.getContacts(constants.k-1) + # self.assertFalse(len(result) != constants.k-2, + # "Too many contacts in returned list %s - should be %s" % + # (len(result), constants.k-2)) + # else: + # result = self.kbucket.getContacts(constants.k-1) + # # if the count is <= 0, it should return all of it's contats + # self.assertFalse(len(result) != constants.k, + # "Too many contacts in returned list %s - should be %s" % + # (len(result), constants.k-2)) + # result = self.kbucket.getContacts(constants.k-3) + # self.assertFalse(len(result) != constants.k-3, + # "Too many contacts in returned list %s - should be %s" % + # (len(result), constants.k-3)) + + def test_remove_peer(self): + # try remove contact from empty list + peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444) + self.assertRaises(ValueError, self.kbucket.remove_peer, peer) + + added = [] + # Add couple contacts + for i in range(constants.k-2): + peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444) + self.assertTrue(self.kbucket.add_peer(peer)) + added.append(peer) + + while added: + peer = added.pop() + self.assertIn(peer, self.kbucket.peers) + self.kbucket.remove_peer(peer) + self.assertNotIn(peer, self.kbucket.peers) diff --git a/tests/unit/dht/routing/test_routing_table.py b/tests/unit/dht/routing/test_routing_table.py new file mode 100644 index 000000000..183e04165 --- /dev/null +++ b/tests/unit/dht/routing/test_routing_table.py @@ -0,0 +1,259 @@ +import asyncio +from torba.testcase import AsyncioTestCase +from tests import dht_mocks +from lbrynet.dht import constants +from lbrynet.dht.node import Node +from lbrynet.dht.peer import PeerManager + + +class TestRouting(AsyncioTestCase): + async def test_fill_one_bucket(self): + loop = asyncio.get_event_loop() + peer_addresses = [ + (constants.generate_id(1), '1.2.3.1'), + (constants.generate_id(2), '1.2.3.2'), + (constants.generate_id(3), '1.2.3.3'), + (constants.generate_id(4), '1.2.3.4'), + (constants.generate_id(5), '1.2.3.5'), + (constants.generate_id(6), '1.2.3.6'), + (constants.generate_id(7), '1.2.3.7'), + (constants.generate_id(8), '1.2.3.8'), + (constants.generate_id(9), '1.2.3.9'), + ] + with dht_mocks.mock_network_loop(loop): + nodes = { + i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address) + for i, (node_id, address) in enumerate(peer_addresses) + } + node_1 = nodes[0] + contact_cnt = 0 + for i in range(1, len(peer_addresses)): + self.assertEqual(len(node_1.protocol.routing_table.get_peers()), contact_cnt) + node = nodes[i] + peer = node_1.protocol.peer_manager.get_kademlia_peer( + node.protocol.node_id, node.protocol.external_ip, + udp_port=node.protocol.udp_port + ) + added = await node_1.protocol.add_peer(peer) + self.assertEqual(True, added) + contact_cnt += 1 + + self.assertEqual(len(node_1.protocol.routing_table.get_peers()), 8) + self.assertEqual(node_1.protocol.routing_table.buckets_with_contacts(), 1) + for node in nodes.values(): + node.protocol.stop() + + +# from binascii import hexlify, unhexlify +# +# from twisted.trial import unittest +# from twisted.internet import defer +# from lbrynet.dht import constants +# from lbrynet.dht.routingtable import TreeRoutingTable +# from lbrynet.dht.contact import ContactManager +# from lbrynet.dht.distance import Distance +# from lbrynet.utils import generate_id +# +# +# class FakeRPCProtocol: +# """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """ +# def sendRPC(self, *args, **kwargs): +# return defer.succeed(None) +# +# +# class TreeRoutingTableTest(unittest.TestCase): +# """ Test case for the RoutingTable class """ +# def setUp(self): +# self.contact_manager = ContactManager() +# self.nodeID = generate_id(b'node1') +# self.protocol = FakeRPCProtocol() +# self.routingTable = TreeRoutingTable(self.nodeID) +# +# def test_distance(self): +# """ Test to see if distance method returns correct result""" +# d = Distance(bytes((170,) * 48)) +# result = d(bytes((85,) * 48)) +# expected = int(hexlify(bytes((255,) * 48)), 16) +# self.assertEqual(result, expected) +# +# @defer.inlineCallbacks +# def test_add_contact(self): +# """ Tests if a contact can be added and retrieved correctly """ +# # Create the contact +# contact_id = generate_id(b'node2') +# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) +# # Now add it... +# yield self.routingTable.addContact(contact) +# # ...and request the closest nodes to it (will retrieve it) +# closest_nodes = self.routingTable.findCloseNodes(contact_id) +# self.assertEqual(len(closest_nodes), 1) +# self.assertIn(contact, closest_nodes) +# +# @defer.inlineCallbacks +# def test_get_contact(self): +# """ Tests if a specific existing contact can be retrieved correctly """ +# contact_id = generate_id(b'node2') +# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) +# # Now add it... +# yield self.routingTable.addContact(contact) +# # ...and get it again +# same_contact = self.routingTable.getContact(contact_id) +# self.assertEqual(contact, same_contact, 'getContact() should return the same contact') +# +# @defer.inlineCallbacks +# def test_add_parent_node_as_contact(self): +# """ +# Tests the routing table's behaviour when attempting to add its parent node as a contact +# """ +# # Create a contact with the same ID as the local node's ID +# contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol) +# # Now try to add it +# yield self.routingTable.addContact(contact) +# # ...and request the closest nodes to it using FIND_NODE +# closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k) +# self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact') +# +# @defer.inlineCallbacks +# def test_remove_contact(self): +# """ Tests contact removal """ +# # Create the contact +# contact_id = generate_id(b'node2') +# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) +# # Now add it... +# yield self.routingTable.addContact(contact) +# # Verify addition +# self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly') +# # Now remove it +# self.routingTable.removeContact(contact) +# self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly') +# +# @defer.inlineCallbacks +# def test_split_bucket(self): +# """ Tests if the the routing table correctly dynamically splits k-buckets """ +# self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384, +# 'Initial k-bucket range should be 0 <= range < 2**384') +# # Add k contacts +# for i in range(constants.k): +# node_id = generate_id(b'remote node %d' % i) +# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) +# yield self.routingTable.addContact(contact) +# +# self.assertEqual(len(self.routingTable._buckets), 1, +# 'Only k nodes have been added; the first k-bucket should now ' +# 'be full, but should not yet be split') +# # Now add 1 more contact +# node_id = generate_id(b'yet another remote node') +# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) +# yield self.routingTable.addContact(contact) +# self.assertEqual(len(self.routingTable._buckets), 2, +# 'k+1 nodes have been added; the first k-bucket should have been ' +# 'split into two new buckets') +# self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384, +# 'K-bucket was split, but its range was not properly adjusted') +# self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384, +# 'K-bucket was split, but the second (new) bucket\'s ' +# 'max range was not set properly') +# self.assertEqual(self.routingTable._buckets[0].rangeMax, +# self.routingTable._buckets[1].rangeMin, +# 'K-bucket was split, but the min/max ranges were ' +# 'not divided properly') +# +# @defer.inlineCallbacks +# def test_full_split(self): +# """ +# Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact +# """ +# +# self.routingTable._parentNodeID = bytes(48 * b'\xff') +# +# node_ids = [ +# b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", +# b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" +# ] +# +# # Add k contacts +# for nodeID in node_ids: +# # self.assertEquals(nodeID, node_ids[i].decode('hex')) +# contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol) +# yield self.routingTable.addContact(contact) +# self.assertEqual(len(self.routingTable._buckets), 2) +# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8) +# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2) +# +# # try adding a contact who is further from us than the k'th known contact +# nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000' +# nodeID = unhexlify(nodeID) +# contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol) +# self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id)) +# yield self.routingTable.addContact(contact) +# self.assertEqual(len(self.routingTable._buckets), 2) +# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8) +# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2) +# self.assertNotIn(contact, self.routingTable._buckets[0]._contacts) +# self.assertNotIn(contact, self.routingTable._buckets[1]._contacts) +# + +# class KeyErrorFixedTest(unittest.TestCase): +# """ Basic tests case for boolean operators on the Contact class """ +# +# def setUp(self): +# own_id = (2 ** constants.key_bits) - 1 +# # carefully chosen own_id. here's the logic +# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id +# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the +# # end +# +# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id) +# +# def fill_bucket(self, bucket_min): +# bucket_size = lbrynet.dht.constants.k +# for i in range(bucket_min, bucket_min + bucket_size): +# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None)) +# +# def overflow_bucket(self, bucket_min): +# bucket_size = lbrynet.dht.constants.k +# self.fill_bucket(bucket_min) +# self.table.addContact( +# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1), +# '127.0.0.1', 9999, None)) +# +# def testKeyError(self): +# +# # find middle, so we know where bucket will split +# bucket_middle = self.table._buckets[0].rangeMax / 2 +# +# # fill last bucket +# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1) +# # -1 in previous line because own_id is in last bucket +# +# # fill/overflow 7 more buckets +# bucket_start = 0 +# for i in range(0, lbrynet.dht.constants.k): +# self.overflow_bucket(bucket_start) +# bucket_start += bucket_middle / (2 ** i) +# +# # replacement cache now has k-1 entries. +# # adding one more contact to bucket 0 used to cause a KeyError, but it should work +# self.table.addContact( +# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None)) +# +# # import math +# # print "" +# # for i, bucket in enumerate(self.table._buckets): +# # print "Bucket " + str(i) + " (2 ** " + str( +# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str( +# # math.log(bucket.rangeMax, 2)) + ")" +# # for c in bucket.getContacts(): +# # print " contact " + str(c.id) +# # for key, bucket in self.table._replacementCache.items(): +# # print "Replacement Cache for Bucket " + str(key) +# # for c in bucket: +# # print " contact " + str(c.id) diff --git a/tests/unit/dht/serialization/__init__.py b/tests/unit/dht/serialization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/dht/serialization/test_bencoding.py b/tests/unit/dht/serialization/test_bencoding.py new file mode 100644 index 000000000..b516d2e46 --- /dev/null +++ b/tests/unit/dht/serialization/test_bencoding.py @@ -0,0 +1,64 @@ +import unittest +from lbrynet.dht.serialization.bencoding import _bencode, bencode, bdecode, DecodeError + + +class EncodeDecodeTest(unittest.TestCase): + def test_fail_with_not_dict(self): + with self.assertRaises(TypeError): + bencode(1) + with self.assertRaises(TypeError): + bencode(b'derp') + with self.assertRaises(TypeError): + bencode('derp') + with self.assertRaises(TypeError): + bencode([b'derp']) + with self.assertRaises(TypeError): + bencode([object()]) + with self.assertRaises(TypeError): + bencode({b'derp': object()}) + + def test_fail_bad_type(self): + with self.assertRaises(DecodeError): + bdecode(b'd4le', True) + + def test_integer(self): + self.assertEqual(_bencode(42), b'i42e') + self.assertEqual(bdecode(b'i42e', True), 42) + + def test_bytes(self): + self.assertEqual(_bencode(b''), b'0:') + self.assertEqual(_bencode(b'spam'), b'4:spam') + self.assertEqual(_bencode(b'4:spam'), b'6:4:spam') + self.assertEqual(_bencode(bytearray(b'spam')), b'4:spam') + + self.assertEqual(bdecode(b'0:', True), b'') + self.assertEqual(bdecode(b'4:spam', True), b'spam') + self.assertEqual(bdecode(b'6:4:spam', True), b'4:spam') + + def test_string(self): + self.assertEqual(_bencode(''), b'0:') + self.assertEqual(_bencode('spam'), b'4:spam') + self.assertEqual(_bencode('4:spam'), b'6:4:spam') + + def test_list(self): + self.assertEqual(_bencode([b'spam', 42]), b'l4:spami42ee') + self.assertEqual(bdecode(b'l4:spami42ee', True), [b'spam', 42]) + + def test_dict(self): + self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee') + self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'}) + + def test_mixed(self): + self.assertEqual(_bencode( + [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]), + b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee' + ) + + self.assertEqual(bdecode( + b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee', True), + [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]] + ) + + def test_decode_error(self): + self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True) + self.assertRaises(DecodeError, bdecode, b'', True) diff --git a/tests/unit/dht/serialization/test_datagram.py b/tests/unit/dht/serialization/test_datagram.py new file mode 100644 index 000000000..738ebaa6c --- /dev/null +++ b/tests/unit/dht/serialization/test_datagram.py @@ -0,0 +1,60 @@ +import unittest +from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram +from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE + + +class TestDatagram(unittest.TestCase): + def test_ping_request_datagram(self): + serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'ping') + self.assertListEqual(decoded.args, [{b'protocolVersion': 1}]) + + def test_ping_response(self): + serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, RESPONSE_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.response, b'pong') + + def test_find_node_request_datagram(self): + serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'findNode') + self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}]) + + def test_find_node_response(self): + closest_response = [(b'3' * 48, '1.2.3.4', 1234)] + expected = [[b'3' * 48, b'1.2.3.4', 1234]] + + serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, closest_response).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, RESPONSE_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.response, expected) + + def test_find_value_request(self): + serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, REQUEST_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertEqual(decoded.method, b'findValue') + self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}]) + + def test_find_value_response(self): + found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']} + serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode() + decoded = decode_datagram(serialized) + self.assertEqual(decoded.packet_type, RESPONSE_TYPE) + self.assertEqual(decoded.rpc_id, b'1' * 20) + self.assertEqual(decoded.node_id, b'1' * 48) + self.assertDictEqual(decoded.response, found_value_response) diff --git a/tests/unit/dht/test_async_gen_junction.py b/tests/unit/dht/test_async_gen_junction.py new file mode 100644 index 000000000..16cd6168c --- /dev/null +++ b/tests/unit/dht/test_async_gen_junction.py @@ -0,0 +1,88 @@ +import asyncio +from torba.testcase import AsyncioTestCase +from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction + + +class MockAsyncGen: + def __init__(self, loop, result, delay, stop_cnt=10): + self.loop = loop + self.result = result + self.delay = delay + self.count = 0 + self.stop_cnt = stop_cnt + self.called_close = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.count > self.stop_cnt - 1: + raise StopAsyncIteration() + self.count += 1 + await asyncio.sleep(self.delay, loop=self.loop) + return self.result + + async def aclose(self): + self.called_close = True + + +class TestAsyncGeneratorJunction(AsyncioTestCase): + def setUp(self): + self.loop = asyncio.get_event_loop() + + async def _test_junction(self, expected, *generators): + order = [] + async with AsyncGeneratorJunction(self.loop) as junction: + for generator in generators: + junction.add_generator(generator) + async for item in junction: + order.append(item) + self.assertListEqual(order, expected) + + async def test_yield_order(self): + expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2] + fast_gen = MockAsyncGen(self.loop, 1, 0.1) + slow_gen = MockAsyncGen(self.loop, 2, 0.2) + await self._test_junction(expected_order, fast_gen, slow_gen) + self.assertEqual(fast_gen.called_close, True) + self.assertEqual(slow_gen.called_close, True) + + async def test_one_stopped_first(self): + expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] + fast_gen = MockAsyncGen(self.loop, 1, 0.1, 5) + slow_gen = MockAsyncGen(self.loop, 2, 0.2) + await self._test_junction(expected_order, fast_gen, slow_gen) + self.assertEqual(fast_gen.called_close, True) + self.assertEqual(slow_gen.called_close, True) + + async def test_with_non_async_gen_class(self): + expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2] + + async def fast_gen(): + for i in range(10): + if i == 5: + return + await asyncio.sleep(0.1) + yield 1 + + slow_gen = MockAsyncGen(self.loop, 2, 0.2) + await self._test_junction(expected_order, fast_gen(), slow_gen) + self.assertEqual(slow_gen.called_close, True) + + async def test_stop_when_encapsulating_task_cancelled(self): + fast_gen = MockAsyncGen(self.loop, 1, 0.1) + slow_gen = MockAsyncGen(self.loop, 2, 0.2) + + async def _task(): + async with AsyncGeneratorJunction(self.loop) as junction: + junction.add_generator(fast_gen) + junction.add_generator(slow_gen) + async for _ in junction: + pass + + task = self.loop.create_task(_task()) + self.loop.call_later(0.5, task.cancel) + with self.assertRaises(asyncio.CancelledError): + await task + self.assertEqual(fast_gen.called_close, True) + self.assertEqual(slow_gen.called_close, True) diff --git a/tests/unit/dht/test_encoding.py b/tests/unit/dht/test_encoding.py deleted file mode 100644 index da29c67b1..000000000 --- a/tests/unit/dht/test_encoding.py +++ /dev/null @@ -1,50 +0,0 @@ -from twisted.trial import unittest -from lbrynet.dht.encoding import bencode, bdecode, DecodeError - - -class EncodeDecodeTest(unittest.TestCase): - - def test_integer(self): - self.assertEqual(bencode(42), b'i42e') - - self.assertEqual(bdecode(b'i42e'), 42) - - def test_bytes(self): - self.assertEqual(bencode(b''), b'0:') - self.assertEqual(bencode(b'spam'), b'4:spam') - self.assertEqual(bencode(b'4:spam'), b'6:4:spam') - self.assertEqual(bencode(bytearray(b'spam')), b'4:spam') - - self.assertEqual(bdecode(b'0:'), b'') - self.assertEqual(bdecode(b'4:spam'), b'spam') - self.assertEqual(bdecode(b'6:4:spam'), b'4:spam') - - def test_string(self): - self.assertEqual(bencode(''), b'0:') - self.assertEqual(bencode('spam'), b'4:spam') - self.assertEqual(bencode('4:spam'), b'6:4:spam') - - def test_list(self): - self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee') - - self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42]) - - def test_dict(self): - self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee') - - self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'}) - - def test_mixed(self): - self.assertEqual(bencode( - [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]), - b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee' - ) - - self.assertEqual(bdecode( - b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'), - [[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]] - ) - - def test_decode_error(self): - self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz') - self.assertRaises(DecodeError, bdecode, b'') diff --git a/tests/unit/dht/test_hash_announcer.py b/tests/unit/dht/test_hash_announcer.py deleted file mode 100644 index 5733c8c72..000000000 --- a/tests/unit/dht/test_hash_announcer.py +++ /dev/null @@ -1,59 +0,0 @@ -from twisted.trial import unittest -from twisted.internet import defer, task -from lbrynet import utils -from lbrynet.conf import Config -from lbrynet.extras.daemon.HashAnnouncer import DHTHashAnnouncer -from tests.test_utils import random_lbry_hash - - -class MocDHTNode: - def __init__(self): - self.blobs_announced = 0 - self.clock = task.Clock() - self.peerPort = 3333 - - def announceHaveBlob(self, blob): - self.blobs_announced += 1 - d = defer.Deferred() - self.clock.callLater(1, d.callback, ['fake']) - return d - - -class MocStorage: - def __init__(self, blobs_to_announce): - self.blobs_to_announce = blobs_to_announce - self.announced = False - - def get_blobs_to_announce(self): - if not self.announced: - self.announced = True - return defer.succeed(self.blobs_to_announce) - else: - return defer.succeed([]) - - def update_last_announced_blob(self, blob_hash, now): - return defer.succeed(None) - - -class DHTHashAnnouncerTest(unittest.TestCase): - - def setUp(self): - conf = Config() - self.num_blobs = 10 - self.blobs_to_announce = [] - for i in range(0, self.num_blobs): - self.blobs_to_announce.append(random_lbry_hash()) - self.dht_node = MocDHTNode() - self.clock = self.dht_node.clock - utils.call_later = self.clock.callLater - self.storage = MocStorage(self.blobs_to_announce) - self.announcer = DHTHashAnnouncer(conf, self.dht_node, self.storage) - - @defer.inlineCallbacks - def test_immediate_announce(self): - announce_d = self.announcer.immediate_announce(self.blobs_to_announce) - self.assertEqual(self.announcer.hash_queue_size(), self.num_blobs) - self.clock.advance(1) - yield announce_d - self.assertEqual(self.dht_node.blobs_announced, self.num_blobs) - self.assertEqual(self.announcer.hash_queue_size(), 0) diff --git a/tests/unit/dht/test_kbucket.py b/tests/unit/dht/test_kbucket.py deleted file mode 100644 index 360dee32d..000000000 --- a/tests/unit/dht/test_kbucket.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive - -from twisted.trial import unittest -import struct -from lbrynet.utils import generate_id -from lbrynet.dht import kbucket -from lbrynet.dht.contact import ContactManager -from lbrynet.dht import constants - - -def address_generator(address=(10, 42, 42, 1)): - def increment(addr): - value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1 - new_addr = [] - for i in range(4): - new_addr.append(value % 256) - value >>= 8 - return tuple(new_addr[::-1]) - - while True: - yield "{}.{}.{}.{}".format(*address) - address = increment(address) - - -class KBucketTest(unittest.TestCase): - """ Test case for the KBucket class """ - def setUp(self): - self.address_generator = address_generator() - self.contact_manager = ContactManager() - self.kbucket = kbucket.KBucket(0, 2**constants.key_bits, generate_id()) - - def testAddContact(self): - """ Tests if the bucket handles contact additions/updates correctly """ - # Test if contacts can be added to empty list - # Add k contacts to bucket - for i in range(constants.k): - tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None) - self.kbucket.addContact(tmpContact) - self.assertEqual( - self.kbucket._contacts[i], - tmpContact, - "Contact in position %d not the same as the newly-added contact" % i) - - # Test if contact is not added to full list - tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None) - self.assertRaises(kbucket.BucketFull, self.kbucket.addContact, tmpContact) - - # Test if an existing contact is updated correctly if added again - existingContact = self.kbucket._contacts[0] - self.kbucket.addContact(existingContact) - self.assertEqual( - self.kbucket._contacts.index(existingContact), - len(self.kbucket._contacts)-1, - 'Contact not correctly updated; it should be at the end of the list of contacts') - - def testGetContacts(self): - # try and get 2 contacts from empty list - result = self.kbucket.getContacts(2) - self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" % - (len(result))) - - - # Add k-2 contacts - node_ids = [] - if constants.k >= 2: - for i in range(constants.k-2): - node_ids.append(generate_id()) - tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0, - None) - self.kbucket.addContact(tmpContact) - else: - # add k contacts - for i in range(constants.k): - node_ids.append(generate_id()) - tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0, - None) - self.kbucket.addContact(tmpContact) - - # try to get too many contacts - # requested count greater than bucket size; should return at most k contacts - contacts = self.kbucket.getContacts(constants.k+3) - self.assertTrue(len(contacts) <= constants.k, - 'Returned list should not have more than k entries!') - - # verify returned contacts in list - for node_id, i in zip(node_ids, range(constants.k-2)): - self.assertFalse(self.kbucket._contacts[i].id != node_id, - "Contact in position %s not same as added contact" % (str(i))) - - # try to get too many contacts - # requested count one greater than number of contacts - if constants.k >= 2: - result = self.kbucket.getContacts(constants.k-1) - self.assertFalse(len(result) != constants.k-2, - "Too many contacts in returned list %s - should be %s" % - (len(result), constants.k-2)) - else: - result = self.kbucket.getContacts(constants.k-1) - # if the count is <= 0, it should return all of it's contats - self.assertFalse(len(result) != constants.k, - "Too many contacts in returned list %s - should be %s" % - (len(result), constants.k-2)) - result = self.kbucket.getContacts(constants.k-3) - self.assertFalse(len(result) != constants.k-3, - "Too many contacts in returned list %s - should be %s" % - (len(result), constants.k-3)) - - def testRemoveContact(self): - # try remove contact from empty list - rmContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None) - self.assertRaises(ValueError, self.kbucket.removeContact, rmContact) - - # Add couple contacts - for i in range(constants.k-2): - tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None) - self.kbucket.addContact(tmpContact) - - # try remove contact from empty list - self.kbucket.addContact(rmContact) - result = self.kbucket.removeContact(rmContact) - self.assertNotIn(rmContact, self.kbucket._contacts, "Could not remove contact from bucket") diff --git a/tests/unit/dht/test_messages.py b/tests/unit/dht/test_messages.py deleted file mode 100644 index f49e2ed44..000000000 --- a/tests/unit/dht/test_messages.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -# -# This library is free software, distributed under the terms of -# the GNU Lesser General Public License Version 3, or any later version. -# See the COPYING file included in this archive - -from twisted.trial import unittest - -from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage -from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat - - -class DefaultFormatTranslatorTest(unittest.TestCase): - """ Test case for the default message translator """ - def setUp(self): - self.cases = ((RequestMessage('1' * 48, 'rpcMethod', - {'arg1': 'a string', 'arg2': 123}, '1' * 20), - {DefaultFormat.headerType: DefaultFormat.typeRequest, - DefaultFormat.headerNodeID: '1' * 48, - DefaultFormat.headerMsgID: '1' * 20, - DefaultFormat.headerPayload: 'rpcMethod', - DefaultFormat.headerArgs: {'arg1': 'a string', 'arg2': 123}}), - - (ResponseMessage('2' * 20, '2' * 48, 'response'), - {DefaultFormat.headerType: DefaultFormat.typeResponse, - DefaultFormat.headerNodeID: '2' * 48, - DefaultFormat.headerMsgID: '2' * 20, - DefaultFormat.headerPayload: 'response'}), - - (ErrorMessage('3' * 20, '3' * 48, - "", 'this is a test exception'), - {DefaultFormat.headerType: DefaultFormat.typeError, - DefaultFormat.headerNodeID: '3' * 48, - DefaultFormat.headerMsgID: '3' * 20, - DefaultFormat.headerPayload: "", - DefaultFormat.headerArgs: 'this is a test exception'}), - - (ResponseMessage( - '4' * 20, '4' * 48, - [('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82', - '127.0.0.1', 1919), - ('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{', - '127.0.0.1', 1921)]), - {DefaultFormat.headerType: DefaultFormat.typeResponse, - DefaultFormat.headerNodeID: '4' * 48, - DefaultFormat.headerMsgID: '4' * 20, - DefaultFormat.headerPayload: - [('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82', - '127.0.0.1', 1919), - ('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{', - '127.0.0.1', 1921)]}) - ) - self.translator = DefaultFormat() - self.assertTrue( - isinstance(self.translator, MessageTranslator), - 'Translator class must inherit from entangled.kademlia.msgformat.MessageTranslator!') - - def testToPrimitive(self): - """ Tests translation from a Message object to a primitive """ - for msg, msgPrimitive in self.cases: - translatedObj = self.translator.toPrimitive(msg) - self.assertEqual(len(translatedObj), len(msgPrimitive), - "Translated object does not match example object's size") - for key in msgPrimitive: - self.assertEqual( - translatedObj[key], msgPrimitive[key], - 'Message object type %s not translated correctly into primitive on ' - 'key "%s"; expected "%s", got "%s"' % - (msg.__class__.__name__, key, msgPrimitive[key], translatedObj[key])) - - def testFromPrimitive(self): - """ Tests translation from a primitive to a Message object """ - for msg, msgPrimitive in self.cases: - translatedObj = self.translator.fromPrimitive(msgPrimitive) - self.assertEqual( - type(translatedObj), type(msg), - 'Message type incorrectly translated; expected "%s", got "%s"' % - (type(msg), type(translatedObj))) - for key in msg.__dict__: - self.assertEqual( - msg.__dict__[key], translatedObj.__dict__[key], - 'Message instance variable "%s" not translated correctly; ' - 'expected "%s", got "%s"' % - (key, msg.__dict__[key], translatedObj.__dict__[key])) diff --git a/tests/unit/dht/test_node.py b/tests/unit/dht/test_node.py index 983975235..2c2926917 100644 --- a/tests/unit/dht/test_node.py +++ b/tests/unit/dht/test_node.py @@ -1,88 +1,85 @@ -import hashlib -import struct - -from twisted.trial import unittest -from twisted.internet import defer -from lbrynet.dht.node import Node +import asyncio +import typing +from torba.testcase import AsyncioTestCase +from tests import dht_mocks from lbrynet.dht import constants -from lbrynet.utils import generate_id +from lbrynet.dht.node import Node +from lbrynet.dht.peer import PeerManager -class NodeIDTest(unittest.TestCase): +class TestNodePingQueueDiscover(AsyncioTestCase): + async def test_ping_queue_discover(self): + loop = asyncio.get_event_loop() - def setUp(self): - self.node = Node() + peer_addresses = [ + (constants.generate_id(1), '1.2.3.1'), + (constants.generate_id(2), '1.2.3.2'), + (constants.generate_id(3), '1.2.3.3'), + (constants.generate_id(4), '1.2.3.4'), + (constants.generate_id(5), '1.2.3.5'), + (constants.generate_id(6), '1.2.3.6'), + (constants.generate_id(7), '1.2.3.7'), + (constants.generate_id(8), '1.2.3.8'), + (constants.generate_id(9), '1.2.3.9'), + ] + with dht_mocks.mock_network_loop(loop): + advance = dht_mocks.get_time_accelerator(loop, loop.time()) + # start the nodes + nodes: typing.Dict[int, Node] = { + i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address) + for i, (node_id, address) in enumerate(peer_addresses) + } + for i, n in nodes.items(): + n.start(peer_addresses[i][1], []) - def test_new_node_has_auto_created_id(self): - self.assertEqual(type(self.node.node_id), bytes) - self.assertEqual(len(self.node.node_id), 48) + await advance(1) - def test_uniqueness_and_length_of_generated_ids(self): - previous_ids = [] - for i in range(100): - new_id = self.node._generateID() - self.assertNotIn(new_id, previous_ids, f'id at index {i} not unique') - self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id))) - previous_ids.append(new_id) + node_1 = nodes[0] + # ping 8 nodes from node_1, this will result in a delayed return ping + futs = [] + for i in range(1, len(peer_addresses)): + node = nodes[i] + assert node.protocol.node_id != node_1.protocol.node_id + peer = node_1.protocol.peer_manager.get_kademlia_peer( + node.protocol.node_id, node.protocol.external_ip, udp_port=node.protocol.udp_port + ) + futs.append(node_1.protocol.get_rpc_peer(peer).ping()) + await advance(3) + replies = await asyncio.gather(*tuple(futs)) + self.assertTrue(all(map(lambda reply: reply == b"pong", replies))) -class NodeDataTest(unittest.TestCase): - """ Test case for the Node class's data-related functions """ + # run for long enough for the delayed pings to have been sent by node 1 + await advance(1000) - def setUp(self): - h = hashlib.sha384() - h.update(b'test') - self.node = Node() - self.contact = self.node.contact_manager.make_contact( - h.digest(), '127.0.0.1', 12345, self.node._protocol) - self.token = self.node.make_token(self.contact.compact_ip()) - self.cases = [] - for i in range(5): - h.update(str(i).encode()) - self.cases.append((h.digest(), 5000+2*i)) - self.cases.append((h.digest(), 5001+2*i)) + # verify all of the previously pinged peers have node_1 in their routing tables + for n in nodes.values(): + peers = n.protocol.routing_table.get_peers() + if n is node_1: + self.assertEqual(8, len(peers)) + else: + self.assertEqual(1, len(peers)) + self.assertEqual((peers[0].node_id, peers[0].address, peers[0].udp_port), + (node_1.protocol.node_id, node_1.protocol.external_ip, node_1.protocol.udp_port)) - @defer.inlineCallbacks - def test_store(self): - """ Tests if the node can store (and privately retrieve) some data """ - for key, port in self.cases: - yield self.node.store( - self.contact, key, self.token, port, self.contact.id, 0 - ) - for key, value in self.cases: - expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id - self.assertTrue(self.node._dataStore.hasPeersForBlob(key), - "Stored key not found in node's DataStore: '%s'" % key) - self.assertIn(expected_result, self.node._dataStore.getPeersForBlob(key), - "Stored val not found in node's DataStore: key:'%s' port:'%s' %s" - % (key, value, self.node._dataStore.getPeersForBlob(key))) + # run long enough for the refresh loop to run + await advance(3600) + # verify all the nodes know about each other + for n in nodes.values(): + if n is node_1: + continue + peers = n.protocol.routing_table.get_peers() + self.assertEqual(8, len(peers)) + self.assertSetEqual( + {n_id[0] for n_id in peer_addresses if n_id[0] != n.protocol.node_id}, + {c.node_id for c in peers} + ) + self.assertSetEqual( + {n_addr[1] for n_addr in peer_addresses if n_addr[1] != n.protocol.external_ip}, + {c.address for c in peers} + ) -class NodeContactTest(unittest.TestCase): - """ Test case for the Node class's contact management-related functions """ - def setUp(self): - self.node = Node() - - @defer.inlineCallbacks - def test_add_contact(self): - """ Tests if a contact can be added and retrieved correctly """ - # Create the contact - contact_id = generate_id(b'node1') - contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol) - # Now add it... - yield self.node.addContact(contact) - # ...and request the closest nodes to it using FIND_NODE - closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k) - self.assertEqual(len(closest_nodes), 1) - self.assertIn(contact, closest_nodes) - - @defer.inlineCallbacks - def test_add_self_as_contact(self): - """ Tests the node's behaviour when attempting to add itself as a contact """ - # Create a contact with the same ID as the local node's ID - contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None) - # Now try to add it - yield self.node.addContact(contact) - # ...and request the closest nodes to it using FIND_NODE - closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k) - self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.') + # teardown + for n in nodes.values(): + n.stop() diff --git a/tests/unit/dht/test_contact.py b/tests/unit/dht/test_peer.py similarity index 65% rename from tests/unit/dht/test_contact.py rename to tests/unit/dht/test_peer.py index 2721a8fe7..e71d4d4da 100644 --- a/tests/unit/dht/test_contact.py +++ b/tests/unit/dht/test_peer.py @@ -1,64 +1,37 @@ -from binascii import hexlify -from twisted.internet import task -from twisted.trial import unittest +import asyncio +import unittest from lbrynet.utils import generate_id -from lbrynet.dht.contact import ContactManager -from lbrynet.dht import constants +from lbrynet.dht.peer import PeerManager +from torba.testcase import AsyncioTestCase -class ContactTest(unittest.TestCase): - """ Basic tests case for boolean operators on the Contact class """ +class PeerTest(AsyncioTestCase): def setUp(self): - self.contact_manager = ContactManager() + self.loop = asyncio.get_event_loop() + self.peer_manager = PeerManager(self.loop) self.node_ids = [generate_id(), generate_id(), generate_id()] - make_contact = self.contact_manager.make_contact - self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1) - self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) - self.second_contact_second_reference = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32) - self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50) + self.first_contact = self.peer_manager.get_kademlia_peer(self.node_ids[1], '127.0.0.1', udp_port=1000) + self.second_contact = self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000) def test_make_contact_error_cases(self): - self.assertRaises( - ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', 100000, None) - self.assertRaises( - ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', 1000, None) - self.assertRaises( - ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', 1000, None) - self.assertRaises( - ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None) - - def test_no_duplicate_contact_objects(self): - self.assertIs(self.second_contact, self.second_contact_second_reference) - self.assertIsNot(self.first_contact, self.first_contact_different_values) + self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', 100000) + self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20.1', 1000) + self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], 'this is not an ip', 1000) + self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', -1000) + self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, b'not valid node id', '192.168.1.20', 1000) def test_boolean(self): - """ Test "equals" and "not equals" comparisons """ - self.assertNotEqual( - self.first_contact, self.contact_manager.make_contact( - self.first_contact.id, self.first_contact.address, self.first_contact.port + 1, None, 32 - ) + self.assertNotEqual(self.first_contact, self.second_contact) + self.assertEqual( + self.second_contact, self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000) ) - self.assertNotEqual( - self.first_contact, self.contact_manager.make_contact( - self.first_contact.id, '193.168.1.1', self.first_contact.port, None, 32 - ) - ) - self.assertNotEqual( - self.first_contact, self.contact_manager.make_contact( - generate_id(), self.first_contact.address, self.first_contact.port, None, 32 - ) - ) - self.assertEqual(self.second_contact, self.second_contact_second_reference) def test_compact_ip(self): self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01') self.assertEqual(self.second_contact.compact_ip(), b'\xc0\xa8\x00\x01') - def test_id_log(self): - self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1])) - self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8]) - +@unittest.SkipTest class TestContactLastReplied(unittest.TestCase): def setUp(self): self.clock = task.Clock() @@ -129,6 +102,7 @@ class TestContactLastReplied(unittest.TestCase): self.assertIsNone(self.contact.contact_is_good) +@unittest.SkipTest class TestContactLastRequested(unittest.TestCase): def setUp(self): self.clock = task.Clock() diff --git a/tests/unit/dht/test_routingtable.py b/tests/unit/dht/test_routingtable.py deleted file mode 100644 index 1c7985cc8..000000000 --- a/tests/unit/dht/test_routingtable.py +++ /dev/null @@ -1,213 +0,0 @@ -from binascii import hexlify, unhexlify - -from twisted.trial import unittest -from twisted.internet import defer -from lbrynet.dht import constants -from lbrynet.dht.routingtable import TreeRoutingTable -from lbrynet.dht.contact import ContactManager -from lbrynet.dht.distance import Distance -from lbrynet.utils import generate_id - - -class FakeRPCProtocol: - """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """ - def sendRPC(self, *args, **kwargs): - return defer.succeed(None) - - -class TreeRoutingTableTest(unittest.TestCase): - """ Test case for the RoutingTable class """ - def setUp(self): - self.contact_manager = ContactManager() - self.nodeID = generate_id(b'node1') - self.protocol = FakeRPCProtocol() - self.routingTable = TreeRoutingTable(self.nodeID) - - def test_distance(self): - """ Test to see if distance method returns correct result""" - d = Distance(bytes((170,) * 48)) - result = d(bytes((85,) * 48)) - expected = int(hexlify(bytes((255,) * 48)), 16) - self.assertEqual(result, expected) - - @defer.inlineCallbacks - def test_add_contact(self): - """ Tests if a contact can be added and retrieved correctly """ - # Create the contact - contact_id = generate_id(b'node2') - contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) - # Now add it... - yield self.routingTable.addContact(contact) - # ...and request the closest nodes to it (will retrieve it) - closest_nodes = self.routingTable.findCloseNodes(contact_id) - self.assertEqual(len(closest_nodes), 1) - self.assertIn(contact, closest_nodes) - - @defer.inlineCallbacks - def test_get_contact(self): - """ Tests if a specific existing contact can be retrieved correctly """ - contact_id = generate_id(b'node2') - contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) - # Now add it... - yield self.routingTable.addContact(contact) - # ...and get it again - same_contact = self.routingTable.getContact(contact_id) - self.assertEqual(contact, same_contact, 'getContact() should return the same contact') - - @defer.inlineCallbacks - def test_add_parent_node_as_contact(self): - """ - Tests the routing table's behaviour when attempting to add its parent node as a contact - """ - # Create a contact with the same ID as the local node's ID - contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol) - # Now try to add it - yield self.routingTable.addContact(contact) - # ...and request the closest nodes to it using FIND_NODE - closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k) - self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact') - - @defer.inlineCallbacks - def test_remove_contact(self): - """ Tests contact removal """ - # Create the contact - contact_id = generate_id(b'node2') - contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol) - # Now add it... - yield self.routingTable.addContact(contact) - # Verify addition - self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly') - # Now remove it - self.routingTable.removeContact(contact) - self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly') - - @defer.inlineCallbacks - def test_split_bucket(self): - """ Tests if the the routing table correctly dynamically splits k-buckets """ - self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384, - 'Initial k-bucket range should be 0 <= range < 2**384') - # Add k contacts - for i in range(constants.k): - node_id = generate_id(b'remote node %d' % i) - contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) - yield self.routingTable.addContact(contact) - - self.assertEqual(len(self.routingTable._buckets), 1, - 'Only k nodes have been added; the first k-bucket should now ' - 'be full, but should not yet be split') - # Now add 1 more contact - node_id = generate_id(b'yet another remote node') - contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol) - yield self.routingTable.addContact(contact) - self.assertEqual(len(self.routingTable._buckets), 2, - 'k+1 nodes have been added; the first k-bucket should have been ' - 'split into two new buckets') - self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384, - 'K-bucket was split, but its range was not properly adjusted') - self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384, - 'K-bucket was split, but the second (new) bucket\'s ' - 'max range was not set properly') - self.assertEqual(self.routingTable._buckets[0].rangeMax, - self.routingTable._buckets[1].rangeMin, - 'K-bucket was split, but the min/max ranges were ' - 'not divided properly') - - @defer.inlineCallbacks - def test_full_split(self): - """ - Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact - """ - - self.routingTable._parentNodeID = bytes(48 * b'\xff') - - node_ids = [ - b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - ] - - # Add k contacts - for nodeID in node_ids: - # self.assertEquals(nodeID, node_ids[i].decode('hex')) - contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol) - yield self.routingTable.addContact(contact) - self.assertEqual(len(self.routingTable._buckets), 2) - self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8) - self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2) - - # try adding a contact who is further from us than the k'th known contact - nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000' - nodeID = unhexlify(nodeID) - contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol) - self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id)) - yield self.routingTable.addContact(contact) - self.assertEqual(len(self.routingTable._buckets), 2) - self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8) - self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2) - self.assertNotIn(contact, self.routingTable._buckets[0]._contacts) - self.assertNotIn(contact, self.routingTable._buckets[1]._contacts) - - -# class KeyErrorFixedTest(unittest.TestCase): -# """ Basic tests case for boolean operators on the Contact class """ -# -# def setUp(self): -# own_id = (2 ** constants.key_bits) - 1 -# # carefully chosen own_id. here's the logic -# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id -# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the -# # end -# -# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id) -# -# def fill_bucket(self, bucket_min): -# bucket_size = lbrynet.dht.constants.k -# for i in range(bucket_min, bucket_min + bucket_size): -# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None)) -# -# def overflow_bucket(self, bucket_min): -# bucket_size = lbrynet.dht.constants.k -# self.fill_bucket(bucket_min) -# self.table.addContact( -# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1), -# '127.0.0.1', 9999, None)) -# -# def testKeyError(self): -# -# # find middle, so we know where bucket will split -# bucket_middle = self.table._buckets[0].rangeMax / 2 -# -# # fill last bucket -# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1) -# # -1 in previous line because own_id is in last bucket -# -# # fill/overflow 7 more buckets -# bucket_start = 0 -# for i in range(0, lbrynet.dht.constants.k): -# self.overflow_bucket(bucket_start) -# bucket_start += bucket_middle / (2 ** i) -# -# # replacement cache now has k-1 entries. -# # adding one more contact to bucket 0 used to cause a KeyError, but it should work -# self.table.addContact( -# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None)) -# -# # import math -# # print "" -# # for i, bucket in enumerate(self.table._buckets): -# # print "Bucket " + str(i) + " (2 ** " + str( -# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str( -# # math.log(bucket.rangeMax, 2)) + ")" -# # for c in bucket.getContacts(): -# # print " contact " + str(c.id) -# # for key, bucket in self.table._replacementCache.items(): -# # print "Replacement Cache for Bucket " + str(key) -# # for c in bucket: -# # print " contact " + str(c.id)