This commit is contained in:
Jack Robison 2018-07-28 22:08:24 -04:00
parent 01dc5d75d1
commit c1dad347ec
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
12 changed files with 451 additions and 164 deletions

View file

@ -1,5 +1,9 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
console_scripts = [
'test-txupnp = txupnp.tests.test_txupnp:main',
]
setup( setup(
name="txupnp", name="txupnp",
version="0.0.1", version="0.0.1",
@ -8,9 +12,12 @@ setup(
description="UPnP for twisted", description="UPnP for twisted",
license='MIT', license='MIT',
packages=find_packages(), packages=find_packages(),
entry_points={'console_scripts': console_scripts},
install_requires=[ install_requires=[
'Twisted', 'Twisted',
'treq', 'treq',
'netifaces' 'netifaces',
'pycryptodome',
'service-identity'
], ],
) )

View file

@ -1,5 +1,7 @@
import logging import logging
# from twisted.python import log
# observer = log.PythonLoggingObserver(loggerName=__name__)
# observer.start()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s')) handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))

View file

@ -1,16 +1,26 @@
POST = "POST" POST = "POST"
ROOT = "root"
SPEC_VERSION = "specVersion"
XML_VERSION = "<?xml version=\"1.0\"?>" XML_VERSION = "<?xml version=\"1.0\"?>"
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault" FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope" ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body" BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/" SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/"
SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope" SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope"
CONTROL_KEY = 'urn:schemas-upnp-org:control-1-0' CONTROL = 'urn:schemas-upnp-org:control-1-0'
SERVICE_KEY = 'urn:schemas-upnp-org:service-1-0' SERVICE = 'urn:schemas-upnp-org:service-1-0'
DEVICE = 'urn:schemas-upnp-org:device-1-0'
GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1' GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
WAN_INTERFACE_KEY = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1' WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
LAYER_FORWARD_KEY = 'urn:schemas-upnp-org:service:Layer3Forwarding:1' LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
WAN_IP_KEY = 'urn:schemas-upnp-org:service:WANIPConnection:1' IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
service_types = [
GATEWAY_SCHEMA,
WAN_SCHEMA,
LAYER_SCHEMA,
IP_SCHEMA,
]
SSDP_IP_ADDRESS = '239.255.255.250' SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900 SSDP_PORT = 1900

View file

@ -1,5 +1,5 @@
from txupnp.util import flatten_keys from txupnp.util import flatten_keys
from txupnp.constants import FAULT, CONTROL_KEY from txupnp.constants import FAULT, CONTROL
class UPnPError(Exception): class UPnPError(Exception):
@ -8,6 +8,6 @@ class UPnPError(Exception):
def handle_fault(response): def handle_fault(response):
if FAULT in response: if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL_KEY) fault = flatten_keys(response[FAULT], "{%s}" % CONTROL)
raise UPnPError(fault['detail']['UPnPError']['errorDescription']) raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
return response return response

95
txupnp/gateway.py Normal file
View file

