lbry-sdk/lbrynet/dht/protocol.py

460 lines
19 KiB
Python
Raw Normal View History

import logging
2015-08-20 17:27:15 +02:00
import time
import socket
import errno
2015-08-20 17:27:15 +02:00
2017-05-25 20:01:39 +02:00
from twisted.internet import protocol, defer, error, reactor, task
2015-08-20 17:27:15 +02:00
import constants
import encoding
import msgtypes
import msgformat
from contact import Contact
2017-10-10 19:09:25 +02:00
from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError
2017-10-10 19:08:22 +02:00
from delay import Delay
2015-08-20 17:27:15 +02:00
log = logging.getLogger(__name__)
2015-08-20 17:27:15 +02:00
class KademliaProtocol(protocol.DatagramProtocol):
""" Implements all low-level network-related functions of a Kademlia node """
2017-05-25 20:01:39 +02:00
2017-03-31 19:32:43 +02:00
msgSizeLimit = constants.udpDatagramMaxSize - 26
2015-08-20 17:27:15 +02:00
2017-05-25 20:01:39 +02:00
def __init__(self, node):
2015-08-20 17:27:15 +02:00
self._node = node
2017-05-25 20:01:39 +02:00
self._encoder = encoding.Bencode()
self._translator = msgformat.DefaultFormat()
2015-08-20 17:27:15 +02:00
self._sentMessages = {}
self._partialMessages = {}
self._partialMessagesProgress = {}
self._delay = Delay()
# keep track of outstanding writes so that they
# can be cancelled on shutdown
self._call_later_list = {}
2015-08-20 17:27:15 +02:00
2017-05-25 20:01:39 +02:00
# keep track of bandwidth usage by peer
self._history_rx = {}
self._history_tx = {}
self._bytes_rx = {}
self._bytes_tx = {}
2017-06-12 22:19:26 +02:00
self._unique_contacts = []
2017-05-25 20:01:39 +02:00
self._queries_rx_per_second = 0
self._queries_tx_per_second = 0
self._kbps_tx = 0
self._kbps_rx = 0
self._recent_contact_count = 0
self._total_bytes_tx = 0
self._total_bytes_rx = 0
self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats)
def _update_bandwidth_stats(self):
recent_rx_history = {}
now = time.time()
for address, history in self._history_rx.iteritems():
recent_rx_history[address] = [(s, t) for (s, t) in history if now - t < 1.0]
qps_rx = sum(len(v) for (k, v) in recent_rx_history.iteritems())
bps_rx = sum(sum([x[0] for x in v]) for (k, v) in recent_rx_history.iteritems())
kbps_rx = round(float(bps_rx) / 1024.0, 2)
recent_tx_history = {}
now = time.time()
for address, history in self._history_tx.iteritems():
recent_tx_history[address] = [(s, t) for (s, t) in history if now - t < 1.0]
qps_tx = sum(len(v) for (k, v) in recent_tx_history.iteritems())
bps_tx = sum(sum([x[0] for x in v]) for (k, v) in recent_tx_history.iteritems())
kbps_tx = round(float(bps_tx) / 1024.0, 2)
recent_contacts = []
for k, v in recent_rx_history.iteritems():
if v:
recent_contacts.append(k)
for k, v in recent_tx_history.iteritems():
if v and k not in recent_contacts:
recent_contacts.append(k)
self._queries_rx_per_second = qps_rx
self._queries_tx_per_second = qps_tx
self._kbps_tx = kbps_tx
self._kbps_rx = kbps_rx
self._recent_contact_count = len(recent_contacts)
self._total_bytes_tx = sum(v for (k, v) in self._bytes_tx.iteritems())
self._total_bytes_rx = sum(v for (k, v) in self._bytes_rx.iteritems())
2017-06-12 22:19:26 +02:00
@property
def unique_contacts(self):
return self._unique_contacts
2017-05-25 20:01:39 +02:00
@property
def queries_rx_per_second(self):
return self._queries_rx_per_second
@property
def queries_tx_per_second(self):
return self._queries_tx_per_second
@property
def kbps_tx(self):
return self._kbps_tx
@property
def kbps_rx(self):
return self._kbps_rx
@property
def recent_contact_count(self):
return self._recent_contact_count
@property
def total_bytes_tx(self):
return self._total_bytes_tx
@property
def total_bytes_rx(self):
return self._total_bytes_rx
@property
def bandwidth_stats(self):
response = {
"kbps_received": self.kbps_rx,
"kbps_sent": self.kbps_tx,
"total_bytes_sent": self.total_bytes_tx,
"total_bytes_received": self.total_bytes_rx,
"queries_received": self.queries_rx_per_second,
"queries_sent": self.queries_tx_per_second,
"recent_contacts": self.recent_contact_count,
2017-06-12 22:19:26 +02:00
"unique_contacts": len(self.unique_contacts)
2017-05-25 20:01:39 +02:00
}
return response
2015-08-20 17:27:15 +02:00
def sendRPC(self, contact, method, args, rawResponse=False):
""" Sends an RPC to the specified contact
@param contact: The contact (remote node) to send the RPC to
@type contact: kademlia.contacts.Contact
@param method: The name of remote method to invoke
@type method: str
@param args: A list of (non-keyword) arguments to pass to the remote
method, in the correct order
@type args: tuple
@param rawResponse: If this is set to C{True}, the caller of this RPC
will receive a tuple containing the actual response
message object and the originating address tuple as
a result; in other words, it will not be
interpreted by this class. Unless something special
needs to be done with the metadata associated with
the message, this should remain C{False}.
@type rawResponse: bool
@return: This immediately returns a deferred object, which will return
the result of the RPC call, or raise the relevant exception
if the remote node raised one. If C{rawResponse} is set to
C{True}, however, it will always return the actual response
message (which may be a C{ResponseMessage} or an
C{ErrorMessage}).
@rtype: twisted.internet.defer.Deferred
"""
2017-10-10 19:15:25 +02:00
msg = msgtypes.RequestMessage(self._node.node_id, method, args)
2015-08-20 17:27:15 +02:00
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
if args:
log.debug("DHT SEND CALL %s(%s)", method, args[0].encode('hex'))
else:
log.debug("DHT SEND CALL %s", method)
2017-04-10 16:51:49 +02:00
2015-08-20 17:27:15 +02:00
df = defer.Deferred()
if rawResponse:
df._rpcRawResponse = True
# Set the RPC timeout timer
2017-05-25 20:01:39 +02:00
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
2015-08-20 17:27:15 +02:00
# Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args)
2015-08-20 17:27:15 +02:00
return df
2017-05-25 20:01:39 +02:00
def startProtocol(self):
log.info("DHT listening on UDP %i", self._node.port)
if not self._bandwidth_stats_update_lc.running:
self._bandwidth_stats_update_lc.start(1)
2015-08-20 17:27:15 +02:00
def datagramReceived(self, datagram, address):
""" Handles and parses incoming RPC messages (and responses)
@note: This is automatically called by Twisted when the protocol
receives a UDP datagram
"""
2017-05-25 20:01:39 +02:00
2015-08-20 17:27:15 +02:00
if datagram[0] == '\x00' and datagram[25] == '\x00':
totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2])
msgID = datagram[5:25]
seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4])
if msgID not in self._partialMessages:
self._partialMessages[msgID] = {}
self._partialMessages[msgID][seqNumber] = datagram[26:]
if len(self._partialMessages[msgID]) == totalPackets:
keys = self._partialMessages[msgID].keys()
keys.sort()
data = ''
for key in keys:
data += self._partialMessages[msgID][key]
datagram = data
del self._partialMessages[msgID]
else:
return
try:
msgPrimitive = self._encoder.decode(datagram)
except encoding.DecodeError:
# We received some rubbish here
return
except IndexError:
log.warning("Couldn't decode dht datagram from %s", address)
return
2015-08-20 17:27:15 +02:00
message = self._translator.fromPrimitive(msgPrimitive)
remoteContact = Contact(message.nodeID, address[0], address[1], self)
2017-05-25 20:01:39 +02:00
now = time.time()
contact_history = self._history_rx.get(address, [])
if len(contact_history) > 1000:
contact_history = [x for x in contact_history if now - x[1] < 1.0]
contact_history.append((len(datagram), time.time()))
self._history_rx[address] = contact_history
bytes_rx = self._bytes_rx.get(address, 0)
bytes_rx += len(datagram)
self._bytes_rx[address] = bytes_rx
2017-06-12 22:19:26 +02:00
if address not in self.unique_contacts:
self._unique_contacts.append(address)
2017-05-25 20:01:39 +02:00
2015-08-20 17:27:15 +02:00
# Refresh the remote node's details in the local node's k-buckets
self._node.addContact(remoteContact)
if isinstance(message, msgtypes.RequestMessage):
# This is an RPC method request
self._handleRPC(remoteContact, message.id, message.request, message.args)
2017-05-25 20:01:39 +02:00
2015-08-20 17:27:15 +02:00
elif isinstance(message, msgtypes.ResponseMessage):
# Find the message that triggered this response
2017-04-10 16:51:49 +02:00
if message.id in self._sentMessages:
2015-08-20 17:27:15 +02:00
# Cancel timeout timer for this RPC
df, timeoutCall = self._sentMessages[message.id][1:3]
timeoutCall.cancel()
del self._sentMessages[message.id]
if hasattr(df, '_rpcRawResponse'):
# The RPC requested that the raw response message
# and originating address be returned; do not
# interpret it
2015-08-20 17:27:15 +02:00
df.callback((message, address))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
2017-10-10 19:09:25 +02:00
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
else:
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
log.error("Remote exception (%s): %s", address, remoteException)
2015-08-20 17:27:15 +02:00
df.errback(remoteException)
else:
# We got a result from the RPC
df.callback(message.response)
else:
# If the original message isn't found, it must have timed out
2017-03-31 19:32:43 +02:00
# TODO: we should probably do something with this...
2015-08-20 17:27:15 +02:00
pass
def _send(self, data, rpcID, address):
""" Transmit the specified data over UDP, breaking it up into several
packets if necessary
2015-08-20 17:27:15 +02:00
If the data is spread over multiple UDP datagrams, the packets have the
following structure::
| | | | | |||||||||||| 0x00 |
|Transmision|Total number|Sequence number| RPC ID |Header end|
| type ID | of packets |of this packet | | indicator|
| (1 byte) | (2 bytes) | (2 bytes) |(20 bytes)| (1 byte) |
| | | | | |||||||||||| |
2015-08-20 17:27:15 +02:00
@note: The header used for breaking up large data segments will
possibly be moved out of the KademliaProtocol class in the
future, into something similar to a message translator/encoder
2016-12-14 00:08:29 +01:00
class (see C{kademlia.msgformat} and C{kademlia.encoding}).
2015-08-20 17:27:15 +02:00
"""
2017-05-25 20:01:39 +02:00
now = time.time()
contact_history = self._history_tx.get(address, [])
if len(contact_history) > 1000:
contact_history = [x for x in contact_history if now - x[1] < 1.0]
contact_history.append((len(data), time.time()))
self._history_tx[address] = contact_history
bytes_tx = self._bytes_tx.get(address, 0)
bytes_tx += len(data)
self._bytes_tx[address] = bytes_tx
2017-06-12 22:19:26 +02:00
if address not in self.unique_contacts:
self._unique_contacts.append(address)
2017-05-25 20:01:39 +02:00
2015-08-20 17:27:15 +02:00
if len(data) > self.msgSizeLimit:
2016-11-30 21:20:45 +01:00
# We have to spread the data over multiple UDP datagrams,
# and provide sequencing information
#
# 1st byte is transmission type id, bytes 2 & 3 are the
# total number of packets in this transmission, bytes 4 &
# 5 are the sequence number for this specific packet
2015-08-20 17:27:15 +02:00
totalPackets = len(data) / self.msgSizeLimit
if len(data) % self.msgSizeLimit > 0:
totalPackets += 1
encTotalPackets = chr(totalPackets >> 8) + chr(totalPackets & 0xff)
seqNumber = 0
startPos = 0
while seqNumber < totalPackets:
2017-03-31 19:32:43 +02:00
packetData = data[startPos:startPos + self.msgSizeLimit]
2015-08-20 17:27:15 +02:00
encSeqNumber = chr(seqNumber >> 8) + chr(seqNumber & 0xff)
txData = '\x00%s%s%s\x00%s' % (encTotalPackets, encSeqNumber, rpcID, packetData)
self._scheduleSendNext(txData, address)
2015-08-20 17:27:15 +02:00
startPos += self.msgSizeLimit
seqNumber += 1
else:
self._scheduleSendNext(data, address)
2015-08-20 17:27:15 +02:00
def _scheduleSendNext(self, txData, address):
"""Schedule the sending of the next UDP packet """
delay = self._delay()
key = object()
delayed_call = reactor.callLater(delay, self._write_and_remove, key, txData, address)
self._call_later_list[key] = delayed_call
def _write_and_remove(self, key, txData, address):
del self._call_later_list[key]
2015-08-20 17:27:15 +02:00
if self.transport:
try:
self.transport.write(txData, address)
except socket.error as err:
2017-02-11 15:49:59 +01:00
if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesnt seem to affect anything -grin
log.debug("Can't send data to dht: EWOULDBLOCK")
elif err.errno == errno.ENETUNREACH:
# this should probably try to retransmit when the network connection is back
log.error("Network is unreachable")
2017-02-11 15:49:59 +01:00
else:
log.error("DHT socket error: %s (%i)", err.message, err.errno)
raise err
2015-08-20 17:27:15 +02:00
def _sendResponse(self, contact, rpcID, response):
""" Send a RPC response to the specified contact
"""
2017-10-10 19:15:25 +02:00
msg = msgtypes.ResponseMessage(rpcID, self._node.node_id, response)
2015-08-20 17:27:15 +02:00
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
""" Send an RPC error message to the specified contact
"""
2017-10-10 19:15:25 +02:00
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
2015-08-20 17:27:15 +02:00
msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive)
self._send(encodedMsg, rpcID, (contact.address, contact.port))
def _handleRPC(self, senderContact, rpcID, method, args):
""" Executes a local function in response to an RPC request """
2017-03-31 19:32:43 +02:00
2015-08-20 17:27:15 +02:00
# Set up the deferred callchain
def handleError(f):
self._sendError(senderContact, rpcID, f.type, f.getErrorMessage())
def handleResult(result):
self._sendResponse(senderContact, rpcID, result)
df = defer.Deferred()
df.addCallback(handleResult)
df.addErrback(handleError)
# Execute the RPC
func = getattr(self._node, method, None)
if callable(func) and hasattr(func, 'rpcmethod'):
# Call the exposed Node method and return the result to the deferred callback chain
if args:
log.debug("DHT RECV CALL %s(%s) %s:%i", method, args[0].encode('hex'),
senderContact.address, senderContact.port)
else:
log.debug("DHT RECV CALL %s %s:%i", method, senderContact.address,
senderContact.port)
2015-08-20 17:27:15 +02:00
try:
2017-10-10 19:18:38 +02:00
if method != 'ping':
kwargs = {'_rpcNodeID': senderContact.id, '_rpcNodeContact': senderContact}
result = func(*args, **kwargs)
else:
result = func()
2015-08-20 17:27:15 +02:00
except Exception, e:
2017-10-10 19:18:38 +02:00
log.exception("error handling request for %s: %s", senderContact.address, method)
2017-10-10 19:17:36 +02:00
df.errback(e)
2015-08-20 17:27:15 +02:00
else:
df.callback(result)
else:
# No such exposed method
2017-10-10 19:17:36 +02:00
df.errback(AttributeError('Invalid method: %s' % method))
2015-08-20 17:27:15 +02:00
def _msgTimeout(self, messageID):
""" Called when an RPC request message times out """
# Find the message that timed out
2017-10-10 19:29:29 +02:00
if messageID not in self._sentMessages:
2015-08-20 17:27:15 +02:00
# This should never be reached
log.error("deferred timed out, but is not present in sent messages list!")
2016-12-14 20:57:19 +01:00
return
remoteContactID, df, timeout_call, method, args = self._sentMessages[messageID]
2016-12-14 20:57:19 +01:00
if self._partialMessages.has_key(messageID):
# We are still receiving this message
self._msgTimeoutInProgress(messageID, remoteContactID, df, method, args)
2016-12-14 20:57:19 +01:00
return
del self._sentMessages[messageID]
# The message's destination node is now considered to be dead;
# raise an (asynchronous) TimeoutError exception and update the host node
self._node.removeContact(remoteContactID)
2017-10-10 19:17:36 +02:00
df.errback(TimeoutError(remoteContactID))
2016-12-14 20:57:19 +01:00
def _msgTimeoutInProgress(self, messageID, remoteContactID, df, method, args):
2016-12-14 20:57:19 +01:00
# See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args)
2016-12-14 20:57:19 +01:00
else:
# No progress has been made
del self._partialMessagesProgress[messageID]
del self._partialMessages[messageID]
2017-10-10 19:17:36 +02:00
df.errback(TimeoutError(remoteContactID))
2016-12-14 20:57:19 +01:00
def _hasProgressBeenMade(self, messageID):
return (
self._partialMessagesProgress.has_key(messageID) and
(
len(self._partialMessagesProgress[messageID]) !=
len(self._partialMessages[messageID])
)
)
2015-08-20 17:27:15 +02:00
def stopProtocol(self):
""" Called when the transport is disconnected.
2015-08-20 17:27:15 +02:00
Will only be called once, after all ports are disconnected.
"""
2017-04-10 16:51:49 +02:00
log.info('Stopping DHT')
2017-05-25 20:01:39 +02:00
if self._bandwidth_stats_update_lc.running:
self._bandwidth_stats_update_lc.stop()
for delayed_call in self._call_later_list.values():
2015-08-20 17:27:15 +02:00
try:
delayed_call.cancel()
except (error.AlreadyCalled, error.AlreadyCancelled):
log.debug('Attempted to cancel a DelayedCall that was not active')
except Exception:
log.exception('Failed to cancel a DelayedCall')
2017-03-31 19:32:43 +02:00
# not sure why this is needed, but taking this out sometimes causes
# exceptions.AttributeError: 'Port' object has no attribute 'socket'
# to happen on shutdown
# reactor.iterate()
2017-04-10 16:51:49 +02:00
log.info('DHT stopped')