add reactor arguments to Node

-adds reactor (clock) and reactor functions listenUDP, callLater, and resolve as arguments to Node.__init__
-set the reactor clock on LoopingCalls to make them easily testable
-convert callLater manage loops to LoopingCalls
This commit is contained in:
Jack Robison 2018-02-20 13:30:56 -05:00
parent efaa97216f
commit 6666468640
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
6 changed files with 62 additions and 41 deletions

View file

@ -45,8 +45,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
self.hash_queue = collections.deque() self.hash_queue = collections.deque()
self._concurrent_announcers = 0 self._concurrent_announcers = 0
self._manage_call_lc = task.LoopingCall(self.manage_lc) self._manage_call_lc = task.LoopingCall(self.manage_lc)
self._manage_call_lc.clock = dht_node.clock
self._lock = utils.DeferredLockContextManager(defer.DeferredLock()) self._lock = utils.DeferredLockContextManager(defer.DeferredLock())
self._last_checked = time.time(), self.CONCURRENT_ANNOUNCERS self._last_checked = dht_node.clock.seconds(), self.CONCURRENT_ANNOUNCERS
self._total = None self._total = None
def run_manage_loop(self): def run_manage_loop(self):
@ -58,7 +59,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
last_time, last_hashes = self._last_checked last_time, last_hashes = self._last_checked
hashes = len(self.hash_queue) hashes = len(self.hash_queue)
if hashes: if hashes:
t, h = time.time() - last_time, last_hashes - hashes t, h = self.dht_node.clock.seconds() - last_time, last_hashes - hashes
blobs_per_second = float(h) / float(t) blobs_per_second = float(h) / float(t)
if blobs_per_second > 0: if blobs_per_second > 0:
estimated_time_remaining = int(float(hashes) / blobs_per_second) estimated_time_remaining = int(float(hashes) / blobs_per_second)
@ -108,7 +109,7 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
defer.returnValue(None) defer.returnValue(None)
log.info('Announcing %s hashes', len(hashes)) log.info('Announcing %s hashes', len(hashes))
# TODO: add a timeit decorator # TODO: add a timeit decorator
start = time.time() start = self.dht_node.clock.seconds()
ds = [] ds = []
with self._lock: with self._lock:
@ -174,9 +175,9 @@ class DHTHashAnnouncer(DummyHashAnnouncer):
for _, announced_to in announcer_results: for _, announced_to in announcer_results:
stored_to.update(announced_to) stored_to.update(announced_to)
log.info('Took %s seconds to announce %s hashes', time.time() - start, len(hashes)) log.info('Took %s seconds to announce %s hashes', self.dht_node.clock.seconds() - start, len(hashes))
seconds_per_blob = (time.time() - start) / len(hashes) seconds_per_blob = (self.dht_node.clock.seconds() - start) / len(hashes)
self.supplier.set_single_hash_announce_duration(seconds_per_blob) self.set_single_hash_announce_duration(seconds_per_blob)
defer.returnValue(stored_to) defer.returnValue(stored_to)

View file

@ -1,24 +1,22 @@
from collections import Counter from collections import Counter
import datetime import datetime
from twisted.internet import task, threads
class HashWatcher(object): class HashWatcher(object):
def __init__(self): def __init__(self, clock=None):
if not clock:
from twisted.internet import reactor as clock
self.ttl = 600 self.ttl = 600
self.hashes = [] self.hashes = []
self.next_tick = None self.lc = task.LoopingCall(self._remove_old_hashes)
self.lc.clock = clock
def tick(self): def start(self):
return self.lc.start(10)
from twisted.internet import reactor
self._remove_old_hashes()
self.next_tick = reactor.callLater(10, self.tick)
def stop(self): def stop(self):
if self.next_tick is not None: return self.lc.stop()
self.next_tick.cancel()
self.next_tick = None
def add_requested_hash(self, hashsum, contact): def add_requested_hash(self, hashsum, contact):
from_ip = contact.compact_ip from_ip = contact.compact_ip

View file

