diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py index 52c99475c..0ab200b13 100644 --- a/lbrynet/dht/protocol.py +++ b/lbrynet/dht/protocol.py @@ -1,6 +1,7 @@ import logging import socket import errno +from collections import deque from twisted.internet import protocol, defer from lbrynet.core.call_later_manager import CallLaterManager @@ -14,6 +15,76 @@ import msgformat log = logging.getLogger(__name__) +class PingQueue(object): + """ + Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the + routing table after having been given enough time for a pinhole to expire. + """ + + def __init__(self, node): + self._node = node + self._get_time = self._node.clock.seconds + self._queue = deque() + self._enqueued_contacts = {} + self._semaphore = defer.DeferredSemaphore(1) + self._ping_semaphore = defer.DeferredSemaphore(constants.alpha) + self._process_lc = node.get_looping_call(self._semaphore.run, self._process) + self._delay = 300 + + def _add_contact(self, contact): + if contact in self._enqueued_contacts: + return defer.succeed(None) + self._enqueued_contacts[contact] = self._get_time() + self._delay + self._queue.append(contact) + return defer.succeed(None) + + @defer.inlineCallbacks + def _process(self): + if not len(self._queue): + defer.returnValue(None) + contact = self._queue.popleft() + now = self._get_time() + + # if the oldest contact in the queue isn't old enough to be pinged, add it back to the queue and return + if now < self._enqueued_contacts[contact]: + self._queue.appendleft(contact) + defer.returnValue(None) + + def _ping(contact): + d = contact.ping() + d.addErrback(lambda err: err.trap(TimeoutError)) + return d + + pinged = [] + checked = [] + while now > self._enqueued_contacts[contact]: + checked.append(contact) + if contact.contact_is_good is None: + pinged.append(contact) + if not len(self._queue): + break + contact = self._queue.popleft() + if not now > self._enqueued_contacts[contact]: + checked.append(contact) + # log.info("ping %i/%i peers", len(pinged), len(checked)) + + yield defer.DeferredList([self._ping_semaphore.run(_ping, contact) for contact in pinged]) + + for contact in checked: + if contact in self._enqueued_contacts: + del self._enqueued_contacts[contact] + + defer.returnValue(None) + + def start(self): + return self._node.safe_start_looping_call(self._process_lc, 60) + + def stop(self): + return self._node.safe_stop_looping_call(self._process_lc) + + def enqueue_maybe_ping(self, contact): + return self._semaphore.run(self._add_contact, contact) + class KademliaProtocol(protocol.DatagramProtocol): """ Implements all low-level network-related functions of a Kademlia node """ @@ -27,6 +98,7 @@ class KademliaProtocol(protocol.DatagramProtocol): self._partialMessages = {} self._partialMessagesProgress = {} self._listening = defer.Deferred(None) + self._ping_queue = PingQueue(self._node) def sendRPC(self, contact, method, args, rawResponse=False): """ @@ -100,6 +172,7 @@ class KademliaProtocol(protocol.DatagramProtocol): def startProtocol(self): log.info("DHT listening on UDP %s:%i", self._node.externalIP, self._node.port) self._listening.callback(True) + return self._ping_queue.start() def datagramReceived(self, datagram, address): """ Handles and parses incoming RPC messages (and responses) @@ -386,5 +459,6 @@ class KademliaProtocol(protocol.DatagramProtocol): Will only be called once, after all ports are disconnected. """ log.info('Stopping DHT') + self._ping_queue.stop() CallLaterManager.stop() log.info('DHT stopped')