diff --git a/txupnp/dirty_pool.py b/txupnp/dirty_pool.py new file mode 100644 index 0000000..f287e13 --- /dev/null +++ b/txupnp/dirty_pool.py @@ -0,0 +1,89 @@ +from twisted.web.client import HTTPConnectionPool, _HTTP11ClientFactory +from twisted.web._newclient import HTTPClientParser, BadResponseVersion, HTTP11ClientProtocol, RequestNotSent +from twisted.web._newclient import TransportProxyProducer, RequestGenerationFailed +from twisted.python.failure import Failure +from twisted.internet.defer import Deferred, fail, maybeDeferred +from twisted.internet.defer import CancelledError + + +class DirtyHTTPParser(HTTPClientParser): + def parseVersion(self, strversion): + """ + Parse version strings of the form Protocol '/' Major '.' Minor. E.g. + b'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError + on bad syntax. + """ + try: + proto, strnumber = strversion.split(b'/') + major, minor = strnumber.split(b'.') + major, minor = int(major), int(minor) + except ValueError as e: + if b'HTTP1.1' in strversion: + return ("HTTP", 1, 1) + raise BadResponseVersion(str(e), strversion) + if major < 0 or minor < 0: + raise BadResponseVersion(u"version may not be negative", + strversion) + return (proto, major, minor) + + +class DirtyHTTPClientProtocol(HTTP11ClientProtocol): + def request(self, request): + if self._state != 'QUIESCENT': + return fail(RequestNotSent()) + + self._state = 'TRANSMITTING' + _requestDeferred = maybeDeferred(request.writeTo, self.transport) + + def cancelRequest(ign): + # Explicitly cancel the request's deferred if it's still trying to + # write when this request is cancelled. + if self._state in ( + 'TRANSMITTING', 'TRANSMITTING_AFTER_RECEIVING_RESPONSE'): + _requestDeferred.cancel() + else: + self.transport.abortConnection() + self._disconnectParser(Failure(CancelledError())) + + self._finishedRequest = Deferred(cancelRequest) + + # Keep track of the Request object in case we need to call stopWriting + # on it. + self._currentRequest = request + + self._transportProxy = TransportProxyProducer(self.transport) + self._parser = DirtyHTTPParser(request, self._finishResponse) + self._parser.makeConnection(self._transportProxy) + self._responseDeferred = self._parser._responseDeferred + + def cbRequestWritten(ignored): + if self._state == 'TRANSMITTING': + self._state = 'WAITING' + self._responseDeferred.chainDeferred(self._finishedRequest) + + def ebRequestWriting(err): + if self._state == 'TRANSMITTING': + self._state = 'GENERATION_FAILED' + self.transport.abortConnection() + self._finishedRequest.errback( + Failure(RequestGenerationFailed([err]))) + else: + self._log.failure( + u'Error writing request, but not in valid state ' + u'to finalize request: {state}', + failure=err, + state=self._state + ) + + _requestDeferred.addCallbacks(cbRequestWritten, ebRequestWriting) + + return self._finishedRequest + + +class DirtyHTTP11ClientFactory(_HTTP11ClientFactory): + def buildProtocol(self, addr): + return DirtyHTTPClientProtocol(self._quiescentCallback) + + +class DirtyPool(HTTPConnectionPool): + _factory = DirtyHTTP11ClientFactory diff --git a/txupnp/scpd.py b/txupnp/scpd.py index b79c458..35a02ea 100644 --- a/txupnp/scpd.py +++ b/txupnp/scpd.py @@ -1,14 +1,15 @@ import logging from collections import OrderedDict from twisted.internet import defer, threads -from twisted.web.client import Agent, HTTPConnectionPool +from twisted.web.client import Agent import treq from treq.client import HTTPClient from xml.etree import ElementTree from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none from txupnp.fault import handle_fault, UPnPError -from txupnp.constants import IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION +from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION from txupnp.constants import BODY, POST +from txupnp.dirty_pool import DirtyPool log = logging.getLogger(__name__) @@ -39,14 +40,9 @@ def get_soap_body(service_name, method, param_names, **kwargs): class _SCPDCommand(object): - def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns, - reactor=None, connection_pool=None, agent=None, http_client=None): - if not reactor: - from twisted.internet import reactor - self._reactor = reactor - self._pool = connection_pool or HTTPConnectionPool(reactor) - self.agent = agent or Agent(reactor, connectTimeout=1) - self._http_client = http_client or HTTPClient(self.agent, data_to_body_producer=StringProducer) + def __init__(self, http_client, gateway_address, service_port, control_url, service_id, method, param_names, + returns): + self._http_client = http_client self.gateway_address = gateway_address self.service_port = service_port self.control_url = control_url @@ -82,7 +78,7 @@ class _SCPDCommand(object): ('Content-Type', 'text/xml'), ('Content-Length', len(soap_body)) )) - log.debug("sending POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body) + log.debug("send POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body) try: response = yield self._http_client.request( POST, url=self.control_url, data=soap_body, headers=headers @@ -90,7 +86,7 @@ class _SCPDCommand(object): except Exception as err: log.error("error (%s) sending POST to %s\nheaders: %s\nbody:%s\n", err, self.control_url, headers, soap_body) - raise UPnPError().with_traceback(err.__traceback__) + raise UPnPError(err) xml_response = yield response.content() try: @@ -150,7 +146,7 @@ class SCPDCommandRunner(object): self._reactor = reactor self._agent = Agent(reactor, connectTimeout=1) self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer) - self._connection_pool = HTTPConnectionPool(reactor) + self._connection_pool = DirtyPool(reactor) @defer.inlineCallbacks def _discover_commands(self, service): @@ -195,10 +191,9 @@ class SCPDCommandRunner(object): def _patch_command(self, action_info, service_type): name, inputs, outputs = self._soap_function_info(action_info) - command = _SCPDCommand(self._gateway.base_address, self._gateway.port, + command = _SCPDCommand(self._http_client, self._gateway.base_address, self._gateway.port, self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(), - self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs, - self._reactor, self._connection_pool, self._agent, self._http_client) + self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs) current = getattr(self, command.method) if hasattr(current, "_return_types"): command._process_result = _return_types(*current._return_types)(command._process_result)