@ -0,0 +1,95 @@
import logging
from twisted.internet import defer
import treq
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys
from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import DEVICE, ROOT
from txupnp.constants import SPEC_VERSION
log = logging.getLogger(__name__)
class Service(object):
def __init__(self, serviceType, serviceId, SCPDURL, eventSubURL, controlURL):
self.service_type = serviceType
self.service_id = serviceId
self.control_path = controlURL
self.subscribe_path = eventSubURL
self.scpd_path = SCPDURL
class Device(object):
def __init__(self, _root_device, deviceType=None, friendlyName=None, manufacturer=None, manufacturerURL=None,
modelDescription=None, modelName=None, modelNumber=None, modelURL=None, serialNumber=None,
UDN=None, serviceList=None, deviceList=None, **kwargs):
serviceList = serviceList or {}
deviceList = deviceList or {}
self._root_device = _root_device
self.device_type = deviceType
self.friendly_name = friendlyName
self.manufacturer = manufacturer
self.manufacturer_url = manufacturerURL
self.model_description = modelDescription
self.model_name = modelName
self.model_number = modelNumber
self.model_url = modelURL
self.serial_number = serialNumber
self.udn = UDN
services = serviceList["service"]
if isinstance(services, dict):
services = [services]
services = [Service(**service) for service in services]
self._root_device.services.extend(services)
devices = [Device(self._root_device, **deviceList[k]) for k in deviceList]
self._root_device.devices.extend(devices)
class RootDevice(object):
def __init__(self, xml_string):
root = flatten_keys(etree_to_dict(ElementTree.fromstring(xml_string)), "{%s}" % DEVICE)[ROOT]
self.spec_version = root.get(SPEC_VERSION)
self.url_base = root["URLBase"]
self.devices = []
self.services = []
root_device = Device(self, **(root["device"]))
self.devices.append(root_device)
log.info("finished setting up root device. %i devices and %i services", len(self.devices), len(self.services))
class Gateway(object):
def __init__(self, usn, ext, server, location, cache_control, date, st):
self.usn = usn.encode()
self.ext = ext.encode()
self.server = server.encode()
self.location = location.encode()
self.cache_control = cache_control.encode()
self.date = date.encode()
self.urn = st.encode()
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self._device = None
@defer.inlineCallbacks
def discover_services(self):
log.info("querying %s", self.location)
response = yield treq.get(self.location)
response_xml = yield response.text()
self._device = RootDevice(response_xml)
@property
def services(self):
if not self._device:
return {}
return {service.service_type: service for service in self._device.services}
@property
def devices(self):
if not self._device:
return {}
return {device.udn: device for device in self._device.devices}
def get_service(self, service_type):
for service in self._device.services:
if service.service_type == service_type:
return service

View file

