refactor ping queue

This commit is contained in:
Victor Shyba 2018-09-29 23:53:40 -03:00
parent 4890fdfb50
commit ea6b2b98fb
3 changed files with 26 additions and 49 deletions

View file

@ -644,10 +644,10 @@ class Node(MockKademliaHelper):
defer.returnValue(None) defer.returnValue(None)
def _refreshContacts(self): def _refreshContacts(self):
self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts, delay=0) self._protocol._ping_queue.enqueue_maybe_ping(*self.contacts)
def _refreshStoringPeers(self): def _refreshStoringPeers(self):
self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts(), delay=0) self._protocol._ping_queue.enqueue_maybe_ping(*self._dataStore.getStoringContacts())
@defer.inlineCallbacks @defer.inlineCallbacks
def _refreshRoutingTable(self): def _refreshRoutingTable(self):

View file

@ -1,7 +1,6 @@
import logging import logging
import errno import errno
from binascii import hexlify from binascii import hexlify
from collections import deque
from twisted.internet import protocol, defer from twisted.internet import protocol, defer
from .error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected from .error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
@ -22,58 +21,36 @@ class PingQueue:
def __init__(self, node): def __init__(self, node):
self._node = node self._node = node
self._get_time = self._node.clock.seconds
self._queue = deque()
self._enqueued_contacts = {} self._enqueued_contacts = {}
self._process_lc = node.get_looping_call(self._process) self._process_lc = node.get_looping_call(self._process)
def enqueue_maybe_ping(self, *contacts, **kwargs): def enqueue_maybe_ping(self, *contacts, **kwargs):
schedule = self._get_time() + (kwargs.get('delay', constants.checkRefreshInterval)) no_op = (defer.succeed(None), lambda: None)
for contact in contacts: for contact in contacts:
if contact not in self._enqueued_contacts: self._enqueued_contacts.setdefault(contact, no_op)
self._enqueued_contacts[contact] = schedule
self._queue.append(contact)
def _process(self):
if not len(self._queue):
return
contact = self._queue.popleft()
now = self._get_time()
# if the oldest contact in the queue isn't old enough to be pinged, add it back to the queue and return
if now < self._enqueued_contacts[contact]:
self._queue.appendleft(contact)
return
pinged = []
checked = []
while now > self._enqueued_contacts[contact]:
checked.append(contact)
if not contact.contact_is_good:
pinged.append(contact)
if not len(self._queue):
break
contact = self._queue.popleft()
if not now > self._enqueued_contacts[contact]:
checked.append(contact)
@defer.inlineCallbacks @defer.inlineCallbacks
def _ping(contact): def _ping(self, contact):
if contact.contact_is_good:
return
try: try:
yield contact.ping() yield contact.ping()
except TimeoutError: except TimeoutError:
pass pass
except Exception as err: except Exception as err:
log.warning("unexpected error: %s", err) log.warning("unexpected error: %s", err)
finally:
d = defer.DeferredList([_ping(contact) for contact in pinged]) if contact in self._enqueued_contacts:
for contact in checked:
if contact in self._enqueued_contacts and contact in pinged:
del self._enqueued_contacts[contact] del self._enqueued_contacts[contact]
elif contact not in self._queue:
self._queue.appendleft(contact) def _process(self):
return d if not self._enqueued_contacts:
return
# spread pings across 60 seconds to avoid flood and/or false negatives
step = 60.0/float(len(self._enqueued_contacts))
for index, (contact, (call, _)) in enumerate(self._enqueued_contacts.items()):
if call.called and not contact.contact_is_good:
self._enqueued_contacts[contact] = self._node.reactor_callLater(index*step, self._ping, contact)
def start(self): def start(self):
return self._node.safe_start_looping_call(self._process_lc, 60) return self._node.safe_start_looping_call(self._process_lc, 60)

View file

@ -47,8 +47,8 @@ class TestKademliaBase(unittest.TestCase):
:param step: reactor tick rate (in seconds) :param step: reactor tick rate (in seconds)
""" """
advanced = 0.0 advanced = 0.0
while advanced < n:
self.clock._sortCalls() self.clock._sortCalls()
while advanced < n:
if step: if step:
next_step = step next_step = step
elif self.clock.getDelayedCalls(): elif self.clock.getDelayedCalls():
@ -120,7 +120,7 @@ class TestKademliaBase(unittest.TestCase):
seed_dl = [] seed_dl = []
seeds = sorted(list(self.seed_dns.keys())) seeds = sorted(list(self.seed_dns.keys()))
known_addresses = [(seed_name, 4444) for seed_name in seeds] known_addresses = [(seed_name, 4444) for seed_name in seeds]
for seed_dns in seeds: for _ in range(len(seeds)):
self._add_next_node() self._add_next_node()
seed = self.nodes.pop() seed = self.nodes.pop()
self._seeds.append(seed) self._seeds.append(seed)