add reactor arguments to Node

-adds reactor (clock) and reactor functions listenUDP, callLater, and resolve as arguments to Node.__init__
-set the reactor clock on LoopingCalls to make them easily testable
-convert callLater manage loops to LoopingCalls
This commit is contained in:
Jack Robison 2018-02-20 13:30:56 -05:00
parent efaa97216f
commit 6666468640
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 62 additions and 41 deletions

View file

@ -45,8 +45,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
self.hash_queue = collections.deque()
self._concurrent_announcers = 0
self._manage_call_lc = task.LoopingCall(self.manage_lc)
self._manage_call_lc.clock = dht_node.clock
self._lock = utils.DeferredLockContextManager(defer.DeferredLock())
self._last_checked = time.time(), self.CONCURRENT_ANNOUNCERS
self._last_checked = dht_node.clock.seconds(), self.CONCURRENT_ANNOUNCERS
self._total = None
def run_manage_loop(self):
@ -58,7 +59,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
last_time, last_hashes = self._last_checked
hashes = len(self.hash_queue)
if hashes:
t, h = time.time() - last_time, last_hashes - hashes
t, h = self.dht_node.clock.seconds() - last_time, last_hashes - hashes
blobs_per_second = float(h) / float(t)
if blobs_per_second > 0:
estimated_time_remaining = int(float(hashes) / blobs_per_second)
@ -108,7 +109,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
defer.returnValue(None)
log.info('Announcing %s hashes', len(hashes))
# TODO: add a timeit decorator
start = time.time()
start = self.dht_node.clock.seconds()
ds = []
with self._lock:
@ -174,9 +175,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
for _, announced_to in announcer_results:
stored_to.update(announced_to)
log.info('Took %s seconds to announce %s hashes', time.time() - start, len(hashes))
seconds_per_blob = (time.time() - start) / len(hashes)
self.supplier.set_single_hash_announce_duration(seconds_per_blob)
log.info('Took %s seconds to announce %s hashes', self.dht_node.clock.seconds() - start, len(hashes))
seconds_per_blob = (self.dht_node.clock.seconds() - start) / len(hashes)
self.set_single_hash_announce_duration(seconds_per_blob)
defer.returnValue(stored_to)

View file

@ -1,24 +1,22 @@
from collections import Counter
import datetime
from twisted.internet import task, threads
class HashWatcher(object):
def __init__(self):
def __init__(self, clock=None):
if not clock:
from twisted.internet import reactor as clock
self.ttl = 600
self.hashes = []
self.next_tick = None
self.lc = task.LoopingCall(self._remove_old_hashes)
self.lc.clock = clock
def tick(self):
from twisted.internet import reactor
self._remove_old_hashes()
self.next_tick = reactor.callLater(10, self.tick)
def start(self):
return self.lc.start(10)
def stop(self):
if self.next_tick is not None:
self.next_tick.cancel()
self.next_tick = None
return self.lc.stop()
def add_requested_hash(self, hashsum, contact):
from_ip = contact.compact_ip

View file

