This commit is contained in:
Jack Robison 2018-07-30 17:48:20 -04:00
parent 97a05b72dd
commit bd1c76237c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
15 changed files with 455 additions and 345 deletions

15
LICENSE Normal file
View file

@ -0,0 +1,15 @@
The MIT License (MIT)
Copyright (c) 2015-2018 LBRY Inc
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

19
README.md Normal file
View file

@ -0,0 +1,19 @@
# UPnP for Twisted
`txupnp` is a python2/3 library to interact with UPnP gateways using Twisted
## Installation
`pip install txupnp`
## Usage
To run the test script, `test-txupnp`. This will attempt to find the gateway and will try to set up and tear down an external redirect.
## License
This project is MIT licensed. For the full license, see [LICENSE](LICENSE).
## Contact
The primary contact for this project is @jackrobison(jackrobison@lbry.io)

View file

@ -1,16 +1,24 @@
import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
from txupnp import __version__, __name__, __email__, __author__, __license__
console_scripts = [ console_scripts = [
'test-txupnp = txupnp.tests.test_txupnp:main', 'txupnp-cli = txupnp.cli:main',
] ]
package_name = "txupnp"
base_dir = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(base_dir, 'README.md'), 'rb') as f:
long_description = f.read().decode('utf-8')
setup( setup(
name="txupnp", name=__name__,
version="0.0.1", version=__version__,
author="Jack Robison", author=__author__,
author_email="jackrobison@lbry.io", author_email=__email__,
description="UPnP for twisted", description="UPnP for twisted",
license='MIT', long_description=long_description,
license=__license__,
packages=find_packages(), packages=find_packages(),
entry_points={'console_scripts': console_scripts}, entry_points={'console_scripts': console_scripts},
install_requires=[ install_requires=[

View file

@ -1,7 +1,12 @@
__version__ = "0.0.1rc1"
__name__ = "txupnp"
__author__ = "Jack Robison"
__maintainer__ = "Jack Robison"
__license__ = "MIT"
__email__ = "jackrobison@lbry.io"
import logging import logging
# from twisted.python import log
# observer = log.PythonLoggingObserver(loggerName=__name__)
# observer.start()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s')) handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))

49
txupnp/cli.py Normal file
View file

@ -0,0 +1,49 @@
import sys
import argparse
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def run_command(found, u, command):
if not found:
print("failed to find gateway")
reactor.callLater(0, reactor.stop)
return
if command == "debug":
external_ip = yield u.get_external_ip()
print(u.get_debug_info())
print("external ip: ", external_ip)
if command == "list_mappings":
redirects = yield u.get_redirects()
print("found {} redirects".format(len(redirects)))
for redirect in redirects:
print("\t", redirect)
def main():
parser = argparse.ArgumentParser(description="upnp command line utility")
parser.add_argument("--command", dest="command", type=str, help="debug | list_mappings", default="debug")
parser.add_argument("--debug", dest="debug", default=False, action="store_true")
args = parser.parse_args()
if args.debug:
from twisted.python import log as tx_log
observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
observer.start()
log.setLevel(logging.DEBUG)
command = args.command
if command not in ['debug', 'list_mappings']:
return sys.exit(0)
u = UPnP(reactor)
d = u.discover()
d.addCallback(run_command, u, command)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -5,8 +5,8 @@ XML_VERSION = "<?xml version=\"1.0\"?>"
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault" FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope" ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body" BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/"
SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope"
CONTROL = 'urn:schemas-upnp-org:control-1-0' CONTROL = 'urn:schemas-upnp-org:control-1-0'
SERVICE = 'urn:schemas-upnp-org:service-1-0' SERVICE = 'urn:schemas-upnp-org:service-1-0'
DEVICE = 'urn:schemas-upnp-org:device-1-0' DEVICE = 'urn:schemas-upnp-org:device-1-0'
@ -20,17 +20,17 @@ service_types = [
WAN_SCHEMA, WAN_SCHEMA,
LAYER_SCHEMA, LAYER_SCHEMA,
IP_SCHEMA, IP_SCHEMA,
CONTROL,
SERVICE,
DEVICE,
] ]
SSDP_IP_ADDRESS = '239.255.255.250' SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900 SSDP_PORT = 1900
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_DISCOVER = "ssdp:discover" SSDP_DISCOVER = "ssdp:discover"
SSDP_ALL = "ssdp:all" SSDP_ALL = "ssdp:all"
SSDP_BYEBYE = "ssdp:byebye"
M_SEARCH_TEMPLATE = "\r\n".join([ SSDP_UPDATE = "ssdp:update"
"M-SEARCH * HTTP/1.1", SSDP_ROOT_DEVICE = "upnp:rootdevice"
"HOST: {}:{}", line_separator = "\r\n"
"ST: {}",
"MAN: \"{}\"",
"MX: {}\r\n\r\n",
])

View file

@ -1,4 +1,3 @@
import json
import logging import logging
from twisted.internet import defer from twisted.internet import defer
import treq import treq
@ -81,7 +80,7 @@ class RootDevice(object):
if root: if root:
root_device = Device(self, **(root["device"])) root_device = Device(self, **(root["device"]))
self.devices.append(root_device) self.devices.append(root_device)
log.info("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services)) log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services))
class Gateway(object): class Gateway(object):
@ -98,11 +97,6 @@ class Gateway(object):
self._device = None self._device = None
def debug_device(self): def debug_device(self):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
devices = [] devices = []
for device in self._device.devices: for device in self._device.devices:
info = device.get_info() info = device.get_info()
@ -111,17 +105,17 @@ class Gateway(object):
for service in self._device.services: for service in self._device.services:
info = service.get_info() info = service.get_info()
services.append(info) services.append(info)
return json.dumps({ return {
'root_url': self.base_address, 'root_url': self.base_address,
'gateway_xml_url': self.location, 'gateway_xml_url': self.location,
'usn': self.usn, 'usn': self.usn,
'devices': devices, 'devices': devices,
'services': services 'services': services
}, indent=2, default=default_byte) }
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self): def discover_services(self):
log.info("querying %s", self.location) log.debug("querying %s", self.location)
response = yield treq.get(self.location) response = yield treq.get(self.location)
response_xml = yield response.content() response_xml = yield response.content()
if not response_xml: if not response_xml:

View file

@ -1,4 +1,3 @@
import json
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from twisted.internet import defer from twisted.internet import defer
@ -8,7 +7,8 @@ from treq.client import HTTPClient
from xml.etree import ElementTree from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
from txupnp.fault import handle_fault, UPnPError from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types from txupnp.constants import IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,13 +40,13 @@ def get_soap_body(service_name, method, param_names, **kwargs):
class _SCPDCommand(object): class _SCPDCommand(object):
def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns, def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns,
reactor=None): reactor=None, connection_pool=None, agent=None, http_client=None):
if not reactor: if not reactor:
from twisted.internet import reactor from twisted.internet import reactor
self._reactor = reactor self._reactor = reactor
self._pool = HTTPConnectionPool(reactor) self._pool = connection_pool or HTTPConnectionPool(reactor)
self.agent = Agent(reactor, connectTimeout=1) self.agent = agent or Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self.agent, data_to_body_producer=StringProducer) self._http_client = http_client or HTTPClient(self.agent, data_to_body_producer=StringProducer)
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.service_port = service_port self.service_port = service_port
self.control_url = control_url self.control_url = control_url
@ -128,10 +128,14 @@ class SCPDResponse(object):
class SCPDCommandRunner(object): class SCPDCommandRunner(object):
def __init__(self, gateway): def __init__(self, gateway, reactor):
self._gateway = gateway self._gateway = gateway
self._unsupported_actions = {} self._unsupported_actions = {}
self._registered_commands = {} self._registered_commands = {}
self._reactor = reactor
self._agent = Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._connection_pool = HTTPConnectionPool(reactor)
@defer.inlineCallbacks @defer.inlineCallbacks
def _discover_commands(self, service): def _discover_commands(self, service):
@ -175,11 +179,12 @@ class SCPDCommandRunner(object):
[i['name'] for i in arg_dicts if i['direction'] == 'out'] [i['name'] for i in arg_dicts if i['direction'] == 'out']
) )
def __register_command(self, action_info, service_type): def _patch_command(self, action_info, service_type):
func_info = self._soap_function_info(action_info) name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._gateway.base_address, self._gateway.port, command = _SCPDCommand(self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(), self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(),
self._gateway.get_service(service_type).service_id.encode(), *func_info) self._gateway.get_service(service_type).service_id.encode(), name, inputs, outputs,
self._reactor, self._connection_pool, self._agent, self._http_client)
current = getattr(self, command.method) current = getattr(self, command.method)
if hasattr(current, "_return_types"): if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result) command._process_result = _return_types(*current._return_types)(command._process_result)
@ -191,7 +196,7 @@ class SCPDCommandRunner(object):
def _register_command(self, action_info, service_type): def _register_command(self, action_info, service_type):
try: try:
return self.__register_command(action_info, service_type) return self._patch_command(action_info, service_type)
except Exception as err: except Exception as err:
s = self._unsupported_actions.get(service_type, []) s = self._unsupported_actions.get(service_type, [])
s.append((action_info, err)) s.append((action_info, err))
@ -199,10 +204,10 @@ class SCPDCommandRunner(object):
log.error("failed to setup command for %s\n%s", service_type, action_info) log.error("failed to setup command for %s\n%s", service_type, action_info)
def debug_commands(self): def debug_commands(self):
return json.dumps({ return {
'available': self._registered_commands, 'available': self._registered_commands,
'failed': self._unsupported_actions 'failed': self._unsupported_actions
}, indent=2) }
@staticmethod @staticmethod
@return_types(none) @return_types(none)

