async lbrynet.dht

This commit is contained in:
Jack Robison 2019-01-22 12:49:43 -05:00 committed by Lex Berezhny
parent a5524d490c
commit 2fa5233796
49 changed files with 3086 additions and 3577 deletions

View file

@ -1,4 +1,8 @@
import hashlib
from cryptography.hazmat.backends import default_backend
backend = default_backend()
def get_lbry_hash_obj():

View file

@ -1,7 +0,0 @@
Francois Aucamp <faucamp@csir.co.za>
Thanks goes to the following people for providing patches/suggestions/tests:
Neil Kleynhans <ntkleynhans@csir.co.za>
Haiyang Ma <haiyang.ma@maidsafe.net>
Bryan McAlister <bmcalister@csir.co.za>

View file

@ -1,165 +0,0 @@
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View file

@ -0,0 +1,65 @@
import asyncio
import typing
import logging
if typing.TYPE_CHECKING:
from lbrynet.dht.node import Node
from lbrynet.extras.daemon.storage import SQLiteStorage
log = logging.getLogger(__name__)
class BlobAnnouncer:
def __init__(self, loop: asyncio.BaseEventLoop, node: 'Node', storage: 'SQLiteStorage'):
self.loop = loop
self.node = node
self.storage = storage
self.pending_call: asyncio.Handle = None
self.announce_task: asyncio.Task = None
self.running = False
self.announce_queue: typing.List[str] = []
async def _announce(self, batch_size: typing.Optional[int] = 10):
if not self.node.joined.is_set():
await self.node.joined.wait()
blob_hashes = await self.storage.get_blobs_to_announce()
if blob_hashes:
self.announce_queue.extend(blob_hashes)
log.info("%i blobs to announce", len(blob_hashes))
batch = []
while len(self.announce_queue):
cnt = 0
announced = []
while self.announce_queue and cnt < batch_size:
blob_hash = self.announce_queue.pop()
announced.append(blob_hash)
batch.append(self.node.announce_blob(blob_hash))
cnt += 1
to_await = []
while batch:
to_await.append(batch.pop())
if to_await:
await asyncio.gather(*tuple(to_await), loop=self.loop)
await self.storage.update_last_announced_blobs(announced, self.loop.time())
log.info("announced %i blobs", len(announced))
if self.running:
self.pending_call = self.loop.call_later(60, self.announce, batch_size)
def announce(self, batch_size: typing.Optional[int] = 10):
self.announce_task = self.loop.create_task(self._announce(batch_size))
def start(self, batch_size: typing.Optional[int] = 10):
if self.running:
raise Exception("already running")
self.running = True
self.announce(batch_size)
def stop(self):
self.running = False
if self.pending_call:
if not self.pending_call.cancelled():
self.pending_call.cancel()
self.pending_call = None
if self.announce_task:
if not (self.announce_task.done() or self.announce_task.cancelled()):
self.announce_task.cancel()
self.announce_task = None

View file

@ -1,81 +0,0 @@
import logging
log = logging.getLogger()
MIN_DELAY = 0.0
MAX_DELAY = 0.01
DELAY_INCREMENT = 0.0001
QUEUE_SIZE_THRESHOLD = 100
class CallLaterManager:
def __init__(self, callLater):
"""
:param callLater: (IReactorTime.callLater)
"""
self._callLater = callLater
self._pendingCallLaters = []
self._delay = MIN_DELAY
def get_min_delay(self):
self._pendingCallLaters = [cl for cl in self._pendingCallLaters if cl.active()]
queue_size = len(self._pendingCallLaters)
if queue_size > QUEUE_SIZE_THRESHOLD:
self._delay = min((self._delay + DELAY_INCREMENT), MAX_DELAY)
else:
self._delay = max((self._delay - 2.0 * DELAY_INCREMENT), MIN_DELAY)
return self._delay
def _cancel(self, call_later):
"""
:param call_later: DelayedCall
:return: (callable) canceller function
"""
def cancel(reason=None):
"""
:param reason: reason for cancellation, this is returned after cancelling the DelayedCall
:return: reason
"""
if call_later.active():
call_later.cancel()
if call_later in self._pendingCallLaters:
self._pendingCallLaters.remove(call_later)
return reason
return cancel
def stop(self):
"""
Cancel any callLaters that are still running
"""
from twisted.internet import defer
while self._pendingCallLaters:
canceller = self._cancel(self._pendingCallLaters[0])
try:
canceller()
except (defer.CancelledError, defer.AlreadyCalledError, ValueError):
pass
def call_later(self, when, what, *args, **kwargs):
"""
Schedule a call later and get a canceller callback function
:param when: (float) delay in seconds
:param what: (callable)
:param args: (*tuple) args to be passed to the callable
:param kwargs: (**dict) kwargs to be passed to the callable
:return: (tuple) twisted.internet.base.DelayedCall object, canceller function
"""
call_later = self._callLater(when, what, *args, **kwargs)
canceller = self._cancel(call_later)
self._pendingCallLaters.append(call_later)
return call_later, canceller
def call_soon(self, what, *args, **kwargs):
delay = self.get_min_delay()
return self.call_later(delay, what, *args, **kwargs)

View file

@ -1,61 +1,40 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
import hashlib
import os
""" This module defines the charaterizing constants of the Kademlia network
C{checkRefreshInterval} and C{udpDatagramMaxSize} are implementation-specific
constants, and do not affect general Kademlia operation.
"""
######### KADEMLIA CONSTANTS ###########
#: Small number Representing the degree of parallelism in network calls
alpha = 3
#: Maximum number of contacts stored in a bucket; this should be an even number
hash_class = hashlib.sha384
hash_length = hash_class().digest_size
hash_bits = hash_length * 8
alpha = 5
k = 8
#: Maximum number of contacts stored in the replacement cache
replacementCacheSize = 8
#: Timeout for network operations (in seconds)
rpcTimeout = 5
# number of rpc attempts to make before a timeout results in the node being removed as a contact
rpcAttempts = 5
# time window to count failures (in seconds)
rpcAttemptsPruningTimeWindow = 600
# Delay between iterations of iterative node lookups (for loose parallelism) (in seconds)
iterativeLookupDelay = rpcTimeout / 2
#: If a k-bucket has not been used for this amount of time, refresh it (in seconds)
refreshTimeout = 3600 # 1 hour
#: The interval at which nodes replicate (republish/refresh) data they are holding
replicateInterval = refreshTimeout
# The time it takes for data to expire in the network; the original publisher of the data
# will also republish the data at this time if it is still valid
dataExpireTimeout = 86400 # 24 hours
tokenSecretChangeInterval = 300 # 5 minutes
######## IMPLEMENTATION-SPECIFIC CONSTANTS ###########
#: The interval for the node to check whether any buckets need refreshing
checkRefreshInterval = refreshTimeout / 5
#: Max size of a single UDP datagram, in bytes. If a message is larger than this, it will
#: be spread across several UDP packets.
udpDatagramMaxSize = 8192 # 8 KB
key_bits = 384
replacement_cache_size = 8
rpc_timeout = 5.0
rpc_attempts = 5
rpc_attempts_pruning_window = 600
iterative_lookup_delay = rpc_timeout / 2.0
refresh_interval = 3600 # 1 hour
replicate_interval = refresh_interval
data_expiration = 86400 # 24 hours
token_secret_refresh_interval = 300 # 5 minutes
check_refresh_interval = refresh_interval / 5
max_datagram_size = 8192 # 8 KB
rpc_id_length = 20
protocol_version = 1
bottom_out_limit = 3
msg_size_limit = max_datagram_size - 26
protocolVersion = 1
def digest(data: bytes) -> bytes:
h = hash_class()
h.update(data)
return h.digest()
def generate_id(num=None) -> bytes:
if num is not None:
return digest(str(num).encode())
else:
return digest(os.urandom(32))
def generate_rpc_id(num=None) -> bytes:
return generate_id(num)[:rpc_id_length]

View file

@ -1,189 +0,0 @@
import ipaddress
from binascii import hexlify
from functools import reduce
from lbrynet.dht import constants
def is_valid_ipv4(address):
try:
ip = ipaddress.ip_address(address)
return ip.version == 4
except ipaddress.AddressValueError:
return False
class _Contact:
""" Encapsulation for remote contact
This class contains information on a single remote contact, and also
provides a direct RPC API to the remote node which it represents
"""
def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm):
if id is not None:
if not len(id) == constants.key_bits // 8:
raise ValueError("invalid node id: {}".format(hexlify(id).decode()))
if not 0 <= udpPort <= 65536:
raise ValueError("invalid port")
if not is_valid_ipv4(ipAddress):
raise ValueError("invalid ip address")
self._contactManager = contactManager
self._id = id
self.address = ipAddress
self.port = udpPort
self._networkProtocol = networkProtocol
self.commTime = firstComm
self.getTime = self._contactManager._get_time
self.lastReplied = None
self.lastRequested = None
self.protocolVersion = 0
self._token = (None, 0) # token, timestamp
def update_token(self, token):
self._token = token, self.getTime() if token else 0
@property
def token(self):
# expire the token 1 minute early to be safe
return self._token[0] if self._token[1] + 240 > self.getTime() else None
@property
def lastInteracted(self):
return max(self.lastRequested or 0, self.lastReplied or 0, self.lastFailed or 0)
@property
def id(self):
return self._id
def log_id(self, short=True):
if not self.id:
return "not initialized"
id_hex = hexlify(self.id)
return id_hex if not short else id_hex[:8]
@property
def failedRPCs(self):
return len(self.failures)
@property
def lastFailed(self):
return (self.failures or [None])[-1]
@property
def failures(self):
return self._contactManager._rpc_failures.get((self.address, self.port), [])
@property
def contact_is_good(self):
"""
:return: False if contact is bad, None if contact is unknown, or True if contact is good
"""
failures = self.failures
now = self.getTime()
delay = constants.checkRefreshInterval
if failures:
if self.lastReplied and len(failures) >= 2 and self.lastReplied < failures[-2]:
return False
elif self.lastReplied and len(failures) >= 2 and self.lastReplied > failures[-2]:
pass # handled below
elif len(failures) >= 2:
return False
if self.lastReplied and self.lastReplied > now - delay:
return True
if self.lastReplied and self.lastRequested and self.lastRequested > now - delay:
return True
return None
def __eq__(self, other):
if not isinstance(other, _Contact):
raise TypeError("invalid type to compare with Contact: %s" % str(type(other)))
return (self.id, self.address, self.port) == (other.id, other.address, other.port)
def __hash__(self):
return hash((self.id, self.address, self.port))
def compact_ip(self):
compact_ip = reduce(
lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray())
return compact_ip
def set_id(self, id):
if not self._id:
self._id = id
def update_last_replied(self):
self.lastReplied = int(self.getTime())
def update_last_requested(self):
self.lastRequested = int(self.getTime())
def update_last_failed(self):
failures = self._contactManager._rpc_failures.get((self.address, self.port), [])
failures.append(self.getTime())
self._contactManager._rpc_failures[(self.address, self.port)] = failures
def update_protocol_version(self, version):
self.protocolVersion = version
def __str__(self):
return '<%s.%s object; IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, self.address, self.port)
def __getattr__(self, name):
""" This override allows the host node to call a method of the remote
node (i.e. this contact) as if it was a local function.
For instance, if C{remoteNode} is a instance of C{Contact}, the
following will result in C{remoteNode}'s C{test()} method to be
called with argument C{123}::
remoteNode.test(123)
Such a RPC method call will return a Deferred, which will callback
when the contact responds with the result (or an error occurs).
This happens via this contact's C{_networkProtocol} object (i.e. the
host Node's C{_protocol} object).
"""
if name not in ['ping', 'findValue', 'findNode', 'store']:
raise AttributeError("unknown command: %s" % name)
def _sendRPC(*args, **kwargs):
return self._networkProtocol.sendRPC(self, name.encode(), args)
return _sendRPC
class ContactManager:
def __init__(self, get_time=None):
if not get_time:
from twisted.internet import reactor
get_time = reactor.seconds
self._get_time = get_time
self._contacts = {}
self._rpc_failures = {}
def get_contact(self, id, address, port):
for contact in self._contacts.values():
if contact.id == id and contact.address == address and contact.port == port:
return contact
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
contact = self.get_contact(id, ipAddress, udpPort)
if contact:
return contact
contact = _Contact(self, id, ipAddress, udpPort, networkProtocol, firstComm or self._get_time())
self._contacts[(id, ipAddress, udpPort)] = contact
return contact
def is_ignored(self, origin_tuple):
failed_rpc_count = len(self._prune_failures(origin_tuple))
return failed_rpc_count > constants.rpcAttempts
def _prune_failures(self, origin_tuple):
# Prunes recorded failures to the last time window of attempts
pruning_limit = self._get_time() - constants.rpcAttemptsPruningTimeWindow
pruned = list(filter(lambda t: t >= pruning_limit, self._rpc_failures.get(origin_tuple, [])))
self._rpc_failures[origin_tuple] = pruned
return pruned

View file

@ -1,66 +0,0 @@
from collections import UserDict
from lbrynet.dht import constants
class DictDataStore(UserDict):
""" A datastore using an in-memory Python dictionary """
#implements(IDataStore)
def __init__(self, getTime=None):
# Dictionary format:
# { <key>: (<contact>, <value>, <lastPublished>, <originallyPublished> <originalPublisherID>) }
super().__init__()
if not getTime:
from twisted.internet import reactor
getTime = reactor.seconds
self._getTime = getTime
self.completed_blobs = set()
def filter_bad_and_expired_peers(self, key):
"""
Returns only non-expired and unknown/good peers
"""
return filter(
lambda peer:
self._getTime() - peer[3] < constants.dataExpireTimeout and peer[0].contact_is_good is not False,
self[key]
)
def filter_expired_peers(self, key):
"""
Returns only non-expired peers
"""
return filter(lambda peer: self._getTime() - peer[3] < constants.dataExpireTimeout, self[key])
def removeExpiredPeers(self):
expired_keys = []
for key in self.keys():
unexpired_peers = list(self.filter_expired_peers(key))
if not unexpired_peers:
expired_keys.append(key)
else:
self[key] = unexpired_peers
for key in expired_keys:
del self[key]
def hasPeersForBlob(self, key):
return bool(key in self and len(tuple(self.filter_bad_and_expired_peers(key))))
def addPeerToBlob(self, contact, key, compact_address, lastPublished, originallyPublished, originalPublisherID):
if key in self:
if compact_address not in map(lambda store_tuple: store_tuple[1], self[key]):
self[key].append(
(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)
)
else:
self[key] = [(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)]
def getPeersForBlob(self, key):
return [] if key not in self else [val[1] for val in self.filter_bad_and_expired_peers(key)]
def getStoringContacts(self):
contacts = set()
for key in self:
for values in self[key]:
contacts.add(values[0])
return list(contacts)

View file

@ -1,43 +1,23 @@
import binascii
#import exceptions
# this is a dict of {"exceptions.<exception class name>": exception class} items used to raise
# remote built-in exceptions locally
BUILTIN_EXCEPTIONS = {
# "exceptions.%s" % e: getattr(exceptions, e) for e in dir(exceptions) if not e.startswith("_")
}
class BaseKademliaException(Exception):
pass
class DecodeError(Exception):
class DecodeError(BaseKademliaException):
"""
Should be raised by an C{Encoding} implementation if decode operation
fails
"""
class BucketFull(Exception):
class BucketFull(BaseKademliaException):
"""
Raised when the bucket is full
"""
class UnknownRemoteException(Exception):
class RemoteException(BaseKademliaException):
pass
class TimeoutError(Exception):
""" Raised when a RPC times out """
def __init__(self, remote_contact_id):
# remote_contact_id is a binary blob so we need to convert it
# into something more readable
if remote_contact_id:
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id))
else:
msg = 'Timeout connecting to uninitialized node'
super().__init__(msg)
self.remote_contact_id = remote_contact_id
class TransportNotConnected(Exception):
class TransportNotConnected(BaseKademliaException):
pass

View file

