diff --git a/CHANGELOG.md b/CHANGELOG.md index 321855f63..aac90f7c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ at anytime. * ### Fixed - * + * Added timeout to ClientProtocol * * diff --git a/lbrynet/core/client/ClientProtocol.py b/lbrynet/core/client/ClientProtocol.py index c588f0f2a..a9dca8307 100644 --- a/lbrynet/core/client/ClientProtocol.py +++ b/lbrynet/core/client/ClientProtocol.py @@ -3,8 +3,10 @@ import logging from decimal import Decimal from twisted.internet import error, defer from twisted.internet.protocol import Protocol, ClientFactory +from twisted.protocols.policies import TimeoutMixin from twisted.python import failure from lbrynet import conf +from lbrynet.core import utils from lbrynet.core.Error import ConnectionClosedBeforeResponseError, NoResponseError from lbrynet.core.Error import DownloadCanceledError, MisbehavingPeerError from lbrynet.core.Error import RequestCanceledError @@ -21,9 +23,10 @@ def encode_decimal(obj): raise TypeError(repr(obj) + " is not JSON serializable") -class ClientProtocol(Protocol): +class ClientProtocol(Protocol, TimeoutMixin): implements(IRequestSender, IRateLimited) ######### Protocol ######### + PROTOCOL_TIMEOUT = 30 def connectionMade(self): log.debug("Connection made to %s", self.factory.peer) @@ -37,13 +40,15 @@ class ClientProtocol(Protocol): self._next_request = {} self.connection_closed = False self.connection_closing = False - + # This needs to be set for TimeoutMixin + self.callLater = utils.call_later self.peer.report_up() self._ask_for_request() def dataReceived(self, data): log.debug("Data receieved from %s", self.peer) + self.setTimeout(None) self._rate_limiter.report_dl_bytes(len(data)) if self._downloading_blob is True: self._blob_download_request.write(data) @@ -60,8 +65,14 @@ class ClientProtocol(Protocol): if self._downloading_blob is True and len(extra_data) != 0: self._blob_download_request.write(extra_data) + def timeoutConnection(self): + log.info("Connection timed out to %s", self.peer) + self.peer.report_down() + self.transport.abortConnection() + def connectionLost(self, reason): log.debug("Connection lost to %s: %s", self.peer, reason) + self.setTimeout(None) self.connection_closed = True if reason.check(error.ConnectionDone): err = failure.Failure(ConnectionClosedBeforeResponseError()) @@ -138,6 +149,7 @@ class ClientProtocol(Protocol): d.addErrback(self._handle_request_error) def _send_request_message(self, request_msg): + self.setTimeout(self.PROTOCOL_TIMEOUT) # TODO: compare this message to the last one. If they're the same, # TODO: incrementally delay this message. m = json.dumps(request_msg, default=encode_decimal) diff --git a/tests/unit/core/client/test_ConnectionManager.py b/tests/unit/core/client/test_ConnectionManager.py index 81d69429e..fd97abf22 100644 --- a/tests/unit/core/client/test_ConnectionManager.py +++ b/tests/unit/core/client/test_ConnectionManager.py @@ -3,9 +3,9 @@ import time import logging from lbrynet.core import log_support -#from lbrynet.core.client.ConnectionManager import ConnectionManager from lbrynet.core.client.ClientRequest import ClientRequest from lbrynet.core.server.ServerProtocol import ServerProtocol +from lbrynet.core.client.ClientProtocol import ClientProtocol from lbrynet.core.RateLimiter import RateLimiter from lbrynet.core.Peer import Peer from lbrynet.core.PeerManager import PeerManager @@ -71,8 +71,9 @@ class MocFunctionalQueryHandler(object): def handle_queries(self, queries): if self.query_identifiers[0] in queries: if self.is_delayed: - out = deferLater(self.clock, 10, lambda: {'moc_request':0}) - self.clock.advance(10) + delay = ClientProtocol.PROTOCOL_TIMEOUT+1 + out = deferLater(self.clock, delay, lambda: {'moc_request':0}) + self.clock.advance(delay) return out if self.is_good: return defer.succeed({'moc_request':0}) @@ -113,7 +114,6 @@ class MocServerProtocolFactory(ServerFactory): self.query_handler_factories = {} self.peer_manager = PeerManager() - class TestIntegrationConnectionManager(unittest.TestCase): def setUp(self): @@ -152,9 +152,7 @@ class TestIntegrationConnectionManager(unittest.TestCase): self.assertEqual(0, self.TEST_PEER.down_count) @defer.inlineCallbacks - def test_bad_server(self): - # test to see that if we setup a server that returns an improper reply - # we don't get a connection + def test_server_with_improper_reply(self): self.server = MocServerProtocolFactory(self.clock, is_good=False) self.server_port = reactor.listenTCP(PEER_PORT, self.server, interface=LOCAL_HOST) yield self.connection_manager.manage(schedule_next_call=False) @@ -213,4 +211,17 @@ class TestIntegrationConnectionManager(unittest.TestCase): self.assertEqual(0, self.connection_manager.num_peer_connections()) self.assertEqual(None, self.connection_manager._next_manage_call) + @defer.inlineCallbacks + def test_closed_connection_when_server_is_slow(self): + self.server = MocServerProtocolFactory(self.clock, has_moc_query_handler=True,is_delayed=True) + self.server_port = reactor.listenTCP(PEER_PORT, self.server, interface=LOCAL_HOST) + + yield self.connection_manager.manage(schedule_next_call=False) + self.assertEqual(1, self.connection_manager.num_peer_connections()) + connection_made = yield self.connection_manager._peer_connections[self.TEST_PEER].factory.connection_was_made_deferred + self.assertEqual(0, self.connection_manager.num_peer_connections()) + self.assertEqual(True, connection_made) + self.assertEqual(0, self.TEST_PEER.success_count) + self.assertEqual(1, self.TEST_PEER.down_count) +