forked from LBRYCommunity/lbry-sdk
* use int to_bytes/from_bytes instead of struct * fix ping queue bug and dht functional tests * run functional tests on travis * re-add contact comparison unit test * dont need __ne__ if its just inverting __eq__ result
502 lines
23 KiB
502 lines
23 KiB
import logging
import errno
from binascii import hexlify
from collections import deque
from twisted.internet import protocol, defer
from .error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
from . import constants
from . import encoding
from . import msgtypes
from . import msgformat
log = logging.getLogger(__name__)
class PingQueue:
Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the
routing table after having been given enough time for a pinhole to expire.
def __init__(self, node):
self._node = node
self._get_time = self._node.clock.seconds
self._queue = deque()
self._enqueued_contacts = {}
self._semaphore = defer.DeferredSemaphore(1)
self._ping_semaphore = defer.DeferredSemaphore(constants.alpha)
self._process_lc = node.get_looping_call(, self._process)
def _add_contact(self, contact, delay=None):
if (contact.address, contact.port) in [(c.address, c.port) for c in self._enqueued_contacts]:
return defer.succeed(None)
delay = delay or constants.checkRefreshInterval
self._enqueued_contacts[contact] = self._get_time() + delay
return defer.succeed(None)
def _process(self):
if not len(self._queue):
contact = self._queue.popleft()
now = self._get_time()
# if the oldest contact in the queue isn't old enough to be pinged, add it back to the queue and return
if now < self._enqueued_contacts[contact]:
pinged = []
checked = []
while now > self._enqueued_contacts[contact]:
if not contact.contact_is_good:
if not len(self._queue):
contact = self._queue.popleft()
if not now > self._enqueued_contacts[contact]:
def _ping(contact):
except TimeoutError:
except Exception as err:
log.warning("unexpected error: %s", err)
yield defer.DeferredList([_ping(contact) for contact in pinged])
for contact in checked:
if contact in self._enqueued_contacts and contact in pinged:
del self._enqueued_contacts[contact]
elif contact not in self._queue:
def start(self):
return self._node.safe_start_looping_call(self._process_lc, 60)
def stop(self):
return self._node.safe_stop_looping_call(self._process_lc)
def enqueue_maybe_ping(self, contact, delay=None):
return, contact, delay)
class KademliaProtocol(protocol.DatagramProtocol):
""" Implements all low-level network-related functions of a Kademlia node """
msgSizeLimit = constants.udpDatagramMaxSize - 26
def __init__(self, node):
self._node = node
self._translator = msgformat.DefaultFormat()
self._sentMessages = {}
self._partialMessages = {}
self._partialMessagesProgress = {}
self._listening = defer.Deferred(None)
self._ping_queue = PingQueue(self._node)
self._protocolVersion = constants.protocolVersion
self.started_listening_time = 0
def _migrate_incoming_rpc_args(self, contact, method, *args):
if method == b'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop(b'token', None)
port = args[1].pop(b'port', -1)
originalPublisherID = args[1].pop(b'lbryid', None)
age = 0
return (blob_hash, token, port, originalPublisherID, age), {}
return args, {}
def _migrate_outgoing_rpc_args(self, contact, method, *args):
This will reformat protocol version 0 arguments for the store function and will add the
protocol version keyword argument to calls to contacts who will accept it
if contact.protocolVersion == 0:
if method == b'store':
blob_hash, token, port, originalPublisherID, age = args
args = (
blob_hash, {
b'token': token,
b'port': port,
b'lbryid': originalPublisherID
}, originalPublisherID, False
return args
return args
if args and isinstance(args[-1], dict):
args[-1][b'protocolVersion'] = self._protocolVersion
return args
return args + ({b'protocolVersion': self._protocolVersion},)
def sendRPC(self, contact, method, args):
Sends an RPC to the specified contact
@param contact: The contact (remote node) to send the RPC to
@type contact: kademlia.contacts.Contact
@param method: The name of remote method to invoke
@type method: str
@param args: A list of (non-keyword) arguments to pass to the remote
method, in the correct order
@type args: tuple
@return: This immediately returns a deferred object, which will return
the result of the RPC call, or raise the relevant exception
if the remote node raised one. If C{rawResponse} is set to
C{True}, however, it will always return the actual response
message (which may be a C{ResponseMessage} or an
@rtype: twisted.internet.defer.Deferred
msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
hexlify(args[0]), contact.address, contact.port)
log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port)
df = defer.Deferred()
def _remove_contact(failure): # remove the contact from the routing table and track the failure
except (ValueError, IndexError):
return failure
def _update_contact(result): # refresh the contact in the routing table
if method == b'findValue':
if b'protocolVersion' not in result:
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
df.addCallbacks(_update_contact, _remove_contact)
# Set the RPC timeout timer
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout,
# Transmit the data
self._send(encodedMsg,, (contact.address, contact.port))
self._sentMessages[] = (contact, df, timeoutCall, cancelTimeout, method, args)
return df
def startProtocol(self):
|"DHT listening on UDP %i (ext port %i)", self._node.port, self._node.externalUDPPort)
if self._listening.called:
self._listening = defer.Deferred()
self.started_listening_time = self._node.clock.seconds()
return self._ping_queue.start()
def datagramReceived(self, datagram, address):
""" Handles and parses incoming RPC messages (and responses)
@note: This is automatically called by Twisted when the protocol
receives a UDP datagram
if chr(datagram[0]) == '\x00' and chr(datagram[25]) == '\x00':
totalPackets = (datagram[1] << 8) | datagram[2]
msgID = datagram[5:25]
seqNumber = (datagram[3] << 8) | datagram[4]
if msgID not in self._partialMessages:
self._partialMessages[msgID] = {}
self._partialMessages[msgID][seqNumber] = datagram[26:]
if len(self._partialMessages[msgID]) == totalPackets:
keys = self._partialMessages[msgID].keys()
data = b''
for key in keys:
data += self._partialMessages[msgID][key]
datagram = data
del self._partialMessages[msgID]
msgPrimitive = encoding.bdecode(datagram)
message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError) as err:
# We received some rubbish here
log.warning("Error decoding datagram %s from %s:%i - %s", datagram.encode('hex'),
address[0], address[1], err)
except (IndexError, KeyError):
log.warning("Couldn't decode dht datagram from %s", address)
if isinstance(message, msgtypes.RequestMessage):
# This is an RPC method request
remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self)
# only add a requesting contact to the routing table if it has replied to one of our requests
if remoteContact.contact_is_good is True:
df = self._node.addContact(remoteContact)
df = defer.succeed(None)
df.addCallback(lambda _: self._handleRPC(remoteContact,, message.request, message.args))
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful
if remoteContact.contact_is_good is None and remoteContact.lastReplied is None:
df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
if in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[][0:5]
del self._sentMessages[]
# reject replies coming from a different address than what we sent our request to
if (remoteContact.address, remoteContact.port) != address:
log.warning("Sent request to node %s at %s:%i, got reply from %s:%i",
remoteContact.log_id(), remoteContact.address,
remoteContact.port, address[0], address[1])
# this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
elif isinstance(message, msgtypes.ResponseMessage):
# Find the message that triggered this response
if in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[][0:5]
del self._sentMessages[]
log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port,
method, remoteContact.address, remoteContact.port)
# When joining the network we made Contact objects for the seed nodes with node ids set to None
# Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set.
# These replies have be distinguished from those where the node id in the datagram does not match
# the node id of the node we sent a message to (these messages are treated as an error)
if and != message.nodeID: # sent_to_id will be None for bootstrap
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
remoteContact.log_id(False), hexlify(message.nodeID))
elif not
# We got a result from the RPC
# If the original message isn't found, it must have timed out
# TODO: we should probably do something with this...
def _send(self, data, rpcID, address):
""" Transmit the specified data over UDP, breaking it up into several
packets if necessary
If the data is spread over multiple UDP datagrams, the packets have the
following structure::
| | | | | |||||||||||| 0x00 |
|Transmision|Total number|Sequence number| RPC ID |Header end|
| type ID | of packets |of this packet | | indicator|
| (1 byte) | (2 bytes) | (2 bytes) |(20 bytes)| (1 byte) |
| | | | | |||||||||||| |
@note: The header used for breaking up large data segments will
possibly be moved out of the KademliaProtocol class in the
future, into something similar to a message translator/encoder
class (see C{kademlia.msgformat} and C{kademlia.encoding}).
if len(data) > self.msgSizeLimit:
# We have to spread the data over multiple UDP datagrams,
# and provide sequencing information
# 1st byte is transmission type id, bytes 2 & 3 are the
# total number of packets in this transmission, bytes 4 &
# 5 are the sequence number for this specific packet
totalPackets = len(data) // self.msgSizeLimit
if len(data) % self.msgSizeLimit > 0:
totalPackets += 1
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
seqNumber = 0
startPos = 0
while seqNumber < totalPackets:
packetData = data[startPos:startPos + self.msgSizeLimit]
encSeqNumber = chr(seqNumber >> 8) + chr(seqNumber & 0xff)
txData = '\x00%s%s%s\x00%s' % (encTotalPackets, encSeqNumber, rpcID, packetData)
self._scheduleSendNext(txData, address)
startPos += self.msgSizeLimit
seqNumber += 1
self._scheduleSendNext(data, address)
def _scheduleSendNext(self, txData, address):
"""Schedule the sending of the next UDP packet """
delayed_call, _ = self._node.reactor_callSoon(self._write, txData, address)
def _write(self, txData, address):
if self.transport:
self.transport.write(txData, address)
except OSError as err:
if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesnt seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
elif err.errno == errno.ENETUNREACH:
# this should probably try to retransmit when the network connection is back
log.error("Network is unreachable")
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(txData), address[0], address[1], err.message, err.errno)
raise err
raise TransportNotConnected()
def _sendResponse(self, contact, rpcID, response):
""" Send a RPC response to the specified contact
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
""" Send an RPC error message to the specified contact
exceptionMessage = exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = encoding.bencode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _handleRPC(self, senderContact, rpcID, method, args):
""" Executes a local function in response to an RPC request """
# Set up the deferred callchain
def handleError(f):
self._sendError(senderContact, rpcID, f.type, f.getErrorMessage())
def handleResult(result):
self._sendResponse(senderContact, rpcID, result)
df = defer.Deferred()
# Execute the RPC
func = getattr(self._node, method.decode(), None)
if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain
# if args:
# log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
# args[0].encode('hex'), senderContact.address, senderContact.port)
# else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting
a, kw = tuple(args[:-1]), args[-1]
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
if method != b'ping':
result = func(senderContact, *a)
result = func()
except Exception as e:
log.exception("error handling request for %s:%i %s", senderContact.address, senderContact.port, method)
# No such exposed method
df.errback(AttributeError('Invalid method: %s' % method))
return df
def _msgTimeout(self, messageID):
""" Called when an RPC request message times out """
# Find the message that timed out
if messageID not in self._sentMessages:
# This should never be reached
log.error("deferred timed out, but is not present in sent messages list!")
remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID]
if messageID in self._partialMessages:
# We are still receiving this message
self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args)
del self._sentMessages[messageID]
# The message's destination node is now considered to be dead;
# raise an (asynchronous) TimeoutError exception and update the host node
def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args):
# See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args)
# No progress has been made
if messageID in self._partialMessagesProgress:
del self._partialMessagesProgress[messageID]
if messageID in self._partialMessages:
del self._partialMessages[messageID]
def _hasProgressBeenMade(self, messageID):
return (
messageID in self._partialMessagesProgress and
len(self._partialMessagesProgress[messageID]) !=
def stopProtocol(self):
""" Called when the transport is disconnected.
Will only be called once, after all ports are disconnected.
|'Stopping DHT')
|'DHT stopped')