refactor iterativeFind, move to own file

This commit is contained in:
Jack Robison 2018-05-23 17:37:20 -04:00
parent e5703833cf
commit f1e0a784d9
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 228 additions and 245 deletions

View file

@ -0,0 +1,226 @@
import logging
from twisted.internet import defer
from distance import Distance
from error import TimeoutError
import constants
log = logging.getLogger(__name__)
def get_contact(contact_list, node_id, address, port):
for contact in contact_list:
if contact.id == node_id and contact.address == address and contact.port == port:
return contact
raise IndexError(node_id)
class _IterativeFind(object):
# TODO: use polymorphism to search for a value or node
# instead of using a find_value flag
def __init__(self, node, shortlist, key, rpc):
self.node = node
self.finished_deferred = defer.Deferred()
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# The closest known and active node yet found
self.closest_node = None if not shortlist else shortlist[0]
self.prev_closest_node = None
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
self.shortlist = shortlist
# The search key
self.key = str(key)
# The rpc method name (findValue or findNode)
self.rpc = rpc
# List of active queries; len() indicates number of active probes
self.active_probes = []
# List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply
self.already_contacted = []
# A list of found and known-to-be-active remote nodes (Contact objects)
self.active_contacts = []
# Ensure only one searchIteration call is running at a time
self._search_iteration_semaphore = defer.DeferredSemaphore(1)
self._iteration_count = 0
self.find_value_result = {}
self.pending_iteration_calls = []
self._lock = defer.DeferredLock()
@property
def is_find_node_request(self):
return self.rpc == "findNode"
@property
def is_find_value_request(self):
return self.rpc == "findValue"
def is_closer(self, responseMsg):
if not self.closest_node:
return True
return self.distance.is_closer(responseMsg.nodeID, self.closest_node.id)
def getContactTriples(self, result):
if self.is_find_value_request:
contact_triples = result['contacts']
else:
contact_triples = result
for contact_tup in contact_triples:
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
raise ValueError("invalid contact triple")
return contact_triples
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
contact_list.sort(key=lambda c: self.distance(c.id))
@defer.inlineCallbacks
def extendShortlist(self, contact, responseTuple):
# The "raw response" tuple contains the response message and the originating address info
responseMsg = responseTuple[0]
originAddress = responseTuple[1] # tuple: (ip address, udp port)
if self.finished_deferred.called:
defer.returnValue(responseMsg.nodeID)
if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored")
if responseMsg.nodeID == self.node.node_id:
defer.returnValue(responseMsg.nodeID)
yield self._lock.acquire()
if contact not in self.active_contacts:
self.active_contacts.append(contact)
if contact not in self.shortlist:
self.shortlist.append(contact)
# Now grow extend the (unverified) shortlist with the returned contacts
result = responseMsg.response
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
if self.is_find_value_request and self.key in result:
# We have found the value
self.find_value_result[self.key] = result[self.key]
self._lock.release()
self.finished_deferred.callback(self.find_value_result)
else:
if self.is_find_value_request:
# We are looking for a value, and the remote node didn't have it
# - mark it as the closest "empty" node, if it is
# TODO: store to this peer after finding the value as per the kademlia spec
if 'closestNodeNoValue' in self.find_value_result:
if self.is_closer(responseMsg):
self.find_value_result['closestNodeNoValue'] = contact
else:
self.find_value_result['closestNodeNoValue'] = contact
contactTriples = self.getContactTriples(result)
for contactTriple in contactTriples:
if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted):
continue
elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])):
raise ValueError("contact is ignored")
else:
found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1],
contactTriple[2], self.node._protocol)
if found_contact not in self.shortlist:
self.shortlist.append(found_contact)
self._lock.release()
if not self.finished_deferred.called:
if self.should_stop():
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
defer.returnValue(responseMsg.nodeID)
@defer.inlineCallbacks
def probeContact(self, contact):
fn = getattr(contact, self.rpc)
try:
response_tuple = yield fn(self.key, rawResponse=True)
result = yield self.extendShortlist(contact, response_tuple)
defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id)
def should_stop(self):
active_contacts_len = len(self.active_contacts)
if active_contacts_len >= constants.k:
# log.info("there are enough results %s(%s)", self.rpc, self.key.encode('hex'))
return True
if self.prev_closest_node and self.closest_node and self.distance.is_closer(
self.prev_closest_node.id, self.closest_node.id):
# log.info("not getting any closer %s(%s)", self.rpc, self.key.encode('hex'))
return True
return False
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
@defer.inlineCallbacks
def _searchIteration(self):
yield self._lock.acquire()
# Sort the discovered active nodes from closest to furthest
if len(self.active_contacts):
self.sortByDistance(self.active_contacts)
self.prev_closest_node = self.closest_node
self.closest_node = self.active_contacts[0]
# Sort and store the current shortList length before contacting other nodes
self.sortByDistance(self.shortlist)
probes = []
already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted}
to_remove = []
for contact in self.shortlist:
if (contact.address, contact.port) not in already_contacted_addresses:
self.already_contacted.append(contact)
to_remove.append(contact)
probe = self.probeContact(contact)
probes.append(probe)
self.active_probes.append(probe)
if len(probes) == constants.alpha:
break
for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully
self.shortlist.remove(contact)
# log.info("Active probes: %i, contacted %i/%i (%s)", len(self.active_probes),
# len(self.active_contacts), len(self.already_contacted), hex(id(self)))
# run the probes
if probes:
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
self.searchIteration()
self._lock.release()
d = defer.gatherResults(probes)
@defer.inlineCallbacks
def _remove_probes(results):
yield self._lock.acquire()
for probe in probes:
self.active_probes.remove(probe)
self._lock.release()
defer.returnValue(results)
d.addCallback(_remove_probes)
elif not self.finished_deferred.called and not self.active_probes:
# If no probes were sent, there will not be any improvement, so we're done
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
def searchIteration(self):
def _cancel_pending_iterations(result):
while self.pending_iteration_calls:
canceller = self.pending_iteration_calls.pop()
canceller()
return result
self.finished_deferred.addBoth(_cancel_pending_iterations)
self._iteration_count += 1
# log.debug("iteration %i %s(%s...)", self._iteration_count, self.rpc, self.key.encode('hex')[:8])
call, cancel = self.node.reactor_callLater(1, self._search_iteration_semaphore.run, self._searchIteration)
self.pending_iteration_calls.append(cancel)
def iterativeFind(node, shortlist, key, rpc):
helper = _IterativeFind(node, shortlist, key, rpc)
helper.searchIteration()
return helper.finished_deferred