@ -1,13 +1,13 @@
import logging import logging
from collections import namedtuple, OrderedDict from collections import OrderedDict
from twisted.internet import defer from twisted.internet import defer
from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.client import Agent, HTTPConnectionPool
import treq import treq
from treq.client import HTTPClient from treq.client import HTTPClient
from xml.etree import ElementTree from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str, none
from txupnp.fault import handle_fault from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, WAN_IP_KEY, SERVICE_KEY, SSDP_IP_ADDRESS from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, IP_SCHEMA, SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -54,7 +54,7 @@ class _SCPDCommand(object):
self.param_names = param_names self.param_names = param_names
self.returns = returns self.returns = returns
def extract_body(self, xml_response, service_key=WAN_IP_KEY): def extract_body(self, xml_response, service_key=IP_SCHEMA):
content_dict = etree_to_dict(ElementTree.fromstring(xml_response)) content_dict = etree_to_dict(ElementTree.fromstring(xml_response))
envelope = content_dict[ENVELOPE] envelope = content_dict[ENVELOPE]
return flatten_keys(envelope[BODY], "{%s}" % service_key) return flatten_keys(envelope[BODY], "{%s}" % service_key)
@ -64,9 +64,8 @@ class _SCPDCommand(object):
if '%sResponse' % self.method in body: if '%sResponse' % self.method in body:
response_key = '%sResponse' % self.method response_key = '%sResponse' % self.method
else: else:
1/0 log.error(body.keys())
return raise UPnPError("unknown response fields")
response = body[response_key] response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns]) extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1: if len(extracted_response) == 1:
@ -81,8 +80,7 @@ class _SCPDCommand(object):
('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))), ('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))),
('Content-Type', 'text/xml'), ('Content-Type', 'text/xml'),
('Content-Length', len(soap_body)) ('Content-Length', len(soap_body))
) ))
)
response = yield self._http_client.request( response = yield self._http_client.request(
POST, url=self.control_url, data=soap_body, headers=headers POST, url=self.control_url, data=soap_body, headers=headers
) )
@ -107,20 +105,50 @@ class _SCPDCommand(object):
defer.returnValue(result) defer.returnValue(result)
class SCPDCommandManager(object): class SCPDResponse(object):
def __init__(self, upnp): def __init__(self, url, headers, content):
self._upnp = upnp self.url = url
self.headers = headers
self.content = content
def get_element_tree(self):
return ElementTree.fromstring(self.content)
def get_element_dict(self, service_key):
return flatten_keys(etree_to_dict(self.get_element_tree()), "{%s}" % service_key)
def get_action_list(self):
return self.get_element_dict(SERVICE)["scpd"]["actionList"]["action"]
def get_device_info(self):
return self.get_element_dict(DEVICE)[ROOT]
class SCPDCommandRunner(object):
def __init__(self, gateway):
self._gateway = gateway
self._unsupported_actions = []
self._scpd_responses = []
@defer.inlineCallbacks
def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.scpd_path.encode()
response = yield treq.get(scpd_url)
content = yield response.content()
scpd_response = SCPDResponse(scpd_url,
response.headers, content)
self._scpd_responses.append(scpd_response)
for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.service_type)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_commands(self): def discover_commands(self):
response = yield treq.get(self._upnp.wan_ip.scpd_url) for service_type in service_types:
content = yield response.content() service = self._gateway.get_service(service_type)
tree = ElementTree.fromstring(content) if not service:
actions = flatten_keys(etree_to_dict(tree), "{%s}" % SERVICE_KEY)["scpd"]["actionList"]["action"] continue
for action_dict in actions: yield self._discover_commands(service)
self._register_command(action_dict)
log.info("registered %i commands", len(actions))
defer.returnValue(None)
@staticmethod @staticmethod
def _soap_function_info(action_dict): def _soap_function_info(action_dict):
@ -139,24 +167,31 @@ class SCPDCommandManager(object):
[i['name'] for i in arg_dicts if i['direction'] == 'out'] [i['name'] for i in arg_dicts if i['direction'] == 'out']
) )
def _register_command(self, action_info): def _register_command(self, action_info, service_type):
command = _SCPDCommand(self._upnp.gateway_ip, self._upnp.gateway_port, self._upnp.wan_ip.control_url, func_info = self._soap_function_info(action_info)
self._upnp.wan_ip.service_id, *self._soap_function_info(action_info)) command = _SCPDCommand(self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).control_path.encode(),
self._gateway.get_service(service_type).service_id.encode(), *func_info)
if not hasattr(self, command.method): if not hasattr(self, command.method):
raise NotImplementedError(command.method) self._unsupported_actions.append(action_info)
print(("# send this to jack!\n\n@staticmethod\ndef %s(" % func_info[0]) + ("" if not func_info[1] else ", ".join(func_info[1])) + ("):\n \"\"\"Returns (%s)\"\"\"\n raise NotImplementedError()\n\n" % ("None" if not func_info[2] else ", ".join(func_info[2]))))
return
current = getattr(self, command.method) current = getattr(self, command.method)
if hasattr(current, "_return_types"): if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result) command._process_result = _return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__) setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command) setattr(self, command.method, command)
# log.info("registered %s::%s", service_type, action_info['name'])
@staticmethod @staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration): NewEnabled, NewPortMappingDescription, NewLeaseDuration):
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus(): def GetNATRSIPStatus():
"""Returns (NewRSIPAvailable, NewNATEnabled)""" """Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError() raise NotImplementedError()
@ -177,36 +212,83 @@ class SCPDCommandManager(object):
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType): def SetConnectionType(NewConnectionType):
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(str)
def GetExternalIPAddress(): def GetExternalIPAddress():
"""Returns (NewExternalIPAddress)""" """Returns (NewExternalIPAddress)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(str, str)
def GetConnectionTypeInfo(): def GetConnectionTypeInfo():
"""Returns (NewConnectionType, NewPossibleConnectionTypes)""" """Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(str, str, int)
def GetStatusInfo(): def GetStatusInfo():
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(none)
def ForceTermination(): def ForceTermination():
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol): def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@return_types(none)
def RequestConnection(): def RequestConnection():
"""Returns None""" """Returns None"""
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService):
"""Returns (None)"""
raise NotImplementedError()

47
txupnp/soap.py Normal file
View file

@ -0,0 +1,47 @@
import logging
from twisted.internet import defer
from txupnp.util import get_lan_info
from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway
from txupnp.constants import GATEWAY_SCHEMA
log = logging.getLogger(__name__)
class SOAPServiceManager(object):
def __init__(self, reactor):
self._reactor = reactor
self.iface_name, self.router_ip, self.lan_address = get_lan_info()
self.sspd_factory = SSDPFactory(self.lan_address, self._reactor)
self._command_runners = {}
self._selected_runner = GATEWAY_SCHEMA
@defer.inlineCallbacks
def discover_services(self, address=None, ttl=30, max_devices=2):
server_infos = yield self.sspd_factory.m_search(
address or self.router_ip, ttl=ttl, max_devices=max_devices
)
locations = []
for server_info in server_infos:
if server_info['st'] not in self._command_runners:
locations.append(server_info['location'])
gateway = Gateway(**server_info)
yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway)
yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner
defer.returnValue(len(self._command_runners))
def set_runner(self, urn):
if urn not in self._command_runners:
raise IndexError(urn)
self._command_runners = urn
def get_runner(self):
if self._selected_runner and self._command_runners and self._selected_runner not in self._command_runners:
self._selected_runner = self._command_runners.keys()[0]
return self._command_runners[self._selected_runner]
def get_available_runners(self):
return self._command_runners.keys()

View file

@ -1,15 +1,48 @@
import logging import logging
import binascii
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol from twisted.internet.protocol import DatagramProtocol
from txupnp.util import get_lan_info from txupnp.fault import UPnPError
from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def parse_http_fields(content_lines):
return {
(k.lower().rstrip(":".encode()).replace("-".encode(), "_".encode())).decode(): v.decode()
for k, v in {
l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:])
for l in content_lines
}.items() if k
}
def parse_ssdp_request(operation, port, protocol, content_lines):
if operation != "NOTIFY".encode():
log.warning("unsupported operation: %s", operation)
raise UPnPError("unsupported operation: %s" % operation)
if port != "*".encode():
log.warning("unexpected port: %s", port)
raise UPnPError("unexpected port: %s" % port)
return parse_http_fields(content_lines)
def parse_ssdp_response(code, response, content_lines):
try:
if int(code) != 200:
raise UPnPError("unexpected http response code: %i" % int(code))
except ValueError:
log.error(response)
raise UPnPError("unexpected http response code: %s" % code)
if response != "OK".encode():
raise UPnPError("unexpected response: %s" % response)
return parse_http_fields(content_lines)
class SSDPProtocol(DatagramProtocol): class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS, def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1): ssdp_port=SSDP_PORT, ttl=1, max_devices=None):
self._reactor = reactor self._reactor = reactor
self._sem = defer.DeferredSemaphore(1) self._sem = defer.DeferredSemaphore(1)
self.finished_deferred = finished_deferred self.finished_deferred = finished_deferred
@ -19,41 +52,51 @@ class SSDPProtocol(DatagramProtocol):
self.ssdp_port = ssdp_port self.ssdp_port = ssdp_port
self.ttl = ttl self.ttl = ttl
self._start = None self._start = None
self.max_devices = max_devices
@staticmethod self.devices = []
def parse_ssdp_response(datagram):
lines = datagram.split("\r\n".encode())
if not lines:
return
protocol, code, response = lines[0].split(" ".encode())
if int(code) != 200:
raise Exception("unexpected http response code")
if response != "OK".encode():
raise Exception("unexpected response")
fields = {
k.lower(): v
for k, v in {
l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:])
for l in lines[1:]
}.items() if k
}
return fields
def startProtocol(self): def startProtocol(self):
return self._sem.run(self.do_start) return self._sem.run(self.do_start)
def send_m_search(self):
data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_DISCOVER, self.ttl)
try:
log.info("sending m-search (%i bytes) to %s:%i", len(data), self.ssdp_address, self.ssdp_port)
self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", binascii.hexlify(data), self.ssdp_address, self.ssdp_port)
raise err
def parse_ssdp_datagram(self, datagram):
lines = datagram.split("\r\n".encode())
header_pieces = lines[0].split(" ".encode())
protocols = {
"HTTP/1.1".encode()
}
operations = {
"M-SEARCH".encode(),
"NOTIFY".encode()
}
if header_pieces[0] in operations:
if header_pieces[2] not in protocols:
raise UPnPError("unknown protocol: %s" % header_pieces[2])
return parse_ssdp_request(header_pieces[0], header_pieces[1], header_pieces[2], lines[1:])
if header_pieces[0] in protocols:
parsed = parse_ssdp_response(header_pieces[1], header_pieces[2], lines[1:])
log.info("received reply (%i bytes) to SSDP request (%f) (%s) %s", len(datagram),
self._reactor.seconds() - self._start, parsed['location'], parsed['server'])
return parsed
raise UPnPError("don't know how to decode datagram: %s" % binascii.hexlify(datagram))
def do_start(self): def do_start(self):
self._start = self._reactor.seconds() self._start = self._reactor.seconds()
self.finished_deferred.addTimeout(self.ttl, self._reactor) self.finished_deferred.addTimeout(self.ttl, self._reactor)
self.transport.setTTL(self.ttl) self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface) self.transport.joinGroup(self.ssdp_address, interface=self.iface)
data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_DISCOVER, self.ttl) self.send_m_search()
self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port))
def do_stop(self, gateway_xml_location): def leave_group(self):
self.transport.leaveGroup(self.ssdp_address, interface=self.iface) self.transport.leaveGroup(self.ssdp_address, interface=self.iface)
if not self.finished_deferred.called:
self.finished_deferred.callback(gateway_xml_location)
def datagramReceived(self, datagram, addr): def datagramReceived(self, datagram, addr):
self._sem.run(self.handle_datagram, datagram, addr) self._sem.run(self.handle_datagram, datagram, addr)
@ -61,28 +104,61 @@ class SSDPProtocol(DatagramProtocol):
def handle_datagram(self, datagram, address): def handle_datagram(self, datagram, address):
if address[0] == self.router: if address[0] == self.router:
try: try:
server_info = self.parse_ssdp_response(datagram) parsed = self.parse_ssdp_datagram(datagram)
except: self.devices.append(parsed)
log.exception("error parsing response: %s", datagram.encode('hex')) log.info("found %i/%s so far", len(self.devices), self.max_devices)
raise if not self.finished_deferred.called:
if server_info: if not self.max_devices or (self.max_devices and len(self.devices) >= self.max_devices):
log.info("received reply (%i bytes) to SSDP request (%fs)", len(datagram), self._sem.run(self.finished_deferred.callback, self.devices)
self._reactor.seconds() - self._start) except UPnPError as err:
self._sem.run(self.do_stop, server_info) log.error("error decoding SSDP response from %s:%s (error: %s)\n%s", address[0], address[1], str(err), binascii.hexlify(datagram))
elif address[0] != get_lan_info()[2]: raise err
log.info("received %i bytes from %s:%i", len(datagram), address[0], address[1]) elif address[0] != self.iface:
log.info("received %i bytes from %s:%i\n%s", len(datagram), address[0], address[1], binascii.hexlify(datagram))
else:
pass # loopback
class SSDPFactory(object): class SSDPFactory(object):
def __init__(self, lan_address, reactor): def __init__(self, lan_address, reactor):
self.lan_address = lan_address self.lan_address = lan_address
self._reactor = reactor self._reactor = reactor
self.protocol = None
self.port = None
self.finished_deferred = defer.Deferred()
def stop(self):
try:
self.protocol.leave_group()
self.port.stopListening()
except:
pass
def connect(self, address, ttl, max_devices=1):
self.protocol = SSDPProtocol(self._reactor, self.finished_deferred, self.lan_address, address, ttl=ttl,
max_devices=max_devices)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
self._reactor.addSystemEventTrigger("before", "shutdown", self.stop)
return self.finished_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def m_search(self, address): def m_search(self, address, ttl=30, max_devices=2):
finished_d = defer.Deferred() """
ssdp_protocol = SSDPProtocol(self._reactor, finished_d, self.lan_address, address, ttl=30) Perform a HTTP over UDP M-SEARCH query
port = ssdp_protocol._reactor.listenMulticast(ssdp_protocol.ssdp_port, ssdp_protocol, listenMultiple=True)
server_info = yield finished_d returns (list) [{
port.stopListening() 'server: <gateway os and version string>
defer.returnValue(server_info) 'location': <upnp gateway url>,
'cache-control': <max age>,
'date': <server time>,
'usn': <usn>
}, ...]
"""
d = self.connect(address, ttl, max_devices=max_devices)
try:
server_infos = yield d
except defer.TimeoutError:
server_infos = self.protocol.devices
log.info("found %i devices", len(server_infos))
self.stop()
defer.returnValue(server_infos)

0
txupnp/tests/__init__.py Normal file
View file

View file

@ -7,7 +7,7 @@ log = logging.getLogger("txupnp")
@defer.inlineCallbacks @defer.inlineCallbacks
def test(ext_port=4446, int_port=4445, proto='UDP'): def test(ext_port=4446, int_port=4446, proto='UDP'):
u = UPnP(reactor) u = UPnP(reactor)
found = yield u.discover() found = yield u.discover()
assert found, "M-SEARCH failed to find gateway" assert found, "M-SEARCH failed to find gateway"
@ -36,6 +36,12 @@ def test(ext_port=4446, int_port=4445, proto='UDP'):
else: else:
log.error("failed to tear down redirect") log.error("failed to tear down redirect")
raise AssertionError() raise AssertionError()
r = yield u.get_rsip_nat_status()
log.info(r)
r = yield u.get_status_info()
log.info(r)
r = yield u.get_connection_type_info()
log.info(r)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -1,111 +1,49 @@
import logging import logging
from xml.etree import ElementTree
from twisted.internet import defer from twisted.internet import defer
import treq
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.ssdp import SSDPFactory from txupnp.soap import SOAPServiceManager
from txupnp.scpd import SCPDCommandManager
from txupnp.util import get_lan_info, BASE_ADDRESS_REGEX, flatten_keys, etree_to_dict, DEVICE_ELEMENT_REGEX
from txupnp.util import find_inner_service_info, BASE_PORT_REGEX
from txupnp.constants import LAYER_FORWARD_KEY, WAN_INTERFACE_KEY, WAN_IP_KEY
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Service(object):
def __init__(self, base_address, serviceId=None, SCPDURL=None, eventSubURL=None, controlURL=None, **kwargs):
self.base_address = base_address
self.service_id = serviceId
self._control_path = controlURL
self._subscribe_path = eventSubURL
self._scpd_path = SCPDURL
@property
def scpd_url(self):
return self.base_address.decode() + self._scpd_path
@property
def control_url(self):
return self.base_address.decode() + self._control_path
class UPnP(object): class UPnP(object):
def __init__(self, reactor): def __init__(self, reactor):
self._reactor = reactor self._reactor = reactor
self.iface_name, self.gateway_ip, self.lan_address = get_lan_info() self.soap_manager = SOAPServiceManager(reactor)
self._m_search_factory = SSDPFactory(self.lan_address, self._reactor)
self.gateway_url = ""
self.gateway_base = ""
self.gateway_port = None
self.layer_3_forwarding = None
self.wan_ip = None
self.wan_interface = None
self.commands = SCPDCommandManager(self)
def m_search(self, address): @property
def lan_address(self):
return self.soap_manager.lan_address
@property
def commands(self):
return self.soap_manager.get_runner()
def m_search(self, address, ttl=30, max_devices=2):
""" """
Perform a HTTP over UDP M-SEARCH query Perform a HTTP over UDP M-SEARCH query
returns (dict) { returns (list) [{
'server: <gateway os and version string>
'location': <upnp gateway url>, 'location': <upnp gateway url>,
'cache-control': <max age>, 'cache-control': <max age>,
'date': <server time>, 'date': <server time>,
'usn': <usn> 'usn': <usn>
} }, ...]
""" """
return self._m_search_factory.m_search(address) return self.soap_manager.sspd_factory.m_search(address, ttl=ttl, max_devices=max_devices)
@defer.inlineCallbacks @defer.inlineCallbacks
def _discover_gateway(self): def discover(self, ttl=30, max_devices=2):
server_info = yield self.m_search(self.gateway_ip)
if 'server'.encode() in server_info:
log.info("gateway version: %s", server_info['server'.encode()])
else:
log.info("discovered gateway")
self.gateway_url = server_info['location'.encode()]
self.gateway_base = BASE_ADDRESS_REGEX.findall(self.gateway_url)[0]
self.gateway_port = int(BASE_PORT_REGEX.findall(self.gateway_url)[0]) # the tcp port
response = yield treq.get(self.gateway_url)
response_xml = yield response.text()
elements = ElementTree.fromstring(response_xml)
for element in elements:
if DEVICE_ELEMENT_REGEX.findall(element.tag):
tag = DEVICE_ELEMENT_REGEX.findall(element.tag)[0]
prefix = tag[:-6]
device_info = flatten_keys(etree_to_dict(elements.find(tag)), prefix)
self.layer_3_forwarding = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['serviceList']['service'], LAYER_FORWARD_KEY
)
)
self.wan_interface = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['deviceList']['device']['serviceList']['service'], WAN_INTERFACE_KEY
)
)
self.wan_ip = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['deviceList']['device']['deviceList']['device']['serviceList']['service'],
WAN_IP_KEY
)
)
defer.returnValue(None)
@defer.inlineCallbacks
def discover(self):
try: try:
yield self._discover_gateway() yield self.soap_manager.discover_services(ttl=ttl, max_devices=max_devices)
except defer.TimeoutError: except defer.TimeoutError:
log.warning("failed to find gateway") log.warning("failed to find upnp gateway")
defer.returnValue(False) defer.returnValue(False)
yield self.commands.discover_commands()
defer.returnValue(True) defer.returnValue(True)
def get_external_ip(self): def get_external_ip(self):
return self.commands.GetExternalIPAddress() return self.commands.GetExternalIPAddress()
#
# def GetStatusInfo(self):
# return self._commands['GetStatusInfo']()
#
# def GetConnectionTypeInfo(self):
# return self._commands['GetConnectionTypeInfo']()
def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description, lease_duration): def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description, lease_duration):
return self.commands.AddPortMapping( return self.commands.AddPortMapping(
@ -115,9 +53,6 @@ class UPnP(object):
NewLeaseDuration=lease_duration NewLeaseDuration=lease_duration
) )
# def GetNATRSIPStatus(self):
# return self._commands['GetNATRSIPStatus']()
def get_port_mapping_by_index(self, index): def get_port_mapping_by_index(self, index):
return self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) return self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
@ -135,14 +70,39 @@ class UPnP(object):
defer.returnValue(redirects) defer.returnValue(redirects)
def get_specific_port_mapping(self, external_port, protocol): def get_specific_port_mapping(self, external_port, protocol):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
"""
return self.commands.GetSpecificPortMappingEntry( return self.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
) )
# def ForceTermination(self):
# return self._commands['ForceTermination']()
def delete_port_mapping(self, external_port, protocol): def delete_port_mapping(self, external_port, protocol):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: None
"""
return self.commands.DeletePortMapping( return self.commands.DeletePortMapping(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
) )
def get_rsip_nat_status(self):
"""
:return: (bool) NewRSIPAvailable, (bool) NewNATEnabled
"""
return self.commands.GetNATRSIPStatus()
def get_status_info(self):
"""
:return: (str) NewConnectionStatus, (str) NewLastConnectionError, (int) NewUptime
"""
return self.commands.GetStatusInfo()
def get_connection_type_info(self):
"""
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
"""
return self.commands.GetConnectionTypeInfo()

View file

@ -81,3 +81,5 @@ def return_types(*types):
none_or_str = lambda x: None if not x or x == 'None' else str(x) none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None