This commit is contained in:
Jack Robison 2018-07-30 17:48:20 -04:00
parent 97a05b72dd
commit bd1c76237c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
15 changed files with 455 additions and 345 deletions

15
LICENSE Normal file
View file

@ -0,0 +1,15 @@
The MIT License (MIT)
Copyright (c) 2015-2018 LBRY Inc
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

19
README.md Normal file
View file

@ -0,0 +1,19 @@
# UPnP for Twisted
`txupnp` is a python2/3 library to interact with UPnP gateways using Twisted
## Installation
`pip install txupnp`
## Usage
To run the test script, `test-txupnp`. This will attempt to find the gateway and will try to set up and tear down an external redirect.
## License
This project is MIT licensed. For the full license, see [LICENSE](LICENSE).
## Contact
The primary contact for this project is @jackrobison(jackrobison@lbry.io)

View file

@ -1,16 +1,24 @@
import os
from setuptools import setup, find_packages
from txupnp import __version__, __name__, __email__, __author__, __license__
console_scripts = [
'test-txupnp = txupnp.tests.test_txupnp:main',
'txupnp-cli = txupnp.cli:main',
]
package_name = "txupnp"
base_dir = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(base_dir, 'README.md'), 'rb') as f:
long_description = f.read().decode('utf-8')
setup(
name="txupnp",
version="0.0.1",
author="Jack Robison",
author_email="jackrobison@lbry.io",
name=__name__,
version=__version__,
author=__author__,
author_email=__email__,
description="UPnP for twisted",
license='MIT',
long_description=long_description,
license=__license__,
packages=find_packages(),
entry_points={'console_scripts': console_scripts},
install_requires=[

View file

@ -1,7 +1,12 @@
__version__ = "0.0.1rc1"
__name__ = "txupnp"
__author__ = "Jack Robison"
__maintainer__ = "Jack Robison"
__license__ = "MIT"
__email__ = "jackrobison@lbry.io"
import logging
# from twisted.python import log
# observer = log.PythonLoggingObserver(loggerName=__name__)
# observer.start()
log = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))

49
txupnp/cli.py Normal file
View file

@ -0,0 +1,49 @@
import sys
import argparse
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def run_command(found, u, command):
if not found:
print("failed to find gateway")
reactor.callLater(0, reactor.stop)
return
if command == "debug":
external_ip = yield u.get_external_ip()
print(u.get_debug_info())
print("external ip: ", external_ip)
if command == "list_mappings":
redirects = yield u.get_redirects()
print("found {} redirects".format(len(redirects)))
for redirect in redirects:
print("\t", redirect)
def main():
parser = argparse.ArgumentParser(description="upnp command line utility")
parser.add_argument("--command", dest="command", type=str, help="debug | list_mappings", default="debug")
parser.add_argument("--debug", dest="debug", default=False, action="store_true")
args = parser.parse_args()
if args.debug:
from twisted.python import log as tx_log
observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
observer.start()
log.setLevel(logging.DEBUG)
command = args.command
if command not in ['debug', 'list_mappings']:
return sys.exit(0)
u = UPnP(reactor)
d = u.discover()
d.addCallback(run_command, u, command)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -5,8 +5,8 @@ XML_VERSION = "<?xml version=\"1.0\"?>"
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/"
SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope"
CONTROL = 'urn:schemas-upnp-org:control-1-0'
SERVICE = 'urn:schemas-upnp-org:service-1-0'
DEVICE = 'urn:schemas-upnp-org:device-1-0'
@ -20,17 +20,17 @@ service_types = [
WAN_SCHEMA,
LAYER_SCHEMA,
IP_SCHEMA,
CONTROL,
SERVICE,
DEVICE,
]
SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_DISCOVER = "ssdp:discover"
SSDP_ALL = "ssdp:all"
M_SEARCH_TEMPLATE = "\r\n".join([
"M-SEARCH * HTTP/1.1",
"HOST: {}:{}",
"ST: {}",
"MAN: \"{}\"",
"MX: {}\r\n\r\n",
])
SSDP_BYEBYE = "ssdp:byebye"
SSDP_UPDATE = "ssdp:update"
SSDP_ROOT_DEVICE = "upnp:rootdevice"
line_separator = "\r\n"

View file