@ -1,116 +0,0 @@
from zope.interface import Interface
class IDataStore(Interface):
""" Interface for classes implementing physical storage (for data
published via the "STORE" RPC) for the Kademlia DHT
@note: This provides an interface for a dict-like object
"""
def keys(self):
""" Return a list of the keys in this data store """
def removeExpiredPeers(self):
pass
def hasPeersForBlob(self, key):
pass
def addPeerToBlob(self, key, value, lastPublished, originallyPublished, originalPublisherID):
pass
def getPeersForBlob(self, key):
pass
def removePeer(self, key):
pass
class IRoutingTable(Interface):
""" Interface for RPC message translators/formatters
Classes inheriting from this should provide a suitable routing table for
a parent Node object (i.e. the local entity in the Kademlia network)
"""
def __init__(self, parentNodeID):
"""
@param parentNodeID: The n-bit node ID of the node to which this
routing table belongs
@type parentNodeID: str
"""
def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already
exists, its status will be updated
@param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact
"""
def findCloseNodes(self, key, count, _rpcNodeID=None):
""" Finds a number of known nodes closest to the node/value with the
specified key.
@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@param count: the amount of contacts to return
@type count: int
@param _rpcNodeID: Used during RPC, this is be the sender's Node ID
Whatever ID is passed in the parameter will get
excluded from the list of returned contacts.
@type _rpcNodeID: str
@return: A list of node contacts (C{kademlia.contact.Contact instances})
closest to the specified key.
This method will return C{k} (or C{count}, if specified)
contacts if at all possible; it will only return fewer if the
node is returning all of the contacts that it knows of.
@rtype: list
"""
def getContact(self, contactID):
""" Returns the (known) contact with the specified node ID
@raise ValueError: No contact with the specified contact ID is known
by this node
"""
def getRefreshList(self, startIndex=0, force=False):
""" Finds all k-buckets that need refreshing, starting at the
k-bucket with the specified index, and returns IDs to be searched for
in order to refresh those k-buckets
@param startIndex: The index of the bucket to start refreshing at;
this bucket and those further away from it will
be refreshed. For example, when joining the
network, this node will set this to the index of
the bucket after the one containing it's closest
neighbour.
@type startIndex: index
@param force: If this is C{True}, all buckets (in the specified range)
will be refreshed, regardless of the time they were last
accessed.
@type force: bool
@return: A list of node ID's that the parent node should search for
in order to refresh the routing Table
@rtype: list
"""
def removeContact(self, contactID):
""" Remove the contact with the specified node ID from the routing
table
@param contactID: The node ID of the contact to remove
@type contactID: str
"""
def touchKBucket(self, key):
""" Update the "last accessed" timestamp of the k-bucket which covers
the range containing the specified key in the key/ID space
@param key: A key in the range of the target k-bucket
@type key: str
"""

View file

@ -1,230 +0,0 @@
import logging
from twisted.internet import defer
from lbrynet.dht.distance import Distance
from lbrynet.dht.error import TimeoutError
from lbrynet.dht import constants
log = logging.getLogger(__name__)
def get_contact(contact_list, node_id, address, port):
for contact in contact_list:
if contact.id == node_id and contact.address == address and contact.port == port:
return contact
raise IndexError(node_id)
def expand_peer(compact_peer_info):
host = "{}.{}.{}.{}".format(*compact_peer_info[:4])
port = int.from_bytes(compact_peer_info[4:6], 'big')
peer_node_id = compact_peer_info[6:]
return (peer_node_id, host, port)
class _IterativeFind:
# TODO: use polymorphism to search for a value or node
# instead of using a find_value flag
def __init__(self, node, shortlist, key, rpc, exclude=None):
self.exclude = set(exclude or [])
self.node = node
self.finished_deferred = defer.Deferred()
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# The closest known and active node yet found
self.closest_node = None if not shortlist else shortlist[0]
self.prev_closest_node = None
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
self.shortlist = shortlist
# The search key
self.key = key
# The rpc method name (findValue or findNode)
self.rpc = rpc
# List of active queries; len() indicates number of active probes
self.active_probes = []
# List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply
self.already_contacted = []
# A list of found and known-to-be-active remote nodes (Contact objects)
self.active_contacts = []
# Ensure only one searchIteration call is running at a time
self._search_iteration_semaphore = defer.DeferredSemaphore(1)
self._iteration_count = 0
self.find_value_result = {}
self.pending_iteration_calls = []
@property
def is_find_node_request(self):
return self.rpc == "findNode"
@property
def is_find_value_request(self):
return self.rpc == "findValue"
def is_closer(self, contact):
if not self.closest_node:
return True
return self.distance.is_closer(contact.id, self.closest_node.id)
def getContactTriples(self, result):
if self.is_find_value_request:
contact_triples = result[b'contacts']
else:
contact_triples = result
for contact_tup in contact_triples:
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
raise ValueError("invalid contact triple")
contact_tup[1] = contact_tup[1].decode() # ips are strings
return contact_triples
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
contact_list.sort(key=lambda c: self.distance(c.id))
def extendShortlist(self, contact, result):
# The "raw response" tuple contains the response message and the originating address info
originAddress = (contact.address, contact.port)
if self.finished_deferred.called:
return contact.id
if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored")
if contact.id == self.node.node_id:
return contact.id
if contact not in self.active_contacts:
self.active_contacts.append(contact)
if contact not in self.shortlist:
self.shortlist.append(contact)
# Now grow extend the (unverified) shortlist with the returned contacts
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
if self.is_find_value_request and self.key in result:
# We have found the value
for peer in result[self.key]:
node_id, host, port = expand_peer(peer)
if (host, port) not in self.exclude:
self.find_value_result.setdefault(self.key, []).append((node_id, host, port))
if self.find_value_result:
self.finished_deferred.callback(self.find_value_result)
else:
if self.is_find_value_request:
# We are looking for a value, and the remote node didn't have it
# - mark it as the closest "empty" node, if it is
# TODO: store to this peer after finding the value as per the kademlia spec
if b'closestNodeNoValue' in self.find_value_result:
if self.is_closer(contact):
self.find_value_result[b'closestNodeNoValue'] = contact
else:
self.find_value_result[b'closestNodeNoValue'] = contact
contactTriples = self.getContactTriples(result)
for contactTriple in contactTriples:
if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted):
continue
elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])):
continue
else:
found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1],
contactTriple[2], self.node._protocol)
if found_contact not in self.shortlist:
self.shortlist.append(found_contact)
if not self.finished_deferred.called and self.should_stop():
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
return contact.id
@defer.inlineCallbacks
def probeContact(self, contact):
fn = getattr(contact, self.rpc)
try:
response = yield fn(self.key)
result = self.extendShortlist(contact, response)
defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id)
def should_stop(self):
if self.is_find_value_request:
# search stops when it finds a value, let it run
return False
if self.prev_closest_node and self.closest_node and self.distance.is_closer(self.prev_closest_node.id,
self.closest_node.id):
# we're getting further away
return True
if len(self.active_contacts) >= constants.k:
# we have enough results
return True
return False
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
def _searchIteration(self):
# Sort the discovered active nodes from closest to furthest
if len(self.active_contacts):
self.sortByDistance(self.active_contacts)
self.prev_closest_node = self.closest_node
self.closest_node = self.active_contacts[0]
# Sort the current shortList before contacting other nodes
self.sortByDistance(self.shortlist)
probes = []
already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted}
to_remove = []
for contact in self.shortlist:
if self.node.contact_manager.is_ignored((contact.address, contact.port)):
to_remove.append(contact) # a contact became bad during iteration
continue
if (contact.address, contact.port) not in already_contacted_addresses:
self.already_contacted.append(contact)
to_remove.append(contact)
probe = self.probeContact(contact)
probes.append(probe)
self.active_probes.append(probe)
if len(probes) == constants.alpha:
break
for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully
self.shortlist.remove(contact)
# run the probes
if probes:
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
self.searchIteration()
d = defer.DeferredList(probes, consumeErrors=True)
def _remove_probes(results):
for probe in probes:
self.active_probes.remove(probe)
return results
d.addCallback(_remove_probes)
elif not self.finished_deferred.called and not self.active_probes or self.should_stop():
# If no probes were sent, there will not be any improvement, so we're done
if self.is_find_value_request:
self.finished_deferred.callback(self.find_value_result)
else:
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
elif not self.finished_deferred.called:
# Force the next iteration
self.searchIteration()
def searchIteration(self, delay=constants.iterativeLookupDelay):
def _cancel_pending_iterations(result):
while self.pending_iteration_calls:
canceller = self.pending_iteration_calls.pop()
canceller()
return result
self.finished_deferred.addBoth(_cancel_pending_iterations)
self._iteration_count += 1
call, cancel = self.node.reactor_callLater(delay, self._search_iteration_semaphore.run, self._searchIteration)
self.pending_iteration_calls.append(cancel)
def iterativeFind(node, shortlist, key, rpc, exclude=None):
helper = _IterativeFind(node, shortlist, key, rpc, exclude)
helper.searchIteration(0)
return helper.finished_deferred

View file

@ -1,147 +0,0 @@
import logging
from lbrynet.dht import constants
from lbrynet.dht.distance import Distance
from lbrynet.dht.error import BucketFull
log = logging.getLogger(__name__)
class KBucket:
""" Description - later
"""
def __init__(self, rangeMin, rangeMax, node_id):
"""
@param rangeMin: The lower boundary for the range in the n-bit ID
space covered by this k-bucket
@param rangeMax: The upper boundary for the range in the ID space
covered by this k-bucket
"""
self.lastAccessed = 0
self.rangeMin = rangeMin
self.rangeMax = rangeMax
self._contacts = list()
self._node_id = node_id
def addContact(self, contact):
""" Add contact to _contact list in the right order. This will move the
contact to the end of the k-bucket if it is already present.
@raise kademlia.kbucket.BucketFull: Raised when the bucket is full and
the contact isn't in the bucket
already
@param contact: The contact to add
@type contact: dht.contact._Contact
"""
if contact in self._contacts:
# Move the existing contact to the end of the list
# - using the new contact to allow add-on data
# (e.g. optimization-specific stuff) to pe updated as well
self._contacts.remove(contact)
self._contacts.append(contact)
elif len(self._contacts) < constants.k:
self._contacts.append(contact)
else:
raise BucketFull("No space in bucket to insert contact")
def getContact(self, contactID):
"""Get the contact specified node ID
@raise IndexError: raised if the contact is not in the bucket
@param contactID: the node id of the contact to retrieve
@type contactID: str
@rtype: dht.contact._Contact
"""
for contact in self._contacts:
if contact.id == contactID:
return contact
raise IndexError(contactID)
def getContacts(self, count=-1, excludeContact=None, sort_distance_to=None):
""" Returns a list containing up to the first count number of contacts
@param count: The amount of contacts to return (if 0 or less, return
all contacts)
@type count: int
@param excludeContact: A node id to exclude; if this contact is in
the list of returned values, it will be
discarded before returning. If a C{str} is
passed as this argument, it must be the
contact's ID.
@type excludeContact: str
@param sort_distance_to: Sort distance to the id, defaulting to the parent node id. If False don't
sort the contacts
@raise IndexError: If the number of requested contacts is too large
@return: Return up to the first count number of contacts in a list
If no contacts are present an empty is returned
@rtype: list
"""
contacts = [contact for contact in self._contacts if contact.id != excludeContact]
# Return all contacts in bucket
if count <= 0:
count = len(contacts)
# Get current contact number
currentLen = len(contacts)
# If count greater than k - return only k contacts
if count > constants.k:
count = constants.k
if not currentLen:
return contacts
if sort_distance_to is False:
pass
else:
sort_distance_to = sort_distance_to or self._node_id
contacts.sort(key=lambda c: Distance(sort_distance_to)(c.id))
return contacts[:min(currentLen, count)]
def getBadOrUnknownContacts(self):
contacts = self.getContacts(sort_distance_to=False)
results = [contact for contact in contacts if contact.contact_is_good is False]
results.extend(contact for contact in contacts if contact.contact_is_good is None)
return results
def removeContact(self, contact):
""" Remove the contact from the bucket
@param contact: The contact to remove
@type contact: dht.contact._Contact
@raise ValueError: The specified contact is not in this bucket
"""
self._contacts.remove(contact)
def keyInRange(self, key):
""" Tests whether the specified key (i.e. node ID) is in the range
of the n-bit ID space covered by this k-bucket (in otherwords, it
returns whether or not the specified key should be placed in this
k-bucket)
@param key: The key to test
@type key: str or int
@return: C{True} if the key is in this k-bucket's range, or C{False}
if not.
@rtype: bool
"""
if isinstance(key, bytes):
key = int.from_bytes(key, 'big')
return self.rangeMin <= key < self.rangeMax
def __len__(self):
return len(self._contacts)
def __contains__(self, item):
return item in self._contacts

View file

@ -1,90 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
from lbrynet.dht import msgtypes
class MessageTranslator:
""" Interface for RPC message translators/formatters
Classes inheriting from this should provide a translation services between
the classes used internally by this Kademlia implementation and the actual
data that is transmitted between nodes.
"""
def fromPrimitive(self, msgPrimitive):
""" Create an RPC Message from a message's string representation
@param msgPrimitive: The unencoded primitive representation of a message
@type msgPrimitive: str, int, list or dict
@return: The translated message object
@rtype: entangled.kademlia.msgtypes.Message
"""
def toPrimitive(self, message):
""" Create a string representation of a message
@param message: The message object
@type message: msgtypes.Message
@return: The message's primitive representation in a particular
messaging format
@rtype: str, int, list or dict
"""
class DefaultFormat(MessageTranslator):
""" The default on-the-wire message format for this library """
typeRequest, typeResponse, typeError = range(3)
headerType, headerMsgID, headerNodeID, headerPayload, headerArgs = range(5) # TODO: make str
@staticmethod
def get(primitive, key):
try:
return primitive[key]
except KeyError:
return primitive[str(key)] # TODO: switch to int()
def fromPrimitive(self, msgPrimitive):
msgType = self.get(msgPrimitive, self.headerType)
if msgType == self.typeRequest:
msg = msgtypes.RequestMessage(self.get(msgPrimitive, self.headerNodeID),
self.get(msgPrimitive, self.headerPayload),
self.get(msgPrimitive, self.headerArgs),
self.get(msgPrimitive, self.headerMsgID))
elif msgType == self.typeResponse:
msg = msgtypes.ResponseMessage(self.get(msgPrimitive, self.headerMsgID),
self.get(msgPrimitive, self.headerNodeID),
self.get(msgPrimitive, self.headerPayload))
elif msgType == self.typeError:
msg = msgtypes.ErrorMessage(self.get(msgPrimitive, self.headerMsgID),
self.get(msgPrimitive, self.headerNodeID),
self.get(msgPrimitive, self.headerPayload),
self.get(msgPrimitive, self.headerArgs))
else:
# Unknown message, no payload
msg = msgtypes.Message(msgPrimitive[self.headerMsgID], msgPrimitive[self.headerNodeID])
return msg
def toPrimitive(self, message):
msg = {self.headerMsgID: message.id,
self.headerNodeID: message.nodeID}
if isinstance(message, msgtypes.RequestMessage):
msg[self.headerType] = self.typeRequest
msg[self.headerPayload] = message.request
msg[self.headerArgs] = message.args
elif isinstance(message, msgtypes.ErrorMessage):
msg[self.headerType] = self.typeError
msg[self.headerPayload] = message.exceptionType
msg[self.headerArgs] = message.response
elif isinstance(message, msgtypes.ResponseMessage):
msg[self.headerType] = self.typeResponse
msg[self.headerPayload] = message.response
return msg

View file

@ -1,52 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
from lbrynet.utils import generate_id
from lbrynet.dht import constants
class Message:
""" Base class for messages - all "unknown" messages use this class """
def __init__(self, rpcID, nodeID):
if len(rpcID) != constants.rpc_id_length:
raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID))
if len(nodeID) != constants.key_bits // 8:
raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID))
self.id = rpcID
self.nodeID = nodeID
class RequestMessage(Message):
""" Message containing an RPC request """
def __init__(self, nodeID, method, methodArgs, rpcID=None):
if rpcID is None:
rpcID = generate_id()[:constants.rpc_id_length]
super().__init__(rpcID, nodeID)
self.request = method
self.args = methodArgs
class ResponseMessage(Message):
""" Message containing the result from a successful RPC request """
def __init__(self, rpcID, nodeID, response):
super().__init__(rpcID, nodeID)
self.response = response
class ErrorMessage(ResponseMessage):
""" Message containing the error from an unsuccessful RPC request """
def __init__(self, rpcID, nodeID, exceptionType, errorMessage):
super().__init__(rpcID, nodeID, errorMessage)
if isinstance(exceptionType, type):
exceptionType = (f'{exceptionType.__module__}.{exceptionType.__name__}').encode()
self.exceptionType = exceptionType

View file

@ -1,649 +1,226 @@
import binascii
import hashlib
import logging
from functools import reduce
import asyncio
import typing
import socket
import binascii
import contextlib
from lbrynet.dht import constants
from lbrynet.dht.error import RemoteException
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
from lbrynet.dht.protocol.distance import Distance
from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder
from lbrynet.dht.protocol.protocol import KademliaProtocol
from lbrynet.dht.peer import KademliaPeer
from twisted.internet import defer, error, task
from lbrynet.utils import generate_id, DeferredDict
from lbrynet.dht.call_later_manager import CallLaterManager
from lbrynet.dht.error import TimeoutError
from lbrynet.dht import constants, routingtable, datastore, protocol
from lbrynet.dht.contact import ContactManager
from lbrynet.dht.iterativefind import iterativeFind
if typing.TYPE_CHECKING:
from lbrynet.dht.peer import PeerManager
log = logging.getLogger(__name__)
def rpcmethod(func):
""" Decorator to expose Node methods as remote procedure calls
class Node:
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int,
internal_udp_port: int, peer_port: int, external_ip: str):
self.loop = loop
self.internal_udp_port = internal_udp_port
self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port)
self.listening_port: asyncio.DatagramTransport = None
self.joined = asyncio.Event(loop=self.loop)
self._join_task: asyncio.Task = None
self._refresh_task: asyncio.Task = None
Apply this decorator to methods in the Node class (or a subclass) in order
to make them remotely callable via the DHT's RPC mechanism.
"""
func.rpcmethod = True
return func
async def refresh_node(self):
while True:
# remove peers with expired blob announcements from the datastore
self.protocol.data_store.removed_expired_peers()
total_peers: typing.List['KademliaPeer'] = []
# add all peers in the routing table
total_peers.extend(self.protocol.routing_table.get_peers())
# add all the peers who have announed blobs to us
total_peers.extend(self.protocol.data_store.get_storing_contacts())
class MockKademliaHelper:
def __init__(self, clock=None, callLater=None, resolve=None, listenUDP=None):
if not listenUDP or not resolve or not callLater or not clock:
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
# get ids falling in the midpoint of each bucket that hasn't been recently updated
node_ids = self.protocol.routing_table.get_refresh_list(0, True)
# if we have 3 or fewer populated buckets get two random ids in the range of each to try and
# populate/split the buckets further
buckets_with_contacts = self.protocol.routing_table.buckets_with_contacts()
if buckets_with_contacts <= 3:
for i in range(buckets_with_contacts):
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
self.clock = clock
self.contact_manager = ContactManager(self.clock.seconds)
self.reactor_listenUDP = listenUDP
self.reactor_resolve = resolve
self.call_later_manager = CallLaterManager(callLater)
self.reactor_callLater = self.call_later_manager.call_later
self.reactor_callSoon = self.call_later_manager.call_soon
self._listeningPort = None # object implementing Twisted
# IListeningPort This will contain a deferred created when
# joining the network, to enable publishing/retrieving
# information from the DHT as soon as the node is part of the
# network (add callbacks to this deferred if scheduling such
# operations before the node has finished joining the network)
def get_looping_call(self, fn, *args, **kwargs):
lc = task.LoopingCall(fn, *args, **kwargs)
lc.clock = self.clock
return lc
def safe_stop_looping_call(self, lc):
if lc and lc.running:
return lc.stop()
return defer.succeed(None)
def safe_start_looping_call(self, lc, t):
if lc and not lc.running:
lc.start(t)
class Node(MockKademliaHelper):
""" Local node in the Kademlia network
This class represents a single local node in a Kademlia network; in other
words, this class encapsulates an Entangled-using application's "presence"
in a Kademlia network.
In Entangled, all interactions with the Kademlia network by a client
application is performed via this class (or a subclass).
"""
def __init__(self, node_id=None, udpPort=4000, dataStore=None,
routingTableClass=None, networkProtocol=None,
externalIP=None, peerPort=3333, listenUDP=None,
callLater=None, resolve=None, clock=None,
interface='', externalUDPPort=None):
"""
@param dataStore: The data store to use. This must be class inheriting
from the C{DataStore} interface (or providing the
same API). How the data store manages its data
internally is up to the implementation of that data
store.
@type dataStore: entangled.kademlia.datastore.DataStore
@param routingTable: The routing table class to use. Since there exists
some ambiguity as to how the routing table should be
implemented in Kademlia, a different routing table
may be used, as long as the appropriate API is
exposed. This should be a class, not an object,
in order to allow the Node to pass an
auto-generated node ID to the routingtable object
upon instantiation (if necessary).
@type routingTable: entangled.kademlia.routingtable.RoutingTable
@param networkProtocol: The network protocol to use. This can be
overridden from the default to (for example)
change the format of the physical RPC messages
being transmitted.
@type networkProtocol: entangled.kademlia.protocol.KademliaProtocol
@param externalIP: the IP at which this node can be contacted
@param peerPort: the port at which this node announces it has a blob for
"""
super().__init__(clock, callLater, resolve, listenUDP)
self.node_id = node_id or self._generateID()
self.port = udpPort
self._listen_interface = interface
self._change_token_lc = self.get_looping_call(self.change_token)
self._refresh_node_lc = self.get_looping_call(self._refreshNode)
self._refresh_contacts_lc = self.get_looping_call(self._refreshContacts)
# Create k-buckets (for storing contacts)
if routingTableClass is None:
self._routingTable = routingtable.TreeRoutingTable(self.node_id, self.clock.seconds)
else:
self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
self._protocol = networkProtocol or protocol.KademliaProtocol(self)
self.token_secret = self._generateID()
self.old_token_secret = None
self.externalIP = externalIP
self.peerPort = peerPort
self.externalUDPPort = externalUDPPort or self.port
self._dataStore = dataStore or datastore.DictDataStore(self.clock.seconds)
self._join_deferred = None
#def __del__(self):
# log.warning("unclean shutdown of the dht node")
# if hasattr(self, "_listeningPort") and self._listeningPort is not None:
# self._listeningPort.stopListening()
def __str__(self):
return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port)
@defer.inlineCallbacks
def stop(self):
# stop LoopingCalls:
yield self.safe_stop_looping_call(self._refresh_node_lc)
yield self.safe_stop_looping_call(self._change_token_lc)
yield self.safe_stop_looping_call(self._refresh_contacts_lc)
if self._listeningPort is not None:
yield self._listeningPort.stopListening()
self._listeningPort = None
def start_listening(self):
if not self._listeningPort:
try:
self._listeningPort = self.reactor_listenUDP(self.port, self._protocol,
interface=self._listen_interface)
except error.CannotListenError as e:
import traceback
log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc())
raise ValueError("%s lbrynet may already be running." % str(e))
else:
log.warning("Already bound to port %s", self._listeningPort)
@defer.inlineCallbacks
def joinNetwork(self, known_node_addresses=(('jack.lbry.tech', 4455), )):
"""
Attempt to join the dht, retry every 30 seconds if unsuccessful
:param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes
"""
self._join_deferred = defer.Deferred()
known_node_resolution = {}
@defer.inlineCallbacks
def _resolve_seeds():
result = {}
for host, port in known_node_addresses:
node_address = yield self.reactor_resolve(host)
result[(host, port)] = node_address
defer.returnValue(result)
if not known_node_resolution:
known_node_resolution = yield _resolve_seeds()
# we are one of the seed nodes, don't add ourselves
if (self.externalIP, self.port) in known_node_resolution.values():
del known_node_resolution[(self.externalIP, self.port)]
known_node_addresses.remove((self.externalIP, self.port))
def _ping_contacts(contacts):
d = DeferredDict({contact: contact.ping() for contact in contacts}, consumeErrors=True)
d.addErrback(lambda err: err.trap(TimeoutError))
return d
@defer.inlineCallbacks
def _initialize_routing():
bootstrap_contacts = []
contact_addresses = {(c.address, c.port): c for c in self.contacts}
for (host, port), ip_address in known_node_resolution.items():
if (host, port) not in contact_addresses:
# Create temporary contact information for the list of addresses of known nodes
# The contact node id will be set with the responding node id when we initialize it to None
contact = self.contact_manager.make_contact(None, ip_address, port, self._protocol)
bootstrap_contacts.append(contact)
else:
for contact in self.contacts:
if contact.address == ip_address and contact.port == port:
if not contact.id:
bootstrap_contacts.append(contact)
break
if not bootstrap_contacts:
log.warning("no bootstrap contacts to ping")
ping_result = yield _ping_contacts(bootstrap_contacts)
shortlist = list(ping_result.keys())
if not shortlist:
log.warning("failed to ping %i bootstrap contacts", len(bootstrap_contacts))
defer.returnValue(None)
if self.protocol.routing_table.get_peers():
# if we have node ids to look up, perform the iterative search until we have k results
while node_ids:
peers = await self.peer_search(node_ids.pop())
total_peers.extend(peers)
else:
# find the closest peers to us
closest = yield self._iterativeFind(self.node_id, shortlist if not self.contacts else None)
yield _ping_contacts(closest)
# query random hashes in our bucket key ranges to fill or split them
random_ids_in_range = self._routingTable.getRefreshList()
while random_ids_in_range:
yield self.iterativeFindNode(random_ids_in_range.pop())
defer.returnValue(None)
fut = asyncio.Future(loop=self.loop)
self.loop.call_later(constants.refresh_interval // 4, fut.set_result, None)
await fut
continue
@defer.inlineCallbacks
def _iterative_join(joined_d=None, last_buckets_with_contacts=None):
log.info("Attempting to join the DHT network, %i contacts known so far", len(self.contacts))
joined_d = joined_d or defer.Deferred()
yield _initialize_routing()
buckets_with_contacts = self.bucketsWithContacts()
if last_buckets_with_contacts and last_buckets_with_contacts == buckets_with_contacts:
if not joined_d.called:
joined_d.callback(True)
elif buckets_with_contacts < 4:
self.reactor_callLater(0, _iterative_join, joined_d, buckets_with_contacts)
elif not joined_d.called:
joined_d.callback(None)
yield joined_d
if not self._join_deferred.called:
self._join_deferred.callback(True)
defer.returnValue(None)
# ping the set of peers; upon success/failure the routing able and last replied/failed time will be updated
to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True]
if to_ping:
await self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0)
yield _iterative_join()
fut = asyncio.Future(loop=self.loop)
self.loop.call_later(constants.refresh_interval, fut.set_result, None)
await fut
@defer.inlineCallbacks
def start(self, known_node_addresses=None, block_on_join=False):
""" Causes the Node to attempt to join the DHT network by contacting the
known DHT nodes. This can be called multiple times if the previous attempt
has failed or if the Node has lost all the contacts.
async def announce_blob(self, blob_hash: str) -> typing.List[bytes]:
announced_to_node_ids = []
while not announced_to_node_ids:
hash_value = binascii.unhexlify(blob_hash.encode())
assert len(hash_value) == constants.hash_length
peers = await self.peer_search(hash_value)
@param known_node_addresses: A sequence of tuples containing IP address
information for existing nodes on the
Kademlia network, in the format:
C{(<ip address>, (udp port>)}
@type known_node_addresses: list
"""
if not self.protocol.external_ip:
raise Exception("Cannot determine external IP")
log.info("Store to %i peers", len(peers))
log.info(peers)
for peer in peers:
log.info("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port)
stored_to_tup = await asyncio.gather(
*(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop
)
announced_to_node_ids.extend([node_id for node_id, contacted in stored_to_tup if contacted])
log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8],
len(announced_to_node_ids), len(peers))
return announced_to_node_ids
self.start_listening()
yield self._protocol._listening
# TODO: Refresh all k-buckets further away than this node's closest neighbour
d = self.joinNetwork(known_node_addresses or [])
d.addCallback(lambda _: self.start_looping_calls())
d.addCallback(lambda _: log.info("Joined the dht"))
if block_on_join:
yield d
def stop(self) -> None:
if self.joined.is_set():
self.joined.clear()
if self._join_task:
self._join_task.cancel()
if self._refresh_task and not (self._refresh_task.done() or self._refresh_task.cancelled()):
self._refresh_task.cancel()
if self.protocol and self.protocol.ping_queue.running:
self.protocol.ping_queue.stop()
if self.listening_port is not None:
self.listening_port.close()
self._join_task = None
self.listening_port = None
log.info("Stopped DHT node")
def start_looping_calls(self):
self.safe_start_looping_call(self._change_token_lc, constants.tokenSecretChangeInterval)
# Start refreshing k-buckets periodically, if necessary
self.safe_start_looping_call(self._refresh_node_lc, constants.checkRefreshInterval)
self.safe_start_looping_call(self._refresh_contacts_lc, 60)
async def start_listening(self, interface: str = '') -> None:
if not self.listening_port:
self.listening_port, _ = await self.loop.create_datagram_endpoint(
lambda: self.protocol, (interface, self.internal_udp_port)
)
log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port)
else:
log.warning("Already bound to port %s", self.listening_port)
@property
def contacts(self):
def _inner():
for i in range(len(self._routingTable._buckets)):
for contact in self._routingTable._buckets[i]._contacts:
yield contact
return list(_inner())
async def join_network(self, interface: typing.Optional[str] = '',
known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
known_node_addresses: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
if not self.listening_port:
await self.start_listening(interface)
self.protocol.ping_queue.start()
self._refresh_task = self.loop.create_task(self.refresh_node())
def hasContacts(self):
for bucket in self._routingTable._buckets:
if bucket._contacts:
return True
return False
known_node_addresses = known_node_addresses or []
if known_node_urls:
for host, port in known_node_urls:
info = await self.loop.getaddrinfo(
host, 'https',
proto=socket.IPPROTO_TCP,
)
if (info[0][4][0], port) not in known_node_addresses:
known_node_addresses.append((info[0][4][0], port))
futs = []
for address, port in known_node_addresses:
peer = self.protocol.get_rpc_peer(KademliaPeer(self.loop, address, udp_port=port))
futs.append(peer.ping())
if futs:
await asyncio.wait(futs, loop=self.loop)
def bucketsWithContacts(self):
return self._routingTable.bucketsWithContacts()
async with self.peer_search_junction(self.protocol.node_id, max_results=16) as junction:
async for peers in junction:
for peer in peers:
try:
await self.protocol.get_rpc_peer(peer).ping()
except (asyncio.TimeoutError, RemoteException):
pass
self.joined.set()
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
self.protocol.routing_table.buckets_with_contacts())
@defer.inlineCallbacks
def storeToContact(self, blob_hash, contact):
def start(self, interface: str, known_node_urls: typing.List[typing.Tuple[str, int]]):
self._join_task = self.loop.create_task(
self.join_network(
interface=interface, known_node_urls=known_node_urls
)
)
def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
bottom_out_limit: int = constants.bottom_out_limit,
max_results: int = constants.k) -> IterativeNodeFinder:
return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
key, bottom_out_limit, max_results, None, shortlist)
def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
bottom_out_limit: int = 40,
max_results: int = -1) -> IterativeValueFinder:
return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
key, bottom_out_limit, max_results, None, shortlist)
@contextlib.asynccontextmanager
async def stream_peer_search_junction(self, hash_queue: asyncio.Queue, bottom_out_limit=20,
max_results=-1) -> AsyncGeneratorJunction:
peer_generator = AsyncGeneratorJunction(self.loop)
async def _add_hashes_from_queue():
while True:
try:
blob_hash = await hash_queue.get()
except asyncio.CancelledError:
break
peer_generator.add_generator(
self.get_iterative_value_finder(
binascii.unhexlify(blob_hash.encode()), bottom_out_limit=bottom_out_limit,
max_results=max_results
)
)
add_hashes_task = self.loop.create_task(_add_hashes_from_queue())
try:
if not contact.token:
yield contact.findValue(blob_hash)
res = yield contact.store(blob_hash, contact.token, self.peerPort, self.node_id, 0)
if res != b"OK":
raise ValueError(res)
log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
return True
except protocol.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s",
binascii.hexlify(blob_hash), contact.log_id())
except ValueError as err:
log.error("Unexpected response: %s" % err)
except Exception as err:
if 'Invalid token' in str(err):
contact.update_token(None)
log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err)
return False
async with peer_generator as junction:
yield junction
await peer_generator.finished.wait()
except asyncio.CancelledError:
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
add_hashes_task.cancel()
raise
finally:
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
add_hashes_task.cancel()
@defer.inlineCallbacks
def announceHaveBlob(self, blob_hash):
contacts = yield self.iterativeFindNode(blob_hash)
def peer_search_junction(self, node_id: bytes, max_results=constants.k*2,
bottom_out_limit=20) -> AsyncGeneratorJunction:
peer_generator = AsyncGeneratorJunction(self.loop)
peer_generator.add_generator(
self.get_iterative_node_finder(
node_id, bottom_out_limit=bottom_out_limit, max_results=max_results
)
)
return peer_generator
if not self.externalIP:
raise Exception("Cannot determine external IP: %s" % self.externalIP)
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]]
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
len(contacted_node_ids), len(contacts))
defer.returnValue(contacted_node_ids)
def change_token(self):
self.old_token_secret = self.token_secret
self.token_secret = self._generateID()
def make_token(self, compact_ip):
h = hashlib.new('sha384')
h.update(self.token_secret + compact_ip)
return h.digest()
def verify_token(self, token, compact_ip):
h = hashlib.new('sha384')
h.update(self.token_secret + compact_ip)
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
h = hashlib.new('sha384')
h.update(self.old_token_secret + compact_ip)
if not token == h.digest():
return False
return True
def iterativeFindNode(self, key):
""" The basic Kademlia node lookup operation
Call this to find a remote node in the P2P overlay network.
@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@return: This immediately returns a deferred object, which will return
a list of k "closest" contacts (C{kademlia.contact.Contact}
objects) to the specified key as soon as the operation is
finished.
@rtype: twisted.internet.defer.Deferred
"""
return self._iterativeFind(key)
@defer.inlineCallbacks
def iterativeFindValue(self, key, exclude=None):
""" The Kademlia search operation (deterministic)
Call this to retrieve data from the DHT.
@param key: the n-bit key (i.e. the value ID) to search for
@type key: str
@return: This immediately returns a deferred object, which will return
either one of two things:
- If the value was found, it will return a Python
dictionary containing the searched-for key (the C{key}
parameter passed to this method), and its associated
value, in the format:
C{<str>key: <str>data_value}
- If the value was not found, it will return a list of k
"closest" contacts (C{kademlia.contact.Contact} objects)
to the specified key
@rtype: twisted.internet.defer.Deferred
"""
if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length!")
# Execute the search
find_result = yield self._iterativeFind(key, rpc='findValue', exclude=exclude)
if isinstance(find_result, dict):
# We have found the value; now see who was the closest contact without it...
# ...and store the key/value pair
pass
else:
# The value wasn't found, but a list of contacts was returned
# Now, see if we have the value (it might seem wasteful to search on the network
# first, but it ensures that all values are properly propagated through the
# network
if self._dataStore.hasPeersForBlob(key):
# Ok, we have the value locally, so use that
# Send this value to the closest node without it
peers = self._dataStore.getPeersForBlob(key)
find_result = {key: peers}
else:
pass
defer.returnValue(list(set(find_result.get(key, []) if find_result else [])))
# TODO: get this working
# if 'closestNodeNoValue' in find_result:
# closest_node_without_value = find_result['closestNodeNoValue']
# try:
# response, address = yield closest_node_without_value.findValue(key, rawResponse=True)
# yield closest_node_without_value.store(key, response.response['token'], self.peerPort)
# except TimeoutError:
# pass
def addContact(self, contact):
""" Add/update the given contact; simple wrapper for the same method
in this object's RoutingTable object
@param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact
"""
return self._routingTable.addContact(contact)
def removeContact(self, contact):
""" Remove the contact with the specified node ID from this node's
table of known nodes. This is a simple wrapper for the same method
in this object's RoutingTable object
@param contact: The Contact object to remove
@type contact: _Contact
"""
self._routingTable.removeContact(contact)
def findContact(self, contactID):
""" Find a entangled.kademlia.contact.Contact object for the specified
cotact ID
@param contactID: The contact ID of the required Contact object
@type contactID: str
@return: Contact object of remote node with the specified node ID,
or None if the contact was not found
@rtype: twisted.internet.defer.Deferred
"""
try:
df = defer.succeed(self._routingTable.getContact(contactID))
except (ValueError, IndexError):
df = self.iterativeFindNode(contactID)
df.addCallback(lambda nodes: ([node for node in nodes if node.id == contactID] or (None,))[0])
return df
@rpcmethod
def ping(self):
""" Used to verify contact between two Kademlia nodes
@rtype: str
"""
return b'pong'
@rpcmethod
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
""" Store the received data in this node's local datastore
@param blob_hash: The hash of the data
@type blob_hash: str
@param token: The token we previously returned when this contact sent us a findValue
@type token: str
@param port: The TCP port the contact is listening on for requests for this blob (the peerPort)
@type port: int
@param originalPublisherID: The node ID of the node that is the publisher of the data
@type originalPublisherID: str
@param age: The relative age of the data (time in seconds since it was
originally published). Note that the original publish time
isn't actually given, to compensate for clock skew between
different nodes.
@type age: int
@rtype: str
"""
if originalPublisherID is None:
originalPublisherID = rpc_contact.id
compact_ip = rpc_contact.compact_ip()
if self.clock.seconds() - self._protocol.started_listening_time < constants.tokenSecretChangeInterval:
pass
elif not self.verify_token(token, compact_ip):
raise ValueError("Invalid token")
if 0 <= port <= 65536:
compact_port = port.to_bytes(2, 'big')
else:
raise TypeError(f'Invalid port: {port}')
compact_address = compact_ip + compact_port + rpc_contact.id
now = int(self.clock.seconds())
originallyPublished = now - age
self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished,
originalPublisherID)
return b'OK'
@rpcmethod
def findNode(self, rpc_contact, key):
""" Finds a number of known nodes closest to the node/value with the
specified key.
@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@return: A list of contact triples closest to the specified key.
This method will return C{k} (or C{count}, if specified)
contacts if at all possible; it will only return fewer if the
node is returning all of the contacts that it knows of.
@rtype: list
"""
if len(key) != constants.key_bits // 8:
raise ValueError("invalid contact id length: %i" % len(key))
contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id)
contact_triples = []
for contact in contacts:
contact_triples.append((contact.id, contact.address, contact.port))
return contact_triples
@rpcmethod
def findValue(self, rpc_contact, key):
""" Return the value associated with the specified key if present in
this node's data, otherwise execute FIND_NODE for the key
@param key: The hashtable key of the data to return
@type key: str
@return: A dictionary containing the requested key/value pair,
or a list of contact triples closest to the requested key.
@rtype: dict or list
"""
if len(key) != constants.key_bits // 8:
raise ValueError("invalid blob hash length: %i" % len(key))
response = {
b'token': self.make_token(rpc_contact.compact_ip()),
}
if self._protocol._protocolVersion:
response[b'protocolVersion'] = self._protocol._protocolVersion
# get peers we have stored for this blob
has_other_peers = self._dataStore.hasPeersForBlob(key)
peers = []
if has_other_peers:
peers.extend(self._dataStore.getPeersForBlob(key))
# if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and key in self._dataStore.completed_blobs:
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
compact_port = self.peerPort.to_bytes(2, 'big')
compact_address = compact_ip + compact_port + self.node_id
peers.append(compact_address)
if peers:
response[key] = peers
else:
response[b'contacts'] = self.findNode(rpc_contact, key)
return response
def _generateID(self):
""" Generates an n-bit pseudo-random identifier
@return: A globally unique n-bit pseudo-random identifier
@rtype: str
"""
return generate_id()
# from lbrynet.p2p.utils import profile_deferred
# @profile_deferred()
@defer.inlineCallbacks
def _iterativeFind(self, key, startupShortlist=None, rpc='findNode', exclude=None):
""" The basic Kademlia iterative lookup operation (for nodes/values)
This builds a list of k "closest" contacts through iterative use of
the "FIND_NODE" RPC, or if C{findValue} is set to C{True}, using the
"FIND_VALUE" RPC, in which case the value (if found) may be returned
instead of a list of contacts
@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@param startupShortlist: A list of contacts to use as the starting
shortlist for this search; this is normally
only used when the node joins the network
@type startupShortlist: list
@param rpc: The name of the RPC to issue to remote nodes during the
Kademlia lookup operation (e.g. this sets whether this
algorithm should search for a data value (if
rpc='findValue') or not. It can thus be used to perform
other operations that piggy-back on the basic Kademlia
lookup operation (Entangled's "delete" RPC, for instance).
@type rpc: str
@return: If C{findValue} is C{True}, the algorithm will stop as soon
as a data value for C{key} is found, and return a dictionary
containing the key and the found value. Otherwise, it will
return a list of the k closest nodes to the specified key
@rtype: twisted.internet.defer.Deferred
"""
if len(key) != constants.key_bits // 8:
raise ValueError("invalid key length: %i" % len(key))
if startupShortlist is None:
shortlist = self._routingTable.findCloseNodes(key)
# if key != self.node_id:
# # Update the "last accessed" timestamp for the appropriate k-bucket
# self._routingTable.touchKBucket(key)
if len(shortlist) == 0:
log.warning("This node doesn't know any other nodes")
# This node doesn't know of any other nodes
fakeDf = defer.Deferred()
fakeDf.callback([])
result = yield fakeDf
defer.returnValue(result)
else:
# This is used during the bootstrap process
shortlist = startupShortlist
result = yield iterativeFind(self, shortlist, key, rpc, exclude=exclude)
defer.returnValue(result)
@defer.inlineCallbacks
def _refreshNode(self):
""" Periodically called to perform k-bucket refreshes and data
replication/republishing as necessary """
yield self._refreshRoutingTable()
self._dataStore.removeExpiredPeers()
self._refreshStoringPeers()
defer.returnValue(None)
def _refreshContacts(self):
self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts, delay=0)
def _refreshStoringPeers(self):
self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts(), delay=0)
@defer.inlineCallbacks
def _refreshRoutingTable(self):
nodeIDs = self._routingTable.getRefreshList(0, False)
while nodeIDs:
searchID = nodeIDs.pop()
yield self.iterativeFindNode(searchID)
defer.returnValue(None)
async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2,
bottom_out_limit=20) -> typing.List['KademliaPeer']:
accumulated: typing.List['KademliaPeer'] = []
async with self.peer_search_junction(self.protocol.node_id, max_results=max_results,
bottom_out_limit=bottom_out_limit) as junction:
async for peers in junction:
log.info("peer search: %s", peers)
accumulated.extend(peers)
log.info("junction done")
log.info("context done")
distance = Distance(node_id)
accumulated.sort(key=lambda peer: distance(peer.node_id))
return accumulated[:count]

203
lbrynet/dht/peer.py Normal file
View file

@ -0,0 +1,203 @@
import typing
import asyncio
import logging
import ipaddress
from binascii import hexlify
from lbrynet.dht import constants
from lbrynet.dht.serialization.datagram import make_compact_address, make_compact_ip, decode_compact_address
log = logging.getLogger(__name__)
def is_valid_ipv4(address):
try:
ip = ipaddress.ip_address(address)
return ip.version == 4
except ipaddress.AddressValueError:
return False
class PeerManager:
def __init__(self, loop: asyncio.BaseEventLoop):
self._loop = loop
self._rpc_failures: typing.Dict[
typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]]
] = {}
self._last_replied: typing.Dict[typing.Tuple[str, int], float] = {}
self._last_sent: typing.Dict[typing.Tuple[str, int], float] = {}
self._last_requested: typing.Dict[typing.Tuple[str, int], float] = {}
self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {}
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer']
self._lock = asyncio.Lock(loop=loop)
async def report_failure(self, address: str, udp_port: int):
now = self._loop.time()
async with self._lock:
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))
self._rpc_failures[(address, udp_port)] = (previous, now)
async def report_last_sent(self, address: str, udp_port: int):
now = self._loop.time()
async with self._lock:
self._last_sent[(address, udp_port)] = now
async def report_last_replied(self, address: str, udp_port: int):
now = self._loop.time()
async with self._lock:
self._last_replied[(address, udp_port)] = now
async def report_last_requested(self, address: str, udp_port: int):
now = self._loop.time()
async with self._lock:
self._last_requested[(address, udp_port)] = now
async def clear_token(self, node_id: bytes):
async with self._lock:
self._node_tokens.pop(node_id, None)
async def update_token(self, node_id: bytes, token: bytes):
now = self._loop.time()
async with self._lock:
self._node_tokens[node_id] = (now, token)
def get_node_token(self, node_id: bytes) -> typing.Optional[bytes]:
ts, token = self._node_tokens.get(node_id, (None, None))
if ts and ts > self._loop.time() - constants.token_secret_refresh_interval:
return token
def get_last_replied(self, address: str, udp_port: int) -> typing.Optional[float]:
return self._last_replied.get((address, udp_port))
def get_node_id(self, address: str, udp_port: int) -> typing.Optional[bytes]:
return self._node_id_mapping.get((address, udp_port))
def get_node_address(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
return self._node_id_reverse_mapping.get(node_id)
async def get_node_address_and_port(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
async with self._lock:
addr_tuple = self._node_id_reverse_mapping.get(node_id)
if addr_tuple and addr_tuple in self._node_id_mapping:
return addr_tuple
async def update_contact_triple(self, node_id: bytes, address: str, udp_port: int):
"""
Update the mapping of node_id -> address tuple and that of address tuple -> node_id
This is to handle peers changing addresses and ids while assuring that the we only ever have
one node id / address tuple mapped to each other
"""
async with self._lock:
if (address, udp_port) in self._node_id_mapping:
self._node_id_reverse_mapping.pop(self._node_id_mapping.pop((address, udp_port)))
if node_id in self._node_id_reverse_mapping:
self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id))
self._node_id_mapping[(address, udp_port)] = node_id
self._node_id_reverse_mapping[node_id] = (address, udp_port)
def get_kademlia_peer(self, node_id: bytes, address: str, udp_port: int) -> 'KademliaPeer':
return KademliaPeer(self._loop, address, node_id, udp_port)
async def prune(self):
now = self._loop.time()
async with self._lock:
to_pop = []
for (address, udp_port), (_, last_failure) in self._rpc_failures.items():
if last_failure and last_failure < now - constants.rpc_attempts_pruning_window:
to_pop.append((address, udp_port))
while to_pop:
del self._rpc_failures[to_pop.pop()]
to_pop = []
for node_id, (age, token) in self._node_tokens.items():
if age < now - constants.token_secret_refresh_interval:
to_pop.append(node_id)
while to_pop:
del self._node_tokens[to_pop.pop()]
def contact_triple_is_good(self, node_id: bytes, address: str, udp_port: int):
"""
:return: False if peer is bad, None if peer is unknown, or True if peer is good
"""
delay = self._loop.time() - constants.check_refresh_interval
if node_id not in self._node_id_reverse_mapping or (address, udp_port) not in self._node_id_mapping:
return
addr_tup = (address, udp_port)
if self._node_id_reverse_mapping[node_id] != addr_tup or self._node_id_mapping[addr_tup] != node_id:
return
previous_failure, most_recent_failure = self._rpc_failures.get((address, udp_port), (None, None))
last_requested = self._last_requested.get((address, udp_port))
last_replied = self._last_replied.get((address, udp_port))
if most_recent_failure and last_replied:
if delay < last_replied > most_recent_failure:
return True
elif last_replied > most_recent_failure:
return
return False
elif previous_failure and most_recent_failure and most_recent_failure > delay:
return False
elif last_replied and last_replied > delay:
return True
elif last_requested and last_requested > delay:
return None
return
def peer_is_good(self, peer: 'KademliaPeer'):
return self.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port)
def decode_tcp_peer_from_compact_address(self, compact_address: bytes) -> 'KademliaPeer':
node_id, address, tcp_port = decode_compact_address(compact_address)
return KademliaPeer(self._loop, address, node_id, tcp_port=tcp_port)
class KademliaPeer:
def __init__(self, loop: asyncio.BaseEventLoop, address: str, node_id: typing.Optional[bytes] = None,
udp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None):
if node_id is not None:
if not len(node_id) == constants.hash_length:
raise ValueError("invalid node_id: {}".format(hexlify(node_id).decode()))
if udp_port is not None and not 0 <= udp_port <= 65536:
raise ValueError("invalid udp port")
if tcp_port and not 0 <= tcp_port <= 65536:
raise ValueError("invalid tcp port")
if not is_valid_ipv4(address):
raise ValueError("invalid ip address")
self.loop = loop
self._node_id = node_id
self.address = address
self.udp_port = udp_port
self.tcp_port = tcp_port
self.protocol_version = 1
def update_tcp_port(self, tcp_port: int):
self.tcp_port = tcp_port
def update_udp_port(self, udp_port: int):
self.udp_port = udp_port
def set_id(self, node_id):
if not self._node_id:
self._node_id = node_id
@property
def node_id(self) -> bytes:
return self._node_id
def compact_address_udp(self) -> bytearray:
return make_compact_address(self.node_id, self.address, self.udp_port)
def compact_address_tcp(self) -> bytearray:
return make_compact_address(self.node_id, self.address, self.tcp_port)
def compact_ip(self):
return make_compact_ip(self.address)
def __eq__(self, other):
if not isinstance(other, KademliaPeer):
raise TypeError("invalid type to compare with Peer: %s" % str(type(other)))
return (self.node_id, self.address, self.udp_port) == (other.node_id, other.address, other.udp_port)
def __hash__(self):
return hash((self.node_id, self.address, self.udp_port))

View file

@ -1,493 +0,0 @@
import logging
import errno
from binascii import hexlify
from twisted.internet import protocol, defer
from lbrynet.dht import constants, encoding, msgformat, msgtypes
from lbrynet.dht.error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
log = logging.getLogger(__name__)
class PingQueue:
"""
Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the
routing table after having been given enough time for a pinhole to expire.
"""
def __init__(self, node):
self._node = node
self._enqueued_contacts = {}
self._pending_contacts = {}
self._process_lc = node.get_looping_call(self._process)
def enqueue_maybe_ping(self, *contacts, **kwargs):
delay = kwargs.get('delay', constants.checkRefreshInterval)
no_op = (defer.succeed(None), lambda: None)
for contact in contacts:
if delay and contact not in self._enqueued_contacts:
self._pending_contacts.setdefault(contact, self._node.clock.seconds() + delay)
else:
self._enqueued_contacts.setdefault(contact, no_op)
@defer.inlineCallbacks
def _ping(self, contact):
if contact.contact_is_good:
return
try:
yield contact.ping()
except TimeoutError:
pass
except Exception as err:
log.warning("unexpected error: %s", err)
finally:
if contact in self._enqueued_contacts:
del self._enqueued_contacts[contact]
def _process(self):
# move contacts that are scheduled to join the queue
if self._pending_contacts:
now = self._node.clock.seconds()
for contact in [contact for contact, schedule in self._pending_contacts.items() if schedule <= now]:
del self._pending_contacts[contact]
self._enqueued_contacts.setdefault(contact, (defer.succeed(None), lambda: None))
# spread pings across 60 seconds to avoid flood and/or false negatives
step = 60.0/float(len(self._enqueued_contacts)) if self._enqueued_contacts else 0
for index, (contact, (call, _)) in enumerate(self._enqueued_contacts.items()):
if call.called and not contact.contact_is_good:
self._enqueued_contacts[contact] = self._node.reactor_callLater(index*step, self._ping, contact)
def start(self):
return self._node.safe_start_looping_call(self._process_lc, 60)
def stop(self):
map(None, (cancel() for _, (call, cancel) in self._enqueued_contacts.items() if not call.called))
return self._node.safe_stop_looping_call(self._process_lc)
class KademliaProtocol(protocol.DatagramProtocol):
""" Implements all low-level network-related functions of a Kademlia node """
msgSizeLimit = constants.udpDatagramMaxSize - 26
def __init__(self, node):
self._node = node
self._translator = msgformat.DefaultFormat()
self._sentMessages = {}
self._partialMessages = {}
self._partialMessagesProgress = {}
self._listening = defer.Deferred(None)
self._ping_queue = PingQueue(self._node)
self._protocolVersion = constants.protocolVersion
self.started_listening_time = 0
def _migrate_incoming_rpc_args(self, contact, method, *args):
if method == b'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop(b'token', None)
port = args[1].pop(b'port', -1)
originalPublisherID = args[1].pop(b'lbryid', None)
age = 0
return (blob_hash, token, port, originalPublisherID, age), {}
return args, {}
def _migrate_outgoing_rpc_args(self, contact, method, *args):
"""
This will reformat protocol version 0 arguments for the store function and will add the
protocol version keyword argument to calls to contacts who will accept it
"""
if contact.protocolVersion == 0:
if method == b'store':
blob_hash, token, port, originalPublisherID, age = args
args = (
blob_hash, {
b'token': token,
b'port': port,
b'lbryid': originalPublisherID
}, originalPublisherID, False
)
return args
return args
if args and isinstance(args[-1], dict):
args[-1][b'protocolVersion'] = self._protocolVersion
return args
return args + ({b'protocolVersion': self._protocolVersion},)
@defer.inlineCallbacks
def sendRPC(self, contact, method, args):
for _ in range(constants.rpcAttempts):
try:
response = yield self._sendRPC(contact, method, args)
return response
except TimeoutError:
if contact.contact_is_good:
log.debug("RETRY %s ON %s", method, contact)
continue
else:
raise
def _sendRPC(self, contact, method, args):
"""
Sends an RPC to the specified contact
@param contact: The contact (remote node) to send the RPC to
@type contact: kademlia.contacts.Contact
@param method: The name of remote method to invoke
@type method: str
@param args: A list of (non-keyword) arguments to pass to the remote
method, in the correct order
@type args: tuple
@return: This immediately returns a deferred object, which will return
the result of the RPC call, or raise the relevant exception
if the remote node raised one. If C{rawResponse} is set to
C{True}, however, it will always return the actual response
message (which may be a C{ResponseMessage} or an
C{ErrorMessage}).
@rtype: twisted.internet.defer.Deferred
"""
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
*args))
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
hexlify(args[0]), contact.address, contact.port)
else:
log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port)
df = defer.Deferred()
def _remove_contact(failure): # remove the contact from the routing table and track the failure
contact.update_last_failed()
try:
if not contact.contact_is_good:
self._node.removeContact(contact)
except (ValueError, IndexError):
pass
return failure
def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied()
if method == b'findValue':
if b'token' in result:
contact.update_token(result[b'token'])
if b'protocolVersion' not in result:
contact.update_protocol_version(0)
else:
contact.update_protocol_version(result.pop(b'protocolVersion'))
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
df.addCallbacks(_update_contact, _remove_contact)
# Set the RPC timeout timer
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact, df, timeoutCall, cancelTimeout, method, args)
df.addErrback(cancelTimeout)
return df
def startProtocol(self):
log.info("DHT listening on UDP %i (ext port %i)", self._node.port, self._node.externalUDPPort)
if self._listening.called:
self._listening = defer.Deferred()
self._listening.callback(True)
self.started_listening_time = self._node.clock.seconds()
return self._ping_queue.start()
def datagramReceived(self, datagram, address):
""" Handles and parses incoming RPC messages (and responses)
@note: This is automatically called by Twisted when the protocol
receives a UDP datagram
"""
if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00':
totalPackets = (datagram[1] << 8) | datagram[2]
msgID = datagram[5:25]
seqNumber = (datagram[3] << 8) | datagram[4]
if msgID not in self._partialMessages:
self._partialMessages[msgID] = {}
self._partialMessages[msgID][seqNumber] = datagram[26:]
if len(self._partialMessages[msgID]) == totalPackets:
keys = self._partialMessages[msgID].keys()
keys.sort()
data = b''
for key in keys:
data += self._partialMessages[msgID][key]
datagram = data
del self._partialMessages[msgID]
else:
return
try:
msgPrimitive = encoding.bdecode(datagram)
message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError) as err:
# We received some rubbish here
log.warning("Error decoding datagram %s from %s:%i - %s", hexlify(datagram),
address[0], address[1], err)
return
except (IndexError, KeyError):
log.warning("Couldn't decode dht datagram from %s", address)
return
if isinstance(message, msgtypes.RequestMessage):
# This is an RPC method request
remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self)
remoteContact.update_last_requested()
# only add a requesting contact to the routing table if it has replied to one of our requests
if remoteContact.contact_is_good is True:
df = self._node.addContact(remoteContact)
else:
df = defer.succeed(None)
df.addCallback(lambda _: self._handleRPC(remoteContact, message.id, message.request, message.args))
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful
if remoteContact.contact_is_good is None and remoteContact.lastReplied is None:
df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
else:
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
if message.id in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCanceller()
del self._sentMessages[message.id]
# reject replies coming from a different address than what we sent our request to
if (remoteContact.address, remoteContact.port) != address:
log.warning("Sent request to node %s at %s:%i, got reply from %s:%i",
remoteContact.log_id(), remoteContact.address,
remoteContact.port, address[0], address[1])
df.errback(TimeoutError(remoteContact.id))
return
# this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
df.errback(remoteException)
elif isinstance(message, msgtypes.ResponseMessage):
# Find the message that triggered this response
if message.id in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCanceller()
del self._sentMessages[message.id]
log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port,
method, remoteContact.address, remoteContact.port)
# When joining the network we made Contact objects for the seed nodes with node ids set to None
# Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set.
# These replies have be distinguished from those where the node id in the datagram does not match
# the node id of the node we sent a message to (these messages are treated as an error)
if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
remoteContact.log_id(False), hexlify(message.nodeID))
df.errback(TimeoutError(remoteContact.id))
return
elif not remoteContact.id:
remoteContact.set_id(message.nodeID)
# We got a result from the RPC
df.callback(message.response)
else:
# If the original message isn't found, it must have timed out
# TODO: we should probably do something with this...
pass
def _send(self, data, rpcID, address):
""" Transmit the specified data over UDP, breaking it up into several
packets if necessary
If the data is spread over multiple UDP datagrams, the packets have the
following structure::
| | | | | |||||||||||| 0x00 |
|Transmission|Total number|Sequence number| RPC ID |Header end|
| type ID | of packets |of this packet | | indicator|
| (1 byte) | (2 bytes) | (2 bytes) |(20 bytes)| (1 byte) |
| | | | | |||||||||||| |
@note: The header used for breaking up large data segments will
possibly be moved out of the KademliaProtocol class in the
future, into something similar to a message translator/encoder
class (see C{kademlia.msgformat} and C{kademlia.encoding}).
"""
if len(data) > self.msgSizeLimit:
# We have to spread the data over multiple UDP datagrams,
# and provide sequencing information
#
# 1st byte is transmission type id, bytes 2 & 3 are the
# total number of packets in this transmission, bytes 4 &
# 5 are the sequence number for this specific packet
totalPackets = len(data) // self.msgSizeLimit
if len(data) % self.msgSizeLimit > 0:
totalPackets += 1
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
seqNumber = 0
startPos = 0
while seqNumber < totalPackets:
packetData = data[startPos:startPos + self.msgSizeLimit]
encSeqNumber = chr(seqNumber >> 8) + chr(seqNumber & 0xff)
txData = f'\x00{encTotalPackets}{encSeqNumber}{rpcID}\x00{packetData}'
self._scheduleSendNext(txData, address)
startPos += self.msgSizeLimit
seqNumber += 1
else:
self._scheduleSendNext(data, address)
def _scheduleSendNext(self, txData, address):
"""Schedule the sending of the next UDP packet """
delayed_call, _ = self._node.reactor_callSoon(self._write, txData, address)
def _write(self, txData, address):
if self.transport:
try:
self.transport.write(txData, address)
except OSError as err:
if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesn't seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
elif err.errno == errno.ENETUNREACH:
# this should probably try to retransmit when the network connection is back
log.error("Network is unreachable")
else:
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(txData), address[0], address[1], err.message, err.errno)
raise err
else:
raise TransportNotConnected()
def _sendResponse(self, contact, rpcID, response):
""" Send a RPC response to the specified contact
"""
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
""" Send an RPC error message to the specified contact
"""
exceptionMessage = exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _handleRPC(self, senderContact, rpcID, method, args):
""" Executes a local function in response to an RPC request """
# Set up the deferred callchain
def handleError(f):
self._sendError(senderContact, rpcID, f.type, f.getErrorMessage())
def handleResult(result):
self._sendResponse(senderContact, rpcID, result)
df = defer.Deferred()
df.addCallback(handleResult)
df.addErrback(handleError)
# Execute the RPC
func = getattr(self._node, method.decode(), None)
if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain
# if args:
# log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
# args[0].encode('hex'), senderContact.address, senderContact.port)
# else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop(b'protocolVersion')))
a, kw = tuple(args[:-1]), args[-1]
else:
senderContact.update_protocol_version(0)
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
try:
if method != b'ping':
result = func(senderContact, *a)
else:
result = func()
except Exception as e:
log.error("error handling request for %s:%i %s", senderContact.address, senderContact.port, method)
df.errback(e)
else:
df.callback(result)
else:
# No such exposed method
df.errback(AttributeError('Invalid method: %s' % method))
return df
def _msgTimeout(self, messageID):
""" Called when an RPC request message times out """
# Find the message that timed out
if messageID not in self._sentMessages:
# This should never be reached
log.error("deferred timed out, but is not present in sent messages list!")
return
remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID]
if messageID in self._partialMessages:
# We are still receiving this message
self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args)
return
del self._sentMessages[messageID]
# The message's destination node is now considered to be dead;
# raise an (asynchronous) TimeoutError exception and update the host node
df.errback(TimeoutError(remoteContact.id))
def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args):
# See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer
timeoutCanceller()
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args)
else:
# No progress has been made
if messageID in self._partialMessagesProgress:
del self._partialMessagesProgress[messageID]
if messageID in self._partialMessages:
del self._partialMessages[messageID]
df.errback(TimeoutError(remoteContact.id))
def _hasProgressBeenMade(self, messageID):
return (
messageID in self._partialMessagesProgress and
(
len(self._partialMessagesProgress[messageID]) !=
len(self._partialMessages[messageID])
)
)
def stopProtocol(self):
""" Called when the transport is disconnected.
Will only be called once, after all ports are disconnected.
"""
log.info('Stopping DHT')
self._ping_queue.stop()
self._node.call_later_manager.stop()
log.info('DHT stopped')

View file

View file

@ -0,0 +1,105 @@
import asyncio
import typing
import logging
import traceback
if typing.TYPE_CHECKING:
from types import AsyncGeneratorType
log = logging.getLogger(__name__)
def cancel_task(task: typing.Optional[asyncio.Task]):
if task and not (task.done() or task.cancelled()):
task.cancel()
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
for task in tasks:
cancel_task(task)
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())
class AsyncGeneratorJunction:
"""
A helper to interleave the results from multiple async generators into one
async generator.
"""
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
self.loop = loop
self.result_queue = queue or asyncio.Queue(loop=loop)
self.tasks: typing.List[asyncio.Task] = []
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
self.can_iterate = asyncio.Event(loop=self.loop)
self.finished = asyncio.Event(loop=self.loop)
@property
def running(self):
return any(self.running_iterators.values())
async def wait_for_generators(self):
async def iterate(iterator: typing.AsyncGenerator):
try:
async for item in iterator:
self.result_queue.put_nowait(item)
finally:
self.running_iterators[iterator] = False
while True:
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
self.running_iterators[async_gen] = True
self.tasks.append(self.loop.create_task(iterate(async_gen)))
if not self.can_iterate.is_set():
self.can_iterate.set()
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
"""
Add an async generator. This can be called during an iteration of the generator junction.
"""
self.generator_queue.put_nowait(async_gen)
def __aiter__(self):
return self
async def __anext__(self):
if not self.can_iterate.is_set():
await self.can_iterate.wait()
if not self.running:
raise StopAsyncIteration()
try:
return await self.result_queue.get()
finally:
self.awaiting = None
def aclose(self):
async def _aclose():
for iterator in list(self.running_iterators.keys()):
result = iterator.aclose()
if asyncio.iscoroutine(result):
await result
self.running_iterators[iterator] = False
drain_tasks(self.tasks)
raise StopAsyncIteration()
if not self.finished.is_set():
self.finished.set()
return self.loop.create_task(_aclose())
async def __aenter__(self):
self.tasks.append(self.loop.create_task(self.wait_for_generators()))
return self
async def __aexit__(self, exc_type, exc, tb):
try:
await self.aclose()
except StopAsyncIteration:
pass
finally:
if exc_type:
if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration):
err = traceback.format_exception(exc_type, exc, tb)
log.error(err)

View file

@ -0,0 +1,76 @@
import asyncio
import typing
from lbrynet.dht import constants
if typing.TYPE_CHECKING:
from lbrynet.dht.peer import KademliaPeer, PeerManager
class DictDataStore:
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'):
# Dictionary format:
# { <key>: [<contact>, <value>, <lastPublished>, <originallyPublished> <original_publisher_id>] }
self._data_store: typing.Dict[bytes,
typing.List[typing.Tuple['KademliaPeer', bytes, float, float, bytes]]] = {}
self._get_time = loop.time
self._peer_manager = peer_manager
self.completed_blobs: typing.Set[str] = set()
def filter_bad_and_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
"""
Returns only non-expired and unknown/good peers
"""
peers = []
for peer in map(lambda p: p[0],
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration,
self._data_store[key])):
if self._peer_manager.peer_is_good(peer) is not False:
peers.append(peer)
return peers
def filter_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
"""
Returns only non-expired peers
"""
return list(
map(
lambda p: p[0],
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, self._data_store[key])
)
)
def removed_expired_peers(self):
expired_keys = []
for key in self._data_store.keys():
unexpired_peers = self.filter_expired_peers(key)
if not unexpired_peers:
expired_keys.append(key)
else:
self._data_store[key] = [x for x in self._data_store[key] if x[0] in unexpired_peers]
for key in expired_keys:
del self._data_store[key]
def has_peers_for_blob(self, key: bytes) -> bool:
return key in self._data_store and len(self.filter_bad_and_expired_peers(key)) > 0
def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes, compact_address: bytes, last_published: int,
originally_published: int, original_publisher_id: bytes) -> None:
if key in self._data_store:
if compact_address not in map(lambda store_tuple: store_tuple[1], self._data_store[key]):
self._data_store[key].append(
(contact, compact_address, last_published, originally_published, original_publisher_id)
)
else:
self._data_store[key] = [(contact, compact_address, last_published, originally_published,
original_publisher_id)]
def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']:
return [] if key not in self._data_store else [peer for peer in self.filter_bad_and_expired_peers(key)]
def get_storing_contacts(self) -> typing.List['KademliaPeer']:
peers = set()
for key in self._data_store:
for values in self._data_store[key]:
if values[0] not in peers:
peers.add(values[0])
return list(peers)

View file

@ -8,20 +8,16 @@ class Distance:
we pre-calculate the value of that point.
"""
def __init__(self, key):
if len(key) != constants.key_bits // 8:
def __init__(self, key: bytes):
if len(key) != constants.hash_length:
raise ValueError("invalid key length: %i" % len(key))
self.key = key
self.val_key_one = int.from_bytes(key, 'big')
def __call__(self, key_two):
def __call__(self, key_two: bytes) -> int:
val_key_two = int.from_bytes(key_two, 'big')
return self.val_key_one ^ val_key_two
def is_closer(self, a, b):
def is_closer(self, a: bytes, b: bytes) -> bool:
"""Returns true is `a` is closer to `key` than `b` is"""
return self(a) < self(b)
def to_contact(self, contact):
"""A convenience function for calculating the distance to a contact"""
return self(contact.id)

View file

@ -0,0 +1,381 @@
import asyncio
import typing
import logging
from lbrynet.utils import drain_tasks
from lbrynet.dht import constants
from lbrynet.dht.error import RemoteException
from lbrynet.dht.protocol.distance import Distance
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
from lbrynet.dht.protocol.protocol import KademliaProtocol
from lbrynet.dht.peer import PeerManager, KademliaPeer
log = logging.getLogger(__name__)
class FindResponse:
@property
def found(self) -> bool:
raise NotImplementedError()
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
raise NotImplementedError()
class FindNodeResponse(FindResponse):
def __init__(self, key: bytes, close_triples: typing.List[typing.Tuple[bytes, str, int]]):
self.key = key
self.close_triples = close_triples
@property
def found(self) -> bool:
return self.key in [triple[0] for triple in self.close_triples]
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
return self.close_triples
class FindValueResponse(FindResponse):
def __init__(self, key: bytes, result_dict: typing.Dict):
self.key = key
self.token = result_dict[b'token']
self.close_triples: typing.List[typing.Tuple[bytes, bytes, int]] = result_dict.get(b'contacts', [])
self.found_compact_addresses = result_dict.get(key, [])
@property
def found(self) -> bool:
return len(self.found_compact_addresses) > 0
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
return [(node_id, address.decode(), port) for node_id, address, port in self.close_triples]
def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
shortlist: typing.Optional[typing.List['KademliaPeer']]) -> typing.List['KademliaPeer']:
"""
If not provided, initialize the shortlist of peers to probe to the (up to) k closest peers in the routing table
:param routing_table: a TreeRoutingTable
:param key: a 48 byte hash
:param shortlist: optional manually provided shortlist, this is done during bootstrapping when there are no
peers in the routing table. During bootstrap the shortlist is set to be the seed nodes.
"""
if len(key) != constants.hash_length:
raise ValueError("invalid key length: %i" % len(key))
if not shortlist:
shortlist = routing_table.find_close_peers(key)
distance = Distance(key)
shortlist.sort(key=lambda peer: distance(peer.node_id), reverse=True)
return shortlist
class IterativeFinder:
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
if len(key) != constants.hash_length:
raise ValueError("invalid key length: %i" % len(key))
self.loop = loop
self.peer_manager = peer_manager
self.routing_table = routing_table
self.protocol = protocol
self.key = key
self.bottom_out_limit = bottom_out_limit
self.max_results = max_results
self.exclude = exclude or []
self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist)
self.active: typing.List['KademliaPeer'] = []
self.contacted: typing.List[typing.Tuple[str, int]] = []
self.distance = Distance(key)
self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0]
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
self.iteration_queue = asyncio.Queue(loop=self.loop)
self.running_probes: typing.List[asyncio.Task] = []
self.lock = asyncio.Lock(loop=self.loop)
self.iteration_count = 0
self.bottom_out_count = 0
self.running = False
self.tasks: typing.List[asyncio.Task] = []
self.delayed_calls: typing.List[asyncio.Handle] = []
self.finished = asyncio.Event(loop=self.loop)
async def send_probe(self, peer: 'KademliaPeer') -> FindResponse:
"""
Send the rpc request to the peer and return an object with the FindResponse interface
"""
raise NotImplementedError()
def check_result_ready(self, response: FindResponse):
"""
Called with a lock after adding peers from an rpc result to the shortlist.
This method is responsible for putting a result for the generator into the Queue
"""
raise NotImplementedError()
def get_initial_result(self) -> typing.List['KademliaPeer']:
"""
Get an initial or cached result to be put into the Queue. Used for findValue requests where the blob
has peers in the local data store of blobs announced to us
"""
return []
def _is_closer(self, peer: 'KademliaPeer') -> bool:
if not self.closest_peer:
return True
return self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
def _update_closest(self):
self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True)
if self.closest_peer and self.closest_peer is not self.shortlist[-1]:
if self._is_closer(self.shortlist[-1]):
self.prev_closest_peer = self.closest_peer
self.closest_peer = self.shortlist[-1]
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
async with self.lock:
if peer not in self.shortlist:
self.shortlist.append(peer)
if peer not in self.active:
self.active.append(peer)
for contact_triple in response.get_close_triples():
addr_tuple = (contact_triple[1], contact_triple[2])
if addr_tuple not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple)
found_peer = self.peer_manager.get_kademlia_peer(
contact_triple[0], contact_triple[1], contact_triple[2]
)
if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False:
self.shortlist.append(found_peer)
self._update_closest()
self.check_result_ready(response)
async def _send_probe(self, peer: 'KademliaPeer'):
try:
response = await self.send_probe(peer)
except asyncio.CancelledError:
return
except asyncio.TimeoutError:
if peer in self.active:
self.active.remove(peer)
return
except ValueError as err:
log.warning(str(err))
if peer in self.active:
self.active.remove(peer)
return
except RemoteException:
return
return await self._handle_probe_result(peer, response)
async def _search_round(self):
"""
Send up to constants.alpha (5) probes to the closest peers in the shortlist
"""
added = 0
async with self.lock:
self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True)
while self.running and len(self.shortlist) and added < constants.alpha:
peer = self.shortlist.pop()
origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
continue
if peer.node_id == self.protocol.node_id:
continue
if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port):
continue
if (peer.address, peer.udp_port) not in self.contacted:
self.contacted.append((peer.address, peer.udp_port))
t = self.loop.create_task(self._send_probe(peer))
def callback(_):
if t and t in self.running_probes:
self.running_probes.remove(t)
if not self.running_probes and self.shortlist:
self.tasks.append(self.loop.create_task(self._search_task(0.0)))
t.add_done_callback(callback)
self.running_probes.append(t)
added += 1
async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay):
try:
if self.running:
await self._search_round()
if self.running:
self.delayed_calls.append(self.loop.call_later(delay, self._search))
except (asyncio.CancelledError, StopAsyncIteration):
if self.running:
drain_tasks(self.running_probes)
self.running = False
def _search(self):
self.tasks.append(self.loop.create_task(self._search_task()))
def search(self):
if self.running:
raise Exception("already running")
self.running = True
self._search()
async def next_queue_or_finished(self) -> typing.List['KademliaPeer']:
peers = self.loop.create_task(self.iteration_queue.get())
finished = self.loop.create_task(self.finished.wait())
err = None
try:
await asyncio.wait([peers, finished], loop=self.loop, return_when='FIRST_COMPLETED')
if peers.done():
return peers.result()
raise StopAsyncIteration()
except asyncio.CancelledError as error:
err = error
finally:
if not finished.done() and not finished.cancelled():
finished.cancel()
if not peers.done() and not peers.cancelled():
peers.cancel()
if err:
raise err
def __aiter__(self):
self.search()
return self
async def __anext__(self) -> typing.List['KademliaPeer']:
try:
if self.iteration_count == 0:
initial_results = self.get_initial_result()
if initial_results:
self.iteration_queue.put_nowait(initial_results)
result = await self.next_queue_or_finished()
self.iteration_count += 1
return result
except (asyncio.CancelledError, StopAsyncIteration):
await self.aclose()
raise
def aclose(self):
self.running = False
async def _aclose():
async with self.lock:
self.running = False
if not self.finished.is_set():
self.finished.set()
drain_tasks(self.tasks)
drain_tasks(self.running_probes)
while self.delayed_calls:
timer = self.delayed_calls.pop()
if timer:
timer.cancel()
return asyncio.ensure_future(_aclose(), loop=self.loop)
class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
shortlist)
self.yielded_peers: typing.Set['KademliaPeer'] = set()
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
return FindNodeResponse(self.key, response)
def put_result(self, from_list: typing.List['KademliaPeer']):
not_yet_yielded = [peer for peer in from_list if peer not in self.yielded_peers]
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
if to_yield:
for peer in to_yield:
self.yielded_peers.add(peer)
self.iteration_queue.put_nowait(to_yield)
def check_result_ready(self, response: FindNodeResponse):
found = response.found and self.key != self.protocol.node_id
if found:
log.info("found")
self.put_result(self.shortlist)
if not self.finished.is_set():
self.finished.set()
return
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
# self.bottom_out_count, self.iteration_count)
self.bottom_out_count = 0
elif self.prev_closest_peer and self.closest_peer:
self.bottom_out_count += 1
log.info("bottom out %i %i %i %i", len(self.active), len(self.contacted), len(self.shortlist),
self.bottom_out_count)
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
log.info("limit hit")
self.put_result(self.active)
if not self.finished.is_set():
self.finished.set()
return
if self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
log.info("max results")
self.put_result(self.active)
if not self.finished.is_set():
self.finished.set()
return
class IterativeValueFinder(IterativeFinder):
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
shortlist)
self.blob_peers: typing.Set['KademliaPeer'] = set()
async def send_probe(self, peer: 'KademliaPeer') -> FindValueResponse:
response = await self.protocol.get_rpc_peer(peer).find_value(self.key)
return FindValueResponse(self.key, response)
def check_result_ready(self, response: FindValueResponse):
if response.found:
blob_peers = [self.peer_manager.decode_tcp_peer_from_compact_address(compact_addr)
for compact_addr in response.found_compact_addresses]
to_yield = []
self.bottom_out_count = 0
for blob_peer in blob_peers:
if blob_peer not in self.blob_peers:
self.blob_peers.add(blob_peer)
to_yield.append(blob_peer)
if to_yield:
# log.info("found %i new peers for blob", len(to_yield))
self.iteration_queue.put_nowait(to_yield)
# if self.max_results and len(self.blob_peers) >= self.max_results:
# log.info("enough blob peers found")
# if not self.finished.is_set():
# self.finished.set()
return
if self.prev_closest_peer and self.closest_peer:
self.bottom_out_count += 1
if self.bottom_out_count >= self.bottom_out_limit:
log.info("blob peer search bottomed out")
if not self.finished.is_set():
self.finished.set()
return
def get_initial_result(self) -> typing.List['KademliaPeer']:
if self.protocol.data_store.has_peers_for_blob(self.key):
return self.protocol.data_store.get_peers_for_blob(self.key)
return []

