From 6666468640a9bb2798c86490a3ff3294feed2191 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Tue, 20 Feb 2018 13:30:56 -0500 Subject: [PATCH] add reactor arguments to Node -adds reactor (clock) and reactor functions listenUDP, callLater, and resolve as arguments to Node.__init__ -set the reactor clock on LoopingCalls to make them easily testable -convert callLater manage loops to LoopingCalls --- lbrynet/dht/hashannouncer.py | 13 +++++------ lbrynet/dht/hashwatcher.py | 20 ++++++++--------- lbrynet/dht/node.py | 42 +++++++++++++++++++++++++----------- lbrynet/dht/peerfinder.py | 4 ++-- lbrynet/dht/protocol.py | 9 ++++---- lbrynet/dht/routingtable.py | 15 +++++++------ 6 files changed, 62 insertions(+), 41 deletions(-) diff --git a/lbrynet/dht/hashannouncer.py b/lbrynet/dht/hashannouncer.py index baf7a0667..5338a9e7c 100644 --- a/lbrynet/dht/hashannouncer.py +++ b/lbrynet/dht/hashannouncer.py @@ -45,8 +45,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer): self.hash_queue = collections.deque() self._concurrent_announcers = 0 self._manage_call_lc = task.LoopingCall(self.manage_lc) + self._manage_call_lc.clock = dht_node.clock self._lock = utils.DeferredLockContextManager(defer.DeferredLock()) - self._last_checked = time.time(), self.CONCURRENT_ANNOUNCERS + self._last_checked = dht_node.clock.seconds(), self.CONCURRENT_ANNOUNCERS self._total = None def run_manage_loop(self): @@ -58,7 +59,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer): last_time, last_hashes = self._last_checked hashes = len(self.hash_queue) if hashes: - t, h = time.time() - last_time, last_hashes - hashes + t, h = self.dht_node.clock.seconds() - last_time, last_hashes - hashes blobs_per_second = float(h) / float(t) if blobs_per_second > 0: estimated_time_remaining = int(float(hashes) / blobs_per_second) @@ -108,7 +109,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer): defer.returnValue(None) log.info('Announcing %s hashes', len(hashes)) # TODO: add a timeit decorator - start = time.time() + start = self.dht_node.clock.seconds() ds = [] with self._lock: @@ -174,9 +175,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer): for _, announced_to in announcer_results: stored_to.update(announced_to) - log.info('Took %s seconds to announce %s hashes', time.time() - start, len(hashes)) - seconds_per_blob = (time.time() - start) / len(hashes) - self.supplier.set_single_hash_announce_duration(seconds_per_blob) + log.info('Took %s seconds to announce %s hashes', self.dht_node.clock.seconds() - start, len(hashes)) + seconds_per_blob = (self.dht_node.clock.seconds() - start) / len(hashes) + self.set_single_hash_announce_duration(seconds_per_blob) defer.returnValue(stored_to) diff --git a/lbrynet/dht/hashwatcher.py b/lbrynet/dht/hashwatcher.py index 3f9699de2..80aa30b6a 100644 --- a/lbrynet/dht/hashwatcher.py +++ b/lbrynet/dht/hashwatcher.py @@ -1,24 +1,22 @@ from collections import Counter import datetime +from twisted.internet import task, threads class HashWatcher(object): - def __init__(self): + def __init__(self, clock=None): + if not clock: + from twisted.internet import reactor as clock self.ttl = 600 self.hashes = [] - self.next_tick = None + self.lc = task.LoopingCall(self._remove_old_hashes) + self.lc.clock = clock - def tick(self): - - from twisted.internet import reactor - - self._remove_old_hashes() - self.next_tick = reactor.callLater(10, self.tick) + def start(self): + return self.lc.start(10) def stop(self): - if self.next_tick is not None: - self.next_tick.cancel() - self.next_tick = None + return self.lc.stop() def add_requested_hash(self, hashsum, contact): from_ip = contact.compact_ip diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 489a16957..c5cc27678 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -11,7 +11,7 @@ import hashlib import operator import struct import time -from twisted.internet import defer, error, reactor, task +from twisted.internet import defer, error, task import constants import routingtable @@ -52,9 +52,11 @@ class Node(object): application is performed via this class (or a subclass). """ - def __init__(self, node_id=None, udpPort=4000, dataStore=None, + def __init__(self, hash_announcer=None, node_id=None, udpPort=4000, dataStore=None, routingTableClass=None, networkProtocol=None, - externalIP=None, peerPort=None): + externalIP=None, peerPort=None, listenUDP=None, + callLater=None, resolve=None, clock=None, peer_finder=None, + peer_manager=None): """ @param dataStore: The data store to use. This must be class inheriting from the C{DataStore} interface (or providing the @@ -79,6 +81,17 @@ class Node(object): @param externalIP: the IP at which this node can be contacted @param peerPort: the port at which this node announces it has a blob for """ + + if not listenUDP or not resolve or not callLater or not clock: + from twisted.internet import reactor + listenUDP = listenUDP or reactor.listenUDP + resolve = resolve or reactor.resolve + callLater = callLater or reactor.callLater + clock = clock or reactor + self.reactor_resolve = resolve + self.reactor_listenUDP = listenUDP + self.reactor_callLater = callLater + self.clock = clock self.node_id = node_id or self._generateID() self.port = udpPort self._listeningPort = None # object implementing Twisted @@ -89,12 +102,14 @@ class Node(object): # operations before the node has finished joining the network) self._joinDeferred = None self.change_token_lc = task.LoopingCall(self.change_token) + self.change_token_lc.clock = self.clock self.refresh_node_lc = task.LoopingCall(self._refreshNode) + self.refresh_node_lc.clock = self.clock # Create k-buckets (for storing contacts) if routingTableClass is None: - self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id) + self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id, self.clock.seconds) else: - self._routingTable = routingTableClass(self.node_id) + self._routingTable = routingTableClass(self.node_id, self.clock.seconds) # Initialize this node's network access mechanisms if networkProtocol is None: @@ -118,7 +133,7 @@ class Node(object): self._routingTable.addContact(contact) self.externalIP = externalIP self.peerPort = peerPort - self.hash_watcher = HashWatcher() + self.hash_watcher = HashWatcher(self.clock) # will be used later self._can_store = True @@ -136,15 +151,18 @@ class Node(object): def can_store(self): return self._can_store is True + @defer.inlineCallbacks def stop(self): + yield self.hash_announcer.stop() # stop LoopingCalls: if self.refresh_node_lc.running: - self.refresh_node_lc.stop() + yield self.refresh_node_lc.stop() if self.change_token_lc.running: - self.change_token_lc.stop() + yield self.change_token_lc.stop() if self._listeningPort is not None: - self._listeningPort.stopListening() - self.hash_watcher.stop() + yield self._listeningPort.stopListening() + if self.hash_watcher.lc.running: + yield self.hash_watcher.stop() @defer.inlineCallbacks def joinNetwork(self, known_node_addresses=None): @@ -183,10 +201,10 @@ class Node(object): # Start refreshing k-buckets periodically, if necessary self.hash_watcher.tick() yield self._joinDeferred + self.hash_watcher.start() self.change_token_lc.start(constants.tokenSecretChangeInterval) self.refresh_node_lc.start(constants.checkRefreshInterval) - self.peer_finder.run_manage_loop() self.hash_announcer.run_manage_loop() #TODO: re-attempt joining the network if it fails @@ -828,7 +846,7 @@ class _IterativeFindHelper(object): if self._should_lookup_active_calls(): # Schedule the next iteration if there are any active # calls (Kademlia uses loose parallelism) - call = reactor.callLater(constants.iterativeLookupDelay, self.searchIteration) + call = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration) self.pending_iteration_calls.append(call) # Check for a quick contact response that made an update to the shortList elif prevShortlistLength < len(self.shortlist): diff --git a/lbrynet/dht/peerfinder.py b/lbrynet/dht/peerfinder.py index 9e4dd167d..afbbddd6b 100644 --- a/lbrynet/dht/peerfinder.py +++ b/lbrynet/dht/peerfinder.py @@ -2,7 +2,7 @@ import binascii import logging from zope.interface import implements -from twisted.internet import defer, reactor +from twisted.internet import defer from lbrynet.interfaces import IPeerFinder from lbrynet.core.utils import short_hash @@ -63,7 +63,7 @@ class DHTPeerFinder(DummyPeerFinder): finished_deferred = self.dht_node.getPeersForBlob(bin_hash) if timeout is not None: - reactor.callLater(timeout, _trigger_timeout) + self.dht_node.reactor_callLater(timeout, _trigger_timeout) try: peer_list = yield finished_deferred diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py index 5656f29bb..d22f9dfa9 100644 --- a/lbrynet/dht/protocol.py +++ b/lbrynet/dht/protocol.py @@ -3,7 +3,7 @@ import time import socket import errno -from twisted.internet import protocol, defer, error, reactor, task +from twisted.internet import protocol, defer, error, task import constants import encoding @@ -47,6 +47,7 @@ class KademliaProtocol(protocol.DatagramProtocol): self._total_bytes_tx = 0 self._total_bytes_rx = 0 self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats) + self._bandwidth_stats_update_lc.clock = self._node.clock def _update_bandwidth_stats(self): recent_rx_history = {} @@ -168,7 +169,7 @@ class KademliaProtocol(protocol.DatagramProtocol): df._rpcRawResponse = True # Set the RPC timeout timer - timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, msg.id) + timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id) # Transmit the data self._send(encodedMsg, msg.id, (contact.address, contact.port)) self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args) @@ -331,7 +332,7 @@ class KademliaProtocol(protocol.DatagramProtocol): """Schedule the sending of the next UDP packet """ delay = self._delay() key = object() - delayed_call = reactor.callLater(delay, self._write_and_remove, key, txData, address) + delayed_call = self._node.reactor_callLater(delay, self._write_and_remove, key, txData, address) self._call_later_list[key] = delayed_call def _write_and_remove(self, key, txData, address): @@ -428,7 +429,7 @@ class KademliaProtocol(protocol.DatagramProtocol): # See if any progress has been made; if not, kill the message if self._hasProgressBeenMade(messageID): # Reset the RPC timeout timer - timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, messageID) + timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID) self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) else: # No progress has been made diff --git a/lbrynet/dht/routingtable.py b/lbrynet/dht/routingtable.py index 863f37770..02f8e9686 100644 --- a/lbrynet/dht/routingtable.py +++ b/lbrynet/dht/routingtable.py @@ -5,7 +5,6 @@ # The docstrings in this module contain epytext markup; API documentation # may be created by processing this file with epydoc: http://epydoc.sf.net -import time import random from zope.interface import implements import constants @@ -34,7 +33,7 @@ class TreeRoutingTable(object): """ implements(IRoutingTable) - def __init__(self, parentNodeID): + def __init__(self, parentNodeID, getTime=None): """ @param parentNodeID: The n-bit node ID of the node to which this routing table belongs @@ -43,6 +42,9 @@ class TreeRoutingTable(object): # Create the initial (single) k-bucket covering the range of the entire n-bit ID space self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)] self._parentNodeID = parentNodeID + if not getTime: + from time import time as getTime + self._getTime = getTime def addContact(self, contact): """ Add the given contact to the correct k-bucket; if it already @@ -194,7 +196,7 @@ class TreeRoutingTable(object): bucketIndex = startIndex refreshIDs = [] for bucket in self._buckets[startIndex:]: - if force or (int(time.time()) - bucket.lastAccessed >= constants.refreshTimeout): + if force or (int(self._getTime()) - bucket.lastAccessed >= constants.refreshTimeout): searchID = self._randomIDInBucketRange(bucketIndex) refreshIDs.append(searchID) bucketIndex += 1 @@ -221,7 +223,7 @@ class TreeRoutingTable(object): @type key: str """ bucketIndex = self._kbucketIndex(key) - self._buckets[bucketIndex].lastAccessed = int(time.time()) + self._buckets[bucketIndex].lastAccessed = int(self._getTime()) def _kbucketIndex(self, key): """ Calculate the index of the k-bucket which is responsible for the @@ -289,8 +291,8 @@ class OptimizedTreeRoutingTable(TreeRoutingTable): of the 13-page version of the Kademlia paper. """ - def __init__(self, parentNodeID): - TreeRoutingTable.__init__(self, parentNodeID) + def __init__(self, parentNodeID, getTime=None): + TreeRoutingTable.__init__(self, parentNodeID, getTime) # Cache containing nodes eligible to replace stale k-bucket entries self._replacementCache = {} @@ -301,6 +303,7 @@ class OptimizedTreeRoutingTable(TreeRoutingTable): @param contact: The contact to add to this node's k-buckets @type contact: kademlia.contact.Contact """ + if contact.id == self._parentNodeID: return