@ -1,4 +1,3 @@
import json
import logging
from twisted.internet import defer
import treq
@ -81,7 +80,7 @@ class RootDevice(object):
if root:
root_device = Device(self, **(root["device"]))
self.devices.append(root_device)
log.info("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services))
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices), len(self.services))
class Gateway(object):
@ -98,11 +97,6 @@ class Gateway(object):
self._device = None
def debug_device(self):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
devices = []
for device in self._device.devices:
info = device.get_info()
@ -111,17 +105,17 @@ class Gateway(object):
for service in self._device.services:
info = service.get_info()
services.append(info)
return json.dumps({
return {
'root_url': self.base_address,
'gateway_xml_url': self.location,
'usn': self.usn,
'devices': devices,
'services': services
}, indent=2, default=default_byte)
}
@defer.inlineCallbacks
def discover_services(self):
log.info("querying %s", self.location)
log.debug("querying %s", self.location)
response = yield treq.get(self.location)
response_xml = yield response.content()
if not response_xml:

View file

@ -1,4 +1,3 @@
import json
import logging
from collections import OrderedDict
from twisted.internet import defer
@ -8,7 +7,8 @@ from treq.client import HTTPClient
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types
from txupnp.constants import IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
log = logging.getLogger(__name__)
@ -40,13 +40,13 @@ def get_soap_body(service_name, method, param_names, **kwargs):
class _SCPDCommand(object):
def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns,
reactor=None):
reactor=None, connection_pool=None, agent=None, http_client=None):
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self._pool = HTTPConnectionPool(reactor)
self.agent = Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self.agent, data_to_body_producer=StringProducer)
self._pool = connection_pool or HTTPConnectionPool(reactor)
self.agent = agent or Agent(reactor, connectTimeout=1)
self._http_client = http_client or HTTPClient(self.agent, data_to_body_producer=StringProducer)
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
@ -128,10 +128,14 @@ class SCPDResponse(object):
class SCPDCommandRunner(object):
def __init__(self, gateway):
def __init__(self, gateway, reactor):
self._gateway = gateway
self._unsupported_actions = {}
self._registered_commands = {}
self._reactor = reactor
self._agent = Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._connection_pool = HTTPConnectionPool(reactor)
@defer.inlineCallbacks
def _discover_commands(self, service):
@ -175,11 +179,12 @@ class SCPDCommandRunner(object):
[i['name'] for i in arg_dicts if i['direction'] == 'out']
)
def __register_command(self, action_info, service_type):
func_info = self._soap_function_info(action_info)
def _patch_command(self, action_info, service_type):
name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(),
self._gateway.get_service(service_type).service_id.encode(), *func_info)
self._gateway.get_service(service_type).service_id.encode(), name, inputs, outputs,
self._reactor, self._connection_pool, self._agent, self._http_client)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result)
@ -191,7 +196,7 @@ class SCPDCommandRunner(object):
def _register_command(self, action_info, service_type):
try:
return self.__register_command(action_info, service_type)
return self._patch_command(action_info, service_type)
except Exception as err:
s = self._unsupported_actions.get(service_type, [])
s.append((action_info, err))
@ -199,10 +204,10 @@ class SCPDCommandRunner(object):
log.error("failed to setup command for %s\n%s", service_type, action_info)
def debug_commands(self):
return json.dumps({
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}, indent=2)
}
@staticmethod
@return_types(none)

View file

@ -29,7 +29,7 @@ class SOAPServiceManager(object):
locations.append(server_info['location'])
gateway = Gateway(**server_info)
yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway)
command_runner = SCPDCommandRunner(gateway, self._reactor)
yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info:
@ -51,3 +51,12 @@ class SOAPServiceManager(object):
def get_available_runners(self):
return self._command_runners.keys()
def debug(self):
results = []
for runner in self._command_runners.values():
gateway = runner._gateway
info = gateway.debug_device()
info.update(runner.debug_commands())
results.append(info)
return results

View file