View file

@ -0,0 +1,634 @@
import logging
import socket
import functools
import hashlib
import asyncio
import typing
import binascii
from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport
from lbrynet.dht import constants
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE
from lbrynet.dht.error import RemoteException, TransportNotConnected
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
from lbrynet.dht.protocol.data_store import DictDataStore
if typing.TYPE_CHECKING:
from lbrynet.dht.peer import PeerManager, KademliaPeer
log = logging.getLogger(__name__)
old_protocol_errors = {
"findNode() takes exactly 2 arguments (5 given)": "0.19.1",
"findValue() takes exactly 2 arguments (5 given)": "0.19.1"
}
class KademliaRPC:
def __init__(self, protocol: 'KademliaProtocol', loop: asyncio.BaseEventLoop, peer_port: int = 3333):
self.protocol = protocol
self.loop = loop
self.peer_port = peer_port
self.old_token_secret: bytes = None
self.token_secret = constants.generate_id()
def compact_address(self):
compact_ip = functools.reduce(lambda buff, x: buff + bytearray([int(x)]),
self.protocol.external_ip.split('.'), bytearray())
compact_port = self.peer_port.to_bytes(2, 'big')
return compact_ip + compact_port + self.protocol.node_id
@staticmethod
def ping():
return b'pong'
def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int,
original_publisher_id: bytes, age: int) -> bytes:
if original_publisher_id is None:
original_publisher_id = rpc_contact.node_id
rpc_contact.update_tcp_port(port)
if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval:
pass
elif not self.verify_token(token, rpc_contact.compact_ip()):
raise ValueError("Invalid token")
now = int(self.loop.time())
originally_published = now - age
self.protocol.data_store.add_peer_to_blob(
rpc_contact, blob_hash, rpc_contact.compact_address_tcp(), now, originally_published, original_publisher_id
)
return b'OK'
def find_node(self, rpc_contact: 'KademliaPeer', key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]:
if len(key) != constants.hash_length:
raise ValueError("invalid contact node_id length: %i" % len(key))
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
contact_triples = []
for contact in contacts:
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
return contact_triples
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes):
if len(key) != constants.hash_length:
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
response = {
b'token': self.make_token(rpc_contact.compact_ip()),
}
if self.protocol.protocol_version:
response[b'protocolVersion'] = self.protocol.protocol_version
# get peers we have stored for this blob_exchange
has_other_peers = self.protocol.data_store.has_peers_for_blob(key)
peers = []
if has_other_peers:
peers.extend(self.protocol.data_store.get_peers_for_blob(key))
# if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
peers.append(self.compact_address())
if peers:
response[key] = peers
else:
response[b'contacts'] = self.find_node(rpc_contact, key)
return response
def refresh_token(self): # TODO: this needs to be called periodically
self.old_token_secret = self.token_secret
self.token_secret = constants.generate_id()
def make_token(self, compact_ip):
h = hashlib.new('sha384')
h.update(self.token_secret + compact_ip)
return h.digest()
def verify_token(self, token, compact_ip):
h = hashlib.new('sha384')
h.update(self.token_secret + compact_ip)
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
h = hashlib.new('sha384')
h.update(self.old_token_secret + compact_ip)
if not token == h.digest():
return False
return True
class RemoteKademliaRPC:
"""
Encapsulates RPC calls to remote Peers
"""
def __init__(self, loop: asyncio.BaseEventLoop, peer_tracker: 'PeerManager', protocol: 'KademliaProtocol',
peer: 'KademliaPeer'):
self.loop = loop
self.peer_tracker = peer_tracker
self.protocol = protocol
self.peer = peer
async def ping(self) -> bytes:
"""
:return: b'pong'
"""
response = await self.protocol.send_request(
self.peer, RequestDatagram.make_ping(self.protocol.node_id)
)
return response.response
async def store(self, blob_hash: bytes) -> bytes:
"""
:param blob_hash: blob hash as bytes
:return: b'OK'
"""
if len(blob_hash) != constants.hash_bits // 8:
raise ValueError(f"invalid length of blob hash: {len(blob_hash)}")
if not self.protocol.peer_port or not 0 < self.protocol.peer_port < 65535:
raise ValueError(f"invalid tcp port: {self.protocol.peer_port}")
token = self.peer_tracker.get_node_token(self.peer.node_id)
if not token:
find_value_resp = await self.find_value(blob_hash)
token = find_value_resp[b'token']
response = await self.protocol.send_request(
self.peer, RequestDatagram.make_store(self.protocol.node_id, blob_hash, token, self.protocol.peer_port)
)
return response.response
async def find_node(self, key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]:
"""
:return: [(node_id, address, udp_port), ...]
"""
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid length of find node key: {len(key)}")
response = await self.protocol.send_request(
self.peer, RequestDatagram.make_find_node(self.protocol.node_id, key)
)
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]:
"""
:return: {
b'token': <token bytes>,
b'contacts': [(node_id, address, udp_port), ...]
<key bytes>: [<blob_peer_compact_address, ...]
}
"""
if len(key) != constants.hash_bits // 8:
raise ValueError(f"invalid length of find value key: {len(key)}")
response = await self.protocol.send_request(
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
)
await self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
return response.response
class PingQueue:
def __init__(self, loop: asyncio.BaseEventLoop, protocol: 'KademliaProtocol'):
self._loop = loop
self._protocol = protocol
self._enqueued_contacts: typing.List['KademliaPeer'] = []
self._pending_contacts: typing.Dict['KademliaPeer', float] = {}
self._process_task: asyncio.Task = None
self._next_task: asyncio.Future = None
self._next_timer: asyncio.TimerHandle = None
self._lock = asyncio.Lock()
self._running = False
@property
def running(self):
return self._running
async def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None):
delay = constants.check_refresh_interval if delay is None else delay
async with self._lock:
for peer in peers:
if delay and peer not in self._enqueued_contacts:
self._pending_contacts[peer] = self._loop.time() + delay
elif peer not in self._enqueued_contacts:
self._enqueued_contacts.append(peer)
if peer in self._pending_contacts:
del self._pending_contacts[peer]
async def _process(self):
async def _ping(p: 'KademliaPeer'):
try:
if self._protocol.peer_manager.peer_is_good(p):
await self._protocol.add_peer(p)
return
await self._protocol.get_rpc_peer(p).ping()
except TimeoutError:
pass
while True:
tasks = []
async with self._lock:
if self._enqueued_contacts or self._pending_contacts:
now = self._loop.time()
scheduled = [k for k, d in self._pending_contacts.items() if now >= d]
for k in scheduled:
del self._pending_contacts[k]
if k not in self._enqueued_contacts:
self._enqueued_contacts.append(k)
while self._enqueued_contacts:
peer = self._enqueued_contacts.pop()
tasks.append(self._loop.create_task(_ping(peer)))
if tasks:
await asyncio.wait(tasks, loop=self._loop)
f = self._loop.create_future()
self._loop.call_later(1.0, lambda: None if f.done() else f.set_result(None))
await f
def start(self):
assert not self._running
self._running = True
if not self._process_task:
self._process_task = self._loop.create_task(self._process())
def stop(self):
assert self._running
self._running = False
if self._process_task:
self._process_task.cancel()
self._process_task = None
if self._next_task:
self._next_task.cancel()
self._next_task = None
if self._next_timer:
self._next_timer.cancel()
self._next_timer = None
class KademliaProtocol(DatagramProtocol):
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, external_ip: str,
udp_port: int, peer_port: int, rpc_timeout: float = 5.0):
self.peer_manager = peer_manager
self.loop = loop
self.node_id = node_id
self.external_ip = external_ip
self.udp_port = udp_port
self.peer_port = peer_port
self.is_seed_node = False
self.partial_messages: typing.Dict[bytes, typing.Dict[bytes, bytes]] = {}
self.sent_messages: typing.Dict[bytes, typing.Tuple['KademliaPeer', asyncio.Future, RequestDatagram]] = {}
self.protocol_version = constants.protocol_version
self.started_listening_time = 0
self.transport: DatagramTransport = None
self.old_token_secret = constants.generate_id()
self.token_secret = constants.generate_id()
self.routing_table = TreeRoutingTable(self.loop, self.peer_manager, self.node_id)
self.data_store = DictDataStore(self.loop, self.peer_manager)
self.ping_queue = PingQueue(self.loop, self)
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
self.lock = asyncio.Lock(loop=self.loop)
self.rpc_timeout = rpc_timeout
self._split_lock = asyncio.Lock(loop=self.loop)
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
def stop(self):
if self.transport:
self.disconnect()
def disconnect(self):
self.transport.close()
def connection_made(self, transport: DatagramTransport):
self.transport = transport
def connection_lost(self, exc):
self.stop()
@staticmethod
def _migrate_incoming_rpc_args(peer: 'KademliaPeer', method: bytes, *args) -> typing.Tuple[typing.Tuple,
typing.Dict]:
if method == b'store' and peer.protocol_version == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop(b'token', None)
port = args[1].pop(b'port', -1)
original_publisher_id = args[1].pop(b'lbryid', None)
age = 0
return (blob_hash, token, port, original_publisher_id, age), {}
return args, {}
async def _add_peer(self, peer: 'KademliaPeer'):
bucket_index = self.routing_table.kbucket_index(peer.node_id)
if self.routing_table.buckets[bucket_index].add_peer(peer):
return True
# The bucket is full; see if it can be split (by checking if its range includes the host node's node_id)
if self.routing_table.should_split(bucket_index, peer.node_id):
self.routing_table.split_bucket(bucket_index)
# Retry the insertion attempt
result = await self._add_peer(peer)
self.routing_table.join_buckets()
return result
else:
# We can't split the k-bucket
#
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
#
# A reasonable extension to this is BEP 0005, which extends the above:
#
# Not all nodes that we learn about are equal. Some are "good" and some are not.
# Many nodes using the DHT are able to send queries and receive responses,
# but are not able to respond to queries from other nodes. It is important that
# each node's routing table must contain only known good nodes. A good node is
# a node has responded to one of our queries within the last 15 minutes. A node
# is also good if it has ever responded to one of our queries and has sent us a
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
# questionable. Nodes become bad when they fail to respond to multiple queries
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
#
# When there are bad or questionable nodes in the bucket, the least recent is selected for
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
# is ignored if the pinged node replies.
not_good_contacts = self.routing_table.buckets[bucket_index].get_bad_or_unknown_peers()
not_recently_replied = []
for p in not_good_contacts:
last_replied = self.peer_manager.get_last_replied(p.address, p.udp_port)
if not last_replied or last_replied + 60 < self.loop.time():
not_recently_replied.append(p)
if not_recently_replied:
to_replace = not_recently_replied[0]
else:
to_replace = self.routing_table.buckets[bucket_index].peers[0]
last_replied = self.peer_manager.get_last_replied(to_replace.address, to_replace.udp_port)
if last_replied and last_replied + 60 > self.loop.time():
return False
log.debug("pinging %s:%s", to_replace.address, to_replace.udp_port)
try:
to_replace_rpc = self.get_rpc_peer(to_replace)
await to_replace_rpc.ping()
return False
except asyncio.TimeoutError:
log.debug("Replacing dead contact in bucket %i: %s:%i with %s:%i ", bucket_index,
to_replace.address, to_replace.udp_port, peer.address, peer.udp_port)
if to_replace in self.routing_table.buckets[bucket_index]:
self.routing_table.buckets[bucket_index].remove_peer(to_replace)
return await self._add_peer(peer)
async def add_peer(self, peer: 'KademliaPeer') -> bool:
if peer.node_id == self.node_id:
return False
async with self._split_lock:
return await self._add_peer(peer)
async def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
binascii.hexlify(self.node_id)[:8].decode())
method = message.method
if method not in [b'ping', b'store', b'findNode', b'findValue']:
raise AttributeError('Invalid method: %s' % message.method.decode())
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
# args don't need reformatting
a, kw = tuple(message.args[:-1]), message.args[-1]
else:
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
sender_contact.address, sender_contact.udp_port)
if method == b'ping':
result = self.node_rpc.ping()
elif method == b'store':
blob_hash, token, port, original_publisher_id, age = a
result = self.node_rpc.store(sender_contact, blob_hash, token, port, original_publisher_id, age)
elif method == b'findNode':
key, = a
result = self.node_rpc.find_node(sender_contact, key)
else:
assert method == b'findValue'
key, = a
result = self.node_rpc.find_value(sender_contact, key)
await self.send_response(
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
)
async def handle_request_datagram(self, address, request_datagram: RequestDatagram):
# This is an RPC method request
await self.peer_manager.report_last_requested(address[0], address[1])
await self.peer_manager.update_contact_triple(request_datagram.node_id, address[0], address[1])
# only add a requesting contact to the routing table if it has replied to one of our requests
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
try:
await self._handle_rpc(peer, request_datagram)
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful
is_good = self.peer_manager.peer_is_good(peer)
if is_good is None:
await self.ping_queue.enqueue_maybe_ping(peer)
elif is_good is True:
await self.add_peer(peer)
except Exception as err:
log.warning("error raised handling %s request from %s:%i - %s(%s)",
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
str(err))
await self.send_error(
peer,
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
str(err).encode())
)
async def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram):
# Find the message that triggered this response
if response_datagram.rpc_id in self.sent_messages:
peer, df, request = self.sent_messages[response_datagram.rpc_id]
if peer.address != address[0]:
df.set_exception(RemoteException(
f"response from {address[0]}:{address[1]}, "
f"expected {peer.address}:{peer.udp_port}")
)
return
peer.set_id(response_datagram.node_id)
# We got a result from the RPC
if peer.node_id == self.node_id:
df.set_exception(RemoteException("node has our node id"))
return
elif response_datagram.node_id == self.node_id:
df.set_exception(RemoteException("incoming message is from our node id"))
return
await self.peer_manager.report_last_replied(address[0], address[1])
await self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1])
if not df.cancelled():
df.set_result(response_datagram)
await self.add_peer(peer)
else:
log.warning("%s:%i replied, but after we cancelled the request attempt",
peer.address, peer.udp_port)
else:
# If the original message isn't found, it must have timed out
# TODO: we should probably do something with this...
pass
def handle_error_datagram(self, address, error_datagram: ErrorDatagram):
# The RPC request raised a remote exception; raise it locally
remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})")
if error_datagram.rpc_id in self.sent_messages:
peer, df, request = self.sent_messages.pop(error_datagram.rpc_id)
error_msg = f"" \
f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \
f"Args: {request.args}\n" \
f"Raised: {str(remote_exception)}"
if error_datagram.response not in old_protocol_errors:
log.warning(error_msg)
else:
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
peer.address, peer.udp_port, old_protocol_errors[error_datagram.response])
# reject replies coming from a different address than what we sent our request to
if (peer.address, peer.udp_port) != address:
log.error("node id mismatch in reply")
remote_exception = TimeoutError(peer.node_id)
df.set_exception(remote_exception)
return
else:
if error_datagram.response not in old_protocol_errors:
msg = f"Received error from {address[0]}:{address[1]}, but it isn't in response to a " \
f"pending request: {str(remote_exception)}"
log.warning(msg)
else:
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
address[0], address[1], old_protocol_errors[error_datagram.response])
def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None:
try:
message = decode_datagram(datagram)
except (ValueError, TypeError):
self.loop.create_task(self.peer_manager.report_failure(address[0], address[1]))
log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode())
return
if isinstance(message, RequestDatagram):
self.loop.create_task(self.handle_request_datagram(address, message))
elif isinstance(message, ErrorDatagram):
self.handle_error_datagram(address, message)
else:
assert isinstance(message, ResponseDatagram), "sanity"
self.loop.create_task(self.handle_response_datagram(address, message))
async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram:
await self._send(peer, request)
response_fut = self.sent_messages[request.rpc_id][1]
try:
response = await asyncio.wait_for(response_fut, self.rpc_timeout)
await self.peer_manager.report_last_replied(peer.address, peer.udp_port)
return response
except (asyncio.TimeoutError, RemoteException):
await self.peer_manager.report_failure(peer.address, peer.udp_port)
if self.peer_manager.peer_is_good(peer) is False:
self.routing_table.remove_peer(peer)
raise
async def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram):
await self._send(peer, response)
async def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
await self._send(peer, error)
async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
ErrorDatagram]):
if not self.transport:
raise TransportNotConnected()
data = message.bencode()
if len(data) > constants.msg_size_limit:
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit)
raise ValueError()
if isinstance(message, (RequestDatagram, ResponseDatagram)):
assert message.node_id == self.node_id, message
if isinstance(message, RequestDatagram):
assert self.node_id != peer.node_id
def pop_from_sent_messages(_):
if message.rpc_id in self.sent_messages:
self.sent_messages.pop(message.rpc_id)
async with self.lock:
if isinstance(message, RequestDatagram):
response_fut = self.loop.create_future()
response_fut.add_done_callback(pop_from_sent_messages)
self.sent_messages[message.rpc_id] = (peer, response_fut, message)
try:
self.transport.sendto(data, (peer.address, peer.udp_port))
except OSError as err:
# TODO: handle ENETUNREACH
if err.errno == socket.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesn't seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
else:
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(data), peer.address, peer.udp_port, str(err), err.errno)
if isinstance(message, RequestDatagram):
self.sent_messages[message.rpc_id][1].set_exception(err)
else:
raise err
if isinstance(message, RequestDatagram):
await self.peer_manager.report_last_sent(peer.address, peer.udp_port)
elif isinstance(message, ErrorDatagram):
await self.peer_manager.report_failure(peer.address, peer.udp_port)
def change_token(self):
self.old_token_secret = self.token_secret
self.token_secret = constants.generate_id()
def make_token(self, compact_ip):
return constants.digest(self.token_secret + compact_ip)
def verify_token(self, token, compact_ip):
h = constants.hash_class()
h.update(self.token_secret + compact_ip)
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
h = constants.hash_class()
h.update(self.old_token_secret + compact_ip)
if not token == h.digest():
return False
return True
async def store_to_peer(self, hash_value: bytes, peer: 'KademliaPeer') -> typing.Tuple[bytes, bool]:
try:
res = await self.get_rpc_peer(peer).store(hash_value)
if res != b"OK":
raise ValueError(res)
log.info("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer)
return peer.node_id, True
except asyncio.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
except ValueError as err:
log.error("Unexpected response: %s" % err)
except Exception as err:
if 'Invalid token' in str(err):
await self.peer_manager.clear_token(peer.node_id)
else:
log.exception("Unexpected error while storing blob_hash")
return peer.node_id, False
def _write(self, data: bytes, address: typing.Tuple[str, int]):
if self.transport:
try:
self.transport.sendto(data, address)
except OSError as err:
if err.errno == socket.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesn't seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
# elif err.errno == socket.ENETUNREACH:
# # this should probably try to retransmit when the network connection is back
# log.error("Network is unreachable")
else:
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(data), address[0], address[1], str(err), err.errno)
raise err
else:
raise TransportNotConnected()