View file

@ -29,7 +29,7 @@ class SOAPServiceManager(object):
locations.append(server_info['location']) locations.append(server_info['location'])
gateway = Gateway(**server_info) gateway = Gateway(**server_info)
yield gateway.discover_services() yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway) command_runner = SCPDCommandRunner(gateway, self._reactor)
yield command_runner.discover_commands() yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info: elif 'st' not in server_info:
@ -51,3 +51,12 @@ class SOAPServiceManager(object):
def get_available_runners(self): def get_available_runners(self):
return self._command_runners.keys() return self._command_runners.keys()
def debug(self):
results = []
for runner in self._command_runners.values():
gateway = runner._gateway
info = gateway.debug_device()
info.update(runner.debug_commands())
results.append(info)
return results

View file

@ -1,182 +1,15 @@
import logging import logging
import binascii import binascii
import re
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol from twisted.internet.protocol import DatagramProtocol
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types
from txupnp.constants import SSDP_HOST
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_BYEBYE = "ssdp:byebye"
SSDP_UPDATE = "ssdp:update"
SSDP_ROOT_DEVICE = "upnp:rootdevice"
line_separator = "\r\n"
class SSDPDatagram(object):
_M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY"
_OK = "OK"
_start_lines = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
_friendly_names = {
_M_SEARCH: "m-search",
_NOTIFY: "notify",
_OK: "m-search response"
}
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = {
_M_SEARCH: [
'host',
'st',
'man',
'mx',
],
_OK: [
'cache_control',
# 'date',
# 'ext',
'location',
'server',
'st',
'usn'
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self):
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
def decode(cls, datagram):
packet = cls._from_string(datagram.decode())
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
result = {}
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result
@classmethod
def _from_string(cls, datagram):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
class SSDPProtocol(DatagramProtocol): class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS, def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None): ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
@ -192,26 +25,55 @@ class SSDPProtocol(DatagramProtocol):
self.max_devices = max_devices self.max_devices = max_devices
self.devices = [] self.devices = []
def startProtocol(self): def _send_m_search(self, service=GATEWAY_SCHEMA):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
self.send_m_search(service=st)
def send_m_search(self, service=GATEWAY_SCHEMA):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1) packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("writing packet:\n%s", packet.encode()) log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
log.info("sending m-search (%i bytes) to %s:%i", len(packet.encode()), self.ssdp_address, self.ssdp_port)
try: try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port)) self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err: except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port) log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
raise err raise err
def leave_group(self): @staticmethod
self.transport.leaveGroup(self.ssdp_address, interface=self.iface) def _gather(finished_deferred, max_results):
results = []
def discover_cb(packet):
if not finished_deferred.called:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
def m_search(self, address=None, timeout=1, max_devices=1):
address = address or self.iface
# return deferred for a pending call if we have one
if address in self.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one
return d
def _trap_timeout_and_return_results(err):
if err.check(defer.TimeoutError):
return self.devices
raise err
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices)
self.discover_callbacks[address] = found_cb, d
for st in service_types:
self._send_m_search(service=st)
return d
def startProtocol(self):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
self.m_search()
def datagramReceived(self, datagram, address): def datagramReceived(self, datagram, address):
if address[0] == self.iface: if address[0] == self.iface:
@ -219,80 +81,65 @@ class SSDPProtocol(DatagramProtocol):
try: try:
packet = SSDPDatagram.decode(datagram) packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode()) log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
binascii.hexlify(datagram))
return
except Exception: except Exception:
log.exception("failed to decode: %s", binascii.hexlify(datagram)) log.exception("failed to decode: %s", binascii.hexlify(datagram))
return return
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
log.info("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location) log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
else: else:
log.info("%s:%i notified us of a service type: %s", address[0], address[1], packet.st) log.debug("%s:%i notified us of a service type: %s", address[0], address[1], packet.st)
if packet.st not in map(lambda p: p['st'], self.devices): if packet.st not in map(lambda p: p['st'], self.devices):
self.devices.append(packet.as_dict()) self.devices.append(packet.as_dict())
log.info("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s") log.debug("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
if address[0] in self.discover_callbacks: if address[0] in self.discover_callbacks:
self._sem.run(self.discover_callbacks[address[0]][0], packet) self._sem.run(self.discover_callbacks[address[0]][0], packet)
def gather(finished_deferred, max_results):
results = []
def discover_cb(packet):
if not finished_deferred.called:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
class SSDPFactory(object): class SSDPFactory(object):
def __init__(self, reactor, lan_address, router_address): def __init__(self, reactor, lan_address, router_address):
self.lan_address = lan_address self.lan_address = lan_address
self.router_address = router_address self.router_address = router_address
self._reactor = reactor self._reactor = reactor
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address) self.protocol = None
self.port = None self.port = None
def disconnect(self): def disconnect(self):
if self.protocol:
self.protocol.leave_group()
self.protocol = None
if not self.port: if not self.port:
return return
self.protocol.transport.leaveGroup(SSDP_IP_ADDRESS, interface=self.lan_address)
self.port.stopListening() self.port.stopListening()
self.port = None self.port = None
self.protocol = None
def connect(self): def connect(self):
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect) if not self.protocol:
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True) self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def m_search(self, address, timeout=30, max_devices=1): def m_search(self, address, timeout=1, max_devices=1):
""" """
Perform a HTTP over UDP M-SEARCH query Perform a M-SEARCH (HTTP over UDP) and gather the results
returns (list) [{ :param address: (str) address to listen for responses from
'server: <gateway os and version string> :param timeout: (int) timeout for the query
'location': <upnp gateway url>, :param max_devices: (int) block until timeout or at least this many devices are found
'cache-control': <max age>, :param service_types: (list) M-SEARCH "ST" arguments to try, if None use the defaults
'date': <server time>, :return: (list) [ (dict) {
'usn': <usn> 'server: (str) gateway os and version
'location': (str) upnp gateway url,
'cache-control': (str) max age,
'date': (int) server time,
'usn': (str) usn
}, ...] }, ...]
""" """
self.connect() self.connect()
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
if address in self.protocol.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
else:
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
found_cb = gather(d, max_devices)
self.protocol.discover_callbacks[address] = found_cb, d
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
self.protocol.send_m_search(service=st)
try:
server_infos = yield d
except defer.TimeoutError:
server_infos = self.protocol.devices
defer.returnValue(server_infos) defer.returnValue(server_infos)