View file

@ -26,6 +26,7 @@ from error import TimeoutError
from peerfinder import DHTPeerFinder
from contact import ContactManager
from distance import Distance
from iterativefind import iterativeFind
log = logging.getLogger(__name__)
@ -599,12 +600,7 @@ class Node(MockKademliaHelper):
# This is used during the bootstrap process
shortlist = startupShortlist
outerDf = defer.Deferred()
helper = _IterativeFindHelper(self, outerDf, shortlist, key, findValue, rpc)
# Start the iterations
helper.searchIteration()
result = yield outerDf
result = yield iterativeFind(self, shortlist, key, rpc)
defer.returnValue(result)
@defer.inlineCallbacks
@ -623,242 +619,3 @@ class Node(MockKademliaHelper):
searchID = nodeIDs.pop()
yield self.iterativeFindNode(searchID)
defer.returnValue(None)
# This was originally a set of nested methods in _iterativeFind
# but they have been moved into this helper class in-order to
# have better scoping and readability
class _IterativeFindHelper(object):
# TODO: use polymorphism to search for a value or node
# instead of using a find_value flag
def __init__(self, node, outer_d, shortlist, key, find_value, rpc):
self.node = node
self.outer_d = outer_d
self.shortlist = shortlist
self.key = key
self.find_value = find_value
self.rpc = rpc
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# List of active queries; len() indicates number of active probes
#
# n.b: using lists for these variables, because Python doesn't
# allow binding a new value to a name in an enclosing
# (non-global) scope
self.active_probes = []
# List of contact IDs that have already been queried
self.already_contacted = []
# Probes that were active during the previous iteration
# A list of found and known-to-be-active remote nodes
self.active_contacts = []
# This should only contain one entry; the next scheduled iteration call
self.pending_iteration_calls = []
self.prev_closest_node = [None]
self.find_value_result = {}
self.slow_node_count = [0]
def extendShortlist(self, responseTuple):
""" @type responseMsg: kademlia.msgtypes.ResponseMessage """
# The "raw response" tuple contains the response message,
# and the originating address info
responseMsg = responseTuple[0]
originAddress = responseTuple[1] # tuple: (ip adress, udp port)
# Make sure the responding node is valid, and abort the operation if it isn't
if responseMsg.nodeID in self.active_contacts or responseMsg.nodeID == self.node.node_id:
return responseMsg.nodeID
# Mark this node as active
aContact = self._getActiveContact(responseMsg, originAddress)
self.active_contacts.append(aContact)
# This makes sure "bootstrap"-nodes with "fake" IDs don't get queried twice
if responseMsg.nodeID not in self.already_contacted:
self.already_contacted.append(responseMsg.nodeID)
# Now grow extend the (unverified) shortlist with the returned contacts
result = responseMsg.response
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
if self.find_value is True and self.key in result and not 'contacts' in result:
# We have found the value
self.find_value_result[self.key] = result[self.key]
else:
if self.find_value is True:
self._setClosestNodeValue(responseMsg, aContact)
self._keepSearching(result)
return responseMsg.nodeID
def _getActiveContact(self, responseMsg, originAddress):
if responseMsg.nodeID in self.shortlist:
# Get the contact information from the shortlist...
return self.shortlist[self.shortlist.index(responseMsg.nodeID)]
else:
# If it's not in the shortlist; we probably used a fake ID to reach it
# - reconstruct the contact, using the real node ID this time
return Contact(
responseMsg.nodeID, originAddress[0], originAddress[1], self.node._protocol)
def _keepSearching(self, result):
contactTriples = self._getContactTriples(result)
for contactTriple in contactTriples:
self._addIfValid(contactTriple)
def _getContactTriples(self, result):
if self.find_value is True:
return result['contacts']
else:
return result
def _setClosestNodeValue(self, responseMsg, aContact):
# We are looking for a value, and the remote node didn't have it
# - mark it as the closest "empty" node, if it is
if 'closestNodeNoValue' in self.find_value_result:
if self._is_closer(responseMsg):
self.find_value_result['closestNodeNoValue'] = aContact
else:
self.find_value_result['closestNodeNoValue'] = aContact
def _is_closer(self, responseMsg):
return self.distance.is_closer(responseMsg.nodeID, self.active_contacts[0].id)
def _addIfValid(self, contactTriple):
if isinstance(contactTriple, (list, tuple)) and len(contactTriple) == 3:
testContact = Contact(
contactTriple[0], contactTriple[1], contactTriple[2], self.node._protocol)
if testContact not in self.shortlist:
self.shortlist.append(testContact)
def removeFromShortlist(self, failure, deadContactID):
""" @type failure: twisted.python.failure.Failure """
failure.trap(TimeoutError, defer.CancelledError, TypeError)
if len(deadContactID) != constants.key_bits / 8:
raise ValueError("invalid lbry id")
if deadContactID in self.shortlist:
self.shortlist.remove(deadContactID)
return deadContactID
def cancelActiveProbe(self, contactID):
self.active_probes.pop()
if len(self.active_probes) <= constants.alpha / 2 and len(self.pending_iteration_calls):
# Force the iteration
self.pending_iteration_calls[0].cancel()
del self.pending_iteration_calls[0]
self.searchIteration()
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
ExpensiveSort(contact_list, self.distance.to_contact).sort()
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
def searchIteration(self):
self.slow_node_count[0] = len(self.active_probes)
# Sort the discovered active nodes from closest to furthest
self.sortByDistance(self.active_contacts)
# This makes sure a returning probe doesn't force calling this function by mistake
while len(self.pending_iteration_calls):
del self.pending_iteration_calls[0]
# See if should continue the search
if self.key in self.find_value_result:
self.outer_d.callback(self.find_value_result)
return
elif len(self.active_contacts) and self.find_value is False:
if self._is_all_done():
# TODO: Re-send the FIND_NODEs to all of the k closest nodes not already queried
#
# Ok, we're done; either we have accumulated k active
# contacts or no improvement in closestNode has been
# noted
self.outer_d.callback(self.active_contacts)
return
# The search continues...
if len(self.active_contacts):
self.prev_closest_node[0] = self.active_contacts[0]
contactedNow = 0
self.sortByDistance(self.shortlist)
# Store the current shortList length before contacting other nodes
prevShortlistLength = len(self.shortlist)
for contact in self.shortlist:
if contact.id not in self.already_contacted:
self._probeContact(contact)
contactedNow += 1
if contactedNow == constants.alpha:
break
if self._should_lookup_active_calls():
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
call, _ = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
self.pending_iteration_calls.append(call)
# Check for a quick contact response that made an update to the shortList
elif prevShortlistLength < len(self.shortlist):
# Ensure that the closest contacts are taken from the updated shortList
self.searchIteration()
else:
# If no probes were sent, there will not be any improvement, so we're done
self.outer_d.callback(self.active_contacts)
def _probeContact(self, contact):
self.active_probes.append(contact.id)
rpcMethod = getattr(contact, self.rpc)
df = rpcMethod(self.key, rawResponse=True)
df.addCallback(self.extendShortlist)
df.addErrback(self.removeFromShortlist, contact.id)
df.addCallback(self.cancelActiveProbe)
df.addErrback(lambda _: log.exception('Failed to contact %s', contact))
self.already_contacted.append(contact.id)
def _should_lookup_active_calls(self):
return (
len(self.active_probes) > self.slow_node_count[0] or
(
len(self.shortlist) < constants.k and
len(self.active_contacts) < len(self.shortlist) and
len(self.active_probes) > 0
)
)
def _is_all_done(self):
return (
len(self.active_contacts) >= constants.k or
(
self.active_contacts[0] == self.prev_closest_node[0] and
len(self.active_probes) == self.slow_node_count[0]
)
)
class ExpensiveSort(object):
"""Sort a list in place.
The result of `key(item)` is cached for each item in the `to_sort`
list as an optimization. This can be useful when `key` is
expensive.
Attributes:
to_sort: a list of items to sort
key: callable, like `key` in normal python sort
attr: the attribute name used to cache the value on each item.
"""
def __init__(self, to_sort, key, attr='__value'):
self.to_sort = to_sort
self.key = key
self.attr = attr
def sort(self):
self._cacheValues()
self._sortByValue()
self._removeValue()
def _cacheValues(self):
for item in self.to_sort:
setattr(item, self.attr, self.key(item))
def _sortByValue(self):
self.to_sort.sort(key=operator.attrgetter(self.attr))
def _removeValue(self):
for item in self.to_sort:
delattr(item, self.attr)