View file

@ -0,0 +1,305 @@
import asyncio
import random
import logging
import typing
import itertools
from lbrynet.dht import constants
from lbrynet.dht.protocol.distance import Distance
if typing.TYPE_CHECKING:
from lbrynet.dht.peer import KademliaPeer, PeerManager
log = logging.getLogger(__name__)
class KBucket:
""" Description - later
"""
def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, node_id: bytes):
"""
@param range_min: The lower boundary for the range in the n-bit ID
space covered by this k-bucket
@param range_max: The upper boundary for the range in the ID space
covered by this k-bucket
"""
self._peer_manager = peer_manager
self.last_accessed = 0
self.range_min = range_min
self.range_max = range_max
self.peers: typing.List['KademliaPeer'] = []
self._node_id = node_id
def add_peer(self, peer: 'KademliaPeer') -> bool:
""" Add contact to _contact list in the right order. This will move the
contact to the end of the k-bucket if it is already present.
@raise kademlia.kbucket.BucketFull: Raised when the bucket is full and
the contact isn't in the bucket
already
@param peer: The contact to add
@type peer: dht.contact._Contact
"""
if peer in self.peers:
# Move the existing contact to the end of the list
# - using the new contact to allow add-on data
# (e.g. optimization-specific stuff) to pe updated as well
self.peers.remove(peer)
self.peers.append(peer)
return True
elif len(self.peers) < constants.k:
self.peers.append(peer)
return True
else:
return False
# raise BucketFull("No space in bucket to insert contact")
def get_peer(self, node_id: bytes) -> 'KademliaPeer':
for peer in self.peers:
if peer.node_id == node_id:
return peer
raise IndexError(node_id)
def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']:
""" Returns a list containing up to the first count number of contacts
@param count: The amount of contacts to return (if 0 or less, return
all contacts)
@type count: int
@param exclude_contact: A node node_id to exclude; if this contact is in
the list of returned values, it will be
discarded before returning. If a C{str} is
passed as this argument, it must be the
contact's ID.
@type exclude_contact: str
@param sort_distance_to: Sort distance to the node_id, defaulting to the parent node node_id. If False don't
sort the contacts
@raise IndexError: If the number of requested contacts is too large
@return: Return up to the first count number of contacts in a list
If no contacts are present an empty is returned
@rtype: list
"""
peers = [peer for peer in self.peers if peer.node_id != exclude_contact]
# Return all contacts in bucket
if count <= 0:
count = len(peers)
# Get current contact number
current_len = len(peers)
# If count greater than k - return only k contacts
if count > constants.k:
count = constants.k
if not current_len:
return peers
if sort_distance_to is False:
pass
else:
sort_distance_to = sort_distance_to or self._node_id
peers.sort(key=lambda c: Distance(sort_distance_to)(c.node_id))
return peers[:min(current_len, count)]
def get_bad_or_unknown_peers(self) -> typing.List['KademliaPeer']:
peer = self.get_peers(sort_distance_to=False)
return [
peer for peer in peer
if self._peer_manager.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) is not True
]
def remove_peer(self, peer: 'KademliaPeer') -> None:
self.peers.remove(peer)
def key_in_range(self, key: bytes) -> bool:
""" Tests whether the specified key (i.e. node ID) is in the range
of the n-bit ID space covered by this k-bucket (in otherwords, it
returns whether or not the specified key should be placed in this
k-bucket)
@param key: The key to test
@type key: str or int
@return: C{True} if the key is in this k-bucket's range, or C{False}
if not.
@rtype: bool
"""
return self.range_min <= int.from_bytes(key, 'big') < self.range_max
def __len__(self) -> int:
return len(self.peers)
def __contains__(self, item) -> bool:
return item in self.peers
class TreeRoutingTable:
""" This class implements a routing table used by a Node class.
The Kademlia routing table is a binary tree whose leaves are k-buckets,
where each k-bucket contains nodes with some common prefix of their IDs.
This prefix is the k-bucket's position in the binary tree; it therefore
covers some range of ID values, and together all of the k-buckets cover
the entire n-bit ID (or key) space (with no overlap).
@note: In this implementation, nodes in the tree (the k-buckets) are
added dynamically, as needed; this technique is described in the 13-page
version of the Kademlia paper, in section 2.4. It does, however, use the
ping RPC-based k-bucket eviction algorithm described in section 2.2 of
that paper.
"""
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes):
self._loop = loop
self._peer_manager = peer_manager
self._parent_node_id = parent_node_id
self.buckets: typing.List[KBucket] = [
KBucket(
self._peer_manager, range_min=0, range_max=2 ** constants.hash_bits, node_id=self._parent_node_id
)
]
def get_peers(self) -> typing.List['KademliaPeer']:
return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets)))
def should_split(self, bucket_index: int, to_add: bytes) -> bool:
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
if self.buckets[bucket_index].key_in_range(self._parent_node_id):
return True
contacts = self.get_peers()
distance = Distance(self._parent_node_id)
contacts.sort(key=lambda c: distance(c.node_id))
kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k - 1]
return distance(to_add) < distance(kth_contact.node_id)
def find_close_peers(self, key: bytes, count: typing.Optional[int] = None,
sender_node_id: typing.Optional[bytes] = None) -> typing.List['KademliaPeer']:
exclude = [self._parent_node_id]
if sender_node_id:
exclude.append(sender_node_id)
if key in exclude:
exclude.remove(key)
count = count or constants.k
distance = Distance(key)
contacts = self.get_peers()
contacts = [c for c in contacts if c.node_id not in exclude]
if contacts:
contacts.sort(key=lambda c: distance(c.node_id))
return contacts[:min(count, len(contacts))]
return []
def get_peer(self, contact_id: bytes) -> 'KademliaPeer':
"""
@raise IndexError: No contact with the specified contact ID is known
by this node
"""
return self.buckets[self.kbucket_index(contact_id)].get_peer(contact_id)
def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]:
bucket_index = start_index
refresh_ids = []
now = int(self._loop.time())
for bucket in self.buckets[start_index:]:
if force or now - bucket.last_accessed >= constants.refresh_interval:
to_search = self.midpoint_id_in_bucket_range(bucket_index)
refresh_ids.append(to_search)
bucket_index += 1
return refresh_ids
def remove_peer(self, peer: 'KademliaPeer') -> None:
if not peer.node_id:
return
bucket_index = self.kbucket_index(peer.node_id)
try:
self.buckets[bucket_index].remove_peer(peer)
except ValueError:
return
def touch_kbucket(self, key: bytes) -> None:
self.touch_kbucket_by_index(self.kbucket_index(key))
def touch_kbucket_by_index(self, bucket_index: int):
self.buckets[bucket_index].last_accessed = int(self._loop.time())
def kbucket_index(self, key: bytes) -> int:
i = 0
for bucket in self.buckets:
if bucket.key_in_range(key):
return i
else:
i += 1
return i
def random_id_in_bucket_range(self, bucket_index: int) -> bytes:
random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max))
return random_id.to_bytes(constants.hash_length, 'big')
def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes:
half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2)
return int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big')
def split_bucket(self, old_bucket_index: int) -> None:
""" Splits the specified k-bucket into two new buckets which together
cover the same range in the key/ID space
@param old_bucket_index: The index of k-bucket to split (in this table's
list of k-buckets)
@type old_bucket_index: int
"""
# Resize the range of the current (old) k-bucket
old_bucket = self.buckets[old_bucket_index]
split_point = old_bucket.range_max - (old_bucket.range_max - old_bucket.range_min) // 2
# Create a new k-bucket to cover the range split off from the old bucket
new_bucket = KBucket(self._peer_manager, split_point, old_bucket.range_max, self._parent_node_id)
old_bucket.range_max = split_point
# Now, add the new bucket into the routing table tree
self.buckets.insert(old_bucket_index + 1, new_bucket)
# Finally, copy all nodes that belong to the new k-bucket into it...
for contact in old_bucket.peers:
if new_bucket.key_in_range(contact.node_id):
new_bucket.add_peer(contact)
# ...and remove them from the old bucket
for contact in new_bucket.peers:
old_bucket.remove_peer(contact)
def join_buckets(self):
to_pop = [i for i, bucket in enumerate(self.buckets) if not len(bucket)]
if not to_pop:
return
log.info("join buckets %i", len(to_pop))
bucket_index_to_pop = to_pop[0]
assert len(self.buckets[bucket_index_to_pop]) == 0
can_go_lower = bucket_index_to_pop - 1 >= 0
can_go_higher = bucket_index_to_pop + 1 < len(self.buckets)
assert can_go_higher or can_go_lower
bucket = self.buckets[bucket_index_to_pop]
if can_go_lower and can_go_higher:
midpoint = ((bucket.range_max - bucket.range_min) // 2) + bucket.range_min
self.buckets[bucket_index_to_pop - 1].range_max = midpoint - 1
self.buckets[bucket_index_to_pop + 1].range_min = midpoint
elif can_go_lower:
self.buckets[bucket_index_to_pop - 1].range_max = bucket.range_max
elif can_go_higher:
self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min
self.buckets.remove(bucket)
return self.join_buckets()
def contact_in_routing_table(self, address_tuple: typing.Tuple[str, int]) -> bool:
for bucket in self.buckets:
for contact in bucket.get_peers(sort_distance_to=False):
if address_tuple[0] == contact.address and address_tuple[1] == contact.udp_port:
return True
return False
def buckets_with_contacts(self) -> int:
count = 0
for bucket in self.buckets:
if len(bucket):
count += 1
return count

View file

@ -1,320 +0,0 @@
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
import random
import logging
from twisted.internet import defer
from lbrynet.dht import constants, kbucket
from lbrynet.dht.error import TimeoutError
from lbrynet.dht.distance import Distance
log = logging.getLogger(__name__)
class TreeRoutingTable:
""" This class implements a routing table used by a Node class.
The Kademlia routing table is a binary tree whFose leaves are k-buckets,
where each k-bucket contains nodes with some common prefix of their IDs.
This prefix is the k-bucket's position in the binary tree; it therefore
covers some range of ID values, and together all of the k-buckets cover
the entire n-bit ID (or key) space (with no overlap).
@note: In this implementation, nodes in the tree (the k-buckets) are
added dynamically, as needed; this technique is described in the 13-page
version of the Kademlia paper, in section 2.4. It does, however, use the
C{PING} RPC-based k-bucket eviction algorithm described in section 2.2 of
that paper.
"""
#implements(IRoutingTable)
def __init__(self, parentNodeID, getTime=None):
"""
@param parentNodeID: The n-bit node ID of the node to which this
routing table belongs
@type parentNodeID: str
"""
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space
self._parentNodeID = parentNodeID
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits, node_id=self._parentNodeID)]
if not getTime:
from twisted.internet import reactor
getTime = reactor.seconds
self._getTime = getTime
self._ongoing_replacements = set()
def get_contacts(self):
contacts = []
for i in range(len(self._buckets)):
for contact in self._buckets[i]._contacts:
contacts.append(contact)
return contacts
def _shouldSplit(self, bucketIndex, toAdd):
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
if self._buckets[bucketIndex].keyInRange(self._parentNodeID):
return True
contacts = self.get_contacts()
distance = Distance(self._parentNodeID)
contacts.sort(key=lambda c: distance(c.id))
kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k-1]
return distance(toAdd) < distance(kth_contact.id)
def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already
exists, its status will be updated
@param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact
@rtype: defer.Deferred
"""
if contact.id == self._parentNodeID:
return defer.succeed(None)
bucketIndex = self._kbucketIndex(contact.id)
try:
self._buckets[bucketIndex].addContact(contact)
except kbucket.BucketFull:
# The bucket is full; see if it can be split (by checking if its range includes the host node's id)
if self._shouldSplit(bucketIndex, contact.id):
self._splitBucket(bucketIndex)
# Retry the insertion attempt
return self.addContact(contact)
else:
# We can't split the k-bucket
#
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
#
# A reasonable extension to this is BEP 0005, which extends the above:
#
# Not all nodes that we learn about are equal. Some are "good" and some are not.
# Many nodes using the DHT are able to send queries and receive responses,
# but are not able to respond to queries from other nodes. It is important that
# each node's routing table must contain only known good nodes. A good node is
# a node has responded to one of our queries within the last 15 minutes. A node
# is also good if it has ever responded to one of our queries and has sent us a
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
# questionable. Nodes become bad when they fail to respond to multiple queries
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
#
# When there are bad or questionable nodes in the bucket, the least recent is selected for
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
# is ignored if the pinged node replies.
def replaceContact(failure, deadContact):
"""
Callback for the deferred PING RPC to see if the node to be replaced in the k-bucket is still
responding
@type failure: twisted.python.failure.Failure
"""
failure.trap(TimeoutError)
log.debug("Replacing dead contact in bucket %i: %s:%i (%s) with %s:%i (%s)", bucketIndex,
deadContact.address, deadContact.port, deadContact.log_id(), contact.address,
contact.port, contact.log_id())
try:
self._buckets[bucketIndex].removeContact(deadContact)
except ValueError:
# The contact has already been removed (probably due to a timeout)
pass
return self.addContact(contact)
not_good_contacts = self._buckets[bucketIndex].getBadOrUnknownContacts()
if not_good_contacts:
to_replace = not_good_contacts[0]
else:
to_replace = self._buckets[bucketIndex]._contacts[0]
if to_replace not in self._ongoing_replacements:
log.debug("pinging %s:%s", to_replace.address, to_replace.port)
self._ongoing_replacements.add(to_replace)
df = to_replace.ping()
df.addErrback(replaceContact, to_replace)
df.addBoth(lambda _: self._ongoing_replacements.remove(to_replace))
else:
df = defer.succeed(None)
return df
else:
self.touchKBucketByIndex(bucketIndex)
return defer.succeed(None)
def findCloseNodes(self, key, count=None, sender_node_id=None):
""" Finds a number of known nodes closest to the node/value with the
specified key.
@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@param count: the amount of contacts to return, default of k (8)
@type count: int
@param sender_node_id: Used during RPC, this is be the sender's Node ID
Whatever ID is passed in the parameter will get
excluded from the list of returned contacts.
@type sender_node_id: str
@return: A list of node contacts (C{kademlia.contact.Contact instances})
closest to the specified key.
This method will return C{k} (or C{count}, if specified)
contacts if at all possible; it will only return fewer if the
node is returning all of the contacts that it knows of.
@rtype: list
"""
exclude = [self._parentNodeID]
if sender_node_id:
exclude.append(sender_node_id)
if key in exclude:
exclude.remove(key)
count = count or constants.k
distance = Distance(key)
contacts = self.get_contacts()
contacts = [c for c in contacts if c.id not in exclude]
contacts.sort(key=lambda c: distance(c.id))
return contacts[:min(count, len(contacts))]
def getContact(self, contactID):
""" Returns the (known) contact with the specified node ID
@raise ValueError: No contact with the specified contact ID is known
by this node
"""
bucketIndex = self._kbucketIndex(contactID)
return self._buckets[bucketIndex].getContact(contactID)
def getRefreshList(self, startIndex=0, force=False):
""" Finds all k-buckets that need refreshing, starting at the
k-bucket with the specified index, and returns IDs to be searched for
in order to refresh those k-buckets
@param startIndex: The index of the bucket to start refreshing at;
this bucket and those further away from it will
be refreshed. For example, when joining the
network, this node will set this to the index of
the bucket after the one containing it's closest
neighbour.
@type startIndex: index
@param force: If this is C{True}, all buckets (in the specified range)
will be refreshed, regardless of the time they were last
accessed.
@type force: bool
@return: A list of node ID's that the parent node should search for
in order to refresh the routing Table
@rtype: list
"""
bucketIndex = startIndex
refreshIDs = []
now = int(self._getTime())
for bucket in self._buckets[startIndex:]:
if force or now - bucket.lastAccessed >= constants.refreshTimeout:
searchID = self.midpoint_id_in_bucket_range(bucketIndex)
refreshIDs.append(searchID)
bucketIndex += 1
return refreshIDs
def removeContact(self, contact):
"""
Remove the contact from the routing table
@param contact: The contact to remove
@type contact: dht.contact._Contact
"""
bucketIndex = self._kbucketIndex(contact.id)
try:
self._buckets[bucketIndex].removeContact(contact)
except ValueError:
return
def touchKBucket(self, key):
""" Update the "last accessed" timestamp of the k-bucket which covers
the range containing the specified key in the key/ID space
@param key: A key in the range of the target k-bucket
@type key: str
"""
self.touchKBucketByIndex(self._kbucketIndex(key))
def touchKBucketByIndex(self, bucketIndex):
self._buckets[bucketIndex].lastAccessed = int(self._getTime())
def _kbucketIndex(self, key):
""" Calculate the index of the k-bucket which is responsible for the
specified key (or ID)
@param key: The key for which to find the appropriate k-bucket index
@type key: str
@return: The index of the k-bucket responsible for the specified key
@rtype: int
"""
i = 0
for bucket in self._buckets:
if bucket.keyInRange(key):
return i
else:
i += 1
return i
def random_id_in_bucket_range(self, bucketIndex):
""" Returns a random ID in the specified k-bucket's range
@param bucketIndex: The index of the k-bucket to use
@type bucketIndex: int
"""
random_id = int(random.randrange(self._buckets[bucketIndex].rangeMin, self._buckets[bucketIndex].rangeMax))
return random_id.to_bytes(constants.key_bits // 8, 'big')
def midpoint_id_in_bucket_range(self, bucketIndex):
""" Returns the middle ID in the specified k-bucket's range
@param bucketIndex: The index of the k-bucket to use
@type bucketIndex: int
"""
half = int((self._buckets[bucketIndex].rangeMax - self._buckets[bucketIndex].rangeMin) // 2)
return int(self._buckets[bucketIndex].rangeMin + half).to_bytes(constants.key_bits // 8, 'big')
def _splitBucket(self, oldBucketIndex):
""" Splits the specified k-bucket into two new buckets which together
cover the same range in the key/ID space
@param oldBucketIndex: The index of k-bucket to split (in this table's
list of k-buckets)
@type oldBucketIndex: int
"""
# Resize the range of the current (old) k-bucket
oldBucket = self._buckets[oldBucketIndex]
splitPoint = oldBucket.rangeMax - (oldBucket.rangeMax - oldBucket.rangeMin) / 2
# Create a new k-bucket to cover the range split off from the old bucket
newBucket = kbucket.KBucket(splitPoint, oldBucket.rangeMax, self._parentNodeID)
oldBucket.rangeMax = splitPoint
# Now, add the new bucket into the routing table tree
self._buckets.insert(oldBucketIndex + 1, newBucket)
# Finally, copy all nodes that belong to the new k-bucket into it...
for contact in oldBucket._contacts:
if newBucket.keyInRange(contact.id):
newBucket.addContact(contact)
# ...and remove them from the old bucket
for contact in newBucket._contacts:
oldBucket.removeContact(contact)
def contactInRoutingTable(self, address_tuple):
for bucket in self._buckets:
for contact in bucket.getContacts(sort_distance_to=False):
if address_tuple[0] == contact.address and address_tuple[1] == contact.port:
return True
return False
def bucketsWithContacts(self):
count = 0
for bucket in self._buckets:
if len(bucket):
count += 1
return count

View file

View file

@ -1,8 +1,8 @@
import typing
from lbrynet.dht.error import DecodeError
def bencode(data):
""" Encoder implementation of the Bencode algorithm (Bittorrent). """
def _bencode(data: typing.Union[int, bytes, bytearray, str, list, tuple, dict]) -> bytes:
if isinstance(data, int):
return b'i%de' % data
elif isinstance(data, (bytes, bytearray)):
@ -12,31 +12,20 @@ def bencode(data):
elif isinstance(data, (list, tuple)):
encoded_list_items = b''
for item in data:
encoded_list_items += bencode(item)
encoded_list_items += _bencode(item)
return b'l%se' % encoded_list_items
elif isinstance(data, dict):
encoded_dict_items = b''
keys = data.keys()
for key in sorted(keys):
encoded_dict_items += bencode(key)
encoded_dict_items += bencode(data[key])
encoded_dict_items += _bencode(key)
encoded_dict_items += _bencode(data[key])
return b'd%se' % encoded_dict_items
else:
raise TypeError("Cannot bencode '%s' object" % type(data))
raise TypeError(f"Cannot bencode {type(data)}")
def bdecode(data):
""" Decoder implementation of the Bencode algorithm. """
assert type(data) == bytes # fixme: _maybe_ remove this after porting
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
try:
return _decode_recursive(data)[0]
except ValueError as e:
raise DecodeError(str(e))
def _decode_recursive(data, start_index=0):
def _bdecode(data: bytes, start_index: int = 0) -> typing.Tuple[typing.Union[int, bytes, list, tuple, dict], int]:
if data[start_index] == ord('i'):
end_pos = data[start_index:].find(b'e') + start_index
return int(data[start_index + 1:end_pos]), end_pos + 1
@ -44,32 +33,44 @@ def _decode_recursive(data, start_index=0):
start_index += 1
decoded_list = []
while data[start_index] != ord('e'):
list_data, start_index = _decode_recursive(data, start_index)
list_data, start_index = _bdecode(data, start_index)
decoded_list.append(list_data)
return decoded_list, start_index + 1
elif data[start_index] == ord('d'):
start_index += 1
decoded_dict = {}
while data[start_index] != ord('e'):
key, start_index = _decode_recursive(data, start_index)
value, start_index = _decode_recursive(data, start_index)
key, start_index = _bdecode(data, start_index)
value, start_index = _bdecode(data, start_index)
decoded_dict[key] = value
return decoded_dict, start_index
elif data[start_index] == ord('f'):
# This (float data type) is a non-standard extension to the original Bencode algorithm
end_pos = data[start_index:].find(b'e') + start_index
return float(data[start_index + 1:end_pos]), end_pos + 1
elif data[start_index] == ord('n'):
# This (None/NULL data type) is a non-standard extension
# to the original Bencode algorithm
return None, start_index + 1
else:
split_pos = data[start_index:].find(b':') + start_index
try:
length = int(data[start_index:split_pos])
except ValueError:
raise DecodeError()
except (ValueError, TypeError) as err:
raise DecodeError(err)
start_index = split_pos + 1
end_pos = start_index + length
b = data[start_index:end_pos]
return b, end_pos
def bencode(data: typing.Dict) -> bytes:
if not isinstance(data, dict):
raise TypeError()
return _bencode(data)
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}")
if len(data) == 0:
raise DecodeError('Cannot decode empty string')
try:
result = _bdecode(data)[0]
if not allow_non_dict_return and not isinstance(result, dict):
raise ValueError(f'expected dict, got {type(result)}')
return result
except (ValueError, TypeError) as err:
raise DecodeError(err)

View file

@ -0,0 +1,181 @@
import typing
from functools import reduce
from lbrynet.dht import constants
from lbrynet.dht.serialization.bencoding import bencode, bdecode
REQUEST_TYPE = 0
RESPONSE_TYPE = 1
ERROR_TYPE = 2
class KademliaDatagramBase:
"""
field names are used to unwrap/wrap the argument names to index integers that replace them in a datagram
all packets have an argument dictionary when bdecoded starting with {0: <int>, 1: <bytes>, 2: <bytes>, ...}
these correspond to the packet_type, rpc_id, and node_id args
"""
fields = [
'packet_type',
'rpc_id',
'node_id'
]
expected_packet_type = -1
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes):
self.packet_type = packet_type
if self.expected_packet_type != packet_type:
raise ValueError(f"invalid packet type: {packet_type}, expected {self.expected_packet_type}")
if len(rpc_id) != constants.rpc_id_length:
raise ValueError(f"invalid rpc node_id: {len(rpc_id)} bytes (expected 20)")
if not len(node_id) == constants.hash_length:
raise ValueError(f"invalid node node_id: {len(node_id)} bytes (expected 48)")
self.rpc_id = rpc_id
self.node_id = node_id
def bencode(self) -> bytes:
return bencode({
i: getattr(self, k) for i, k in enumerate(self.fields)
})
class RequestDatagram(KademliaDatagramBase):
fields = [
'packet_type',
'rpc_id',
'node_id',
'method',
'args'
]
expected_packet_type = REQUEST_TYPE
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, method: bytes,
args: typing.Optional[typing.List] = None):
super().__init__(packet_type, rpc_id, node_id)
self.method = method
self.args = args or []
if not self.args:
self.args.append({})
if isinstance(self.args[-1], dict):
self.args[-1][b'protocolVersion'] = 1
else:
self.args.append({b'protocolVersion': 1})
@classmethod
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
elif not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
@classmethod
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
store_args = [blob_hash, token, port, from_node_id, 0]
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
@classmethod
def make_find_node(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
@classmethod
def make_find_value(cls, from_node_id: bytes, key: bytes,
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
if rpc_id and len(rpc_id) != constants.rpc_id_length:
raise ValueError("invalid rpc id length")
if not rpc_id:
rpc_id = constants.generate_id()[:constants.rpc_id_length]
if len(from_node_id) != constants.hash_bits // 8:
raise ValueError("invalid node id")
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
class ResponseDatagram(KademliaDatagramBase):
fields = [
'packet_type',
'rpc_id',
'node_id',
'response'
]
expected_packet_type = RESPONSE_TYPE
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, response):
super().__init__(packet_type, rpc_id, node_id)
self.response = response
class ErrorDatagram(KademliaDatagramBase):
fields = [
'packet_type',
'rpc_id',
'node_id',
'exception_type',
'response',
]
expected_packet_type = ERROR_TYPE
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, exception_type: bytes, response: bytes):
super().__init__(packet_type, rpc_id, node_id)
self.exception_type = exception_type.decode()
self.response = response.decode()
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
msg_types = {
REQUEST_TYPE: RequestDatagram,
RESPONSE_TYPE: ResponseDatagram,
ERROR_TYPE: ErrorDatagram
}
primitive: typing.Dict = bdecode(datagram)
if not isinstance(primitive, dict):
raise ValueError("invalid datagram type")
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
else:
raise ValueError("invalid datagram type")
datagram_class = msg_types[datagram_type]
return datagram_class(**{
k: primitive[i] # pylint: disable=unsubscriptable-object
for i, k in enumerate(datagram_class.fields)
if i in primitive # pylint: disable=unsupported-membership-test
}
)
def make_compact_ip(address: str):
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
compact_ip = make_compact_ip(address)
if not 0 <= port <= 65536:
raise ValueError(f'Invalid port: {port}')
return compact_ip + port.to_bytes(2, 'big') + node_id
def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, int]:
address = "{}.{}.{}.{}".format(*compact_address[:4])
port = int.from_bytes(compact_address[4:6], 'big')
node_id = compact_address[6:]
return node_id, address, port

View file

@ -1,92 +0,0 @@
import binascii
import logging
from twisted.internet import defer, task
from lbrynet.extras.compat import f2d
from lbrynet import utils
from lbrynet.conf import Config
log = logging.getLogger(__name__)
class DHTHashAnnouncer:
def __init__(self, conf: Config, dht_node, storage, concurrent_announcers=None):
self.conf = conf
self.dht_node = dht_node
self.storage = storage
self.clock = dht_node.clock
self.peer_port = dht_node.peerPort
self.hash_queue = []
if concurrent_announcers is None:
self.concurrent_announcers = conf.concurrent_announcers
else:
self.concurrent_announcers = concurrent_announcers
self._manage_lc = None
if self.concurrent_announcers:
self._manage_lc = task.LoopingCall(self.manage)
self._manage_lc.clock = self.clock
self.sem = defer.DeferredSemaphore(self.concurrent_announcers or conf.concurrent_announcers or 1)
def start(self):
if self._manage_lc:
self._manage_lc.start(30)
def stop(self):
if self._manage_lc and self._manage_lc.running:
self._manage_lc.stop()
@defer.inlineCallbacks
def do_store(self, blob_hash):
storing_node_ids = yield self.dht_node.announceHaveBlob(binascii.unhexlify(blob_hash))
now = self.clock.seconds()
if storing_node_ids:
result = (now, storing_node_ids)
yield f2d(self.storage.update_last_announced_blob(blob_hash, now))
log.debug("Stored %s to %i peers", blob_hash[:16], len(storing_node_ids))
else:
result = (None, [])
self.hash_queue.remove(blob_hash)
defer.returnValue(result)
def hash_queue_size(self):
return len(self.hash_queue)
def _show_announce_progress(self, size, start):
queue_size = len(self.hash_queue)
average_blobs_per_second = float(size - queue_size) / (self.clock.seconds() - start)
log.info("Announced %i/%i blobs, %f blobs per second", size - queue_size, size, average_blobs_per_second)
@defer.inlineCallbacks
def immediate_announce(self, blob_hashes):
self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue)
log.info("Announcing %i blobs", len(self.hash_queue))
start = self.clock.seconds()
progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start)
progress_lc.clock = self.clock
progress_lc.start(60, now=False)
results = yield utils.DeferredDict(
{blob_hash: self.sem.run(self.do_store, blob_hash) for blob_hash in blob_hashes}
)
now = self.clock.seconds()
progress_lc.stop()
announced_to = [blob_hash for blob_hash in results if results[blob_hash][0]]
if len(announced_to) != len(results):
log.debug("Failed to announce %i blobs", len(results) - len(announced_to))
if announced_to:
log.info('Took %s seconds to announce %i of %i attempted hashes (%f hashes per second)',
now - start, len(announced_to), len(blob_hashes),
int(float(len(blob_hashes)) / float(now - start)))
defer.returnValue(results)
@defer.inlineCallbacks
def manage(self):
if not self.dht_node.contacts:
log.info("Not ready to start announcing hashes")
return
need_reannouncement = yield f2d(self.storage.get_blobs_to_announce())
if need_reannouncement:
yield self.immediate_announce(need_reannouncement)
else:
log.debug("Nothing to announce")

View file

@ -1,70 +0,0 @@
import binascii
import logging
from twisted.internet import defer
log = logging.getLogger(__name__)
class DummyPeerFinder:
"""This class finds peers which have announced to the DHT that they have certain blobs"""
def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
return defer.succeed([])
class DHTPeerFinder(DummyPeerFinder):
"""This class finds peers which have announced to the DHT that they have certain blobs"""
#implements(IPeerFinder)
def __init__(self, component_manager):
"""
component_manager - an instance of ComponentManager
"""
self.component_manager = component_manager
self.peer_manager = component_manager.peer_manager
self.peers = {}
self._ongoing_searchs = {}
def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
"""
Find peers for blob in the DHT
blob_hash (str): blob hash to look for
timeout (int): seconds to timeout after
filter_self (bool): if True, and if a peer for a blob is itself, filter it
from the result
Returns:
list of peers for the blob
"""
if "dht" in self.component_manager.skip_components:
return defer.succeed([])
if not self.component_manager.all_components_running("dht"):
return defer.succeed([])
dht_node = self.component_manager.get_component("dht")
self.peers.setdefault(blob_hash, {(dht_node.externalIP, dht_node.peerPort,)})
if not blob_hash in self._ongoing_searchs or self._ongoing_searchs[blob_hash].called:
self._ongoing_searchs[blob_hash] = self._execute_peer_search(dht_node, blob_hash, timeout)
def _filter_self(blob_hash):
my_host, my_port = dht_node.externalIP, dht_node.peerPort
return {(host, port) for host, port in self.peers[blob_hash] if (host, port) != (my_host, my_port)}
peers = set(_filter_self(blob_hash) if filter_self else self.peers[blob_hash])
return defer.succeed([self.peer_manager.get_peer(*peer) for peer in peers])
@defer.inlineCallbacks
def _execute_peer_search(self, dht_node, blob_hash, timeout):
bin_hash = binascii.unhexlify(blob_hash)
finished_deferred = dht_node.iterativeFindValue(bin_hash, exclude=self.peers[blob_hash])
timeout = timeout or self.component_manager.conf.peer_search_timeout
if timeout:
finished_deferred.addTimeout(timeout, dht_node.clock)
try:
peer_list = yield finished_deferred
self.peers[blob_hash].update({(host, port) for _, host, port in peer_list})
except defer.TimeoutError:
log.debug("DHT timed out while looking peers for blob %s after %s seconds", blob_hash, timeout)
finally:
del self._ongoing_searchs[blob_hash]

View file

@ -1,14 +0,0 @@
from lbrynet.p2p.Peer import Peer
class PeerManager:
def __init__(self):
self.peers = []
def get_peer(self, host, port):
for peer in self.peers:
if peer.host == host and peer.port == port:
return peer
peer = Peer(host, port)
self.peers.append(peer)
return peer

View file

@ -1,46 +0,0 @@
import datetime
from collections import defaultdict
from lbrynet import utils
# Do not create this object except through PeerManager
class Peer:
def __init__(self, host, port):
self.host = host
self.port = port
# If a peer is reported down, we wait till this time till reattempting connection
self.attempt_connection_at = None
# Number of times this Peer has been reported to be down, resets to 0 when it is up
self.down_count = 0
# Number of successful connections (with full protocol completion) to this peer
self.success_count = 0
self.score = 0
self.stats = defaultdict(float) # {string stat_type, float count}
def is_available(self):
if self.attempt_connection_at is None or utils.today() > self.attempt_connection_at:
return True
return False
def report_up(self):
self.down_count = 0
self.attempt_connection_at = None
def report_success(self):
self.success_count += 1
def report_down(self):
self.down_count += 1
timeout_time = datetime.timedelta(seconds=60 * self.down_count)
self.attempt_connection_at = utils.today() + timeout_time
def update_score(self, score_change):
self.score += score_change
def update_stats(self, stat_type, count):
self.stats[stat_type] += count
def __str__(self):
return f'{self.host}:{self.port}'
def __repr__(self):
return f'Peer({self.host!r}, {self.port!r})'

83
tests/dht_mocks.py Normal file
View file

@ -0,0 +1,83 @@
import typing
import contextlib
import socket
import mock
import functools
import asyncio
if typing.TYPE_CHECKING:
from lbrynet.dht.protocol.protocol import KademliaProtocol
def get_time_accelerator(loop: asyncio.BaseEventLoop,
now: typing.Optional[float] = None) -> typing.Callable[[float], typing.Awaitable[None]]:
"""
Returns an async advance() function
This provides a way to advance() the BaseEventLoop.time for the scheduled TimerHandles
made by call_later, call_at, and call_soon.
"""
_time = now or loop.time()
loop.time = functools.wraps(loop.time)(lambda: _time)
async def accelerate_time(seconds: float) -> None:
nonlocal _time
if seconds < 0:
raise ValueError('Cannot go back in time ({} seconds)'.format(seconds))
_time += seconds
await past_events()
await asyncio.sleep(0)
async def past_events() -> None:
while loop._scheduled:
timer: asyncio.TimerHandle = loop._scheduled[0]
if timer not in loop._ready and timer._when <= _time:
loop._scheduled.remove(timer)
loop._ready.append(timer)
if timer._when > _time:
break
await asyncio.sleep(0)
async def accelerator(seconds: float):
steps = seconds * 10.0
for _ in range(max(int(steps), 1)):
await accelerate_time(0.1)
return accelerator
@contextlib.contextmanager
def mock_network_loop(loop: asyncio.BaseEventLoop):
dht_network: typing.Dict[typing.Tuple[str, int], 'KademliaProtocol'] = {}
async def create_datagram_endpoint(proto_lam: typing.Callable[[], 'KademliaProtocol'],
from_addr: typing.Tuple[str, int]):
def sendto(data, to_addr):
rx = dht_network.get(to_addr)
if rx and rx.external_ip:
# print(f"{from_addr[0]}:{from_addr[1]} -{len(data)} bytes-> {rx.external_ip}:{rx.udp_port}")
return rx.datagram_received(data, from_addr)
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
dht_network[from_addr] = protocol
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
yield

View file

View file

@ -0,0 +1,91 @@
import asyncio
from torba.testcase import AsyncioTestCase
from tests import dht_mocks
from lbrynet.dht import constants
from lbrynet.dht.protocol.protocol import KademliaProtocol
from lbrynet.dht.peer import PeerManager
class TestProtocol(AsyncioTestCase):
async def test_ping(self):
loop = asyncio.get_event_loop()
with dht_mocks.mock_network_loop(loop):
node_id1 = constants.generate_id()
peer1 = KademliaProtocol(
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
)
peer2 = KademliaProtocol(
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
)
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
result = await peer2.get_rpc_peer(peer).ping()
self.assertEqual(result, b'pong')
peer1.stop()
peer2.stop()
peer1.disconnect()
peer2.disconnect()
async def test_update_token(self):
loop = asyncio.get_event_loop()
with dht_mocks.mock_network_loop(loop):
node_id1 = constants.generate_id()
peer1 = KademliaProtocol(
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
)
peer2 = KademliaProtocol(
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
)
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
self.assertEqual(None, peer2.peer_manager.get_node_token(peer.node_id))
await peer2.get_rpc_peer(peer).find_value(b'1' * 48)
self.assertNotEqual(None, peer2.peer_manager.get_node_token(peer.node_id))
peer1.stop()
peer2.stop()
peer1.disconnect()
peer2.disconnect()
async def test_store_to_peer(self):
loop = asyncio.get_event_loop()
with dht_mocks.mock_network_loop(loop):
node_id1 = constants.generate_id()
peer1 = KademliaProtocol(
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
)
peer2 = KademliaProtocol(
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
)
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
peer2_from_peer1 = peer1.peer_manager.get_kademlia_peer(
peer2.node_id, peer2.external_ip, udp_port=peer2.udp_port
)
peer3 = peer1.peer_manager.get_kademlia_peer(
constants.generate_id(), '1.2.3.6', udp_port=4444
)
store_result = await peer2.store_to_peer(b'2' * 48, peer)
self.assertEqual(store_result[0], peer.node_id)
self.assertEqual(True, store_result[1])
self.assertEqual(True, peer1.data_store.has_peers_for_blob(b'2' * 48))
self.assertEqual(False, peer1.data_store.has_peers_for_blob(b'3' * 48))
self.assertListEqual([peer2_from_peer1], peer1.data_store.get_storing_contacts())
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
self.assertSetEqual(
{b'2' * 48, b'token', b'protocolVersion'}, set(find_value_response.keys())
)
self.assertEqual(1, len(find_value_response[b'2' * 48]))
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1)
# self.assertEqual(peer2_from_peer1.tcp_port, 3333)
peer1.stop()
peer2.stop()
peer1.disconnect()
peer2.disconnect()

View file

View file

@ -0,0 +1,115 @@
import struct
import asyncio
from lbrynet.utils import generate_id
from lbrynet.dht.protocol.routing_table import KBucket
from lbrynet.dht.peer import PeerManager
from lbrynet.dht import constants
from torba.testcase import AsyncioTestCase
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
class TestKBucket(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.address_generator = address_generator()
self.peer_manager = PeerManager(self.loop)
self.kbucket = KBucket(self.peer_manager, 0, 2**constants.hash_bits, generate_id())
def test_add_peer(self):
# Test if contacts can be added to empty list
# Add k contacts to bucket
for i in range(constants.k):
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
self.assertTrue(self.kbucket.add_peer(peer))
self.assertEqual(peer, self.kbucket.peers[i])
# Test if contact is not added to full list
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
self.assertFalse(self.kbucket.add_peer(peer))
# Test if an existing contact is updated correctly if added again
existing_peer = self.kbucket.peers[0]
self.assertTrue(self.kbucket.add_peer(existing_peer))
self.assertEqual(existing_peer, self.kbucket.peers[-1])
# def testGetContacts(self):
# # try and get 2 contacts from empty list
# result = self.kbucket.getContacts(2)
# self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" %
# (len(result)))
#
# # Add k-2 contacts
# node_ids = []
# if constants.k >= 2:
# for i in range(constants.k-2):
# node_ids.append(generate_id())
# tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
# None)
# self.kbucket.addContact(tmpContact)
# else:
# # add k contacts
# for i in range(constants.k):
# node_ids.append(generate_id())
# tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
# None)
# self.kbucket.addContact(tmpContact)
#
# # try to get too many contacts
# # requested count greater than bucket size; should return at most k contacts
# contacts = self.kbucket.getContacts(constants.k+3)
# self.assertTrue(len(contacts) <= constants.k,
# 'Returned list should not have more than k entries!')
#
# # verify returned contacts in list
# for node_id, i in zip(node_ids, range(constants.k-2)):
# self.assertFalse(self.kbucket._contacts[i].id != node_id,
# "Contact in position %s not same as added contact" % (str(i)))
#
# # try to get too many contacts
# # requested count one greater than number of contacts
# if constants.k >= 2:
# result = self.kbucket.getContacts(constants.k-1)
# self.assertFalse(len(result) != constants.k-2,
# "Too many contacts in returned list %s - should be %s" %
# (len(result), constants.k-2))
# else:
# result = self.kbucket.getContacts(constants.k-1)
# # if the count is <= 0, it should return all of it's contats
# self.assertFalse(len(result) != constants.k,
# "Too many contacts in returned list %s - should be %s" %
# (len(result), constants.k-2))
# result = self.kbucket.getContacts(constants.k-3)
# self.assertFalse(len(result) != constants.k-3,
# "Too many contacts in returned list %s - should be %s" %
# (len(result), constants.k-3))
def test_remove_peer(self):
# try remove contact from empty list
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
self.assertRaises(ValueError, self.kbucket.remove_peer, peer)
added = []
# Add couple contacts
for i in range(constants.k-2):
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
self.assertTrue(self.kbucket.add_peer(peer))
added.append(peer)
while added:
peer = added.pop()
self.assertIn(peer, self.kbucket.peers)
self.kbucket.remove_peer(peer)
self.assertNotIn(peer, self.kbucket.peers)

View file

@ -0,0 +1,259 @@
import asyncio
from torba.testcase import AsyncioTestCase
from tests import dht_mocks
from lbrynet.dht import constants
from lbrynet.dht.node import Node
from lbrynet.dht.peer import PeerManager
class TestRouting(AsyncioTestCase):
async def test_fill_one_bucket(self):
loop = asyncio.get_event_loop()
peer_addresses = [
(constants.generate_id(1), '1.2.3.1'),
(constants.generate_id(2), '1.2.3.2'),
(constants.generate_id(3), '1.2.3.3'),
(constants.generate_id(4), '1.2.3.4'),
(constants.generate_id(5), '1.2.3.5'),
(constants.generate_id(6), '1.2.3.6'),
(constants.generate_id(7), '1.2.3.7'),
(constants.generate_id(8), '1.2.3.8'),
(constants.generate_id(9), '1.2.3.9'),
]
with dht_mocks.mock_network_loop(loop):
nodes = {
i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address)
for i, (node_id, address) in enumerate(peer_addresses)
}
node_1 = nodes[0]
contact_cnt = 0
for i in range(1, len(peer_addresses)):
self.assertEqual(len(node_1.protocol.routing_table.get_peers()), contact_cnt)
node = nodes[i]
peer = node_1.protocol.peer_manager.get_kademlia_peer(
node.protocol.node_id, node.protocol.external_ip,
udp_port=node.protocol.udp_port
)
added = await node_1.protocol.add_peer(peer)
self.assertEqual(True, added)
contact_cnt += 1
self.assertEqual(len(node_1.protocol.routing_table.get_peers()), 8)
self.assertEqual(node_1.protocol.routing_table.buckets_with_contacts(), 1)
for node in nodes.values():
node.protocol.stop()
# from binascii import hexlify, unhexlify
#
# from twisted.trial import unittest
# from twisted.internet import defer
# from lbrynet.dht import constants
# from lbrynet.dht.routingtable import TreeRoutingTable
# from lbrynet.dht.contact import ContactManager
# from lbrynet.dht.distance import Distance
# from lbrynet.utils import generate_id
#
#
# class FakeRPCProtocol:
# """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
# def sendRPC(self, *args, **kwargs):
# return defer.succeed(None)
#
#
# class TreeRoutingTableTest(unittest.TestCase):
# """ Test case for the RoutingTable class """
# def setUp(self):
# self.contact_manager = ContactManager()
# self.nodeID = generate_id(b'node1')
# self.protocol = FakeRPCProtocol()
# self.routingTable = TreeRoutingTable(self.nodeID)
#
# def test_distance(self):
# """ Test to see if distance method returns correct result"""
# d = Distance(bytes((170,) * 48))
# result = d(bytes((85,) * 48))
# expected = int(hexlify(bytes((255,) * 48)), 16)
# self.assertEqual(result, expected)
#
# @defer.inlineCallbacks
# def test_add_contact(self):
# """ Tests if a contact can be added and retrieved correctly """
# # Create the contact
# contact_id = generate_id(b'node2')
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# # Now add it...
# yield self.routingTable.addContact(contact)
# # ...and request the closest nodes to it (will retrieve it)
# closest_nodes = self.routingTable.findCloseNodes(contact_id)
# self.assertEqual(len(closest_nodes), 1)
# self.assertIn(contact, closest_nodes)
#
# @defer.inlineCallbacks
# def test_get_contact(self):
# """ Tests if a specific existing contact can be retrieved correctly """
# contact_id = generate_id(b'node2')
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# # Now add it...
# yield self.routingTable.addContact(contact)
# # ...and get it again
# same_contact = self.routingTable.getContact(contact_id)
# self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
#
# @defer.inlineCallbacks
# def test_add_parent_node_as_contact(self):
# """
# Tests the routing table's behaviour when attempting to add its parent node as a contact
# """
# # Create a contact with the same ID as the local node's ID
# contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
# # Now try to add it
# yield self.routingTable.addContact(contact)
# # ...and request the closest nodes to it using FIND_NODE
# closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
# self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
#
# @defer.inlineCallbacks
# def test_remove_contact(self):
# """ Tests contact removal """
# # Create the contact
# contact_id = generate_id(b'node2')
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# # Now add it...
# yield self.routingTable.addContact(contact)
# # Verify addition
# self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
# # Now remove it
# self.routingTable.removeContact(contact)
# self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
#
# @defer.inlineCallbacks
# def test_split_bucket(self):
# """ Tests if the the routing table correctly dynamically splits k-buckets """
# self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
# 'Initial k-bucket range should be 0 <= range < 2**384')
# # Add k contacts
# for i in range(constants.k):
# node_id = generate_id(b'remote node %d' % i)
# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
# yield self.routingTable.addContact(contact)
#
# self.assertEqual(len(self.routingTable._buckets), 1,
# 'Only k nodes have been added; the first k-bucket should now '
# 'be full, but should not yet be split')
# # Now add 1 more contact
# node_id = generate_id(b'yet another remote node')
# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
# yield self.routingTable.addContact(contact)
# self.assertEqual(len(self.routingTable._buckets), 2,
# 'k+1 nodes have been added; the first k-bucket should have been '
# 'split into two new buckets')
# self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384,
# 'K-bucket was split, but its range was not properly adjusted')
# self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384,
# 'K-bucket was split, but the second (new) bucket\'s '
# 'max range was not set properly')
# self.assertEqual(self.routingTable._buckets[0].rangeMax,
# self.routingTable._buckets[1].rangeMin,
# 'K-bucket was split, but the min/max ranges were '
# 'not divided properly')
#
# @defer.inlineCallbacks
# def test_full_split(self):
# """
# Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
# """
#
# self.routingTable._parentNodeID = bytes(48 * b'\xff')
#
# node_ids = [
# b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
# b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
# ]
#
# # Add k contacts
# for nodeID in node_ids:
# # self.assertEquals(nodeID, node_ids[i].decode('hex'))
# contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol)
# yield self.routingTable.addContact(contact)
# self.assertEqual(len(self.routingTable._buckets), 2)
# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
#
# # try adding a contact who is further from us than the k'th known contact
# nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
# nodeID = unhexlify(nodeID)
# contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
# self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
# yield self.routingTable.addContact(contact)
# self.assertEqual(len(self.routingTable._buckets), 2)
# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
# self.assertNotIn(contact, self.routingTable._buckets[0]._contacts)
# self.assertNotIn(contact, self.routingTable._buckets[1]._contacts)
#
# class KeyErrorFixedTest(unittest.TestCase):
# """ Basic tests case for boolean operators on the Contact class """
#
# def setUp(self):
# own_id = (2 ** constants.key_bits) - 1
# # carefully chosen own_id. here's the logic
# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
# # end
#
# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
#
# def fill_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# for i in range(bucket_min, bucket_min + bucket_size):
# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
#
# def overflow_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# self.fill_bucket(bucket_min)
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
# '127.0.0.1', 9999, None))
#
# def testKeyError(self):
#
# # find middle, so we know where bucket will split
# bucket_middle = self.table._buckets[0].rangeMax / 2
#
# # fill last bucket
# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
# # -1 in previous line because own_id is in last bucket
#
# # fill/overflow 7 more buckets
# bucket_start = 0
# for i in range(0, lbrynet.dht.constants.k):
# self.overflow_bucket(bucket_start)
# bucket_start += bucket_middle / (2 ** i)
#
# # replacement cache now has k-1 entries.
# # adding one more contact to bucket 0 used to cause a KeyError, but it should work
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
#
# # import math
# # print ""
# # for i, bucket in enumerate(self.table._buckets):
# # print "Bucket " + str(i) + " (2 ** " + str(
# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
# # math.log(bucket.rangeMax, 2)) + ")"
# # for c in bucket.getContacts():
# # print " contact " + str(c.id)
# # for key, bucket in self.table._replacementCache.items():
# # print "Replacement Cache for Bucket " + str(key)
# # for c in bucket:
# # print " contact " + str(c.id)

