Merge remote-tracking branch 'origin/1439'

This commit is contained in:
Jack Robison 2018-10-08 09:15:46 -04:00
commit 315a557019
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 39 additions and 65 deletions

View file

@ -640,19 +640,14 @@ class Node(MockKademliaHelper):
replication/republishing as necessary """ replication/republishing as necessary """
yield self._refreshRoutingTable() yield self._refreshRoutingTable()
self._dataStore.removeExpiredPeers() self._dataStore.removeExpiredPeers()
yield self._refreshStoringPeers() self._refreshStoringPeers()
defer.returnValue(None) defer.returnValue(None)
def _refreshContacts(self): def _refreshContacts(self):
return defer.DeferredList( self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts, delay=0)
[self._protocol._ping_queue.enqueue_maybe_ping(contact, delay=0) for contact in self.contacts]
)
def _refreshStoringPeers(self): def _refreshStoringPeers(self):
storing_contacts = self._dataStore.getStoringContacts() self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts(), delay=0)
return defer.DeferredList(
[self._protocol._ping_queue.enqueue_maybe_ping(contact, delay=0) for contact in storing_contacts]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _refreshRoutingTable(self): def _refreshRoutingTable(self):

View file

@ -1,7 +1,6 @@
import logging import logging
import errno import errno
from binascii import hexlify from binascii import hexlify
from collections import deque
from twisted.internet import protocol, defer from twisted.internet import protocol, defer
from .error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected from .error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
@ -22,73 +21,53 @@ class PingQueue:
def __init__(self, node): def __init__(self, node):
self._node = node self._node = node
self._get_time = self._node.clock.seconds
self._queue = deque()
self._enqueued_contacts = {} self._enqueued_contacts = {}
self._semaphore = defer.DeferredSemaphore(1) self._pending_contacts = {}
self._ping_semaphore = defer.DeferredSemaphore(constants.alpha) self._process_lc = node.get_looping_call(self._process)
self._process_lc = node.get_looping_call(self._semaphore.run, self._process)
def _add_contact(self, contact, delay=None): def enqueue_maybe_ping(self, *contacts, **kwargs):
if (contact.address, contact.port) in [(c.address, c.port) for c in self._enqueued_contacts]: delay = kwargs.get('delay', constants.checkRefreshInterval)
return defer.succeed(None) no_op = (defer.succeed(None), lambda: None)
delay = delay or constants.checkRefreshInterval for contact in contacts:
self._enqueued_contacts[contact] = self._get_time() + delay if delay and contact not in self._enqueued_contacts:
self._queue.append(contact) self._pending_contacts.setdefault(contact, self._node.clock.seconds() + delay)
return defer.succeed(None) else:
self._enqueued_contacts.setdefault(contact, no_op)
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _ping(self, contact):
if not len(self._queue): if contact.contact_is_good:
defer.returnValue(None) return
contact = self._queue.popleft() try:
now = self._get_time() yield contact.ping()
except TimeoutError:
# if the oldest contact in the queue isn't old enough to be pinged, add it back to the queue and return pass
if now < self._enqueued_contacts[contact]: except Exception as err:
self._queue.appendleft(contact) log.warning("unexpected error: %s", err)
defer.returnValue(None) finally:
if contact in self._enqueued_contacts:
pinged = []
checked = []
while now > self._enqueued_contacts[contact]:
checked.append(contact)
if not contact.contact_is_good:
pinged.append(contact)
if not len(self._queue):
break
contact = self._queue.popleft()
if not now > self._enqueued_contacts[contact]:
checked.append(contact)
@defer.inlineCallbacks
def _ping(contact):
try:
yield contact.ping()
except TimeoutError:
pass
except Exception as err:
log.warning("unexpected error: %s", err)
yield defer.DeferredList([_ping(contact) for contact in pinged])
for contact in checked:
if contact in self._enqueued_contacts and contact in pinged:
del self._enqueued_contacts[contact] del self._enqueued_contacts[contact]
elif contact not in self._queue:
self._queue.appendleft(contact)
defer.returnValue(None) def _process(self):
# move contacts that are scheduled to join the queue
if self._pending_contacts:
now = self._node.clock.seconds()
for contact in [contact for contact, schedule in self._pending_contacts.items() if schedule <= now]:
del self._pending_contacts[contact]
self._enqueued_contacts.setdefault(contact, (defer.succeed(None), lambda: None))
# spread pings across 60 seconds to avoid flood and/or false negatives
step = 60.0/float(len(self._enqueued_contacts)) if self._enqueued_contacts else 0
for index, (contact, (call, _)) in enumerate(self._enqueued_contacts.items()):
if call.called and not contact.contact_is_good:
self._enqueued_contacts[contact] = self._node.reactor_callLater(index*step, self._ping, contact)
def start(self): def start(self):
return self._node.safe_start_looping_call(self._process_lc, 60) return self._node.safe_start_looping_call(self._process_lc, 60)
def stop(self): def stop(self):
map(None, (cancel() for _, (call, cancel) in self._enqueued_contacts.items() if not call.called))
return self._node.safe_stop_looping_call(self._process_lc) return self._node.safe_stop_looping_call(self._process_lc)
def enqueue_maybe_ping(self, contact, delay=None):
return self._semaphore.run(self._add_contact, contact, delay)
class KademliaProtocol(protocol.DatagramProtocol): class KademliaProtocol(protocol.DatagramProtocol):
""" Implements all low-level network-related functions of a Kademlia node """ """ Implements all low-level network-related functions of a Kademlia node """

View file

@ -47,8 +47,8 @@ class TestKademliaBase(unittest.TestCase):
:param step: reactor tick rate (in seconds) :param step: reactor tick rate (in seconds)
""" """
advanced = 0.0 advanced = 0.0
self.clock._sortCalls()
while advanced < n: while advanced < n:
self.clock._sortCalls()
if step: if step:
next_step = step next_step = step
elif self.clock.getDelayedCalls(): elif self.clock.getDelayedCalls():
@ -120,7 +120,7 @@ class TestKademliaBase(unittest.TestCase):
seed_dl = [] seed_dl = []
seeds = sorted(list(self.seed_dns.keys())) seeds = sorted(list(self.seed_dns.keys()))
known_addresses = [(seed_name, 4444) for seed_name in seeds] known_addresses = [(seed_name, 4444) for seed_name in seeds]
for seed_dns in seeds: for _ in range(len(seeds)):
self._add_next_node() self._add_next_node()
seed = self.nodes.pop() seed = self.nodes.pop()
self._seeds.append(seed) self._seeds.append(seed)