refactor Contact class, DHT RPCs, and Contact addition/removal

-track contact failures, last replied, and last requested. use this to provide a 'contact_is_good' property on Contact objects
-ensure no duplicate contact objects are created
-remove confusing conflation of node id strings with Contact objects, update docstrings
-move RPC failure tracking to a callback/errback pair in sendRPC (so the contact is only updated once)
-handle seed nodes during the join sequence by setting their node ids after they initially reply to our ping
-name all of the kademlia RPC keyword args, remove confusing **kwargs and dictionary parsing
-add host ip/port to DHT send/receive logging to make the results comprehensible when running many nodes at once
This commit is contained in:
Jack Robison 2018-05-23 17:32:55 -04:00
parent ad2dcf0893
commit 23c202b5e4
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 294 additions and 142 deletions

View file

@ -1,19 +1,78 @@
class Contact(object): from lbrynet.dht import constants
class _Contact(object):
""" Encapsulation for remote contact """ Encapsulation for remote contact
This class contains information on a single remote contact, and also This class contains information on a single remote contact, and also
provides a direct RPC API to the remote node which it represents provides a direct RPC API to the remote node which it represents
""" """
def __init__(self, id, ipAddress, udpPort, networkProtocol, firstComm=0): def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm):
self.id = id self._contactManager = contactManager
self._id = id
if id is not None:
if not len(id) == constants.key_bits / 8:
raise ValueError("invalid node id: %s", id.encode('hex'))
self.address = ipAddress self.address = ipAddress
self.port = udpPort self.port = udpPort
self._networkProtocol = networkProtocol self._networkProtocol = networkProtocol
self.commTime = firstComm self.commTime = firstComm
self.getTime = self._contactManager._get_time
self.lastReplied = None
self.lastRequested = None
@property
def lastInteracted(self):
return max(self.lastRequested or 0, self.lastReplied or 0, self.lastFailed or 0)
@property
def id(self):
return self._id
def log_id(self, short=True):
if not self.id:
return "not initialized"
id_hex = self.id.encode('hex')
return id_hex if not short else id_hex[:8]
@property
def failedRPCs(self):
return len(self.failures)
@property
def lastFailed(self):
return self._contactManager._rpc_failures.get((self.address, self.port), [None])[-1]
@property
def failures(self):
return self._contactManager._rpc_failures.get((self.address, self.port), [])
@property
def contact_is_good(self):
"""
:return: False if contact is bad, None if contact is unknown, or True if contact is good
"""
failures = self.failures
now = self.getTime()
delay = constants.refreshTimeout / 4
if failures:
if self.lastReplied and len(failures) >= 2 and self.lastReplied < failures[-2]:
return False
elif self.lastReplied and len(failures) >= 2 and self.lastReplied > failures[-2]:
pass # handled below
elif len(failures) >= 2:
return False
if self.lastReplied and self.lastReplied > now - delay:
return True
if self.lastReplied and self.lastRequested and self.lastRequested > now - delay:
return True
return None
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Contact): if isinstance(other, _Contact):
return self.id == other.id return self.id == other.id
elif isinstance(other, str): elif isinstance(other, str):
return self.id == other return self.id == other
@ -21,7 +80,7 @@ class Contact(object):
return False return False
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, Contact): if isinstance(other, _Contact):
return self.id != other.id return self.id != other.id
elif isinstance(other, str): elif isinstance(other, str):
return self.id != other return self.id != other
@ -33,6 +92,21 @@ class Contact(object):
lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray()) lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray())
return str(compact_ip) return str(compact_ip)
def set_id(self, id):
if not self._id:
self._id = id
def update_last_replied(self):
self.lastReplied = int(self.getTime())
def update_last_requested(self):
self.lastRequested = int(self.getTime())
def update_last_failed(self):
failures = self._contactManager._rpc_failures.get((self.address, self.port), [])
failures.append(self.getTime())
self._contactManager._rpc_failures[(self.address, self.port)] = failures
def __str__(self): def __str__(self):
return '<%s.%s object; IP address: %s, UDP port: %d>' % ( return '<%s.%s object; IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, self.address, self.port) self.__module__, self.__class__.__name__, self.address, self.port)
@ -56,3 +130,31 @@ class Contact(object):
return self._networkProtocol.sendRPC(self, name, args, **kwargs) return self._networkProtocol.sendRPC(self, name, args, **kwargs)
return _sendRPC return _sendRPC
class ContactManager(object):
def __init__(self, get_time=None):
if not get_time:
from twisted.internet import reactor
get_time = reactor.seconds
self._get_time = get_time
self._contacts = {}
self._rpc_failures = {}
def get_contact(self, id, address, port):
for contact in self._contacts.itervalues():
if contact.id == id and contact.address == address and contact.port == port:
return contact
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
ipAddress = str(ipAddress)
contact = self.get_contact(id, ipAddress, udpPort)
if contact:
return contact
contact = _Contact(self, id, ipAddress, udpPort, networkProtocol, firstComm or self._get_time())
self._contacts[(id, ipAddress, udpPort)] = contact
return contact
def is_ignored(self, origin_tuple):
failed_rpc_count = len(self._rpc_failures.get(origin_tuple, []))
return failed_rpc_count > constants.rpcAttempts

View file

@ -33,6 +33,9 @@ class TimeoutError(Exception):
def __init__(self, remote_contact_id): def __init__(self, remote_contact_id):
# remote_contact_id is a binary blob so we need to convert it # remote_contact_id is a binary blob so we need to convert it
# into something more readable # into something more readable
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id)) if remote_contact_id:
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id))
else:
msg = 'Timeout connecting to uninitialized node'
Exception.__init__(self, msg) Exception.__init__(self, msg)
self.remote_contact_id = remote_contact_id self.remote_contact_id = remote_contact_id