View file

View file

@ -0,0 +1,64 @@
import unittest
from lbrynet.dht.serialization.bencoding import _bencode, bencode, bdecode, DecodeError
class EncodeDecodeTest(unittest.TestCase):
def test_fail_with_not_dict(self):
with self.assertRaises(TypeError):
bencode(1)
with self.assertRaises(TypeError):
bencode(b'derp')
with self.assertRaises(TypeError):
bencode('derp')
with self.assertRaises(TypeError):
bencode([b'derp'])
with self.assertRaises(TypeError):
bencode([object()])
with self.assertRaises(TypeError):
bencode({b'derp': object()})
def test_fail_bad_type(self):
with self.assertRaises(DecodeError):
bdecode(b'd4le', True)
def test_integer(self):
self.assertEqual(_bencode(42), b'i42e')
self.assertEqual(bdecode(b'i42e', True), 42)
def test_bytes(self):
self.assertEqual(_bencode(b''), b'0:')
self.assertEqual(_bencode(b'spam'), b'4:spam')
self.assertEqual(_bencode(b'4:spam'), b'6:4:spam')
self.assertEqual(_bencode(bytearray(b'spam')), b'4:spam')
self.assertEqual(bdecode(b'0:', True), b'')
self.assertEqual(bdecode(b'4:spam', True), b'spam')
self.assertEqual(bdecode(b'6:4:spam', True), b'4:spam')
def test_string(self):
self.assertEqual(_bencode(''), b'0:')
self.assertEqual(_bencode('spam'), b'4:spam')
self.assertEqual(_bencode('4:spam'), b'6:4:spam')
def test_list(self):
self.assertEqual(_bencode([b'spam', 42]), b'l4:spami42ee')
self.assertEqual(bdecode(b'l4:spami42ee', True), [b'spam', 42])
def test_dict(self):
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
def test_mixed(self):
self.assertEqual(_bencode(
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
)
self.assertEqual(bdecode(
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee', True),
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
)
def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
self.assertRaises(DecodeError, bdecode, b'', True)

View file

@ -0,0 +1,60 @@
import unittest
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE
class TestDatagram(unittest.TestCase):
def test_ping_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'ping')
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
def test_ping_response(self):
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.response, b'pong')
def test_find_node_request_datagram(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findNode')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
def test_find_node_response(self):
closest_response = [(b'3' * 48, '1.2.3.4', 1234)]
expected = [[b'3' * 48, b'1.2.3.4', 1234]]
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, closest_response).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.response, expected)
def test_find_value_request(self):
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertEqual(decoded.method, b'findValue')
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
def test_find_value_response(self):
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
decoded = decode_datagram(serialized)
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
self.assertEqual(decoded.rpc_id, b'1' * 20)
self.assertEqual(decoded.node_id, b'1' * 48)
self.assertDictEqual(decoded.response, found_value_response)

