This commit is contained in:
Jack Robison 2018-07-26 19:49:33 -04:00
commit 418c8c632b
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
10 changed files with 645 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
*.egg-info
*.pyc
_trial_temp/

16
setup.py Normal file
View file

@ -0,0 +1,16 @@
from setuptools import setup, find_packages
setup(
name="txupnp",
version="0.0.1",
author="Jack Robison",
author_email="jackrobison@lbry.io",
description="UPnP for twisted",
license='MIT',
packages=find_packages(),
install_requires=[
'Twisted',
'treq',
'netifaces'
],
)

55
tests/test_txupnp.py Normal file
View file

@ -0,0 +1,55 @@
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
from txupnp.fault import UPnPError
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def test(ext_port=4446, int_port=4445, proto='UDP'):
u = UPnP(reactor)
found = yield u.discover()
assert found, "M-SEARCH failed to find gateway"
external_ip = yield u.get_external_ip()
assert external_ip, "Failed to get the external IP"
log.info(external_ip)
try:
yield u.get_specific_port_mapping(ext_port, proto)
except UPnPError as err:
if 'NoSuchEntryInArray' in str(err):
pass
else:
log.error("there is already a redirect")
raise AssertionError()
yield u.add_port_mapping(ext_port, proto, int_port, u.lan_address, 'woah', 0)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("made redirect")
else:
log.error("failed to make redirect")
raise AssertionError()
yield u.delete_port_mapping(ext_port, proto)
redirects = yield u.get_redirects()
if (ext_port, u.lan_address, proto) not in map(lambda x: (x[1], x[4], x[2]), redirects):
log.info("tore down redirect")
else:
log.error("failed to tear down redirect")
raise AssertionError()
@defer.inlineCallbacks
def run_tests():
for p in ['UDP', 'TCP']:
yield test(proto=p)
def main():
d = run_tests()
d.addErrback(log.exception)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

7
txupnp/__init__.py Normal file
View file

@ -0,0 +1,7 @@
import logging
log = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)
log.setLevel(logging.INFO)

24
txupnp/constants.py Normal file
View file

@ -0,0 +1,24 @@
POST = "POST"
XML_VERSION = "<?xml version=\"1.0\"?>"
FAULT = "{http://schemas.xmlsoap.org/soap/envelope/}Fault"
ENVELOPE = "{http://schemas.xmlsoap.org/soap/envelope/}Envelope"
BODY = "{http://schemas.xmlsoap.org/soap/envelope/}Body"
SOAP_ENCODING = "http://schemas.xmlsoap.org/soap/encoding/"
SOAP_ENVELOPE = "http://schemas.xmlsoap.org/soap/envelope"
CONTROL_KEY = 'urn:schemas-upnp-org:control-1-0'
SERVICE_KEY = 'urn:schemas-upnp-org:service-1-0'
GATEWAY_SCHEMA = 'urn:schemas-upnp-org:device:InternetGatewayDevice:1'
WAN_INTERFACE_KEY = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
LAYER_FORWARD_KEY = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
WAN_IP_KEY = 'urn:schemas-upnp-org:service:WANIPConnection:1'
SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900
SSDP_DISCOVER = "ssdp:discover"
M_SEARCH_TEMPLATE = "\r\n".join([
"M-SEARCH * HTTP/1.1",
"HOST: {}:{}",
"ST: {}",
"MAN: \"{}\"",
"MX: {}\r\n\r\n",
])

13
txupnp/fault.py Normal file
View file

@ -0,0 +1,13 @@
from txupnp.util import flatten_keys
from txupnp.constants import FAULT, CONTROL_KEY
class UPnPError(Exception):
pass
def handle_fault(response):
if FAULT in response:
fault = flatten_keys(response[FAULT], "{%s}" % CONTROL_KEY)
raise UPnPError(fault['detail']['UPnPError']['errorDescription'])
return response

212
txupnp/scpd.py Normal file
View file

