Add distance optimization

This commit is contained in:
Job Evers-Meltzer 2016-12-13 20:53:24 -06:00
parent 740fad5cbe
commit 0084d4684f

View file

@ -6,10 +6,16 @@
# #
# The docstrings in this module contain epytext markup; API documentation # The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
import hashlib, random, struct, time, binascii
import argparse import argparse
import binascii
import hashlib
import operator
import random
import struct
import time
from twisted.internet import defer, error from twisted.internet import defer, error
import constants import constants
import routingtable import routingtable
import datastore import datastore
@ -645,6 +651,9 @@ class _IterativeFindHelper(object):
self.key = key self.key = key
self.find_value = find_value self.find_value = find_value
self.rpc = rpc self.rpc = rpc
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# List of active queries; len() indicates number of active probes # List of active queries; len() indicates number of active probes
# #
# n.b: using lists for these variables, because Python doesn't # n.b: using lists for these variables, because Python doesn't
@ -725,10 +734,7 @@ class _IterativeFindHelper(object):
self.find_value_result['closestNodeNoValue'] = aContact self.find_value_result['closestNodeNoValue'] = aContact
def _is_closer(self, responseMsg): def _is_closer(self, responseMsg):
return ( return self.distance.is_closer(responseMsg.nodeID, self.active_contacts[0].id)
self.node._routingTable.distance(self.key, responseMsg.nodeID) <
self.node._routingTable.distance(self.key, self.active_contacts[0].id)
)
def _addIfValid(self, contactTriple): def _addIfValid(self, contactTriple):
if isinstance(contactTriple, (list, tuple)) and len(contactTriple) == 3: if isinstance(contactTriple, (list, tuple)) and len(contactTriple) == 3:
@ -753,17 +759,15 @@ class _IterativeFindHelper(object):
del self.pending_iteration_calls[0] del self.pending_iteration_calls[0]
self.searchIteration() self.searchIteration()
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
ExpensiveSort(contact_list, self.distance.to_contact).sort()
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts # Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
def searchIteration(self): def searchIteration(self):
self.slow_node_count[0] = len(self.active_probes) self.slow_node_count[0] = len(self.active_probes)
# TODO: move sort_key to be a method on the class
def sort_key(firstContact, secondContact, targetKey=self.key):
return cmp(
self.node._routingTable.distance(firstContact.id, targetKey),
self.node._routingTable.distance(secondContact.id, targetKey)
)
# Sort the discovered active nodes from closest to furthest # Sort the discovered active nodes from closest to furthest
self.active_contacts.sort(sort_key) self.sortByDistance(self.active_contacts)
# This makes sure a returning probe doesn't force calling this function by mistake # This makes sure a returning probe doesn't force calling this function by mistake
while len(self.pending_iteration_calls): while len(self.pending_iteration_calls):
del self.pending_iteration_calls[0] del self.pending_iteration_calls[0]
@ -784,7 +788,7 @@ class _IterativeFindHelper(object):
if len(self.active_contacts): if len(self.active_contacts):
self.prev_closest_node[0] = self.active_contacts[0] self.prev_closest_node[0] = self.active_contacts[0]
contactedNow = 0 contactedNow = 0
self.shortlist.sort(sort_key) self.sortByDistance(self.shortlist)
# Store the current shortList length before contacting other nodes # Store the current shortList length before contacting other nodes
prevShortlistLength = len(self.shortlist) prevShortlistLength = len(self.shortlist)
for contact in self.shortlist: for contact in self.shortlist:
@ -837,6 +841,62 @@ class _IterativeFindHelper(object):
) )
class Distance(object):
"""Calculate the XOR result between two string variables.
Frequently we re-use one of the points so as an optimization
we pre-calculate the long value of that point.
"""
def __init__(self, key):
self.key = key
self.val_key_one = long(key.encode('hex'), 16)
def __call__(self, key_two):
val_key_two = long(key_two.encode('hex'), 16)
return self.val_key_one ^ val_key_two
def is_closer(self, a, b):
"""Returns true is `a` is closer to `key` than `b` is"""
return self(a) < self(b)
def to_contact(self, contact):
"""A convenience function for calculating the distance to a contact"""
return self(contact.id)
class ExpensiveSort(object):
"""Sort a list in place.
The result of `key(item)` is cached for each item in the `to_sort`
list as an optimization. This can be useful when `key` is
expensive.
Attributes:
to_sort: a list of items to sort
key: callable, like `key` in normal python sort
attr: the attribute name used to cache the value on each item.
"""
def __init__(self, to_sort, key, attr='__value'):
self.to_sort = to_sort
self.key = key
self.attr = attr
def sort(self):
self._cacheValues()
self._sortByValue()
self._removeValue()
def _cacheValues(self):
for item in self.to_sort:
setattr(item, self.attr, self.key(item))
def _sortByValue(self):
self.to_sort.sort(key=operator.attrgetter(self.attr))
def _removeValue(self):
for item in self.to_sort:
delattr(item, self.attr)
def main(): def main():
parser = argparse.ArgumentParser(description="Launch a dht node") parser = argparse.ArgumentParser(description="Launch a dht node")