View file

@ -0,0 +1,88 @@
import asyncio
from torba.testcase import AsyncioTestCase
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
class MockAsyncGen:
def __init__(self, loop, result, delay, stop_cnt=10):
self.loop = loop
self.result = result
self.delay = delay
self.count = 0
self.stop_cnt = stop_cnt
self.called_close = False
def __aiter__(self):
return self
async def __anext__(self):
if self.count > self.stop_cnt - 1:
raise StopAsyncIteration()
self.count += 1
await asyncio.sleep(self.delay, loop=self.loop)
return self.result
async def aclose(self):
self.called_close = True
class TestAsyncGeneratorJunction(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
async def _test_junction(self, expected, *generators):
order = []
async with AsyncGeneratorJunction(self.loop) as junction:
for generator in generators:
junction.add_generator(generator)
async for item in junction:
order.append(item)
self.assertListEqual(order, expected)
async def test_yield_order(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2]
fast_gen = MockAsyncGen(self.loop, 1, 0.1)
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
await self._test_junction(expected_order, fast_gen, slow_gen)
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)
async def test_one_stopped_first(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
fast_gen = MockAsyncGen(self.loop, 1, 0.1, 5)
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
await self._test_junction(expected_order, fast_gen, slow_gen)
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)
async def test_with_non_async_gen_class(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
async def fast_gen():
for i in range(10):
if i == 5:
return
await asyncio.sleep(0.1)
yield 1
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
await self._test_junction(expected_order, fast_gen(), slow_gen)
self.assertEqual(slow_gen.called_close, True)
async def test_stop_when_encapsulating_task_cancelled(self):
fast_gen = MockAsyncGen(self.loop, 1, 0.1)
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
async def _task():
async with AsyncGeneratorJunction(self.loop) as junction:
junction.add_generator(fast_gen)
junction.add_generator(slow_gen)
async for _ in junction:
pass
task = self.loop.create_task(_task())
self.loop.call_later(0.5, task.cancel)
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)

