fix more headers

This commit is contained in:
Jack Robison 2018-10-04 18:58:56 -04:00
parent 6a0d71e891
commit 2bd4b5e3e6
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 255 additions and 213 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -101,13 +101,17 @@ class TestSCPD(TestDevice):
self.assertEqual(result, expected) self.assertEqual(result, expected)
class TestDDWRT(unittest.TestCase):
manufacturer, model = "DD-WRT", "router"
class TestDDWRTSSDP(TestSSDP): class TestDDWRTSSDP(TestSSDP):
manufacturer, model = "DD-WRT", "router" manufacturer, model = "DD-WRT", "router"
class TestDDWRTSCPD(TestSCPD): class TestDDWRTSCPD(TestSCPD):
manufacturer, model = "DD-WRT", "router" manufacturer, model = "DD-WRT", "router"
class TestMiniUPnPMiniUPnPd(TestSSDP):
manufacturer, model = "MiniUPnP", "MiniUPnPd"
class TestMiniUPnPMiniUPnPdSCPD(TestSCPD):
manufacturer, model = "MiniUPnP", "MiniUPnPd"

View file

@ -22,6 +22,8 @@ XML_ROOT_SANITY_PATTERN = re.compile(
def parse_service_description(content: bytes): def parse_service_description(content: bytes):
if not content:
return []
element_dict = etree_to_dict(ElementTree.fromstring(content.decode())) element_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
service_info = flatten_keys(element_dict, "{%s}" % SERVICE) service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "scpd" not in service_info: if "scpd" not in service_info:
@ -97,11 +99,15 @@ class SCPDHTTPClientFactory(ClientFactory):
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % ( '<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, command.method, command.service_id.decode(), XML_VERSION, command.method, command.service_id.decode(),
args, command.method)) args, command.method))
if "http://" in command.gateway_address.decode():
host = command.gateway_address.decode().split("http://")[1]
else:
host = command.gateway_address.decode()
data = ( data = (
( (
'POST %s HTTP/1.1\r\n' 'POST %s HTTP/1.1\r\n'
'Host: %s\r\n' 'Host: %s\r\n'
'User-Agent: Debian/buster/sid, UPnP/1.0, MiniUPnPc/1.9\r\n' 'User-Agent: python3/txupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n' 'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n' 'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n' 'SOAPAction: \"%s#%s\"\r\n'
@ -112,7 +118,7 @@ class SCPDHTTPClientFactory(ClientFactory):
'\r\n' '\r\n'
) % ( ) % (
command.control_url.decode(), # could be just / even if it shouldn't be command.control_url.decode(), # could be just / even if it shouldn't be
command.gateway_address.decode().split("http://")[1], host,
len(soap_body), len(soap_body),
command.service_id.decode(), # maybe no quotes command.service_id.decode(), # maybe no quotes
command.method, command.method,
@ -123,13 +129,21 @@ class SCPDHTTPClientFactory(ClientFactory):
@classmethod @classmethod
def get(cls, reactor, control_url: str, address: str): def get(cls, reactor, control_url: str, address: str):
if "http://" in address:
host = address.split("http://")[1]
else:
host = address
if ":" in host:
host = host.split(":")[0]
if not control_url.startswith("/"):
control_url = "/" + control_url
data = ( data = (
( (
'GET %s HTTP/1.1\r\n' 'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n' 'Accept-Encoding: gzip\r\n'
'Host: %s\r\n' 'Host: %s\r\n'
'\r\n' '\r\n'
) % (control_url, address.split("http://")[1]) ) % (control_url, host)
).encode() ).encode()
return cls(reactor, data) return cls(reactor, data)