167
txupnp/ssdp_datagram.py Normal file
View file

@ -0,0 +1,167 @@
import re
import logging
from txupnp.fault import UPnPError
from txupnp.constants import line_separator
log = logging.getLogger(__name__)
class SSDPDatagram(object):
_M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY"
_OK = "OK"
_start_lines = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
_friendly_names = {
_M_SEARCH: "m-search",
_NOTIFY: "notify",
_OK: "m-search response"
}
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = {
_M_SEARCH: [
'host',
'st',
'man',
'mx',
],
_OK: [
'cache_control',
# 'date',
# 'ext',
'location',
'server',
'st',
'usn'
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self):
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
def decode(cls, datagram):
packet = cls._from_string(datagram.decode())
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
result = {}
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result
@classmethod
def _from_string(cls, datagram):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))

View file

@ -1,69 +0,0 @@
import sys
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
from txupnp.fault import UPnPError
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def test(ext_port=4446, int_port=4446, proto='UDP', timeout=1):
u = UPnP(reactor)
found = yield u.discover(timeout=timeout)
if not found:
print("failed to find gateway")
defer.returnValue(None)
external_ip = yield u.get_external_ip()
assert external_ip, "Failed to get the external IP"
log.info(external_ip)
try:
yield u.get_specific_port_mapping(ext_port, proto)
except UPnPError as err:
if 'NoSuchEntryInArray' in str(err):
pass
else:
log.error("there is already a redirect")
raise AssertionError()
yield u.add_port_mapping(ext_port, proto, int_port, u.lan_address, 'woah', 0)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("made redirect")
else:
log.error("failed to make redirect")
raise AssertionError()
yield u.delete_port_mapping(ext_port, proto)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) not in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("tore down redirect")
else:
log.error("failed to tear down redirect")
raise AssertionError()
r = yield u.get_rsip_nat_status()
log.info(r)
r = yield u.get_status_info()
log.info(r)
r = yield u.get_connection_type_info()
log.info(r)
@defer.inlineCallbacks
def run_tests():
if len(sys.argv) > 1:
log.setLevel(logging.DEBUG)
timeout = int(sys.argv[1])
else:
timeout = 1
for p in ['UDP']:
yield test(proto=p, timeout=timeout)
def main():
d = run_tests()
d.addErrback(log.exception)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -1,7 +1,9 @@
import logging import logging
import json
from twisted.internet import defer from twisted.internet import defer
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.soap import SOAPServiceManager from txupnp.soap import SOAPServiceManager
from txupnp.util import DeferredDict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -17,7 +19,10 @@ class UPnP(object):
@property @property
def commands(self): def commands(self):
return self.soap_manager.get_runner() try:
return self.soap_manager.get_runner()
except UPnPError as err:
log.warning("upnp is not available: %s", err)
def m_search(self, address, timeout=30, max_devices=2): def m_search(self, address, timeout=30, max_devices=2):
""" """
@ -34,13 +39,17 @@ class UPnP(object):
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices) return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks @defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1): def discover(self, timeout=1, max_devices=1, keep_listening=False):
try: try:
yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices) yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
found = True
except defer.TimeoutError: except defer.TimeoutError:
log.warning("failed to find upnp gateway") log.warning("failed to find upnp gateway")
defer.returnValue(False) found = False
defer.returnValue(True) finally:
if not keep_listening:
self.soap_manager.sspd_factory.disconnect()
defer.returnValue(found)
def get_external_ip(self): def get_external_ip(self):
return self.commands.GetExternalIPAddress() return self.commands.GetExternalIPAddress()
@ -69,15 +78,23 @@ class UPnP(object):
break break
defer.returnValue(redirects) defer.returnValue(redirects)
@defer.inlineCallbacks
def get_specific_port_mapping(self, external_port, protocol): def get_specific_port_mapping(self, external_port, protocol):
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time> :return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
""" """
return self.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol try:
) result = yield self.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
defer.returnValue(result)
except UPnPError as err:
if 'NoSuchEntryInArray' in str(err):
defer.returnValue(None)
raise err
def delete_port_mapping(self, external_port, protocol): def delete_port_mapping(self, external_port, protocol):
""" """
@ -106,3 +123,31 @@ class UPnP(object):
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str) :return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
""" """
return self.commands.GetConnectionTypeInfo() return self.commands.GetConnectionTypeInfo()
@defer.inlineCallbacks
def get_next_mapping(self, port, protocol, description):
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
mappings = yield DeferredDict({p: self.get_specific_port_mapping(port, p)
for p in ["UDP", "TCP"]})
if not any((m is not None for m in mappings.values())): # there are no redirects for this port
yield self.add_port_mapping( # set one up
port, protocol, port, self.lan_address, description, 0
)
defer.returnValue(port)
if mappings[protocol]:
mapped_port = mappings[protocol][0]
mapped_address = mappings[protocol][1]
if mapped_port == port and mapped_address == self.lan_address: # reuse redirect to us
defer.returnValue(port)
port = yield self.get_next_mapping( # try the next port
port + 1, protocol, description
)
defer.returnValue(port)
def get_debug_info(self):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
return json.dumps(self.soap_manager.debug(), indent=2, default=default_byte)

View file

@ -2,6 +2,7 @@ import re
import functools import functools
from collections import defaultdict from collections import defaultdict
import netifaces import netifaces
from twisted.internet import defer
DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$") DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$")
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
@ -74,3 +75,18 @@ def return_types(*types):
none_or_str = lambda x: None if not x or x == 'None' else str(x) none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None none = lambda _: None
@defer.inlineCallbacks
def DeferredDict(d, consumeErrors=False):
keys = []
dl = []
response = {}
for k, v in d.items():
keys.append(k)
dl.append(v)
results = yield defer.DeferredList(dl, consumeErrors=consumeErrors)
for k, (success, result) in zip(keys, results):
if success:
response[k] = result
defer.returnValue(response)