View file

@ -1,50 +0,0 @@
from twisted.trial import unittest
from lbrynet.dht.encoding import bencode, bdecode, DecodeError
class EncodeDecodeTest(unittest.TestCase):
def test_integer(self):
self.assertEqual(bencode(42), b'i42e')
self.assertEqual(bdecode(b'i42e'), 42)
def test_bytes(self):
self.assertEqual(bencode(b''), b'0:')
self.assertEqual(bencode(b'spam'), b'4:spam')
self.assertEqual(bencode(b'4:spam'), b'6:4:spam')
self.assertEqual(bencode(bytearray(b'spam')), b'4:spam')
self.assertEqual(bdecode(b'0:'), b'')
self.assertEqual(bdecode(b'4:spam'), b'spam')
self.assertEqual(bdecode(b'6:4:spam'), b'4:spam')
def test_string(self):
self.assertEqual(bencode(''), b'0:')
self.assertEqual(bencode('spam'), b'4:spam')
self.assertEqual(bencode('4:spam'), b'6:4:spam')
def test_list(self):
self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee')
self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42])
def test_dict(self):
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
def test_mixed(self):
self.assertEqual(bencode(
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
)
self.assertEqual(bdecode(
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'),
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
)
def test_decode_error(self):
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz')
self.assertRaises(DecodeError, bdecode, b'')

View file

@ -1,59 +0,0 @@
from twisted.trial import unittest
from twisted.internet import defer, task
from lbrynet import utils
from lbrynet.conf import Config
from lbrynet.extras.daemon.HashAnnouncer import DHTHashAnnouncer
from tests.test_utils import random_lbry_hash
class MocDHTNode:
def __init__(self):
self.blobs_announced = 0
self.clock = task.Clock()
self.peerPort = 3333
def announceHaveBlob(self, blob):
self.blobs_announced += 1
d = defer.Deferred()
self.clock.callLater(1, d.callback, ['fake'])
return d
class MocStorage:
def __init__(self, blobs_to_announce):
self.blobs_to_announce = blobs_to_announce
self.announced = False
def get_blobs_to_announce(self):
if not self.announced:
self.announced = True
return defer.succeed(self.blobs_to_announce)
else:
return defer.succeed([])
def update_last_announced_blob(self, blob_hash, now):
return defer.succeed(None)
class DHTHashAnnouncerTest(unittest.TestCase):
def setUp(self):
conf = Config()
self.num_blobs = 10
self.blobs_to_announce = []
for i in range(0, self.num_blobs):
self.blobs_to_announce.append(random_lbry_hash())
self.dht_node = MocDHTNode()
self.clock = self.dht_node.clock
utils.call_later = self.clock.callLater
self.storage = MocStorage(self.blobs_to_announce)
self.announcer = DHTHashAnnouncer(conf, self.dht_node, self.storage)
@defer.inlineCallbacks
def test_immediate_announce(self):
announce_d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(self.announcer.hash_queue_size(), self.num_blobs)
self.clock.advance(1)
yield announce_d
self.assertEqual(self.dht_node.blobs_announced, self.num_blobs)
self.assertEqual(self.announcer.hash_queue_size(), 0)

View file

@ -1,125 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
from twisted.trial import unittest
import struct
from lbrynet.utils import generate_id
from lbrynet.dht import kbucket
from lbrynet.dht.contact import ContactManager
from lbrynet.dht import constants
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
class KBucketTest(unittest.TestCase):
""" Test case for the KBucket class """
def setUp(self):
self.address_generator = address_generator()
self.contact_manager = ContactManager()
self.kbucket = kbucket.KBucket(0, 2**constants.key_bits, generate_id())
def testAddContact(self):
""" Tests if the bucket handles contact additions/updates correctly """
# Test if contacts can be added to empty list
# Add k contacts to bucket
for i in range(constants.k):
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.kbucket.addContact(tmpContact)
self.assertEqual(
self.kbucket._contacts[i],
tmpContact,
"Contact in position %d not the same as the newly-added contact" % i)
# Test if contact is not added to full list
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.assertRaises(kbucket.BucketFull, self.kbucket.addContact, tmpContact)
# Test if an existing contact is updated correctly if added again
existingContact = self.kbucket._contacts[0]
self.kbucket.addContact(existingContact)
self.assertEqual(
self.kbucket._contacts.index(existingContact),
len(self.kbucket._contacts)-1,
'Contact not correctly updated; it should be at the end of the list of contacts')
def testGetContacts(self):
# try and get 2 contacts from empty list
result = self.kbucket.getContacts(2)
self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" %
(len(result)))
# Add k-2 contacts
node_ids = []
if constants.k >= 2:
for i in range(constants.k-2):
node_ids.append(generate_id())
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
None)
self.kbucket.addContact(tmpContact)
else:
# add k contacts
for i in range(constants.k):
node_ids.append(generate_id())
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
None)
self.kbucket.addContact(tmpContact)
# try to get too many contacts
# requested count greater than bucket size; should return at most k contacts
contacts = self.kbucket.getContacts(constants.k+3)
self.assertTrue(len(contacts) <= constants.k,
'Returned list should not have more than k entries!')
# verify returned contacts in list
for node_id, i in zip(node_ids, range(constants.k-2)):
self.assertFalse(self.kbucket._contacts[i].id != node_id,
"Contact in position %s not same as added contact" % (str(i)))
# try to get too many contacts
# requested count one greater than number of contacts
if constants.k >= 2:
result = self.kbucket.getContacts(constants.k-1)
self.assertFalse(len(result) != constants.k-2,
"Too many contacts in returned list %s - should be %s" %
(len(result), constants.k-2))
else:
result = self.kbucket.getContacts(constants.k-1)
# if the count is <= 0, it should return all of it's contats
self.assertFalse(len(result) != constants.k,
"Too many contacts in returned list %s - should be %s" %
(len(result), constants.k-2))
result = self.kbucket.getContacts(constants.k-3)
self.assertFalse(len(result) != constants.k-3,
"Too many contacts in returned list %s - should be %s" %
(len(result), constants.k-3))
def testRemoveContact(self):
# try remove contact from empty list
rmContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.assertRaises(ValueError, self.kbucket.removeContact, rmContact)
# Add couple contacts
for i in range(constants.k-2):
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.kbucket.addContact(tmpContact)
# try remove contact from empty list
self.kbucket.addContact(rmContact)
result = self.kbucket.removeContact(rmContact)
self.assertNotIn(rmContact, self.kbucket._contacts, "Could not remove contact from bucket")