View file

@ -42,9 +42,19 @@ class KBucket(object):
raise BucketFull("No space in bucket to insert contact") raise BucketFull("No space in bucket to insert contact")
def getContact(self, contactID): def getContact(self, contactID):
""" Get the contact specified node ID""" """Get the contact specified node ID
index = self._contacts.index(contactID)
return self._contacts[index] @raise IndexError: raised if the contact is not in the bucket
@param contactID: the node id of the contact to retrieve
@type contactID: str
@rtype: dht.contact._Contact
"""
for contact in self._contacts:
if contact.id == contactID:
return contact
raise IndexError(contactID)
def getContacts(self, count=-1, excludeContact=None): def getContacts(self, count=-1, excludeContact=None):
""" Returns a list containing up to the first count number of contacts """ Returns a list containing up to the first count number of contacts
@ -92,14 +102,18 @@ class KBucket(object):
if excludeContact in contactList: if excludeContact in contactList:
contactList.remove(excludeContact) contactList.remove(excludeContact)
def getBadOrUnknownContacts(self):
contacts = self.getContacts(sort_distance_to=False)
results = [contact for contact in contacts if contact.contact_is_good is False]
results.extend(contact for contact in contacts if contact.contact_is_good is None)
return results
return contactList return contactList
def removeContact(self, contact): def removeContact(self, contact):
""" Remove given contact from list """ Remove the contact from the bucket
@param contact: The contact to remove, or a string containing the @param contact: The contact to remove
contact's node ID @type contact: dht.contact._Contact
@type contact: kademlia.contact.Contact or str
@raise ValueError: The specified contact is not in this bucket @raise ValueError: The specified contact is not in this bucket
""" """
@ -124,3 +138,6 @@ class KBucket(object):
def __len__(self): def __len__(self):
return len(self._contacts) return len(self._contacts)
def __contains__(self, item):
return item in self._contacts

View file