@ -1,182 +1,15 @@
import logging
import binascii
import re
from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types
from txupnp.constants import SSDP_HOST
from txupnp.fault import UPnPError
from txupnp.constants import GATEWAY_SCHEMA, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_ALL
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger(__name__)
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_BYEBYE = "ssdp:byebye"
SSDP_UPDATE = "ssdp:update"
SSDP_ROOT_DEVICE = "upnp:rootdevice"
line_separator = "\r\n"
class SSDPDatagram(object):
_M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY"
_OK = "OK"
_start_lines = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
_friendly_names = {
_M_SEARCH: "m-search",
_NOTIFY: "notify",
_OK: "m-search response"
}
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = {
_M_SEARCH: [
'host',
'st',
'man',
'mx',
],
_OK: [
'cache_control',
# 'date',
# 'ext',
'location',
'server',
'st',
'usn'
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self):
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
def decode(cls, datagram):
packet = cls._from_string(datagram.decode())
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
result = {}
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result
@classmethod
def _from_string(cls, datagram):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))
class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
@ -192,48 +25,17 @@ class SSDPProtocol(DatagramProtocol):
self.max_devices = max_devices
self.devices = []
def startProtocol(self):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
self.send_m_search(service=st)
def send_m_search(self, service=GATEWAY_SCHEMA):
def _send_m_search(self, service=GATEWAY_SCHEMA):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("writing packet:\n%s", packet.encode())
log.info("sending m-search (%i bytes) to %s:%i", len(packet.encode()), self.ssdp_address, self.ssdp_port)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(packet.encode()), self.ssdp_address, self.ssdp_port)
raise err
def leave_group(self):
self.transport.leaveGroup(self.ssdp_address, interface=self.iface)
def datagramReceived(self, datagram, address):
if address[0] == self.iface:
return
try:
packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
except Exception:
log.exception("failed to decode: %s", binascii.hexlify(datagram))
return
if packet._packet_type == packet._OK:
log.info("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
else:
log.info("%s:%i notified us of a service type: %s", address[0], address[1], packet.st)
if packet.st not in map(lambda p: p['st'], self.devices):
self.devices.append(packet.as_dict())
log.info("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
if address[0] in self.discover_callbacks:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
def gather(finished_deferred, max_results):
@staticmethod
def _gather(finished_deferred, max_results):
results = []
def discover_cb(packet):
@ -244,55 +46,100 @@ def gather(finished_deferred, max_results):
return discover_cb
def m_search(self, address=None, timeout=1, max_devices=1):
address = address or self.iface
# return deferred for a pending call if we have one
if address in self.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one
return d
def _trap_timeout_and_return_results(err):
if err.check(defer.TimeoutError):
return self.devices
raise err
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices)
self.discover_callbacks[address] = found_cb, d
for st in service_types:
self._send_m_search(service=st)
return d
def startProtocol(self):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
self.m_search()
def datagramReceived(self, datagram, address):
if address[0] == self.iface:
return
try:
packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
binascii.hexlify(datagram))
return
except Exception:
log.exception("failed to decode: %s", binascii.hexlify(datagram))
return
if packet._packet_type == packet._OK:
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
else:
log.debug("%s:%i notified us of a service type: %s", address[0], address[1], packet.st)
if packet.st not in map(lambda p: p['st'], self.devices):
self.devices.append(packet.as_dict())
log.debug("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s")
if address[0] in self.discover_callbacks:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
class SSDPFactory(object):
def __init__(self, reactor, lan_address, router_address):
self.lan_address = lan_address
self.router_address = router_address
self._reactor = reactor
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
self.protocol = None
self.port = None
def disconnect(self):
if self.protocol:
self.protocol.leave_group()
self.protocol = None
if not self.port:
return
self.protocol.transport.leaveGroup(SSDP_IP_ADDRESS, interface=self.lan_address)
self.port.stopListening()
self.port = None
self.protocol = None
def connect(self):
if not self.protocol:
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address)
if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks
def m_search(self, address, timeout=30, max_devices=1):
def m_search(self, address, timeout=1, max_devices=1):
"""
Perform a HTTP over UDP M-SEARCH query
Perform a M-SEARCH (HTTP over UDP) and gather the results
returns (list) [{
'server: <gateway os and version string>
'location': <upnp gateway url>,
'cache-control': <max age>,
'date': <server time>,
'usn': <usn>
:param address: (str) address to listen for responses from
:param timeout: (int) timeout for the query
:param max_devices: (int) block until timeout or at least this many devices are found
:param service_types: (list) M-SEARCH "ST" arguments to try, if None use the defaults
:return: (list) [ (dict) {
'server: (str) gateway os and version
'location': (str) upnp gateway url,
'cache-control': (str) max age,
'date': (int) server time,
'usn': (str) usn
}, ...]
"""
self.connect()
if address in self.protocol.discover_callbacks:
d = self.protocol.discover_callbacks[address][1]
else:
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
found_cb = gather(d, max_devices)
self.protocol.discover_callbacks[address] = found_cb, d
for st in [SSDP_ALL, SSDP_ROOT_DEVICE, GATEWAY_SCHEMA, GATEWAY_SCHEMA.lower()]:
self.protocol.send_m_search(service=st)
try:
server_infos = yield d
except defer.TimeoutError:
server_infos = self.protocol.devices
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
defer.returnValue(server_infos)