@ -11,7 +11,7 @@ import hashlib
import operator
import struct
import time
from twisted.internet import defer, error, reactor, task
from twisted.internet import defer, error, task
import constants
import routingtable
@ -52,9 +52,11 @@ class Node(object):
application is performed via this class (or a subclass).
"""
def __init__(self, node_id=None, udpPort=4000, dataStore=None,
def __init__(self, hash_announcer=None, node_id=None, udpPort=4000, dataStore=None,
routingTableClass=None, networkProtocol=None,
externalIP=None, peerPort=None):
externalIP=None, peerPort=None, listenUDP=None,
callLater=None, resolve=None, clock=None, peer_finder=None,
peer_manager=None):
"""
@param dataStore: The data store to use. This must be class inheriting
from the C{DataStore} interface (or providing the
@ -79,6 +81,17 @@ class Node(object):
@param externalIP: the IP at which this node can be contacted
@param peerPort: the port at which this node announces it has a blob for
"""
if not listenUDP or not resolve or not callLater or not clock:
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
self.reactor_resolve = resolve
self.reactor_listenUDP = listenUDP
self.reactor_callLater = callLater
self.clock = clock
self.node_id = node_id or self._generateID()
self.port = udpPort
self._listeningPort = None # object implementing Twisted
@ -89,12 +102,14 @@ class Node(object):
# operations before the node has finished joining the network)
self._joinDeferred = None
self.change_token_lc = task.LoopingCall(self.change_token)
self.change_token_lc.clock = self.clock
self.refresh_node_lc = task.LoopingCall(self._refreshNode)
self.refresh_node_lc.clock = self.clock
# Create k-buckets (for storing contacts)
if routingTableClass is None:
self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id)
self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id, self.clock.seconds)
else:
self._routingTable = routingTableClass(self.node_id)
self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
# Initialize this node's network access mechanisms
if networkProtocol is None:
@ -118,7 +133,7 @@ class Node(object):
self._routingTable.addContact(contact)
self.externalIP = externalIP
self.peerPort = peerPort
self.hash_watcher = HashWatcher()
self.hash_watcher = HashWatcher(self.clock)
# will be used later
self._can_store = True
@ -136,15 +151,18 @@ class Node(object):
def can_store(self):
return self._can_store is True
@defer.inlineCallbacks
def stop(self):
yield self.hash_announcer.stop()
# stop LoopingCalls:
if self.refresh_node_lc.running:
self.refresh_node_lc.stop()
yield self.refresh_node_lc.stop()
if self.change_token_lc.running:
self.change_token_lc.stop()
yield self.change_token_lc.stop()
if self._listeningPort is not None:
self._listeningPort.stopListening()
self.hash_watcher.stop()
yield self._listeningPort.stopListening()
if self.hash_watcher.lc.running:
yield self.hash_watcher.stop()
@defer.inlineCallbacks
def joinNetwork(self, known_node_addresses=None):
@ -183,10 +201,10 @@ class Node(object):
# Start refreshing k-buckets periodically, if necessary
self.hash_watcher.tick()
yield self._joinDeferred
self.hash_watcher.start()
self.change_token_lc.start(constants.tokenSecretChangeInterval)
self.refresh_node_lc.start(constants.checkRefreshInterval)
self.peer_finder.run_manage_loop()
self.hash_announcer.run_manage_loop()
#TODO: re-attempt joining the network if it fails
@ -828,7 +846,7 @@ class _IterativeFindHelper(object):
if self._should_lookup_active_calls():
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
call = reactor.callLater(constants.iterativeLookupDelay, self.searchIteration)
call = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
self.pending_iteration_calls.append(call)
# Check for a quick contact response that made an update to the shortList
elif prevShortlistLength < len(self.shortlist):

View file

@ -2,7 +2,7 @@ import binascii
import logging
from zope.interface import implements
from twisted.internet import defer, reactor
from twisted.internet import defer
from lbrynet.interfaces import IPeerFinder
from lbrynet.core.utils import short_hash
@ -63,7 +63,7 @@ class DHTPeerFinder(DummyPeerFinder):
finished_deferred = self.dht_node.getPeersForBlob(bin_hash)
if timeout is not None:
reactor.callLater(timeout, _trigger_timeout)
self.dht_node.reactor_callLater(timeout, _trigger_timeout)
try:
peer_list = yield finished_deferred

View file

@ -3,7 +3,7 @@ import time
import socket
import errno
from twisted.internet import protocol, defer, error, reactor, task
from twisted.internet import protocol, defer, error, task
import constants
import encoding
@ -47,6 +47,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._total_bytes_tx = 0
self._total_bytes_rx = 0
self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats)
self._bandwidth_stats_update_lc.clock = self._node.clock
def _update_bandwidth_stats(self):
recent_rx_history = {}
@ -168,7 +169,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
df._rpcRawResponse = True
# Set the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args)
@ -331,7 +332,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
"""Schedule the sending of the next UDP packet """
delay = self._delay()
key = object()
delayed_call = reactor.callLater(delay, self._write_and_remove, key, txData, address)
delayed_call = self._node.reactor_callLater(delay, self._write_and_remove, key, txData, address)
self._call_later_list[key] = delayed_call
def _write_and_remove(self, key, txData, address):
@ -428,7 +429,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, messageID)
timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args)
else:
# No progress has been made

View file

@ -5,7 +5,6 @@
# The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net
import time
import random
from zope.interface import implements
import constants
@ -34,7 +33,7 @@ class TreeRoutingTable(object):
"""
implements(IRoutingTable)
def __init__(self, parentNodeID):
def __init__(self, parentNodeID, getTime=None):
"""
@param parentNodeID: The n-bit node ID of the node to which this
routing table belongs
@ -43,6 +42,9 @@ class TreeRoutingTable(object):
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)]
self._parentNodeID = parentNodeID
if not getTime:
from time import time as getTime
self._getTime = getTime
def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already
@ -194,7 +196,7 @@ class TreeRoutingTable(object):
bucketIndex = startIndex
refreshIDs = []
for bucket in self._buckets[startIndex:]:
if force or (int(time.time()) - bucket.lastAccessed >= constants.refreshTimeout):
if force or (int(self._getTime()) - bucket.lastAccessed >= constants.refreshTimeout):
searchID = self._randomIDInBucketRange(bucketIndex)
refreshIDs.append(searchID)
bucketIndex += 1
@ -221,7 +223,7 @@ class TreeRoutingTable(object):
@type key: str
"""
bucketIndex = self._kbucketIndex(key)
self._buckets[bucketIndex].lastAccessed = int(time.time())
self._buckets[bucketIndex].lastAccessed = int(self._getTime())
def _kbucketIndex(self, key):
""" Calculate the index of the k-bucket which is responsible for the
@ -289,8 +291,8 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
of the 13-page version of the Kademlia paper.
"""
def __init__(self, parentNodeID):
TreeRoutingTable.__init__(self, parentNodeID)
def __init__(self, parentNodeID, getTime=None):
TreeRoutingTable.__init__(self, parentNodeID, getTime)
# Cache containing nodes eligible to replace stale k-bucket entries
self._replacementCache = {}
@ -301,6 +303,7 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
@param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact
"""
if contact.id == self._parentNodeID:
return