diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index 03f56ca9d..31f1b238a 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -41,7 +41,46 @@ def rpcmethod(func): return func -class Node(object): +class MockKademliaHelper(object): + def __init__(self, clock=None, callLater=None, resolve=None, listenUDP=None): + if not listenUDP or not resolve or not callLater or not clock: + from twisted.internet import reactor + listenUDP = listenUDP or reactor.listenUDP + resolve = resolve or reactor.resolve + callLater = callLater or reactor.callLater + clock = clock or reactor + + self.clock = clock + self.reactor_listenUDP = listenUDP + self.reactor_resolve = resolve + + CallLaterManager.setup(callLater) + self.reactor_callLater = CallLaterManager.call_later + self.reactor_callSoon = CallLaterManager.call_soon + + self._listeningPort = None # object implementing Twisted + # IListeningPort This will contain a deferred created when + # joining the network, to enable publishing/retrieving + # information from the DHT as soon as the node is part of the + # network (add callbacks to this deferred if scheduling such + # operations before the node has finished joining the network) + + def get_looping_call(self, fn, *args, **kwargs): + lc = task.LoopingCall(fn, *args, **kwargs) + lc.clock = self.clock + return lc + + def safe_stop_looping_call(self, lc): + if lc and lc.running: + return lc.stop() + return defer.succeed(None) + + def safe_start_looping_call(self, lc, t): + if lc and not lc.running: + lc.start(t) + + +class Node(MockKademliaHelper): """ Local node in the Kademlia network This class represents a single local node in a Kademlia network; in other @@ -54,7 +93,7 @@ class Node(object): def __init__(self, node_id=None, udpPort=4000, dataStore=None, routingTableClass=None, networkProtocol=None, - externalIP=None, peerPort=None, listenUDP=None, + externalIP=None, peerPort=3333, listenUDP=None, callLater=None, resolve=None, clock=None, peer_finder=None, peer_manager=None): """ @@ -82,31 +121,11 @@ class Node(object): @param peerPort: the port at which this node announces it has a blob for """ - if not listenUDP or not resolve or not callLater or not clock: - from twisted.internet import reactor - listenUDP = listenUDP or reactor.listenUDP - resolve = resolve or reactor.resolve - callLater = callLater or reactor.callLater - clock = clock or reactor - self.clock = clock - CallLaterManager.setup(callLater) - self.reactor_resolve = resolve - self.reactor_listenUDP = listenUDP - self.reactor_callLater = CallLaterManager.call_later - self.reactor_callSoon = CallLaterManager.call_soon + MockKademliaHelper.__init__(self, clock, callLater, resolve, listenUDP) self.node_id = node_id or self._generateID() self.port = udpPort - self._listeningPort = None # object implementing Twisted - # IListeningPort This will contain a deferred created when - # joining the network, to enable publishing/retrieving - # information from the DHT as soon as the node is part of the - # network (add callbacks to this deferred if scheduling such - # operations before the node has finished joining the network) - self._joinDeferred = defer.Deferred(None) - self.change_token_lc = task.LoopingCall(self.change_token) - self.change_token_lc.clock = self.clock - self.refresh_node_lc = task.LoopingCall(self._refreshNode) - self.refresh_node_lc.clock = self.clock + self._change_token_lc = self.get_looping_call(self.change_token) + self._refresh_node_lc = self.get_looping_call(self._refreshNode) # Create k-buckets (for storing contacts) if routingTableClass is None: @@ -127,6 +146,7 @@ class Node(object): self._dataStore = dataStore or datastore.DictDataStore() self.peer_manager = peer_manager or PeerManager() self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager) + self._join_deferred = None def __del__(self): log.warning("unclean shutdown of the dht node") @@ -136,10 +156,8 @@ class Node(object): @defer.inlineCallbacks def stop(self): # stop LoopingCalls: - if self.refresh_node_lc.running: - yield self.refresh_node_lc.stop() - if self.change_token_lc.running: - yield self.change_token_lc.stop() + yield self.safe_stop_looping_call(self._refresh_node_lc) + yield self.safe_stop_looping_call(self._change_token_lc) if self._listeningPort is not None: yield self._listeningPort.stopListening() @@ -204,11 +222,10 @@ class Node(object): self.start_listening() # #TODO: Refresh all k-buckets further away than this node's closest neighbour + self.safe_start_looping_call(self._change_token_lc, constants.tokenSecretChangeInterval) # Start refreshing k-buckets periodically, if necessary self.bootstrap_join(known_node_addresses or [], self._joinDeferred) - yield self._joinDeferred - self.change_token_lc.start(constants.tokenSecretChangeInterval) - self.refresh_node_lc.start(constants.checkRefreshInterval) + self.safe_start_looping_call(self._refresh_node_lc, constants.checkRefreshInterval) @property def contacts(self):