From d6695c79253e83e680d4e1307cccfd342268ece8 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Sun, 29 Jul 2018 17:32:14 -0400 Subject: [PATCH] better ssdp datagram parsing --- txupnp/gateway.py | 57 +++++- txupnp/scpd.py | 6 +- txupnp/soap.py | 6 +- txupnp/ssdp.py | 346 ++++++++++++++++++++++++------------ txupnp/tests/test_txupnp.py | 26 ++- txupnp/upnp.py | 8 +- 6 files changed, 310 insertions(+), 139 deletions(-) diff --git a/txupnp/gateway.py b/txupnp/gateway.py index 1597b73..bdac70d 100644 --- a/txupnp/gateway.py +++ b/txupnp/gateway.py @@ -1,4 +1,4 @@ -import binascii +import json import logging from twisted.internet import defer import treq @@ -19,6 +19,15 @@ class Service(object): self.subscribe_path = eventSubURL self.scpd_path = SCPDURL + def get_info(self): + return { + "service_type": self.service_type, + "service_id": self.service_id, + "control_path": self.control_path, + "subscribe_path": self.subscribe_path, + "scpd_path": self.scpd_path + } + class Device(object): def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None, @@ -45,6 +54,17 @@ class Device(object): devices = [Device(self._root_device, **deviceList[k]) for k in deviceList] self._root_device.devices.extend(devices) + def get_info(self): + return { + 'device_type': self.device_type, + 'friendly_name': self.friendly_name, + 'manufacturers': self.manufacturer, + 'model_name': self.model_name, + 'model_number': self.model_number, + 'serial_number': self.serial_number, + 'udn': self.udn + } + class RootDevice(object): def __init__(self, xml_string): @@ -61,7 +81,7 @@ class RootDevice(object): if root: root_device = Device(self, **(root["device"])) self.devices.append(root_device) - log.info("finished setting up root device. %i devices and %i services", len(self.devices), len(self.services)) + log.info("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services)) class Gateway(object): @@ -77,14 +97,41 @@ class Gateway(object): self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self._device = None + def debug_device(self): + def default_byte(x): + if isinstance(x, bytes): + return x.decode() + return x + + devices = [] + for device in self._device.devices: + info = device.get_info() + devices.append(info) + services = [] + for service in self._device.services: + info = service.get_info() + services.append(info) + return json.dumps({ + 'root_url': self.base_address, + 'gateway_xml_url': self.location, + 'usn': self.usn, + 'devices': devices, + 'services': services + }, indent=2, default=default_byte) + @defer.inlineCallbacks def discover_services(self): log.info("querying %s", self.location) response = yield treq.get(self.location) response_xml = yield response.content() - self._device = RootDevice(response_xml) - if not self._device.devices or not self._device.services: - log.error("failed to parse device: \n%s", response_xml) + if not response_xml: + log.error("service sent an empty reply\n%s", self.debug_device()) + try: + self._device = RootDevice(response_xml) + except Exception as err: + log.error("error parsing gateway: %s\n%s\n\n%s", err, self.debug_device(), response_xml) + self._device = RootDevice("") + log.debug("finished setting up gateway:\n%s", self.debug_device()) @property def services(self): diff --git a/txupnp/scpd.py b/txupnp/scpd.py index d420071..b53cdfb 100644 --- a/txupnp/scpd.py +++ b/txupnp/scpd.py @@ -87,7 +87,7 @@ class _SCPDCommand(object): xml_response = yield response.content() response = self.extract_response(self.extract_body(xml_response)) if not response: - log.error("empty response to %s\n%s", self.method, xml_response) + log.debug("empty response to %s\n%s", self.method, xml_response) defer.returnValue(response) @staticmethod @@ -155,7 +155,7 @@ class SCPDCommandRunner(object): @staticmethod def _soap_function_info(action_dict): if not action_dict.get('argumentList'): - log.warning("don't know how to handle argument list: %s", action_dict) + log.debug("don't know how to handle argument list: %s", action_dict) return ( action_dict['name'], [], @@ -185,7 +185,7 @@ class SCPDCommandRunner(object): command._process_result = _return_types(*current._return_types)(command._process_result) setattr(command, "__doc__", current.__doc__) setattr(self, command.method, command) - log.info("registered %s %s", service_type, action_info['name']) + log.debug("registered %s %s", service_type, action_info['name']) def _register_command(self, action_info, service_type): try: diff --git a/txupnp/soap.py b/txupnp/soap.py index 44abb97..193700c 100644 --- a/txupnp/soap.py +++ b/txupnp/soap.py @@ -14,14 +14,14 @@ class SOAPServiceManager(object): def __init__(self, reactor): self._reactor = reactor self.iface_name, self.router_ip, self.lan_address = get_lan_info() - self.sspd_factory = SSDPFactory(self.lan_address, self._reactor) + self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip) self._command_runners = {} self._selected_runner = GATEWAY_SCHEMA @defer.inlineCallbacks - def discover_services(self, address=None, ttl=30, max_devices=2): + def discover_services(self, address=None, timeout=30, max_devices=1): server_infos = yield self.sspd_factory.m_search( - address or self.router_ip, ttl=ttl, max_devices=max_devices + address or self.router_ip, timeout=timeout, max_devices=max_devices ) locations = [] for server_info in server_infos: diff --git a/txupnp/ssdp.py b/txupnp/ssdp.py index 67acd31..386b713 100644 --- a/txupnp/ssdp.py +++ b/txupnp/ssdp.py @@ -1,67 +1,188 @@ import logging import binascii +import re from twisted.internet import defer from twisted.internet.protocol import DatagramProtocol from txupnp.fault import UPnPError -from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL +from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL log = logging.getLogger(__name__) -def parse_http_fields(content_lines): - def flatten(s, lower=True): - r = s.rstrip(":").rstrip(" ").lstrip(" ").replace("-", "_") - if lower: - return r.lower() - return r - - result = {} - for l in content_lines: - split = l.decode().split(":") - if split and split[0]: - k = split[0] - v = ":".join(split[1:]) - result[flatten(k)] = flatten(v, lower=False) - return result - - # - # return { - # (k.lower().rstrip(":".encode()).replace("-".encode(), "_".encode())).decode(): v.decode() - # for k, v in { - # l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:]) - # - # }.items() if k - # } +SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT) +SSDP_BYEBYE = "ssdp:byebye" +SSDP_UPDATE = "ssdp:update" +SSDP_ROOT_DEVICE = "upnp:rootdevice" +line_separator = "\r\n" -def parse_ssdp_request(operation, port, protocol, content_lines): - if operation != "NOTIFY".encode(): - log.warning("unsupported operation: %s", operation) - raise UPnPError("unsupported operation: %s" % operation) - if port != "*".encode(): - log.warning("unexpected port: %s", port) - raise UPnPError("unexpected port: %s" % port) - return parse_http_fields(content_lines) +class SSDPDatagram(object): + _M_SEARCH = "M-SEARCH" + _NOTIFY = "NOTIFY" + _OK = "OK" + _start_lines = { + _M_SEARCH: "M-SEARCH * HTTP/1.1", + _NOTIFY: "NOTIFY * HTTP/1.1", + _OK: "HTTP/1.1 200 OK" + } -def parse_ssdp_response(code, response, content_lines): - try: - if int(code) != 200: - raise UPnPError("unexpected http response code: %i" % int(code)) - except ValueError: - log.error(response) - raise UPnPError("unexpected http response code: %s" % code) - if response != "OK".encode(): - raise UPnPError("unexpected response: %s" % response) - return parse_http_fields(content_lines) + _friendly_names = { + _M_SEARCH: "m-search", + _NOTIFY: "notify", + _OK: "m-search response" + } + + _vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$") + + _patterns = { + 'host': (re.compile("^(?i)(host):(.*)$"), str), + 'st': (re.compile("^(?i)(st):(.*)$"), str), + 'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str), + 'mx': (re.compile("^(?i)(mx):(.*)$"), int), + 'nt': (re.compile("^(?i)(nt):(.*)$"), str), + 'nts': (re.compile("^(?i)(nts):(.*)$"), str), + 'usn': (re.compile("^(?i)(usn):(.*)$"), str), + 'location': (re.compile("^(?i)(location):(.*)$"), str), + 'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str), + 'server': (re.compile("^(?i)(server):(.*)$"), str), + } + + _required_fields = { + _M_SEARCH: [ + 'host', + 'st', + 'man', + 'mx', + ], + _OK: [ + 'cache_control', + # 'date', + # 'ext', + 'location', + 'server', + 'st', + 'usn' + ] + } + + _marshallers = { + 'mx': str, + 'man': lambda x: ("\"%s\"" % x) + } + + def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None, + cache_control=None, server=None, date=None, ext=None, **kwargs): + if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]: + raise UPnPError("unknown packet type: {}".format(packet_type)) + self._packet_type = packet_type + self.host = host + self.st = st + self.man = man + self.mx = mx + self.nt = nt + self.nts = nts + self.usn = usn + self.location = location + self.cache_control = cache_control + self.server = server + self.date = date + self.ext = ext + for k, v in kwargs.items(): + if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None: + setattr(self, k.lower(), v) + + def __getitem__(self, item): + for i in self._required_fields[self._packet_type]: + if i.lower() == item.lower(): + return getattr(self, i) + raise KeyError(item) + + def get_friendly_name(self): + return self._friendly_names[self._packet_type] + + def encode(self, trailing_newlines=2): + lines = [self._start_lines[self._packet_type]] + for attr_name in self._required_fields[self._packet_type]: + attr = getattr(self, attr_name) + if attr is None: + raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name)) + if attr_name in self._marshallers: + value = self._marshallers[attr_name](attr) + else: + value = attr + lines.append("{}: {}".format(attr_name.upper(), value)) + serialized = line_separator.join(lines) + for _ in range(trailing_newlines): + serialized += line_separator + return serialized + + def as_dict(self): + return self._lines_to_content_dict(self.encode().split(line_separator)) + + @classmethod + def decode(cls, datagram): + packet = cls._from_string(datagram.decode()) + for attr_name in packet._required_fields[packet._packet_type]: + attr = getattr(packet, attr_name) + if attr is None: + raise UPnPError( + "required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name) + ) + return packet + + @classmethod + def _lines_to_content_dict(cls, lines): + result = {} + for line in lines: + if not line: + continue + matched = False + for name, (pattern, field_type) in cls._patterns.items(): + if name not in result and pattern.findall(line): + match = pattern.findall(line)[-1][-1] + result[name] = field_type(match.lstrip(" ").rstrip(" ")) + matched = True + break + if not matched: + if cls._vendor_field_pattern.findall(line): + match = cls._vendor_field_pattern.findall(line)[-1] + vendor_key = match[0].lstrip(" ").rstrip(" ") + # vendor_domain = match[1].lstrip(" ").rstrip(" ") + value = match[2].lstrip(" ").rstrip(" ") + if vendor_key not in result: + result[vendor_key] = value + return result + + @classmethod + def _from_string(cls, datagram): + lines = [l for l in datagram.split(line_separator) if l] + if lines[0] == cls._start_lines[cls._M_SEARCH]: + return cls._from_request(lines[1:]) + if lines[0] == cls._start_lines[cls._NOTIFY]: + return cls._from_notify(lines[1:]) + if lines[0] == cls._start_lines[cls._OK]: + return cls._from_response(lines[1:]) + + @classmethod + def _from_response(cls, lines): + return cls(cls._OK, **cls._lines_to_content_dict(lines)) + + @classmethod + def _from_notify(cls, lines): + return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines)) + + @classmethod + def _from_request(cls, lines): + return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines)) class SSDPProtocol(DatagramProtocol): - def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS, + def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS, ssdp_port=SSDP_PORT, ttl=1, max_devices=None): self._reactor = reactor self._sem = defer.DeferredSemaphore(1) - self.finished_deferred = finished_deferred + self.discover_callbacks = {} self.iface = iface self.router = router self.ssdp_address = ssdp_address @@ -72,93 +193,81 @@ class SSDPProtocol(DatagramProtocol): self.devices = [] def startProtocol(self): - return self._sem.run(self.do_start) - - def send_m_search(self): - data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_ALL, self.ttl) - try: - log.info("sending m-search (%i bytes) to %s:%i", len(data), self.ssdp_address, self.ssdp_port) - self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port)) - except Exception as err: - log.exception("failed to write %s to %s:%i", binascii.hexlify(data), self.ssdp_address, self.ssdp_port) - raise err - - def parse_ssdp_datagram(self, datagram): - lines = datagram.split("\r\n".encode()) - header_pieces = lines[0].split(" ".encode()) - protocols = { - "HTTP/1.1".encode() - } - operations = { - "M-SEARCH".encode(), - "NOTIFY".encode() - } - if header_pieces[0] in operations: - if header_pieces[2] not in protocols: - raise UPnPError("unknown protocol: %s" % header_pieces[2]) - return parse_ssdp_request(header_pieces[0], header_pieces[1], header_pieces[2], lines[1:]) - if header_pieces[0] in protocols: - parsed = parse_ssdp_response(header_pieces[1], header_pieces[2], lines[1:]) - log.info("received reply (%i bytes) to SSDP request (%f) (%s) %s", len(datagram), - self._reactor.seconds() - self._start, parsed['location'], parsed['server']) - return parsed - raise UPnPError("don't know how to decode datagram: %s" % binascii.hexlify(datagram)) - - def do_start(self): self._start = self._reactor.seconds() - self.finished_deferred.addTimeout(self.ttl, self._reactor) self.transport.setTTL(self.ttl) self.transport.joinGroup(self.ssdp_address, interface=self.iface) - self.send_m_search() + + for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]: + self.send_m_search(service=st) + + def send_m_search(self, service=GATEWAY_SCHEMA): + packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1) + log.debug("writing packet:\n%s", packet.encode()) + log.info("sending m-search (%i bytes) to %s:%i", len(packet.encode()), self.ssdp_address, self.ssdp_port) + try: + self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port)) + except Exception as err: + log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port) + raise err def leave_group(self): self.transport.leaveGroup(self.ssdp_address, interface=self.iface) - def datagramReceived(self, datagram, addr): - self._sem.run(self.handle_datagram, datagram, addr) - - def handle_datagram(self, datagram, address): - if address[0] == self.router: - try: - parsed = self.parse_ssdp_datagram(datagram) - self.devices.append(parsed) - log.info("found %i/%s so far", len(self.devices), self.max_devices) - if not self.finished_deferred.called: - if not self.max_devices or (self.max_devices and len(self.devices) >= self.max_devices): - self._sem.run(self.finished_deferred.callback, self.devices) - except UPnPError as err: - log.error("error decoding SSDP response from %s:%s (error: %s)\n%s", address[0], address[1], str(err), binascii.hexlify(datagram)) - raise err - elif address[0] != self.iface: - log.info("received %i bytes from %s:%i\n%s", len(datagram), address[0], address[1], binascii.hexlify(datagram)) + def datagramReceived(self, datagram, address): + if address[0] == self.iface: + return + try: + packet = SSDPDatagram.decode(datagram) + log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode()) + except Exception: + log.exception("failed to decode: %s", binascii.hexlify(datagram)) + return + if packet._packet_type == packet._OK: + log.info("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location) else: - pass # loopback + log.info("%s:%i notified us of a service type: %s", address[0], address[1], packet.st) + if packet.st not in map(lambda p: p['st'], self.devices): + self.devices.append(packet.as_dict()) + log.info("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s") + if address[0] in self.discover_callbacks: + self._sem.run(self.discover_callbacks[address[0]][0], packet) + + +def gather(finished_deferred, max_results): + results = [] + + def discover_cb(packet): + if not finished_deferred.called: + results.append(packet.as_dict()) + if len(results) >= max_results: + finished_deferred.callback(results) + + return discover_cb class SSDPFactory(object): - def __init__(self, lan_address, reactor): + def __init__(self, reactor, lan_address, router_address): self.lan_address = lan_address + self.router_address = router_address self._reactor = reactor - self.protocol = None + self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address) self.port = None - self.finished_deferred = defer.Deferred() - def stop(self): - try: + def disconnect(self): + if self.protocol: self.protocol.leave_group() - self.port.stopListening() - except: - pass + self.protocol = None + if not self.port: + return + self.port.stopListening() + self.port = None - def connect(self, address, ttl, max_devices=1): - self.protocol = SSDPProtocol(self._reactor, self.finished_deferred, self.lan_address, address, ttl=ttl, - max_devices=max_devices) + def connect(self): + self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect) self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True) - self._reactor.addSystemEventTrigger("before", "shutdown", self.stop) - return self.finished_deferred @defer.inlineCallbacks - def m_search(self, address, ttl=30, max_devices=2): + def m_search(self, address, timeout=30, max_devices=1): """ Perform a HTTP over UDP M-SEARCH query @@ -170,11 +279,18 @@ class SSDPFactory(object): 'usn': }, ...] """ - d = self.connect(address, ttl, max_devices=max_devices) + + self.connect() + + if address in self.protocol.discover_callbacks: + d = self.protocol.discover_callbacks[address][1] + else: + d = defer.Deferred() + d.addTimeout(timeout, self._reactor) + found_cb = gather(d, max_devices) + self.protocol.discover_callbacks[address] = found_cb, d try: server_infos = yield d except defer.TimeoutError: server_infos = self.protocol.devices - log.info("found %i devices", len(server_infos)) - self.stop() defer.returnValue(server_infos) diff --git a/txupnp/tests/test_txupnp.py b/txupnp/tests/test_txupnp.py index ff5a1c0..cdfd340 100644 --- a/txupnp/tests/test_txupnp.py +++ b/txupnp/tests/test_txupnp.py @@ -1,3 +1,4 @@ +import sys import logging from twisted.internet import reactor, defer from txupnp.upnp import UPnP @@ -7,10 +8,12 @@ log = logging.getLogger("txupnp") @defer.inlineCallbacks -def test(ext_port=4446, int_port=4446, proto='UDP'): +def test(ext_port=4446, int_port=4446, proto='UDP', timeout=1): u = UPnP(reactor) - found = yield u.discover() - assert found, "M-SEARCH failed to find gateway" + found = yield u.discover(timeout=timeout) + if not found: + print("failed to find gateway") + defer.returnValue(None) external_ip = yield u.get_external_ip() assert external_ip, "Failed to get the external IP" log.info(external_ip) @@ -45,17 +48,22 @@ def test(ext_port=4446, int_port=4446, proto='UDP'): @defer.inlineCallbacks -def run_tests(): - for p in ['UDP', 'TCP']: - yield test(proto=p) +def run_tests(timeout=1): + for p in ['UDP']: + yield test(proto=p, timeout=timeout) -def main(): - d = run_tests() +def main(timeout): + d = run_tests(timeout) d.addErrback(log.exception) d.addBoth(lambda _: reactor.callLater(0, reactor.stop)) reactor.run() if __name__ == "__main__": - main() + if len(sys.argv) > 1: + log.setLevel(logging.DEBUG) + timeout = int(sys.argv[1]) + else: + timeout = 1 + main(timeout) diff --git a/txupnp/upnp.py b/txupnp/upnp.py index faf85e5..7439f57 100644 --- a/txupnp/upnp.py +++ b/txupnp/upnp.py @@ -19,7 +19,7 @@ class UPnP(object): def commands(self): return self.soap_manager.get_runner() - def m_search(self, address, ttl=30, max_devices=2): + def m_search(self, address, timeout=30, max_devices=2): """ Perform a HTTP over UDP M-SEARCH query @@ -31,12 +31,12 @@ class UPnP(object): 'usn': }, ...] """ - return self.soap_manager.sspd_factory.m_search(address, ttl=ttl, max_devices=max_devices) + return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices) @defer.inlineCallbacks - def discover(self, ttl=30, max_devices=2): + def discover(self, timeout=1, max_devices=1): try: - yield self.soap_manager.discover_services(ttl=ttl, max_devices=max_devices) + yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices) except defer.TimeoutError: log.warning("failed to find upnp gateway") defer.returnValue(False)