@ -24,7 +24,7 @@ import datastore
import protocol import protocol
from error import TimeoutError from error import TimeoutError
from peerfinder import DHTPeerFinder from peerfinder import DHTPeerFinder
from contact import Contact from contact import ContactManager
from distance import Distance from distance import Distance
@ -51,6 +51,7 @@ class MockKademliaHelper(object):
clock = clock or reactor clock = clock or reactor
self.clock = clock self.clock = clock
self.contact_manager = ContactManager(self.clock.seconds)
self.reactor_listenUDP = listenUDP self.reactor_listenUDP = listenUDP
self.reactor_resolve = resolve self.reactor_resolve = resolve
@ -276,8 +277,10 @@ class Node(MockKademliaHelper):
is_closer = Distance(blob_hash).is_closer(self.node_id, contacts[-1].id) is_closer = Distance(blob_hash).is_closer(self.node_id, contacts[-1].id)
if is_closer: if is_closer:
contacts.pop() contacts.pop()
yield self.store(blob_hash, value, originalPublisherID=self.node_id, self_contact = self.contact_manager.make_contact(self.node_id, self.externalIP,
self_store=True) self.port, self._protocol)
token = self.make_token(self_contact.compact_ip())
yield self.store(self_contact, blob_hash, token, self.peerPort)
elif self.externalIP is not None: elif self.externalIP is not None:
pass pass
else: else:
@ -403,17 +406,17 @@ class Node(MockKademliaHelper):
@param contact: The contact to add to this node's k-buckets @param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact @type contact: kademlia.contact.Contact
""" """
self._routingTable.addContact(contact) return self._routingTable.addContact(contact)
def removeContact(self, contactID): def removeContact(self, contact):
""" Remove the contact with the specified node ID from this node's """ Remove the contact with the specified node ID from this node's
table of known nodes. This is a simple wrapper for the same method table of known nodes. This is a simple wrapper for the same method
in this object's RoutingTable object in this object's RoutingTable object
@param contactID: The node ID of the contact to remove @param contact: The Contact object to remove
@type contactID: str @type contact: _Contact
""" """
self._routingTable.removeContact(contactID) self._routingTable.removeContact(contact)
def findContact(self, contactID): def findContact(self, contactID):
""" Find a entangled.kademlia.contact.Contact object for the specified """ Find a entangled.kademlia.contact.Contact object for the specified
@ -430,10 +433,11 @@ class Node(MockKademliaHelper):
contact = self._routingTable.getContact(contactID) contact = self._routingTable.getContact(contactID)
df = defer.Deferred() df = defer.Deferred()
df.callback(contact) df.callback(contact)
except ValueError: except (ValueError, IndexError):
def parseResults(nodes): def parseResults(nodes):
node_ids = [c.id for c in nodes]
if contactID in nodes: if contactID in nodes:
contact = nodes[nodes.index(contactID)] contact = nodes[node_ids.index(contactID)]
return contact return contact
else: else:
return None return None
@ -451,11 +455,11 @@ class Node(MockKademliaHelper):
return 'pong' return 'pong'
@rpcmethod @rpcmethod
def store(self, key, value, originalPublisherID=None, self_store=False, **kwargs): def store(self, rpc_contact, blob_hash, token, port, originalPublisherID=None, age=0):
""" Store the received data in this node's local hash table """ Store the received data in this node's local hash table
@param key: The hashtable key of the data @param blob_hash: The hashtable key of the data
@type key: str @type blob_hash: str
@param value: The actual data (the value associated with C{key}) @param value: The actual data (the value associated with C{key})
@type value: str @type value: str
@param originalPublisherID: The node ID of the node that is the @param originalPublisherID: The node ID of the node that is the
@ -473,54 +477,24 @@ class Node(MockKademliaHelper):
(which is the case currently) might not be a good idea... will have (which is the case currently) might not be a good idea... will have
to fix this (perhaps use a stream from the Protocol class?) to fix this (perhaps use a stream from the Protocol class?)
""" """
# Get the sender's ID (if any)
if originalPublisherID is None: if originalPublisherID is None:
if '_rpcNodeID' in kwargs: originalPublisherID = rpc_contact.id
originalPublisherID = kwargs['_rpcNodeID'] compact_ip = rpc_contact.compact_ip()
else: if not self.verify_token(token, compact_ip):
raise TypeError, 'No NodeID given. Therefore we can\'t store this node' raise ValueError("Invalid token")
if 0 <= port <= 65536:
if self_store is True and self.externalIP: compact_port = str(struct.pack('>H', port))
contact = Contact(self.node_id, self.externalIP, self.port, None, None)
compact_ip = contact.compact_ip()
elif '_rpcNodeContact' in kwargs:
contact = kwargs['_rpcNodeContact']
compact_ip = contact.compact_ip()
else: else:
raise TypeError, 'No contact info available' raise TypeError('Invalid port')
if not self_store:
if 'token' not in value:
raise ValueError("Missing token")
if not self.verify_token(value['token'], compact_ip):
raise ValueError("Invalid token")
if 'port' in value:
port = int(value['port'])
if 0 <= port <= 65536:
compact_port = str(struct.pack('>H', port))
else:
raise TypeError('Invalid port')
else:
raise TypeError('No port available')
if 'lbryid' in value:
if len(value['lbryid']) != constants.key_bits / 8:
raise ValueError('Invalid lbryid (%i bytes): %s' % (len(value['lbryid']),
value['lbryid'].encode('hex')))
else:
compact_address = compact_ip + compact_port + value['lbryid']
else:
raise TypeError('No lbryid given')
compact_address = compact_ip + compact_port + rpc_contact.id
now = int(time.time()) now = int(time.time())
originallyPublished = now # - age originallyPublished = now - age
self._dataStore.addPeerToBlob(key, compact_address, now, originallyPublished, self._dataStore.addPeerToBlob(blob_hash, compact_address, now, originallyPublished, originalPublisherID)
originalPublisherID)
return 'OK' return 'OK'
@rpcmethod @rpcmethod
def findNode(self, key, **kwargs): def findNode(self, rpc_contact, key):
""" Finds a number of known nodes closest to the node/value with the """ Finds a number of known nodes closest to the node/value with the
specified key. specified key.
@ -533,20 +507,17 @@ class Node(MockKademliaHelper):
node is returning all of the contacts that it knows of. node is returning all of the contacts that it knows of.
@rtype: list @rtype: list
""" """
if len(key) != constants.key_bits / 8:
raise ValueError("invalid contact id length: %i" % len(key))
# Get the sender's ID (if any) contacts = self._routingTable.findCloseNodes(key, constants.k, rpc_contact.id)
if '_rpcNodeID' in kwargs:
rpc_sender_id = kwargs['_rpcNodeID']
else:
rpc_sender_id = None
contacts = self._routingTable.findCloseNodes(key, constants.k, rpc_sender_id)
contact_triples = [] contact_triples = []
for contact in contacts: for contact in contacts:
contact_triples.append((contact.id, contact.address, contact.port)) contact_triples.append((contact.id, contact.address, contact.port))
return contact_triples return contact_triples
@rpcmethod @rpcmethod
def findValue(self, key, **kwargs): def findValue(self, rpc_contact, key):
""" Return the value associated with the specified key if present in """ Return the value associated with the specified key if present in
this node's data, otherwise execute FIND_NODE for the key this node's data, otherwise execute FIND_NODE for the key
@ -558,16 +529,18 @@ class Node(MockKademliaHelper):
@rtype: dict or list @rtype: dict or list
""" """
if len(key) != constants.key_bits / 8:
raise ValueError("invalid blob hash length: %i" % len(key))
response = {
'token': self.make_token(rpc_contact.compact_ip()),
}
if self._dataStore.hasPeersForBlob(key): if self._dataStore.hasPeersForBlob(key):
rval = {key: self._dataStore.getPeersForBlob(key)} response[key] = self._dataStore.getPeersForBlob(key)
else: else:
contact_triples = self.findNode(key, **kwargs) response['contacts'] = self.findNode(rpc_contact, key)
rval = {'contacts': contact_triples} return response
if '_rpcNodeContact' in kwargs:
contact = kwargs['_rpcNodeContact']
compact_ip = contact.compact_ip()
rval['token'] = self.make_token(compact_ip)
return rval
def _generateID(self): def _generateID(self):
""" Generates an n-bit pseudo-random identifier """ Generates an n-bit pseudo-random identifier
@ -606,13 +579,15 @@ class Node(MockKademliaHelper):
return a list of the k closest nodes to the specified key return a list of the k closest nodes to the specified key
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
findValue = rpc != 'findNode'
if len(key) != constants.key_bits / 8:
raise ValueError("invalid key length: %i" % len(key))
if startupShortlist is None: if startupShortlist is None:
shortlist = self._routingTable.findCloseNodes(key, constants.k) shortlist = self._routingTable.findCloseNodes(key, constants.k)
if key != self.node_id: # if key != self.node_id:
# Update the "last accessed" timestamp for the appropriate k-bucket # # Update the "last accessed" timestamp for the appropriate k-bucket
self._routingTable.touchKBucket(key) # self._routingTable.touchKBucket(key)
if len(shortlist) == 0: if len(shortlist) == 0:
log.warning("This node doesnt know any other nodes") log.warning("This node doesnt know any other nodes")
# This node doesn't know of any other nodes # This node doesn't know of any other nodes
@ -621,7 +596,7 @@ class Node(MockKademliaHelper):
result = yield fakeDf result = yield fakeDf
defer.returnValue(result) defer.returnValue(result)
else: else:
# This is used during the bootstrap process; node ID's are most probably fake # This is used during the bootstrap process
shortlist = startupShortlist shortlist = startupShortlist
outerDf = defer.Deferred() outerDf = defer.Deferred()