@ -11,7 +11,7 @@ import hashlib
import operator import operator
import struct import struct
import time import time
from twisted.internet import defer, error, reactor, task from twisted.internet import defer, error, task
import constants import constants
import routingtable import routingtable
@ -52,9 +52,11 @@ class Node(object):
application is performed via this class (or a subclass). application is performed via this class (or a subclass).
""" """
def __init__(self, node_id=None, udpPort=4000, dataStore=None, def __init__(self, hash_announcer=None, node_id=None, udpPort=4000, dataStore=None,
routingTableClass=None, networkProtocol=None, routingTableClass=None, networkProtocol=None,
externalIP=None, peerPort=None): externalIP=None, peerPort=None, listenUDP=None,
callLater=None, resolve=None, clock=None, peer_finder=None,
peer_manager=None):
""" """
@param dataStore: The data store to use. This must be class inheriting @param dataStore: The data store to use. This must be class inheriting
from the C{DataStore} interface (or providing the from the C{DataStore} interface (or providing the
@ -79,6 +81,17 @@ class Node(object):
@param externalIP: the IP at which this node can be contacted @param externalIP: the IP at which this node can be contacted
@param peerPort: the port at which this node announces it has a blob for @param peerPort: the port at which this node announces it has a blob for
""" """
if not listenUDP or not resolve or not callLater or not clock:
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
self.reactor_resolve = resolve
self.reactor_listenUDP = listenUDP
self.reactor_callLater = callLater
self.clock = clock
self.node_id = node_id or self._generateID() self.node_id = node_id or self._generateID()
self.port = udpPort self.port = udpPort
self._listeningPort = None # object implementing Twisted self._listeningPort = None # object implementing Twisted
@ -89,12 +102,14 @@ class Node(object):
# operations before the node has finished joining the network) # operations before the node has finished joining the network)
self._joinDeferred = None self._joinDeferred = None
self.change_token_lc = task.LoopingCall(self.change_token) self.change_token_lc = task.LoopingCall(self.change_token)
self.change_token_lc.clock = self.clock
self.refresh_node_lc = task.LoopingCall(self._refreshNode) self.refresh_node_lc = task.LoopingCall(self._refreshNode)
self.refresh_node_lc.clock = self.clock
# Create k-buckets (for storing contacts) # Create k-buckets (for storing contacts)
if routingTableClass is None: if routingTableClass is None:
self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id) self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id, self.clock.seconds)
else: else:
self._routingTable = routingTableClass(self.node_id) self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
# Initialize this node's network access mechanisms # Initialize this node's network access mechanisms
if networkProtocol is None: if networkProtocol is None:
@ -118,7 +133,7 @@ class Node(object):
self._routingTable.addContact(contact) self._routingTable.addContact(contact)
self.externalIP = externalIP self.externalIP = externalIP
self.peerPort = peerPort self.peerPort = peerPort
self.hash_watcher = HashWatcher() self.hash_watcher = HashWatcher(self.clock)
# will be used later # will be used later
self._can_store = True self._can_store = True
@ -136,15 +151,18 @@ class Node(object):
def can_store(self): def can_store(self):
return self._can_store is True return self._can_store is True
@defer.inlineCallbacks
def stop(self): def stop(self):
yield self.hash_announcer.stop()
# stop LoopingCalls: # stop LoopingCalls:
if self.refresh_node_lc.running: if self.refresh_node_lc.running:
self.refresh_node_lc.stop() yield self.refresh_node_lc.stop()
if self.change_token_lc.running: if self.change_token_lc.running:
self.change_token_lc.stop() yield self.change_token_lc.stop()
if self._listeningPort is not None: if self._listeningPort is not None:
self._listeningPort.stopListening() yield self._listeningPort.stopListening()
self.hash_watcher.stop() if self.hash_watcher.lc.running:
yield self.hash_watcher.stop()
@defer.inlineCallbacks @defer.inlineCallbacks
def joinNetwork(self, known_node_addresses=None): def joinNetwork(self, known_node_addresses=None):
@ -183,10 +201,10 @@ class Node(object):
# Start refreshing k-buckets periodically, if necessary # Start refreshing k-buckets periodically, if necessary
self.hash_watcher.tick() self.hash_watcher.tick()
yield self._joinDeferred yield self._joinDeferred
self.hash_watcher.start()
self.change_token_lc.start(constants.tokenSecretChangeInterval) self.change_token_lc.start(constants.tokenSecretChangeInterval)
self.refresh_node_lc.start(constants.checkRefreshInterval) self.refresh_node_lc.start(constants.checkRefreshInterval)
self.peer_finder.run_manage_loop()
self.hash_announcer.run_manage_loop() self.hash_announcer.run_manage_loop()
#TODO: re-attempt joining the network if it fails #TODO: re-attempt joining the network if it fails
@ -828,7 +846,7 @@ class _IterativeFindHelper(object):
if self._should_lookup_active_calls(): if self._should_lookup_active_calls():
# Schedule the next iteration if there are any active # Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism) # calls (Kademlia uses loose parallelism)
call = reactor.callLater(constants.iterativeLookupDelay, self.searchIteration) call = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
self.pending_iteration_calls.append(call) self.pending_iteration_calls.append(call)
# Check for a quick contact response that made an update to the shortList # Check for a quick contact response that made an update to the shortList
elif prevShortlistLength < len(self.shortlist): elif prevShortlistLength < len(self.shortlist):

View file

@ -2,7 +2,7 @@ import binascii
import logging import logging
from zope.interface import implements from zope.interface import implements
from twisted.internet import defer, reactor from twisted.internet import defer
from lbrynet.interfaces import IPeerFinder from lbrynet.interfaces import IPeerFinder
from lbrynet.core.utils import short_hash from lbrynet.core.utils import short_hash
@ -63,7 +63,7 @@ class DHTPeerFinder(DummyPeerFinder):
finished_deferred = self.dht_node.getPeersForBlob(bin_hash) finished_deferred = self.dht_node.getPeersForBlob(bin_hash)
if timeout is not None: if timeout is not None:
reactor.callLater(timeout, _trigger_timeout) self.dht_node.reactor_callLater(timeout, _trigger_timeout)
try: try:
peer_list = yield finished_deferred peer_list = yield finished_deferred

