commit 418c8c632b1d800b6040d5629ff3be96d7164f44 Author: Jack Robison Date: Thu Jul 26 19:49:33 2018 -0400 initial diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b10db10 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.egg-info +*.pyc +_trial_temp/ \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..88df1bc --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages + +setup( + name="txupnp", + version="0.0.1", + author="Jack Robison", + author_email="jackrobison@lbry.io", + description="UPnP for twisted", + license='MIT', + packages=find_packages(), + install_requires=[ + 'Twisted', + 'treq', + 'netifaces' + ], +) diff --git a/tests/test_txupnp.py b/tests/test_txupnp.py new file mode 100644 index 0000000..9a4785d --- /dev/null +++ b/tests/test_txupnp.py @@ -0,0 +1,55 @@ +import logging +from twisted.internet import reactor, defer +from txupnp.upnp import UPnP +from txupnp.fault import UPnPError + +log = logging.getLogger("txupnp") + + +@defer.inlineCallbacks +def test(ext_port=4446, int_port=4445, proto='UDP'): + u = UPnP(reactor) + found = yield u.discover() + assert found, "M-SEARCH failed to find gateway" + external_ip = yield u.get_external_ip() + assert external_ip, "Failed to get the external IP" + log.info(external_ip) + try: + yield u.get_specific_port_mapping(ext_port, proto) + except UPnPError as err: + if 'NoSuchEntryInArray' in str(err): + pass + else: + log.error("there is already a redirect") + raise AssertionError() + yield u.add_port_mapping(ext_port, proto, int_port, u.lan_address, 'woah', 0) + redirects = yield u.get_redirects() + if (ext_port, u.lan_address, proto) in map(lambda x: (x[1], x[4], x[2]), redirects): + log.info("made redirect") + else: + log.error("failed to make redirect") + raise AssertionError() + yield u.delete_port_mapping(ext_port, proto) + redirects = yield u.get_redirects() + if (ext_port, u.lan_address, proto) not in map(lambda x: (x[1], x[4], x[2]), redirects): + log.info("tore down redirect") + else: + log.error("failed to tear down redirect") + raise AssertionError() + + +@defer.inlineCallbacks +def run_tests(): + for p in ['UDP', 'TCP']: + yield test(proto=p) + + +def main(): + d = run_tests() + d.addErrback(log.exception) + d.addBoth(lambda _: reactor.callLater(0, reactor.stop)) + reactor.run() + + +if __name__ == "__main__": + main() diff --git a/txupnp/__init__.py b/txupnp/__init__.py new file mode 100644 index 0000000..2bb3785 --- /dev/null +++ b/txupnp/__init__.py @@ -0,0 +1,7 @@ +import logging + +log = logging.getLogger(__name__) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s')) +log.addHandler(handler) +log.setLevel(logging.INFO) diff --git a/txupnp/constants.py b/txupnp/constants.py new file mode 100644 index 0000000..14a150f --- /dev/null +++ b/txupnp/constants.py @@ -0,0 +1,24 @@ +POST = "POST" +XML_VERSION = "" +FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault" +ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope" +BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body" +SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/" +SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope" +CONTROL_KEY = 'urn:schemas-upnp-org:control-1-0' +SERVICE_KEY = 'urn:schemas-upnp-org:service-1-0' +GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1' +WAN_INTERFACE_KEY = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1' +LAYER_FORWARD_KEY = 'urn:schemas-upnp-org:service:Layer3Forwarding:1' +WAN_IP_KEY = 'urn:schemas-upnp-org:service:WANIPConnection:1' + +SSDP_IP_ADDRESS = '239.255.255.250' +SSDP_PORT = 1900 +SSDP_DISCOVER = "ssdp:discover" +M_SEARCH_TEMPLATE = "\r\n".join([ + "M-SEARCH * HTTP/1.1", + "HOST: {}:{}", + "ST: {}", + "MAN: \"{}\"", + "MX: {}\r\n\r\n", +]) diff --git a/txupnp/fault.py b/txupnp/fault.py new file mode 100644 index 0000000..3a6a94e --- /dev/null +++ b/txupnp/fault.py @@ -0,0 +1,13 @@ +from txupnp.util import flatten_keys +from txupnp.constants import FAULT, CONTROL_KEY + + +class UPnPError(Exception): + pass + + +def handle_fault(response): + if FAULT in response: + fault = flatten_keys(response[FAULT], "{%s}" % CONTROL_KEY) + raise UPnPError(fault['detail']['UPnPError']['errorDescription']) + return response diff --git a/txupnp/scpd.py b/txupnp/scpd.py new file mode 100644 index 0000000..6ad4522 --- /dev/null +++ b/txupnp/scpd.py @@ -0,0 +1,212 @@ +import logging +from collections import namedtuple, OrderedDict +from twisted.internet import defer +from twisted.web.client import Agent, HTTPConnectionPool +import treq +from treq.client import HTTPClient +from xml.etree import ElementTree +from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str +from txupnp.fault import handle_fault +from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, WAN_IP_KEY, SERVICE_KEY, SSDP_IP_ADDRESS + +log = logging.getLogger(__name__) + + +class StringProducer(object): + def __init__(self, body): + self.body = body + self.length = len(body) + + def startProducing(self, consumer): + consumer.write(self.body) + return defer.succeed(None) + + def pauseProducing(self): + pass + + def stopProducing(self): + pass + + +def xml_arg(name, arg): + return "<%s>%s" % (name, arg, name) + + +def get_soap_body(service_name, method, param_names, **kwargs): + args = "".join(xml_arg(n, kwargs.get(n)) for n in param_names) + return '\n%s\n%s' % (XML_VERSION, method, service_name, args, method) + + +class _SCPDCommand(object): + def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns, + reactor=None): + if not reactor: + from twisted.internet import reactor + self._reactor = reactor + self._pool = HTTPConnectionPool(reactor) + self.agent = Agent(reactor, connectTimeout=1) + self._http_client = HTTPClient(self.agent, data_to_body_producer=StringProducer) + self.gateway_address = gateway_address + self.service_port = service_port + self.control_url = control_url + self.service_id = service_id + self.method = method + self.param_names = param_names + self.returns = returns + + def extract_body(self, xml_response, service_key=WAN_IP_KEY): + content_dict = etree_to_dict(ElementTree.fromstring(xml_response)) + envelope = content_dict[ENVELOPE] + return flatten_keys(envelope[BODY], "{%s}" % service_key) + + def extract_response(self, body): + body = handle_fault(body) # raises UPnPError if there is a fault + if '%sResponse' % self.method in body: + response_key = '%sResponse' % self.method + else: + 1/0 + return + + response = body[response_key] + extracted_response = tuple([response[n] for n in self.returns]) + if len(extracted_response) == 1: + return extracted_response[0] + return extracted_response + + @defer.inlineCallbacks + def send_upnp_soap(self, **kwargs): + soap_body = get_soap_body(self.service_id, self.method, self.param_names, **kwargs).encode() + headers = OrderedDict(( + ('SOAPAction', '%s#%s' % (self.service_id, self.method)), + ('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))), + ('Content-Type', 'text/xml'), + ('Content-Length', len(soap_body)) + ) + ) + response = yield self._http_client.request( + POST, url=self.control_url, data=soap_body, headers=headers + ) + xml_response = yield response.content() + response = self.extract_response(self.extract_body(xml_response)) + defer.returnValue(response) + + @staticmethod + def _process_result(results): + """ + this method gets decorated automatically with a function that maps result types to the types + defined in the @return_types decorator + """ + return results + + @defer.inlineCallbacks + def __call__(self, **kwargs): + if set(kwargs.keys()) != set(self.param_names): + raise Exception("argument mismatch") + response = yield self.send_upnp_soap(**kwargs) + result = self._process_result(response) + defer.returnValue(result) + + +class SCPDCommandManager(object): + def __init__(self, upnp): + self._upnp = upnp + + @defer.inlineCallbacks + def discover_commands(self): + response = yield treq.get(self._upnp.wan_ip.scpd_url) + content = yield response.content() + tree = ElementTree.fromstring(content) + actions = flatten_keys(etree_to_dict(tree), "{%s}" % SERVICE_KEY)["scpd"]["actionList"]["action"] + for action_dict in actions: + self._register_command(action_dict) + log.info("registered %i commands", len(actions)) + defer.returnValue(None) + + @staticmethod + def _soap_function_info(action_dict): + if not action_dict.get('argumentList'): + return ( + action_dict['name'], + [], + [] + ) + arg_dicts = action_dict['argumentList']['argument'] + if not isinstance(arg_dicts, list): # when there is one arg, ew + arg_dicts = [arg_dicts] + return ( + action_dict['name'], + [i['name'] for i in arg_dicts if i['direction'] == 'in'], + [i['name'] for i in arg_dicts if i['direction'] == 'out'] + ) + + def _register_command(self, action_info): + command = _SCPDCommand(self._upnp.gateway_ip, self._upnp.gateway_port, self._upnp.wan_ip.control_url, + self._upnp.wan_ip.service_id, *self._soap_function_info(action_info)) + if not hasattr(self, command.method): + raise NotImplementedError(command.method) + current = getattr(self, command.method) + if hasattr(current, "_return_types"): + command._process_result = _return_types(*current._return_types)(command._process_result) + setattr(command, "__doc__", current.__doc__) + setattr(self, command.method, command) + + @staticmethod + def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, + NewEnabled, NewPortMappingDescription, NewLeaseDuration): + """Returns None""" + raise NotImplementedError() + + @staticmethod + def GetNATRSIPStatus(): + """Returns (NewRSIPAvailable, NewNATEnabled)""" + raise NotImplementedError() + + @staticmethod + @return_types(none_or_str, int, str, int, str, bool, str, int) + def GetGenericPortMappingEntry(NewPortMappingIndex): + """ + Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, + NewPortMappingDescription, NewLeaseDuration) + """ + raise NotImplementedError() + + @staticmethod + @return_types(int, str, bool, str, int) + def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol): + """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" + raise NotImplementedError() + + @staticmethod + def SetConnectionType(NewConnectionType): + """Returns None""" + raise NotImplementedError() + + @staticmethod + def GetExternalIPAddress(): + """Returns (NewExternalIPAddress)""" + raise NotImplementedError() + + @staticmethod + def GetConnectionTypeInfo(): + """Returns (NewConnectionType, NewPossibleConnectionTypes)""" + raise NotImplementedError() + + @staticmethod + def GetStatusInfo(): + """Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)""" + raise NotImplementedError() + + @staticmethod + def ForceTermination(): + """Returns None""" + raise NotImplementedError() + + @staticmethod + def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol): + """Returns None""" + raise NotImplementedError() + + @staticmethod + def RequestConnection(): + """Returns None""" + raise NotImplementedError() diff --git a/txupnp/ssdp.py b/txupnp/ssdp.py new file mode 100644 index 0000000..83cd686 --- /dev/null +++ b/txupnp/ssdp.py @@ -0,0 +1,84 @@ +import logging +from twisted.internet import defer +from twisted.internet.protocol import DatagramProtocol +from txupnp.util import get_lan_info +from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT + +log = logging.getLogger(__name__) + + +class SSDPProtocol(DatagramProtocol): + def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS, + ssdp_port=SSDP_PORT, ttl=1): + self._reactor = reactor + self._sem = defer.DeferredSemaphore(1) + self.finished_deferred = finished_deferred + self.iface = iface + self.router = router + self.ssdp_address = ssdp_address + self.ssdp_port = ssdp_port + self.ttl = ttl + self._start = None + + @staticmethod + def parse_ssdp_response(datagram): + lines = datagram.split("\r\n".encode()) + if not lines: + return + protocol, code, response = lines[0].split(" ".encode()) + if int(code) != 200: + raise Exception("unexpected http response code") + if response != "OK".encode(): + raise Exception("unexpected response") + fields = { + k.lower(): v + for k, v in { + l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:]) + for l in lines[1:] + }.items() if k + } + return fields + + def startProtocol(self): + return self._sem.run(self.do_start) + + def do_start(self): + self._start = self._reactor.seconds() + self.finished_deferred.addTimeout(self.ttl, self._reactor) + self.transport.setTTL(self.ttl) + self.transport.joinGroup(self.ssdp_address, interface=self.iface) + data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_DISCOVER, self.ttl) + self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port)) + + def do_stop(self, gateway_xml_location): + self.transport.leaveGroup(self.ssdp_address, interface=self.iface) + if not self.finished_deferred.called: + self.finished_deferred.callback(gateway_xml_location) + + def datagramReceived(self, datagram, addr): + self._sem.run(self.handle_datagram, datagram, addr) + + def handle_datagram(self, datagram, address): + if address[0] == self.router: + server_info = self.parse_ssdp_response(datagram) + if server_info: + log.info("received reply (%i bytes) to SSDP request (%fs)", len(datagram), + self._reactor.seconds() - self._start) + self._sem.run(self.do_stop, server_info) + elif address[0] != get_lan_info()[2]: + log.info("received %i bytes from %s:%i", len(datagram), address[0], address[1]) + + +class SSDPFactory(object): + def __init__(self, lan_address, reactor): + self.lan_address = lan_address + self._reactor = reactor + + @defer.inlineCallbacks + def m_search(self, address): + finished_d = defer.Deferred() + ssdp_protocol = SSDPProtocol(self._reactor, finished_d, self.lan_address, address, ttl=30) + port = ssdp_protocol._reactor.listenMulticast(ssdp_protocol.ssdp_port, ssdp_protocol, listenMultiple=True) + server_info = yield finished_d + port.stopListening() + defer.returnValue(server_info) diff --git a/txupnp/upnp.py b/txupnp/upnp.py new file mode 100644 index 0000000..78f8035 --- /dev/null +++ b/txupnp/upnp.py @@ -0,0 +1,148 @@ +import logging +from xml.etree import ElementTree +from twisted.internet import defer +import treq +from txupnp.fault import UPnPError +from txupnp.ssdp import SSDPFactory +from txupnp.scpd import SCPDCommandManager +from txupnp.util import get_lan_info, BASE_ADDRESS_REGEX, flatten_keys, etree_to_dict, DEVICE_ELEMENT_REGEX +from txupnp.util import find_inner_service_info, BASE_PORT_REGEX +from txupnp.constants import LAYER_FORWARD_KEY, WAN_INTERFACE_KEY, WAN_IP_KEY + +log = logging.getLogger(__name__) + + +class Service(object): + def __init__(self, base_address, serviceId=None, SCPDURL=None, eventSubURL=None, controlURL=None, **kwargs): + self.base_address = base_address + self.service_id = serviceId + self._control_path = controlURL + self._subscribe_path = eventSubURL + self._scpd_path = SCPDURL + + @property + def scpd_url(self): + return self.base_address.decode() + self._scpd_path + + @property + def control_url(self): + return self.base_address.decode() + self._control_path + + +class UPnP(object): + def __init__(self, reactor): + self._reactor = reactor + self.iface_name, self.gateway_ip, self.lan_address = get_lan_info() + self._m_search_factory = SSDPFactory(self.lan_address, self._reactor) + self.gateway_url = "" + self.gateway_base = "" + self.gateway_port = None + self.layer_3_forwarding = None + self.wan_ip = None + self.wan_interface = None + self.commands = SCPDCommandManager(self) + + def m_search(self, address): + """ + Perform a HTTP over UDP M-SEARCH query + + returns (dict) { + 'location': , + 'cache-control': , + 'date': , + 'usn': + } + """ + return self._m_search_factory.m_search(address) + + @defer.inlineCallbacks + def _discover_gateway(self): + server_info = yield self.m_search(self.gateway_ip) + if 'server'.encode() in server_info: + log.info("gateway version: %s", server_info['server'.encode()]) + else: + log.info("discovered gateway") + self.gateway_url = server_info['location'.encode()] + self.gateway_base = BASE_ADDRESS_REGEX.findall(self.gateway_url)[0] + self.gateway_port = int(BASE_PORT_REGEX.findall(self.gateway_url)[0]) # the tcp port + response = yield treq.get(self.gateway_url) + response_xml = yield response.text() + elements = ElementTree.fromstring(response_xml) + for element in elements: + if DEVICE_ELEMENT_REGEX.findall(element.tag): + tag = DEVICE_ELEMENT_REGEX.findall(element.tag)[0] + prefix = tag[:-6] + device_info = flatten_keys(etree_to_dict(elements.find(tag)), prefix) + self.layer_3_forwarding = Service(self.gateway_base, **find_inner_service_info( + device_info['device']['serviceList']['service'], LAYER_FORWARD_KEY + ) + ) + self.wan_interface = Service(self.gateway_base, **find_inner_service_info( + device_info['device']['deviceList']['device']['serviceList']['service'], WAN_INTERFACE_KEY + ) + ) + self.wan_ip = Service(self.gateway_base, **find_inner_service_info( + device_info['device']['deviceList']['device']['deviceList']['device']['serviceList']['service'], + WAN_IP_KEY + ) + ) + defer.returnValue(None) + + @defer.inlineCallbacks + def discover(self): + try: + yield self._discover_gateway() + except defer.TimeoutError: + log.warning("failed to find gateway") + defer.returnValue(False) + yield self.commands.discover_commands() + defer.returnValue(True) + + def get_external_ip(self): + return self.commands.GetExternalIPAddress() + # + # def GetStatusInfo(self): + # return self._commands['GetStatusInfo']() + # + # def GetConnectionTypeInfo(self): + # return self._commands['GetConnectionTypeInfo']() + + def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description, lease_duration): + return self.commands.AddPortMapping( + NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol, + NewInternalPort=internal_port, NewInternalClient=lan_address, + NewEnabled=1, NewPortMappingDescription=description, + NewLeaseDuration=lease_duration + ) + + # def GetNATRSIPStatus(self): + # return self._commands['GetNATRSIPStatus']() + + def get_port_mapping_by_index(self, index): + return self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) + + @defer.inlineCallbacks + def get_redirects(self): + redirects = [] + cnt = 0 + while True: + try: + redirect = yield self.get_port_mapping_by_index(cnt) + redirects.append(redirect) + cnt += 1 + except UPnPError: + break + defer.returnValue(redirects) + + def get_specific_port_mapping(self, external_port, protocol): + return self.commands.GetSpecificPortMappingEntry( + NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol + ) + + # def ForceTermination(self): + # return self._commands['ForceTermination']() + + def delete_port_mapping(self, external_port, protocol): + return self.commands.DeletePortMapping( + NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol + ) diff --git a/txupnp/util.py b/txupnp/util.py new file mode 100644 index 0000000..ea59637 --- /dev/null +++ b/txupnp/util.py @@ -0,0 +1,83 @@ +import re +import functools +from collections import defaultdict +import netifaces + +DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$") +BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) +BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode()) + + +def etree_to_dict(t): + d = {t.tag: {} if t.attrib else None} + children = list(t) + if children: + dd = defaultdict(list) + for dc in map(etree_to_dict, children): + for k, v in dc.items(): + dd[k].append(v) + d = {t.tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}} + if t.attrib: + d[t.tag].update(('@' + k, v) for k, v in t.attrib.items()) + if t.text: + text = t.text.strip() + if children or t.attrib: + if text: + d[t.tag]['#text'] = text + else: + d[t.tag] = text + return d + + +def flatten_keys(d, strip): + if not isinstance(d, (list, dict)): + return d + if isinstance(d, list): + return [flatten_keys(i, strip) for i in d] + t = {} + for k, v in d.items(): + if strip in k and strip != k: + t[k.split(strip)[1]] = flatten_keys(v, strip) + else: + t[k] = flatten_keys(v, strip) + return t + + +def get_lan_info(): + gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET] + lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr'] + return iface_name, gateway_address, lan_addr + + +def find_inner_service_info(service, name): + if isinstance(service, dict): + return service + for s in service: + if name == s['serviceType']: + return s + raise IndexError(name) + + +def _return_types(*types): + def _return_types_wrapper(fn): + @functools.wraps(fn) + def _inner(response): + if isinstance(response, (list, tuple)): + r = tuple(t(r) for t, r in zip(types, response)) + if len(r) == 1: + return fn(r[0]) + return fn(r) + return fn(types[0](response)) + return _inner + return _return_types_wrapper + + +def return_types(*types): + def return_types_wrapper(fn): + fn._return_types = types + return fn + + return return_types_wrapper + + +none_or_str = lambda x: None if not x or x == 'None' else str(x)