167
txupnp/ssdp_datagram.py Normal file
View file

@ -0,0 +1,167 @@
import re
import logging
from txupnp.fault import UPnPError
from txupnp.constants import line_separator
log = logging.getLogger(__name__)
class SSDPDatagram(object):
_M_SEARCH = "M-SEARCH"
_NOTIFY = "NOTIFY"
_OK = "OK"
_start_lines = {
_M_SEARCH: "M-SEARCH * HTTP/1.1",
_NOTIFY: "NOTIFY * HTTP/1.1",
_OK: "HTTP/1.1 200 OK"
}
_friendly_names = {
_M_SEARCH: "m-search",
_NOTIFY: "notify",
_OK: "m-search response"
}
_vendor_field_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
'mx': (re.compile("^(?i)(mx):(.*)$"), int),
'nt': (re.compile("^(?i)(nt):(.*)$"), str),
'nts': (re.compile("^(?i)(nts):(.*)$"), str),
'usn': (re.compile("^(?i)(usn):(.*)$"), str),
'location': (re.compile("^(?i)(location):(.*)$"), str),
'cache_control': (re.compile("^(?i)(cache-control):(.*)$"), str),
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_required_fields = {
_M_SEARCH: [
'host',
'st',
'man',
'mx',
],
_OK: [
'cache_control',
# 'date',
# 'ext',
'location',
'server',
'st',
'usn'
]
}
_marshallers = {
'mx': str,
'man': lambda x: ("\"%s\"" % x)
}
def __init__(self, packet_type, host=None, st=None, man=None, mx=None, nt=None, nts=None, usn=None, location=None,
cache_control=None, server=None, date=None, ext=None, **kwargs):
if packet_type not in [self._M_SEARCH, self._NOTIFY, self._OK]:
raise UPnPError("unknown packet type: {}".format(packet_type))
self._packet_type = packet_type
self.host = host
self.st = st
self.man = man
self.mx = mx
self.nt = nt
self.nts = nts
self.usn = usn
self.location = location
self.cache_control = cache_control
self.server = server
self.date = date
self.ext = ext
for k, v in kwargs.items():
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
return getattr(self, i)
raise KeyError(item)
def get_friendly_name(self):
return self._friendly_names[self._packet_type]
def encode(self, trailing_newlines=2):
lines = [self._start_lines[self._packet_type]]
for attr_name in self._required_fields[self._packet_type]:
attr = getattr(self, attr_name)
if attr is None:
raise UPnPError("required field for {} is missing: {}".format(self._packet_type, attr_name))
if attr_name in self._marshallers:
value = self._marshallers[attr_name](attr)
else:
value = attr
lines.append("{}: {}".format(attr_name.upper(), value))
serialized = line_separator.join(lines)
for _ in range(trailing_newlines):
serialized += line_separator
return serialized
def as_dict(self):
return self._lines_to_content_dict(self.encode().split(line_separator))
@classmethod
def decode(cls, datagram):
packet = cls._from_string(datagram.decode())
for attr_name in packet._required_fields[packet._packet_type]:
attr = getattr(packet, attr_name)
if attr is None:
raise UPnPError(
"required field for {} is missing from m-search response: {}".format(packet._packet_type, attr_name)
)
return packet
@classmethod
def _lines_to_content_dict(cls, lines):
result = {}
for line in lines:
if not line:
continue
matched = False
for name, (pattern, field_type) in cls._patterns.items():
if name not in result and pattern.findall(line):
match = pattern.findall(line)[-1][-1]
result[name] = field_type(match.lstrip(" ").rstrip(" "))
matched = True
break
if not matched:
if cls._vendor_field_pattern.findall(line):
match = cls._vendor_field_pattern.findall(line)[-1]
vendor_key = match[0].lstrip(" ").rstrip(" ")
# vendor_domain = match[1].lstrip(" ").rstrip(" ")
value = match[2].lstrip(" ").rstrip(" ")
if vendor_key not in result:
result[vendor_key] = value
return result
@classmethod
def _from_string(cls, datagram):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])
@classmethod
def _from_response(cls, lines):
return cls(cls._OK, **cls._lines_to_content_dict(lines))
@classmethod
def _from_notify(cls, lines):
return cls(cls._NOTIFY, **cls._lines_to_content_dict(lines))
@classmethod
def _from_request(cls, lines):
return cls(cls._M_SEARCH, **cls._lines_to_content_dict(lines))

View file

