forked from LBRYCommunity/lbry-sdk
refactor iterativeFind, move to own file
This commit is contained in:
parent
e5703833cf
commit
f1e0a784d9
2 changed files with 228 additions and 245 deletions
226
lbrynet/dht/iterativefind.py
Normal file
226
lbrynet/dht/iterativefind.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
import logging
|
||||
from twisted.internet import defer
|
||||
from distance import Distance
|
||||
from error import TimeoutError
|
||||
import constants
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_contact(contact_list, node_id, address, port):
|
||||
for contact in contact_list:
|
||||
if contact.id == node_id and contact.address == address and contact.port == port:
|
||||
return contact
|
||||
raise IndexError(node_id)
|
||||
|
||||
|
||||
class _IterativeFind(object):
|
||||
# TODO: use polymorphism to search for a value or node
|
||||
# instead of using a find_value flag
|
||||
def __init__(self, node, shortlist, key, rpc):
|
||||
self.node = node
|
||||
self.finished_deferred = defer.Deferred()
|
||||
# all distance operations in this class only care about the distance
|
||||
# to self.key, so this makes it easier to calculate those
|
||||
self.distance = Distance(key)
|
||||
# The closest known and active node yet found
|
||||
self.closest_node = None if not shortlist else shortlist[0]
|
||||
self.prev_closest_node = None
|
||||
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
|
||||
self.shortlist = shortlist
|
||||
# The search key
|
||||
self.key = str(key)
|
||||
# The rpc method name (findValue or findNode)
|
||||
self.rpc = rpc
|
||||
# List of active queries; len() indicates number of active probes
|
||||
self.active_probes = []
|
||||
# List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply
|
||||
self.already_contacted = []
|
||||
# A list of found and known-to-be-active remote nodes (Contact objects)
|
||||
self.active_contacts = []
|
||||
# Ensure only one searchIteration call is running at a time
|
||||
self._search_iteration_semaphore = defer.DeferredSemaphore(1)
|
||||
self._iteration_count = 0
|
||||
self.find_value_result = {}
|
||||
self.pending_iteration_calls = []
|
||||
self._lock = defer.DeferredLock()
|
||||
|
||||
@property
|
||||
def is_find_node_request(self):
|
||||
return self.rpc == "findNode"
|
||||
|
||||
@property
|
||||
def is_find_value_request(self):
|
||||
return self.rpc == "findValue"
|
||||
|
||||
def is_closer(self, responseMsg):
|
||||
if not self.closest_node:
|
||||
return True
|
||||
return self.distance.is_closer(responseMsg.nodeID, self.closest_node.id)
|
||||
|
||||
def getContactTriples(self, result):
|
||||
if self.is_find_value_request:
|
||||
contact_triples = result['contacts']
|
||||
else:
|
||||
contact_triples = result
|
||||
for contact_tup in contact_triples:
|
||||
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
|
||||
raise ValueError("invalid contact triple")
|
||||
return contact_triples
|
||||
|
||||
def sortByDistance(self, contact_list):
|
||||
"""Sort the list of contacts in order by distance from key"""
|
||||
contact_list.sort(key=lambda c: self.distance(c.id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def extendShortlist(self, contact, responseTuple):
|
||||
# The "raw response" tuple contains the response message and the originating address info
|
||||
responseMsg = responseTuple[0]
|
||||
originAddress = responseTuple[1] # tuple: (ip address, udp port)
|
||||
if self.finished_deferred.called:
|
||||
defer.returnValue(responseMsg.nodeID)
|
||||
if self.node.contact_manager.is_ignored(originAddress):
|
||||
raise ValueError("contact is ignored")
|
||||
if responseMsg.nodeID == self.node.node_id:
|
||||
defer.returnValue(responseMsg.nodeID)
|
||||
|
||||
yield self._lock.acquire()
|
||||
|
||||
if contact not in self.active_contacts:
|
||||
self.active_contacts.append(contact)
|
||||
if contact not in self.shortlist:
|
||||
self.shortlist.append(contact)
|
||||
|
||||
# Now grow extend the (unverified) shortlist with the returned contacts
|
||||
result = responseMsg.response
|
||||
# TODO: some validation on the result (for guarding against attacks)
|
||||
# If we are looking for a value, first see if this result is the value
|
||||
# we are looking for before treating it as a list of contact triples
|
||||
if self.is_find_value_request and self.key in result:
|
||||
# We have found the value
|
||||
self.find_value_result[self.key] = result[self.key]
|
||||
self._lock.release()
|
||||
self.finished_deferred.callback(self.find_value_result)
|
||||
else:
|
||||
if self.is_find_value_request:
|
||||
# We are looking for a value, and the remote node didn't have it
|
||||
# - mark it as the closest "empty" node, if it is
|
||||
# TODO: store to this peer after finding the value as per the kademlia spec
|
||||
if 'closestNodeNoValue' in self.find_value_result:
|
||||
if self.is_closer(responseMsg):
|
||||
self.find_value_result['closestNodeNoValue'] = contact
|
||||
else:
|
||||
self.find_value_result['closestNodeNoValue'] = contact
|
||||
contactTriples = self.getContactTriples(result)
|
||||
for contactTriple in contactTriples:
|
||||
if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted):
|
||||
continue
|
||||
elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])):
|
||||
raise ValueError("contact is ignored")
|
||||
else:
|
||||
found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1],
|
||||
contactTriple[2], self.node._protocol)
|
||||
if found_contact not in self.shortlist:
|
||||
self.shortlist.append(found_contact)
|
||||
|
||||
self._lock.release()
|
||||
|
||||
if not self.finished_deferred.called:
|
||||
if self.should_stop():
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
|
||||
|
||||
defer.returnValue(responseMsg.nodeID)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def probeContact(self, contact):
|
||||
fn = getattr(contact, self.rpc)
|
||||
try:
|
||||
response_tuple = yield fn(self.key, rawResponse=True)
|
||||
result = yield self.extendShortlist(contact, response_tuple)
|
||||
defer.returnValue(result)
|
||||
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
|
||||
defer.returnValue(contact.id)
|
||||
|
||||
def should_stop(self):
|
||||
active_contacts_len = len(self.active_contacts)
|
||||
if active_contacts_len >= constants.k:
|
||||
# log.info("there are enough results %s(%s)", self.rpc, self.key.encode('hex'))
|
||||
return True
|
||||
if self.prev_closest_node and self.closest_node and self.distance.is_closer(
|
||||
self.prev_closest_node.id, self.closest_node.id):
|
||||
# log.info("not getting any closer %s(%s)", self.rpc, self.key.encode('hex'))
|
||||
return True
|
||||
return False
|
||||
|
||||
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
|
||||
@defer.inlineCallbacks
|
||||
def _searchIteration(self):
|
||||
yield self._lock.acquire()
|
||||
# Sort the discovered active nodes from closest to furthest
|
||||
if len(self.active_contacts):
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.prev_closest_node = self.closest_node
|
||||
self.closest_node = self.active_contacts[0]
|
||||
|
||||
# Sort and store the current shortList length before contacting other nodes
|
||||
self.sortByDistance(self.shortlist)
|
||||
probes = []
|
||||
already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted}
|
||||
to_remove = []
|
||||
for contact in self.shortlist:
|
||||
if (contact.address, contact.port) not in already_contacted_addresses:
|
||||
self.already_contacted.append(contact)
|
||||
to_remove.append(contact)
|
||||
probe = self.probeContact(contact)
|
||||
probes.append(probe)
|
||||
self.active_probes.append(probe)
|
||||
if len(probes) == constants.alpha:
|
||||
break
|
||||
for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully
|
||||
self.shortlist.remove(contact)
|
||||
|
||||
# log.info("Active probes: %i, contacted %i/%i (%s)", len(self.active_probes),
|
||||
# len(self.active_contacts), len(self.already_contacted), hex(id(self)))
|
||||
|
||||
# run the probes
|
||||
if probes:
|
||||
# Schedule the next iteration if there are any active
|
||||
# calls (Kademlia uses loose parallelism)
|
||||
self.searchIteration()
|
||||
self._lock.release()
|
||||
|
||||
d = defer.gatherResults(probes)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_probes(results):
|
||||
yield self._lock.acquire()
|
||||
for probe in probes:
|
||||
self.active_probes.remove(probe)
|
||||
self._lock.release()
|
||||
defer.returnValue(results)
|
||||
|
||||
d.addCallback(_remove_probes)
|
||||
|
||||
elif not self.finished_deferred.called and not self.active_probes:
|
||||
# If no probes were sent, there will not be any improvement, so we're done
|
||||
self.sortByDistance(self.active_contacts)
|
||||
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
|
||||
|
||||
def searchIteration(self):
|
||||
def _cancel_pending_iterations(result):
|
||||
while self.pending_iteration_calls:
|
||||
canceller = self.pending_iteration_calls.pop()
|
||||
canceller()
|
||||
return result
|
||||
self.finished_deferred.addBoth(_cancel_pending_iterations)
|
||||
self._iteration_count += 1
|
||||
# log.debug("iteration %i %s(%s...)", self._iteration_count, self.rpc, self.key.encode('hex')[:8])
|
||||
call, cancel = self.node.reactor_callLater(1, self._search_iteration_semaphore.run, self._searchIteration)
|
||||
self.pending_iteration_calls.append(cancel)
|
||||
|
||||
|
||||
def iterativeFind(node, shortlist, key, rpc):
|
||||
helper = _IterativeFind(node, shortlist, key, rpc)
|
||||
helper.searchIteration()
|
||||
return helper.finished_deferred
|
|
@ -26,6 +26,7 @@ from error import TimeoutError
|
|||
from peerfinder import DHTPeerFinder
|
||||
from contact import ContactManager
|
||||
from distance import Distance
|
||||
from iterativefind import iterativeFind
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -599,12 +600,7 @@ class Node(MockKademliaHelper):
|
|||
# This is used during the bootstrap process
|
||||
shortlist = startupShortlist
|
||||
|
||||
outerDf = defer.Deferred()
|
||||
|
||||
helper = _IterativeFindHelper(self, outerDf, shortlist, key, findValue, rpc)
|
||||
# Start the iterations
|
||||
helper.searchIteration()
|
||||
result = yield outerDf
|
||||
result = yield iterativeFind(self, shortlist, key, rpc)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -623,242 +619,3 @@ class Node(MockKademliaHelper):
|
|||
searchID = nodeIDs.pop()
|
||||
yield self.iterativeFindNode(searchID)
|
||||
defer.returnValue(None)
|
||||
|
||||
|
||||
# This was originally a set of nested methods in _iterativeFind
|
||||
# but they have been moved into this helper class in-order to
|
||||
# have better scoping and readability
|
||||
class _IterativeFindHelper(object):
|
||||
# TODO: use polymorphism to search for a value or node
|
||||
# instead of using a find_value flag
|
||||
def __init__(self, node, outer_d, shortlist, key, find_value, rpc):
|
||||
self.node = node
|
||||
self.outer_d = outer_d
|
||||
self.shortlist = shortlist
|
||||
self.key = key
|
||||
self.find_value = find_value
|
||||
self.rpc = rpc
|
||||
# all distance operations in this class only care about the distance
|
||||
# to self.key, so this makes it easier to calculate those
|
||||
self.distance = Distance(key)
|
||||
# List of active queries; len() indicates number of active probes
|
||||
#
|
||||
# n.b: using lists for these variables, because Python doesn't
|
||||
# allow binding a new value to a name in an enclosing
|
||||
# (non-global) scope
|
||||
self.active_probes = []
|
||||
# List of contact IDs that have already been queried
|
||||
self.already_contacted = []
|
||||
# Probes that were active during the previous iteration
|
||||
# A list of found and known-to-be-active remote nodes
|
||||
self.active_contacts = []
|
||||
# This should only contain one entry; the next scheduled iteration call
|
||||
self.pending_iteration_calls = []
|
||||
self.prev_closest_node = [None]
|
||||
self.find_value_result = {}
|
||||
self.slow_node_count = [0]
|
||||
|
||||
def extendShortlist(self, responseTuple):
|
||||
""" @type responseMsg: kademlia.msgtypes.ResponseMessage """
|
||||
# The "raw response" tuple contains the response message,
|
||||
# and the originating address info
|
||||
responseMsg = responseTuple[0]
|
||||
originAddress = responseTuple[1] # tuple: (ip adress, udp port)
|
||||
# Make sure the responding node is valid, and abort the operation if it isn't
|
||||
if responseMsg.nodeID in self.active_contacts or responseMsg.nodeID == self.node.node_id:
|
||||
return responseMsg.nodeID
|
||||
|
||||
# Mark this node as active
|
||||
aContact = self._getActiveContact(responseMsg, originAddress)
|
||||
self.active_contacts.append(aContact)
|
||||
|
||||
# This makes sure "bootstrap"-nodes with "fake" IDs don't get queried twice
|
||||
if responseMsg.nodeID not in self.already_contacted:
|
||||
self.already_contacted.append(responseMsg.nodeID)
|
||||
|
||||
# Now grow extend the (unverified) shortlist with the returned contacts
|
||||
result = responseMsg.response
|
||||
# TODO: some validation on the result (for guarding against attacks)
|
||||
# If we are looking for a value, first see if this result is the value
|
||||
# we are looking for before treating it as a list of contact triples
|
||||
if self.find_value is True and self.key in result and not 'contacts' in result:
|
||||
# We have found the value
|
||||
self.find_value_result[self.key] = result[self.key]
|
||||
else:
|
||||
if self.find_value is True:
|
||||
self._setClosestNodeValue(responseMsg, aContact)
|
||||
self._keepSearching(result)
|
||||
return responseMsg.nodeID
|
||||
|
||||
def _getActiveContact(self, responseMsg, originAddress):
|
||||
if responseMsg.nodeID in self.shortlist:
|
||||
# Get the contact information from the shortlist...
|
||||
return self.shortlist[self.shortlist.index(responseMsg.nodeID)]
|
||||
else:
|
||||
# If it's not in the shortlist; we probably used a fake ID to reach it
|
||||
# - reconstruct the contact, using the real node ID this time
|
||||
return Contact(
|
||||
responseMsg.nodeID, originAddress[0], originAddress[1], self.node._protocol)
|
||||
|
||||
def _keepSearching(self, result):
|
||||
contactTriples = self._getContactTriples(result)
|
||||
for contactTriple in contactTriples:
|
||||
self._addIfValid(contactTriple)
|
||||
|
||||
def _getContactTriples(self, result):
|
||||
if self.find_value is True:
|
||||
return result['contacts']
|
||||
else:
|
||||
return result
|
||||
|
||||
def _setClosestNodeValue(self, responseMsg, aContact):
|
||||
# We are looking for a value, and the remote node didn't have it
|
||||
# - mark it as the closest "empty" node, if it is
|
||||
if 'closestNodeNoValue' in self.find_value_result:
|
||||
if self._is_closer(responseMsg):
|
||||
self.find_value_result['closestNodeNoValue'] = aContact
|
||||
else:
|
||||
self.find_value_result['closestNodeNoValue'] = aContact
|
||||
|
||||
def _is_closer(self, responseMsg):
|
||||
return self.distance.is_closer(responseMsg.nodeID, self.active_contacts[0].id)
|
||||
|
||||
def _addIfValid(self, contactTriple):
|
||||
if isinstance(contactTriple, (list, tuple)) and len(contactTriple) == 3:
|
||||
testContact = Contact(
|
||||
contactTriple[0], contactTriple[1], contactTriple[2], self.node._protocol)
|
||||
if testContact not in self.shortlist:
|
||||
self.shortlist.append(testContact)
|
||||
|
||||
def removeFromShortlist(self, failure, deadContactID):
|
||||
""" @type failure: twisted.python.failure.Failure """
|
||||
failure.trap(TimeoutError, defer.CancelledError, TypeError)
|
||||
if len(deadContactID) != constants.key_bits / 8:
|
||||
raise ValueError("invalid lbry id")
|
||||
if deadContactID in self.shortlist:
|
||||
self.shortlist.remove(deadContactID)
|
||||
return deadContactID
|
||||
|
||||
def cancelActiveProbe(self, contactID):
|
||||
self.active_probes.pop()
|
||||
if len(self.active_probes) <= constants.alpha / 2 and len(self.pending_iteration_calls):
|
||||
# Force the iteration
|
||||
self.pending_iteration_calls[0].cancel()
|
||||
del self.pending_iteration_calls[0]
|
||||
self.searchIteration()
|
||||
|
||||
def sortByDistance(self, contact_list):
|
||||
"""Sort the list of contacts in order by distance from key"""
|
||||
ExpensiveSort(contact_list, self.distance.to_contact).sort()
|
||||
|
||||
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
|
||||
def searchIteration(self):
|
||||
self.slow_node_count[0] = len(self.active_probes)
|
||||
# Sort the discovered active nodes from closest to furthest
|
||||
self.sortByDistance(self.active_contacts)
|
||||
# This makes sure a returning probe doesn't force calling this function by mistake
|
||||
while len(self.pending_iteration_calls):
|
||||
del self.pending_iteration_calls[0]
|
||||
# See if should continue the search
|
||||
if self.key in self.find_value_result:
|
||||
self.outer_d.callback(self.find_value_result)
|
||||
return
|
||||
elif len(self.active_contacts) and self.find_value is False:
|
||||
if self._is_all_done():
|
||||
# TODO: Re-send the FIND_NODEs to all of the k closest nodes not already queried
|
||||
#
|
||||
# Ok, we're done; either we have accumulated k active
|
||||
# contacts or no improvement in closestNode has been
|
||||
# noted
|
||||
self.outer_d.callback(self.active_contacts)
|
||||
return
|
||||
|
||||
# The search continues...
|
||||
if len(self.active_contacts):
|
||||
self.prev_closest_node[0] = self.active_contacts[0]
|
||||
contactedNow = 0
|
||||
self.sortByDistance(self.shortlist)
|
||||
# Store the current shortList length before contacting other nodes
|
||||
prevShortlistLength = len(self.shortlist)
|
||||
for contact in self.shortlist:
|
||||
if contact.id not in self.already_contacted:
|
||||
self._probeContact(contact)
|
||||
contactedNow += 1
|
||||
if contactedNow == constants.alpha:
|
||||
break
|
||||
if self._should_lookup_active_calls():
|
||||
# Schedule the next iteration if there are any active
|
||||
# calls (Kademlia uses loose parallelism)
|
||||
call, _ = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
|
||||
self.pending_iteration_calls.append(call)
|
||||
# Check for a quick contact response that made an update to the shortList
|
||||
elif prevShortlistLength < len(self.shortlist):
|
||||
# Ensure that the closest contacts are taken from the updated shortList
|
||||
self.searchIteration()
|
||||
else:
|
||||
# If no probes were sent, there will not be any improvement, so we're done
|
||||
self.outer_d.callback(self.active_contacts)
|
||||
|
||||
def _probeContact(self, contact):
|
||||
self.active_probes.append(contact.id)
|
||||
rpcMethod = getattr(contact, self.rpc)
|
||||
df = rpcMethod(self.key, rawResponse=True)
|
||||
df.addCallback(self.extendShortlist)
|
||||
df.addErrback(self.removeFromShortlist, contact.id)
|
||||
df.addCallback(self.cancelActiveProbe)
|
||||
df.addErrback(lambda _: log.exception('Failed to contact %s', contact))
|
||||
self.already_contacted.append(contact.id)
|
||||
|
||||
def _should_lookup_active_calls(self):
|
||||
return (
|
||||
len(self.active_probes) > self.slow_node_count[0] or
|
||||
(
|
||||
len(self.shortlist) < constants.k and
|
||||
len(self.active_contacts) < len(self.shortlist) and
|
||||
len(self.active_probes) > 0
|
||||
)
|
||||
)
|
||||
|
||||
def _is_all_done(self):
|
||||
return (
|
||||
len(self.active_contacts) >= constants.k or
|
||||
(
|
||||
self.active_contacts[0] == self.prev_closest_node[0] and
|
||||
len(self.active_probes) == self.slow_node_count[0]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExpensiveSort(object):
|
||||
"""Sort a list in place.
|
||||
|
||||
The result of `key(item)` is cached for each item in the `to_sort`
|
||||
list as an optimization. This can be useful when `key` is
|
||||
expensive.
|
||||
|
||||
Attributes:
|
||||
to_sort: a list of items to sort
|
||||
key: callable, like `key` in normal python sort
|
||||
attr: the attribute name used to cache the value on each item.
|
||||
"""
|
||||
|
||||
def __init__(self, to_sort, key, attr='__value'):
|
||||
self.to_sort = to_sort
|
||||
self.key = key
|
||||
self.attr = attr
|
||||
|
||||
def sort(self):
|
||||
self._cacheValues()
|
||||
self._sortByValue()
|
||||
self._removeValue()
|
||||
|
||||
def _cacheValues(self):
|
||||
for item in self.to_sort:
|
||||
setattr(item, self.attr, self.key(item))
|
||||
|
||||
def _sortByValue(self):
|
||||
self.to_sort.sort(key=operator.attrgetter(self.attr))
|
||||
|
||||
def _removeValue(self):
|
||||
for item in self.to_sort:
|
||||
delattr(item, self.attr)
|
||||
|
|
Loading…
Reference in a new issue