View file

@ -1,84 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
from twisted.trial import unittest
from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage
from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat
class DefaultFormatTranslatorTest(unittest.TestCase):
""" Test case for the default message translator """
def setUp(self):
self.cases = ((RequestMessage('1' * 48, 'rpcMethod',
{'arg1': 'a string', 'arg2': 123}, '1' * 20),
{DefaultFormat.headerType: DefaultFormat.typeRequest,
DefaultFormat.headerNodeID: '1' * 48,
DefaultFormat.headerMsgID: '1' * 20,
DefaultFormat.headerPayload: 'rpcMethod',
DefaultFormat.headerArgs: {'arg1': 'a string', 'arg2': 123}}),
(ResponseMessage('2' * 20, '2' * 48, 'response'),
{DefaultFormat.headerType: DefaultFormat.typeResponse,
DefaultFormat.headerNodeID: '2' * 48,
DefaultFormat.headerMsgID: '2' * 20,
DefaultFormat.headerPayload: 'response'}),
(ErrorMessage('3' * 20, '3' * 48,
"<type 'exceptions.ValueError'>", 'this is a test exception'),
{DefaultFormat.headerType: DefaultFormat.typeError,
DefaultFormat.headerNodeID: '3' * 48,
DefaultFormat.headerMsgID: '3' * 20,
DefaultFormat.headerPayload: "<type 'exceptions.ValueError'>",
DefaultFormat.headerArgs: 'this is a test exception'}),
(ResponseMessage(
'4' * 20, '4' * 48,
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
'127.0.0.1', 1919),
('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{',
'127.0.0.1', 1921)]),
{DefaultFormat.headerType: DefaultFormat.typeResponse,
DefaultFormat.headerNodeID: '4' * 48,
DefaultFormat.headerMsgID: '4' * 20,
DefaultFormat.headerPayload:
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
'127.0.0.1', 1919),
('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{',
'127.0.0.1', 1921)]})
)
self.translator = DefaultFormat()
self.assertTrue(
isinstance(self.translator, MessageTranslator),
'Translator class must inherit from entangled.kademlia.msgformat.MessageTranslator!')
def testToPrimitive(self):
""" Tests translation from a Message object to a primitive """
for msg, msgPrimitive in self.cases:
translatedObj = self.translator.toPrimitive(msg)
self.assertEqual(len(translatedObj), len(msgPrimitive),
"Translated object does not match example object's size")
for key in msgPrimitive:
self.assertEqual(
translatedObj[key], msgPrimitive[key],
'Message object type %s not translated correctly into primitive on '
'key "%s"; expected "%s", got "%s"' %
(msg.__class__.__name__, key, msgPrimitive[key], translatedObj[key]))
def testFromPrimitive(self):
""" Tests translation from a primitive to a Message object """
for msg, msgPrimitive in self.cases:
translatedObj = self.translator.fromPrimitive(msgPrimitive)
self.assertEqual(
type(translatedObj), type(msg),
'Message type incorrectly translated; expected "%s", got "%s"' %
(type(msg), type(translatedObj)))
for key in msg.__dict__:
self.assertEqual(
msg.__dict__[key], translatedObj.__dict__[key],
'Message instance variable "%s" not translated correctly; '
'expected "%s", got "%s"' %
(key, msg.__dict__[key], translatedObj.__dict__[key]))

View file

@ -1,88 +1,85 @@
import hashlib
import struct
from twisted.trial import unittest
from twisted.internet import defer
from lbrynet.dht.node import Node
import asyncio
import typing
from torba.testcase import AsyncioTestCase
from tests import dht_mocks
from lbrynet.dht import constants
from lbrynet.utils import generate_id
from lbrynet.dht.node import Node
from lbrynet.dht.peer import PeerManager
class NodeIDTest(unittest.TestCase):
class TestNodePingQueueDiscover(AsyncioTestCase):
async def test_ping_queue_discover(self):
loop = asyncio.get_event_loop()
def setUp(self):
self.node = Node()
peer_addresses = [
(constants.generate_id(1), '1.2.3.1'),
(constants.generate_id(2), '1.2.3.2'),
(constants.generate_id(3), '1.2.3.3'),
(constants.generate_id(4), '1.2.3.4'),
(constants.generate_id(5), '1.2.3.5'),
(constants.generate_id(6), '1.2.3.6'),
(constants.generate_id(7), '1.2.3.7'),
(constants.generate_id(8), '1.2.3.8'),
(constants.generate_id(9), '1.2.3.9'),
]
with dht_mocks.mock_network_loop(loop):
advance = dht_mocks.get_time_accelerator(loop, loop.time())
# start the nodes
nodes: typing.Dict[int, Node] = {
i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address)
for i, (node_id, address) in enumerate(peer_addresses)
}
for i, n in nodes.items():
n.start(peer_addresses[i][1], [])
def test_new_node_has_auto_created_id(self):
self.assertEqual(type(self.node.node_id), bytes)
self.assertEqual(len(self.node.node_id), 48)
await advance(1)
def test_uniqueness_and_length_of_generated_ids(self):
previous_ids = []
for i in range(100):
new_id = self.node._generateID()
self.assertNotIn(new_id, previous_ids, f'id at index {i} not unique')
self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id)))
previous_ids.append(new_id)
node_1 = nodes[0]
# ping 8 nodes from node_1, this will result in a delayed return ping
futs = []
for i in range(1, len(peer_addresses)):
node = nodes[i]
assert node.protocol.node_id != node_1.protocol.node_id
peer = node_1.protocol.peer_manager.get_kademlia_peer(
node.protocol.node_id, node.protocol.external_ip, udp_port=node.protocol.udp_port
)
futs.append(node_1.protocol.get_rpc_peer(peer).ping())
await advance(3)
replies = await asyncio.gather(*tuple(futs))
self.assertTrue(all(map(lambda reply: reply == b"pong", replies)))
class NodeDataTest(unittest.TestCase):
""" Test case for the Node class's data-related functions """
# run for long enough for the delayed pings to have been sent by node 1
await advance(1000)
def setUp(self):
h = hashlib.sha384()
h.update(b'test')
self.node = Node()
self.contact = self.node.contact_manager.make_contact(
h.digest(), '127.0.0.1', 12345, self.node._protocol)
self.token = self.node.make_token(self.contact.compact_ip())
self.cases = []
for i in range(5):
h.update(str(i).encode())
self.cases.append((h.digest(), 5000+2*i))
self.cases.append((h.digest(), 5001+2*i))
# verify all of the previously pinged peers have node_1 in their routing tables
for n in nodes.values():
peers = n.protocol.routing_table.get_peers()
if n is node_1:
self.assertEqual(8, len(peers))
else:
self.assertEqual(1, len(peers))
self.assertEqual((peers[0].node_id, peers[0].address, peers[0].udp_port),
(node_1.protocol.node_id, node_1.protocol.external_ip, node_1.protocol.udp_port))
@defer.inlineCallbacks
def test_store(self):
""" Tests if the node can store (and privately retrieve) some data """
for key, port in self.cases:
yield self.node.store(
self.contact, key, self.token, port, self.contact.id, 0
)
for key, value in self.cases:
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id
self.assertTrue(self.node._dataStore.hasPeersForBlob(key),
"Stored key not found in node's DataStore: '%s'" % key)
self.assertIn(expected_result, self.node._dataStore.getPeersForBlob(key),
"Stored val not found in node's DataStore: key:'%s' port:'%s' %s"
% (key, value, self.node._dataStore.getPeersForBlob(key)))
# run long enough for the refresh loop to run
await advance(3600)
# verify all the nodes know about each other
for n in nodes.values():
if n is node_1:
continue
peers = n.protocol.routing_table.get_peers()
self.assertEqual(8, len(peers))
self.assertSetEqual(
{n_id[0] for n_id in peer_addresses if n_id[0] != n.protocol.node_id},
{c.node_id for c in peers}
)
self.assertSetEqual(
{n_addr[1] for n_addr in peer_addresses if n_addr[1] != n.protocol.external_ip},
{c.address for c in peers}
)
class NodeContactTest(unittest.TestCase):
""" Test case for the Node class's contact management-related functions """
def setUp(self):
self.node = Node()
@defer.inlineCallbacks
def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """
# Create the contact
contact_id = generate_id(b'node1')
contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol)
# Now add it...
yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k)
self.assertEqual(len(closest_nodes), 1)
self.assertIn(contact, closest_nodes)
@defer.inlineCallbacks
def test_add_self_as_contact(self):
""" Tests the node's behaviour when attempting to add itself as a contact """
# Create a contact with the same ID as the local node's ID
contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
# Now try to add it
yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k)
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.')
# teardown
for n in nodes.values():
n.stop()

View file

@ -1,64 +1,37 @@
from binascii import hexlify
from twisted.internet import task
from twisted.trial import unittest
import asyncio
import unittest
from lbrynet.utils import generate_id
from lbrynet.dht.contact import ContactManager
from lbrynet.dht import constants
from lbrynet.dht.peer import PeerManager
from torba.testcase import AsyncioTestCase
class ContactTest(unittest.TestCase):
""" Basic tests case for boolean operators on the Contact class """
class PeerTest(AsyncioTestCase):
def setUp(self):
self.contact_manager = ContactManager()
self.loop = asyncio.get_event_loop()
self.peer_manager = PeerManager(self.loop)
self.node_ids = [generate_id(), generate_id(), generate_id()]
make_contact = self.contact_manager.make_contact
self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1)
self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.second_contact_second_reference = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50)
self.first_contact = self.peer_manager.get_kademlia_peer(self.node_ids[1], '127.0.0.1', udp_port=1000)
self.second_contact = self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000)
def test_make_contact_error_cases(self):
self.assertRaises(
ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', 100000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', 1000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', 1000, None)
self.assertRaises(
ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None)
def test_no_duplicate_contact_objects(self):
self.assertIs(self.second_contact, self.second_contact_second_reference)
self.assertIsNot(self.first_contact, self.first_contact_different_values)
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', 100000)
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20.1', 1000)
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], 'this is not an ip', 1000)
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', -1000)
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, b'not valid node id', '192.168.1.20', 1000)
def test_boolean(self):
""" Test "equals" and "not equals" comparisons """
self.assertNotEqual(
self.first_contact, self.contact_manager.make_contact(
self.first_contact.id, self.first_contact.address, self.first_contact.port + 1, None, 32
)
self.assertNotEqual(self.first_contact, self.second_contact)
self.assertEqual(
self.second_contact, self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000)
)
self.assertNotEqual(
self.first_contact, self.contact_manager.make_contact(
self.first_contact.id, '193.168.1.1', self.first_contact.port, None, 32
)
)
self.assertNotEqual(
self.first_contact, self.contact_manager.make_contact(
generate_id(), self.first_contact.address, self.first_contact.port, None, 32
)
)
self.assertEqual(self.second_contact, self.second_contact_second_reference)
def test_compact_ip(self):
self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01')
self.assertEqual(self.second_contact.compact_ip(), b'\xc0\xa8\x00\x01')
def test_id_log(self):
self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1]))
self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8])
@unittest.SkipTest
class TestContactLastReplied(unittest.TestCase):
def setUp(self):
self.clock = task.Clock()
@ -129,6 +102,7 @@ class TestContactLastReplied(unittest.TestCase):
self.assertIsNone(self.contact.contact_is_good)
@unittest.SkipTest
class TestContactLastRequested(unittest.TestCase):
def setUp(self):
self.clock = task.Clock()

View file

@ -1,213 +0,0 @@
from binascii import hexlify, unhexlify
from twisted.trial import unittest
from twisted.internet import defer
from lbrynet.dht import constants
from lbrynet.dht.routingtable import TreeRoutingTable
from lbrynet.dht.contact import ContactManager
from lbrynet.dht.distance import Distance
from lbrynet.utils import generate_id
class FakeRPCProtocol:
""" Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
def sendRPC(self, *args, **kwargs):
return defer.succeed(None)
class TreeRoutingTableTest(unittest.TestCase):
""" Test case for the RoutingTable class """
def setUp(self):
self.contact_manager = ContactManager()
self.nodeID = generate_id(b'node1')
self.protocol = FakeRPCProtocol()
self.routingTable = TreeRoutingTable(self.nodeID)
def test_distance(self):
""" Test to see if distance method returns correct result"""
d = Distance(bytes((170,) * 48))
result = d(bytes((85,) * 48))
expected = int(hexlify(bytes((255,) * 48)), 16)
self.assertEqual(result, expected)
@defer.inlineCallbacks
def test_add_contact(self):
""" Tests if a contact can be added and retrieved correctly """
# Create the contact
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it (will retrieve it)
closest_nodes = self.routingTable.findCloseNodes(contact_id)
self.assertEqual(len(closest_nodes), 1)
self.assertIn(contact, closest_nodes)
@defer.inlineCallbacks
def test_get_contact(self):
""" Tests if a specific existing contact can be retrieved correctly """
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# ...and get it again
same_contact = self.routingTable.getContact(contact_id)
self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
@defer.inlineCallbacks
def test_add_parent_node_as_contact(self):
"""
Tests the routing table's behaviour when attempting to add its parent node as a contact
"""
# Create a contact with the same ID as the local node's ID
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
# Now try to add it
yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE
closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
@defer.inlineCallbacks
def test_remove_contact(self):
""" Tests contact removal """
# Create the contact
contact_id = generate_id(b'node2')
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
# Now add it...
yield self.routingTable.addContact(contact)
# Verify addition
self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
# Now remove it
self.routingTable.removeContact(contact)
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
@defer.inlineCallbacks
def test_split_bucket(self):
""" Tests if the the routing table correctly dynamically splits k-buckets """
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
'Initial k-bucket range should be 0 <= range < 2**384')
# Add k contacts
for i in range(constants.k):
node_id = generate_id(b'remote node %d' % i)
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 1,
'Only k nodes have been added; the first k-bucket should now '
'be full, but should not yet be split')
# Now add 1 more contact
node_id = generate_id(b'yet another remote node')
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 2,
'k+1 nodes have been added; the first k-bucket should have been '
'split into two new buckets')
self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384,
'K-bucket was split, but its range was not properly adjusted')
self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384,
'K-bucket was split, but the second (new) bucket\'s '
'max range was not set properly')
self.assertEqual(self.routingTable._buckets[0].rangeMax,
self.routingTable._buckets[1].rangeMin,
'K-bucket was split, but the min/max ranges were '
'not divided properly')
@defer.inlineCallbacks
def test_full_split(self):
"""
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
"""
self.routingTable._parentNodeID = bytes(48 * b'\xff')
node_ids = [
b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
]
# Add k contacts
for nodeID in node_ids:
# self.assertEquals(nodeID, node_ids[i].decode('hex'))
contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol)
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 2)
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
# try adding a contact who is further from us than the k'th known contact
nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
nodeID = unhexlify(nodeID)
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
yield self.routingTable.addContact(contact)
self.assertEqual(len(self.routingTable._buckets), 2)
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
self.assertNotIn(contact, self.routingTable._buckets[0]._contacts)
self.assertNotIn(contact, self.routingTable._buckets[1]._contacts)
# class KeyErrorFixedTest(unittest.TestCase):
# """ Basic tests case for boolean operators on the Contact class """
#
# def setUp(self):
# own_id = (2 ** constants.key_bits) - 1
# # carefully chosen own_id. here's the logic
# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
# # end
#
# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
#
# def fill_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# for i in range(bucket_min, bucket_min + bucket_size):
# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
#
# def overflow_bucket(self, bucket_min):
# bucket_size = lbrynet.dht.constants.k
# self.fill_bucket(bucket_min)
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
# '127.0.0.1', 9999, None))
#
# def testKeyError(self):
#
# # find middle, so we know where bucket will split
# bucket_middle = self.table._buckets[0].rangeMax / 2
#
# # fill last bucket
# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
# # -1 in previous line because own_id is in last bucket
#
# # fill/overflow 7 more buckets
# bucket_start = 0
# for i in range(0, lbrynet.dht.constants.k):
# self.overflow_bucket(bucket_start)
# bucket_start += bucket_middle / (2 ** i)
#
# # replacement cache now has k-1 entries.
# # adding one more contact to bucket 0 used to cause a KeyError, but it should work
# self.table.addContact(
# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
#
# # import math
# # print ""
# # for i, bucket in enumerate(self.table._buckets):
# # print "Bucket " + str(i) + " (2 ** " + str(
# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
# # math.log(bucket.rangeMax, 2)) + ")"
# # for c in bucket.getContacts():
# # print " contact " + str(c.id)
# # for key, bucket in self.table._replacementCache.items():
# # print "Replacement Cache for Bucket " + str(key)
# # for c in bucket:
# # print " contact " + str(c.id)