routers are terrible

(fixes SOAP responses with invalid HTTP versions)
This commit is contained in:
Jack Robison 2018-08-01 20:52:40 -04:00
parent dad07199a5
commit 2b2d3e0eb1
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 100 additions and 16 deletions

89
txupnp/dirty_pool.py Normal file
View file

@ -0,0 +1,89 @@
from twisted.web.client import HTTPConnectionPool, _HTTP11ClientFactory
from twisted.web._newclient import HTTPClientParser, BadResponseVersion, HTTP11ClientProtocol, RequestNotSent
from twisted.web._newclient import TransportProxyProducer, RequestGenerationFailed
from twisted.python.failure import Failure
from twisted.internet.defer import Deferred, fail, maybeDeferred
from twisted.internet.defer import CancelledError
class DirtyHTTPParser(HTTPClientParser):
def parseVersion(self, strversion):
"""
Parse version strings of the form Protocol '/' Major '.' Minor. E.g.
b'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError
on bad syntax.
"""
try:
proto, strnumber = strversion.split(b'/')
major, minor = strnumber.split(b'.')
major, minor = int(major), int(minor)
except ValueError as e:
if b'HTTP1.1' in strversion:
return ("HTTP", 1, 1)
raise BadResponseVersion(str(e), strversion)
if major < 0 or minor < 0:
raise BadResponseVersion(u"version may not be negative",
strversion)
return (proto, major, minor)
class DirtyHTTPClientProtocol(HTTP11ClientProtocol):
def request(self, request):
if self._state != 'QUIESCENT':
return fail(RequestNotSent())
self._state = 'TRANSMITTING'
_requestDeferred = maybeDeferred(request.writeTo, self.transport)
def cancelRequest(ign):
# Explicitly cancel the request's deferred if it's still trying to
# write when this request is cancelled.
if self._state in (
'TRANSMITTING', 'TRANSMITTING_AFTER_RECEIVING_RESPONSE'):
_requestDeferred.cancel()
else:
self.transport.abortConnection()
self._disconnectParser(Failure(CancelledError()))
self._finishedRequest = Deferred(cancelRequest)
# Keep track of the Request object in case we need to call stopWriting
# on it.
self._currentRequest = request
self._transportProxy = TransportProxyProducer(self.transport)
self._parser = DirtyHTTPParser(request, self._finishResponse)
self._parser.makeConnection(self._transportProxy)
self._responseDeferred = self._parser._responseDeferred
def cbRequestWritten(ignored):
if self._state == 'TRANSMITTING':
self._state = 'WAITING'
self._responseDeferred.chainDeferred(self._finishedRequest)
def ebRequestWriting(err):
if self._state == 'TRANSMITTING':
self._state = 'GENERATION_FAILED'
self.transport.abortConnection()
self._finishedRequest.errback(
Failure(RequestGenerationFailed([err])))
else:
self._log.failure(
u'Error writing request, but not in valid state '
u'to finalize request: {state}',
failure=err,
state=self._state
)
_requestDeferred.addCallbacks(cbRequestWritten, ebRequestWriting)
return self._finishedRequest
class DirtyHTTP11ClientFactory(_HTTP11ClientFactory):
def buildProtocol(self, addr):
return DirtyHTTPClientProtocol(self._quiescentCallback)
class DirtyPool(HTTPConnectionPool):
_factory = DirtyHTTP11ClientFactory

View file

@ -1,14 +1,15 @@
import logging
from collections import OrderedDict
from twisted.internet import defer, threads
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.client import Agent
import treq
from treq.client import HTTPClient
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
from txupnp.dirty_pool import DirtyPool
log = logging.getLogger(__name__)
@ -39,14 +40,9 @@ def get_soap_body(service_name, method, param_names, **kwargs):
class _SCPDCommand(object):
def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns,
reactor=None, connection_pool=None, agent=None, http_client=None):
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self._pool = connection_pool or HTTPConnectionPool(reactor)
self.agent = agent or Agent(reactor, connectTimeout=1)
self._http_client = http_client or HTTPClient(self.agent, data_to_body_producer=StringProducer)
def __init__(self, http_client, gateway_address, service_port, control_url, service_id, method, param_names,
returns):
self._http_client = http_client
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
@ -82,7 +78,7 @@ class _SCPDCommand(object):
('Content-Type', 'text/xml'),
('Content-Length', len(soap_body))
))
log.debug("sending POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body)
log.debug("send POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body)
try:
response = yield self._http_client.request(
POST, url=self.control_url, data=soap_body, headers=headers
@ -90,7 +86,7 @@ class _SCPDCommand(object):
except Exception as err:
log.error("error (%s) sending POST to %s\nheaders: %s\nbody:%s\n", err, self.control_url, headers,
soap_body)
raise UPnPError().with_traceback(err.__traceback__)
raise UPnPError(err)
xml_response = yield response.content()
try:
@ -150,7 +146,7 @@ class SCPDCommandRunner(object):
self._reactor = reactor
self._agent = Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._connection_pool = HTTPConnectionPool(reactor)
self._connection_pool = DirtyPool(reactor)
@defer.inlineCallbacks
def _discover_commands(self, service):
@ -195,10 +191,9 @@ class SCPDCommandRunner(object):
def _patch_command(self, action_info, service_type):
name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._gateway.base_address, self._gateway.port,
command = _SCPDCommand(self._http_client, self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(),
self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs,
self._reactor, self._connection_pool, self._agent, self._http_client)
self._gateway.get_service(service_type).serviceId.encode(), name, inputs, outputs)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result)