View file

@ -9,7 +9,6 @@ import constants
import encoding import encoding
import msgtypes import msgtypes
import msgformat import msgformat
from contact import Contact
from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -29,7 +28,8 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._partialMessagesProgress = {} self._partialMessagesProgress = {}
def sendRPC(self, contact, method, args, rawResponse=False): def sendRPC(self, contact, method, args, rawResponse=False):
""" Sends an RPC to the specified contact """
Sends an RPC to the specified contact
@param contact: The contact (remote node) to send the RPC to @param contact: The contact (remote node) to send the RPC to
@type contact: kademlia.contacts.Contact @type contact: kademlia.contacts.Contact
@ -60,19 +60,39 @@ class KademliaProtocol(protocol.DatagramProtocol):
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = self._encoder.encode(msgPrimitive)
if args: if args:
log.debug("DHT SEND CALL %s(%s)", method, args[0].encode('hex')) log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
args[0].encode('hex'), contact.address, contact.port)
else: else:
log.debug("DHT SEND CALL %s", method) log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port)
df = defer.Deferred() df = defer.Deferred()
if rawResponse: if rawResponse:
df._rpcRawResponse = True df._rpcRawResponse = True
def _remove_contact(failure): # remove the contact from the routing table and track the failure
try:
self._node.removeContact(contact)
except (ValueError, IndexError):
pass
contact.update_last_failed()
return failure
def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied()
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
df.addCallbacks(_update_contact, _remove_contact)
# Set the RPC timeout timer # Set the RPC timeout timer
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id) timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data # Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port)) self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args) self._sentMessages[msg.id] = (contact, df, timeoutCall, cancelTimeout, method, args)
df.addErrback(cancelTimeout) df.addErrback(cancelTimeout)
return df return df
@ -115,46 +135,80 @@ class KademliaProtocol(protocol.DatagramProtocol):
log.warning("Couldn't decode dht datagram from %s", address) log.warning("Couldn't decode dht datagram from %s", address)
return return
remoteContact = Contact(message.nodeID, address[0], address[1], self)
# Refresh the remote node's details in the local node's k-buckets
self._node.addContact(remoteContact)
if isinstance(message, msgtypes.RequestMessage): if isinstance(message, msgtypes.RequestMessage):
# This is an RPC method request # This is an RPC method request
self._handleRPC(remoteContact, message.id, message.request, message.args) remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self)
remoteContact.update_last_requested()
# only add a requesting contact to the routing table if it has replied to one of our requests
if remoteContact.contact_is_good is True:
df = self._node.addContact(remoteContact)
else:
df = defer.succeed(None)
df.addCallback(lambda _: self._handleRPC(remoteContact, message.id, message.request, message.args))
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful
if remoteContact.contact_is_good is None and remoteContact.lastReplied is None:
df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
else:
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
if message.id in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCanceller()
del self._sentMessages[message.id]
# reject replies coming from a different address than what we sent our request to
if (remoteContact.address, remoteContact.port) != address:
log.warning("Sent request to node %s at %s:%i, got reply from %s:%i",
remoteContact.log_id(), remoteContact.address,
remoteContact.port, address[0], address[1])
df.errback(TimeoutError(remoteContact.id))
return
# this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
df.errback(remoteException)
elif isinstance(message, msgtypes.ResponseMessage): elif isinstance(message, msgtypes.ResponseMessage):
# Find the message that triggered this response # Find the message that triggered this response
if message.id in self._sentMessages: if message.id in self._sentMessages:
# Cancel timeout timer for this RPC # Cancel timeout timer for this RPC
df, timeoutCall = self._sentMessages[message.id][1:3] remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCall.cancel() timeoutCanceller()
del self._sentMessages[message.id] del self._sentMessages[message.id]
log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port,
method, remoteContact.address, remoteContact.port)
# When joining the network we made Contact objects for the seed nodes with node ids set to None
# Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set.
# These replies have be distinguished from those where the node id in the datagram does not match
# the node id of the node we sent a message to (these messages are treated as an error)
if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
remoteContact.log_id(False), message.nodeID.encode('hex'))
df.errback(TimeoutError(remoteContact.id))
return
elif not remoteContact.id:
remoteContact.set_id(message.nodeID)
if hasattr(df, '_rpcRawResponse'): if hasattr(df, '_rpcRawResponse'):
# The RPC requested that the raw response message # The RPC requested that the raw response message
# and originating address be returned; do not # and originating address be returned; do not
# interpret it # interpret it
df.callback((message, address)) df.callback((message, address))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
else:
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
# this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
df.errback(remoteException)
else: else:
# We got a result from the RPC # We got a result from the RPC
df.callback(message.response) df.callback(message.response)
@ -259,28 +313,29 @@ class KademliaProtocol(protocol.DatagramProtocol):
# Execute the RPC # Execute the RPC
func = getattr(self._node, method, None) func = getattr(self._node, method, None)
if callable(func) and hasattr(func, 'rpcmethod'): if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain # Call the exposed Node method and return the result to the deferred callback chain
if args: if args:
log.debug("DHT RECV CALL %s(%s) %s:%i", method, args[0].encode('hex'), log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port) args[0].encode('hex'), senderContact.address, senderContact.port)
else: else:
log.debug("DHT RECV CALL %s %s:%i", method, senderContact.address, log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.port) senderContact.address, senderContact.port)
try: try:
if method != 'ping': if method != 'ping':
kwargs = {'_rpcNodeID': senderContact.id, '_rpcNodeContact': senderContact} result = func(senderContact, *args)
result = func(*args, **kwargs)
else: else:
result = func() result = func()
except Exception, e: except Exception, e:
log.exception("error handling request for %s: %s", senderContact.address, method) log.exception("error handling request for %s:%i %s", senderContact.address,
senderContact.port, method)
df.errback(e) df.errback(e)
else: else:
df.callback(result) df.callback(result)
else: else:
# No such exposed method # No such exposed method
df.errback(AttributeError('Invalid method: %s' % method)) df.errback(AttributeError('Invalid method: %s' % method))
return df
def _msgTimeout(self, messageID): def _msgTimeout(self, messageID):
""" Called when an RPC request message times out """ """ Called when an RPC request message times out """
@ -289,30 +344,30 @@ class KademliaProtocol(protocol.DatagramProtocol):
# This should never be reached # This should never be reached
log.error("deferred timed out, but is not present in sent messages list!") log.error("deferred timed out, but is not present in sent messages list!")
return return
remoteContactID, df, timeout_call, method, args = self._sentMessages[messageID] remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID]
if self._partialMessages.has_key(messageID): if self._partialMessages.has_key(messageID):
# We are still receiving this message # We are still receiving this message
self._msgTimeoutInProgress(messageID, remoteContactID, df, method, args) self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args)
return return
del self._sentMessages[messageID] del self._sentMessages[messageID]
# The message's destination node is now considered to be dead; # The message's destination node is now considered to be dead;
# raise an (asynchronous) TimeoutError exception and update the host node # raise an (asynchronous) TimeoutError exception and update the host node
self._node.removeContact(remoteContactID) df.errback(TimeoutError(remoteContact.id))
df.errback(TimeoutError(remoteContactID))
def _msgTimeoutInProgress(self, messageID, remoteContactID, df, method, args): def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args):
# See if any progress has been made; if not, kill the message # See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID): if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer # Reset the RPC timeout timer
timeoutCall, _ = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID) timeoutCanceller()
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args)
else: else:
# No progress has been made # No progress has been made
if messageID in self._partialMessagesProgress: if messageID in self._partialMessagesProgress:
del self._partialMessagesProgress[messageID] del self._partialMessagesProgress[messageID]
if messageID in self._partialMessages: if messageID in self._partialMessages:
del self._partialMessages[messageID] del self._partialMessages[messageID]
df.errback(TimeoutError(remoteContactID)) df.errback(TimeoutError(remoteContact.id))
def _hasProgressBeenMade(self, messageID): def _hasProgressBeenMade(self, messageID):
return ( return (

View file

@ -202,16 +202,16 @@ class TreeRoutingTable(object):
bucketIndex += 1 bucketIndex += 1
return refreshIDs return refreshIDs
def removeContact(self, contactID): def removeContact(self, contact):
""" Remove the contact with the specified node ID from the routing
table
@param contactID: The node ID of the contact to remove
@type contactID: str
""" """
bucketIndex = self._kbucketIndex(contactID) Remove the contact from the routing table
@param contact: The contact to remove
@type contact: dht.contact._Contact
"""
bucketIndex = self._kbucketIndex(contact.id)
try: try:
self._buckets[bucketIndex].removeContact(contactID) self._buckets[bucketIndex].removeContact(contact)
except ValueError: except ValueError:
return return