@ -1,69 +0,0 @@
import sys
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
from txupnp.fault import UPnPError
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def test(ext_port=4446, int_port=4446, proto='UDP', timeout=1):
u = UPnP(reactor)
found = yield u.discover(timeout=timeout)
if not found:
print("failed to find gateway")
defer.returnValue(None)
external_ip = yield u.get_external_ip()
assert external_ip, "Failed to get the external IP"
log.info(external_ip)
try:
yield u.get_specific_port_mapping(ext_port, proto)
except UPnPError as err:
if 'NoSuchEntryInArray' in str(err):
pass
else:
log.error("there is already a redirect")
raise AssertionError()
yield u.add_port_mapping(ext_port, proto, int_port, u.lan_address, 'woah', 0)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("made redirect")
else:
log.error("failed to make redirect")
raise AssertionError()
yield u.delete_port_mapping(ext_port, proto)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) not in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("tore down redirect")
else:
log.error("failed to tear down redirect")
raise AssertionError()
r = yield u.get_rsip_nat_status()
log.info(r)
r = yield u.get_status_info()
log.info(r)
r = yield u.get_connection_type_info()
log.info(r)
@defer.inlineCallbacks
def run_tests():
if len(sys.argv) > 1:
log.setLevel(logging.DEBUG)
timeout = int(sys.argv[1])
else:
timeout = 1
for p in ['UDP']:
yield test(proto=p, timeout=timeout)
def main():
d = run_tests()
d.addErrback(log.exception)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -1,7 +1,9 @@
import logging
import json
from twisted.internet import defer
from txupnp.fault import UPnPError
from txupnp.soap import SOAPServiceManager
from txupnp.util import DeferredDict
log = logging.getLogger(__name__)
@ -17,7 +19,10 @@ class UPnP(object):
@property
def commands(self):
try:
return self.soap_manager.get_runner()
except UPnPError as err:
log.warning("upnp is not available: %s", err)
def m_search(self, address, timeout=30, max_devices=2):
"""
@ -34,13 +39,17 @@ class UPnP(object):
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1):
def discover(self, timeout=1, max_devices=1, keep_listening=False):
try:
yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices)
found = True
except defer.TimeoutError:
log.warning("failed to find upnp gateway")
defer.returnValue(False)
defer.returnValue(True)
found = False
finally:
if not keep_listening:
self.soap_manager.sspd_factory.disconnect()
defer.returnValue(found)
def get_external_ip(self):
return self.commands.GetExternalIPAddress()
@ -69,15 +78,23 @@ class UPnP(object):
break
defer.returnValue(redirects)
@defer.inlineCallbacks
def get_specific_port_mapping(self, external_port, protocol):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
"""
return self.commands.GetSpecificPortMappingEntry(
try:
result = yield self.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
defer.returnValue(result)
except UPnPError as err:
if 'NoSuchEntryInArray' in str(err):
defer.returnValue(None)
raise err
def delete_port_mapping(self, external_port, protocol):
"""
@ -106,3 +123,31 @@ class UPnP(object):
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
"""
return self.commands.GetConnectionTypeInfo()
@defer.inlineCallbacks
def get_next_mapping(self, port, protocol, description):
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
mappings = yield DeferredDict({p: self.get_specific_port_mapping(port, p)
for p in ["UDP", "TCP"]})
if not any((m is not None for m in mappings.values())): # there are no redirects for this port
yield self.add_port_mapping( # set one up
port, protocol, port, self.lan_address, description, 0
)
defer.returnValue(port)
if mappings[protocol]:
mapped_port = mappings[protocol][0]
mapped_address = mappings[protocol][1]
if mapped_port == port and mapped_address == self.lan_address: # reuse redirect to us
defer.returnValue(port)
port = yield self.get_next_mapping( # try the next port
port + 1, protocol, description
)
defer.returnValue(port)
def get_debug_info(self):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
return json.dumps(self.soap_manager.debug(), indent=2, default=default_byte)

View file

@ -2,6 +2,7 @@ import re
import functools
from collections import defaultdict
import netifaces
from twisted.internet import defer
DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$")
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
@ -74,3 +75,18 @@ def return_types(*types):
none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None
@defer.inlineCallbacks
def DeferredDict(d, consumeErrors=False):
keys = []
dl = []
response = {}
for k, v in d.items():
keys.append(k)
dl.append(v)
results = yield defer.DeferredList(dl, consumeErrors=consumeErrors)
for k, (success, result) in zip(keys, results):
if success:
response[k] = result
defer.returnValue(response)