async lbrynet.dht
This commit is contained in:
parent
a5524d490c
commit
2fa5233796
49 changed files with 3086 additions and 3577 deletions
|
@ -1,4 +1,8 @@
|
|||
import hashlib
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
|
||||
backend = default_backend()
|
||||
|
||||
|
||||
def get_lbry_hash_obj():
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
Francois Aucamp <faucamp@csir.co.za>
|
||||
|
||||
Thanks goes to the following people for providing patches/suggestions/tests:
|
||||
|
||||
Neil Kleynhans <ntkleynhans@csir.co.za>
|
||||
Haiyang Ma <haiyang.ma@maidsafe.net>
|
||||
Bryan McAlister <bmcalister@csir.co.za>
|
|
@ -1,165 +0,0 @@
|
|||
GNU LESSER GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
|
||||
This version of the GNU Lesser General Public License incorporates
|
||||
the terms and conditions of version 3 of the GNU General Public
|
||||
License, supplemented by the additional permissions listed below.
|
||||
|
||||
0. Additional Definitions.
|
||||
|
||||
As used herein, "this License" refers to version 3 of the GNU Lesser
|
||||
General Public License, and the "GNU GPL" refers to version 3 of the GNU
|
||||
General Public License.
|
||||
|
||||
"The Library" refers to a covered work governed by this License,
|
||||
other than an Application or a Combined Work as defined below.
|
||||
|
||||
An "Application" is any work that makes use of an interface provided
|
||||
by the Library, but which is not otherwise based on the Library.
|
||||
Defining a subclass of a class defined by the Library is deemed a mode
|
||||
of using an interface provided by the Library.
|
||||
|
||||
A "Combined Work" is a work produced by combining or linking an
|
||||
Application with the Library. The particular version of the Library
|
||||
with which the Combined Work was made is also called the "Linked
|
||||
Version".
|
||||
|
||||
The "Minimal Corresponding Source" for a Combined Work means the
|
||||
Corresponding Source for the Combined Work, excluding any source code
|
||||
for portions of the Combined Work that, considered in isolation, are
|
||||
based on the Application, and not on the Linked Version.
|
||||
|
||||
The "Corresponding Application Code" for a Combined Work means the
|
||||
object code and/or source code for the Application, including any data
|
||||
and utility programs needed for reproducing the Combined Work from the
|
||||
Application, but excluding the System Libraries of the Combined Work.
|
||||
|
||||
1. Exception to Section 3 of the GNU GPL.
|
||||
|
||||
You may convey a covered work under sections 3 and 4 of this License
|
||||
without being bound by section 3 of the GNU GPL.
|
||||
|
||||
2. Conveying Modified Versions.
|
||||
|
||||
If you modify a copy of the Library, and, in your modifications, a
|
||||
facility refers to a function or data to be supplied by an Application
|
||||
that uses the facility (other than as an argument passed when the
|
||||
facility is invoked), then you may convey a copy of the modified
|
||||
version:
|
||||
|
||||
a) under this License, provided that you make a good faith effort to
|
||||
ensure that, in the event an Application does not supply the
|
||||
function or data, the facility still operates, and performs
|
||||
whatever part of its purpose remains meaningful, or
|
||||
|
||||
b) under the GNU GPL, with none of the additional permissions of
|
||||
this License applicable to that copy.
|
||||
|
||||
3. Object Code Incorporating Material from Library Header Files.
|
||||
|
||||
The object code form of an Application may incorporate material from
|
||||
a header file that is part of the Library. You may convey such object
|
||||
code under terms of your choice, provided that, if the incorporated
|
||||
material is not limited to numerical parameters, data structure
|
||||
layouts and accessors, or small macros, inline functions and templates
|
||||
(ten or fewer lines in length), you do both of the following:
|
||||
|
||||
a) Give prominent notice with each copy of the object code that the
|
||||
Library is used in it and that the Library and its use are
|
||||
covered by this License.
|
||||
|
||||
b) Accompany the object code with a copy of the GNU GPL and this license
|
||||
document.
|
||||
|
||||
4. Combined Works.
|
||||
|
||||
You may convey a Combined Work under terms of your choice that,
|
||||
taken together, effectively do not restrict modification of the
|
||||
portions of the Library contained in the Combined Work and reverse
|
||||
engineering for debugging such modifications, if you also do each of
|
||||
the following:
|
||||
|
||||
a) Give prominent notice with each copy of the Combined Work that
|
||||
the Library is used in it and that the Library and its use are
|
||||
covered by this License.
|
||||
|
||||
b) Accompany the Combined Work with a copy of the GNU GPL and this license
|
||||
document.
|
||||
|
||||
c) For a Combined Work that displays copyright notices during
|
||||
execution, include the copyright notice for the Library among
|
||||
these notices, as well as a reference directing the user to the
|
||||
copies of the GNU GPL and this license document.
|
||||
|
||||
d) Do one of the following:
|
||||
|
||||
0) Convey the Minimal Corresponding Source under the terms of this
|
||||
License, and the Corresponding Application Code in a form
|
||||
suitable for, and under terms that permit, the user to
|
||||
recombine or relink the Application with a modified version of
|
||||
the Linked Version to produce a modified Combined Work, in the
|
||||
manner specified by section 6 of the GNU GPL for conveying
|
||||
Corresponding Source.
|
||||
|
||||
1) Use a suitable shared library mechanism for linking with the
|
||||
Library. A suitable mechanism is one that (a) uses at run time
|
||||
a copy of the Library already present on the user's computer
|
||||
system, and (b) will operate properly with a modified version
|
||||
of the Library that is interface-compatible with the Linked
|
||||
Version.
|
||||
|
||||
e) Provide Installation Information, but only if you would otherwise
|
||||
be required to provide such information under section 6 of the
|
||||
GNU GPL, and only to the extent that such information is
|
||||
necessary to install and execute a modified version of the
|
||||
Combined Work produced by recombining or relinking the
|
||||
Application with a modified version of the Linked Version. (If
|
||||
you use option 4d0, the Installation Information must accompany
|
||||
the Minimal Corresponding Source and Corresponding Application
|
||||
Code. If you use option 4d1, you must provide the Installation
|
||||
Information in the manner specified by section 6 of the GNU GPL
|
||||
for conveying Corresponding Source.)
|
||||
|
||||
5. Combined Libraries.
|
||||
|
||||
You may place library facilities that are a work based on the
|
||||
Library side by side in a single library together with other library
|
||||
facilities that are not Applications and are not covered by this
|
||||
License, and convey such a combined library under terms of your
|
||||
choice, if you do both of the following:
|
||||
|
||||
a) Accompany the combined library with a copy of the same work based
|
||||
on the Library, uncombined with any other library facilities,
|
||||
conveyed under the terms of this License.
|
||||
|
||||
b) Give prominent notice with the combined library that part of it
|
||||
is a work based on the Library, and explaining where to find the
|
||||
accompanying uncombined form of the same work.
|
||||
|
||||
6. Revised Versions of the GNU Lesser General Public License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions
|
||||
of the GNU Lesser General Public License from time to time. Such new
|
||||
versions will be similar in spirit to the present version, but may
|
||||
differ in detail to address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Library as you received it specifies that a certain numbered version
|
||||
of the GNU Lesser General Public License "or any later version"
|
||||
applies to it, you have the option of following the terms and
|
||||
conditions either of that published version or of any later version
|
||||
published by the Free Software Foundation. If the Library as you
|
||||
received it does not specify a version number of the GNU Lesser
|
||||
General Public License, you may choose any version of the GNU Lesser
|
||||
General Public License ever published by the Free Software Foundation.
|
||||
|
||||
If the Library as you received it specifies that a proxy can decide
|
||||
whether future versions of the GNU Lesser General Public License shall
|
||||
apply, that proxy's public statement of acceptance of any version is
|
||||
permanent authorization for you to choose that version for the
|
||||
Library.
|
65
lbrynet/dht/blob_announcer.py
Normal file
65
lbrynet/dht/blob_announcer.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlobAnnouncer:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, node: 'Node', storage: 'SQLiteStorage'):
|
||||
self.loop = loop
|
||||
self.node = node
|
||||
self.storage = storage
|
||||
self.pending_call: asyncio.Handle = None
|
||||
self.announce_task: asyncio.Task = None
|
||||
self.running = False
|
||||
self.announce_queue: typing.List[str] = []
|
||||
|
||||
async def _announce(self, batch_size: typing.Optional[int] = 10):
|
||||
if not self.node.joined.is_set():
|
||||
await self.node.joined.wait()
|
||||
blob_hashes = await self.storage.get_blobs_to_announce()
|
||||
if blob_hashes:
|
||||
self.announce_queue.extend(blob_hashes)
|
||||
log.info("%i blobs to announce", len(blob_hashes))
|
||||
batch = []
|
||||
while len(self.announce_queue):
|
||||
cnt = 0
|
||||
announced = []
|
||||
while self.announce_queue and cnt < batch_size:
|
||||
blob_hash = self.announce_queue.pop()
|
||||
announced.append(blob_hash)
|
||||
batch.append(self.node.announce_blob(blob_hash))
|
||||
cnt += 1
|
||||
to_await = []
|
||||
while batch:
|
||||
to_await.append(batch.pop())
|
||||
if to_await:
|
||||
await asyncio.gather(*tuple(to_await), loop=self.loop)
|
||||
await self.storage.update_last_announced_blobs(announced, self.loop.time())
|
||||
log.info("announced %i blobs", len(announced))
|
||||
if self.running:
|
||||
self.pending_call = self.loop.call_later(60, self.announce, batch_size)
|
||||
|
||||
def announce(self, batch_size: typing.Optional[int] = 10):
|
||||
self.announce_task = self.loop.create_task(self._announce(batch_size))
|
||||
|
||||
def start(self, batch_size: typing.Optional[int] = 10):
|
||||
if self.running:
|
||||
raise Exception("already running")
|
||||
self.running = True
|
||||
self.announce(batch_size)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
if self.pending_call:
|
||||
if not self.pending_call.cancelled():
|
||||
self.pending_call.cancel()
|
||||
self.pending_call = None
|
||||
if self.announce_task:
|
||||
if not (self.announce_task.done() or self.announce_task.cancelled()):
|
||||
self.announce_task.cancel()
|
||||
self.announce_task = None
|
|
@ -1,81 +0,0 @@
|
|||
import logging
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
MIN_DELAY = 0.0
|
||||
MAX_DELAY = 0.01
|
||||
DELAY_INCREMENT = 0.0001
|
||||
QUEUE_SIZE_THRESHOLD = 100
|
||||
|
||||
|
||||
class CallLaterManager:
|
||||
def __init__(self, callLater):
|
||||
"""
|
||||
:param callLater: (IReactorTime.callLater)
|
||||
"""
|
||||
|
||||
self._callLater = callLater
|
||||
self._pendingCallLaters = []
|
||||
self._delay = MIN_DELAY
|
||||
|
||||
def get_min_delay(self):
|
||||
self._pendingCallLaters = [cl for cl in self._pendingCallLaters if cl.active()]
|
||||
queue_size = len(self._pendingCallLaters)
|
||||
if queue_size > QUEUE_SIZE_THRESHOLD:
|
||||
self._delay = min((self._delay + DELAY_INCREMENT), MAX_DELAY)
|
||||
else:
|
||||
self._delay = max((self._delay - 2.0 * DELAY_INCREMENT), MIN_DELAY)
|
||||
return self._delay
|
||||
|
||||
def _cancel(self, call_later):
|
||||
"""
|
||||
:param call_later: DelayedCall
|
||||
:return: (callable) canceller function
|
||||
"""
|
||||
|
||||
def cancel(reason=None):
|
||||
"""
|
||||
:param reason: reason for cancellation, this is returned after cancelling the DelayedCall
|
||||
:return: reason
|
||||
"""
|
||||
|
||||
if call_later.active():
|
||||
call_later.cancel()
|
||||
if call_later in self._pendingCallLaters:
|
||||
self._pendingCallLaters.remove(call_later)
|
||||
return reason
|
||||
return cancel
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Cancel any callLaters that are still running
|
||||
"""
|
||||
|
||||
from twisted.internet import defer
|
||||
while self._pendingCallLaters:
|
||||
canceller = self._cancel(self._pendingCallLaters[0])
|
||||
try:
|
||||
canceller()
|
||||
except (defer.CancelledError, defer.AlreadyCalledError, ValueError):
|
||||
pass
|
||||
|
||||
def call_later(self, when, what, *args, **kwargs):
|
||||
"""
|
||||
Schedule a call later and get a canceller callback function
|
||||
|
||||
:param when: (float) delay in seconds
|
||||
:param what: (callable)
|
||||
:param args: (*tuple) args to be passed to the callable
|
||||
:param kwargs: (**dict) kwargs to be passed to the callable
|
||||
|
||||
:return: (tuple) twisted.internet.base.DelayedCall object, canceller function
|
||||
"""
|
||||
|
||||
call_later = self._callLater(when, what, *args, **kwargs)
|
||||
canceller = self._cancel(call_later)
|
||||
self._pendingCallLaters.append(call_later)
|
||||
return call_later, canceller
|
||||
|
||||
def call_soon(self, what, *args, **kwargs):
|
||||
delay = self.get_min_delay()
|
||||
return self.call_later(delay, what, *args, **kwargs)
|
|
@ -1,61 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
#
|
||||
# The docstrings in this module contain epytext markup; API documentation
|
||||
# may be created by processing this file with epydoc: http://epydoc.sf.net
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
""" This module defines the charaterizing constants of the Kademlia network
|
||||
|
||||
C{checkRefreshInterval} and C{udpDatagramMaxSize} are implementation-specific
|
||||
constants, and do not affect general Kademlia operation.
|
||||
"""
|
||||
|
||||
######### KADEMLIA CONSTANTS ###########
|
||||
|
||||
#: Small number Representing the degree of parallelism in network calls
|
||||
alpha = 3
|
||||
|
||||
#: Maximum number of contacts stored in a bucket; this should be an even number
|
||||
hash_class = hashlib.sha384
|
||||
hash_length = hash_class().digest_size
|
||||
hash_bits = hash_length * 8
|
||||
alpha = 5
|
||||
k = 8
|
||||
|
||||
#: Maximum number of contacts stored in the replacement cache
|
||||
replacementCacheSize = 8
|
||||
|
||||
#: Timeout for network operations (in seconds)
|
||||
rpcTimeout = 5
|
||||
|
||||
# number of rpc attempts to make before a timeout results in the node being removed as a contact
|
||||
rpcAttempts = 5
|
||||
# time window to count failures (in seconds)
|
||||
rpcAttemptsPruningTimeWindow = 600
|
||||
|
||||
# Delay between iterations of iterative node lookups (for loose parallelism) (in seconds)
|
||||
iterativeLookupDelay = rpcTimeout / 2
|
||||
|
||||
#: If a k-bucket has not been used for this amount of time, refresh it (in seconds)
|
||||
refreshTimeout = 3600 # 1 hour
|
||||
#: The interval at which nodes replicate (republish/refresh) data they are holding
|
||||
replicateInterval = refreshTimeout
|
||||
# The time it takes for data to expire in the network; the original publisher of the data
|
||||
# will also republish the data at this time if it is still valid
|
||||
dataExpireTimeout = 86400 # 24 hours
|
||||
|
||||
tokenSecretChangeInterval = 300 # 5 minutes
|
||||
|
||||
######## IMPLEMENTATION-SPECIFIC CONSTANTS ###########
|
||||
|
||||
#: The interval for the node to check whether any buckets need refreshing
|
||||
checkRefreshInterval = refreshTimeout / 5
|
||||
|
||||
#: Max size of a single UDP datagram, in bytes. If a message is larger than this, it will
|
||||
#: be spread across several UDP packets.
|
||||
udpDatagramMaxSize = 8192 # 8 KB
|
||||
|
||||
key_bits = 384
|
||||
|
||||
replacement_cache_size = 8
|
||||
rpc_timeout = 5.0
|
||||
rpc_attempts = 5
|
||||
rpc_attempts_pruning_window = 600
|
||||
iterative_lookup_delay = rpc_timeout / 2.0
|
||||
refresh_interval = 3600 # 1 hour
|
||||
replicate_interval = refresh_interval
|
||||
data_expiration = 86400 # 24 hours
|
||||
token_secret_refresh_interval = 300 # 5 minutes
|
||||
check_refresh_interval = refresh_interval / 5
|
||||
max_datagram_size = 8192 # 8 KB
|
||||
rpc_id_length = 20
|
||||
protocol_version = 1
|
||||
bottom_out_limit = 3
|
||||
msg_size_limit = max_datagram_size - 26
|
||||
|
||||
protocolVersion = 1
|
||||
|
||||
def digest(data: bytes) -> bytes:
|
||||
h = hash_class()
|
||||
h.update(data)
|
||||
return h.digest()
|
||||
|
||||
|
||||
def generate_id(num=None) -> bytes:
|
||||
if num is not None:
|
||||
return digest(str(num).encode())
|
||||
else:
|
||||
return digest(os.urandom(32))
|
||||
|
||||
|
||||
def generate_rpc_id(num=None) -> bytes:
|
||||
return generate_id(num)[:rpc_id_length]
|
||||
|
|
|
@ -1,189 +0,0 @@
|
|||
import ipaddress
|
||||
from binascii import hexlify
|
||||
from functools import reduce
|
||||
from lbrynet.dht import constants
|
||||
|
||||
|
||||
def is_valid_ipv4(address):
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
return ip.version == 4
|
||||
except ipaddress.AddressValueError:
|
||||
return False
|
||||
|
||||
|
||||
class _Contact:
|
||||
""" Encapsulation for remote contact
|
||||
|
||||
This class contains information on a single remote contact, and also
|
||||
provides a direct RPC API to the remote node which it represents
|
||||
"""
|
||||
|
||||
def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm):
|
||||
if id is not None:
|
||||
if not len(id) == constants.key_bits // 8:
|
||||
raise ValueError("invalid node id: {}".format(hexlify(id).decode()))
|
||||
if not 0 <= udpPort <= 65536:
|
||||
raise ValueError("invalid port")
|
||||
if not is_valid_ipv4(ipAddress):
|
||||
raise ValueError("invalid ip address")
|
||||
self._contactManager = contactManager
|
||||
self._id = id
|
||||
self.address = ipAddress
|
||||
self.port = udpPort
|
||||
self._networkProtocol = networkProtocol
|
||||
self.commTime = firstComm
|
||||
self.getTime = self._contactManager._get_time
|
||||
self.lastReplied = None
|
||||
self.lastRequested = None
|
||||
self.protocolVersion = 0
|
||||
self._token = (None, 0) # token, timestamp
|
||||
|
||||
def update_token(self, token):
|
||||
self._token = token, self.getTime() if token else 0
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
# expire the token 1 minute early to be safe
|
||||
return self._token[0] if self._token[1] + 240 > self.getTime() else None
|
||||
|
||||
@property
|
||||
def lastInteracted(self):
|
||||
return max(self.lastRequested or 0, self.lastReplied or 0, self.lastFailed or 0)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
def log_id(self, short=True):
|
||||
if not self.id:
|
||||
return "not initialized"
|
||||
id_hex = hexlify(self.id)
|
||||
return id_hex if not short else id_hex[:8]
|
||||
|
||||
@property
|
||||
def failedRPCs(self):
|
||||
return len(self.failures)
|
||||
|
||||
@property
|
||||
def lastFailed(self):
|
||||
return (self.failures or [None])[-1]
|
||||
|
||||
@property
|
||||
def failures(self):
|
||||
return self._contactManager._rpc_failures.get((self.address, self.port), [])
|
||||
|
||||
@property
|
||||
def contact_is_good(self):
|
||||
"""
|
||||
:return: False if contact is bad, None if contact is unknown, or True if contact is good
|
||||
"""
|
||||
failures = self.failures
|
||||
now = self.getTime()
|
||||
delay = constants.checkRefreshInterval
|
||||
|
||||
if failures:
|
||||
if self.lastReplied and len(failures) >= 2 and self.lastReplied < failures[-2]:
|
||||
return False
|
||||
elif self.lastReplied and len(failures) >= 2 and self.lastReplied > failures[-2]:
|
||||
pass # handled below
|
||||
elif len(failures) >= 2:
|
||||
return False
|
||||
|
||||
if self.lastReplied and self.lastReplied > now - delay:
|
||||
return True
|
||||
if self.lastReplied and self.lastRequested and self.lastRequested > now - delay:
|
||||
return True
|
||||
return None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _Contact):
|
||||
raise TypeError("invalid type to compare with Contact: %s" % str(type(other)))
|
||||
return (self.id, self.address, self.port) == (other.id, other.address, other.port)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.id, self.address, self.port))
|
||||
|
||||
def compact_ip(self):
|
||||
compact_ip = reduce(
|
||||
lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray())
|
||||
return compact_ip
|
||||
|
||||
def set_id(self, id):
|
||||
if not self._id:
|
||||
self._id = id
|
||||
|
||||
def update_last_replied(self):
|
||||
self.lastReplied = int(self.getTime())
|
||||
|
||||
def update_last_requested(self):
|
||||
self.lastRequested = int(self.getTime())
|
||||
|
||||
def update_last_failed(self):
|
||||
failures = self._contactManager._rpc_failures.get((self.address, self.port), [])
|
||||
failures.append(self.getTime())
|
||||
self._contactManager._rpc_failures[(self.address, self.port)] = failures
|
||||
|
||||
def update_protocol_version(self, version):
|
||||
self.protocolVersion = version
|
||||
|
||||
def __str__(self):
|
||||
return '<%s.%s object; IP address: %s, UDP port: %d>' % (
|
||||
self.__module__, self.__class__.__name__, self.address, self.port)
|
||||
|
||||
def __getattr__(self, name):
|
||||
""" This override allows the host node to call a method of the remote
|
||||
node (i.e. this contact) as if it was a local function.
|
||||
|
||||
For instance, if C{remoteNode} is a instance of C{Contact}, the
|
||||
following will result in C{remoteNode}'s C{test()} method to be
|
||||
called with argument C{123}::
|
||||
remoteNode.test(123)
|
||||
|
||||
Such a RPC method call will return a Deferred, which will callback
|
||||
when the contact responds with the result (or an error occurs).
|
||||
This happens via this contact's C{_networkProtocol} object (i.e. the
|
||||
host Node's C{_protocol} object).
|
||||
"""
|
||||
|
||||
if name not in ['ping', 'findValue', 'findNode', 'store']:
|
||||
raise AttributeError("unknown command: %s" % name)
|
||||
|
||||
def _sendRPC(*args, **kwargs):
|
||||
return self._networkProtocol.sendRPC(self, name.encode(), args)
|
||||
|
||||
return _sendRPC
|
||||
|
||||
|
||||
class ContactManager:
|
||||
def __init__(self, get_time=None):
|
||||
if not get_time:
|
||||
from twisted.internet import reactor
|
||||
get_time = reactor.seconds
|
||||
self._get_time = get_time
|
||||
self._contacts = {}
|
||||
self._rpc_failures = {}
|
||||
|
||||
def get_contact(self, id, address, port):
|
||||
for contact in self._contacts.values():
|
||||
if contact.id == id and contact.address == address and contact.port == port:
|
||||
return contact
|
||||
|
||||
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
|
||||
contact = self.get_contact(id, ipAddress, udpPort)
|
||||
if contact:
|
||||
return contact
|
||||
contact = _Contact(self, id, ipAddress, udpPort, networkProtocol, firstComm or self._get_time())
|
||||
self._contacts[(id, ipAddress, udpPort)] = contact
|
||||
return contact
|
||||
|
||||
def is_ignored(self, origin_tuple):
|
||||
failed_rpc_count = len(self._prune_failures(origin_tuple))
|
||||
return failed_rpc_count > constants.rpcAttempts
|
||||
|
||||
def _prune_failures(self, origin_tuple):
|
||||
# Prunes recorded failures to the last time window of attempts
|
||||
pruning_limit = self._get_time() - constants.rpcAttemptsPruningTimeWindow
|
||||
pruned = list(filter(lambda t: t >= pruning_limit, self._rpc_failures.get(origin_tuple, [])))
|
||||
self._rpc_failures[origin_tuple] = pruned
|
||||
return pruned
|
|
@ -1,66 +0,0 @@
|
|||
from collections import UserDict
|
||||
from lbrynet.dht import constants
|
||||
|
||||
|
||||
class DictDataStore(UserDict):
|
||||
""" A datastore using an in-memory Python dictionary """
|
||||
#implements(IDataStore)
|
||||
|
||||
def __init__(self, getTime=None):
|
||||
# Dictionary format:
|
||||
# { <key>: (<contact>, <value>, <lastPublished>, <originallyPublished> <originalPublisherID>) }
|
||||
super().__init__()
|
||||
if not getTime:
|
||||
from twisted.internet import reactor
|
||||
getTime = reactor.seconds
|
||||
self._getTime = getTime
|
||||
self.completed_blobs = set()
|
||||
|
||||
def filter_bad_and_expired_peers(self, key):
|
||||
"""
|
||||
Returns only non-expired and unknown/good peers
|
||||
"""
|
||||
return filter(
|
||||
lambda peer:
|
||||
self._getTime() - peer[3] < constants.dataExpireTimeout and peer[0].contact_is_good is not False,
|
||||
self[key]
|
||||
)
|
||||
|
||||
def filter_expired_peers(self, key):
|
||||
"""
|
||||
Returns only non-expired peers
|
||||
"""
|
||||
return filter(lambda peer: self._getTime() - peer[3] < constants.dataExpireTimeout, self[key])
|
||||
|
||||
def removeExpiredPeers(self):
|
||||
expired_keys = []
|
||||
for key in self.keys():
|
||||
unexpired_peers = list(self.filter_expired_peers(key))
|
||||
if not unexpired_peers:
|
||||
expired_keys.append(key)
|
||||
else:
|
||||
self[key] = unexpired_peers
|
||||
for key in expired_keys:
|
||||
del self[key]
|
||||
|
||||
def hasPeersForBlob(self, key):
|
||||
return bool(key in self and len(tuple(self.filter_bad_and_expired_peers(key))))
|
||||
|
||||
def addPeerToBlob(self, contact, key, compact_address, lastPublished, originallyPublished, originalPublisherID):
|
||||
if key in self:
|
||||
if compact_address not in map(lambda store_tuple: store_tuple[1], self[key]):
|
||||
self[key].append(
|
||||
(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)
|
||||
)
|
||||
else:
|
||||
self[key] = [(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)]
|
||||
|
||||
def getPeersForBlob(self, key):
|
||||
return [] if key not in self else [val[1] for val in self.filter_bad_and_expired_peers(key)]
|
||||
|
||||
def getStoringContacts(self):
|
||||
contacts = set()
|
||||
for key in self:
|
||||
for values in self[key]:
|
||||
contacts.add(values[0])
|
||||
return list(contacts)
|
|
@ -1,43 +1,23 @@
|
|||
import binascii
|
||||
#import exceptions
|
||||
|
||||
# this is a dict of {"exceptions.<exception class name>": exception class} items used to raise
|
||||
# remote built-in exceptions locally
|
||||
BUILTIN_EXCEPTIONS = {
|
||||
# "exceptions.%s" % e: getattr(exceptions, e) for e in dir(exceptions) if not e.startswith("_")
|
||||
}
|
||||
class BaseKademliaException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DecodeError(Exception):
|
||||
class DecodeError(BaseKademliaException):
|
||||
"""
|
||||
Should be raised by an C{Encoding} implementation if decode operation
|
||||
fails
|
||||
"""
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
class BucketFull(BaseKademliaException):
|
||||
"""
|
||||
Raised when the bucket is full
|
||||
"""
|
||||
|
||||
|
||||
class UnknownRemoteException(Exception):
|
||||
class RemoteException(BaseKademliaException):
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
""" Raised when a RPC times out """
|
||||
|
||||
def __init__(self, remote_contact_id):
|
||||
# remote_contact_id is a binary blob so we need to convert it
|
||||
# into something more readable
|
||||
if remote_contact_id:
|
||||
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id))
|
||||
else:
|
||||
msg = 'Timeout connecting to uninitialized node'
|
||||
super().__init__(msg)
|
||||
self.remote_contact_id = remote_contact_id
|
||||
|
||||
|
||||
class TransportNotConnected(Exception):
|
||||
class TransportNotConnected(BaseKademliaException):
|
||||
pass
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
from zope.interface import Interface
|
||||
|
||||
|
||||
class IDataStore(Interface):
|
||||
""" Interface for classes implementing physical storage (for data
|
||||
published via the "STORE" RPC) for the Kademlia DHT
|
||||
|
||||
@note: This provides an interface for a dict-like object
|
||||
"""
|
||||
|
||||
def keys(self):
|
||||
""" Return a list of the keys in this data store """
|
||||
|
||||
def removeExpiredPeers(self):
|
||||
pass
|
||||
|
||||
def hasPeersForBlob(self, key):
|
||||
pass
|
||||
|
||||
def addPeerToBlob(self, key, value, lastPublished, originallyPublished, originalPublisherID):
|
||||
pass
|
||||
|
||||
def getPeersForBlob(self, key):
|
||||
pass
|
||||
|
||||
def removePeer(self, key):
|
||||
pass
|
||||
|
||||
|
||||
class IRoutingTable(Interface):
|
||||
""" Interface for RPC message translators/formatters
|
||||
|
||||
Classes inheriting from this should provide a suitable routing table for
|
||||
a parent Node object (i.e. the local entity in the Kademlia network)
|
||||
"""
|
||||
|
||||
def __init__(self, parentNodeID):
|
||||
"""
|
||||
@param parentNodeID: The n-bit node ID of the node to which this
|
||||
routing table belongs
|
||||
@type parentNodeID: str
|
||||
"""
|
||||
|
||||
def addContact(self, contact):
|
||||
""" Add the given contact to the correct k-bucket; if it already
|
||||
exists, its status will be updated
|
||||
|
||||
@param contact: The contact to add to this node's k-buckets
|
||||
@type contact: kademlia.contact.Contact
|
||||
"""
|
||||
|
||||
def findCloseNodes(self, key, count, _rpcNodeID=None):
|
||||
""" Finds a number of known nodes closest to the node/value with the
|
||||
specified key.
|
||||
|
||||
@param key: the n-bit key (i.e. the node or value ID) to search for
|
||||
@type key: str
|
||||
@param count: the amount of contacts to return
|
||||
@type count: int
|
||||
@param _rpcNodeID: Used during RPC, this is be the sender's Node ID
|
||||
Whatever ID is passed in the parameter will get
|
||||
excluded from the list of returned contacts.
|
||||
@type _rpcNodeID: str
|
||||
|
||||
@return: A list of node contacts (C{kademlia.contact.Contact instances})
|
||||
closest to the specified key.
|
||||
This method will return C{k} (or C{count}, if specified)
|
||||
contacts if at all possible; it will only return fewer if the
|
||||
node is returning all of the contacts that it knows of.
|
||||
@rtype: list
|
||||
"""
|
||||
|
||||
def getContact(self, contactID):
|
||||
""" Returns the (known) contact with the specified node ID
|
||||
|
||||
@raise ValueError: No contact with the specified contact ID is known
|
||||
by this node
|
||||
"""
|
||||
|
||||
def getRefreshList(self, startIndex=0, force=False):
|
||||
""" Finds all k-buckets that need refreshing, starting at the
|
||||
k-bucket with the specified index, and returns IDs to be searched for
|
||||
in order to refresh those k-buckets
|
||||
|
||||
@param startIndex: The index of the bucket to start refreshing at;
|
||||
this bucket and those further away from it will
|
||||
be refreshed. For example, when joining the
|
||||
network, this node will set this to the index of
|
||||
the bucket after the one containing it's closest
|
||||
neighbour.
|
||||
@type startIndex: index
|
||||
@param force: If this is C{True}, all buckets (in the specified range)
|
||||
will be refreshed, regardless of the time they were last
|
||||
accessed.
|
||||
@type force: bool
|
||||
|
||||
@return: A list of node ID's that the parent node should search for
|
||||
in order to refresh the routing Table
|
||||
@rtype: list
|
||||
"""
|
||||
|
||||
def removeContact(self, contactID):
|
||||
""" Remove the contact with the specified node ID from the routing
|
||||
table
|
||||
|
||||
@param contactID: The node ID of the contact to remove
|
||||
@type contactID: str
|
||||
"""
|
||||
|
||||
def touchKBucket(self, key):
|
||||
""" Update the "last accessed" timestamp of the k-bucket which covers
|
||||
the range containing the specified key in the key/ID space
|
||||
|
||||
@param key: A key in the range of the target k-bucket
|
||||
@type key: str
|
||||
"""
|
|
@ -1,230 +0,0 @@
|
|||
import logging
|
||||
from twisted.internet import defer
|
||||
from lbrynet.dht.distance import Distance
|
||||
from lbrynet.dht.error import TimeoutError
|
||||
from lbrynet.dht import constants
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_contact(contact_list, node_id, address, port):
|
||||
for contact in contact_list:
|
||||
if contact.id == node_id and contact.address == address and contact.port == port:
|
||||
return contact
|
||||
raise IndexError(node_id)
|
||||
|
||||
|
||||
def expand_peer(compact_peer_info):
|
||||
host = "{}.{}.{}.{}".format(*compact_peer_info[:4])
|
||||
port = int.from_bytes(compact_peer_info[4:6], 'big')
|
||||
peer_node_id = compact_peer_info[6:]
|
||||
return (peer_node_id, host, port)
|
||||
|
||||
|
||||
class _IterativeFind:
|
||||
# TODO: use polymorphism to search for a value or node
|
||||
# instead of using a find_value flag
|
||||
def __init__(self, node, shortlist, key, rpc, exclude=None):
|
||||
self.exclude = set(exclude or [])
|
||||
self.node = node
|
||||
self.finished_deferred = defer.Deferred()
|
||||
# all distance operations in this class only care about the distance
|
||||
# to self.key, so this makes it easier to calculate those
|
||||
self.distance = Distance(key)
|
||||
# The closest known and active node yet found
|
||||
self.closest_node = None if not shortlist else shortlist[0]
|
||||
self.prev_closest_node = None
|
||||
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
|
||||
self.shortlist = shortlist
|
||||
# The search key
|
||||
self.key = key
|
||||
# The rpc method name (findValue or findNode)
|
||||
self.rpc = rpc
|
||||
# List of active queries; len() indicates number of active probes
|
||||
self.active_probes = []
|
||||
# List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply
|
||||
self.already_contacted = []
|
||||
# A list of found and known-to-be-active remote nodes (Contact objects)
|
||||
self.active_contacts = []
|
||||
# Ensure only one searchIteration call is running at a time
|
||||
self._search_iteration_semaphore = defer.DeferredSemaphore(1)
|
||||
self._iteration_count = 0
|
||||
self.find_value_result = {}
|
||||
self.pending_iteration_calls = []
|
||||
|
||||
@property
|
||||
def is_find_node_request(self):
|
||||
return self.rpc == "findNode"
|
||||
|
||||
@property
|
||||
def is_find_value_request(self):
|
||||
return self.rpc == "findValue"
|
||||
|
||||
def is_closer(self, contact):
|
||||
if not self.closest_node:
|
||||
return True
|
||||
return self.distance.is_closer(contact.id, self.closest_node.id)
|
||||
|
||||
def getContactTriples(self, result):
|
||||
if self.is_find_value_request:
|
||||
contact_triples = result[b'contacts']
|
||||
else:
|
||||
contact_triples = result
|
||||
for contact_tup in contact_triples:
|
||||
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
|
||||
raise ValueError("invalid contact triple")
|
||||
contact_tup[1] = contact_tup[1].decode() # ips are strings
|
||||
return contact_triples
|
||||
|
||||
def sortByDistance(self, contact_list):
|
||||
"""Sort the list of contacts in order by distance from key"""
|
||||
contact_list.sort(key=lambda c: self.distance(c.id))
|
||||
|
||||
def extendShortlist(self, contact, result):
|
||||
# The "raw response" tuple contains the response message and the originating address info
|
||||
originAddress = (contact.address, contact.port)
|
||||
if self.finished_deferred.called:
|
||||
return contact.id
|
||||
if self.node.contact_manager.is_ignored(originAddress):
|
||||
raise ValueError("contact is ignored")
|
||||
if contact.id == self.node.node_id:
|
||||
return contact.id
|
||||
|
||||
if contact not in self.active_contacts:
|
||||
self.active_contacts.append(contact)
|
||||
if contact not in self.shortlist:
|
||||
self.shortlist.append(contact)
|
||||
|
||||
# Now grow extend the (unverified) shortlist with the returned contacts
|
||||
# TODO: some validation on the result (for guarding against attacks)
|
||||
# If we are looking for a value, first see if this result is the value
|
||||
# we are looking for before treating it as a list of contact triples
|
||||
if self.is_find_value_request and self.key in result:
|
||||
# We have found the value
|
||||
for peer in result[self.key]:
|
||||
node_id, host, port = expand_peer(peer)
|
||||
if (host, port) not in self.exclude:
|
||||
self.find_value_result.setdefault(self.key, []).append((node_id, host, port))
|
||||
if self.find_value_result:
|
||||
self.finished_deferred.callback(self.find_value_result)
|
||||
else:
|
||||
if self.is_find_value_request:
|
||||
# We are looking for a value, and the remote node didn't have it
|
||||
# - mark it as the closest "empty" node, if it is
|
||||
# TODO: store to this peer after finding the value as per the kademlia spec
|
||||
if b'closestNodeNoValue' in self.find_value_result:
|
||||
if self.is_closer(contact):
|
||||
self.find_value_result[b'closestNodeNoValue'] = contact
|
||||
else:
|
||||
self.find_value_result[b'closestNodeNoValue'] = contact
|
||||
contactTriples = self.getContactTriples(result)
|
||||
for contactTriple in contactTriples:
|
||||
if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted):
|
||||
continue
|
||||
elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])):
|
||||
continue
|
||||
else:
|
||||
found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1],
|
||||
contactTriple[2], self.node._protocol)
|
||||
if found_contact not in self.shortlist:
|
||||
self.shortlist.append(found_contact)
|
||||
|
||||
if not self.finished_deferred.called and self.should_stop():
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
|
||||
|
||||
return contact.id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def probeContact(self, contact):
|
||||
fn = getattr(contact, self.rpc)
|
||||
try:
|
||||
response = yield fn(self.key)
|
||||
result = self.extendShortlist(contact, response)
|
||||
defer.returnValue(result)
|
||||
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
|
||||
defer.returnValue(contact.id)
|
||||
|
||||
def should_stop(self):
|
||||
if self.is_find_value_request:
|
||||
# search stops when it finds a value, let it run
|
||||
return False
|
||||
if self.prev_closest_node and self.closest_node and self.distance.is_closer(self.prev_closest_node.id,
|
||||
self.closest_node.id):
|
||||
# we're getting further away
|
||||
return True
|
||||
if len(self.active_contacts) >= constants.k:
|
||||
# we have enough results
|
||||
return True
|
||||
return False
|
||||
|
||||
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
|
||||
def _searchIteration(self):
|
||||
# Sort the discovered active nodes from closest to furthest
|
||||
if len(self.active_contacts):
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.prev_closest_node = self.closest_node
|
||||
self.closest_node = self.active_contacts[0]
|
||||
|
||||
# Sort the current shortList before contacting other nodes
|
||||
self.sortByDistance(self.shortlist)
|
||||
probes = []
|
||||
already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted}
|
||||
to_remove = []
|
||||
for contact in self.shortlist:
|
||||
if self.node.contact_manager.is_ignored((contact.address, contact.port)):
|
||||
to_remove.append(contact) # a contact became bad during iteration
|
||||
continue
|
||||
if (contact.address, contact.port) not in already_contacted_addresses:
|
||||
self.already_contacted.append(contact)
|
||||
to_remove.append(contact)
|
||||
probe = self.probeContact(contact)
|
||||
probes.append(probe)
|
||||
self.active_probes.append(probe)
|
||||
if len(probes) == constants.alpha:
|
||||
break
|
||||
for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully
|
||||
self.shortlist.remove(contact)
|
||||
|
||||
# run the probes
|
||||
if probes:
|
||||
# Schedule the next iteration if there are any active
|
||||
# calls (Kademlia uses loose parallelism)
|
||||
self.searchIteration()
|
||||
|
||||
d = defer.DeferredList(probes, consumeErrors=True)
|
||||
|
||||
def _remove_probes(results):
|
||||
for probe in probes:
|
||||
self.active_probes.remove(probe)
|
||||
return results
|
||||
|
||||
d.addCallback(_remove_probes)
|
||||
|
||||
elif not self.finished_deferred.called and not self.active_probes or self.should_stop():
|
||||
# If no probes were sent, there will not be any improvement, so we're done
|
||||
if self.is_find_value_request:
|
||||
self.finished_deferred.callback(self.find_value_result)
|
||||
else:
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
|
||||
elif not self.finished_deferred.called:
|
||||
# Force the next iteration
|
||||
self.searchIteration()
|
||||
|
||||
def searchIteration(self, delay=constants.iterativeLookupDelay):
|
||||
def _cancel_pending_iterations(result):
|
||||
while self.pending_iteration_calls:
|
||||
canceller = self.pending_iteration_calls.pop()
|
||||
canceller()
|
||||
return result
|
||||
self.finished_deferred.addBoth(_cancel_pending_iterations)
|
||||
self._iteration_count += 1
|
||||
call, cancel = self.node.reactor_callLater(delay, self._search_iteration_semaphore.run, self._searchIteration)
|
||||
self.pending_iteration_calls.append(cancel)
|
||||
|
||||
|
||||
def iterativeFind(node, shortlist, key, rpc, exclude=None):
|
||||
helper = _IterativeFind(node, shortlist, key, rpc, exclude)
|
||||
helper.searchIteration(0)
|
||||
return helper.finished_deferred
|
|
@ -1,147 +0,0 @@
|
|||
import logging
|
||||
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.distance import Distance
|
||||
from lbrynet.dht.error import BucketFull
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KBucket:
|
||||
""" Description - later
|
||||
"""
|
||||
|
||||
def __init__(self, rangeMin, rangeMax, node_id):
|
||||
"""
|
||||
@param rangeMin: The lower boundary for the range in the n-bit ID
|
||||
space covered by this k-bucket
|
||||
@param rangeMax: The upper boundary for the range in the ID space
|
||||
covered by this k-bucket
|
||||
"""
|
||||
self.lastAccessed = 0
|
||||
self.rangeMin = rangeMin
|
||||
self.rangeMax = rangeMax
|
||||
self._contacts = list()
|
||||
self._node_id = node_id
|
||||
|
||||
def addContact(self, contact):
|
||||
""" Add contact to _contact list in the right order. This will move the
|
||||
contact to the end of the k-bucket if it is already present.
|
||||
|
||||
@raise kademlia.kbucket.BucketFull: Raised when the bucket is full and
|
||||
the contact isn't in the bucket
|
||||
already
|
||||
|
||||
@param contact: The contact to add
|
||||
@type contact: dht.contact._Contact
|
||||
"""
|
||||
if contact in self._contacts:
|
||||
# Move the existing contact to the end of the list
|
||||
# - using the new contact to allow add-on data
|
||||
# (e.g. optimization-specific stuff) to pe updated as well
|
||||
self._contacts.remove(contact)
|
||||
self._contacts.append(contact)
|
||||
elif len(self._contacts) < constants.k:
|
||||
self._contacts.append(contact)
|
||||
else:
|
||||
raise BucketFull("No space in bucket to insert contact")
|
||||
|
||||
def getContact(self, contactID):
|
||||
"""Get the contact specified node ID
|
||||
|
||||
@raise IndexError: raised if the contact is not in the bucket
|
||||
|
||||
@param contactID: the node id of the contact to retrieve
|
||||
@type contactID: str
|
||||
|
||||
@rtype: dht.contact._Contact
|
||||
"""
|
||||
for contact in self._contacts:
|
||||
if contact.id == contactID:
|
||||
return contact
|
||||
raise IndexError(contactID)
|
||||
|
||||
def getContacts(self, count=-1, excludeContact=None, sort_distance_to=None):
|
||||
""" Returns a list containing up to the first count number of contacts
|
||||
|
||||
@param count: The amount of contacts to return (if 0 or less, return
|
||||
all contacts)
|
||||
@type count: int
|
||||
@param excludeContact: A node id to exclude; if this contact is in
|
||||
the list of returned values, it will be
|
||||
discarded before returning. If a C{str} is
|
||||
passed as this argument, it must be the
|
||||
contact's ID.
|
||||
@type excludeContact: str
|
||||
|
||||
@param sort_distance_to: Sort distance to the id, defaulting to the parent node id. If False don't
|
||||
sort the contacts
|
||||
|
||||
@raise IndexError: If the number of requested contacts is too large
|
||||
|
||||
@return: Return up to the first count number of contacts in a list
|
||||
If no contacts are present an empty is returned
|
||||
@rtype: list
|
||||
"""
|
||||
contacts = [contact for contact in self._contacts if contact.id != excludeContact]
|
||||
|
||||
# Return all contacts in bucket
|
||||
if count <= 0:
|
||||
count = len(contacts)
|
||||
|
||||
# Get current contact number
|
||||
currentLen = len(contacts)
|
||||
|
||||
# If count greater than k - return only k contacts
|
||||
if count > constants.k:
|
||||
count = constants.k
|
||||
|
||||
if not currentLen:
|
||||
return contacts
|
||||
|
||||
if sort_distance_to is False:
|
||||
pass
|
||||
else:
|
||||
sort_distance_to = sort_distance_to or self._node_id
|
||||
contacts.sort(key=lambda c: Distance(sort_distance_to)(c.id))
|
||||
|
||||
return contacts[:min(currentLen, count)]
|
||||
|
||||
def getBadOrUnknownContacts(self):
|
||||
contacts = self.getContacts(sort_distance_to=False)
|
||||
results = [contact for contact in contacts if contact.contact_is_good is False]
|
||||
results.extend(contact for contact in contacts if contact.contact_is_good is None)
|
||||
return results
|
||||
|
||||
def removeContact(self, contact):
|
||||
""" Remove the contact from the bucket
|
||||
|
||||
@param contact: The contact to remove
|
||||
@type contact: dht.contact._Contact
|
||||
|
||||
@raise ValueError: The specified contact is not in this bucket
|
||||
"""
|
||||
self._contacts.remove(contact)
|
||||
|
||||
def keyInRange(self, key):
|
||||
""" Tests whether the specified key (i.e. node ID) is in the range
|
||||
of the n-bit ID space covered by this k-bucket (in otherwords, it
|
||||
returns whether or not the specified key should be placed in this
|
||||
k-bucket)
|
||||
|
||||
@param key: The key to test
|
||||
@type key: str or int
|
||||
|
||||
@return: C{True} if the key is in this k-bucket's range, or C{False}
|
||||
if not.
|
||||
@rtype: bool
|
||||
"""
|
||||
if isinstance(key, bytes):
|
||||
key = int.from_bytes(key, 'big')
|
||||
return self.rangeMin <= key < self.rangeMax
|
||||
|
||||
def __len__(self):
|
||||
return len(self._contacts)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._contacts
|
|
@ -1,90 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
#
|
||||
# The docstrings in this module contain epytext markup; API documentation
|
||||
# may be created by processing this file with epydoc: http://epydoc.sf.net
|
||||
|
||||
from lbrynet.dht import msgtypes
|
||||
|
||||
|
||||
class MessageTranslator:
|
||||
""" Interface for RPC message translators/formatters
|
||||
|
||||
Classes inheriting from this should provide a translation services between
|
||||
the classes used internally by this Kademlia implementation and the actual
|
||||
data that is transmitted between nodes.
|
||||
"""
|
||||
|
||||
def fromPrimitive(self, msgPrimitive):
|
||||
""" Create an RPC Message from a message's string representation
|
||||
|
||||
@param msgPrimitive: The unencoded primitive representation of a message
|
||||
@type msgPrimitive: str, int, list or dict
|
||||
|
||||
@return: The translated message object
|
||||
@rtype: entangled.kademlia.msgtypes.Message
|
||||
"""
|
||||
|
||||
def toPrimitive(self, message):
|
||||
""" Create a string representation of a message
|
||||
|
||||
@param message: The message object
|
||||
@type message: msgtypes.Message
|
||||
|
||||
@return: The message's primitive representation in a particular
|
||||
messaging format
|
||||
@rtype: str, int, list or dict
|
||||
"""
|
||||
|
||||
|
||||
class DefaultFormat(MessageTranslator):
|
||||
""" The default on-the-wire message format for this library """
|
||||
typeRequest, typeResponse, typeError = range(3)
|
||||
headerType, headerMsgID, headerNodeID, headerPayload, headerArgs = range(5) # TODO: make str
|
||||
|
||||
@staticmethod
|
||||
def get(primitive, key):
|
||||
try:
|
||||
return primitive[key]
|
||||
except KeyError:
|
||||
return primitive[str(key)] # TODO: switch to int()
|
||||
|
||||
def fromPrimitive(self, msgPrimitive):
|
||||
msgType = self.get(msgPrimitive, self.headerType)
|
||||
if msgType == self.typeRequest:
|
||||
msg = msgtypes.RequestMessage(self.get(msgPrimitive, self.headerNodeID),
|
||||
self.get(msgPrimitive, self.headerPayload),
|
||||
self.get(msgPrimitive, self.headerArgs),
|
||||
self.get(msgPrimitive, self.headerMsgID))
|
||||
elif msgType == self.typeResponse:
|
||||
msg = msgtypes.ResponseMessage(self.get(msgPrimitive, self.headerMsgID),
|
||||
self.get(msgPrimitive, self.headerNodeID),
|
||||
self.get(msgPrimitive, self.headerPayload))
|
||||
elif msgType == self.typeError:
|
||||
msg = msgtypes.ErrorMessage(self.get(msgPrimitive, self.headerMsgID),
|
||||
self.get(msgPrimitive, self.headerNodeID),
|
||||
self.get(msgPrimitive, self.headerPayload),
|
||||
self.get(msgPrimitive, self.headerArgs))
|
||||
else:
|
||||
# Unknown message, no payload
|
||||
msg = msgtypes.Message(msgPrimitive[self.headerMsgID], msgPrimitive[self.headerNodeID])
|
||||
return msg
|
||||
|
||||
def toPrimitive(self, message):
|
||||
msg = {self.headerMsgID: message.id,
|
||||
self.headerNodeID: message.nodeID}
|
||||
if isinstance(message, msgtypes.RequestMessage):
|
||||
msg[self.headerType] = self.typeRequest
|
||||
msg[self.headerPayload] = message.request
|
||||
msg[self.headerArgs] = message.args
|
||||
elif isinstance(message, msgtypes.ErrorMessage):
|
||||
msg[self.headerType] = self.typeError
|
||||
msg[self.headerPayload] = message.exceptionType
|
||||
msg[self.headerArgs] = message.response
|
||||
elif isinstance(message, msgtypes.ResponseMessage):
|
||||
msg[self.headerType] = self.typeResponse
|
||||
msg[self.headerPayload] = message.response
|
||||
return msg
|
|
@ -1,52 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
#
|
||||
# The docstrings in this module contain epytext markup; API documentation
|
||||
# may be created by processing this file with epydoc: http://epydoc.sf.net
|
||||
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.dht import constants
|
||||
|
||||
|
||||
class Message:
|
||||
""" Base class for messages - all "unknown" messages use this class """
|
||||
|
||||
def __init__(self, rpcID, nodeID):
|
||||
if len(rpcID) != constants.rpc_id_length:
|
||||
raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID))
|
||||
if len(nodeID) != constants.key_bits // 8:
|
||||
raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID))
|
||||
self.id = rpcID
|
||||
self.nodeID = nodeID
|
||||
|
||||
|
||||
class RequestMessage(Message):
|
||||
""" Message containing an RPC request """
|
||||
|
||||
def __init__(self, nodeID, method, methodArgs, rpcID=None):
|
||||
if rpcID is None:
|
||||
rpcID = generate_id()[:constants.rpc_id_length]
|
||||
super().__init__(rpcID, nodeID)
|
||||
self.request = method
|
||||
self.args = methodArgs
|
||||
|
||||
|
||||
class ResponseMessage(Message):
|
||||
""" Message containing the result from a successful RPC request """
|
||||
|
||||
def __init__(self, rpcID, nodeID, response):
|
||||
super().__init__(rpcID, nodeID)
|
||||
self.response = response
|
||||
|
||||
|
||||
class ErrorMessage(ResponseMessage):
|
||||
""" Message containing the error from an unsuccessful RPC request """
|
||||
|
||||
def __init__(self, rpcID, nodeID, exceptionType, errorMessage):
|
||||
super().__init__(rpcID, nodeID, errorMessage)
|
||||
if isinstance(exceptionType, type):
|
||||
exceptionType = (f'{exceptionType.__module__}.{exceptionType.__name__}').encode()
|
||||
self.exceptionType = exceptionType
|
|
@ -1,649 +1,226 @@
|
|||
import binascii
|
||||
import hashlib
|
||||
import logging
|
||||
from functools import reduce
|
||||
import asyncio
|
||||
import typing
|
||||
import socket
|
||||
import binascii
|
||||
import contextlib
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.error import RemoteException
|
||||
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
|
||||
from lbrynet.dht.protocol.distance import Distance
|
||||
from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder
|
||||
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
|
||||
from twisted.internet import defer, error, task
|
||||
|
||||
from lbrynet.utils import generate_id, DeferredDict
|
||||
from lbrynet.dht.call_later_manager import CallLaterManager
|
||||
from lbrynet.dht.error import TimeoutError
|
||||
from lbrynet.dht import constants, routingtable, datastore, protocol
|
||||
from lbrynet.dht.contact import ContactManager
|
||||
from lbrynet.dht.iterativefind import iterativeFind
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rpcmethod(func):
|
||||
""" Decorator to expose Node methods as remote procedure calls
|
||||
class Node:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, udp_port: int,
|
||||
internal_udp_port: int, peer_port: int, external_ip: str):
|
||||
self.loop = loop
|
||||
self.internal_udp_port = internal_udp_port
|
||||
self.protocol = KademliaProtocol(loop, peer_manager, node_id, external_ip, udp_port, peer_port)
|
||||
self.listening_port: asyncio.DatagramTransport = None
|
||||
self.joined = asyncio.Event(loop=self.loop)
|
||||
self._join_task: asyncio.Task = None
|
||||
self._refresh_task: asyncio.Task = None
|
||||
|
||||
Apply this decorator to methods in the Node class (or a subclass) in order
|
||||
to make them remotely callable via the DHT's RPC mechanism.
|
||||
"""
|
||||
func.rpcmethod = True
|
||||
return func
|
||||
async def refresh_node(self):
|
||||
while True:
|
||||
# remove peers with expired blob announcements from the datastore
|
||||
self.protocol.data_store.removed_expired_peers()
|
||||
|
||||
total_peers: typing.List['KademliaPeer'] = []
|
||||
# add all peers in the routing table
|
||||
total_peers.extend(self.protocol.routing_table.get_peers())
|
||||
# add all the peers who have announed blobs to us
|
||||
total_peers.extend(self.protocol.data_store.get_storing_contacts())
|
||||
|
||||
class MockKademliaHelper:
|
||||
def __init__(self, clock=None, callLater=None, resolve=None, listenUDP=None):
|
||||
if not listenUDP or not resolve or not callLater or not clock:
|
||||
from twisted.internet import reactor
|
||||
listenUDP = listenUDP or reactor.listenUDP
|
||||
resolve = resolve or reactor.resolve
|
||||
callLater = callLater or reactor.callLater
|
||||
clock = clock or reactor
|
||||
# get ids falling in the midpoint of each bucket that hasn't been recently updated
|
||||
node_ids = self.protocol.routing_table.get_refresh_list(0, True)
|
||||
# if we have 3 or fewer populated buckets get two random ids in the range of each to try and
|
||||
# populate/split the buckets further
|
||||
buckets_with_contacts = self.protocol.routing_table.buckets_with_contacts()
|
||||
if buckets_with_contacts <= 3:
|
||||
for i in range(buckets_with_contacts):
|
||||
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
|
||||
node_ids.append(self.protocol.routing_table.random_id_in_bucket_range(i))
|
||||
|
||||
self.clock = clock
|
||||
self.contact_manager = ContactManager(self.clock.seconds)
|
||||
self.reactor_listenUDP = listenUDP
|
||||
self.reactor_resolve = resolve
|
||||
self.call_later_manager = CallLaterManager(callLater)
|
||||
self.reactor_callLater = self.call_later_manager.call_later
|
||||
self.reactor_callSoon = self.call_later_manager.call_soon
|
||||
|
||||
self._listeningPort = None # object implementing Twisted
|
||||
# IListeningPort This will contain a deferred created when
|
||||
# joining the network, to enable publishing/retrieving
|
||||
# information from the DHT as soon as the node is part of the
|
||||
# network (add callbacks to this deferred if scheduling such
|
||||
# operations before the node has finished joining the network)
|
||||
|
||||
def get_looping_call(self, fn, *args, **kwargs):
|
||||
lc = task.LoopingCall(fn, *args, **kwargs)
|
||||
lc.clock = self.clock
|
||||
return lc
|
||||
|
||||
def safe_stop_looping_call(self, lc):
|
||||
if lc and lc.running:
|
||||
return lc.stop()
|
||||
return defer.succeed(None)
|
||||
|
||||
def safe_start_looping_call(self, lc, t):
|
||||
if lc and not lc.running:
|
||||
lc.start(t)
|
||||
|
||||
|
||||
class Node(MockKademliaHelper):
|
||||
""" Local node in the Kademlia network
|
||||
|
||||
This class represents a single local node in a Kademlia network; in other
|
||||
words, this class encapsulates an Entangled-using application's "presence"
|
||||
in a Kademlia network.
|
||||
|
||||
In Entangled, all interactions with the Kademlia network by a client
|
||||
application is performed via this class (or a subclass).
|
||||
"""
|
||||
|
||||
def __init__(self, node_id=None, udpPort=4000, dataStore=None,
|
||||
routingTableClass=None, networkProtocol=None,
|
||||
externalIP=None, peerPort=3333, listenUDP=None,
|
||||
callLater=None, resolve=None, clock=None,
|
||||
interface='', externalUDPPort=None):
|
||||
"""
|
||||
@param dataStore: The data store to use. This must be class inheriting
|
||||
from the C{DataStore} interface (or providing the
|
||||
same API). How the data store manages its data
|
||||
internally is up to the implementation of that data
|
||||
store.
|
||||
@type dataStore: entangled.kademlia.datastore.DataStore
|
||||
@param routingTable: The routing table class to use. Since there exists
|
||||
some ambiguity as to how the routing table should be
|
||||
implemented in Kademlia, a different routing table
|
||||
may be used, as long as the appropriate API is
|
||||
exposed. This should be a class, not an object,
|
||||
in order to allow the Node to pass an
|
||||
auto-generated node ID to the routingtable object
|
||||
upon instantiation (if necessary).
|
||||
@type routingTable: entangled.kademlia.routingtable.RoutingTable
|
||||
@param networkProtocol: The network protocol to use. This can be
|
||||
overridden from the default to (for example)
|
||||
change the format of the physical RPC messages
|
||||
being transmitted.
|
||||
@type networkProtocol: entangled.kademlia.protocol.KademliaProtocol
|
||||
@param externalIP: the IP at which this node can be contacted
|
||||
@param peerPort: the port at which this node announces it has a blob for
|
||||
"""
|
||||
|
||||
super().__init__(clock, callLater, resolve, listenUDP)
|
||||
self.node_id = node_id or self._generateID()
|
||||
self.port = udpPort
|
||||
self._listen_interface = interface
|
||||
self._change_token_lc = self.get_looping_call(self.change_token)
|
||||
self._refresh_node_lc = self.get_looping_call(self._refreshNode)
|
||||
self._refresh_contacts_lc = self.get_looping_call(self._refreshContacts)
|
||||
|
||||
# Create k-buckets (for storing contacts)
|
||||
if routingTableClass is None:
|
||||
self._routingTable = routingtable.TreeRoutingTable(self.node_id, self.clock.seconds)
|
||||
if self.protocol.routing_table.get_peers():
|
||||
# if we have node ids to look up, perform the iterative search until we have k results
|
||||
while node_ids:
|
||||
peers = await self.peer_search(node_ids.pop())
|
||||
total_peers.extend(peers)
|
||||
else:
|
||||
self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
self.loop.call_later(constants.refresh_interval // 4, fut.set_result, None)
|
||||
await fut
|
||||
continue
|
||||
|
||||
self._protocol = networkProtocol or protocol.KademliaProtocol(self)
|
||||
self.token_secret = self._generateID()
|
||||
self.old_token_secret = None
|
||||
self.externalIP = externalIP
|
||||
self.peerPort = peerPort
|
||||
self.externalUDPPort = externalUDPPort or self.port
|
||||
self._dataStore = dataStore or datastore.DictDataStore(self.clock.seconds)
|
||||
self._join_deferred = None
|
||||
# ping the set of peers; upon success/failure the routing able and last replied/failed time will be updated
|
||||
to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True]
|
||||
if to_ping:
|
||||
await self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0)
|
||||
|
||||
#def __del__(self):
|
||||
# log.warning("unclean shutdown of the dht node")
|
||||
# if hasattr(self, "_listeningPort") and self._listeningPort is not None:
|
||||
# self._listeningPort.stopListening()
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
self.loop.call_later(constants.refresh_interval, fut.set_result, None)
|
||||
await fut
|
||||
|
||||
def __str__(self):
|
||||
return '<%s.%s object; ID: %s, IP address: %s, UDP port: %d>' % (
|
||||
self.__module__, self.__class__.__name__, binascii.hexlify(self.node_id), self.externalIP, self.port)
|
||||
async def announce_blob(self, blob_hash: str) -> typing.List[bytes]:
|
||||
announced_to_node_ids = []
|
||||
while not announced_to_node_ids:
|
||||
hash_value = binascii.unhexlify(blob_hash.encode())
|
||||
assert len(hash_value) == constants.hash_length
|
||||
peers = await self.peer_search(hash_value)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stop(self):
|
||||
# stop LoopingCalls:
|
||||
yield self.safe_stop_looping_call(self._refresh_node_lc)
|
||||
yield self.safe_stop_looping_call(self._change_token_lc)
|
||||
yield self.safe_stop_looping_call(self._refresh_contacts_lc)
|
||||
if self._listeningPort is not None:
|
||||
yield self._listeningPort.stopListening()
|
||||
self._listeningPort = None
|
||||
if not self.protocol.external_ip:
|
||||
raise Exception("Cannot determine external IP")
|
||||
log.info("Store to %i peers", len(peers))
|
||||
log.info(peers)
|
||||
for peer in peers:
|
||||
log.info("store to %s %s %s", peer.address, peer.udp_port, peer.tcp_port)
|
||||
stored_to_tup = await asyncio.gather(
|
||||
*(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop
|
||||
)
|
||||
announced_to_node_ids.extend([node_id for node_id, contacted in stored_to_tup if contacted])
|
||||
log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8],
|
||||
len(announced_to_node_ids), len(peers))
|
||||
return announced_to_node_ids
|
||||
|
||||
def start_listening(self):
|
||||
if not self._listeningPort:
|
||||
def stop(self) -> None:
|
||||
if self.joined.is_set():
|
||||
self.joined.clear()
|
||||
if self._join_task:
|
||||
self._join_task.cancel()
|
||||
if self._refresh_task and not (self._refresh_task.done() or self._refresh_task.cancelled()):
|
||||
self._refresh_task.cancel()
|
||||
if self.protocol and self.protocol.ping_queue.running:
|
||||
self.protocol.ping_queue.stop()
|
||||
if self.listening_port is not None:
|
||||
self.listening_port.close()
|
||||
self._join_task = None
|
||||
self.listening_port = None
|
||||
log.info("Stopped DHT node")
|
||||
|
||||
async def start_listening(self, interface: str = '') -> None:
|
||||
if not self.listening_port:
|
||||
self.listening_port, _ = await self.loop.create_datagram_endpoint(
|
||||
lambda: self.protocol, (interface, self.internal_udp_port)
|
||||
)
|
||||
log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port)
|
||||
else:
|
||||
log.warning("Already bound to port %s", self.listening_port)
|
||||
|
||||
async def join_network(self, interface: typing.Optional[str] = '',
|
||||
known_node_urls: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
|
||||
known_node_addresses: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
|
||||
if not self.listening_port:
|
||||
await self.start_listening(interface)
|
||||
self.protocol.ping_queue.start()
|
||||
self._refresh_task = self.loop.create_task(self.refresh_node())
|
||||
|
||||
known_node_addresses = known_node_addresses or []
|
||||
if known_node_urls:
|
||||
for host, port in known_node_urls:
|
||||
info = await self.loop.getaddrinfo(
|
||||
host, 'https',
|
||||
proto=socket.IPPROTO_TCP,
|
||||
)
|
||||
if (info[0][4][0], port) not in known_node_addresses:
|
||||
known_node_addresses.append((info[0][4][0], port))
|
||||
futs = []
|
||||
for address, port in known_node_addresses:
|
||||
peer = self.protocol.get_rpc_peer(KademliaPeer(self.loop, address, udp_port=port))
|
||||
futs.append(peer.ping())
|
||||
if futs:
|
||||
await asyncio.wait(futs, loop=self.loop)
|
||||
|
||||
async with self.peer_search_junction(self.protocol.node_id, max_results=16) as junction:
|
||||
async for peers in junction:
|
||||
for peer in peers:
|
||||
try:
|
||||
self._listeningPort = self.reactor_listenUDP(self.port, self._protocol,
|
||||
interface=self._listen_interface)
|
||||
except error.CannotListenError as e:
|
||||
import traceback
|
||||
log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc())
|
||||
raise ValueError("%s lbrynet may already be running." % str(e))
|
||||
else:
|
||||
log.warning("Already bound to port %s", self._listeningPort)
|
||||
await self.protocol.get_rpc_peer(peer).ping()
|
||||
except (asyncio.TimeoutError, RemoteException):
|
||||
pass
|
||||
self.joined.set()
|
||||
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||
self.protocol.routing_table.buckets_with_contacts())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def joinNetwork(self, known_node_addresses=(('jack.lbry.tech', 4455), )):
|
||||
"""
|
||||
Attempt to join the dht, retry every 30 seconds if unsuccessful
|
||||
:param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes
|
||||
"""
|
||||
def start(self, interface: str, known_node_urls: typing.List[typing.Tuple[str, int]]):
|
||||
self._join_task = self.loop.create_task(
|
||||
self.join_network(
|
||||
interface=interface, known_node_urls=known_node_urls
|
||||
)
|
||||
)
|
||||
|
||||
self._join_deferred = defer.Deferred()
|
||||
known_node_resolution = {}
|
||||
def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
|
||||
bottom_out_limit: int = constants.bottom_out_limit,
|
||||
max_results: int = constants.k) -> IterativeNodeFinder:
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _resolve_seeds():
|
||||
result = {}
|
||||
for host, port in known_node_addresses:
|
||||
node_address = yield self.reactor_resolve(host)
|
||||
result[(host, port)] = node_address
|
||||
defer.returnValue(result)
|
||||
return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
||||
key, bottom_out_limit, max_results, None, shortlist)
|
||||
|
||||
if not known_node_resolution:
|
||||
known_node_resolution = yield _resolve_seeds()
|
||||
# we are one of the seed nodes, don't add ourselves
|
||||
if (self.externalIP, self.port) in known_node_resolution.values():
|
||||
del known_node_resolution[(self.externalIP, self.port)]
|
||||
known_node_addresses.remove((self.externalIP, self.port))
|
||||
def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
|
||||
bottom_out_limit: int = 40,
|
||||
max_results: int = -1) -> IterativeValueFinder:
|
||||
|
||||
def _ping_contacts(contacts):
|
||||
d = DeferredDict({contact: contact.ping() for contact in contacts}, consumeErrors=True)
|
||||
d.addErrback(lambda err: err.trap(TimeoutError))
|
||||
return d
|
||||
return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
||||
key, bottom_out_limit, max_results, None, shortlist)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _initialize_routing():
|
||||
bootstrap_contacts = []
|
||||
contact_addresses = {(c.address, c.port): c for c in self.contacts}
|
||||
for (host, port), ip_address in known_node_resolution.items():
|
||||
if (host, port) not in contact_addresses:
|
||||
# Create temporary contact information for the list of addresses of known nodes
|
||||
# The contact node id will be set with the responding node id when we initialize it to None
|
||||
contact = self.contact_manager.make_contact(None, ip_address, port, self._protocol)
|
||||
bootstrap_contacts.append(contact)
|
||||
else:
|
||||
for contact in self.contacts:
|
||||
if contact.address == ip_address and contact.port == port:
|
||||
if not contact.id:
|
||||
bootstrap_contacts.append(contact)
|
||||
@contextlib.asynccontextmanager
|
||||
async def stream_peer_search_junction(self, hash_queue: asyncio.Queue, bottom_out_limit=20,
|
||||
max_results=-1) -> AsyncGeneratorJunction:
|
||||
peer_generator = AsyncGeneratorJunction(self.loop)
|
||||
|
||||
async def _add_hashes_from_queue():
|
||||
while True:
|
||||
try:
|
||||
blob_hash = await hash_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
if not bootstrap_contacts:
|
||||
log.warning("no bootstrap contacts to ping")
|
||||
ping_result = yield _ping_contacts(bootstrap_contacts)
|
||||
shortlist = list(ping_result.keys())
|
||||
if not shortlist:
|
||||
log.warning("failed to ping %i bootstrap contacts", len(bootstrap_contacts))
|
||||
defer.returnValue(None)
|
||||
else:
|
||||
# find the closest peers to us
|
||||
closest = yield self._iterativeFind(self.node_id, shortlist if not self.contacts else None)
|
||||
yield _ping_contacts(closest)
|
||||
# query random hashes in our bucket key ranges to fill or split them
|
||||
random_ids_in_range = self._routingTable.getRefreshList()
|
||||
while random_ids_in_range:
|
||||
yield self.iterativeFindNode(random_ids_in_range.pop())
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _iterative_join(joined_d=None, last_buckets_with_contacts=None):
|
||||
log.info("Attempting to join the DHT network, %i contacts known so far", len(self.contacts))
|
||||
joined_d = joined_d or defer.Deferred()
|
||||
yield _initialize_routing()
|
||||
buckets_with_contacts = self.bucketsWithContacts()
|
||||
if last_buckets_with_contacts and last_buckets_with_contacts == buckets_with_contacts:
|
||||
if not joined_d.called:
|
||||
joined_d.callback(True)
|
||||
elif buckets_with_contacts < 4:
|
||||
self.reactor_callLater(0, _iterative_join, joined_d, buckets_with_contacts)
|
||||
elif not joined_d.called:
|
||||
joined_d.callback(None)
|
||||
yield joined_d
|
||||
if not self._join_deferred.called:
|
||||
self._join_deferred.callback(True)
|
||||
defer.returnValue(None)
|
||||
|
||||
yield _iterative_join()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self, known_node_addresses=None, block_on_join=False):
|
||||
""" Causes the Node to attempt to join the DHT network by contacting the
|
||||
known DHT nodes. This can be called multiple times if the previous attempt
|
||||
has failed or if the Node has lost all the contacts.
|
||||
|
||||
@param known_node_addresses: A sequence of tuples containing IP address
|
||||
information for existing nodes on the
|
||||
Kademlia network, in the format:
|
||||
C{(<ip address>, (udp port>)}
|
||||
@type known_node_addresses: list
|
||||
"""
|
||||
|
||||
self.start_listening()
|
||||
yield self._protocol._listening
|
||||
# TODO: Refresh all k-buckets further away than this node's closest neighbour
|
||||
d = self.joinNetwork(known_node_addresses or [])
|
||||
d.addCallback(lambda _: self.start_looping_calls())
|
||||
d.addCallback(lambda _: log.info("Joined the dht"))
|
||||
if block_on_join:
|
||||
yield d
|
||||
|
||||
def start_looping_calls(self):
|
||||
self.safe_start_looping_call(self._change_token_lc, constants.tokenSecretChangeInterval)
|
||||
# Start refreshing k-buckets periodically, if necessary
|
||||
self.safe_start_looping_call(self._refresh_node_lc, constants.checkRefreshInterval)
|
||||
self.safe_start_looping_call(self._refresh_contacts_lc, 60)
|
||||
|
||||
@property
|
||||
def contacts(self):
|
||||
def _inner():
|
||||
for i in range(len(self._routingTable._buckets)):
|
||||
for contact in self._routingTable._buckets[i]._contacts:
|
||||
yield contact
|
||||
return list(_inner())
|
||||
|
||||
def hasContacts(self):
|
||||
for bucket in self._routingTable._buckets:
|
||||
if bucket._contacts:
|
||||
return True
|
||||
return False
|
||||
|
||||
def bucketsWithContacts(self):
|
||||
return self._routingTable.bucketsWithContacts()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def storeToContact(self, blob_hash, contact):
|
||||
peer_generator.add_generator(
|
||||
self.get_iterative_value_finder(
|
||||
binascii.unhexlify(blob_hash.encode()), bottom_out_limit=bottom_out_limit,
|
||||
max_results=max_results
|
||||
)
|
||||
)
|
||||
add_hashes_task = self.loop.create_task(_add_hashes_from_queue())
|
||||
try:
|
||||
if not contact.token:
|
||||
yield contact.findValue(blob_hash)
|
||||
res = yield contact.store(blob_hash, contact.token, self.peerPort, self.node_id, 0)
|
||||
if res != b"OK":
|
||||
raise ValueError(res)
|
||||
log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
|
||||
return True
|
||||
except protocol.TimeoutError:
|
||||
log.debug("Timeout while storing blob_hash %s at %s",
|
||||
binascii.hexlify(blob_hash), contact.log_id())
|
||||
except ValueError as err:
|
||||
log.error("Unexpected response: %s" % err)
|
||||
except Exception as err:
|
||||
if 'Invalid token' in str(err):
|
||||
contact.update_token(None)
|
||||
log.error("Unexpected error while storing blob_hash %s at %s: %s",
|
||||
binascii.hexlify(blob_hash), contact, err)
|
||||
return False
|
||||
async with peer_generator as junction:
|
||||
yield junction
|
||||
await peer_generator.finished.wait()
|
||||
except asyncio.CancelledError:
|
||||
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
|
||||
add_hashes_task.cancel()
|
||||
raise
|
||||
finally:
|
||||
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
|
||||
add_hashes_task.cancel()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def announceHaveBlob(self, blob_hash):
|
||||
contacts = yield self.iterativeFindNode(blob_hash)
|
||||
def peer_search_junction(self, node_id: bytes, max_results=constants.k*2,
|
||||
bottom_out_limit=20) -> AsyncGeneratorJunction:
|
||||
peer_generator = AsyncGeneratorJunction(self.loop)
|
||||
peer_generator.add_generator(
|
||||
self.get_iterative_node_finder(
|
||||
node_id, bottom_out_limit=bottom_out_limit, max_results=max_results
|
||||
)
|
||||
)
|
||||
return peer_generator
|
||||
|
||||
if not self.externalIP:
|
||||
raise Exception("Cannot determine external IP: %s" % self.externalIP)
|
||||
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
|
||||
contacted_node_ids = [binascii.hexlify(contact.id) for contact in stored_to.keys() if stored_to[contact]]
|
||||
log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
|
||||
len(contacted_node_ids), len(contacts))
|
||||
defer.returnValue(contacted_node_ids)
|
||||
|
||||
def change_token(self):
|
||||
self.old_token_secret = self.token_secret
|
||||
self.token_secret = self._generateID()
|
||||
|
||||
def make_token(self, compact_ip):
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.token_secret + compact_ip)
|
||||
return h.digest()
|
||||
|
||||
def verify_token(self, token, compact_ip):
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.token_secret + compact_ip)
|
||||
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.old_token_secret + compact_ip)
|
||||
if not token == h.digest():
|
||||
return False
|
||||
return True
|
||||
|
||||
def iterativeFindNode(self, key):
|
||||
""" The basic Kademlia node lookup operation
|
||||
|
||||
Call this to find a remote node in the P2P overlay network.
|
||||
|
||||
@param key: the n-bit key (i.e. the node or value ID) to search for
|
||||
@type key: str
|
||||
|
||||
@return: This immediately returns a deferred object, which will return
|
||||
a list of k "closest" contacts (C{kademlia.contact.Contact}
|
||||
objects) to the specified key as soon as the operation is
|
||||
finished.
|
||||
@rtype: twisted.internet.defer.Deferred
|
||||
"""
|
||||
return self._iterativeFind(key)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def iterativeFindValue(self, key, exclude=None):
|
||||
""" The Kademlia search operation (deterministic)
|
||||
|
||||
Call this to retrieve data from the DHT.
|
||||
|
||||
@param key: the n-bit key (i.e. the value ID) to search for
|
||||
@type key: str
|
||||
|
||||
@return: This immediately returns a deferred object, which will return
|
||||
either one of two things:
|
||||
- If the value was found, it will return a Python
|
||||
dictionary containing the searched-for key (the C{key}
|
||||
parameter passed to this method), and its associated
|
||||
value, in the format:
|
||||
C{<str>key: <str>data_value}
|
||||
- If the value was not found, it will return a list of k
|
||||
"closest" contacts (C{kademlia.contact.Contact} objects)
|
||||
to the specified key
|
||||
@rtype: twisted.internet.defer.Deferred
|
||||
"""
|
||||
|
||||
if len(key) != constants.key_bits // 8:
|
||||
raise ValueError("invalid key length!")
|
||||
|
||||
# Execute the search
|
||||
find_result = yield self._iterativeFind(key, rpc='findValue', exclude=exclude)
|
||||
if isinstance(find_result, dict):
|
||||
# We have found the value; now see who was the closest contact without it...
|
||||
# ...and store the key/value pair
|
||||
pass
|
||||
else:
|
||||
# The value wasn't found, but a list of contacts was returned
|
||||
# Now, see if we have the value (it might seem wasteful to search on the network
|
||||
# first, but it ensures that all values are properly propagated through the
|
||||
# network
|
||||
if self._dataStore.hasPeersForBlob(key):
|
||||
# Ok, we have the value locally, so use that
|
||||
# Send this value to the closest node without it
|
||||
peers = self._dataStore.getPeersForBlob(key)
|
||||
find_result = {key: peers}
|
||||
else:
|
||||
pass
|
||||
|
||||
defer.returnValue(list(set(find_result.get(key, []) if find_result else [])))
|
||||
# TODO: get this working
|
||||
# if 'closestNodeNoValue' in find_result:
|
||||
# closest_node_without_value = find_result['closestNodeNoValue']
|
||||
# try:
|
||||
# response, address = yield closest_node_without_value.findValue(key, rawResponse=True)
|
||||
# yield closest_node_without_value.store(key, response.response['token'], self.peerPort)
|
||||
# except TimeoutError:
|
||||
# pass
|
||||
|
||||
def addContact(self, contact):
|
||||
""" Add/update the given contact; simple wrapper for the same method
|
||||
in this object's RoutingTable object
|
||||
|
||||
@param contact: The contact to add to this node's k-buckets
|
||||
@type contact: kademlia.contact.Contact
|
||||
"""
|
||||
return self._routingTable.addContact(contact)
|
||||
|
||||
def removeContact(self, contact):
|
||||
""" Remove the contact with the specified node ID from this node's
|
||||
table of known nodes. This is a simple wrapper for the same method
|
||||
in this object's RoutingTable object
|
||||
|
||||
@param contact: The Contact object to remove
|
||||
@type contact: _Contact
|
||||
"""
|
||||
self._routingTable.removeContact(contact)
|
||||
|
||||
def findContact(self, contactID):
|
||||
""" Find a entangled.kademlia.contact.Contact object for the specified
|
||||
cotact ID
|
||||
|
||||
@param contactID: The contact ID of the required Contact object
|
||||
@type contactID: str
|
||||
|
||||
@return: Contact object of remote node with the specified node ID,
|
||||
or None if the contact was not found
|
||||
@rtype: twisted.internet.defer.Deferred
|
||||
"""
|
||||
try:
|
||||
df = defer.succeed(self._routingTable.getContact(contactID))
|
||||
except (ValueError, IndexError):
|
||||
df = self.iterativeFindNode(contactID)
|
||||
df.addCallback(lambda nodes: ([node for node in nodes if node.id == contactID] or (None,))[0])
|
||||
return df
|
||||
|
||||
@rpcmethod
|
||||
def ping(self):
|
||||
""" Used to verify contact between two Kademlia nodes
|
||||
|
||||
@rtype: str
|
||||
"""
|
||||
return b'pong'
|
||||
|
||||
@rpcmethod
|
||||
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
|
||||
""" Store the received data in this node's local datastore
|
||||
|
||||
@param blob_hash: The hash of the data
|
||||
@type blob_hash: str
|
||||
|
||||
@param token: The token we previously returned when this contact sent us a findValue
|
||||
@type token: str
|
||||
|
||||
@param port: The TCP port the contact is listening on for requests for this blob (the peerPort)
|
||||
@type port: int
|
||||
|
||||
@param originalPublisherID: The node ID of the node that is the publisher of the data
|
||||
@type originalPublisherID: str
|
||||
|
||||
@param age: The relative age of the data (time in seconds since it was
|
||||
originally published). Note that the original publish time
|
||||
isn't actually given, to compensate for clock skew between
|
||||
different nodes.
|
||||
@type age: int
|
||||
|
||||
@rtype: str
|
||||
"""
|
||||
|
||||
if originalPublisherID is None:
|
||||
originalPublisherID = rpc_contact.id
|
||||
compact_ip = rpc_contact.compact_ip()
|
||||
if self.clock.seconds() - self._protocol.started_listening_time < constants.tokenSecretChangeInterval:
|
||||
pass
|
||||
elif not self.verify_token(token, compact_ip):
|
||||
raise ValueError("Invalid token")
|
||||
if 0 <= port <= 65536:
|
||||
compact_port = port.to_bytes(2, 'big')
|
||||
else:
|
||||
raise TypeError(f'Invalid port: {port}')
|
||||
compact_address = compact_ip + compact_port + rpc_contact.id
|
||||
now = int(self.clock.seconds())
|
||||
originallyPublished = now - age
|
||||
self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished,
|
||||
originalPublisherID)
|
||||
return b'OK'
|
||||
|
||||
@rpcmethod
|
||||
def findNode(self, rpc_contact, key):
|
||||
""" Finds a number of known nodes closest to the node/value with the
|
||||
specified key.
|
||||
|
||||
@param key: the n-bit key (i.e. the node or value ID) to search for
|
||||
@type key: str
|
||||
|
||||
@return: A list of contact triples closest to the specified key.
|
||||
This method will return C{k} (or C{count}, if specified)
|
||||
contacts if at all possible; it will only return fewer if the
|
||||
node is returning all of the contacts that it knows of.
|
||||
@rtype: list
|
||||
"""
|
||||
if len(key) != constants.key_bits // 8:
|
||||
raise ValueError("invalid contact id length: %i" % len(key))
|
||||
|
||||
contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id)
|
||||
contact_triples = []
|
||||
for contact in contacts:
|
||||
contact_triples.append((contact.id, contact.address, contact.port))
|
||||
return contact_triples
|
||||
|
||||
@rpcmethod
|
||||
def findValue(self, rpc_contact, key):
|
||||
""" Return the value associated with the specified key if present in
|
||||
this node's data, otherwise execute FIND_NODE for the key
|
||||
|
||||
@param key: The hashtable key of the data to return
|
||||
@type key: str
|
||||
|
||||
@return: A dictionary containing the requested key/value pair,
|
||||
or a list of contact triples closest to the requested key.
|
||||
@rtype: dict or list
|
||||
"""
|
||||
|
||||
if len(key) != constants.key_bits // 8:
|
||||
raise ValueError("invalid blob hash length: %i" % len(key))
|
||||
|
||||
response = {
|
||||
b'token': self.make_token(rpc_contact.compact_ip()),
|
||||
}
|
||||
|
||||
if self._protocol._protocolVersion:
|
||||
response[b'protocolVersion'] = self._protocol._protocolVersion
|
||||
|
||||
# get peers we have stored for this blob
|
||||
has_other_peers = self._dataStore.hasPeersForBlob(key)
|
||||
peers = []
|
||||
if has_other_peers:
|
||||
peers.extend(self._dataStore.getPeersForBlob(key))
|
||||
|
||||
# if we don't have k storing peers to return and we have this hash locally, include our contact information
|
||||
if len(peers) < constants.k and key in self._dataStore.completed_blobs:
|
||||
compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
|
||||
compact_port = self.peerPort.to_bytes(2, 'big')
|
||||
compact_address = compact_ip + compact_port + self.node_id
|
||||
peers.append(compact_address)
|
||||
|
||||
if peers:
|
||||
response[key] = peers
|
||||
else:
|
||||
response[b'contacts'] = self.findNode(rpc_contact, key)
|
||||
return response
|
||||
|
||||
def _generateID(self):
|
||||
""" Generates an n-bit pseudo-random identifier
|
||||
|
||||
@return: A globally unique n-bit pseudo-random identifier
|
||||
@rtype: str
|
||||
"""
|
||||
return generate_id()
|
||||
|
||||
# from lbrynet.p2p.utils import profile_deferred
|
||||
# @profile_deferred()
|
||||
@defer.inlineCallbacks
|
||||
def _iterativeFind(self, key, startupShortlist=None, rpc='findNode', exclude=None):
|
||||
""" The basic Kademlia iterative lookup operation (for nodes/values)
|
||||
|
||||
This builds a list of k "closest" contacts through iterative use of
|
||||
the "FIND_NODE" RPC, or if C{findValue} is set to C{True}, using the
|
||||
"FIND_VALUE" RPC, in which case the value (if found) may be returned
|
||||
instead of a list of contacts
|
||||
|
||||
@param key: the n-bit key (i.e. the node or value ID) to search for
|
||||
@type key: str
|
||||
@param startupShortlist: A list of contacts to use as the starting
|
||||
shortlist for this search; this is normally
|
||||
only used when the node joins the network
|
||||
@type startupShortlist: list
|
||||
@param rpc: The name of the RPC to issue to remote nodes during the
|
||||
Kademlia lookup operation (e.g. this sets whether this
|
||||
algorithm should search for a data value (if
|
||||
rpc='findValue') or not. It can thus be used to perform
|
||||
other operations that piggy-back on the basic Kademlia
|
||||
lookup operation (Entangled's "delete" RPC, for instance).
|
||||
@type rpc: str
|
||||
|
||||
@return: If C{findValue} is C{True}, the algorithm will stop as soon
|
||||
as a data value for C{key} is found, and return a dictionary
|
||||
containing the key and the found value. Otherwise, it will
|
||||
return a list of the k closest nodes to the specified key
|
||||
@rtype: twisted.internet.defer.Deferred
|
||||
"""
|
||||
|
||||
if len(key) != constants.key_bits // 8:
|
||||
raise ValueError("invalid key length: %i" % len(key))
|
||||
|
||||
if startupShortlist is None:
|
||||
shortlist = self._routingTable.findCloseNodes(key)
|
||||
# if key != self.node_id:
|
||||
# # Update the "last accessed" timestamp for the appropriate k-bucket
|
||||
# self._routingTable.touchKBucket(key)
|
||||
if len(shortlist) == 0:
|
||||
log.warning("This node doesn't know any other nodes")
|
||||
# This node doesn't know of any other nodes
|
||||
fakeDf = defer.Deferred()
|
||||
fakeDf.callback([])
|
||||
result = yield fakeDf
|
||||
defer.returnValue(result)
|
||||
else:
|
||||
# This is used during the bootstrap process
|
||||
shortlist = startupShortlist
|
||||
|
||||
result = yield iterativeFind(self, shortlist, key, rpc, exclude=exclude)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _refreshNode(self):
|
||||
""" Periodically called to perform k-bucket refreshes and data
|
||||
replication/republishing as necessary """
|
||||
yield self._refreshRoutingTable()
|
||||
self._dataStore.removeExpiredPeers()
|
||||
self._refreshStoringPeers()
|
||||
defer.returnValue(None)
|
||||
|
||||
def _refreshContacts(self):
|
||||
self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts, delay=0)
|
||||
|
||||
def _refreshStoringPeers(self):
|
||||
self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts(), delay=0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _refreshRoutingTable(self):
|
||||
nodeIDs = self._routingTable.getRefreshList(0, False)
|
||||
while nodeIDs:
|
||||
searchID = nodeIDs.pop()
|
||||
yield self.iterativeFindNode(searchID)
|
||||
defer.returnValue(None)
|
||||
async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2,
|
||||
bottom_out_limit=20) -> typing.List['KademliaPeer']:
|
||||
accumulated: typing.List['KademliaPeer'] = []
|
||||
async with self.peer_search_junction(self.protocol.node_id, max_results=max_results,
|
||||
bottom_out_limit=bottom_out_limit) as junction:
|
||||
async for peers in junction:
|
||||
log.info("peer search: %s", peers)
|
||||
accumulated.extend(peers)
|
||||
log.info("junction done")
|
||||
log.info("context done")
|
||||
distance = Distance(node_id)
|
||||
accumulated.sort(key=lambda peer: distance(peer.node_id))
|
||||
return accumulated[:count]
|
||||
|
|
203
lbrynet/dht/peer.py
Normal file
203
lbrynet/dht/peer.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
import typing
|
||||
import asyncio
|
||||
import logging
|
||||
import ipaddress
|
||||
from binascii import hexlify
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.serialization.datagram import make_compact_address, make_compact_ip, decode_compact_address
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_ipv4(address):
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
return ip.version == 4
|
||||
except ipaddress.AddressValueError:
|
||||
return False
|
||||
|
||||
|
||||
class PeerManager:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop):
|
||||
self._loop = loop
|
||||
self._rpc_failures: typing.Dict[
|
||||
typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]]
|
||||
] = {}
|
||||
self._last_replied: typing.Dict[typing.Tuple[str, int], float] = {}
|
||||
self._last_sent: typing.Dict[typing.Tuple[str, int], float] = {}
|
||||
self._last_requested: typing.Dict[typing.Tuple[str, int], float] = {}
|
||||
self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {}
|
||||
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
||||
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
||||
self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer']
|
||||
self._lock = asyncio.Lock(loop=loop)
|
||||
|
||||
async def report_failure(self, address: str, udp_port: int):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))
|
||||
self._rpc_failures[(address, udp_port)] = (previous, now)
|
||||
|
||||
async def report_last_sent(self, address: str, udp_port: int):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
self._last_sent[(address, udp_port)] = now
|
||||
|
||||
async def report_last_replied(self, address: str, udp_port: int):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
self._last_replied[(address, udp_port)] = now
|
||||
|
||||
async def report_last_requested(self, address: str, udp_port: int):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
self._last_requested[(address, udp_port)] = now
|
||||
|
||||
async def clear_token(self, node_id: bytes):
|
||||
async with self._lock:
|
||||
self._node_tokens.pop(node_id, None)
|
||||
|
||||
async def update_token(self, node_id: bytes, token: bytes):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
self._node_tokens[node_id] = (now, token)
|
||||
|
||||
def get_node_token(self, node_id: bytes) -> typing.Optional[bytes]:
|
||||
ts, token = self._node_tokens.get(node_id, (None, None))
|
||||
if ts and ts > self._loop.time() - constants.token_secret_refresh_interval:
|
||||
return token
|
||||
|
||||
def get_last_replied(self, address: str, udp_port: int) -> typing.Optional[float]:
|
||||
return self._last_replied.get((address, udp_port))
|
||||
|
||||
def get_node_id(self, address: str, udp_port: int) -> typing.Optional[bytes]:
|
||||
return self._node_id_mapping.get((address, udp_port))
|
||||
|
||||
def get_node_address(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
|
||||
return self._node_id_reverse_mapping.get(node_id)
|
||||
|
||||
async def get_node_address_and_port(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
|
||||
async with self._lock:
|
||||
addr_tuple = self._node_id_reverse_mapping.get(node_id)
|
||||
if addr_tuple and addr_tuple in self._node_id_mapping:
|
||||
return addr_tuple
|
||||
|
||||
async def update_contact_triple(self, node_id: bytes, address: str, udp_port: int):
|
||||
"""
|
||||
Update the mapping of node_id -> address tuple and that of address tuple -> node_id
|
||||
This is to handle peers changing addresses and ids while assuring that the we only ever have
|
||||
one node id / address tuple mapped to each other
|
||||
"""
|
||||
async with self._lock:
|
||||
if (address, udp_port) in self._node_id_mapping:
|
||||
self._node_id_reverse_mapping.pop(self._node_id_mapping.pop((address, udp_port)))
|
||||
if node_id in self._node_id_reverse_mapping:
|
||||
self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id))
|
||||
self._node_id_mapping[(address, udp_port)] = node_id
|
||||
self._node_id_reverse_mapping[node_id] = (address, udp_port)
|
||||
|
||||
def get_kademlia_peer(self, node_id: bytes, address: str, udp_port: int) -> 'KademliaPeer':
|
||||
return KademliaPeer(self._loop, address, node_id, udp_port)
|
||||
|
||||
async def prune(self):
|
||||
now = self._loop.time()
|
||||
async with self._lock:
|
||||
to_pop = []
|
||||
for (address, udp_port), (_, last_failure) in self._rpc_failures.items():
|
||||
if last_failure and last_failure < now - constants.rpc_attempts_pruning_window:
|
||||
to_pop.append((address, udp_port))
|
||||
while to_pop:
|
||||
del self._rpc_failures[to_pop.pop()]
|
||||
to_pop = []
|
||||
for node_id, (age, token) in self._node_tokens.items():
|
||||
if age < now - constants.token_secret_refresh_interval:
|
||||
to_pop.append(node_id)
|
||||
while to_pop:
|
||||
del self._node_tokens[to_pop.pop()]
|
||||
|
||||
def contact_triple_is_good(self, node_id: bytes, address: str, udp_port: int):
|
||||
"""
|
||||
:return: False if peer is bad, None if peer is unknown, or True if peer is good
|
||||
"""
|
||||
|
||||
delay = self._loop.time() - constants.check_refresh_interval
|
||||
|
||||
if node_id not in self._node_id_reverse_mapping or (address, udp_port) not in self._node_id_mapping:
|
||||
return
|
||||
addr_tup = (address, udp_port)
|
||||
if self._node_id_reverse_mapping[node_id] != addr_tup or self._node_id_mapping[addr_tup] != node_id:
|
||||
return
|
||||
previous_failure, most_recent_failure = self._rpc_failures.get((address, udp_port), (None, None))
|
||||
last_requested = self._last_requested.get((address, udp_port))
|
||||
last_replied = self._last_replied.get((address, udp_port))
|
||||
if most_recent_failure and last_replied:
|
||||
if delay < last_replied > most_recent_failure:
|
||||
return True
|
||||
elif last_replied > most_recent_failure:
|
||||
return
|
||||
return False
|
||||
elif previous_failure and most_recent_failure and most_recent_failure > delay:
|
||||
return False
|
||||
elif last_replied and last_replied > delay:
|
||||
return True
|
||||
elif last_requested and last_requested > delay:
|
||||
return None
|
||||
return
|
||||
|
||||
def peer_is_good(self, peer: 'KademliaPeer'):
|
||||
return self.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port)
|
||||
|
||||
def decode_tcp_peer_from_compact_address(self, compact_address: bytes) -> 'KademliaPeer':
|
||||
node_id, address, tcp_port = decode_compact_address(compact_address)
|
||||
return KademliaPeer(self._loop, address, node_id, tcp_port=tcp_port)
|
||||
|
||||
|
||||
class KademliaPeer:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, address: str, node_id: typing.Optional[bytes] = None,
|
||||
udp_port: typing.Optional[int] = None, tcp_port: typing.Optional[int] = None):
|
||||
if node_id is not None:
|
||||
if not len(node_id) == constants.hash_length:
|
||||
raise ValueError("invalid node_id: {}".format(hexlify(node_id).decode()))
|
||||
if udp_port is not None and not 0 <= udp_port <= 65536:
|
||||
raise ValueError("invalid udp port")
|
||||
if tcp_port and not 0 <= tcp_port <= 65536:
|
||||
raise ValueError("invalid tcp port")
|
||||
if not is_valid_ipv4(address):
|
||||
raise ValueError("invalid ip address")
|
||||
self.loop = loop
|
||||
self._node_id = node_id
|
||||
self.address = address
|
||||
self.udp_port = udp_port
|
||||
self.tcp_port = tcp_port
|
||||
self.protocol_version = 1
|
||||
|
||||
def update_tcp_port(self, tcp_port: int):
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
def update_udp_port(self, udp_port: int):
|
||||
self.udp_port = udp_port
|
||||
|
||||
def set_id(self, node_id):
|
||||
if not self._node_id:
|
||||
self._node_id = node_id
|
||||
|
||||
@property
|
||||
def node_id(self) -> bytes:
|
||||
return self._node_id
|
||||
|
||||
def compact_address_udp(self) -> bytearray:
|
||||
return make_compact_address(self.node_id, self.address, self.udp_port)
|
||||
|
||||
def compact_address_tcp(self) -> bytearray:
|
||||
return make_compact_address(self.node_id, self.address, self.tcp_port)
|
||||
|
||||
def compact_ip(self):
|
||||
return make_compact_ip(self.address)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, KademliaPeer):
|
||||
raise TypeError("invalid type to compare with Peer: %s" % str(type(other)))
|
||||
return (self.node_id, self.address, self.udp_port) == (other.node_id, other.address, other.udp_port)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.node_id, self.address, self.udp_port))
|
|
@ -1,493 +0,0 @@
|
|||
import logging
|
||||
import errno
|
||||
from binascii import hexlify
|
||||
|
||||
from twisted.internet import protocol, defer
|
||||
from lbrynet.dht import constants, encoding, msgformat, msgtypes
|
||||
from lbrynet.dht.error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PingQueue:
|
||||
"""
|
||||
Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the
|
||||
routing table after having been given enough time for a pinhole to expire.
|
||||
"""
|
||||
|
||||
def __init__(self, node):
|
||||
self._node = node
|
||||
self._enqueued_contacts = {}
|
||||
self._pending_contacts = {}
|
||||
self._process_lc = node.get_looping_call(self._process)
|
||||
|
||||
def enqueue_maybe_ping(self, *contacts, **kwargs):
|
||||
delay = kwargs.get('delay', constants.checkRefreshInterval)
|
||||
no_op = (defer.succeed(None), lambda: None)
|
||||
for contact in contacts:
|
||||
if delay and contact not in self._enqueued_contacts:
|
||||
self._pending_contacts.setdefault(contact, self._node.clock.seconds() + delay)
|
||||
else:
|
||||
self._enqueued_contacts.setdefault(contact, no_op)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _ping(self, contact):
|
||||
if contact.contact_is_good:
|
||||
return
|
||||
try:
|
||||
yield contact.ping()
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as err:
|
||||
log.warning("unexpected error: %s", err)
|
||||
finally:
|
||||
if contact in self._enqueued_contacts:
|
||||
del self._enqueued_contacts[contact]
|
||||
|
||||
def _process(self):
|
||||
# move contacts that are scheduled to join the queue
|
||||
if self._pending_contacts:
|
||||
now = self._node.clock.seconds()
|
||||
for contact in [contact for contact, schedule in self._pending_contacts.items() if schedule <= now]:
|
||||
del self._pending_contacts[contact]
|
||||
self._enqueued_contacts.setdefault(contact, (defer.succeed(None), lambda: None))
|
||||
# spread pings across 60 seconds to avoid flood and/or false negatives
|
||||
step = 60.0/float(len(self._enqueued_contacts)) if self._enqueued_contacts else 0
|
||||
for index, (contact, (call, _)) in enumerate(self._enqueued_contacts.items()):
|
||||
if call.called and not contact.contact_is_good:
|
||||
self._enqueued_contacts[contact] = self._node.reactor_callLater(index*step, self._ping, contact)
|
||||
|
||||
def start(self):
|
||||
return self._node.safe_start_looping_call(self._process_lc, 60)
|
||||
|
||||
def stop(self):
|
||||
map(None, (cancel() for _, (call, cancel) in self._enqueued_contacts.items() if not call.called))
|
||||
return self._node.safe_stop_looping_call(self._process_lc)
|
||||
|
||||
|
||||
class KademliaProtocol(protocol.DatagramProtocol):
|
||||
""" Implements all low-level network-related functions of a Kademlia node """
|
||||
|
||||
msgSizeLimit = constants.udpDatagramMaxSize - 26
|
||||
|
||||
def __init__(self, node):
|
||||
self._node = node
|
||||
self._translator = msgformat.DefaultFormat()
|
||||
self._sentMessages = {}
|
||||
self._partialMessages = {}
|
||||
self._partialMessagesProgress = {}
|
||||
self._listening = defer.Deferred(None)
|
||||
self._ping_queue = PingQueue(self._node)
|
||||
self._protocolVersion = constants.protocolVersion
|
||||
self.started_listening_time = 0
|
||||
|
||||
def _migrate_incoming_rpc_args(self, contact, method, *args):
|
||||
if method == b'store' and contact.protocolVersion == 0:
|
||||
if isinstance(args[1], dict):
|
||||
blob_hash = args[0]
|
||||
token = args[1].pop(b'token', None)
|
||||
port = args[1].pop(b'port', -1)
|
||||
originalPublisherID = args[1].pop(b'lbryid', None)
|
||||
age = 0
|
||||
return (blob_hash, token, port, originalPublisherID, age), {}
|
||||
return args, {}
|
||||
|
||||
def _migrate_outgoing_rpc_args(self, contact, method, *args):
|
||||
"""
|
||||
This will reformat protocol version 0 arguments for the store function and will add the
|
||||
protocol version keyword argument to calls to contacts who will accept it
|
||||
"""
|
||||
if contact.protocolVersion == 0:
|
||||
if method == b'store':
|
||||
blob_hash, token, port, originalPublisherID, age = args
|
||||
args = (
|
||||
blob_hash, {
|
||||
b'token': token,
|
||||
b'port': port,
|
||||
b'lbryid': originalPublisherID
|
||||
}, originalPublisherID, False
|
||||
)
|
||||
return args
|
||||
return args
|
||||
if args and isinstance(args[-1], dict):
|
||||
args[-1][b'protocolVersion'] = self._protocolVersion
|
||||
return args
|
||||
return args + ({b'protocolVersion': self._protocolVersion},)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sendRPC(self, contact, method, args):
|
||||
for _ in range(constants.rpcAttempts):
|
||||
try:
|
||||
response = yield self._sendRPC(contact, method, args)
|
||||
return response
|
||||
except TimeoutError:
|
||||
if contact.contact_is_good:
|
||||
log.debug("RETRY %s ON %s", method, contact)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
def _sendRPC(self, contact, method, args):
|
||||
"""
|
||||
Sends an RPC to the specified contact
|
||||
|
||||
@param contact: The contact (remote node) to send the RPC to
|
||||
@type contact: kademlia.contacts.Contact
|
||||
@param method: The name of remote method to invoke
|
||||
@type method: str
|
||||
@param args: A list of (non-keyword) arguments to pass to the remote
|
||||
method, in the correct order
|
||||
@type args: tuple
|
||||
|
||||
@return: This immediately returns a deferred object, which will return
|
||||
the result of the RPC call, or raise the relevant exception
|
||||
if the remote node raised one. If C{rawResponse} is set to
|
||||
C{True}, however, it will always return the actual response
|
||||
message (which may be a C{ResponseMessage} or an
|
||||
C{ErrorMessage}).
|
||||
@rtype: twisted.internet.defer.Deferred
|
||||
"""
|
||||
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
|
||||
*args))
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
|
||||
if args:
|
||||
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
|
||||
hexlify(args[0]), contact.address, contact.port)
|
||||
else:
|
||||
log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
|
||||
contact.address, contact.port)
|
||||
|
||||
df = defer.Deferred()
|
||||
|
||||
def _remove_contact(failure): # remove the contact from the routing table and track the failure
|
||||
contact.update_last_failed()
|
||||
try:
|
||||
if not contact.contact_is_good:
|
||||
self._node.removeContact(contact)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return failure
|
||||
|
||||
def _update_contact(result): # refresh the contact in the routing table
|
||||
contact.update_last_replied()
|
||||
if method == b'findValue':
|
||||
if b'token' in result:
|
||||
contact.update_token(result[b'token'])
|
||||
if b'protocolVersion' not in result:
|
||||
contact.update_protocol_version(0)
|
||||
else:
|
||||
contact.update_protocol_version(result.pop(b'protocolVersion'))
|
||||
d = self._node.addContact(contact)
|
||||
d.addCallback(lambda _: result)
|
||||
return d
|
||||
|
||||
df.addCallbacks(_update_contact, _remove_contact)
|
||||
|
||||
# Set the RPC timeout timer
|
||||
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
|
||||
|
||||
# Transmit the data
|
||||
self._send(encodedMsg, msg.id, (contact.address, contact.port))
|
||||
self._sentMessages[msg.id] = (contact, df, timeoutCall, cancelTimeout, method, args)
|
||||
|
||||
df.addErrback(cancelTimeout)
|
||||
return df
|
||||
|
||||
def startProtocol(self):
|
||||
log.info("DHT listening on UDP %i (ext port %i)", self._node.port, self._node.externalUDPPort)
|
||||
if self._listening.called:
|
||||
self._listening = defer.Deferred()
|
||||
self._listening.callback(True)
|
||||
self.started_listening_time = self._node.clock.seconds()
|
||||
return self._ping_queue.start()
|
||||
|
||||
def datagramReceived(self, datagram, address):
|
||||
""" Handles and parses incoming RPC messages (and responses)
|
||||
|
||||
@note: This is automatically called by Twisted when the protocol
|
||||
receives a UDP datagram
|
||||
"""
|
||||
if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00':
|
||||
totalPackets = (datagram[1] << 8) | datagram[2]
|
||||
msgID = datagram[5:25]
|
||||
seqNumber = (datagram[3] << 8) | datagram[4]
|
||||
if msgID not in self._partialMessages:
|
||||
self._partialMessages[msgID] = {}
|
||||
self._partialMessages[msgID][seqNumber] = datagram[26:]
|
||||
if len(self._partialMessages[msgID]) == totalPackets:
|
||||
keys = self._partialMessages[msgID].keys()
|
||||
keys.sort()
|
||||
data = b''
|
||||
for key in keys:
|
||||
data += self._partialMessages[msgID][key]
|
||||
datagram = data
|
||||
del self._partialMessages[msgID]
|
||||
else:
|
||||
return
|
||||
try:
|
||||
msgPrimitive = encoding.bdecode(datagram)
|
||||
message = self._translator.fromPrimitive(msgPrimitive)
|
||||
except (encoding.DecodeError, ValueError) as err:
|
||||
# We received some rubbish here
|
||||
log.warning("Error decoding datagram %s from %s:%i - %s", hexlify(datagram),
|
||||
address[0], address[1], err)
|
||||
return
|
||||
except (IndexError, KeyError):
|
||||
log.warning("Couldn't decode dht datagram from %s", address)
|
||||
return
|
||||
|
||||
if isinstance(message, msgtypes.RequestMessage):
|
||||
# This is an RPC method request
|
||||
remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self)
|
||||
remoteContact.update_last_requested()
|
||||
# only add a requesting contact to the routing table if it has replied to one of our requests
|
||||
if remoteContact.contact_is_good is True:
|
||||
df = self._node.addContact(remoteContact)
|
||||
else:
|
||||
df = defer.succeed(None)
|
||||
df.addCallback(lambda _: self._handleRPC(remoteContact, message.id, message.request, message.args))
|
||||
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
|
||||
# will be added to our routing table if successful
|
||||
if remoteContact.contact_is_good is None and remoteContact.lastReplied is None:
|
||||
df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact))
|
||||
elif isinstance(message, msgtypes.ErrorMessage):
|
||||
# The RPC request raised a remote exception; raise it locally
|
||||
if message.exceptionType in BUILTIN_EXCEPTIONS:
|
||||
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
|
||||
else:
|
||||
exception_type = UnknownRemoteException
|
||||
remoteException = exception_type(message.response)
|
||||
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
|
||||
address[1], remoteException)
|
||||
if message.id in self._sentMessages:
|
||||
# Cancel timeout timer for this RPC
|
||||
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
|
||||
timeoutCanceller()
|
||||
del self._sentMessages[message.id]
|
||||
|
||||
# reject replies coming from a different address than what we sent our request to
|
||||
if (remoteContact.address, remoteContact.port) != address:
|
||||
log.warning("Sent request to node %s at %s:%i, got reply from %s:%i",
|
||||
remoteContact.log_id(), remoteContact.address,
|
||||
remoteContact.port, address[0], address[1])
|
||||
df.errback(TimeoutError(remoteContact.id))
|
||||
return
|
||||
|
||||
# this error is returned by nodes that can be contacted but have an old
|
||||
# and broken version of the ping command, if they return it the node can
|
||||
# be contacted, so we'll treat it as a successful ping
|
||||
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
|
||||
if isinstance(remoteException, TypeError) and \
|
||||
remoteException.message == old_ping_error:
|
||||
log.debug("old pong error")
|
||||
df.callback('pong')
|
||||
else:
|
||||
df.errback(remoteException)
|
||||
elif isinstance(message, msgtypes.ResponseMessage):
|
||||
# Find the message that triggered this response
|
||||
if message.id in self._sentMessages:
|
||||
# Cancel timeout timer for this RPC
|
||||
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
|
||||
timeoutCanceller()
|
||||
del self._sentMessages[message.id]
|
||||
log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port,
|
||||
method, remoteContact.address, remoteContact.port)
|
||||
|
||||
# When joining the network we made Contact objects for the seed nodes with node ids set to None
|
||||
# Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set.
|
||||
# These replies have be distinguished from those where the node id in the datagram does not match
|
||||
# the node id of the node we sent a message to (these messages are treated as an error)
|
||||
if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
|
||||
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
|
||||
remoteContact.log_id(False), hexlify(message.nodeID))
|
||||
df.errback(TimeoutError(remoteContact.id))
|
||||
return
|
||||
elif not remoteContact.id:
|
||||
remoteContact.set_id(message.nodeID)
|
||||
|
||||
# We got a result from the RPC
|
||||
df.callback(message.response)
|
||||
else:
|
||||
# If the original message isn't found, it must have timed out
|
||||
# TODO: we should probably do something with this...
|
||||
pass
|
||||
|
||||
def _send(self, data, rpcID, address):
|
||||
""" Transmit the specified data over UDP, breaking it up into several
|
||||
packets if necessary
|
||||
|
||||
If the data is spread over multiple UDP datagrams, the packets have the
|
||||
following structure::
|
||||
| | | | | |||||||||||| 0x00 |
|
||||
|Transmission|Total number|Sequence number| RPC ID |Header end|
|
||||
| type ID | of packets |of this packet | | indicator|
|
||||
| (1 byte) | (2 bytes) | (2 bytes) |(20 bytes)| (1 byte) |
|
||||
| | | | | |||||||||||| |
|
||||
|
||||
@note: The header used for breaking up large data segments will
|
||||
possibly be moved out of the KademliaProtocol class in the
|
||||
future, into something similar to a message translator/encoder
|
||||
class (see C{kademlia.msgformat} and C{kademlia.encoding}).
|
||||
"""
|
||||
|
||||
if len(data) > self.msgSizeLimit:
|
||||
# We have to spread the data over multiple UDP datagrams,
|
||||
# and provide sequencing information
|
||||
#
|
||||
# 1st byte is transmission type id, bytes 2 & 3 are the
|
||||
# total number of packets in this transmission, bytes 4 &
|
||||
# 5 are the sequence number for this specific packet
|
||||
totalPackets = len(data) // self.msgSizeLimit
|
||||
if len(data) % self.msgSizeLimit > 0:
|
||||
totalPackets += 1
|
||||
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
|
||||
seqNumber = 0
|
||||
startPos = 0
|
||||
while seqNumber < totalPackets:
|
||||
packetData = data[startPos:startPos + self.msgSizeLimit]
|
||||
encSeqNumber = chr(seqNumber >> 8) + chr(seqNumber & 0xff)
|
||||
txData = f'\x00{encTotalPackets}{encSeqNumber}{rpcID}\x00{packetData}'
|
||||
self._scheduleSendNext(txData, address)
|
||||
|
||||
startPos += self.msgSizeLimit
|
||||
seqNumber += 1
|
||||
else:
|
||||
self._scheduleSendNext(data, address)
|
||||
|
||||
def _scheduleSendNext(self, txData, address):
|
||||
"""Schedule the sending of the next UDP packet """
|
||||
delayed_call, _ = self._node.reactor_callSoon(self._write, txData, address)
|
||||
|
||||
def _write(self, txData, address):
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.write(txData, address)
|
||||
except OSError as err:
|
||||
if err.errno == errno.EWOULDBLOCK:
|
||||
# i'm scared this may swallow important errors, but i get a million of these
|
||||
# on Linux and it doesn't seem to affect anything -grin
|
||||
log.warning("Can't send data to dht: EWOULDBLOCK")
|
||||
elif err.errno == errno.ENETUNREACH:
|
||||
# this should probably try to retransmit when the network connection is back
|
||||
log.error("Network is unreachable")
|
||||
else:
|
||||
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
|
||||
len(txData), address[0], address[1], err.message, err.errno)
|
||||
raise err
|
||||
else:
|
||||
raise TransportNotConnected()
|
||||
|
||||
def _sendResponse(self, contact, rpcID, response):
|
||||
""" Send a RPC response to the specified contact
|
||||
"""
|
||||
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
self._send(encodedMsg, rpcID, (contact.address, contact.port))
|
||||
|
||||
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
|
||||
""" Send an RPC error message to the specified contact
|
||||
"""
|
||||
exceptionMessage = exceptionMessage.encode()
|
||||
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
|
||||
msgPrimitive = self._translator.toPrimitive(msg)
|
||||
encodedMsg = encoding.bencode(msgPrimitive)
|
||||
self._send(encodedMsg, rpcID, (contact.address, contact.port))
|
||||
|
||||
def _handleRPC(self, senderContact, rpcID, method, args):
|
||||
""" Executes a local function in response to an RPC request """
|
||||
|
||||
# Set up the deferred callchain
|
||||
def handleError(f):
|
||||
self._sendError(senderContact, rpcID, f.type, f.getErrorMessage())
|
||||
|
||||
def handleResult(result):
|
||||
self._sendResponse(senderContact, rpcID, result)
|
||||
|
||||
df = defer.Deferred()
|
||||
df.addCallback(handleResult)
|
||||
df.addErrback(handleError)
|
||||
|
||||
# Execute the RPC
|
||||
func = getattr(self._node, method.decode(), None)
|
||||
if callable(func) and hasattr(func, "rpcmethod"):
|
||||
# Call the exposed Node method and return the result to the deferred callback chain
|
||||
# if args:
|
||||
# log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
|
||||
# args[0].encode('hex'), senderContact.address, senderContact.port)
|
||||
# else:
|
||||
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
|
||||
senderContact.address, senderContact.port)
|
||||
if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting
|
||||
senderContact.update_protocol_version(int(args[-1].pop(b'protocolVersion')))
|
||||
a, kw = tuple(args[:-1]), args[-1]
|
||||
else:
|
||||
senderContact.update_protocol_version(0)
|
||||
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
|
||||
try:
|
||||
if method != b'ping':
|
||||
result = func(senderContact, *a)
|
||||
else:
|
||||
result = func()
|
||||
except Exception as e:
|
||||
log.error("error handling request for %s:%i %s", senderContact.address, senderContact.port, method)
|
||||
df.errback(e)
|
||||
else:
|
||||
df.callback(result)
|
||||
else:
|
||||
# No such exposed method
|
||||
df.errback(AttributeError('Invalid method: %s' % method))
|
||||
return df
|
||||
|
||||
def _msgTimeout(self, messageID):
|
||||
""" Called when an RPC request message times out """
|
||||
# Find the message that timed out
|
||||
if messageID not in self._sentMessages:
|
||||
# This should never be reached
|
||||
log.error("deferred timed out, but is not present in sent messages list!")
|
||||
return
|
||||
remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID]
|
||||
if messageID in self._partialMessages:
|
||||
# We are still receiving this message
|
||||
self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args)
|
||||
return
|
||||
del self._sentMessages[messageID]
|
||||
# The message's destination node is now considered to be dead;
|
||||
# raise an (asynchronous) TimeoutError exception and update the host node
|
||||
df.errback(TimeoutError(remoteContact.id))
|
||||
|
||||
def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args):
|
||||
# See if any progress has been made; if not, kill the message
|
||||
if self._hasProgressBeenMade(messageID):
|
||||
# Reset the RPC timeout timer
|
||||
timeoutCanceller()
|
||||
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
|
||||
self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args)
|
||||
else:
|
||||
# No progress has been made
|
||||
if messageID in self._partialMessagesProgress:
|
||||
del self._partialMessagesProgress[messageID]
|
||||
if messageID in self._partialMessages:
|
||||
del self._partialMessages[messageID]
|
||||
df.errback(TimeoutError(remoteContact.id))
|
||||
|
||||
def _hasProgressBeenMade(self, messageID):
|
||||
return (
|
||||
messageID in self._partialMessagesProgress and
|
||||
(
|
||||
len(self._partialMessagesProgress[messageID]) !=
|
||||
len(self._partialMessages[messageID])
|
||||
)
|
||||
)
|
||||
|
||||
def stopProtocol(self):
|
||||
""" Called when the transport is disconnected.
|
||||
|
||||
Will only be called once, after all ports are disconnected.
|
||||
"""
|
||||
log.info('Stopping DHT')
|
||||
self._ping_queue.stop()
|
||||
self._node.call_later_manager.stop()
|
||||
log.info('DHT stopped')
|
0
lbrynet/dht/protocol/__init__.py
Normal file
0
lbrynet/dht/protocol/__init__.py
Normal file
105
lbrynet/dht/protocol/async_generator_junction.py
Normal file
105
lbrynet/dht/protocol/async_generator_junction.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
import traceback
|
||||
if typing.TYPE_CHECKING:
|
||||
from types import AsyncGeneratorType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cancel_task(task: typing.Optional[asyncio.Task]):
|
||||
if task and not (task.done() or task.cancelled()):
|
||||
task.cancel()
|
||||
|
||||
|
||||
def cancel_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
||||
for task in tasks:
|
||||
cancel_task(task)
|
||||
|
||||
|
||||
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
||||
while tasks:
|
||||
cancel_task(tasks.pop())
|
||||
|
||||
|
||||
class AsyncGeneratorJunction:
|
||||
"""
|
||||
A helper to interleave the results from multiple async generators into one
|
||||
async generator.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
|
||||
self.loop = loop
|
||||
self.result_queue = queue or asyncio.Queue(loop=loop)
|
||||
self.tasks: typing.List[asyncio.Task] = []
|
||||
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
|
||||
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
|
||||
self.can_iterate = asyncio.Event(loop=self.loop)
|
||||
self.finished = asyncio.Event(loop=self.loop)
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return any(self.running_iterators.values())
|
||||
|
||||
async def wait_for_generators(self):
|
||||
async def iterate(iterator: typing.AsyncGenerator):
|
||||
try:
|
||||
async for item in iterator:
|
||||
self.result_queue.put_nowait(item)
|
||||
finally:
|
||||
self.running_iterators[iterator] = False
|
||||
|
||||
while True:
|
||||
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
|
||||
self.running_iterators[async_gen] = True
|
||||
self.tasks.append(self.loop.create_task(iterate(async_gen)))
|
||||
if not self.can_iterate.is_set():
|
||||
self.can_iterate.set()
|
||||
|
||||
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
|
||||
"""
|
||||
Add an async generator. This can be called during an iteration of the generator junction.
|
||||
"""
|
||||
self.generator_queue.put_nowait(async_gen)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self.can_iterate.is_set():
|
||||
await self.can_iterate.wait()
|
||||
if not self.running:
|
||||
raise StopAsyncIteration()
|
||||
try:
|
||||
return await self.result_queue.get()
|
||||
finally:
|
||||
self.awaiting = None
|
||||
|
||||
def aclose(self):
|
||||
async def _aclose():
|
||||
for iterator in list(self.running_iterators.keys()):
|
||||
result = iterator.aclose()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
self.running_iterators[iterator] = False
|
||||
drain_tasks(self.tasks)
|
||||
raise StopAsyncIteration()
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
return self.loop.create_task(_aclose())
|
||||
|
||||
async def __aenter__(self):
|
||||
self.tasks.append(self.loop.create_task(self.wait_for_generators()))
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
try:
|
||||
await self.aclose()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
if exc_type:
|
||||
if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration):
|
||||
err = traceback.format_exception(exc_type, exc, tb)
|
||||
log.error(err)
|
76
lbrynet/dht/protocol/data_store.py
Normal file
76
lbrynet/dht/protocol/data_store.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import asyncio
|
||||
import typing
|
||||
|
||||
from lbrynet.dht import constants
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.peer import KademliaPeer, PeerManager
|
||||
|
||||
|
||||
class DictDataStore:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'):
|
||||
# Dictionary format:
|
||||
# { <key>: [<contact>, <value>, <lastPublished>, <originallyPublished> <original_publisher_id>] }
|
||||
self._data_store: typing.Dict[bytes,
|
||||
typing.List[typing.Tuple['KademliaPeer', bytes, float, float, bytes]]] = {}
|
||||
self._get_time = loop.time
|
||||
self._peer_manager = peer_manager
|
||||
self.completed_blobs: typing.Set[str] = set()
|
||||
|
||||
def filter_bad_and_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
|
||||
"""
|
||||
Returns only non-expired and unknown/good peers
|
||||
"""
|
||||
peers = []
|
||||
for peer in map(lambda p: p[0],
|
||||
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration,
|
||||
self._data_store[key])):
|
||||
if self._peer_manager.peer_is_good(peer) is not False:
|
||||
peers.append(peer)
|
||||
return peers
|
||||
|
||||
def filter_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
|
||||
"""
|
||||
Returns only non-expired peers
|
||||
"""
|
||||
return list(
|
||||
map(
|
||||
lambda p: p[0],
|
||||
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, self._data_store[key])
|
||||
)
|
||||
)
|
||||
|
||||
def removed_expired_peers(self):
|
||||
expired_keys = []
|
||||
for key in self._data_store.keys():
|
||||
unexpired_peers = self.filter_expired_peers(key)
|
||||
if not unexpired_peers:
|
||||
expired_keys.append(key)
|
||||
else:
|
||||
self._data_store[key] = [x for x in self._data_store[key] if x[0] in unexpired_peers]
|
||||
for key in expired_keys:
|
||||
del self._data_store[key]
|
||||
|
||||
def has_peers_for_blob(self, key: bytes) -> bool:
|
||||
return key in self._data_store and len(self.filter_bad_and_expired_peers(key)) > 0
|
||||
|
||||
def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes, compact_address: bytes, last_published: int,
|
||||
originally_published: int, original_publisher_id: bytes) -> None:
|
||||
if key in self._data_store:
|
||||
if compact_address not in map(lambda store_tuple: store_tuple[1], self._data_store[key]):
|
||||
self._data_store[key].append(
|
||||
(contact, compact_address, last_published, originally_published, original_publisher_id)
|
||||
)
|
||||
else:
|
||||
self._data_store[key] = [(contact, compact_address, last_published, originally_published,
|
||||
original_publisher_id)]
|
||||
|
||||
def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']:
|
||||
return [] if key not in self._data_store else [peer for peer in self.filter_bad_and_expired_peers(key)]
|
||||
|
||||
def get_storing_contacts(self) -> typing.List['KademliaPeer']:
|
||||
peers = set()
|
||||
for key in self._data_store:
|
||||
for values in self._data_store[key]:
|
||||
if values[0] not in peers:
|
||||
peers.add(values[0])
|
||||
return list(peers)
|
|
@ -8,20 +8,16 @@ class Distance:
|
|||
we pre-calculate the value of that point.
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
if len(key) != constants.key_bits // 8:
|
||||
def __init__(self, key: bytes):
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid key length: %i" % len(key))
|
||||
self.key = key
|
||||
self.val_key_one = int.from_bytes(key, 'big')
|
||||
|
||||
def __call__(self, key_two):
|
||||
def __call__(self, key_two: bytes) -> int:
|
||||
val_key_two = int.from_bytes(key_two, 'big')
|
||||
return self.val_key_one ^ val_key_two
|
||||
|
||||
def is_closer(self, a, b):
|
||||
def is_closer(self, a: bytes, b: bytes) -> bool:
|
||||
"""Returns true is `a` is closer to `key` than `b` is"""
|
||||
return self(a) < self(b)
|
||||
|
||||
def to_contact(self, contact):
|
||||
"""A convenience function for calculating the distance to a contact"""
|
||||
return self(contact.id)
|
381
lbrynet/dht/protocol/iterative_find.py
Normal file
381
lbrynet/dht/protocol/iterative_find.py
Normal file
|
@ -0,0 +1,381 @@
|
|||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
from lbrynet.utils import drain_tasks
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.error import RemoteException
|
||||
from lbrynet.dht.protocol.distance import Distance
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
|
||||
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
||||
from lbrynet.dht.peer import PeerManager, KademliaPeer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindResponse:
|
||||
@property
|
||||
def found(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FindNodeResponse(FindResponse):
|
||||
def __init__(self, key: bytes, close_triples: typing.List[typing.Tuple[bytes, str, int]]):
|
||||
self.key = key
|
||||
self.close_triples = close_triples
|
||||
|
||||
@property
|
||||
def found(self) -> bool:
|
||||
return self.key in [triple[0] for triple in self.close_triples]
|
||||
|
||||
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
|
||||
return self.close_triples
|
||||
|
||||
|
||||
class FindValueResponse(FindResponse):
|
||||
def __init__(self, key: bytes, result_dict: typing.Dict):
|
||||
self.key = key
|
||||
self.token = result_dict[b'token']
|
||||
self.close_triples: typing.List[typing.Tuple[bytes, bytes, int]] = result_dict.get(b'contacts', [])
|
||||
self.found_compact_addresses = result_dict.get(key, [])
|
||||
|
||||
@property
|
||||
def found(self) -> bool:
|
||||
return len(self.found_compact_addresses) > 0
|
||||
|
||||
def get_close_triples(self) -> typing.List[typing.Tuple[bytes, str, int]]:
|
||||
return [(node_id, address.decode(), port) for node_id, address, port in self.close_triples]
|
||||
|
||||
|
||||
def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
|
||||
shortlist: typing.Optional[typing.List['KademliaPeer']]) -> typing.List['KademliaPeer']:
|
||||
"""
|
||||
If not provided, initialize the shortlist of peers to probe to the (up to) k closest peers in the routing table
|
||||
|
||||
:param routing_table: a TreeRoutingTable
|
||||
:param key: a 48 byte hash
|
||||
:param shortlist: optional manually provided shortlist, this is done during bootstrapping when there are no
|
||||
peers in the routing table. During bootstrap the shortlist is set to be the seed nodes.
|
||||
"""
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid key length: %i" % len(key))
|
||||
if not shortlist:
|
||||
shortlist = routing_table.find_close_peers(key)
|
||||
distance = Distance(key)
|
||||
shortlist.sort(key=lambda peer: distance(peer.node_id), reverse=True)
|
||||
return shortlist
|
||||
|
||||
|
||||
class IterativeFinder:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
|
||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
|
||||
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
|
||||
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid key length: %i" % len(key))
|
||||
self.loop = loop
|
||||
self.peer_manager = peer_manager
|
||||
self.routing_table = routing_table
|
||||
self.protocol = protocol
|
||||
|
||||
self.key = key
|
||||
self.bottom_out_limit = bottom_out_limit
|
||||
self.max_results = max_results
|
||||
self.exclude = exclude or []
|
||||
|
||||
self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist)
|
||||
self.active: typing.List['KademliaPeer'] = []
|
||||
self.contacted: typing.List[typing.Tuple[str, int]] = []
|
||||
self.distance = Distance(key)
|
||||
|
||||
self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0]
|
||||
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
|
||||
|
||||
self.iteration_queue = asyncio.Queue(loop=self.loop)
|
||||
|
||||
self.running_probes: typing.List[asyncio.Task] = []
|
||||
self.lock = asyncio.Lock(loop=self.loop)
|
||||
self.iteration_count = 0
|
||||
self.bottom_out_count = 0
|
||||
self.running = False
|
||||
self.tasks: typing.List[asyncio.Task] = []
|
||||
self.delayed_calls: typing.List[asyncio.Handle] = []
|
||||
self.finished = asyncio.Event(loop=self.loop)
|
||||
|
||||
async def send_probe(self, peer: 'KademliaPeer') -> FindResponse:
|
||||
"""
|
||||
Send the rpc request to the peer and return an object with the FindResponse interface
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_result_ready(self, response: FindResponse):
|
||||
"""
|
||||
Called with a lock after adding peers from an rpc result to the shortlist.
|
||||
This method is responsible for putting a result for the generator into the Queue
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_initial_result(self) -> typing.List['KademliaPeer']:
|
||||
"""
|
||||
Get an initial or cached result to be put into the Queue. Used for findValue requests where the blob
|
||||
has peers in the local data store of blobs announced to us
|
||||
"""
|
||||
return []
|
||||
|
||||
def _is_closer(self, peer: 'KademliaPeer') -> bool:
|
||||
if not self.closest_peer:
|
||||
return True
|
||||
return self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
|
||||
|
||||
def _update_closest(self):
|
||||
self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True)
|
||||
if self.closest_peer and self.closest_peer is not self.shortlist[-1]:
|
||||
if self._is_closer(self.shortlist[-1]):
|
||||
self.prev_closest_peer = self.closest_peer
|
||||
self.closest_peer = self.shortlist[-1]
|
||||
|
||||
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
|
||||
async with self.lock:
|
||||
if peer not in self.shortlist:
|
||||
self.shortlist.append(peer)
|
||||
if peer not in self.active:
|
||||
self.active.append(peer)
|
||||
for contact_triple in response.get_close_triples():
|
||||
addr_tuple = (contact_triple[1], contact_triple[2])
|
||||
if addr_tuple not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple)
|
||||
found_peer = self.peer_manager.get_kademlia_peer(
|
||||
contact_triple[0], contact_triple[1], contact_triple[2]
|
||||
)
|
||||
if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False:
|
||||
self.shortlist.append(found_peer)
|
||||
self._update_closest()
|
||||
self.check_result_ready(response)
|
||||
|
||||
async def _send_probe(self, peer: 'KademliaPeer'):
|
||||
try:
|
||||
response = await self.send_probe(peer)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
if peer in self.active:
|
||||
self.active.remove(peer)
|
||||
return
|
||||
except ValueError as err:
|
||||
log.warning(str(err))
|
||||
if peer in self.active:
|
||||
self.active.remove(peer)
|
||||
return
|
||||
except RemoteException:
|
||||
return
|
||||
return await self._handle_probe_result(peer, response)
|
||||
|
||||
async def _search_round(self):
|
||||
"""
|
||||
Send up to constants.alpha (5) probes to the closest peers in the shortlist
|
||||
"""
|
||||
|
||||
added = 0
|
||||
async with self.lock:
|
||||
self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True)
|
||||
while self.running and len(self.shortlist) and added < constants.alpha:
|
||||
peer = self.shortlist.pop()
|
||||
origin_address = (peer.address, peer.udp_port)
|
||||
if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
|
||||
continue
|
||||
if peer.node_id == self.protocol.node_id:
|
||||
continue
|
||||
if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port):
|
||||
continue
|
||||
if (peer.address, peer.udp_port) not in self.contacted:
|
||||
self.contacted.append((peer.address, peer.udp_port))
|
||||
|
||||
t = self.loop.create_task(self._send_probe(peer))
|
||||
|
||||
def callback(_):
|
||||
if t and t in self.running_probes:
|
||||
self.running_probes.remove(t)
|
||||
if not self.running_probes and self.shortlist:
|
||||
self.tasks.append(self.loop.create_task(self._search_task(0.0)))
|
||||
|
||||
t.add_done_callback(callback)
|
||||
self.running_probes.append(t)
|
||||
added += 1
|
||||
|
||||
async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay):
|
||||
try:
|
||||
if self.running:
|
||||
await self._search_round()
|
||||
if self.running:
|
||||
self.delayed_calls.append(self.loop.call_later(delay, self._search))
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
if self.running:
|
||||
drain_tasks(self.running_probes)
|
||||
self.running = False
|
||||
|
||||
def _search(self):
|
||||
self.tasks.append(self.loop.create_task(self._search_task()))
|
||||
|
||||
def search(self):
|
||||
if self.running:
|
||||
raise Exception("already running")
|
||||
self.running = True
|
||||
self._search()
|
||||
|
||||
async def next_queue_or_finished(self) -> typing.List['KademliaPeer']:
|
||||
peers = self.loop.create_task(self.iteration_queue.get())
|
||||
finished = self.loop.create_task(self.finished.wait())
|
||||
err = None
|
||||
try:
|
||||
await asyncio.wait([peers, finished], loop=self.loop, return_when='FIRST_COMPLETED')
|
||||
if peers.done():
|
||||
return peers.result()
|
||||
raise StopAsyncIteration()
|
||||
except asyncio.CancelledError as error:
|
||||
err = error
|
||||
finally:
|
||||
if not finished.done() and not finished.cancelled():
|
||||
finished.cancel()
|
||||
if not peers.done() and not peers.cancelled():
|
||||
peers.cancel()
|
||||
if err:
|
||||
raise err
|
||||
|
||||
def __aiter__(self):
|
||||
self.search()
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> typing.List['KademliaPeer']:
|
||||
try:
|
||||
if self.iteration_count == 0:
|
||||
initial_results = self.get_initial_result()
|
||||
if initial_results:
|
||||
self.iteration_queue.put_nowait(initial_results)
|
||||
result = await self.next_queue_or_finished()
|
||||
self.iteration_count += 1
|
||||
return result
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
await self.aclose()
|
||||
raise
|
||||
|
||||
def aclose(self):
|
||||
self.running = False
|
||||
|
||||
async def _aclose():
|
||||
async with self.lock:
|
||||
self.running = False
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
drain_tasks(self.tasks)
|
||||
drain_tasks(self.running_probes)
|
||||
while self.delayed_calls:
|
||||
timer = self.delayed_calls.pop()
|
||||
if timer:
|
||||
timer.cancel()
|
||||
|
||||
return asyncio.ensure_future(_aclose(), loop=self.loop)
|
||||
|
||||
|
||||
class IterativeNodeFinder(IterativeFinder):
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
|
||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
|
||||
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
|
||||
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
|
||||
super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
|
||||
shortlist)
|
||||
self.yielded_peers: typing.Set['KademliaPeer'] = set()
|
||||
|
||||
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
|
||||
response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
|
||||
return FindNodeResponse(self.key, response)
|
||||
|
||||
def put_result(self, from_list: typing.List['KademliaPeer']):
|
||||
not_yet_yielded = [peer for peer in from_list if peer not in self.yielded_peers]
|
||||
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
|
||||
to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
|
||||
if to_yield:
|
||||
for peer in to_yield:
|
||||
self.yielded_peers.add(peer)
|
||||
self.iteration_queue.put_nowait(to_yield)
|
||||
|
||||
def check_result_ready(self, response: FindNodeResponse):
|
||||
found = response.found and self.key != self.protocol.node_id
|
||||
|
||||
if found:
|
||||
log.info("found")
|
||||
self.put_result(self.shortlist)
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
return
|
||||
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
|
||||
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
|
||||
# self.bottom_out_count, self.iteration_count)
|
||||
self.bottom_out_count = 0
|
||||
elif self.prev_closest_peer and self.closest_peer:
|
||||
self.bottom_out_count += 1
|
||||
log.info("bottom out %i %i %i %i", len(self.active), len(self.contacted), len(self.shortlist),
|
||||
self.bottom_out_count)
|
||||
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
|
||||
log.info("limit hit")
|
||||
self.put_result(self.active)
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
return
|
||||
if self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
|
||||
log.info("max results")
|
||||
self.put_result(self.active)
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
return
|
||||
|
||||
|
||||
class IterativeValueFinder(IterativeFinder):
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager',
|
||||
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
|
||||
bottom_out_limit: typing.Optional[int] = 2, max_results: typing.Optional[int] = constants.k,
|
||||
exclude: typing.Optional[typing.List[typing.Tuple[str, int]]] = None,
|
||||
shortlist: typing.Optional[typing.List['KademliaPeer']] = None):
|
||||
super().__init__(loop, peer_manager, routing_table, protocol, key, bottom_out_limit, max_results, exclude,
|
||||
shortlist)
|
||||
self.blob_peers: typing.Set['KademliaPeer'] = set()
|
||||
|
||||
async def send_probe(self, peer: 'KademliaPeer') -> FindValueResponse:
|
||||
response = await self.protocol.get_rpc_peer(peer).find_value(self.key)
|
||||
return FindValueResponse(self.key, response)
|
||||
|
||||
def check_result_ready(self, response: FindValueResponse):
|
||||
if response.found:
|
||||
blob_peers = [self.peer_manager.decode_tcp_peer_from_compact_address(compact_addr)
|
||||
for compact_addr in response.found_compact_addresses]
|
||||
to_yield = []
|
||||
self.bottom_out_count = 0
|
||||
for blob_peer in blob_peers:
|
||||
if blob_peer not in self.blob_peers:
|
||||
self.blob_peers.add(blob_peer)
|
||||
to_yield.append(blob_peer)
|
||||
if to_yield:
|
||||
# log.info("found %i new peers for blob", len(to_yield))
|
||||
self.iteration_queue.put_nowait(to_yield)
|
||||
# if self.max_results and len(self.blob_peers) >= self.max_results:
|
||||
# log.info("enough blob peers found")
|
||||
# if not self.finished.is_set():
|
||||
# self.finished.set()
|
||||
return
|
||||
if self.prev_closest_peer and self.closest_peer:
|
||||
self.bottom_out_count += 1
|
||||
if self.bottom_out_count >= self.bottom_out_limit:
|
||||
log.info("blob peer search bottomed out")
|
||||
if not self.finished.is_set():
|
||||
self.finished.set()
|
||||
return
|
||||
|
||||
def get_initial_result(self) -> typing.List['KademliaPeer']:
|
||||
if self.protocol.data_store.has_peers_for_blob(self.key):
|
||||
return self.protocol.data_store.get_peers_for_blob(self.key)
|
||||
return []
|
634
lbrynet/dht/protocol/protocol.py
Normal file
634
lbrynet/dht/protocol/protocol.py
Normal file
|
@ -0,0 +1,634 @@
|
|||
import logging
|
||||
import socket
|
||||
import functools
|
||||
import hashlib
|
||||
import asyncio
|
||||
import typing
|
||||
import binascii
|
||||
from asyncio.protocols import DatagramProtocol
|
||||
from asyncio.transports import DatagramTransport
|
||||
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram
|
||||
from lbrynet.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE
|
||||
from lbrynet.dht.error import RemoteException, TransportNotConnected
|
||||
from lbrynet.dht.protocol.routing_table import TreeRoutingTable
|
||||
from lbrynet.dht.protocol.data_store import DictDataStore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.peer import PeerManager, KademliaPeer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
old_protocol_errors = {
|
||||
"findNode() takes exactly 2 arguments (5 given)": "0.19.1",
|
||||
"findValue() takes exactly 2 arguments (5 given)": "0.19.1"
|
||||
}
|
||||
|
||||
|
||||
class KademliaRPC:
|
||||
def __init__(self, protocol: 'KademliaProtocol', loop: asyncio.BaseEventLoop, peer_port: int = 3333):
|
||||
self.protocol = protocol
|
||||
self.loop = loop
|
||||
self.peer_port = peer_port
|
||||
self.old_token_secret: bytes = None
|
||||
self.token_secret = constants.generate_id()
|
||||
|
||||
def compact_address(self):
|
||||
compact_ip = functools.reduce(lambda buff, x: buff + bytearray([int(x)]),
|
||||
self.protocol.external_ip.split('.'), bytearray())
|
||||
compact_port = self.peer_port.to_bytes(2, 'big')
|
||||
return compact_ip + compact_port + self.protocol.node_id
|
||||
|
||||
@staticmethod
|
||||
def ping():
|
||||
return b'pong'
|
||||
|
||||
def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int,
|
||||
original_publisher_id: bytes, age: int) -> bytes:
|
||||
if original_publisher_id is None:
|
||||
original_publisher_id = rpc_contact.node_id
|
||||
rpc_contact.update_tcp_port(port)
|
||||
if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval:
|
||||
pass
|
||||
elif not self.verify_token(token, rpc_contact.compact_ip()):
|
||||
raise ValueError("Invalid token")
|
||||
now = int(self.loop.time())
|
||||
originally_published = now - age
|
||||
self.protocol.data_store.add_peer_to_blob(
|
||||
rpc_contact, blob_hash, rpc_contact.compact_address_tcp(), now, originally_published, original_publisher_id
|
||||
)
|
||||
return b'OK'
|
||||
|
||||
def find_node(self, rpc_contact: 'KademliaPeer', key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]:
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid contact node_id length: %i" % len(key))
|
||||
|
||||
contacts = self.protocol.routing_table.find_close_peers(key, sender_node_id=rpc_contact.node_id)
|
||||
contact_triples = []
|
||||
for contact in contacts:
|
||||
contact_triples.append((contact.node_id, contact.address, contact.udp_port))
|
||||
return contact_triples
|
||||
|
||||
def find_value(self, rpc_contact: 'KademliaPeer', key: bytes):
|
||||
if len(key) != constants.hash_length:
|
||||
raise ValueError("invalid blob_exchange hash length: %i" % len(key))
|
||||
|
||||
response = {
|
||||
b'token': self.make_token(rpc_contact.compact_ip()),
|
||||
}
|
||||
|
||||
if self.protocol.protocol_version:
|
||||
response[b'protocolVersion'] = self.protocol.protocol_version
|
||||
|
||||
# get peers we have stored for this blob_exchange
|
||||
has_other_peers = self.protocol.data_store.has_peers_for_blob(key)
|
||||
peers = []
|
||||
if has_other_peers:
|
||||
peers.extend(self.protocol.data_store.get_peers_for_blob(key))
|
||||
|
||||
# if we don't have k storing peers to return and we have this hash locally, include our contact information
|
||||
if len(peers) < constants.k and binascii.hexlify(key).decode() in self.protocol.data_store.completed_blobs:
|
||||
peers.append(self.compact_address())
|
||||
if peers:
|
||||
response[key] = peers
|
||||
else:
|
||||
response[b'contacts'] = self.find_node(rpc_contact, key)
|
||||
return response
|
||||
|
||||
def refresh_token(self): # TODO: this needs to be called periodically
|
||||
self.old_token_secret = self.token_secret
|
||||
self.token_secret = constants.generate_id()
|
||||
|
||||
def make_token(self, compact_ip):
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.token_secret + compact_ip)
|
||||
return h.digest()
|
||||
|
||||
def verify_token(self, token, compact_ip):
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.token_secret + compact_ip)
|
||||
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
|
||||
h = hashlib.new('sha384')
|
||||
h.update(self.old_token_secret + compact_ip)
|
||||
if not token == h.digest():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class RemoteKademliaRPC:
|
||||
"""
|
||||
Encapsulates RPC calls to remote Peers
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_tracker: 'PeerManager', protocol: 'KademliaProtocol',
|
||||
peer: 'KademliaPeer'):
|
||||
self.loop = loop
|
||||
self.peer_tracker = peer_tracker
|
||||
self.protocol = protocol
|
||||
self.peer = peer
|
||||
|
||||
async def ping(self) -> bytes:
|
||||
"""
|
||||
:return: b'pong'
|
||||
"""
|
||||
response = await self.protocol.send_request(
|
||||
self.peer, RequestDatagram.make_ping(self.protocol.node_id)
|
||||
)
|
||||
return response.response
|
||||
|
||||
async def store(self, blob_hash: bytes) -> bytes:
|
||||
"""
|
||||
:param blob_hash: blob hash as bytes
|
||||
:return: b'OK'
|
||||
"""
|
||||
if len(blob_hash) != constants.hash_bits // 8:
|
||||
raise ValueError(f"invalid length of blob hash: {len(blob_hash)}")
|
||||
if not self.protocol.peer_port or not 0 < self.protocol.peer_port < 65535:
|
||||
raise ValueError(f"invalid tcp port: {self.protocol.peer_port}")
|
||||
token = self.peer_tracker.get_node_token(self.peer.node_id)
|
||||
if not token:
|
||||
find_value_resp = await self.find_value(blob_hash)
|
||||
token = find_value_resp[b'token']
|
||||
response = await self.protocol.send_request(
|
||||
self.peer, RequestDatagram.make_store(self.protocol.node_id, blob_hash, token, self.protocol.peer_port)
|
||||
)
|
||||
return response.response
|
||||
|
||||
async def find_node(self, key: bytes) -> typing.List[typing.Tuple[bytes, str, int]]:
|
||||
"""
|
||||
:return: [(node_id, address, udp_port), ...]
|
||||
"""
|
||||
if len(key) != constants.hash_bits // 8:
|
||||
raise ValueError(f"invalid length of find node key: {len(key)}")
|
||||
response = await self.protocol.send_request(
|
||||
self.peer, RequestDatagram.make_find_node(self.protocol.node_id, key)
|
||||
)
|
||||
return [(node_id, address.decode(), udp_port) for node_id, address, udp_port in response.response]
|
||||
|
||||
async def find_value(self, key: bytes) -> typing.Union[typing.Dict]:
|
||||
"""
|
||||
:return: {
|
||||
b'token': <token bytes>,
|
||||
b'contacts': [(node_id, address, udp_port), ...]
|
||||
<key bytes>: [<blob_peer_compact_address, ...]
|
||||
}
|
||||
"""
|
||||
if len(key) != constants.hash_bits // 8:
|
||||
raise ValueError(f"invalid length of find value key: {len(key)}")
|
||||
response = await self.protocol.send_request(
|
||||
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
|
||||
)
|
||||
await self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
|
||||
return response.response
|
||||
|
||||
|
||||
class PingQueue:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, protocol: 'KademliaProtocol'):
|
||||
self._loop = loop
|
||||
self._protocol = protocol
|
||||
self._enqueued_contacts: typing.List['KademliaPeer'] = []
|
||||
self._pending_contacts: typing.Dict['KademliaPeer', float] = {}
|
||||
self._process_task: asyncio.Task = None
|
||||
self._next_task: asyncio.Future = None
|
||||
self._next_timer: asyncio.TimerHandle = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running
|
||||
|
||||
async def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None):
|
||||
delay = constants.check_refresh_interval if delay is None else delay
|
||||
async with self._lock:
|
||||
for peer in peers:
|
||||
if delay and peer not in self._enqueued_contacts:
|
||||
self._pending_contacts[peer] = self._loop.time() + delay
|
||||
elif peer not in self._enqueued_contacts:
|
||||
self._enqueued_contacts.append(peer)
|
||||
if peer in self._pending_contacts:
|
||||
del self._pending_contacts[peer]
|
||||
|
||||
async def _process(self):
|
||||
async def _ping(p: 'KademliaPeer'):
|
||||
try:
|
||||
if self._protocol.peer_manager.peer_is_good(p):
|
||||
await self._protocol.add_peer(p)
|
||||
return
|
||||
await self._protocol.get_rpc_peer(p).ping()
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
tasks = []
|
||||
|
||||
async with self._lock:
|
||||
if self._enqueued_contacts or self._pending_contacts:
|
||||
now = self._loop.time()
|
||||
scheduled = [k for k, d in self._pending_contacts.items() if now >= d]
|
||||
for k in scheduled:
|
||||
del self._pending_contacts[k]
|
||||
if k not in self._enqueued_contacts:
|
||||
self._enqueued_contacts.append(k)
|
||||
while self._enqueued_contacts:
|
||||
peer = self._enqueued_contacts.pop()
|
||||
tasks.append(self._loop.create_task(_ping(peer)))
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, loop=self._loop)
|
||||
|
||||
f = self._loop.create_future()
|
||||
self._loop.call_later(1.0, lambda: None if f.done() else f.set_result(None))
|
||||
await f
|
||||
|
||||
def start(self):
|
||||
assert not self._running
|
||||
self._running = True
|
||||
if not self._process_task:
|
||||
self._process_task = self._loop.create_task(self._process())
|
||||
|
||||
def stop(self):
|
||||
assert self._running
|
||||
self._running = False
|
||||
if self._process_task:
|
||||
self._process_task.cancel()
|
||||
self._process_task = None
|
||||
if self._next_task:
|
||||
self._next_task.cancel()
|
||||
self._next_task = None
|
||||
if self._next_timer:
|
||||
self._next_timer.cancel()
|
||||
self._next_timer = None
|
||||
|
||||
|
||||
class KademliaProtocol(DatagramProtocol):
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', node_id: bytes, external_ip: str,
|
||||
udp_port: int, peer_port: int, rpc_timeout: float = 5.0):
|
||||
self.peer_manager = peer_manager
|
||||
self.loop = loop
|
||||
self.node_id = node_id
|
||||
self.external_ip = external_ip
|
||||
self.udp_port = udp_port
|
||||
self.peer_port = peer_port
|
||||
self.is_seed_node = False
|
||||
self.partial_messages: typing.Dict[bytes, typing.Dict[bytes, bytes]] = {}
|
||||
self.sent_messages: typing.Dict[bytes, typing.Tuple['KademliaPeer', asyncio.Future, RequestDatagram]] = {}
|
||||
self.protocol_version = constants.protocol_version
|
||||
self.started_listening_time = 0
|
||||
self.transport: DatagramTransport = None
|
||||
self.old_token_secret = constants.generate_id()
|
||||
self.token_secret = constants.generate_id()
|
||||
self.routing_table = TreeRoutingTable(self.loop, self.peer_manager, self.node_id)
|
||||
self.data_store = DictDataStore(self.loop, self.peer_manager)
|
||||
self.ping_queue = PingQueue(self.loop, self)
|
||||
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
|
||||
self.lock = asyncio.Lock(loop=self.loop)
|
||||
self.rpc_timeout = rpc_timeout
|
||||
self._split_lock = asyncio.Lock(loop=self.loop)
|
||||
|
||||
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
||||
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
||||
|
||||
def stop(self):
|
||||
if self.transport:
|
||||
self.disconnect()
|
||||
|
||||
def disconnect(self):
|
||||
self.transport.close()
|
||||
|
||||
def connection_made(self, transport: DatagramTransport):
|
||||
self.transport = transport
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.stop()
|
||||
|
||||
@staticmethod
|
||||
def _migrate_incoming_rpc_args(peer: 'KademliaPeer', method: bytes, *args) -> typing.Tuple[typing.Tuple,
|
||||
typing.Dict]:
|
||||
if method == b'store' and peer.protocol_version == 0:
|
||||
if isinstance(args[1], dict):
|
||||
blob_hash = args[0]
|
||||
token = args[1].pop(b'token', None)
|
||||
port = args[1].pop(b'port', -1)
|
||||
original_publisher_id = args[1].pop(b'lbryid', None)
|
||||
age = 0
|
||||
return (blob_hash, token, port, original_publisher_id, age), {}
|
||||
return args, {}
|
||||
|
||||
async def _add_peer(self, peer: 'KademliaPeer'):
|
||||
bucket_index = self.routing_table.kbucket_index(peer.node_id)
|
||||
if self.routing_table.buckets[bucket_index].add_peer(peer):
|
||||
return True
|
||||
# The bucket is full; see if it can be split (by checking if its range includes the host node's node_id)
|
||||
if self.routing_table.should_split(bucket_index, peer.node_id):
|
||||
self.routing_table.split_bucket(bucket_index)
|
||||
# Retry the insertion attempt
|
||||
result = await self._add_peer(peer)
|
||||
self.routing_table.join_buckets()
|
||||
return result
|
||||
else:
|
||||
# We can't split the k-bucket
|
||||
#
|
||||
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
|
||||
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
|
||||
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
|
||||
#
|
||||
# A reasonable extension to this is BEP 0005, which extends the above:
|
||||
#
|
||||
# Not all nodes that we learn about are equal. Some are "good" and some are not.
|
||||
# Many nodes using the DHT are able to send queries and receive responses,
|
||||
# but are not able to respond to queries from other nodes. It is important that
|
||||
# each node's routing table must contain only known good nodes. A good node is
|
||||
# a node has responded to one of our queries within the last 15 minutes. A node
|
||||
# is also good if it has ever responded to one of our queries and has sent us a
|
||||
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
|
||||
# questionable. Nodes become bad when they fail to respond to multiple queries
|
||||
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
|
||||
#
|
||||
# When there are bad or questionable nodes in the bucket, the least recent is selected for
|
||||
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
|
||||
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
|
||||
# is ignored if the pinged node replies.
|
||||
|
||||
not_good_contacts = self.routing_table.buckets[bucket_index].get_bad_or_unknown_peers()
|
||||
not_recently_replied = []
|
||||
for p in not_good_contacts:
|
||||
last_replied = self.peer_manager.get_last_replied(p.address, p.udp_port)
|
||||
if not last_replied or last_replied + 60 < self.loop.time():
|
||||
not_recently_replied.append(p)
|
||||
if not_recently_replied:
|
||||
to_replace = not_recently_replied[0]
|
||||
else:
|
||||
to_replace = self.routing_table.buckets[bucket_index].peers[0]
|
||||
last_replied = self.peer_manager.get_last_replied(to_replace.address, to_replace.udp_port)
|
||||
if last_replied and last_replied + 60 > self.loop.time():
|
||||
return False
|
||||
log.debug("pinging %s:%s", to_replace.address, to_replace.udp_port)
|
||||
try:
|
||||
to_replace_rpc = self.get_rpc_peer(to_replace)
|
||||
await to_replace_rpc.ping()
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("Replacing dead contact in bucket %i: %s:%i with %s:%i ", bucket_index,
|
||||
to_replace.address, to_replace.udp_port, peer.address, peer.udp_port)
|
||||
if to_replace in self.routing_table.buckets[bucket_index]:
|
||||
self.routing_table.buckets[bucket_index].remove_peer(to_replace)
|
||||
return await self._add_peer(peer)
|
||||
|
||||
async def add_peer(self, peer: 'KademliaPeer') -> bool:
|
||||
if peer.node_id == self.node_id:
|
||||
return False
|
||||
async with self._split_lock:
|
||||
return await self._add_peer(peer)
|
||||
|
||||
async def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
|
||||
assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
|
||||
binascii.hexlify(self.node_id)[:8].decode())
|
||||
method = message.method
|
||||
if method not in [b'ping', b'store', b'findNode', b'findValue']:
|
||||
raise AttributeError('Invalid method: %s' % message.method.decode())
|
||||
if message.args and isinstance(message.args[-1], dict) and b'protocolVersion' in message.args[-1]:
|
||||
# args don't need reformatting
|
||||
a, kw = tuple(message.args[:-1]), message.args[-1]
|
||||
else:
|
||||
a, kw = self._migrate_incoming_rpc_args(sender_contact, message.method, *message.args)
|
||||
log.debug("%s:%i RECV CALL %s %s:%i", self.external_ip, self.udp_port, message.method.decode(),
|
||||
sender_contact.address, sender_contact.udp_port)
|
||||
|
||||
if method == b'ping':
|
||||
result = self.node_rpc.ping()
|
||||
elif method == b'store':
|
||||
blob_hash, token, port, original_publisher_id, age = a
|
||||
result = self.node_rpc.store(sender_contact, blob_hash, token, port, original_publisher_id, age)
|
||||
elif method == b'findNode':
|
||||
key, = a
|
||||
result = self.node_rpc.find_node(sender_contact, key)
|
||||
else:
|
||||
assert method == b'findValue'
|
||||
key, = a
|
||||
result = self.node_rpc.find_value(sender_contact, key)
|
||||
|
||||
await self.send_response(
|
||||
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
||||
)
|
||||
|
||||
async def handle_request_datagram(self, address, request_datagram: RequestDatagram):
|
||||
# This is an RPC method request
|
||||
await self.peer_manager.report_last_requested(address[0], address[1])
|
||||
await self.peer_manager.update_contact_triple(request_datagram.node_id, address[0], address[1])
|
||||
# only add a requesting contact to the routing table if it has replied to one of our requests
|
||||
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
|
||||
try:
|
||||
await self._handle_rpc(peer, request_datagram)
|
||||
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
|
||||
# will be added to our routing table if successful
|
||||
is_good = self.peer_manager.peer_is_good(peer)
|
||||
if is_good is None:
|
||||
await self.ping_queue.enqueue_maybe_ping(peer)
|
||||
elif is_good is True:
|
||||
await self.add_peer(peer)
|
||||
|
||||
except Exception as err:
|
||||
log.warning("error raised handling %s request from %s:%i - %s(%s)",
|
||||
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
|
||||
str(err))
|
||||
await self.send_error(
|
||||
peer,
|
||||
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
|
||||
str(err).encode())
|
||||
)
|
||||
|
||||
async def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram):
|
||||
# Find the message that triggered this response
|
||||
if response_datagram.rpc_id in self.sent_messages:
|
||||
peer, df, request = self.sent_messages[response_datagram.rpc_id]
|
||||
if peer.address != address[0]:
|
||||
df.set_exception(RemoteException(
|
||||
f"response from {address[0]}:{address[1]}, "
|
||||
f"expected {peer.address}:{peer.udp_port}")
|
||||
)
|
||||
return
|
||||
peer.set_id(response_datagram.node_id)
|
||||
# We got a result from the RPC
|
||||
if peer.node_id == self.node_id:
|
||||
df.set_exception(RemoteException("node has our node id"))
|
||||
return
|
||||
elif response_datagram.node_id == self.node_id:
|
||||
df.set_exception(RemoteException("incoming message is from our node id"))
|
||||
return
|
||||
await self.peer_manager.report_last_replied(address[0], address[1])
|
||||
await self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1])
|
||||
if not df.cancelled():
|
||||
df.set_result(response_datagram)
|
||||
await self.add_peer(peer)
|
||||
else:
|
||||
log.warning("%s:%i replied, but after we cancelled the request attempt",
|
||||
peer.address, peer.udp_port)
|
||||
else:
|
||||
# If the original message isn't found, it must have timed out
|
||||
# TODO: we should probably do something with this...
|
||||
pass
|
||||
|
||||
def handle_error_datagram(self, address, error_datagram: ErrorDatagram):
|
||||
# The RPC request raised a remote exception; raise it locally
|
||||
remote_exception = RemoteException(f"{error_datagram.exception_type}({error_datagram.response})")
|
||||
if error_datagram.rpc_id in self.sent_messages:
|
||||
peer, df, request = self.sent_messages.pop(error_datagram.rpc_id)
|
||||
|
||||
error_msg = f"" \
|
||||
f"Error sending '{request.method}' to {peer.address}:{peer.udp_port}\n" \
|
||||
f"Args: {request.args}\n" \
|
||||
f"Raised: {str(remote_exception)}"
|
||||
if error_datagram.response not in old_protocol_errors:
|
||||
log.warning(error_msg)
|
||||
else:
|
||||
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
|
||||
peer.address, peer.udp_port, old_protocol_errors[error_datagram.response])
|
||||
|
||||
# reject replies coming from a different address than what we sent our request to
|
||||
if (peer.address, peer.udp_port) != address:
|
||||
log.error("node id mismatch in reply")
|
||||
remote_exception = TimeoutError(peer.node_id)
|
||||
df.set_exception(remote_exception)
|
||||
return
|
||||
else:
|
||||
if error_datagram.response not in old_protocol_errors:
|
||||
msg = f"Received error from {address[0]}:{address[1]}, but it isn't in response to a " \
|
||||
f"pending request: {str(remote_exception)}"
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.warning("known dht protocol backwards compatibility error with %s:%i (lbrynet v%s)",
|
||||
address[0], address[1], old_protocol_errors[error_datagram.response])
|
||||
|
||||
def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None:
|
||||
try:
|
||||
message = decode_datagram(datagram)
|
||||
except (ValueError, TypeError):
|
||||
self.loop.create_task(self.peer_manager.report_failure(address[0], address[1]))
|
||||
log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode())
|
||||
return
|
||||
|
||||
if isinstance(message, RequestDatagram):
|
||||
self.loop.create_task(self.handle_request_datagram(address, message))
|
||||
elif isinstance(message, ErrorDatagram):
|
||||
self.handle_error_datagram(address, message)
|
||||
else:
|
||||
assert isinstance(message, ResponseDatagram), "sanity"
|
||||
self.loop.create_task(self.handle_response_datagram(address, message))
|
||||
|
||||
async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram:
|
||||
await self._send(peer, request)
|
||||
response_fut = self.sent_messages[request.rpc_id][1]
|
||||
try:
|
||||
response = await asyncio.wait_for(response_fut, self.rpc_timeout)
|
||||
await self.peer_manager.report_last_replied(peer.address, peer.udp_port)
|
||||
return response
|
||||
except (asyncio.TimeoutError, RemoteException):
|
||||
await self.peer_manager.report_failure(peer.address, peer.udp_port)
|
||||
if self.peer_manager.peer_is_good(peer) is False:
|
||||
self.routing_table.remove_peer(peer)
|
||||
raise
|
||||
|
||||
async def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram):
|
||||
await self._send(peer, response)
|
||||
|
||||
async def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
||||
await self._send(peer, error)
|
||||
|
||||
async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
|
||||
ErrorDatagram]):
|
||||
if not self.transport:
|
||||
raise TransportNotConnected()
|
||||
|
||||
data = message.bencode()
|
||||
if len(data) > constants.msg_size_limit:
|
||||
log.exception("unexpected: %i vs %i", len(data), constants.msg_size_limit)
|
||||
raise ValueError()
|
||||
if isinstance(message, (RequestDatagram, ResponseDatagram)):
|
||||
assert message.node_id == self.node_id, message
|
||||
if isinstance(message, RequestDatagram):
|
||||
assert self.node_id != peer.node_id
|
||||
|
||||
def pop_from_sent_messages(_):
|
||||
if message.rpc_id in self.sent_messages:
|
||||
self.sent_messages.pop(message.rpc_id)
|
||||
|
||||
async with self.lock:
|
||||
if isinstance(message, RequestDatagram):
|
||||
response_fut = self.loop.create_future()
|
||||
response_fut.add_done_callback(pop_from_sent_messages)
|
||||
self.sent_messages[message.rpc_id] = (peer, response_fut, message)
|
||||
try:
|
||||
self.transport.sendto(data, (peer.address, peer.udp_port))
|
||||
except OSError as err:
|
||||
# TODO: handle ENETUNREACH
|
||||
if err.errno == socket.EWOULDBLOCK:
|
||||
# i'm scared this may swallow important errors, but i get a million of these
|
||||
# on Linux and it doesn't seem to affect anything -grin
|
||||
log.warning("Can't send data to dht: EWOULDBLOCK")
|
||||
else:
|
||||
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
|
||||
len(data), peer.address, peer.udp_port, str(err), err.errno)
|
||||
if isinstance(message, RequestDatagram):
|
||||
self.sent_messages[message.rpc_id][1].set_exception(err)
|
||||
else:
|
||||
raise err
|
||||
if isinstance(message, RequestDatagram):
|
||||
await self.peer_manager.report_last_sent(peer.address, peer.udp_port)
|
||||
elif isinstance(message, ErrorDatagram):
|
||||
await self.peer_manager.report_failure(peer.address, peer.udp_port)
|
||||
|
||||
def change_token(self):
|
||||
self.old_token_secret = self.token_secret
|
||||
self.token_secret = constants.generate_id()
|
||||
|
||||
def make_token(self, compact_ip):
|
||||
return constants.digest(self.token_secret + compact_ip)
|
||||
|
||||
def verify_token(self, token, compact_ip):
|
||||
h = constants.hash_class()
|
||||
h.update(self.token_secret + compact_ip)
|
||||
if self.old_token_secret and not token == h.digest(): # TODO: why should we be accepting the previous token?
|
||||
h = constants.hash_class()
|
||||
h.update(self.old_token_secret + compact_ip)
|
||||
if not token == h.digest():
|
||||
return False
|
||||
return True
|
||||
|
||||
async def store_to_peer(self, hash_value: bytes, peer: 'KademliaPeer') -> typing.Tuple[bytes, bool]:
|
||||
try:
|
||||
res = await self.get_rpc_peer(peer).store(hash_value)
|
||||
if res != b"OK":
|
||||
raise ValueError(res)
|
||||
log.info("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||
return peer.node_id, True
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||
except ValueError as err:
|
||||
log.error("Unexpected response: %s" % err)
|
||||
except Exception as err:
|
||||
if 'Invalid token' in str(err):
|
||||
await self.peer_manager.clear_token(peer.node_id)
|
||||
else:
|
||||
log.exception("Unexpected error while storing blob_hash")
|
||||
return peer.node_id, False
|
||||
|
||||
def _write(self, data: bytes, address: typing.Tuple[str, int]):
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.sendto(data, address)
|
||||
except OSError as err:
|
||||
if err.errno == socket.EWOULDBLOCK:
|
||||
# i'm scared this may swallow important errors, but i get a million of these
|
||||
# on Linux and it doesn't seem to affect anything -grin
|
||||
log.warning("Can't send data to dht: EWOULDBLOCK")
|
||||
# elif err.errno == socket.ENETUNREACH:
|
||||
# # this should probably try to retransmit when the network connection is back
|
||||
# log.error("Network is unreachable")
|
||||
else:
|
||||
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
|
||||
len(data), address[0], address[1], str(err), err.errno)
|
||||
raise err
|
||||
else:
|
||||
raise TransportNotConnected()
|
305
lbrynet/dht/protocol/routing_table.py
Normal file
305
lbrynet/dht/protocol/routing_table.py
Normal file
|
@ -0,0 +1,305 @@
|
|||
import asyncio
|
||||
import random
|
||||
import logging
|
||||
import typing
|
||||
import itertools
|
||||
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.protocol.distance import Distance
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.peer import KademliaPeer, PeerManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KBucket:
|
||||
""" Description - later
|
||||
"""
|
||||
|
||||
def __init__(self, peer_manager: 'PeerManager', range_min: int, range_max: int, node_id: bytes):
|
||||
"""
|
||||
@param range_min: The lower boundary for the range in the n-bit ID
|
||||
space covered by this k-bucket
|
||||
@param range_max: The upper boundary for the range in the ID space
|
||||
covered by this k-bucket
|
||||
"""
|
||||
self._peer_manager = peer_manager
|
||||
self.last_accessed = 0
|
||||
self.range_min = range_min
|
||||
self.range_max = range_max
|
||||
self.peers: typing.List['KademliaPeer'] = []
|
||||
self._node_id = node_id
|
||||
|
||||
def add_peer(self, peer: 'KademliaPeer') -> bool:
|
||||
""" Add contact to _contact list in the right order. This will move the
|
||||
contact to the end of the k-bucket if it is already present.
|
||||
|
||||
@raise kademlia.kbucket.BucketFull: Raised when the bucket is full and
|
||||
the contact isn't in the bucket
|
||||
already
|
||||
|
||||
@param peer: The contact to add
|
||||
@type peer: dht.contact._Contact
|
||||
"""
|
||||
if peer in self.peers:
|
||||
# Move the existing contact to the end of the list
|
||||
# - using the new contact to allow add-on data
|
||||
# (e.g. optimization-specific stuff) to pe updated as well
|
||||
self.peers.remove(peer)
|
||||
self.peers.append(peer)
|
||||
return True
|
||||
elif len(self.peers) < constants.k:
|
||||
self.peers.append(peer)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
# raise BucketFull("No space in bucket to insert contact")
|
||||
|
||||
def get_peer(self, node_id: bytes) -> 'KademliaPeer':
|
||||
for peer in self.peers:
|
||||
if peer.node_id == node_id:
|
||||
return peer
|
||||
raise IndexError(node_id)
|
||||
|
||||
def get_peers(self, count=-1, exclude_contact=None, sort_distance_to=None) -> typing.List['KademliaPeer']:
|
||||
""" Returns a list containing up to the first count number of contacts
|
||||
|
||||
@param count: The amount of contacts to return (if 0 or less, return
|
||||
all contacts)
|
||||
@type count: int
|
||||
@param exclude_contact: A node node_id to exclude; if this contact is in
|
||||
the list of returned values, it will be
|
||||
discarded before returning. If a C{str} is
|
||||
passed as this argument, it must be the
|
||||
contact's ID.
|
||||
@type exclude_contact: str
|
||||
|
||||
@param sort_distance_to: Sort distance to the node_id, defaulting to the parent node node_id. If False don't
|
||||
sort the contacts
|
||||
|
||||
@raise IndexError: If the number of requested contacts is too large
|
||||
|
||||
@return: Return up to the first count number of contacts in a list
|
||||
If no contacts are present an empty is returned
|
||||
@rtype: list
|
||||
"""
|
||||
peers = [peer for peer in self.peers if peer.node_id != exclude_contact]
|
||||
|
||||
# Return all contacts in bucket
|
||||
if count <= 0:
|
||||
count = len(peers)
|
||||
|
||||
# Get current contact number
|
||||
current_len = len(peers)
|
||||
|
||||
# If count greater than k - return only k contacts
|
||||
if count > constants.k:
|
||||
count = constants.k
|
||||
|
||||
if not current_len:
|
||||
return peers
|
||||
|
||||
if sort_distance_to is False:
|
||||
pass
|
||||
else:
|
||||
sort_distance_to = sort_distance_to or self._node_id
|
||||
peers.sort(key=lambda c: Distance(sort_distance_to)(c.node_id))
|
||||
|
||||
return peers[:min(current_len, count)]
|
||||
|
||||
def get_bad_or_unknown_peers(self) -> typing.List['KademliaPeer']:
|
||||
peer = self.get_peers(sort_distance_to=False)
|
||||
return [
|
||||
peer for peer in peer
|
||||
if self._peer_manager.contact_triple_is_good(peer.node_id, peer.address, peer.udp_port) is not True
|
||||
]
|
||||
|
||||
def remove_peer(self, peer: 'KademliaPeer') -> None:
|
||||
self.peers.remove(peer)
|
||||
|
||||
def key_in_range(self, key: bytes) -> bool:
|
||||
""" Tests whether the specified key (i.e. node ID) is in the range
|
||||
of the n-bit ID space covered by this k-bucket (in otherwords, it
|
||||
returns whether or not the specified key should be placed in this
|
||||
k-bucket)
|
||||
|
||||
@param key: The key to test
|
||||
@type key: str or int
|
||||
|
||||
@return: C{True} if the key is in this k-bucket's range, or C{False}
|
||||
if not.
|
||||
@rtype: bool
|
||||
"""
|
||||
return self.range_min <= int.from_bytes(key, 'big') < self.range_max
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.peers)
|
||||
|
||||
def __contains__(self, item) -> bool:
|
||||
return item in self.peers
|
||||
|
||||
|
||||
class TreeRoutingTable:
|
||||
""" This class implements a routing table used by a Node class.
|
||||
|
||||
The Kademlia routing table is a binary tree whose leaves are k-buckets,
|
||||
where each k-bucket contains nodes with some common prefix of their IDs.
|
||||
This prefix is the k-bucket's position in the binary tree; it therefore
|
||||
covers some range of ID values, and together all of the k-buckets cover
|
||||
the entire n-bit ID (or key) space (with no overlap).
|
||||
|
||||
@note: In this implementation, nodes in the tree (the k-buckets) are
|
||||
added dynamically, as needed; this technique is described in the 13-page
|
||||
version of the Kademlia paper, in section 2.4. It does, however, use the
|
||||
ping RPC-based k-bucket eviction algorithm described in section 2.2 of
|
||||
that paper.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager', parent_node_id: bytes):
|
||||
self._loop = loop
|
||||
self._peer_manager = peer_manager
|
||||
self._parent_node_id = parent_node_id
|
||||
self.buckets: typing.List[KBucket] = [
|
||||
KBucket(
|
||||
self._peer_manager, range_min=0, range_max=2 ** constants.hash_bits, node_id=self._parent_node_id
|
||||
)
|
||||
]
|
||||
|
||||
def get_peers(self) -> typing.List['KademliaPeer']:
|
||||
return list(itertools.chain.from_iterable(map(lambda bucket: bucket.peers, self.buckets)))
|
||||
|
||||
def should_split(self, bucket_index: int, to_add: bytes) -> bool:
|
||||
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
|
||||
if self.buckets[bucket_index].key_in_range(self._parent_node_id):
|
||||
return True
|
||||
contacts = self.get_peers()
|
||||
distance = Distance(self._parent_node_id)
|
||||
contacts.sort(key=lambda c: distance(c.node_id))
|
||||
kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k - 1]
|
||||
return distance(to_add) < distance(kth_contact.node_id)
|
||||
|
||||
def find_close_peers(self, key: bytes, count: typing.Optional[int] = None,
|
||||
sender_node_id: typing.Optional[bytes] = None) -> typing.List['KademliaPeer']:
|
||||
exclude = [self._parent_node_id]
|
||||
if sender_node_id:
|
||||
exclude.append(sender_node_id)
|
||||
if key in exclude:
|
||||
exclude.remove(key)
|
||||
count = count or constants.k
|
||||
distance = Distance(key)
|
||||
contacts = self.get_peers()
|
||||
contacts = [c for c in contacts if c.node_id not in exclude]
|
||||
if contacts:
|
||||
contacts.sort(key=lambda c: distance(c.node_id))
|
||||
return contacts[:min(count, len(contacts))]
|
||||
return []
|
||||
|
||||
def get_peer(self, contact_id: bytes) -> 'KademliaPeer':
|
||||
"""
|
||||
@raise IndexError: No contact with the specified contact ID is known
|
||||
by this node
|
||||
"""
|
||||
return self.buckets[self.kbucket_index(contact_id)].get_peer(contact_id)
|
||||
|
||||
def get_refresh_list(self, start_index: int = 0, force: bool = False) -> typing.List[bytes]:
|
||||
bucket_index = start_index
|
||||
refresh_ids = []
|
||||
now = int(self._loop.time())
|
||||
for bucket in self.buckets[start_index:]:
|
||||
if force or now - bucket.last_accessed >= constants.refresh_interval:
|
||||
to_search = self.midpoint_id_in_bucket_range(bucket_index)
|
||||
refresh_ids.append(to_search)
|
||||
bucket_index += 1
|
||||
return refresh_ids
|
||||
|
||||
def remove_peer(self, peer: 'KademliaPeer') -> None:
|
||||
if not peer.node_id:
|
||||
return
|
||||
bucket_index = self.kbucket_index(peer.node_id)
|
||||
try:
|
||||
self.buckets[bucket_index].remove_peer(peer)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
def touch_kbucket(self, key: bytes) -> None:
|
||||
self.touch_kbucket_by_index(self.kbucket_index(key))
|
||||
|
||||
def touch_kbucket_by_index(self, bucket_index: int):
|
||||
self.buckets[bucket_index].last_accessed = int(self._loop.time())
|
||||
|
||||
def kbucket_index(self, key: bytes) -> int:
|
||||
i = 0
|
||||
for bucket in self.buckets:
|
||||
if bucket.key_in_range(key):
|
||||
return i
|
||||
else:
|
||||
i += 1
|
||||
return i
|
||||
|
||||
def random_id_in_bucket_range(self, bucket_index: int) -> bytes:
|
||||
random_id = int(random.randrange(self.buckets[bucket_index].range_min, self.buckets[bucket_index].range_max))
|
||||
return random_id.to_bytes(constants.hash_length, 'big')
|
||||
|
||||
def midpoint_id_in_bucket_range(self, bucket_index: int) -> bytes:
|
||||
half = int((self.buckets[bucket_index].range_max - self.buckets[bucket_index].range_min) // 2)
|
||||
return int(self.buckets[bucket_index].range_min + half).to_bytes(constants.hash_length, 'big')
|
||||
|
||||
def split_bucket(self, old_bucket_index: int) -> None:
|
||||
""" Splits the specified k-bucket into two new buckets which together
|
||||
cover the same range in the key/ID space
|
||||
|
||||
@param old_bucket_index: The index of k-bucket to split (in this table's
|
||||
list of k-buckets)
|
||||
@type old_bucket_index: int
|
||||
"""
|
||||
# Resize the range of the current (old) k-bucket
|
||||
old_bucket = self.buckets[old_bucket_index]
|
||||
split_point = old_bucket.range_max - (old_bucket.range_max - old_bucket.range_min) // 2
|
||||
# Create a new k-bucket to cover the range split off from the old bucket
|
||||
new_bucket = KBucket(self._peer_manager, split_point, old_bucket.range_max, self._parent_node_id)
|
||||
old_bucket.range_max = split_point
|
||||
# Now, add the new bucket into the routing table tree
|
||||
self.buckets.insert(old_bucket_index + 1, new_bucket)
|
||||
# Finally, copy all nodes that belong to the new k-bucket into it...
|
||||
for contact in old_bucket.peers:
|
||||
if new_bucket.key_in_range(contact.node_id):
|
||||
new_bucket.add_peer(contact)
|
||||
# ...and remove them from the old bucket
|
||||
for contact in new_bucket.peers:
|
||||
old_bucket.remove_peer(contact)
|
||||
|
||||
def join_buckets(self):
|
||||
to_pop = [i for i, bucket in enumerate(self.buckets) if not len(bucket)]
|
||||
if not to_pop:
|
||||
return
|
||||
log.info("join buckets %i", len(to_pop))
|
||||
bucket_index_to_pop = to_pop[0]
|
||||
assert len(self.buckets[bucket_index_to_pop]) == 0
|
||||
can_go_lower = bucket_index_to_pop - 1 >= 0
|
||||
can_go_higher = bucket_index_to_pop + 1 < len(self.buckets)
|
||||
assert can_go_higher or can_go_lower
|
||||
bucket = self.buckets[bucket_index_to_pop]
|
||||
if can_go_lower and can_go_higher:
|
||||
midpoint = ((bucket.range_max - bucket.range_min) // 2) + bucket.range_min
|
||||
self.buckets[bucket_index_to_pop - 1].range_max = midpoint - 1
|
||||
self.buckets[bucket_index_to_pop + 1].range_min = midpoint
|
||||
elif can_go_lower:
|
||||
self.buckets[bucket_index_to_pop - 1].range_max = bucket.range_max
|
||||
elif can_go_higher:
|
||||
self.buckets[bucket_index_to_pop + 1].range_min = bucket.range_min
|
||||
self.buckets.remove(bucket)
|
||||
return self.join_buckets()
|
||||
|
||||
def contact_in_routing_table(self, address_tuple: typing.Tuple[str, int]) -> bool:
|
||||
for bucket in self.buckets:
|
||||
for contact in bucket.get_peers(sort_distance_to=False):
|
||||
if address_tuple[0] == contact.address and address_tuple[1] == contact.udp_port:
|
||||
return True
|
||||
return False
|
||||
|
||||
def buckets_with_contacts(self) -> int:
|
||||
count = 0
|
||||
for bucket in self.buckets:
|
||||
if len(bucket):
|
||||
count += 1
|
||||
return count
|
|
@ -1,320 +0,0 @@
|
|||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
#
|
||||
# The docstrings in this module contain epytext markup; API documentation
|
||||
# may be created by processing this file with epydoc: http://epydoc.sf.net
|
||||
|
||||
import random
|
||||
import logging
|
||||
from twisted.internet import defer
|
||||
|
||||
from lbrynet.dht import constants, kbucket
|
||||
from lbrynet.dht.error import TimeoutError
|
||||
from lbrynet.dht.distance import Distance
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TreeRoutingTable:
|
||||
""" This class implements a routing table used by a Node class.
|
||||
|
||||
The Kademlia routing table is a binary tree whFose leaves are k-buckets,
|
||||
where each k-bucket contains nodes with some common prefix of their IDs.
|
||||
This prefix is the k-bucket's position in the binary tree; it therefore
|
||||
covers some range of ID values, and together all of the k-buckets cover
|
||||
the entire n-bit ID (or key) space (with no overlap).
|
||||
|
||||
@note: In this implementation, nodes in the tree (the k-buckets) are
|
||||
added dynamically, as needed; this technique is described in the 13-page
|
||||
version of the Kademlia paper, in section 2.4. It does, however, use the
|
||||
C{PING} RPC-based k-bucket eviction algorithm described in section 2.2 of
|
||||
that paper.
|
||||
"""
|
||||
#implements(IRoutingTable)
|
||||
|
||||
def __init__(self, parentNodeID, getTime=None):
|
||||
"""
|
||||
@param parentNodeID: The n-bit node ID of the node to which this
|
||||
routing table belongs
|
||||
@type parentNodeID: str
|
||||
"""
|
||||
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space
|
||||
self._parentNodeID = parentNodeID
|
||||
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits, node_id=self._parentNodeID)]
|
||||
if not getTime:
|
||||
from twisted.internet import reactor
|
||||
getTime = reactor.seconds
|
||||
self._getTime = getTime
|
||||
self._ongoing_replacements = set()
|
||||
|
||||
def get_contacts(self):
|
||||
contacts = []
|
||||
for i in range(len(self._buckets)):
|
||||
for contact in self._buckets[i]._contacts:
|
||||
contacts.append(contact)
|
||||
return contacts
|
||||
|
||||
def _shouldSplit(self, bucketIndex, toAdd):
|
||||
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
|
||||
if self._buckets[bucketIndex].keyInRange(self._parentNodeID):
|
||||
return True
|
||||
contacts = self.get_contacts()
|
||||
distance = Distance(self._parentNodeID)
|
||||
contacts.sort(key=lambda c: distance(c.id))
|
||||
kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k-1]
|
||||
return distance(toAdd) < distance(kth_contact.id)
|
||||
|
||||
def addContact(self, contact):
|
||||
""" Add the given contact to the correct k-bucket; if it already
|
||||
exists, its status will be updated
|
||||
|
||||
@param contact: The contact to add to this node's k-buckets
|
||||
@type contact: kademlia.contact.Contact
|
||||
|
||||
@rtype: defer.Deferred
|
||||
"""
|
||||
|
||||
if contact.id == self._parentNodeID:
|
||||
return defer.succeed(None)
|
||||
bucketIndex = self._kbucketIndex(contact.id)
|
||||
try:
|
||||
self._buckets[bucketIndex].addContact(contact)
|
||||
except kbucket.BucketFull:
|
||||
# The bucket is full; see if it can be split (by checking if its range includes the host node's id)
|
||||
if self._shouldSplit(bucketIndex, contact.id):
|
||||
self._splitBucket(bucketIndex)
|
||||
# Retry the insertion attempt
|
||||
return self.addContact(contact)
|
||||
else:
|
||||
# We can't split the k-bucket
|
||||
#
|
||||
# The 13 page kademlia paper specifies that the least recently contacted node in the bucket
|
||||
# shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
|
||||
# the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
|
||||
#
|
||||
# A reasonable extension to this is BEP 0005, which extends the above:
|
||||
#
|
||||
# Not all nodes that we learn about are equal. Some are "good" and some are not.
|
||||
# Many nodes using the DHT are able to send queries and receive responses,
|
||||
# but are not able to respond to queries from other nodes. It is important that
|
||||
# each node's routing table must contain only known good nodes. A good node is
|
||||
# a node has responded to one of our queries within the last 15 minutes. A node
|
||||
# is also good if it has ever responded to one of our queries and has sent us a
|
||||
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
|
||||
# questionable. Nodes become bad when they fail to respond to multiple queries
|
||||
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
|
||||
#
|
||||
# When there are bad or questionable nodes in the bucket, the least recent is selected for
|
||||
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
|
||||
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
|
||||
# is ignored if the pinged node replies.
|
||||
|
||||
def replaceContact(failure, deadContact):
|
||||
"""
|
||||
Callback for the deferred PING RPC to see if the node to be replaced in the k-bucket is still
|
||||
responding
|
||||
|
||||
@type failure: twisted.python.failure.Failure
|
||||
"""
|
||||
failure.trap(TimeoutError)
|
||||
log.debug("Replacing dead contact in bucket %i: %s:%i (%s) with %s:%i (%s)", bucketIndex,
|
||||
deadContact.address, deadContact.port, deadContact.log_id(), contact.address,
|
||||
contact.port, contact.log_id())
|
||||
try:
|
||||
self._buckets[bucketIndex].removeContact(deadContact)
|
||||
except ValueError:
|
||||
# The contact has already been removed (probably due to a timeout)
|
||||
pass
|
||||
return self.addContact(contact)
|
||||
|
||||
not_good_contacts = self._buckets[bucketIndex].getBadOrUnknownContacts()
|
||||
if not_good_contacts:
|
||||
to_replace = not_good_contacts[0]
|
||||
else:
|
||||
to_replace = self._buckets[bucketIndex]._contacts[0]
|
||||
if to_replace not in self._ongoing_replacements:
|
||||
log.debug("pinging %s:%s", to_replace.address, to_replace.port)
|
||||
self._ongoing_replacements.add(to_replace)
|
||||
df = to_replace.ping()
|
||||
df.addErrback(replaceContact, to_replace)
|
||||
df.addBoth(lambda _: self._ongoing_replacements.remove(to_replace))
|
||||
else:
|
||||
df = defer.succeed(None)
|
||||
return df
|
||||
else:
|
||||
self.touchKBucketByIndex(bucketIndex)
|
||||
return defer.succeed(None)
|
||||
|
||||
def findCloseNodes(self, key, count=None, sender_node_id=None):
|
||||
""" Finds a number of known nodes closest to the node/value with the
|
||||
specified key.
|
||||
|
||||
@param key: the n-bit key (i.e. the node or value ID) to search for
|
||||
@type key: str
|
||||
@param count: the amount of contacts to return, default of k (8)
|
||||
@type count: int
|
||||
@param sender_node_id: Used during RPC, this is be the sender's Node ID
|
||||
Whatever ID is passed in the parameter will get
|
||||
excluded from the list of returned contacts.
|
||||
@type sender_node_id: str
|
||||
|
||||
@return: A list of node contacts (C{kademlia.contact.Contact instances})
|
||||
closest to the specified key.
|
||||
This method will return C{k} (or C{count}, if specified)
|
||||
contacts if at all possible; it will only return fewer if the
|
||||
node is returning all of the contacts that it knows of.
|
||||
@rtype: list
|
||||
"""
|
||||
exclude = [self._parentNodeID]
|
||||
if sender_node_id:
|
||||
exclude.append(sender_node_id)
|
||||
if key in exclude:
|
||||
exclude.remove(key)
|
||||
count = count or constants.k
|
||||
distance = Distance(key)
|
||||
contacts = self.get_contacts()
|
||||
contacts = [c for c in contacts if c.id not in exclude]
|
||||
contacts.sort(key=lambda c: distance(c.id))
|
||||
return contacts[:min(count, len(contacts))]
|
||||
|
||||
def getContact(self, contactID):
|
||||
""" Returns the (known) contact with the specified node ID
|
||||
|
||||
@raise ValueError: No contact with the specified contact ID is known
|
||||
by this node
|
||||
"""
|
||||
bucketIndex = self._kbucketIndex(contactID)
|
||||
return self._buckets[bucketIndex].getContact(contactID)
|
||||
|
||||
def getRefreshList(self, startIndex=0, force=False):
|
||||
""" Finds all k-buckets that need refreshing, starting at the
|
||||
k-bucket with the specified index, and returns IDs to be searched for
|
||||
in order to refresh those k-buckets
|
||||
|
||||
@param startIndex: The index of the bucket to start refreshing at;
|
||||
this bucket and those further away from it will
|
||||
be refreshed. For example, when joining the
|
||||
network, this node will set this to the index of
|
||||
the bucket after the one containing it's closest
|
||||
neighbour.
|
||||
@type startIndex: index
|
||||
@param force: If this is C{True}, all buckets (in the specified range)
|
||||
will be refreshed, regardless of the time they were last
|
||||
accessed.
|
||||
@type force: bool
|
||||
|
||||
@return: A list of node ID's that the parent node should search for
|
||||
in order to refresh the routing Table
|
||||
@rtype: list
|
||||
"""
|
||||
bucketIndex = startIndex
|
||||
refreshIDs = []
|
||||
now = int(self._getTime())
|
||||
for bucket in self._buckets[startIndex:]:
|
||||
if force or now - bucket.lastAccessed >= constants.refreshTimeout:
|
||||
searchID = self.midpoint_id_in_bucket_range(bucketIndex)
|
||||
refreshIDs.append(searchID)
|
||||
bucketIndex += 1
|
||||
return refreshIDs
|
||||
|
||||
def removeContact(self, contact):
|
||||
"""
|
||||
Remove the contact from the routing table
|
||||
|
||||
@param contact: The contact to remove
|
||||
@type contact: dht.contact._Contact
|
||||
"""
|
||||
bucketIndex = self._kbucketIndex(contact.id)
|
||||
try:
|
||||
self._buckets[bucketIndex].removeContact(contact)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
def touchKBucket(self, key):
|
||||
""" Update the "last accessed" timestamp of the k-bucket which covers
|
||||
the range containing the specified key in the key/ID space
|
||||
|
||||
@param key: A key in the range of the target k-bucket
|
||||
@type key: str
|
||||
"""
|
||||
self.touchKBucketByIndex(self._kbucketIndex(key))
|
||||
|
||||
def touchKBucketByIndex(self, bucketIndex):
|
||||
self._buckets[bucketIndex].lastAccessed = int(self._getTime())
|
||||
|
||||
def _kbucketIndex(self, key):
|
||||
""" Calculate the index of the k-bucket which is responsible for the
|
||||
specified key (or ID)
|
||||
|
||||
@param key: The key for which to find the appropriate k-bucket index
|
||||
@type key: str
|
||||
|
||||
@return: The index of the k-bucket responsible for the specified key
|
||||
@rtype: int
|
||||
"""
|
||||
i = 0
|
||||
for bucket in self._buckets:
|
||||
if bucket.keyInRange(key):
|
||||
return i
|
||||
else:
|
||||
i += 1
|
||||
return i
|
||||
|
||||
def random_id_in_bucket_range(self, bucketIndex):
|
||||
""" Returns a random ID in the specified k-bucket's range
|
||||
|
||||
@param bucketIndex: The index of the k-bucket to use
|
||||
@type bucketIndex: int
|
||||
"""
|
||||
|
||||
random_id = int(random.randrange(self._buckets[bucketIndex].rangeMin, self._buckets[bucketIndex].rangeMax))
|
||||
return random_id.to_bytes(constants.key_bits // 8, 'big')
|
||||
|
||||
def midpoint_id_in_bucket_range(self, bucketIndex):
|
||||
""" Returns the middle ID in the specified k-bucket's range
|
||||
|
||||
@param bucketIndex: The index of the k-bucket to use
|
||||
@type bucketIndex: int
|
||||
"""
|
||||
|
||||
half = int((self._buckets[bucketIndex].rangeMax - self._buckets[bucketIndex].rangeMin) // 2)
|
||||
return int(self._buckets[bucketIndex].rangeMin + half).to_bytes(constants.key_bits // 8, 'big')
|
||||
|
||||
def _splitBucket(self, oldBucketIndex):
|
||||
""" Splits the specified k-bucket into two new buckets which together
|
||||
cover the same range in the key/ID space
|
||||
|
||||
@param oldBucketIndex: The index of k-bucket to split (in this table's
|
||||
list of k-buckets)
|
||||
@type oldBucketIndex: int
|
||||
"""
|
||||
# Resize the range of the current (old) k-bucket
|
||||
oldBucket = self._buckets[oldBucketIndex]
|
||||
splitPoint = oldBucket.rangeMax - (oldBucket.rangeMax - oldBucket.rangeMin) / 2
|
||||
# Create a new k-bucket to cover the range split off from the old bucket
|
||||
newBucket = kbucket.KBucket(splitPoint, oldBucket.rangeMax, self._parentNodeID)
|
||||
oldBucket.rangeMax = splitPoint
|
||||
# Now, add the new bucket into the routing table tree
|
||||
self._buckets.insert(oldBucketIndex + 1, newBucket)
|
||||
# Finally, copy all nodes that belong to the new k-bucket into it...
|
||||
for contact in oldBucket._contacts:
|
||||
if newBucket.keyInRange(contact.id):
|
||||
newBucket.addContact(contact)
|
||||
# ...and remove them from the old bucket
|
||||
for contact in newBucket._contacts:
|
||||
oldBucket.removeContact(contact)
|
||||
|
||||
def contactInRoutingTable(self, address_tuple):
|
||||
for bucket in self._buckets:
|
||||
for contact in bucket.getContacts(sort_distance_to=False):
|
||||
if address_tuple[0] == contact.address and address_tuple[1] == contact.port:
|
||||
return True
|
||||
return False
|
||||
|
||||
def bucketsWithContacts(self):
|
||||
count = 0
|
||||
for bucket in self._buckets:
|
||||
if len(bucket):
|
||||
count += 1
|
||||
return count
|
0
lbrynet/dht/serialization/__init__.py
Normal file
0
lbrynet/dht/serialization/__init__.py
Normal file
|
@ -1,8 +1,8 @@
|
|||
import typing
|
||||
from lbrynet.dht.error import DecodeError
|
||||
|
||||
|
||||
def bencode(data):
|
||||
""" Encoder implementation of the Bencode algorithm (Bittorrent). """
|
||||
def _bencode(data: typing.Union[int, bytes, bytearray, str, list, tuple, dict]) -> bytes:
|
||||
if isinstance(data, int):
|
||||
return b'i%de' % data
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
|
@ -12,31 +12,20 @@ def bencode(data):
|
|||
elif isinstance(data, (list, tuple)):
|
||||
encoded_list_items = b''
|
||||
for item in data:
|
||||
encoded_list_items += bencode(item)
|
||||
encoded_list_items += _bencode(item)
|
||||
return b'l%se' % encoded_list_items
|
||||
elif isinstance(data, dict):
|
||||
encoded_dict_items = b''
|
||||
keys = data.keys()
|
||||
for key in sorted(keys):
|
||||
encoded_dict_items += bencode(key)
|
||||
encoded_dict_items += bencode(data[key])
|
||||
encoded_dict_items += _bencode(key)
|
||||
encoded_dict_items += _bencode(data[key])
|
||||
return b'd%se' % encoded_dict_items
|
||||
else:
|
||||
raise TypeError("Cannot bencode '%s' object" % type(data))
|
||||
raise TypeError(f"Cannot bencode {type(data)}")
|
||||
|
||||
|
||||
def bdecode(data):
|
||||
""" Decoder implementation of the Bencode algorithm. """
|
||||
assert type(data) == bytes # fixme: _maybe_ remove this after porting
|
||||
if len(data) == 0:
|
||||
raise DecodeError('Cannot decode empty string')
|
||||
try:
|
||||
return _decode_recursive(data)[0]
|
||||
except ValueError as e:
|
||||
raise DecodeError(str(e))
|
||||
|
||||
|
||||
def _decode_recursive(data, start_index=0):
|
||||
def _bdecode(data: bytes, start_index: int = 0) -> typing.Tuple[typing.Union[int, bytes, list, tuple, dict], int]:
|
||||
if data[start_index] == ord('i'):
|
||||
end_pos = data[start_index:].find(b'e') + start_index
|
||||
return int(data[start_index + 1:end_pos]), end_pos + 1
|
||||
|
@ -44,32 +33,44 @@ def _decode_recursive(data, start_index=0):
|
|||
start_index += 1
|
||||
decoded_list = []
|
||||
while data[start_index] != ord('e'):
|
||||
list_data, start_index = _decode_recursive(data, start_index)
|
||||
list_data, start_index = _bdecode(data, start_index)
|
||||
decoded_list.append(list_data)
|
||||
return decoded_list, start_index + 1
|
||||
elif data[start_index] == ord('d'):
|
||||
start_index += 1
|
||||
decoded_dict = {}
|
||||
while data[start_index] != ord('e'):
|
||||
key, start_index = _decode_recursive(data, start_index)
|
||||
value, start_index = _decode_recursive(data, start_index)
|
||||
key, start_index = _bdecode(data, start_index)
|
||||
value, start_index = _bdecode(data, start_index)
|
||||
decoded_dict[key] = value
|
||||
return decoded_dict, start_index
|
||||
elif data[start_index] == ord('f'):
|
||||
# This (float data type) is a non-standard extension to the original Bencode algorithm
|
||||
end_pos = data[start_index:].find(b'e') + start_index
|
||||
return float(data[start_index + 1:end_pos]), end_pos + 1
|
||||
elif data[start_index] == ord('n'):
|
||||
# This (None/NULL data type) is a non-standard extension
|
||||
# to the original Bencode algorithm
|
||||
return None, start_index + 1
|
||||
else:
|
||||
split_pos = data[start_index:].find(b':') + start_index
|
||||
try:
|
||||
length = int(data[start_index:split_pos])
|
||||
except ValueError:
|
||||
raise DecodeError()
|
||||
except (ValueError, TypeError) as err:
|
||||
raise DecodeError(err)
|
||||
start_index = split_pos + 1
|
||||
end_pos = start_index + length
|
||||
b = data[start_index:end_pos]
|
||||
return b, end_pos
|
||||
|
||||
|
||||
def bencode(data: typing.Dict) -> bytes:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError()
|
||||
return _bencode(data)
|
||||
|
||||
|
||||
def bdecode(data: bytes, allow_non_dict_return: typing.Optional[bool] = False) -> typing.Dict:
|
||||
assert type(data) == bytes, DecodeError(f"invalid data type: {str(type(data))}")
|
||||
|
||||
if len(data) == 0:
|
||||
raise DecodeError('Cannot decode empty string')
|
||||
try:
|
||||
result = _bdecode(data)[0]
|
||||
if not allow_non_dict_return and not isinstance(result, dict):
|
||||
raise ValueError(f'expected dict, got {type(result)}')
|
||||
return result
|
||||
except (ValueError, TypeError) as err:
|
||||
raise DecodeError(err)
|
181
lbrynet/dht/serialization/datagram.py
Normal file
181
lbrynet/dht/serialization/datagram.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import typing
|
||||
from functools import reduce
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.serialization.bencoding import bencode, bdecode
|
||||
|
||||
REQUEST_TYPE = 0
|
||||
RESPONSE_TYPE = 1
|
||||
ERROR_TYPE = 2
|
||||
|
||||
|
||||
class KademliaDatagramBase:
|
||||
"""
|
||||
field names are used to unwrap/wrap the argument names to index integers that replace them in a datagram
|
||||
all packets have an argument dictionary when bdecoded starting with {0: <int>, 1: <bytes>, 2: <bytes>, ...}
|
||||
these correspond to the packet_type, rpc_id, and node_id args
|
||||
"""
|
||||
|
||||
fields = [
|
||||
'packet_type',
|
||||
'rpc_id',
|
||||
'node_id'
|
||||
]
|
||||
|
||||
expected_packet_type = -1
|
||||
|
||||
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes):
|
||||
self.packet_type = packet_type
|
||||
if self.expected_packet_type != packet_type:
|
||||
raise ValueError(f"invalid packet type: {packet_type}, expected {self.expected_packet_type}")
|
||||
if len(rpc_id) != constants.rpc_id_length:
|
||||
raise ValueError(f"invalid rpc node_id: {len(rpc_id)} bytes (expected 20)")
|
||||
if not len(node_id) == constants.hash_length:
|
||||
raise ValueError(f"invalid node node_id: {len(node_id)} bytes (expected 48)")
|
||||
self.rpc_id = rpc_id
|
||||
self.node_id = node_id
|
||||
|
||||
def bencode(self) -> bytes:
|
||||
return bencode({
|
||||
i: getattr(self, k) for i, k in enumerate(self.fields)
|
||||
})
|
||||
|
||||
|
||||
class RequestDatagram(KademliaDatagramBase):
|
||||
fields = [
|
||||
'packet_type',
|
||||
'rpc_id',
|
||||
'node_id',
|
||||
'method',
|
||||
'args'
|
||||
]
|
||||
|
||||
expected_packet_type = REQUEST_TYPE
|
||||
|
||||
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, method: bytes,
|
||||
args: typing.Optional[typing.List] = None):
|
||||
super().__init__(packet_type, rpc_id, node_id)
|
||||
self.method = method
|
||||
self.args = args or []
|
||||
if not self.args:
|
||||
self.args.append({})
|
||||
if isinstance(self.args[-1], dict):
|
||||
self.args[-1][b'protocolVersion'] = 1
|
||||
else:
|
||||
self.args.append({b'protocolVersion': 1})
|
||||
|
||||
@classmethod
|
||||
def make_ping(cls, from_node_id: bytes, rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
||||
raise ValueError("invalid rpc id length")
|
||||
elif not rpc_id:
|
||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
||||
if len(from_node_id) != constants.hash_bits // 8:
|
||||
raise ValueError("invalid node id")
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'ping')
|
||||
|
||||
@classmethod
|
||||
def make_store(cls, from_node_id: bytes, blob_hash: bytes, token: bytes, port: int,
|
||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
||||
raise ValueError("invalid rpc id length")
|
||||
if not rpc_id:
|
||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
||||
if len(from_node_id) != constants.hash_bits // 8:
|
||||
raise ValueError("invalid node id")
|
||||
store_args = [blob_hash, token, port, from_node_id, 0]
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'store', store_args)
|
||||
|
||||
@classmethod
|
||||
def make_find_node(cls, from_node_id: bytes, key: bytes,
|
||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
||||
raise ValueError("invalid rpc id length")
|
||||
if not rpc_id:
|
||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
||||
if len(from_node_id) != constants.hash_bits // 8:
|
||||
raise ValueError("invalid node id")
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findNode', [key])
|
||||
|
||||
@classmethod
|
||||
def make_find_value(cls, from_node_id: bytes, key: bytes,
|
||||
rpc_id: typing.Optional[bytes] = None) -> 'RequestDatagram':
|
||||
if rpc_id and len(rpc_id) != constants.rpc_id_length:
|
||||
raise ValueError("invalid rpc id length")
|
||||
if not rpc_id:
|
||||
rpc_id = constants.generate_id()[:constants.rpc_id_length]
|
||||
if len(from_node_id) != constants.hash_bits // 8:
|
||||
raise ValueError("invalid node id")
|
||||
return cls(REQUEST_TYPE, rpc_id, from_node_id, b'findValue', [key])
|
||||
|
||||
|
||||
class ResponseDatagram(KademliaDatagramBase):
|
||||
fields = [
|
||||
'packet_type',
|
||||
'rpc_id',
|
||||
'node_id',
|
||||
'response'
|
||||
]
|
||||
|
||||
expected_packet_type = RESPONSE_TYPE
|
||||
|
||||
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, response):
|
||||
super().__init__(packet_type, rpc_id, node_id)
|
||||
self.response = response
|
||||
|
||||
|
||||
class ErrorDatagram(KademliaDatagramBase):
|
||||
fields = [
|
||||
'packet_type',
|
||||
'rpc_id',
|
||||
'node_id',
|
||||
'exception_type',
|
||||
'response',
|
||||
]
|
||||
|
||||
expected_packet_type = ERROR_TYPE
|
||||
|
||||
def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, exception_type: bytes, response: bytes):
|
||||
super().__init__(packet_type, rpc_id, node_id)
|
||||
self.exception_type = exception_type.decode()
|
||||
self.response = response.decode()
|
||||
|
||||
|
||||
def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
|
||||
msg_types = {
|
||||
REQUEST_TYPE: RequestDatagram,
|
||||
RESPONSE_TYPE: ResponseDatagram,
|
||||
ERROR_TYPE: ErrorDatagram
|
||||
}
|
||||
|
||||
primitive: typing.Dict = bdecode(datagram)
|
||||
if not isinstance(primitive, dict):
|
||||
raise ValueError("invalid datagram type")
|
||||
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
|
||||
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object
|
||||
else:
|
||||
raise ValueError("invalid datagram type")
|
||||
datagram_class = msg_types[datagram_type]
|
||||
return datagram_class(**{
|
||||
k: primitive[i] # pylint: disable=unsubscriptable-object
|
||||
for i, k in enumerate(datagram_class.fields)
|
||||
if i in primitive # pylint: disable=unsupported-membership-test
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_compact_ip(address: str):
|
||||
return reduce(lambda buff, x: buff + bytearray([int(x)]), address.split('.'), bytearray())
|
||||
|
||||
|
||||
def make_compact_address(node_id: bytes, address: str, port: int) -> bytearray:
|
||||
compact_ip = make_compact_ip(address)
|
||||
if not 0 <= port <= 65536:
|
||||
raise ValueError(f'Invalid port: {port}')
|
||||
return compact_ip + port.to_bytes(2, 'big') + node_id
|
||||
|
||||
|
||||
def decode_compact_address(compact_address: bytes) -> typing.Tuple[bytes, str, int]:
|
||||
address = "{}.{}.{}.{}".format(*compact_address[:4])
|
||||
port = int.from_bytes(compact_address[4:6], 'big')
|
||||
node_id = compact_address[6:]
|
||||
return node_id, address, port
|
|
@ -1,92 +0,0 @@
|
|||
import binascii
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer, task
|
||||
from lbrynet.extras.compat import f2d
|
||||
from lbrynet import utils
|
||||
from lbrynet.conf import Config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DHTHashAnnouncer:
|
||||
def __init__(self, conf: Config, dht_node, storage, concurrent_announcers=None):
|
||||
self.conf = conf
|
||||
self.dht_node = dht_node
|
||||
self.storage = storage
|
||||
self.clock = dht_node.clock
|
||||
self.peer_port = dht_node.peerPort
|
||||
self.hash_queue = []
|
||||
if concurrent_announcers is None:
|
||||
self.concurrent_announcers = conf.concurrent_announcers
|
||||
else:
|
||||
self.concurrent_announcers = concurrent_announcers
|
||||
self._manage_lc = None
|
||||
if self.concurrent_announcers:
|
||||
self._manage_lc = task.LoopingCall(self.manage)
|
||||
self._manage_lc.clock = self.clock
|
||||
self.sem = defer.DeferredSemaphore(self.concurrent_announcers or conf.concurrent_announcers or 1)
|
||||
|
||||
def start(self):
|
||||
if self._manage_lc:
|
||||
self._manage_lc.start(30)
|
||||
|
||||
def stop(self):
|
||||
if self._manage_lc and self._manage_lc.running:
|
||||
self._manage_lc.stop()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_store(self, blob_hash):
|
||||
storing_node_ids = yield self.dht_node.announceHaveBlob(binascii.unhexlify(blob_hash))
|
||||
now = self.clock.seconds()
|
||||
if storing_node_ids:
|
||||
result = (now, storing_node_ids)
|
||||
yield f2d(self.storage.update_last_announced_blob(blob_hash, now))
|
||||
log.debug("Stored %s to %i peers", blob_hash[:16], len(storing_node_ids))
|
||||
else:
|
||||
result = (None, [])
|
||||
self.hash_queue.remove(blob_hash)
|
||||
defer.returnValue(result)
|
||||
|
||||
def hash_queue_size(self):
|
||||
return len(self.hash_queue)
|
||||
|
||||
def _show_announce_progress(self, size, start):
|
||||
queue_size = len(self.hash_queue)
|
||||
average_blobs_per_second = float(size - queue_size) / (self.clock.seconds() - start)
|
||||
log.info("Announced %i/%i blobs, %f blobs per second", size - queue_size, size, average_blobs_per_second)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def immediate_announce(self, blob_hashes):
|
||||
self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue)
|
||||
log.info("Announcing %i blobs", len(self.hash_queue))
|
||||
start = self.clock.seconds()
|
||||
progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start)
|
||||
progress_lc.clock = self.clock
|
||||
progress_lc.start(60, now=False)
|
||||
results = yield utils.DeferredDict(
|
||||
{blob_hash: self.sem.run(self.do_store, blob_hash) for blob_hash in blob_hashes}
|
||||
)
|
||||
now = self.clock.seconds()
|
||||
|
||||
progress_lc.stop()
|
||||
|
||||
announced_to = [blob_hash for blob_hash in results if results[blob_hash][0]]
|
||||
if len(announced_to) != len(results):
|
||||
log.debug("Failed to announce %i blobs", len(results) - len(announced_to))
|
||||
if announced_to:
|
||||
log.info('Took %s seconds to announce %i of %i attempted hashes (%f hashes per second)',
|
||||
now - start, len(announced_to), len(blob_hashes),
|
||||
int(float(len(blob_hashes)) / float(now - start)))
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def manage(self):
|
||||
if not self.dht_node.contacts:
|
||||
log.info("Not ready to start announcing hashes")
|
||||
return
|
||||
need_reannouncement = yield f2d(self.storage.get_blobs_to_announce())
|
||||
if need_reannouncement:
|
||||
yield self.immediate_announce(need_reannouncement)
|
||||
else:
|
||||
log.debug("Nothing to announce")
|
|
@ -1,70 +0,0 @@
|
|||
import binascii
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DummyPeerFinder:
|
||||
"""This class finds peers which have announced to the DHT that they have certain blobs"""
|
||||
|
||||
def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
|
||||
return defer.succeed([])
|
||||
|
||||
|
||||
class DHTPeerFinder(DummyPeerFinder):
|
||||
"""This class finds peers which have announced to the DHT that they have certain blobs"""
|
||||
#implements(IPeerFinder)
|
||||
|
||||
def __init__(self, component_manager):
|
||||
"""
|
||||
component_manager - an instance of ComponentManager
|
||||
"""
|
||||
self.component_manager = component_manager
|
||||
self.peer_manager = component_manager.peer_manager
|
||||
self.peers = {}
|
||||
self._ongoing_searchs = {}
|
||||
|
||||
def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
|
||||
"""
|
||||
Find peers for blob in the DHT
|
||||
blob_hash (str): blob hash to look for
|
||||
timeout (int): seconds to timeout after
|
||||
filter_self (bool): if True, and if a peer for a blob is itself, filter it
|
||||
from the result
|
||||
|
||||
Returns:
|
||||
list of peers for the blob
|
||||
"""
|
||||
if "dht" in self.component_manager.skip_components:
|
||||
return defer.succeed([])
|
||||
if not self.component_manager.all_components_running("dht"):
|
||||
return defer.succeed([])
|
||||
dht_node = self.component_manager.get_component("dht")
|
||||
|
||||
self.peers.setdefault(blob_hash, {(dht_node.externalIP, dht_node.peerPort,)})
|
||||
if not blob_hash in self._ongoing_searchs or self._ongoing_searchs[blob_hash].called:
|
||||
self._ongoing_searchs[blob_hash] = self._execute_peer_search(dht_node, blob_hash, timeout)
|
||||
|
||||
def _filter_self(blob_hash):
|
||||
my_host, my_port = dht_node.externalIP, dht_node.peerPort
|
||||
return {(host, port) for host, port in self.peers[blob_hash] if (host, port) != (my_host, my_port)}
|
||||
|
||||
peers = set(_filter_self(blob_hash) if filter_self else self.peers[blob_hash])
|
||||
return defer.succeed([self.peer_manager.get_peer(*peer) for peer in peers])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _execute_peer_search(self, dht_node, blob_hash, timeout):
|
||||
bin_hash = binascii.unhexlify(blob_hash)
|
||||
finished_deferred = dht_node.iterativeFindValue(bin_hash, exclude=self.peers[blob_hash])
|
||||
timeout = timeout or self.component_manager.conf.peer_search_timeout
|
||||
if timeout:
|
||||
finished_deferred.addTimeout(timeout, dht_node.clock)
|
||||
try:
|
||||
peer_list = yield finished_deferred
|
||||
self.peers[blob_hash].update({(host, port) for _, host, port in peer_list})
|
||||
except defer.TimeoutError:
|
||||
log.debug("DHT timed out while looking peers for blob %s after %s seconds", blob_hash, timeout)
|
||||
finally:
|
||||
del self._ongoing_searchs[blob_hash]
|
|
@ -1,14 +0,0 @@
|
|||
from lbrynet.p2p.Peer import Peer
|
||||
|
||||
|
||||
class PeerManager:
|
||||
def __init__(self):
|
||||
self.peers = []
|
||||
|
||||
def get_peer(self, host, port):
|
||||
for peer in self.peers:
|
||||
if peer.host == host and peer.port == port:
|
||||
return peer
|
||||
peer = Peer(host, port)
|
||||
self.peers.append(peer)
|
||||
return peer
|
|
@ -1,46 +0,0 @@
|
|||
import datetime
|
||||
from collections import defaultdict
|
||||
from lbrynet import utils
|
||||
|
||||
# Do not create this object except through PeerManager
|
||||
class Peer:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
# If a peer is reported down, we wait till this time till reattempting connection
|
||||
self.attempt_connection_at = None
|
||||
# Number of times this Peer has been reported to be down, resets to 0 when it is up
|
||||
self.down_count = 0
|
||||
# Number of successful connections (with full protocol completion) to this peer
|
||||
self.success_count = 0
|
||||
self.score = 0
|
||||
self.stats = defaultdict(float) # {string stat_type, float count}
|
||||
|
||||
def is_available(self):
|
||||
if self.attempt_connection_at is None or utils.today() > self.attempt_connection_at:
|
||||
return True
|
||||
return False
|
||||
|
||||
def report_up(self):
|
||||
self.down_count = 0
|
||||
self.attempt_connection_at = None
|
||||
|
||||
def report_success(self):
|
||||
self.success_count += 1
|
||||
|
||||
def report_down(self):
|
||||
self.down_count += 1
|
||||
timeout_time = datetime.timedelta(seconds=60 * self.down_count)
|
||||
self.attempt_connection_at = utils.today() + timeout_time
|
||||
|
||||
def update_score(self, score_change):
|
||||
self.score += score_change
|
||||
|
||||
def update_stats(self, stat_type, count):
|
||||
self.stats[stat_type] += count
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.host}:{self.port}'
|
||||
|
||||
def __repr__(self):
|
||||
return f'Peer({self.host!r}, {self.port!r})'
|
83
tests/dht_mocks.py
Normal file
83
tests/dht_mocks.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import typing
|
||||
import contextlib
|
||||
import socket
|
||||
import mock
|
||||
import functools
|
||||
import asyncio
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
||||
|
||||
|
||||
def get_time_accelerator(loop: asyncio.BaseEventLoop,
|
||||
now: typing.Optional[float] = None) -> typing.Callable[[float], typing.Awaitable[None]]:
|
||||
"""
|
||||
Returns an async advance() function
|
||||
|
||||
This provides a way to advance() the BaseEventLoop.time for the scheduled TimerHandles
|
||||
made by call_later, call_at, and call_soon.
|
||||
"""
|
||||
|
||||
_time = now or loop.time()
|
||||
loop.time = functools.wraps(loop.time)(lambda: _time)
|
||||
|
||||
async def accelerate_time(seconds: float) -> None:
|
||||
nonlocal _time
|
||||
if seconds < 0:
|
||||
raise ValueError('Cannot go back in time ({} seconds)'.format(seconds))
|
||||
_time += seconds
|
||||
await past_events()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def past_events() -> None:
|
||||
while loop._scheduled:
|
||||
timer: asyncio.TimerHandle = loop._scheduled[0]
|
||||
if timer not in loop._ready and timer._when <= _time:
|
||||
loop._scheduled.remove(timer)
|
||||
loop._ready.append(timer)
|
||||
if timer._when > _time:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def accelerator(seconds: float):
|
||||
steps = seconds * 10.0
|
||||
|
||||
for _ in range(max(int(steps), 1)):
|
||||
await accelerate_time(0.1)
|
||||
|
||||
return accelerator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_network_loop(loop: asyncio.BaseEventLoop):
|
||||
dht_network: typing.Dict[typing.Tuple[str, int], 'KademliaProtocol'] = {}
|
||||
|
||||
async def create_datagram_endpoint(proto_lam: typing.Callable[[], 'KademliaProtocol'],
|
||||
from_addr: typing.Tuple[str, int]):
|
||||
def sendto(data, to_addr):
|
||||
rx = dht_network.get(to_addr)
|
||||
if rx and rx.external_ip:
|
||||
# print(f"{from_addr[0]}:{from_addr[1]} -{len(data)} bytes-> {rx.external_ip}:{rx.udp_port}")
|
||||
return rx.datagram_received(data, from_addr)
|
||||
|
||||
protocol = proto_lam()
|
||||
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
|
||||
transport.close = lambda: mock_sock.close()
|
||||
mock_sock.sendto = sendto
|
||||
transport.sendto = mock_sock.sendto
|
||||
protocol.connection_made(transport)
|
||||
dht_network[from_addr] = protocol
|
||||
return transport, protocol
|
||||
|
||||
with mock.patch('socket.socket') as mock_socket:
|
||||
mock_sock = mock.Mock(spec=socket.socket)
|
||||
mock_sock.setsockopt = lambda *_: None
|
||||
mock_sock.bind = lambda *_: None
|
||||
mock_sock.setblocking = lambda *_: None
|
||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||
mock_sock.getpeername = lambda: ""
|
||||
mock_sock.close = lambda: None
|
||||
mock_sock.type = socket.SOCK_DGRAM
|
||||
mock_sock.fileno = lambda: 7
|
||||
mock_socket.return_value = mock_sock
|
||||
loop.create_datagram_endpoint = create_datagram_endpoint
|
||||
yield
|
0
tests/unit/dht/protocol/__init__.py
Normal file
0
tests/unit/dht/protocol/__init__.py
Normal file
91
tests/unit/dht/protocol/test_protocol.py
Normal file
91
tests/unit/dht/protocol/test_protocol.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import asyncio
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from tests import dht_mocks
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
|
||||
|
||||
class TestProtocol(AsyncioTestCase):
|
||||
async def test_ping(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
with dht_mocks.mock_network_loop(loop):
|
||||
node_id1 = constants.generate_id()
|
||||
peer1 = KademliaProtocol(
|
||||
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
|
||||
)
|
||||
peer2 = KademliaProtocol(
|
||||
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
|
||||
)
|
||||
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
|
||||
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
|
||||
|
||||
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
|
||||
result = await peer2.get_rpc_peer(peer).ping()
|
||||
self.assertEqual(result, b'pong')
|
||||
peer1.stop()
|
||||
peer2.stop()
|
||||
peer1.disconnect()
|
||||
peer2.disconnect()
|
||||
|
||||
async def test_update_token(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
with dht_mocks.mock_network_loop(loop):
|
||||
node_id1 = constants.generate_id()
|
||||
peer1 = KademliaProtocol(
|
||||
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
|
||||
)
|
||||
peer2 = KademliaProtocol(
|
||||
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
|
||||
)
|
||||
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
|
||||
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
|
||||
|
||||
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
|
||||
self.assertEqual(None, peer2.peer_manager.get_node_token(peer.node_id))
|
||||
await peer2.get_rpc_peer(peer).find_value(b'1' * 48)
|
||||
self.assertNotEqual(None, peer2.peer_manager.get_node_token(peer.node_id))
|
||||
peer1.stop()
|
||||
peer2.stop()
|
||||
peer1.disconnect()
|
||||
peer2.disconnect()
|
||||
|
||||
async def test_store_to_peer(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
with dht_mocks.mock_network_loop(loop):
|
||||
node_id1 = constants.generate_id()
|
||||
peer1 = KademliaProtocol(
|
||||
loop, PeerManager(loop), node_id1, '1.2.3.4', 4444, 3333
|
||||
)
|
||||
peer2 = KademliaProtocol(
|
||||
loop, PeerManager(loop), constants.generate_id(), '1.2.3.5', 4444, 3333
|
||||
)
|
||||
await loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
|
||||
await loop.create_datagram_endpoint(lambda: peer2, ('1.2.3.5', 4444))
|
||||
|
||||
peer = peer2.peer_manager.get_kademlia_peer(node_id1, '1.2.3.4', udp_port=4444)
|
||||
peer2_from_peer1 = peer1.peer_manager.get_kademlia_peer(
|
||||
peer2.node_id, peer2.external_ip, udp_port=peer2.udp_port
|
||||
)
|
||||
peer3 = peer1.peer_manager.get_kademlia_peer(
|
||||
constants.generate_id(), '1.2.3.6', udp_port=4444
|
||||
)
|
||||
store_result = await peer2.store_to_peer(b'2' * 48, peer)
|
||||
self.assertEqual(store_result[0], peer.node_id)
|
||||
self.assertEqual(True, store_result[1])
|
||||
self.assertEqual(True, peer1.data_store.has_peers_for_blob(b'2' * 48))
|
||||
self.assertEqual(False, peer1.data_store.has_peers_for_blob(b'3' * 48))
|
||||
self.assertListEqual([peer2_from_peer1], peer1.data_store.get_storing_contacts())
|
||||
|
||||
find_value_response = peer1.node_rpc.find_value(peer3, b'2' * 48)
|
||||
self.assertSetEqual(
|
||||
{b'2' * 48, b'token', b'protocolVersion'}, set(find_value_response.keys())
|
||||
)
|
||||
self.assertEqual(1, len(find_value_response[b'2' * 48]))
|
||||
self.assertEqual(find_value_response[b'2' * 48][0], peer2_from_peer1)
|
||||
# self.assertEqual(peer2_from_peer1.tcp_port, 3333)
|
||||
|
||||
peer1.stop()
|
||||
peer2.stop()
|
||||
peer1.disconnect()
|
||||
peer2.disconnect()
|
0
tests/unit/dht/routing/__init__.py
Normal file
0
tests/unit/dht/routing/__init__.py
Normal file
115
tests/unit/dht/routing/test_kbucket.py
Normal file
115
tests/unit/dht/routing/test_kbucket.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import struct
|
||||
import asyncio
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.dht.protocol.routing_table import KBucket
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
from lbrynet.dht import constants
|
||||
from torba.testcase import AsyncioTestCase
|
||||
|
||||
|
||||
def address_generator(address=(10, 42, 42, 1)):
|
||||
def increment(addr):
|
||||
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1
|
||||
new_addr = []
|
||||
for i in range(4):
|
||||
new_addr.append(value % 256)
|
||||
value >>= 8
|
||||
return tuple(new_addr[::-1])
|
||||
|
||||
while True:
|
||||
yield "{}.{}.{}.{}".format(*address)
|
||||
address = increment(address)
|
||||
|
||||
|
||||
class TestKBucket(AsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.address_generator = address_generator()
|
||||
self.peer_manager = PeerManager(self.loop)
|
||||
self.kbucket = KBucket(self.peer_manager, 0, 2**constants.hash_bits, generate_id())
|
||||
|
||||
def test_add_peer(self):
|
||||
# Test if contacts can be added to empty list
|
||||
# Add k contacts to bucket
|
||||
for i in range(constants.k):
|
||||
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
|
||||
self.assertTrue(self.kbucket.add_peer(peer))
|
||||
self.assertEqual(peer, self.kbucket.peers[i])
|
||||
|
||||
# Test if contact is not added to full list
|
||||
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
|
||||
self.assertFalse(self.kbucket.add_peer(peer))
|
||||
|
||||
# Test if an existing contact is updated correctly if added again
|
||||
existing_peer = self.kbucket.peers[0]
|
||||
self.assertTrue(self.kbucket.add_peer(existing_peer))
|
||||
self.assertEqual(existing_peer, self.kbucket.peers[-1])
|
||||
|
||||
# def testGetContacts(self):
|
||||
# # try and get 2 contacts from empty list
|
||||
# result = self.kbucket.getContacts(2)
|
||||
# self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" %
|
||||
# (len(result)))
|
||||
#
|
||||
# # Add k-2 contacts
|
||||
# node_ids = []
|
||||
# if constants.k >= 2:
|
||||
# for i in range(constants.k-2):
|
||||
# node_ids.append(generate_id())
|
||||
# tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
|
||||
# None)
|
||||
# self.kbucket.addContact(tmpContact)
|
||||
# else:
|
||||
# # add k contacts
|
||||
# for i in range(constants.k):
|
||||
# node_ids.append(generate_id())
|
||||
# tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
|
||||
# None)
|
||||
# self.kbucket.addContact(tmpContact)
|
||||
#
|
||||
# # try to get too many contacts
|
||||
# # requested count greater than bucket size; should return at most k contacts
|
||||
# contacts = self.kbucket.getContacts(constants.k+3)
|
||||
# self.assertTrue(len(contacts) <= constants.k,
|
||||
# 'Returned list should not have more than k entries!')
|
||||
#
|
||||
# # verify returned contacts in list
|
||||
# for node_id, i in zip(node_ids, range(constants.k-2)):
|
||||
# self.assertFalse(self.kbucket._contacts[i].id != node_id,
|
||||
# "Contact in position %s not same as added contact" % (str(i)))
|
||||
#
|
||||
# # try to get too many contacts
|
||||
# # requested count one greater than number of contacts
|
||||
# if constants.k >= 2:
|
||||
# result = self.kbucket.getContacts(constants.k-1)
|
||||
# self.assertFalse(len(result) != constants.k-2,
|
||||
# "Too many contacts in returned list %s - should be %s" %
|
||||
# (len(result), constants.k-2))
|
||||
# else:
|
||||
# result = self.kbucket.getContacts(constants.k-1)
|
||||
# # if the count is <= 0, it should return all of it's contats
|
||||
# self.assertFalse(len(result) != constants.k,
|
||||
# "Too many contacts in returned list %s - should be %s" %
|
||||
# (len(result), constants.k-2))
|
||||
# result = self.kbucket.getContacts(constants.k-3)
|
||||
# self.assertFalse(len(result) != constants.k-3,
|
||||
# "Too many contacts in returned list %s - should be %s" %
|
||||
# (len(result), constants.k-3))
|
||||
|
||||
def test_remove_peer(self):
|
||||
# try remove contact from empty list
|
||||
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
|
||||
self.assertRaises(ValueError, self.kbucket.remove_peer, peer)
|
||||
|
||||
added = []
|
||||
# Add couple contacts
|
||||
for i in range(constants.k-2):
|
||||
peer = self.peer_manager.get_kademlia_peer(generate_id(), next(self.address_generator), 4444)
|
||||
self.assertTrue(self.kbucket.add_peer(peer))
|
||||
added.append(peer)
|
||||
|
||||
while added:
|
||||
peer = added.pop()
|
||||
self.assertIn(peer, self.kbucket.peers)
|
||||
self.kbucket.remove_peer(peer)
|
||||
self.assertNotIn(peer, self.kbucket.peers)
|
259
tests/unit/dht/routing/test_routing_table.py
Normal file
259
tests/unit/dht/routing/test_routing_table.py
Normal file
|
@ -0,0 +1,259 @@
|
|||
import asyncio
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from tests import dht_mocks
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
|
||||
|
||||
class TestRouting(AsyncioTestCase):
|
||||
async def test_fill_one_bucket(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
peer_addresses = [
|
||||
(constants.generate_id(1), '1.2.3.1'),
|
||||
(constants.generate_id(2), '1.2.3.2'),
|
||||
(constants.generate_id(3), '1.2.3.3'),
|
||||
(constants.generate_id(4), '1.2.3.4'),
|
||||
(constants.generate_id(5), '1.2.3.5'),
|
||||
(constants.generate_id(6), '1.2.3.6'),
|
||||
(constants.generate_id(7), '1.2.3.7'),
|
||||
(constants.generate_id(8), '1.2.3.8'),
|
||||
(constants.generate_id(9), '1.2.3.9'),
|
||||
]
|
||||
with dht_mocks.mock_network_loop(loop):
|
||||
nodes = {
|
||||
i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address)
|
||||
for i, (node_id, address) in enumerate(peer_addresses)
|
||||
}
|
||||
node_1 = nodes[0]
|
||||
contact_cnt = 0
|
||||
for i in range(1, len(peer_addresses)):
|
||||
self.assertEqual(len(node_1.protocol.routing_table.get_peers()), contact_cnt)
|
||||
node = nodes[i]
|
||||
peer = node_1.protocol.peer_manager.get_kademlia_peer(
|
||||
node.protocol.node_id, node.protocol.external_ip,
|
||||
udp_port=node.protocol.udp_port
|
||||
)
|
||||
added = await node_1.protocol.add_peer(peer)
|
||||
self.assertEqual(True, added)
|
||||
contact_cnt += 1
|
||||
|
||||
self.assertEqual(len(node_1.protocol.routing_table.get_peers()), 8)
|
||||
self.assertEqual(node_1.protocol.routing_table.buckets_with_contacts(), 1)
|
||||
for node in nodes.values():
|
||||
node.protocol.stop()
|
||||
|
||||
|
||||
# from binascii import hexlify, unhexlify
|
||||
#
|
||||
# from twisted.trial import unittest
|
||||
# from twisted.internet import defer
|
||||
# from lbrynet.dht import constants
|
||||
# from lbrynet.dht.routingtable import TreeRoutingTable
|
||||
# from lbrynet.dht.contact import ContactManager
|
||||
# from lbrynet.dht.distance import Distance
|
||||
# from lbrynet.utils import generate_id
|
||||
#
|
||||
#
|
||||
# class FakeRPCProtocol:
|
||||
# """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
|
||||
# def sendRPC(self, *args, **kwargs):
|
||||
# return defer.succeed(None)
|
||||
#
|
||||
#
|
||||
# class TreeRoutingTableTest(unittest.TestCase):
|
||||
# """ Test case for the RoutingTable class """
|
||||
# def setUp(self):
|
||||
# self.contact_manager = ContactManager()
|
||||
# self.nodeID = generate_id(b'node1')
|
||||
# self.protocol = FakeRPCProtocol()
|
||||
# self.routingTable = TreeRoutingTable(self.nodeID)
|
||||
#
|
||||
# def test_distance(self):
|
||||
# """ Test to see if distance method returns correct result"""
|
||||
# d = Distance(bytes((170,) * 48))
|
||||
# result = d(bytes((85,) * 48))
|
||||
# expected = int(hexlify(bytes((255,) * 48)), 16)
|
||||
# self.assertEqual(result, expected)
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_add_contact(self):
|
||||
# """ Tests if a contact can be added and retrieved correctly """
|
||||
# # Create the contact
|
||||
# contact_id = generate_id(b'node2')
|
||||
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# # Now add it...
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# # ...and request the closest nodes to it (will retrieve it)
|
||||
# closest_nodes = self.routingTable.findCloseNodes(contact_id)
|
||||
# self.assertEqual(len(closest_nodes), 1)
|
||||
# self.assertIn(contact, closest_nodes)
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_get_contact(self):
|
||||
# """ Tests if a specific existing contact can be retrieved correctly """
|
||||
# contact_id = generate_id(b'node2')
|
||||
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# # Now add it...
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# # ...and get it again
|
||||
# same_contact = self.routingTable.getContact(contact_id)
|
||||
# self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_add_parent_node_as_contact(self):
|
||||
# """
|
||||
# Tests the routing table's behaviour when attempting to add its parent node as a contact
|
||||
# """
|
||||
# # Create a contact with the same ID as the local node's ID
|
||||
# contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
# # Now try to add it
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# # ...and request the closest nodes to it using FIND_NODE
|
||||
# closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
|
||||
# self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_remove_contact(self):
|
||||
# """ Tests contact removal """
|
||||
# # Create the contact
|
||||
# contact_id = generate_id(b'node2')
|
||||
# contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# # Now add it...
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# # Verify addition
|
||||
# self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
|
||||
# # Now remove it
|
||||
# self.routingTable.removeContact(contact)
|
||||
# self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_split_bucket(self):
|
||||
# """ Tests if the the routing table correctly dynamically splits k-buckets """
|
||||
# self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
|
||||
# 'Initial k-bucket range should be 0 <= range < 2**384')
|
||||
# # Add k contacts
|
||||
# for i in range(constants.k):
|
||||
# node_id = generate_id(b'remote node %d' % i)
|
||||
# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
# yield self.routingTable.addContact(contact)
|
||||
#
|
||||
# self.assertEqual(len(self.routingTable._buckets), 1,
|
||||
# 'Only k nodes have been added; the first k-bucket should now '
|
||||
# 'be full, but should not yet be split')
|
||||
# # Now add 1 more contact
|
||||
# node_id = generate_id(b'yet another remote node')
|
||||
# contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# self.assertEqual(len(self.routingTable._buckets), 2,
|
||||
# 'k+1 nodes have been added; the first k-bucket should have been '
|
||||
# 'split into two new buckets')
|
||||
# self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384,
|
||||
# 'K-bucket was split, but its range was not properly adjusted')
|
||||
# self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384,
|
||||
# 'K-bucket was split, but the second (new) bucket\'s '
|
||||
# 'max range was not set properly')
|
||||
# self.assertEqual(self.routingTable._buckets[0].rangeMax,
|
||||
# self.routingTable._buckets[1].rangeMin,
|
||||
# 'K-bucket was split, but the min/max ranges were '
|
||||
# 'not divided properly')
|
||||
#
|
||||
# @defer.inlineCallbacks
|
||||
# def test_full_split(self):
|
||||
# """
|
||||
# Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
|
||||
# """
|
||||
#
|
||||
# self.routingTable._parentNodeID = bytes(48 * b'\xff')
|
||||
#
|
||||
# node_ids = [
|
||||
# b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
# b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
# ]
|
||||
#
|
||||
# # Add k contacts
|
||||
# for nodeID in node_ids:
|
||||
# # self.assertEquals(nodeID, node_ids[i].decode('hex'))
|
||||
# contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol)
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# self.assertEqual(len(self.routingTable._buckets), 2)
|
||||
# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
|
||||
# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
|
||||
#
|
||||
# # try adding a contact who is further from us than the k'th known contact
|
||||
# nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
|
||||
# nodeID = unhexlify(nodeID)
|
||||
# contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
# self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
|
||||
# yield self.routingTable.addContact(contact)
|
||||
# self.assertEqual(len(self.routingTable._buckets), 2)
|
||||
# self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
|
||||
# self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
|
||||
# self.assertNotIn(contact, self.routingTable._buckets[0]._contacts)
|
||||
# self.assertNotIn(contact, self.routingTable._buckets[1]._contacts)
|
||||
#
|
||||
|
||||
# class KeyErrorFixedTest(unittest.TestCase):
|
||||
# """ Basic tests case for boolean operators on the Contact class """
|
||||
#
|
||||
# def setUp(self):
|
||||
# own_id = (2 ** constants.key_bits) - 1
|
||||
# # carefully chosen own_id. here's the logic
|
||||
# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
|
||||
# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
|
||||
# # end
|
||||
#
|
||||
# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
|
||||
#
|
||||
# def fill_bucket(self, bucket_min):
|
||||
# bucket_size = lbrynet.dht.constants.k
|
||||
# for i in range(bucket_min, bucket_min + bucket_size):
|
||||
# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
|
||||
#
|
||||
# def overflow_bucket(self, bucket_min):
|
||||
# bucket_size = lbrynet.dht.constants.k
|
||||
# self.fill_bucket(bucket_min)
|
||||
# self.table.addContact(
|
||||
# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
|
||||
# '127.0.0.1', 9999, None))
|
||||
#
|
||||
# def testKeyError(self):
|
||||
#
|
||||
# # find middle, so we know where bucket will split
|
||||
# bucket_middle = self.table._buckets[0].rangeMax / 2
|
||||
#
|
||||
# # fill last bucket
|
||||
# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
|
||||
# # -1 in previous line because own_id is in last bucket
|
||||
#
|
||||
# # fill/overflow 7 more buckets
|
||||
# bucket_start = 0
|
||||
# for i in range(0, lbrynet.dht.constants.k):
|
||||
# self.overflow_bucket(bucket_start)
|
||||
# bucket_start += bucket_middle / (2 ** i)
|
||||
#
|
||||
# # replacement cache now has k-1 entries.
|
||||
# # adding one more contact to bucket 0 used to cause a KeyError, but it should work
|
||||
# self.table.addContact(
|
||||
# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
|
||||
#
|
||||
# # import math
|
||||
# # print ""
|
||||
# # for i, bucket in enumerate(self.table._buckets):
|
||||
# # print "Bucket " + str(i) + " (2 ** " + str(
|
||||
# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
|
||||
# # math.log(bucket.rangeMax, 2)) + ")"
|
||||
# # for c in bucket.getContacts():
|
||||
# # print " contact " + str(c.id)
|
||||
# # for key, bucket in self.table._replacementCache.items():
|
||||
# # print "Replacement Cache for Bucket " + str(key)
|
||||
# # for c in bucket:
|
||||
# # print " contact " + str(c.id)
|
0
tests/unit/dht/serialization/__init__.py
Normal file
0
tests/unit/dht/serialization/__init__.py
Normal file
64
tests/unit/dht/serialization/test_bencoding.py
Normal file
64
tests/unit/dht/serialization/test_bencoding.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import unittest
|
||||
from lbrynet.dht.serialization.bencoding import _bencode, bencode, bdecode, DecodeError
|
||||
|
||||
|
||||
class EncodeDecodeTest(unittest.TestCase):
|
||||
def test_fail_with_not_dict(self):
|
||||
with self.assertRaises(TypeError):
|
||||
bencode(1)
|
||||
with self.assertRaises(TypeError):
|
||||
bencode(b'derp')
|
||||
with self.assertRaises(TypeError):
|
||||
bencode('derp')
|
||||
with self.assertRaises(TypeError):
|
||||
bencode([b'derp'])
|
||||
with self.assertRaises(TypeError):
|
||||
bencode([object()])
|
||||
with self.assertRaises(TypeError):
|
||||
bencode({b'derp': object()})
|
||||
|
||||
def test_fail_bad_type(self):
|
||||
with self.assertRaises(DecodeError):
|
||||
bdecode(b'd4le', True)
|
||||
|
||||
def test_integer(self):
|
||||
self.assertEqual(_bencode(42), b'i42e')
|
||||
self.assertEqual(bdecode(b'i42e', True), 42)
|
||||
|
||||
def test_bytes(self):
|
||||
self.assertEqual(_bencode(b''), b'0:')
|
||||
self.assertEqual(_bencode(b'spam'), b'4:spam')
|
||||
self.assertEqual(_bencode(b'4:spam'), b'6:4:spam')
|
||||
self.assertEqual(_bencode(bytearray(b'spam')), b'4:spam')
|
||||
|
||||
self.assertEqual(bdecode(b'0:', True), b'')
|
||||
self.assertEqual(bdecode(b'4:spam', True), b'spam')
|
||||
self.assertEqual(bdecode(b'6:4:spam', True), b'4:spam')
|
||||
|
||||
def test_string(self):
|
||||
self.assertEqual(_bencode(''), b'0:')
|
||||
self.assertEqual(_bencode('spam'), b'4:spam')
|
||||
self.assertEqual(_bencode('4:spam'), b'6:4:spam')
|
||||
|
||||
def test_list(self):
|
||||
self.assertEqual(_bencode([b'spam', 42]), b'l4:spami42ee')
|
||||
self.assertEqual(bdecode(b'l4:spami42ee', True), [b'spam', 42])
|
||||
|
||||
def test_dict(self):
|
||||
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
|
||||
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
|
||||
|
||||
def test_mixed(self):
|
||||
self.assertEqual(_bencode(
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
|
||||
)
|
||||
|
||||
self.assertEqual(bdecode(
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee', True),
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
|
||||
)
|
||||
|
||||
def test_decode_error(self):
|
||||
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz', True)
|
||||
self.assertRaises(DecodeError, bdecode, b'', True)
|
60
tests/unit/dht/serialization/test_datagram.py
Normal file
60
tests/unit/dht/serialization/test_datagram.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import unittest
|
||||
from lbrynet.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram
|
||||
from lbrynet.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE
|
||||
|
||||
|
||||
class TestDatagram(unittest.TestCase):
|
||||
def test_ping_request_datagram(self):
|
||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'ping', []).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.method, b'ping')
|
||||
self.assertListEqual(decoded.args, [{b'protocolVersion': 1}])
|
||||
|
||||
def test_ping_response(self):
|
||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, b'pong').bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.response, b'pong')
|
||||
|
||||
def test_find_node_request_datagram(self):
|
||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findNode', [b'2' * 48]).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.method, b'findNode')
|
||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
|
||||
|
||||
def test_find_node_response(self):
|
||||
closest_response = [(b'3' * 48, '1.2.3.4', 1234)]
|
||||
expected = [[b'3' * 48, b'1.2.3.4', 1234]]
|
||||
|
||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, closest_response).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.response, expected)
|
||||
|
||||
def test_find_value_request(self):
|
||||
serialized = RequestDatagram(REQUEST_TYPE, b'1' * 20, b'1' * 48, b'findValue', [b'2' * 48]).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, REQUEST_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertEqual(decoded.method, b'findValue')
|
||||
self.assertListEqual(decoded.args, [b'2' * 48, {b'protocolVersion': 1}])
|
||||
|
||||
def test_find_value_response(self):
|
||||
found_value_response = {b'2' * 48: [b'\x7f\x00\x00\x01']}
|
||||
serialized = ResponseDatagram(RESPONSE_TYPE, b'1' * 20, b'1' * 48, found_value_response).bencode()
|
||||
decoded = decode_datagram(serialized)
|
||||
self.assertEqual(decoded.packet_type, RESPONSE_TYPE)
|
||||
self.assertEqual(decoded.rpc_id, b'1' * 20)
|
||||
self.assertEqual(decoded.node_id, b'1' * 48)
|
||||
self.assertDictEqual(decoded.response, found_value_response)
|
88
tests/unit/dht/test_async_gen_junction.py
Normal file
88
tests/unit/dht/test_async_gen_junction.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import asyncio
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
|
||||
|
||||
|
||||
class MockAsyncGen:
|
||||
def __init__(self, loop, result, delay, stop_cnt=10):
|
||||
self.loop = loop
|
||||
self.result = result
|
||||
self.delay = delay
|
||||
self.count = 0
|
||||
self.stop_cnt = stop_cnt
|
||||
self.called_close = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.count > self.stop_cnt - 1:
|
||||
raise StopAsyncIteration()
|
||||
self.count += 1
|
||||
await asyncio.sleep(self.delay, loop=self.loop)
|
||||
return self.result
|
||||
|
||||
async def aclose(self):
|
||||
self.called_close = True
|
||||
|
||||
|
||||
class TestAsyncGeneratorJunction(AsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def _test_junction(self, expected, *generators):
|
||||
order = []
|
||||
async with AsyncGeneratorJunction(self.loop) as junction:
|
||||
for generator in generators:
|
||||
junction.add_generator(generator)
|
||||
async for item in junction:
|
||||
order.append(item)
|
||||
self.assertListEqual(order, expected)
|
||||
|
||||
async def test_yield_order(self):
|
||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2]
|
||||
fast_gen = MockAsyncGen(self.loop, 1, 0.1)
|
||||
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
|
||||
await self._test_junction(expected_order, fast_gen, slow_gen)
|
||||
self.assertEqual(fast_gen.called_close, True)
|
||||
self.assertEqual(slow_gen.called_close, True)
|
||||
|
||||
async def test_one_stopped_first(self):
|
||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
|
||||
fast_gen = MockAsyncGen(self.loop, 1, 0.1, 5)
|
||||
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
|
||||
await self._test_junction(expected_order, fast_gen, slow_gen)
|
||||
self.assertEqual(fast_gen.called_close, True)
|
||||
self.assertEqual(slow_gen.called_close, True)
|
||||
|
||||
async def test_with_non_async_gen_class(self):
|
||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
|
||||
|
||||
async def fast_gen():
|
||||
for i in range(10):
|
||||
if i == 5:
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
yield 1
|
||||
|
||||
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
|
||||
await self._test_junction(expected_order, fast_gen(), slow_gen)
|
||||
self.assertEqual(slow_gen.called_close, True)
|
||||
|
||||
async def test_stop_when_encapsulating_task_cancelled(self):
|
||||
fast_gen = MockAsyncGen(self.loop, 1, 0.1)
|
||||
slow_gen = MockAsyncGen(self.loop, 2, 0.2)
|
||||
|
||||
async def _task():
|
||||
async with AsyncGeneratorJunction(self.loop) as junction:
|
||||
junction.add_generator(fast_gen)
|
||||
junction.add_generator(slow_gen)
|
||||
async for _ in junction:
|
||||
pass
|
||||
|
||||
task = self.loop.create_task(_task())
|
||||
self.loop.call_later(0.5, task.cancel)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
self.assertEqual(fast_gen.called_close, True)
|
||||
self.assertEqual(slow_gen.called_close, True)
|
|
@ -1,50 +0,0 @@
|
|||
from twisted.trial import unittest
|
||||
from lbrynet.dht.encoding import bencode, bdecode, DecodeError
|
||||
|
||||
|
||||
class EncodeDecodeTest(unittest.TestCase):
|
||||
|
||||
def test_integer(self):
|
||||
self.assertEqual(bencode(42), b'i42e')
|
||||
|
||||
self.assertEqual(bdecode(b'i42e'), 42)
|
||||
|
||||
def test_bytes(self):
|
||||
self.assertEqual(bencode(b''), b'0:')
|
||||
self.assertEqual(bencode(b'spam'), b'4:spam')
|
||||
self.assertEqual(bencode(b'4:spam'), b'6:4:spam')
|
||||
self.assertEqual(bencode(bytearray(b'spam')), b'4:spam')
|
||||
|
||||
self.assertEqual(bdecode(b'0:'), b'')
|
||||
self.assertEqual(bdecode(b'4:spam'), b'spam')
|
||||
self.assertEqual(bdecode(b'6:4:spam'), b'4:spam')
|
||||
|
||||
def test_string(self):
|
||||
self.assertEqual(bencode(''), b'0:')
|
||||
self.assertEqual(bencode('spam'), b'4:spam')
|
||||
self.assertEqual(bencode('4:spam'), b'6:4:spam')
|
||||
|
||||
def test_list(self):
|
||||
self.assertEqual(bencode([b'spam', 42]), b'l4:spami42ee')
|
||||
|
||||
self.assertEqual(bdecode(b'l4:spami42ee'), [b'spam', 42])
|
||||
|
||||
def test_dict(self):
|
||||
self.assertEqual(bencode({b'foo': 42, b'bar': b'spam'}), b'd3:bar4:spam3:fooi42ee')
|
||||
|
||||
self.assertEqual(bdecode(b'd3:bar4:spam3:fooi42ee'), {b'foo': 42, b'bar': b'spam'})
|
||||
|
||||
def test_mixed(self):
|
||||
self.assertEqual(bencode(
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]),
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'
|
||||
)
|
||||
|
||||
self.assertEqual(bdecode(
|
||||
b'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'),
|
||||
[[b'abc', b'127.0.0.1', 1919], [b'def', b'127.0.0.1', 1921]]
|
||||
)
|
||||
|
||||
def test_decode_error(self):
|
||||
self.assertRaises(DecodeError, bdecode, b'abcdefghijklmnopqrstuvwxyz')
|
||||
self.assertRaises(DecodeError, bdecode, b'')
|
|
@ -1,59 +0,0 @@
|
|||
from twisted.trial import unittest
|
||||
from twisted.internet import defer, task
|
||||
from lbrynet import utils
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.extras.daemon.HashAnnouncer import DHTHashAnnouncer
|
||||
from tests.test_utils import random_lbry_hash
|
||||
|
||||
|
||||
class MocDHTNode:
|
||||
def __init__(self):
|
||||
self.blobs_announced = 0
|
||||
self.clock = task.Clock()
|
||||
self.peerPort = 3333
|
||||
|
||||
def announceHaveBlob(self, blob):
|
||||
self.blobs_announced += 1
|
||||
d = defer.Deferred()
|
||||
self.clock.callLater(1, d.callback, ['fake'])
|
||||
return d
|
||||
|
||||
|
||||
class MocStorage:
|
||||
def __init__(self, blobs_to_announce):
|
||||
self.blobs_to_announce = blobs_to_announce
|
||||
self.announced = False
|
||||
|
||||
def get_blobs_to_announce(self):
|
||||
if not self.announced:
|
||||
self.announced = True
|
||||
return defer.succeed(self.blobs_to_announce)
|
||||
else:
|
||||
return defer.succeed([])
|
||||
|
||||
def update_last_announced_blob(self, blob_hash, now):
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
class DHTHashAnnouncerTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
conf = Config()
|
||||
self.num_blobs = 10
|
||||
self.blobs_to_announce = []
|
||||
for i in range(0, self.num_blobs):
|
||||
self.blobs_to_announce.append(random_lbry_hash())
|
||||
self.dht_node = MocDHTNode()
|
||||
self.clock = self.dht_node.clock
|
||||
utils.call_later = self.clock.callLater
|
||||
self.storage = MocStorage(self.blobs_to_announce)
|
||||
self.announcer = DHTHashAnnouncer(conf, self.dht_node, self.storage)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_immediate_announce(self):
|
||||
announce_d = self.announcer.immediate_announce(self.blobs_to_announce)
|
||||
self.assertEqual(self.announcer.hash_queue_size(), self.num_blobs)
|
||||
self.clock.advance(1)
|
||||
yield announce_d
|
||||
self.assertEqual(self.dht_node.blobs_announced, self.num_blobs)
|
||||
self.assertEqual(self.announcer.hash_queue_size(), 0)
|
|
@ -1,125 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
|
||||
from twisted.trial import unittest
|
||||
import struct
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.dht import kbucket
|
||||
from lbrynet.dht.contact import ContactManager
|
||||
from lbrynet.dht import constants
|
||||
|
||||
|
||||
def address_generator(address=(10, 42, 42, 1)):
|
||||
def increment(addr):
|
||||
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]).encode())[0] + 1
|
||||
new_addr = []
|
||||
for i in range(4):
|
||||
new_addr.append(value % 256)
|
||||
value >>= 8
|
||||
return tuple(new_addr[::-1])
|
||||
|
||||
while True:
|
||||
yield "{}.{}.{}.{}".format(*address)
|
||||
address = increment(address)
|
||||
|
||||
|
||||
class KBucketTest(unittest.TestCase):
|
||||
""" Test case for the KBucket class """
|
||||
def setUp(self):
|
||||
self.address_generator = address_generator()
|
||||
self.contact_manager = ContactManager()
|
||||
self.kbucket = kbucket.KBucket(0, 2**constants.key_bits, generate_id())
|
||||
|
||||
def testAddContact(self):
|
||||
""" Tests if the bucket handles contact additions/updates correctly """
|
||||
# Test if contacts can be added to empty list
|
||||
# Add k contacts to bucket
|
||||
for i in range(constants.k):
|
||||
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
|
||||
self.kbucket.addContact(tmpContact)
|
||||
self.assertEqual(
|
||||
self.kbucket._contacts[i],
|
||||
tmpContact,
|
||||
"Contact in position %d not the same as the newly-added contact" % i)
|
||||
|
||||
# Test if contact is not added to full list
|
||||
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
|
||||
self.assertRaises(kbucket.BucketFull, self.kbucket.addContact, tmpContact)
|
||||
|
||||
# Test if an existing contact is updated correctly if added again
|
||||
existingContact = self.kbucket._contacts[0]
|
||||
self.kbucket.addContact(existingContact)
|
||||
self.assertEqual(
|
||||
self.kbucket._contacts.index(existingContact),
|
||||
len(self.kbucket._contacts)-1,
|
||||
'Contact not correctly updated; it should be at the end of the list of contacts')
|
||||
|
||||
def testGetContacts(self):
|
||||
# try and get 2 contacts from empty list
|
||||
result = self.kbucket.getContacts(2)
|
||||
self.assertFalse(len(result) != 0, "Returned list should be empty; returned list length: %d" %
|
||||
(len(result)))
|
||||
|
||||
|
||||
# Add k-2 contacts
|
||||
node_ids = []
|
||||
if constants.k >= 2:
|
||||
for i in range(constants.k-2):
|
||||
node_ids.append(generate_id())
|
||||
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
|
||||
None)
|
||||
self.kbucket.addContact(tmpContact)
|
||||
else:
|
||||
# add k contacts
|
||||
for i in range(constants.k):
|
||||
node_ids.append(generate_id())
|
||||
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
|
||||
None)
|
||||
self.kbucket.addContact(tmpContact)
|
||||
|
||||
# try to get too many contacts
|
||||
# requested count greater than bucket size; should return at most k contacts
|
||||
contacts = self.kbucket.getContacts(constants.k+3)
|
||||
self.assertTrue(len(contacts) <= constants.k,
|
||||
'Returned list should not have more than k entries!')
|
||||
|
||||
# verify returned contacts in list
|
||||
for node_id, i in zip(node_ids, range(constants.k-2)):
|
||||
self.assertFalse(self.kbucket._contacts[i].id != node_id,
|
||||
"Contact in position %s not same as added contact" % (str(i)))
|
||||
|
||||
# try to get too many contacts
|
||||
# requested count one greater than number of contacts
|
||||
if constants.k >= 2:
|
||||
result = self.kbucket.getContacts(constants.k-1)
|
||||
self.assertFalse(len(result) != constants.k-2,
|
||||
"Too many contacts in returned list %s - should be %s" %
|
||||
(len(result), constants.k-2))
|
||||
else:
|
||||
result = self.kbucket.getContacts(constants.k-1)
|
||||
# if the count is <= 0, it should return all of it's contats
|
||||
self.assertFalse(len(result) != constants.k,
|
||||
"Too many contacts in returned list %s - should be %s" %
|
||||
(len(result), constants.k-2))
|
||||
result = self.kbucket.getContacts(constants.k-3)
|
||||
self.assertFalse(len(result) != constants.k-3,
|
||||
"Too many contacts in returned list %s - should be %s" %
|
||||
(len(result), constants.k-3))
|
||||
|
||||
def testRemoveContact(self):
|
||||
# try remove contact from empty list
|
||||
rmContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
|
||||
self.assertRaises(ValueError, self.kbucket.removeContact, rmContact)
|
||||
|
||||
# Add couple contacts
|
||||
for i in range(constants.k-2):
|
||||
tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
|
||||
self.kbucket.addContact(tmpContact)
|
||||
|
||||
# try remove contact from empty list
|
||||
self.kbucket.addContact(rmContact)
|
||||
result = self.kbucket.removeContact(rmContact)
|
||||
self.assertNotIn(rmContact, self.kbucket._contacts, "Could not remove contact from bucket")
|
|
@ -1,84 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# This library is free software, distributed under the terms of
|
||||
# the GNU Lesser General Public License Version 3, or any later version.
|
||||
# See the COPYING file included in this archive
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage
|
||||
from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat
|
||||
|
||||
|
||||
class DefaultFormatTranslatorTest(unittest.TestCase):
|
||||
""" Test case for the default message translator """
|
||||
def setUp(self):
|
||||
self.cases = ((RequestMessage('1' * 48, 'rpcMethod',
|
||||
{'arg1': 'a string', 'arg2': 123}, '1' * 20),
|
||||
{DefaultFormat.headerType: DefaultFormat.typeRequest,
|
||||
DefaultFormat.headerNodeID: '1' * 48,
|
||||
DefaultFormat.headerMsgID: '1' * 20,
|
||||
DefaultFormat.headerPayload: 'rpcMethod',
|
||||
DefaultFormat.headerArgs: {'arg1': 'a string', 'arg2': 123}}),
|
||||
|
||||
(ResponseMessage('2' * 20, '2' * 48, 'response'),
|
||||
{DefaultFormat.headerType: DefaultFormat.typeResponse,
|
||||
DefaultFormat.headerNodeID: '2' * 48,
|
||||
DefaultFormat.headerMsgID: '2' * 20,
|
||||
DefaultFormat.headerPayload: 'response'}),
|
||||
|
||||
(ErrorMessage('3' * 20, '3' * 48,
|
||||
"<type 'exceptions.ValueError'>", 'this is a test exception'),
|
||||
{DefaultFormat.headerType: DefaultFormat.typeError,
|
||||
DefaultFormat.headerNodeID: '3' * 48,
|
||||
DefaultFormat.headerMsgID: '3' * 20,
|
||||
DefaultFormat.headerPayload: "<type 'exceptions.ValueError'>",
|
||||
DefaultFormat.headerArgs: 'this is a test exception'}),
|
||||
|
||||
(ResponseMessage(
|
||||
'4' * 20, '4' * 48,
|
||||
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
|
||||
'127.0.0.1', 1919),
|
||||
('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{',
|
||||
'127.0.0.1', 1921)]),
|
||||
{DefaultFormat.headerType: DefaultFormat.typeResponse,
|
||||
DefaultFormat.headerNodeID: '4' * 48,
|
||||
DefaultFormat.headerMsgID: '4' * 20,
|
||||
DefaultFormat.headerPayload:
|
||||
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
|
||||
'127.0.0.1', 1919),
|
||||
('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{',
|
||||
'127.0.0.1', 1921)]})
|
||||
)
|
||||
self.translator = DefaultFormat()
|
||||
self.assertTrue(
|
||||
isinstance(self.translator, MessageTranslator),
|
||||
'Translator class must inherit from entangled.kademlia.msgformat.MessageTranslator!')
|
||||
|
||||
def testToPrimitive(self):
|
||||
""" Tests translation from a Message object to a primitive """
|
||||
for msg, msgPrimitive in self.cases:
|
||||
translatedObj = self.translator.toPrimitive(msg)
|
||||
self.assertEqual(len(translatedObj), len(msgPrimitive),
|
||||
"Translated object does not match example object's size")
|
||||
for key in msgPrimitive:
|
||||
self.assertEqual(
|
||||
translatedObj[key], msgPrimitive[key],
|
||||
'Message object type %s not translated correctly into primitive on '
|
||||
'key "%s"; expected "%s", got "%s"' %
|
||||
(msg.__class__.__name__, key, msgPrimitive[key], translatedObj[key]))
|
||||
|
||||
def testFromPrimitive(self):
|
||||
""" Tests translation from a primitive to a Message object """
|
||||
for msg, msgPrimitive in self.cases:
|
||||
translatedObj = self.translator.fromPrimitive(msgPrimitive)
|
||||
self.assertEqual(
|
||||
type(translatedObj), type(msg),
|
||||
'Message type incorrectly translated; expected "%s", got "%s"' %
|
||||
(type(msg), type(translatedObj)))
|
||||
for key in msg.__dict__:
|
||||
self.assertEqual(
|
||||
msg.__dict__[key], translatedObj.__dict__[key],
|
||||
'Message instance variable "%s" not translated correctly; '
|
||||
'expected "%s", got "%s"' %
|
||||
(key, msg.__dict__[key], translatedObj.__dict__[key]))
|
|
@ -1,88 +1,85 @@
|
|||
import hashlib
|
||||
import struct
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
from lbrynet.dht.node import Node
|
||||
import asyncio
|
||||
import typing
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from tests import dht_mocks
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
|
||||
|
||||
class NodeIDTest(unittest.TestCase):
|
||||
class TestNodePingQueueDiscover(AsyncioTestCase):
|
||||
async def test_ping_queue_discover(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def setUp(self):
|
||||
self.node = Node()
|
||||
peer_addresses = [
|
||||
(constants.generate_id(1), '1.2.3.1'),
|
||||
(constants.generate_id(2), '1.2.3.2'),
|
||||
(constants.generate_id(3), '1.2.3.3'),
|
||||
(constants.generate_id(4), '1.2.3.4'),
|
||||
(constants.generate_id(5), '1.2.3.5'),
|
||||
(constants.generate_id(6), '1.2.3.6'),
|
||||
(constants.generate_id(7), '1.2.3.7'),
|
||||
(constants.generate_id(8), '1.2.3.8'),
|
||||
(constants.generate_id(9), '1.2.3.9'),
|
||||
]
|
||||
with dht_mocks.mock_network_loop(loop):
|
||||
advance = dht_mocks.get_time_accelerator(loop, loop.time())
|
||||
# start the nodes
|
||||
nodes: typing.Dict[int, Node] = {
|
||||
i: Node(loop, PeerManager(loop), node_id, 4444, 4444, 3333, address)
|
||||
for i, (node_id, address) in enumerate(peer_addresses)
|
||||
}
|
||||
for i, n in nodes.items():
|
||||
n.start(peer_addresses[i][1], [])
|
||||
|
||||
def test_new_node_has_auto_created_id(self):
|
||||
self.assertEqual(type(self.node.node_id), bytes)
|
||||
self.assertEqual(len(self.node.node_id), 48)
|
||||
await advance(1)
|
||||
|
||||
def test_uniqueness_and_length_of_generated_ids(self):
|
||||
previous_ids = []
|
||||
for i in range(100):
|
||||
new_id = self.node._generateID()
|
||||
self.assertNotIn(new_id, previous_ids, f'id at index {i} not unique')
|
||||
self.assertEqual(len(new_id), 48, 'id at index {} wrong length: {}'.format(i, len(new_id)))
|
||||
previous_ids.append(new_id)
|
||||
node_1 = nodes[0]
|
||||
|
||||
|
||||
class NodeDataTest(unittest.TestCase):
|
||||
""" Test case for the Node class's data-related functions """
|
||||
|
||||
def setUp(self):
|
||||
h = hashlib.sha384()
|
||||
h.update(b'test')
|
||||
self.node = Node()
|
||||
self.contact = self.node.contact_manager.make_contact(
|
||||
h.digest(), '127.0.0.1', 12345, self.node._protocol)
|
||||
self.token = self.node.make_token(self.contact.compact_ip())
|
||||
self.cases = []
|
||||
for i in range(5):
|
||||
h.update(str(i).encode())
|
||||
self.cases.append((h.digest(), 5000+2*i))
|
||||
self.cases.append((h.digest(), 5001+2*i))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_store(self):
|
||||
""" Tests if the node can store (and privately retrieve) some data """
|
||||
for key, port in self.cases:
|
||||
yield self.node.store(
|
||||
self.contact, key, self.token, port, self.contact.id, 0
|
||||
# ping 8 nodes from node_1, this will result in a delayed return ping
|
||||
futs = []
|
||||
for i in range(1, len(peer_addresses)):
|
||||
node = nodes[i]
|
||||
assert node.protocol.node_id != node_1.protocol.node_id
|
||||
peer = node_1.protocol.peer_manager.get_kademlia_peer(
|
||||
node.protocol.node_id, node.protocol.external_ip, udp_port=node.protocol.udp_port
|
||||
)
|
||||
for key, value in self.cases:
|
||||
expected_result = self.contact.compact_ip() + struct.pack('>H', value) + self.contact.id
|
||||
self.assertTrue(self.node._dataStore.hasPeersForBlob(key),
|
||||
"Stored key not found in node's DataStore: '%s'" % key)
|
||||
self.assertIn(expected_result, self.node._dataStore.getPeersForBlob(key),
|
||||
"Stored val not found in node's DataStore: key:'%s' port:'%s' %s"
|
||||
% (key, value, self.node._dataStore.getPeersForBlob(key)))
|
||||
futs.append(node_1.protocol.get_rpc_peer(peer).ping())
|
||||
await advance(3)
|
||||
replies = await asyncio.gather(*tuple(futs))
|
||||
self.assertTrue(all(map(lambda reply: reply == b"pong", replies)))
|
||||
|
||||
# run for long enough for the delayed pings to have been sent by node 1
|
||||
await advance(1000)
|
||||
|
||||
class NodeContactTest(unittest.TestCase):
|
||||
""" Test case for the Node class's contact management-related functions """
|
||||
def setUp(self):
|
||||
self.node = Node()
|
||||
# verify all of the previously pinged peers have node_1 in their routing tables
|
||||
for n in nodes.values():
|
||||
peers = n.protocol.routing_table.get_peers()
|
||||
if n is node_1:
|
||||
self.assertEqual(8, len(peers))
|
||||
else:
|
||||
self.assertEqual(1, len(peers))
|
||||
self.assertEqual((peers[0].node_id, peers[0].address, peers[0].udp_port),
|
||||
(node_1.protocol.node_id, node_1.protocol.external_ip, node_1.protocol.udp_port))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_contact(self):
|
||||
""" Tests if a contact can be added and retrieved correctly """
|
||||
# Create the contact
|
||||
contact_id = generate_id(b'node1')
|
||||
contact = self.node.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.node._protocol)
|
||||
# Now add it...
|
||||
yield self.node.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closest_nodes = self.node._routingTable.findCloseNodes(contact_id, constants.k)
|
||||
self.assertEqual(len(closest_nodes), 1)
|
||||
self.assertIn(contact, closest_nodes)
|
||||
# run long enough for the refresh loop to run
|
||||
await advance(3600)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_self_as_contact(self):
|
||||
""" Tests the node's behaviour when attempting to add itself as a contact """
|
||||
# Create a contact with the same ID as the local node's ID
|
||||
contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
|
||||
# Now try to add it
|
||||
yield self.node.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closest_nodes = self.node._routingTable.findCloseNodes(self.node.node_id, constants.k)
|
||||
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact.')
|
||||
# verify all the nodes know about each other
|
||||
for n in nodes.values():
|
||||
if n is node_1:
|
||||
continue
|
||||
peers = n.protocol.routing_table.get_peers()
|
||||
self.assertEqual(8, len(peers))
|
||||
self.assertSetEqual(
|
||||
{n_id[0] for n_id in peer_addresses if n_id[0] != n.protocol.node_id},
|
||||
{c.node_id for c in peers}
|
||||
)
|
||||
self.assertSetEqual(
|
||||
{n_addr[1] for n_addr in peer_addresses if n_addr[1] != n.protocol.external_ip},
|
||||
{c.address for c in peers}
|
||||
)
|
||||
|
||||
# teardown
|
||||
for n in nodes.values():
|
||||
n.stop()
|
||||
|
|
|
@ -1,64 +1,37 @@
|
|||
from binascii import hexlify
|
||||
from twisted.internet import task
|
||||
from twisted.trial import unittest
|
||||
import asyncio
|
||||
import unittest
|
||||
from lbrynet.utils import generate_id
|
||||
from lbrynet.dht.contact import ContactManager
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.peer import PeerManager
|
||||
from torba.testcase import AsyncioTestCase
|
||||
|
||||
|
||||
class ContactTest(unittest.TestCase):
|
||||
""" Basic tests case for boolean operators on the Contact class """
|
||||
class PeerTest(AsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.contact_manager = ContactManager()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.peer_manager = PeerManager(self.loop)
|
||||
self.node_ids = [generate_id(), generate_id(), generate_id()]
|
||||
make_contact = self.contact_manager.make_contact
|
||||
self.first_contact = make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1)
|
||||
self.second_contact = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
|
||||
self.second_contact_second_reference = make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
|
||||
self.first_contact_different_values = make_contact(self.node_ids[1], '192.168.1.20', 1000, None, 50)
|
||||
self.first_contact = self.peer_manager.get_kademlia_peer(self.node_ids[1], '127.0.0.1', udp_port=1000)
|
||||
self.second_contact = self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000)
|
||||
|
||||
def test_make_contact_error_cases(self):
|
||||
self.assertRaises(
|
||||
ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20', 100000, None)
|
||||
self.assertRaises(
|
||||
ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1', 1000, None)
|
||||
self.assertRaises(
|
||||
ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip', 1000, None)
|
||||
self.assertRaises(
|
||||
ValueError, self.contact_manager.make_contact, b'not valid node id', '192.168.1.20.1', 1000, None)
|
||||
|
||||
def test_no_duplicate_contact_objects(self):
|
||||
self.assertIs(self.second_contact, self.second_contact_second_reference)
|
||||
self.assertIsNot(self.first_contact, self.first_contact_different_values)
|
||||
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', 100000)
|
||||
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20.1', 1000)
|
||||
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], 'this is not an ip', 1000)
|
||||
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, self.node_ids[1], '192.168.1.20', -1000)
|
||||
self.assertRaises(ValueError, self.peer_manager.get_kademlia_peer, b'not valid node id', '192.168.1.20', 1000)
|
||||
|
||||
def test_boolean(self):
|
||||
""" Test "equals" and "not equals" comparisons """
|
||||
self.assertNotEqual(
|
||||
self.first_contact, self.contact_manager.make_contact(
|
||||
self.first_contact.id, self.first_contact.address, self.first_contact.port + 1, None, 32
|
||||
self.assertNotEqual(self.first_contact, self.second_contact)
|
||||
self.assertEqual(
|
||||
self.second_contact, self.peer_manager.get_kademlia_peer(self.node_ids[0], '192.168.0.1', udp_port=1000)
|
||||
)
|
||||
)
|
||||
self.assertNotEqual(
|
||||
self.first_contact, self.contact_manager.make_contact(
|
||||
self.first_contact.id, '193.168.1.1', self.first_contact.port, None, 32
|
||||
)
|
||||
)
|
||||
self.assertNotEqual(
|
||||
self.first_contact, self.contact_manager.make_contact(
|
||||
generate_id(), self.first_contact.address, self.first_contact.port, None, 32
|
||||
)
|
||||
)
|
||||
self.assertEqual(self.second_contact, self.second_contact_second_reference)
|
||||
|
||||
def test_compact_ip(self):
|
||||
self.assertEqual(self.first_contact.compact_ip(), b'\x7f\x00\x00\x01')
|
||||
self.assertEqual(self.second_contact.compact_ip(), b'\xc0\xa8\x00\x01')
|
||||
|
||||
def test_id_log(self):
|
||||
self.assertEqual(self.first_contact.log_id(False), hexlify(self.node_ids[1]))
|
||||
self.assertEqual(self.first_contact.log_id(True), hexlify(self.node_ids[1])[:8])
|
||||
|
||||
|
||||
@unittest.SkipTest
|
||||
class TestContactLastReplied(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.clock = task.Clock()
|
||||
|
@ -129,6 +102,7 @@ class TestContactLastReplied(unittest.TestCase):
|
|||
self.assertIsNone(self.contact.contact_is_good)
|
||||
|
||||
|
||||
@unittest.SkipTest
|
||||
class TestContactLastRequested(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.clock = task.Clock()
|
|
@ -1,213 +0,0 @@
|
|||
from binascii import hexlify, unhexlify
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
from lbrynet.dht import constants
|
||||
from lbrynet.dht.routingtable import TreeRoutingTable
|
||||
from lbrynet.dht.contact import ContactManager
|
||||
from lbrynet.dht.distance import Distance
|
||||
from lbrynet.utils import generate_id
|
||||
|
||||
|
||||
class FakeRPCProtocol:
|
||||
""" Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
|
||||
def sendRPC(self, *args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
class TreeRoutingTableTest(unittest.TestCase):
|
||||
""" Test case for the RoutingTable class """
|
||||
def setUp(self):
|
||||
self.contact_manager = ContactManager()
|
||||
self.nodeID = generate_id(b'node1')
|
||||
self.protocol = FakeRPCProtocol()
|
||||
self.routingTable = TreeRoutingTable(self.nodeID)
|
||||
|
||||
def test_distance(self):
|
||||
""" Test to see if distance method returns correct result"""
|
||||
d = Distance(bytes((170,) * 48))
|
||||
result = d(bytes((85,) * 48))
|
||||
expected = int(hexlify(bytes((255,) * 48)), 16)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_contact(self):
|
||||
""" Tests if a contact can be added and retrieved correctly """
|
||||
# Create the contact
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and request the closest nodes to it (will retrieve it)
|
||||
closest_nodes = self.routingTable.findCloseNodes(contact_id)
|
||||
self.assertEqual(len(closest_nodes), 1)
|
||||
self.assertIn(contact, closest_nodes)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_contact(self):
|
||||
""" Tests if a specific existing contact can be retrieved correctly """
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and get it again
|
||||
same_contact = self.routingTable.getContact(contact_id)
|
||||
self.assertEqual(contact, same_contact, 'getContact() should return the same contact')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_parent_node_as_contact(self):
|
||||
"""
|
||||
Tests the routing table's behaviour when attempting to add its parent node as a contact
|
||||
"""
|
||||
# Create a contact with the same ID as the local node's ID
|
||||
contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
# Now try to add it
|
||||
yield self.routingTable.addContact(contact)
|
||||
# ...and request the closest nodes to it using FIND_NODE
|
||||
closest_nodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
|
||||
self.assertNotIn(contact, closest_nodes, 'Node added itself as a contact')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_remove_contact(self):
|
||||
""" Tests contact removal """
|
||||
# Create the contact
|
||||
contact_id = generate_id(b'node2')
|
||||
contact = self.contact_manager.make_contact(contact_id, '127.0.0.1', 9182, self.protocol)
|
||||
# Now add it...
|
||||
yield self.routingTable.addContact(contact)
|
||||
# Verify addition
|
||||
self.assertEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
|
||||
# Now remove it
|
||||
self.routingTable.removeContact(contact)
|
||||
self.assertEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_split_bucket(self):
|
||||
""" Tests if the the routing table correctly dynamically splits k-buckets """
|
||||
self.assertEqual(self.routingTable._buckets[0].rangeMax, 2**384,
|
||||
'Initial k-bucket range should be 0 <= range < 2**384')
|
||||
# Add k contacts
|
||||
for i in range(constants.k):
|
||||
node_id = generate_id(b'remote node %d' % i)
|
||||
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
yield self.routingTable.addContact(contact)
|
||||
|
||||
self.assertEqual(len(self.routingTable._buckets), 1,
|
||||
'Only k nodes have been added; the first k-bucket should now '
|
||||
'be full, but should not yet be split')
|
||||
# Now add 1 more contact
|
||||
node_id = generate_id(b'yet another remote node')
|
||||
contact = self.contact_manager.make_contact(node_id, '127.0.0.1', 9182, self.protocol)
|
||||
yield self.routingTable.addContact(contact)
|
||||
self.assertEqual(len(self.routingTable._buckets), 2,
|
||||
'k+1 nodes have been added; the first k-bucket should have been '
|
||||
'split into two new buckets')
|
||||
self.assertNotEqual(self.routingTable._buckets[0].rangeMax, 2**384,
|
||||
'K-bucket was split, but its range was not properly adjusted')
|
||||
self.assertEqual(self.routingTable._buckets[1].rangeMax, 2**384,
|
||||
'K-bucket was split, but the second (new) bucket\'s '
|
||||
'max range was not set properly')
|
||||
self.assertEqual(self.routingTable._buckets[0].rangeMax,
|
||||
self.routingTable._buckets[1].rangeMin,
|
||||
'K-bucket was split, but the min/max ranges were '
|
||||
'not divided properly')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_full_split(self):
|
||||
"""
|
||||
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
|
||||
"""
|
||||
|
||||
self.routingTable._parentNodeID = bytes(48 * b'\xff')
|
||||
|
||||
node_ids = [
|
||||
b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
b"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
]
|
||||
|
||||
# Add k contacts
|
||||
for nodeID in node_ids:
|
||||
# self.assertEquals(nodeID, node_ids[i].decode('hex'))
|
||||
contact = self.contact_manager.make_contact(unhexlify(nodeID), '127.0.0.1', 9182, self.protocol)
|
||||
yield self.routingTable.addContact(contact)
|
||||
self.assertEqual(len(self.routingTable._buckets), 2)
|
||||
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
|
||||
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
|
||||
|
||||
# try adding a contact who is further from us than the k'th known contact
|
||||
nodeID = b'020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
|
||||
nodeID = unhexlify(nodeID)
|
||||
contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
|
||||
self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
|
||||
yield self.routingTable.addContact(contact)
|
||||
self.assertEqual(len(self.routingTable._buckets), 2)
|
||||
self.assertEqual(len(self.routingTable._buckets[0]._contacts), 8)
|
||||
self.assertEqual(len(self.routingTable._buckets[1]._contacts), 2)
|
||||
self.assertNotIn(contact, self.routingTable._buckets[0]._contacts)
|
||||
self.assertNotIn(contact, self.routingTable._buckets[1]._contacts)
|
||||
|
||||
|
||||
# class KeyErrorFixedTest(unittest.TestCase):
|
||||
# """ Basic tests case for boolean operators on the Contact class """
|
||||
#
|
||||
# def setUp(self):
|
||||
# own_id = (2 ** constants.key_bits) - 1
|
||||
# # carefully chosen own_id. here's the logic
|
||||
# # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
|
||||
# # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
|
||||
# # end
|
||||
#
|
||||
# self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
|
||||
#
|
||||
# def fill_bucket(self, bucket_min):
|
||||
# bucket_size = lbrynet.dht.constants.k
|
||||
# for i in range(bucket_min, bucket_min + bucket_size):
|
||||
# self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
|
||||
#
|
||||
# def overflow_bucket(self, bucket_min):
|
||||
# bucket_size = lbrynet.dht.constants.k
|
||||
# self.fill_bucket(bucket_min)
|
||||
# self.table.addContact(
|
||||
# lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
|
||||
# '127.0.0.1', 9999, None))
|
||||
#
|
||||
# def testKeyError(self):
|
||||
#
|
||||
# # find middle, so we know where bucket will split
|
||||
# bucket_middle = self.table._buckets[0].rangeMax / 2
|
||||
#
|
||||
# # fill last bucket
|
||||
# self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
|
||||
# # -1 in previous line because own_id is in last bucket
|
||||
#
|
||||
# # fill/overflow 7 more buckets
|
||||
# bucket_start = 0
|
||||
# for i in range(0, lbrynet.dht.constants.k):
|
||||
# self.overflow_bucket(bucket_start)
|
||||
# bucket_start += bucket_middle / (2 ** i)
|
||||
#
|
||||
# # replacement cache now has k-1 entries.
|
||||
# # adding one more contact to bucket 0 used to cause a KeyError, but it should work
|
||||
# self.table.addContact(
|
||||
# lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
|
||||
#
|
||||
# # import math
|
||||
# # print ""
|
||||
# # for i, bucket in enumerate(self.table._buckets):
|
||||
# # print "Bucket " + str(i) + " (2 ** " + str(
|
||||
# # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
|
||||
# # math.log(bucket.rangeMax, 2)) + ")"
|
||||
# # for c in bucket.getContacts():
|
||||
# # print " contact " + str(c.id)
|
||||
# # for key, bucket in self.table._replacementCache.items():
|
||||
# # print "Replacement Cache for Bucket " + str(key)
|
||||
# # for c in bucket:
|
||||
# # print " contact " + str(c.id)
|
Loading…
Add table
Reference in a new issue