View file

@ -3,7 +3,7 @@ import time
import socket import socket
import errno import errno
from twisted.internet import protocol, defer, error, reactor, task from twisted.internet import protocol, defer, error, task
import constants import constants
import encoding import encoding
@ -47,6 +47,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._total_bytes_tx = 0 self._total_bytes_tx = 0
self._total_bytes_rx = 0 self._total_bytes_rx = 0
self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats) self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats)
self._bandwidth_stats_update_lc.clock = self._node.clock
def _update_bandwidth_stats(self): def _update_bandwidth_stats(self):
recent_rx_history = {} recent_rx_history = {}
@ -168,7 +169,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
df._rpcRawResponse = True df._rpcRawResponse = True
# Set the RPC timeout timer # Set the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, msg.id) timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data # Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port)) self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args) self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args)
@ -331,7 +332,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
"""Schedule the sending of the next UDP packet """ """Schedule the sending of the next UDP packet """
delay = self._delay() delay = self._delay()
key = object() key = object()
delayed_call = reactor.callLater(delay, self._write_and_remove, key, txData, address) delayed_call = self._node.reactor_callLater(delay, self._write_and_remove, key, txData, address)
self._call_later_list[key] = delayed_call self._call_later_list[key] = delayed_call
def _write_and_remove(self, key, txData, address): def _write_and_remove(self, key, txData, address):
@ -428,7 +429,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# See if any progress has been made; if not, kill the message # See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID): if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer # Reset the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, messageID) timeoutCall = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args)
else: else:
# No progress has been made # No progress has been made

View file

@ -5,7 +5,6 @@
# The docstrings in this module contain epytext markup; API documentation # The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
import time
import random import random
from zope.interface import implements from zope.interface import implements
import constants import constants
@ -34,7 +33,7 @@ class TreeRoutingTable(object):
""" """
implements(IRoutingTable) implements(IRoutingTable)
def __init__(self, parentNodeID): def __init__(self, parentNodeID, getTime=None):
""" """
@param parentNodeID: The n-bit node ID of the node to which this @param parentNodeID: The n-bit node ID of the node to which this
routing table belongs routing table belongs
@ -43,6 +42,9 @@ class TreeRoutingTable(object):
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space # Create the initial (single) k-bucket covering the range of the entire n-bit ID space
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)] self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)]
self._parentNodeID = parentNodeID self._parentNodeID = parentNodeID
if not getTime:
from time import time as getTime
self._getTime = getTime
def addContact(self, contact): def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already """ Add the given contact to the correct k-bucket; if it already
@ -194,7 +196,7 @@ class TreeRoutingTable(object):
bucketIndex = startIndex bucketIndex = startIndex
refreshIDs = [] refreshIDs = []
for bucket in self._buckets[startIndex:]: for bucket in self._buckets[startIndex:]:
if force or (int(time.time()) - bucket.lastAccessed >= constants.refreshTimeout): if force or (int(self._getTime()) - bucket.lastAccessed >= constants.refreshTimeout):
searchID = self._randomIDInBucketRange(bucketIndex) searchID = self._randomIDInBucketRange(bucketIndex)
refreshIDs.append(searchID) refreshIDs.append(searchID)
bucketIndex += 1 bucketIndex += 1
@ -221,7 +223,7 @@ class TreeRoutingTable(object):
@type key: str @type key: str
""" """
bucketIndex = self._kbucketIndex(key) bucketIndex = self._kbucketIndex(key)
self._buckets[bucketIndex].lastAccessed = int(time.time()) self._buckets[bucketIndex].lastAccessed = int(self._getTime())
def _kbucketIndex(self, key): def _kbucketIndex(self, key):
""" Calculate the index of the k-bucket which is responsible for the """ Calculate the index of the k-bucket which is responsible for the
@ -289,8 +291,8 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
of the 13-page version of the Kademlia paper. of the 13-page version of the Kademlia paper.
""" """
def __init__(self, parentNodeID): def __init__(self, parentNodeID, getTime=None):
TreeRoutingTable.__init__(self, parentNodeID) TreeRoutingTable.__init__(self, parentNodeID, getTime)
# Cache containing nodes eligible to replace stale k-bucket entries # Cache containing nodes eligible to replace stale k-bucket entries
self._replacementCache = {} self._replacementCache = {}
@ -301,6 +303,7 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
@param contact: The contact to add to this node's k-buckets @param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact @type contact: kademlia.contact.Contact
""" """
if contact.id == self._parentNodeID: if contact.id == self._parentNodeID:
return return