@ -0,0 +1,212 @@
import logging
from collections import namedtuple, OrderedDict
from twisted.internet import defer
from twisted.web.client import Agent, HTTPConnectionPool
import treq
from treq.client import HTTPClient
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, _return_types, none_or_str
from txupnp.fault import handle_fault
from txupnp.constants import POST, ENVELOPE, BODY, XML_VERSION, WAN_IP_KEY, SERVICE_KEY, SSDP_IP_ADDRESS
log = logging.getLogger(__name__)
class StringProducer(object):
def __init__(self, body):
self.body = body
self.length = len(body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def xml_arg(name, arg):
return "<%s>%s</%s>" % (name, arg, name)
def get_soap_body(service_name, method, param_names, **kwargs):
args = "".join(xml_arg(n, kwargs.get(n)) for n in param_names)
return '\n%s\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (XML_VERSION, method, service_name, args, method)
class _SCPDCommand(object):
def __init__(self, gateway_address, service_port, control_url, service_id, method, param_names, returns,
reactor=None):
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self._pool = HTTPConnectionPool(reactor)
self.agent = Agent(reactor, connectTimeout=1)
self._http_client = HTTPClient(self.agent, data_to_body_producer=StringProducer)
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_names = param_names
self.returns = returns
def extract_body(self, xml_response, service_key=WAN_IP_KEY):
content_dict = etree_to_dict(ElementTree.fromstring(xml_response))
envelope = content_dict[ENVELOPE]
return flatten_keys(envelope[BODY], "{%s}" % service_key)
def extract_response(self, body):
body = handle_fault(body) # raises UPnPError if there is a fault
if '%sResponse' % self.method in body:
response_key = '%sResponse' % self.method
else:
1/0
return
response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1:
return extracted_response[0]
return extracted_response
@defer.inlineCallbacks
def send_upnp_soap(self, **kwargs):
soap_body = get_soap_body(self.service_id, self.method, self.param_names, **kwargs).encode()
headers = OrderedDict((
('SOAPAction', '%s#%s' % (self.service_id, self.method)),
('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))),
('Content-Type', 'text/xml'),
('Content-Length', len(soap_body))
)
)
response = yield self._http_client.request(
POST, url=self.control_url, data=soap_body, headers=headers
)
xml_response = yield response.content()
response = self.extract_response(self.extract_body(xml_response))
defer.returnValue(response)
@staticmethod
def _process_result(results):
"""
this method gets decorated automatically with a function that maps result types to the types
defined in the @return_types decorator
"""
return results
@defer.inlineCallbacks
def __call__(self, **kwargs):
if set(kwargs.keys()) != set(self.param_names):
raise Exception("argument mismatch")
response = yield self.send_upnp_soap(**kwargs)
result = self._process_result(response)
defer.returnValue(result)
class SCPDCommandManager(object):
def __init__(self, upnp):
self._upnp = upnp
@defer.inlineCallbacks
def discover_commands(self):
response = yield treq.get(self._upnp.wan_ip.scpd_url)
content = yield response.content()
tree = ElementTree.fromstring(content)
actions = flatten_keys(etree_to_dict(tree), "{%s}" % SERVICE_KEY)["scpd"]["actionList"]["action"]
for action_dict in actions:
self._register_command(action_dict)
log.info("registered %i commands", len(actions))
defer.returnValue(None)
@staticmethod
def _soap_function_info(action_dict):
if not action_dict.get('argumentList'):
return (
action_dict['name'],
[],
[]
)
arg_dicts = action_dict['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return (
action_dict['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
)
def _register_command(self, action_info):
command = _SCPDCommand(self._upnp.gateway_ip, self._upnp.gateway_port, self._upnp.wan_ip.control_url,
self._upnp.wan_ip.service_id, *self._soap_function_info(action_info))
if not hasattr(self, command.method):
raise NotImplementedError(command.method)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = _return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
@staticmethod
def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration):
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetNATRSIPStatus():
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
def SetConnectionType(NewConnectionType):
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetExternalIPAddress():
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
def GetConnectionTypeInfo():
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
def GetStatusInfo():
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
def ForceTermination():
"""Returns None"""
raise NotImplementedError()
@staticmethod
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None"""
raise NotImplementedError()
@staticmethod
def RequestConnection():
"""Returns None"""
raise NotImplementedError()

84
txupnp/ssdp.py Normal file
View file

@ -0,0 +1,84 @@
import logging
from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol
from txupnp.util import get_lan_info
from txupnp.constants import GATEWAY_SCHEMA, M_SEARCH_TEMPLATE, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT
log = logging.getLogger(__name__)
class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, finished_deferred, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1):
self._reactor = reactor
self._sem = defer.DeferredSemaphore(1)
self.finished_deferred = finished_deferred
self.iface = iface
self.router = router
self.ssdp_address = ssdp_address
self.ssdp_port = ssdp_port
self.ttl = ttl
self._start = None
@staticmethod
def parse_ssdp_response(datagram):
lines = datagram.split("\r\n".encode())
if not lines:
return
protocol, code, response = lines[0].split(" ".encode())
if int(code) != 200:
raise Exception("unexpected http response code")
if response != "OK".encode():
raise Exception("unexpected response")
fields = {
k.lower(): v
for k, v in {
l.split(": ".encode())[0]: "".encode().join(l.split(": ".encode())[1:])
for l in lines[1:]
}.items() if k
}
return fields
def startProtocol(self):
return self._sem.run(self.do_start)
def do_start(self):
self._start = self._reactor.seconds()
self.finished_deferred.addTimeout(self.ttl, self._reactor)
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
data = M_SEARCH_TEMPLATE.format(self.ssdp_address, self.ssdp_port, GATEWAY_SCHEMA, SSDP_DISCOVER, self.ttl)
self.transport.write(data.encode(), (self.ssdp_address, self.ssdp_port))
def do_stop(self, gateway_xml_location):
self.transport.leaveGroup(self.ssdp_address, interface=self.iface)
if not self.finished_deferred.called:
self.finished_deferred.callback(gateway_xml_location)
def datagramReceived(self, datagram, addr):
self._sem.run(self.handle_datagram, datagram, addr)
def handle_datagram(self, datagram, address):
if address[0] == self.router:
server_info = self.parse_ssdp_response(datagram)
if server_info:
log.info("received reply (%i bytes) to SSDP request (%fs)", len(datagram),
self._reactor.seconds() - self._start)
self._sem.run(self.do_stop, server_info)
elif address[0] != get_lan_info()[2]:
log.info("received %i bytes from %s:%i", len(datagram), address[0], address[1])
class SSDPFactory(object):
def __init__(self, lan_address, reactor):
self.lan_address = lan_address
self._reactor = reactor
@defer.inlineCallbacks
def m_search(self, address):
finished_d = defer.Deferred()
ssdp_protocol = SSDPProtocol(self._reactor, finished_d, self.lan_address, address, ttl=30)
port = ssdp_protocol._reactor.listenMulticast(ssdp_protocol.ssdp_port, ssdp_protocol, listenMultiple=True)
server_info = yield finished_d
port.stopListening()
defer.returnValue(server_info)

148
txupnp/upnp.py Normal file
View file

@ -0,0 +1,148 @@
import logging
from xml.etree import ElementTree
from twisted.internet import defer
import treq
from txupnp.fault import UPnPError
from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandManager
from txupnp.util import get_lan_info, BASE_ADDRESS_REGEX, flatten_keys, etree_to_dict, DEVICE_ELEMENT_REGEX
from txupnp.util import find_inner_service_info, BASE_PORT_REGEX
from txupnp.constants import LAYER_FORWARD_KEY, WAN_INTERFACE_KEY, WAN_IP_KEY
log = logging.getLogger(__name__)
class Service(object):
def __init__(self, base_address, serviceId=None, SCPDURL=None, eventSubURL=None, controlURL=None, **kwargs):
self.base_address = base_address
self.service_id = serviceId
self._control_path = controlURL
self._subscribe_path = eventSubURL
self._scpd_path = SCPDURL
@property
def scpd_url(self):
return self.base_address.decode() + self._scpd_path
@property
def control_url(self):
return self.base_address.decode() + self._control_path
class UPnP(object):
def __init__(self, reactor):
self._reactor = reactor
self.iface_name, self.gateway_ip, self.lan_address = get_lan_info()
self._m_search_factory = SSDPFactory(self.lan_address, self._reactor)
self.gateway_url = ""
self.gateway_base = ""
self.gateway_port = None
self.layer_3_forwarding = None
self.wan_ip = None
self.wan_interface = None
self.commands = SCPDCommandManager(self)
def m_search(self, address):
"""
Perform a HTTP over UDP M-SEARCH query
returns (dict) {
'location': <upnp gateway url>,
'cache-control': <max age>,
'date': <server time>,
'usn': <usn>
}
"""
return self._m_search_factory.m_search(address)
@defer.inlineCallbacks
def _discover_gateway(self):
server_info = yield self.m_search(self.gateway_ip)
if 'server'.encode() in server_info:
log.info("gateway version: %s", server_info['server'.encode()])
else:
log.info("discovered gateway")
self.gateway_url = server_info['location'.encode()]
self.gateway_base = BASE_ADDRESS_REGEX.findall(self.gateway_url)[0]
self.gateway_port = int(BASE_PORT_REGEX.findall(self.gateway_url)[0]) # the tcp port
response = yield treq.get(self.gateway_url)
response_xml = yield response.text()
elements = ElementTree.fromstring(response_xml)
for element in elements:
if DEVICE_ELEMENT_REGEX.findall(element.tag):
tag = DEVICE_ELEMENT_REGEX.findall(element.tag)[0]
prefix = tag[:-6]
device_info = flatten_keys(etree_to_dict(elements.find(tag)), prefix)
self.layer_3_forwarding = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['serviceList']['service'], LAYER_FORWARD_KEY
)
)
self.wan_interface = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['deviceList']['device']['serviceList']['service'], WAN_INTERFACE_KEY
)
)
self.wan_ip = Service(self.gateway_base, **find_inner_service_info(
device_info['device']['deviceList']['device']['deviceList']['device']['serviceList']['service'],
WAN_IP_KEY
)
)
defer.returnValue(None)
@defer.inlineCallbacks
def discover(self):
try:
yield self._discover_gateway()
except defer.TimeoutError:
log.warning("failed to find gateway")
defer.returnValue(False)
yield self.commands.discover_commands()
defer.returnValue(True)
def get_external_ip(self):
return self.commands.GetExternalIPAddress()
#
# def GetStatusInfo(self):
# return self._commands['GetStatusInfo']()
#
# def GetConnectionTypeInfo(self):
# return self._commands['GetConnectionTypeInfo']()
def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description, lease_duration):
return self.commands.AddPortMapping(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description,
NewLeaseDuration=lease_duration
)
# def GetNATRSIPStatus(self):
# return self._commands['GetNATRSIPStatus']()
def get_port_mapping_by_index(self, index):
return self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
@defer.inlineCallbacks
def get_redirects(self):
redirects = []
cnt = 0
while True:
try:
redirect = yield self.get_port_mapping_by_index(cnt)
redirects.append(redirect)
cnt += 1
except UPnPError:
break
defer.returnValue(redirects)
def get_specific_port_mapping(self, external_port, protocol):
return self.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
# def ForceTermination(self):
# return self._commands['ForceTermination']()
def delete_port_mapping(self, external_port, protocol):
return self.commands.DeletePortMapping(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)

83
txupnp/util.py Normal file
View file

@ -0,0 +1,83 @@
import re
import functools
from collections import defaultdict
import netifaces
DEVICE_ELEMENT_REGEX = re.compile("^\{urn:schemas-upnp-org:device-\d-\d\}device$")
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
def etree_to_dict(t):
d = {t.tag: {} if t.attrib else None}
children = list(t)
if children:
dd = defaultdict(list)
for dc in map(etree_to_dict, children):
for k, v in dc.items():
dd[k].append(v)
d = {t.tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}}
if t.attrib:
d[t.tag].update(('@' + k, v) for k, v in t.attrib.items())
if t.text:
text = t.text.strip()
if children or t.attrib:
if text:
d[t.tag]['#text'] = text
else:
d[t.tag] = text
return d
def flatten_keys(d, strip):
if not isinstance(d, (list, dict)):
return d
if isinstance(d, list):
return [flatten_keys(i, strip) for i in d]
t = {}
for k, v in d.items():
if strip in k and strip != k:
t[k.split(strip)[1]] = flatten_keys(v, strip)
else:
t[k] = flatten_keys(v, strip)
return t
def get_lan_info():
gateway_address, iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
lan_addr = netifaces.ifaddresses(iface_name)[netifaces.AF_INET][0]['addr']
return iface_name, gateway_address, lan_addr
def find_inner_service_info(service, name):
if isinstance(service, dict):
return service
for s in service:
if name == s['serviceType']:
return s
raise IndexError(name)
def _return_types(*types):
def _return_types_wrapper(fn):
@functools.wraps(fn)
def _inner(response):
if isinstance(response, (list, tuple)):
r = tuple(t(r) for t, r in zip(types, response))
if len(r) == 1:
return fn(r[0])
return fn(r)
return fn(types[0](response))
return _inner
return _return_types_wrapper
def return_types(*types):
def return_types_wrapper(fn):
fn._return_types = types
return fn
return return_types_wrapper
none_or_str = lambda x: None if not x or x == 'None' else str(x)