This commit is contained in:
Jack Robison 2018-10-01 09:51:31 -04:00
parent e7f72149c2
commit 69f35aa54b
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
18 changed files with 1017 additions and 948 deletions

1
.gitignore vendored
View file

@ -3,3 +3,4 @@
_trial_temp/ _trial_temp/
build/ build/
dist/ dist/
.coverage

View file

@ -1,6 +1,6 @@
# UPnP for Twisted # UPnP for Twisted
`txupnp` is a python2/3 library to interact with UPnP gateways using Twisted `txupnp` is a python 3 library to interact with UPnP gateways using `twisted`
## Installation ## Installation

View file

@ -21,7 +21,7 @@ setup(
long_description=long_description, long_description=long_description,
url="https://github.com/lbryio/txupnp", url="https://github.com/lbryio/txupnp",
license=__license__, license=__license__,
packages=find_packages(), packages=find_packages(exclude=['tests']),
entry_points={'console_scripts': console_scripts}, entry_points={'console_scripts': console_scripts},
install_requires=[ install_requires=[
'twisted[tls]', 'twisted[tls]',

View file

@ -0,0 +1,7 @@
import logging
log = logging.getLogger("txupnp")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)
log.setLevel(logging.INFO)

File diff suppressed because one or more lines are too long

105
tests/test_devices.py Normal file
View file

@ -0,0 +1,105 @@
from twisted.internet import reactor, defer
from twisted.trial import unittest
from txupnp.constants import SSDP_PORT, SSDP_IP_ADDRESS
from txupnp.upnp import UPnP
from txupnp.scpd import SCPDCommand
from txupnp.gateway import Service
from txupnp.fault import UPnPError
from txupnp.mocks import MockReactor, MockSSDPServiceGatewayProtocol, get_device_test_case
from txupnp.util import verify_return_types
class TestDevice(unittest.TestCase):
manufacturer, model = "Cisco", "CGA4131COM"
device = get_device_test_case(manufacturer, model)
router_address = device.device_dict['router_address']
client_address = device.device_dict['client_address']
expected_devices = device.device_dict['expected_devices']
packets_rx = device.device_dict['ssdp']['received']
packets_tx = device.device_dict['ssdp']['sent']
expected_available_commands = device.device_dict['commands']['available']
scdp_packets = device.device_dict['scpd']
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
def tearDown(self):
self.upnp.sspd_factory.disconnect()
self.server_port.stopListening()
class TestSSDP(TestDevice):
@defer.inlineCallbacks
def test_discover_device(self):
result = yield self.upnp.m_search(self.router_address, timeout=1)
self.assertEqual(len(self.expected_devices), len(result))
self.assertEqual(len(result), 1)
self.assertDictEqual(self.expected_devices[0], result[0])
class TestSCPD(TestDevice):
@defer.inlineCallbacks
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
reactor.connectTCP = fake_reactor.connectTCP
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
yield self.upnp.discover()
def test_parse_available_commands(self):
self.assertDictEqual(self.expected_available_commands, self.upnp.gateway.debug_commands()['available'])
def test_parse_gateway(self):
self.assertDictEqual(self.device.device_dict['gateway_dict'], self.upnp.gateway.as_dict())
@defer.inlineCallbacks
def test_commands(self):
method, args, expected = self.device.device_dict['soap'][0]
command1 = getattr(self.upnp, method)
result = yield command1(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][1]
command2 = getattr(self.upnp, method)
result = yield command2(*tuple(args))
result = [[i for i in r] for r in result]
self.assertListEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][2]
command3 = getattr(self.upnp, method)
result = yield command3(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][3]
command4 = getattr(self.upnp, method)
result = yield command4(*tuple(args))
result = [r for r in result]
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][4]
command5 = getattr(self.upnp, method)
result = yield command5(*tuple(args))
self.assertEqual(result, expected)

View file

@ -1,45 +0,0 @@
from twisted.internet import reactor, defer
from twisted.trial import unittest
from txupnp.constants import SSDP_PORT
from txupnp import mocks
from txupnp.ssdp import SSDPFactory
class TestDiscoverGateway(unittest.TestCase):
router_address = '10.0.0.1'
client_address = '10.0.0.10'
service_name = 'WANCommonInterfaceConfig:1'
st = 'urn:schemas-upnp-org:service:%s' % service_name
port = 49152
location = 'InternetGatewayDevice.xml'
usn = 'uuid:00000000-0000-0000-0000-000000000000::%s' % st
version = 'Linux, UPnP/1.0, DIR-890L Ver 1.20'
expected_devices = [
{
'cache_control': 'max-age=1800',
'location': 'http://%s:%i/%s' % (router_address, port, location),
'server': version,
'st': st,
'usn': usn
}
]
def setUp(self):
fake_reactor = mocks.MockReactor()
reactor.listenMulticast = fake_reactor.listenMulticast
self.reactor = reactor
server_protocol = mocks.MockSSDPServiceGatewayProtocol(
self.router_address, self.service_name, self.st, self.port, self.location, self.usn, self.version
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol)
def tearDown(self):
self.server_port.stopListening()
@defer.inlineCallbacks
def test_discover(self):
client_factory = SSDPFactory(self.reactor, self.client_address, self.router_address)
result = yield client_factory.m_search(self.router_address)
self.assertListEqual(self.expected_devices, result)
client_factory.disconnect()

View file

@ -1,3 +1,5 @@
import os
import json
import argparse import argparse
import logging import logging
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
@ -6,11 +8,6 @@ from txupnp.upnp import UPnP
log = logging.getLogger("txupnp") log = logging.getLogger("txupnp")
def debug_device(u, include_gateway_xml=False, *_):
print(u.get_debug_info(include_gateway_xml=include_gateway_xml))
return defer.succeed(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_external_ip(u, *_): def get_external_ip(u, *_):
ip = yield u.get_external_ip() ip = yield u.get_external_ip()
@ -30,7 +27,7 @@ def list_mappings(u, *_):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_mapping(u, *_): def add_mapping(u, *_):
port = 4567 port = 51413
protocol = "UDP" protocol = "UDP"
description = "txupnp test mapping" description = "txupnp test mapping"
ext_port = yield u.get_next_mapping(port, protocol, description) ext_port = yield u.get_next_mapping(port, protocol, description)
@ -50,12 +47,64 @@ def delete_mapping(u, *_):
print("removed mapping") print("removed mapping")
def _encode(x):
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
@defer.inlineCallbacks
def generate_test_data(u, *_):
external_ip = yield u.get_external_ip()
redirects = yield u.get_redirects()
ext_port = yield u.get_next_mapping(4567, "UDP", "txupnp test mapping")
delete = yield u.delete_port_mapping(ext_port, "UDP")
after_delete = yield u.get_specific_port_mapping(ext_port, "UDP")
commands_test_case = (
("get_external_ip", (), "1.2.3.4"),
("get_redirects", (), redirects),
("get_next_mapping", (4567, "UDP", "txupnp test mapping"), ext_port),
("delete_port_mapping", (ext_port, "UDP"), delete),
("get_specific_port_mapping", (ext_port, "UDP"), after_delete),
)
gateway = u.gateway
device = list(gateway.devices.values())[0]
assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName))
commands = gateway.debug_commands()
with open(device_path, "w") as f:
f.write(json.dumps({
"router_address": u.router_ip,
"client_address": u.lan_address,
"port": gateway.port,
"gateway_dict": gateway.as_dict(),
'expected_devices': [
{
'cache_control': 'max-age=1800',
'location': gateway.location,
'server': gateway.server,
'st': gateway.urn,
'usn': gateway.usn
}
],
'commands': commands,
'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
'scpd': gateway.requester.dump_packets(),
'soap': commands_test_case
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
print("Generated test data! -> %s" % device_path)
cli_commands = { cli_commands = {
"debug_device": debug_device,
"get_external_ip": get_external_ip, "get_external_ip": get_external_ip,
"list_mappings": list_mappings, "list_mappings": list_mappings,
"add_mapping": add_mapping, "add_mapping": add_mapping,
"delete_mapping": delete_mapping "delete_mapping": delete_mapping,
"generate_test_data": generate_test_data,
} }
@ -85,9 +134,9 @@ def main():
parser.add_argument("--include_igd_xml", dest="include_igd_xml", default=False, action="store_true") parser.add_argument("--include_igd_xml", dest="include_igd_xml", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.debug_logging: if args.debug_logging:
from twisted.python import log as tx_log # from twisted.python import log as tx_log
observer = tx_log.PythonLoggingObserver(loggerName="txupnp") # observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
observer.start() # observer.start()
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
command = args.command command = args.command
command = command.replace("-", "_") command = command.replace("-", "_")
@ -98,7 +147,7 @@ def main():
def show(err): def show(err):
print("error: {}".format(err)) print("error: {}".format(err))
u = UPnP(reactor) u = UPnP(reactor, debug_ssdp=(command == "generate_test_data"))
d = u.discover() d = u.discover()
d.addCallback(run_command, u, command, args.include_igd_xml) d.addCallback(run_command, u, command, args.include_igd_xml)
d.addErrback(show) d.addErrback(show)

137
txupnp/commands.py Normal file
View file

@ -0,0 +1,137 @@
from txupnp.util import return_types, none_or_str, none
class SCPDCommands: # TODO use type annotations
def debug_commands(self) -> dict:
raise NotImplementedError()
@staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: bool, NewPortMappingDescription: str,
NewLeaseDuration: str = '') -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus() -> (bool, bool):
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex) -> (none_or_str, int, str, int, str, bool, str, int):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol) -> (int, str, bool, str, int):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(str)
def GetExternalIPAddress() -> str:
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetConnectionTypeInfo() -> (str, str):
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str, int)
def GetStatusInfo() -> (str, str, int):
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def ForceTermination() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def RequestConnection() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService) -> None:
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetEnabledForInternet(NewEnabledForInternet) -> None:
raise NotImplementedError()
@staticmethod
@return_types(bool)
def GetEnabledForInternet() -> bool:
raise NotImplementedError()
@staticmethod
def GetMaximumActiveConnections(NewActiveConnectionIndex):
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetActiveConnections() -> (str, str):
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()

View file

@ -20,15 +20,7 @@ IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
service_types = [ service_types = [
UPNP_ORG_IGD, UPNP_ORG_IGD,
WIFI_ALLIANCE_ORG_IGD, # WIFI_ALLIANCE_ORG_IGD,
WAN_SCHEMA,
LAYER_SCHEMA,
IP_SCHEMA,
CONTROL,
SERVICE,
DEVICE,
] ]
SSDP_IP_ADDRESS = '239.255.255.250' SSDP_IP_ADDRESS = '239.255.255.250'

View file

@ -1,94 +0,0 @@
import logging
from twisted.web.client import HTTPConnectionPool, _HTTP11ClientFactory
from twisted.web._newclient import HTTPClientParser, BadResponseVersion, HTTP11ClientProtocol, RequestNotSent
from twisted.web._newclient import TransportProxyProducer, RequestGenerationFailed
from twisted.python.failure import Failure
from twisted.internet.defer import Deferred, fail, maybeDeferred
from twisted.internet.defer import CancelledError
log = logging.getLogger()
class DirtyHTTPParser(HTTPClientParser):
def parseVersion(self, strversion):
"""
Parse version strings of the form Protocol '/' Major '.' Minor. E.g.
b'HTTP/1.1'. Returns (protocol, major, minor). Will raise ValueError
on bad syntax.
"""
try:
proto, strnumber = strversion.split(b'/')
major, minor = strnumber.split(b'.')
major, minor = int(major), int(minor)
except ValueError as e:
log.exception("got a bad http version: %s", strversion)
if b'HTTP1.1' in strversion:
return ("HTTP", 1, 1)
raise BadResponseVersion(str(e), strversion)
if major < 0 or minor < 0:
raise BadResponseVersion(u"version may not be negative",
strversion)
return (proto, major, minor)
class DirtyHTTPClientProtocol(HTTP11ClientProtocol):
def request(self, request):
if self._state != 'QUIESCENT':
return fail(RequestNotSent())
self._state = 'TRANSMITTING'
_requestDeferred = maybeDeferred(request.writeTo, self.transport)
def cancelRequest(ign):
# Explicitly cancel the request's deferred if it's still trying to
# write when this request is cancelled.
if self._state in (
'TRANSMITTING', 'TRANSMITTING_AFTER_RECEIVING_RESPONSE'):
_requestDeferred.cancel()
else:
self.transport.abortConnection()
self._disconnectParser(Failure(CancelledError()))
self._finishedRequest = Deferred(cancelRequest)
# Keep track of the Request object in case we need to call stopWriting
# on it.
self._currentRequest = request
self._transportProxy = TransportProxyProducer(self.transport)
self._parser = DirtyHTTPParser(request, self._finishResponse)
self._parser.makeConnection(self._transportProxy)
self._responseDeferred = self._parser._responseDeferred
def cbRequestWritten(ignored):
if self._state == 'TRANSMITTING':
self._state = 'WAITING'
self._responseDeferred.chainDeferred(self._finishedRequest)
def ebRequestWriting(err):
if self._state == 'TRANSMITTING':
self._state = 'GENERATION_FAILED'
self.transport.abortConnection()
self._finishedRequest.errback(
Failure(RequestGenerationFailed([err])))
else:
self._log.failure(
u'Error writing request, but not in valid state '
u'to finalize request: {state}',
failure=err,
state=self._state
)
_requestDeferred.addCallbacks(cbRequestWritten, ebRequestWriting)
return self._finishedRequest
class DirtyHTTP11ClientFactory(_HTTP11ClientFactory):
def buildProtocol(self, addr):
return DirtyHTTPClientProtocol(self._quiescentCallback)
class DirtyPool(HTTPConnectionPool):
_factory = DirtyHTTP11ClientFactory

View file

@ -1,23 +1,12 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
import treq from txupnp.scpd import SCPDCommand, SCPDRequester
import re from txupnp.util import get_dict_val_case_insensitive, verify_return_types, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, get_dict_val_case_insensitive
from txupnp.util import BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import DEVICE, ROOT
from txupnp.constants import SPEC_VERSION from txupnp.constants import SPEC_VERSION
from txupnp.commands import SCPDCommands
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
service_type_pattern = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\})"
)
xml_root_sanity_pattern = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
)
class CaseInsensitive: class CaseInsensitive:
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -112,7 +101,7 @@ class Device(CaseInsensitive):
class Gateway: class Gateway:
def __init__(self, **kwargs): def __init__(self, reactor, **kwargs):
flattened = { flattened = {
k.lower(): v for k, v in kwargs.items() k.lower(): v for k, v in kwargs.items()
} }
@ -133,9 +122,10 @@ class Gateway:
self.date = date.encode() self.date = date.encode()
self.urn = st.encode() self.urn = st.encode()
self._xml_response = ""
self._service_descriptors = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.xml_response = None
self.spec_version = None self.spec_version = None
self.url_base = None self.url_base = None
@ -143,55 +133,65 @@ class Gateway:
self._devices = [] self._devices = []
self._services = [] self._services = []
def debug_device(self, include_xml: bool = False, include_services: bool = True) -> dict: self._reactor = reactor
r = { self._unsupported_actions = {}
'server': self.server, self._registered_commands = {}
'urlBase': self.url_base, self.commands = SCPDCommands()
'location': self.location, self.requester = SCPDRequester(self._reactor)
"specVersion": self.spec_version,
'usn': self.usn,
'urn': self.urn,
'devices': [device.as_dict() for device in self._devices]
}
if include_xml:
r['xml_response'] = self.xml_response
if include_services:
r['services'] = [service.as_dict() for service in self._services]
def as_dict(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r return r
@defer.inlineCallbacks @defer.inlineCallbacks
def discover_services(self): def discover_commands(self):
log.debug("querying %s", self.location) response = yield self.requester.scpd_get(self.location.decode().split(self.base_address.decode())[1], self.base_address.decode(), self.port)
response = yield treq.get(self.location) self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.xml_response = yield response.content() self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.xml_response: if not self.url_base:
log.warning("service sent an empty reply\n%s", self.debug_device()) self.url_base = self.base_address.decode()
xml_dict = etree_to_dict(ElementTree.fromstring(self.xml_response)) if response:
schema_key = DEVICE
root = ROOT
if len(xml_dict) > 1:
log.warning(xml_dict.keys())
for k in xml_dict.keys():
m = xml_root_sanity_pattern.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
self.spec_version = get_dict_val_case_insensitive(flattened_xml, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(flattened_xml, "urlbase")
if flattened_xml:
self._device = Device( self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(flattened_xml, "device") self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
) )
log.debug("finished setting up root gateway. %i devices and %i services", len(self.devices),
len(self.services))
else: else:
self._device = Device(self._devices, self._services) self._device = Device(self._devices, self._services)
log.debug("finished setting up gateway:\n%s", self.debug_device()) for service_type in self.services.keys():
service = self.services[service_type]
yield self.register_commands(service)
@defer.inlineCallbacks
def register_commands(self, service: Service):
try:
action_list = yield self.requester.scpd_get_supported_actions(service, self.base_address.decode(), self.port)
except Exception as err:
log.exception("failed to register service %s: %s", service.serviceType, str(err))
return
for name, inputs, outputs in action_list:
try:
command = SCPDCommand(self.requester, self.base_address, self.port,
service.controlURL.encode(),
service.serviceType.encode(), name, inputs, outputs)
current = getattr(self.commands, command.method)
if hasattr(current, "_return_types"):
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
@property @property
def services(self) -> dict: def services(self) -> dict:
@ -205,7 +205,13 @@ class Gateway:
return {} return {}
return {device.udn: device for device in self._devices} return {device.udn: device for device in self._devices}
def get_service(self, service_type) -> Service: def get_service(self, service_type: str) -> Service:
for service in self._services: for service in self._services:
if service.serviceType.lower() == service_type.lower(): if service.serviceType.lower() == service_type.lower():
return service return service
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}

View file

@ -1,44 +1,118 @@
import os
import json
import logging import logging
import binascii from twisted.internet import task, defer
from twisted.internet import task from twisted.internet.error import ConnectionDone
from txupnp.constants import SSDP_IP_ADDRESS
from txupnp.fault import UPnPError
from txupnp.ssdp_datagram import SSDPDatagram
from twisted.internet.protocol import DatagramProtocol from twisted.internet.protocol import DatagramProtocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import _FakePort
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger() log = logging.getLogger()
class MockResponse:
def __init__(self, content):
self._content = content
self.headers = {}
def content(self):
return defer.succeed(self._content)
class MockDevice:
def __init__(self, manufacturer, model):
self.manufacturer = manufacturer
self.model = model
device_path = os.path.join(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
)
assert os.path.isfile(device_path)
with open(device_path, "r") as f:
self.device_dict = json.loads(f.read())
def __repr__(self):
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
def get_mock_devices():
return [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
if ".py" not in path and "pycache" not in path
]
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
r = [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
]
return r[0]
class MockMulticastTransport: class MockMulticastTransport:
def __init__(self, address, port, max_packet_size, network): def __init__(self, address, port, max_packet_size, network, protocol):
self.address = address self.address = address
self.port = port self.port = port
self.max_packet_size = max_packet_size self.max_packet_size = max_packet_size
self._network = network self._network = network
self._protocol = protocol
def write(self, data, address): def write(self, data, address):
if address in self._network.peers: if address[0] in self._network.group:
for dest in self._network.peers[address]: destinations = self._network.group[address[0]]
dest.datagramReceived(data, (self.address, self.port)) else:
else: # the node is sending to an address that doesnt currently exist, act like it never arrived destinations = address[0]
pass for address, dest in self._network.peers.items():
if address[0] in destinations and dest.address != self.address:
dest._protocol.datagramReceived(data, (self.address, self.port))
def setTTL(self, ttl): def setTTL(self, ttl):
pass pass
def joinGroup(self, address, interface=None): def joinGroup(self, address, interface=None):
pass group = self._network.group.get(address, [])
group.append(interface)
self._network.group[address] = group
def leaveGroup(self, address, interface=None): def leaveGroup(self, address, interface=None):
pass group = self._network.group.get(address, [])
if interface in group:
group.remove(interface)
self._network.group[address] = group
class MockMulticastPort(object): class MockTCPTransport(_FakePort):
def __init__(self, protocol, remover): def __init__(self, address, port, callback, mock_requests):
super().__init__((address, port))
self._callback = callback
self._mock_requests = mock_requests
def write(self, data):
if data.startswith(b"POST"):
for url, packets in self._mock_requests['POST'].items():
for request_response in packets:
if data.decode() == request_response['request']:
self._callback(request_response['response'].encode())
return
elif data.startswith(b"GET"):
for url, packets in self._mock_requests['GET'].items():
if data.decode() == packets['request']:
self._callback(packets['response'].encode())
return
class MockMulticastPort(_FakePort):
def __init__(self, protocol, remover, address, transport):
super().__init__((address, 1900))
self.protocol = protocol self.protocol = protocol
self._remover = remover self._remover = remover
self.transport = transport
def startListening(self, reason=None): def startListening(self, reason=None):
self.protocol.transport = self.transport
return self.protocol.startProtocol() return self.protocol.startProtocol()
def stopListening(self, reason=None): def stopListening(self, reason=None):
@ -49,57 +123,55 @@ class MockMulticastPort(object):
class MockNetwork: class MockNetwork:
def __init__(self): def __init__(self):
self.peers = {} # (interface, port): (protocol, max_packet_size) self.peers = {}
self.group = {}
def add_peer(self, port, protocol, interface, maxPacketSize): def add_peer(self, port, protocol, interface, maxPacketSize):
protocol.transport = MockMulticastTransport(interface, port, maxPacketSize, self) transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
peers = self.peers.get((interface, port), []) self.peers[(interface, port)] = transport
peers.append(protocol)
self.peers[(interface, port)] = peers
def remove_peer(): def remove_peer():
if self.peers.get((interface, port)): if self.peers.get((interface, port)):
self.peers[(interface, port)].remove(protocol)
if not self.peers.get((interface, port)):
del self.peers[(interface, port)] del self.peers[(interface, port)]
del protocol.transport return transport, remove_peer
return remove_peer
class MockReactor(task.Clock): class MockReactor(task.Clock):
def __init__(self): def __init__(self, client_addr, mock_scpd_requests):
super().__init__() super().__init__()
self.client_addr = client_addr
self._mock_scpd_requests = mock_scpd_requests
self.network = MockNetwork() self.network = MockNetwork()
def listenMulticast(self, port, protocol, interface=SSDP_IP_ADDRESS, maxPacketSize=8192, listenMultiple=True): def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
remover = self.network.add_peer(port, protocol, interface, maxPacketSize) interface = interface or self.client_addr
port = MockMulticastPort(protocol, remover) transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover, interface, transport)
port.startListening() port.startListening()
return port return port
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
protocol = factory.buildProtocol(host)
def _write_and_close(data):
protocol.dataReceived(data)
protocol.connectionLost(Failure(ConnectionDone()))
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
protocol.connectionMade()
class MockSSDPServiceGatewayProtocol(DatagramProtocol): class MockSSDPServiceGatewayProtocol(DatagramProtocol):
def __init__(self, iface, service_name, st, port, location, usn, version): def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
self.client_addr = client_addr
self.iface = iface self.iface = iface
self.service_name = service_name self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
self.gateway_st = st self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
self.gateway_location = location
self.gateway_usn = usn
self.gateway_version = version
self.gateway_port = port
def datagramReceived(self, datagram, address): def datagramReceived(self, datagram, address):
try:
packet = SSDPDatagram.decode(datagram) packet = SSDPDatagram.decode(datagram)
except UPnPError as err: if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err, reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
binascii.hexlify(datagram)) self.transport.write(reply.encode().encode(), (self.client_addr, 1900))
return else:
pass
if packet._packet_type == packet._M_SEARCH:
if packet.man == "ssdp:discover" and packet.st == self.gateway_st:
location = 'http://{}:{}/{}'.format(self.iface, self.gateway_port, self.gateway_location)
response = SSDPDatagram(SSDPDatagram._OK, st=self.gateway_st, cache_control='max-age=1800',
location=location, server=self.gateway_version, usn=self.gateway_usn)
self.transport.write(response.encode().encode(), address)

View file

@ -1,48 +1,224 @@
import re
import logging import logging
from collections import OrderedDict
from twisted.internet import defer, threads
from twisted.web.client import Agent
import treq
from treq.client import HTTPClient
from xml.etree import ElementTree from xml.etree import ElementTree
from txupnp.util import etree_to_dict, flatten_keys, return_types, verify_return_types, none_or_str, none from twisted.internet.protocol import Protocol, ClientFactory
from twisted.internet import defer, error
from txupnp.constants import XML_VERSION, DEVICE, ROOT, SERVICE, ENVELOPE, BODY
from txupnp.util import etree_to_dict, flatten_keys
from txupnp.fault import handle_fault, UPnPError from txupnp.fault import handle_fault, UPnPError
from txupnp.constants import SERVICE, SSDP_IP_ADDRESS, DEVICE, ROOT, service_types, ENVELOPE, XML_VERSION
from txupnp.constants import BODY, POST
from txupnp.dirty_pool import DirtyPool
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CONTENT_PATTERN = re.compile(
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
)
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
)
class StringProducer: XML_ROOT_SANITY_PATTERN = re.compile(
def __init__(self, body): "(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
self.body = body )
self.length = len(body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def xml_arg(name, arg): def parse_service_description(content: bytes):
return "<%s>%s</%s>" % (name, arg, name) element_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "scpd" not in service_info:
return []
action_list = service_info["scpd"]["actionList"]
if not len(action_list): # it could be an empty string
return []
result = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
result.append((
action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
def get_soap_body(service_name, method, param_names, **kwargs): class SCPDHTTPClientProtocol(Protocol):
args = "".join(xml_arg(n, kwargs.get(n)) for n in param_names) def connectionMade(self):
return '\n%s\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body><u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (XML_VERSION, method, service_name, args, method) self.response_buff = b""
self.factory.reactor.callLater(0, self.transport.write, self.factory.packet)
def dataReceived(self, data):
self.response_buff += data
def connectionLost(self, reason):
if reason.trap(error.ConnectionDone):
if XML_VERSION.encode() in self.response_buff:
parsed = CONTENT_PATTERN.findall(self.response_buff)
result = b'' if not parsed else parsed[0][0]
self.factory.finished_deferred.callback(result)
else:
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(self.response_buff)
result = b'' if not parsed else XML_VERSION.encode() + b'\r\n' + parsed[0][0]
self.factory.finished_deferred.callback(result)
class _SCPDCommand(object): class SCPDHTTPClientFactory(ClientFactory):
def __init__(self, http_client, gateway_address, service_port, control_url, service_id, method, param_names, protocol = SCPDHTTPClientProtocol
def __init__(self, reactor, packet):
self.reactor = reactor
self.finished_deferred = defer.Deferred()
self.packet = packet
def buildProtocol(self, addr):
p = self.protocol()
p.factory = self
return p
@classmethod
def post(cls, reactor, command, **kwargs):
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in command.param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, command.method, command.service_id.decode(),
args, command.method))
data = (
(
'POST %s HTTP/1.1\r\n'
'Host: %s\r\n'
'User-Agent: Debian/buster/sid, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n'
'Connection: Close\r\n'
'Cache-Control: no-cache\r\n'
'Pragma: no-cache\r\n'
'%s'
'\r\n'
) % (
command.control_url.decode(), # could be just / even if it shouldn't be
command.gateway_address.decode().split("http://")[1],
len(soap_body),
command.service_id.decode(), # maybe no quotes
command.method,
soap_body
)
).encode()
return cls(reactor, data)
@classmethod
def get(cls, reactor, control_url: str, address: str):
data = (
(
'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n'
'Host: %s\r\n'
'\r\n'
) % (control_url, address)
).encode()
return cls(reactor, data)
class SCPDRequester:
client_factory = SCPDHTTPClientFactory
def __init__(self, reactor):
self._reactor = reactor
self._get_requests = {}
self._post_requests = {}
def _save_get(self, request: bytes, response: bytes, destination: str) -> None:
self._get_requests[destination.lstrip("/")] = {
'request': request,
'response': response
}
def _save_post(self, request: bytes, response: bytes, destination: str) -> None:
p = self._post_requests.get(destination.lstrip("/"), [])
p.append({
'request': request,
'response': response,
})
self._post_requests[destination.lstrip("/")] = p
@defer.inlineCallbacks
def _scpd_get_soap_xml(self, control_url: str, address: str, service_port: int) -> bytes:
factory = self.client_factory.get(self._reactor, control_url, address)
url = address.split("http://")[1].split(":")[0]
self._reactor.connectTCP(url, service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_get(factory.packet, xml_response_bytes, control_url)
return xml_response_bytes
@defer.inlineCallbacks
def scpd_post_soap(self, command, **kwargs) -> tuple:
factory = self.client_factory.post(self._reactor, command, **kwargs)
url = command.gateway_address.split(b"http://")[1].split(b":")[0]
self._reactor.connectTCP(url.decode(), command.service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_post(
factory.packet, xml_response_bytes, command.gateway_address.decode() + command.control_url.decode()
)
content_dict = etree_to_dict(ElementTree.fromstring(xml_response_bytes.decode()))
envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % command.service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if command.method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in command.returns])
return extracted_response
@defer.inlineCallbacks
def scpd_get_supported_actions(self, service, address: str, port: int) -> list:
xml_bytes = yield self._scpd_get_soap_xml(service.SCPDURL, address, port)
return parse_service_description(xml_bytes)
@defer.inlineCallbacks
def scpd_get(self, control_url: str, service_address: str, service_port: int) -> dict:
xml_bytes = yield self._scpd_get_soap_xml(control_url, service_address, service_port)
xml_dict = etree_to_dict(ElementTree.fromstring(xml_bytes.decode()))
schema_key = DEVICE
root = ROOT
for k in xml_dict.keys():
m = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return flattened_xml
def dump_packets(self) -> dict:
return {
'GET': self._get_requests,
'POST': self._post_requests
}
class SCPDCommand:
def __init__(self, scpd_requester: SCPDRequester, gateway_address, service_port, control_url, service_id, method,
param_names,
returns): returns):
self._http_client = http_client self.scpd_requester = scpd_requester
self.gateway_address = gateway_address self.gateway_address = gateway_address
self.service_port = service_port self.service_port = service_port
self.control_url = control_url self.control_url = control_url
@ -51,59 +227,8 @@ class _SCPDCommand(object):
self.param_names = param_names self.param_names = param_names
self.returns = returns self.returns = returns
def extract_body(self, xml_response):
content_dict = etree_to_dict(ElementTree.fromstring(xml_response))
envelope = content_dict[ENVELOPE]
return flatten_keys(envelope[BODY], "{%s}" % self.service_id)
def extract_response(self, body):
body = handle_fault(body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if self.method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in self.returns])
if len(extracted_response) == 1:
return extracted_response[0]
return extracted_response
@defer.inlineCallbacks
def send_upnp_soap(self, **kwargs):
soap_body = get_soap_body(self.service_id, self.method, self.param_names, **kwargs).encode()
headers = OrderedDict((
('SOAPAction', '%s#%s' % (self.service_id, self.method)),
('Host', ('%s:%i' % (SSDP_IP_ADDRESS, self.service_port))),
('Content-Type', 'text/xml'),
('Content-Length', len(soap_body))
))
log.debug("send POST to %s\nheaders: %s\nbody:%s\n", self.control_url, headers, soap_body)
try:
response = yield self._http_client.request(
POST, url=self.control_url, data=soap_body, headers=headers
)
except Exception as err:
log.error("error (%s) sending POST to %s\nheaders: %s\nbody:%s\n", err, self.control_url, headers,
soap_body)
raise UPnPError(err)
xml_response = yield response.content()
try:
response = self.extract_response(self.extract_body(xml_response))
except UPnPError:
raise
except Exception as err:
log.debug("error extracting response (%s) to %s:\n%s", err, self.method, xml_response)
raise err
if not response:
log.debug("empty response to %s\n%s", self.method, xml_response)
defer.returnValue(response)
@staticmethod @staticmethod
def _process_result(results): def _process_result(*results):
""" """
this method gets decorated automatically with a function that maps result types to the types this method gets decorated automatically with a function that maps result types to the types
defined in the @return_types decorator defined in the @return_types decorator
@ -114,367 +239,10 @@ class _SCPDCommand(object):
def __call__(self, **kwargs): def __call__(self, **kwargs):
if set(kwargs.keys()) != set(self.param_names): if set(kwargs.keys()) != set(self.param_names):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_names)) raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_names))
response = yield self.send_upnp_soap(**kwargs) response = yield self.scpd_requester.scpd_post_soap(self, **kwargs)
try: try:
result = self._process_result(response) result = self._process_result(*response)
except Exception as err: except Exception as err:
log.error("error formatting response (%s):\n%s", err, response) log.error("error formatting response (%s):\n%s", err, response)
raise err raise err
defer.returnValue(result) defer.returnValue(result)
class SCPDResponse(object):
def __init__(self, url, headers, content):
self.url = url
self.headers = headers
self.content = content
def get_element_tree(self):
return ElementTree.fromstring(self.content)
def get_element_dict(self, service_key):
return flatten_keys(etree_to_dict(self.get_element_tree()), "{%s}" % service_key)
def get_action_list(self):
return self.get_element_dict(SERVICE)["scpd"]["actionList"]["action"]
def get_device_info(self):
return self.get_element_dict(DEVICE)[ROOT]
class SCPDCommandRunner(object):
def __init__(self, gateway, reactor, treq_get=None):
self._gateway = gateway
self._unsupported_actions = {}
self._registered_commands = {}
self._reactor = reactor
self._connection_pool = DirtyPool(reactor)
self._agent = Agent(reactor, connectTimeout=1, pool=self._connection_pool)
self._http_client = HTTPClient(self._agent, data_to_body_producer=StringProducer)
self._treq_get = treq_get or treq.get
@defer.inlineCallbacks
def _discover_commands(self, service):
scpd_url = self._gateway.base_address + service.SCPDURL.encode()
response = yield self._treq_get(scpd_url)
content = yield response.content()
try:
scpd_response = SCPDResponse(scpd_url, response.headers, content)
for action_dict in scpd_response.get_action_list():
self._register_command(action_dict, service.serviceType)
except Exception as err:
log.exception("failed to parse scpd response (%s) from %s\nheaders:\n%s\ncontent\n%s",
err, scpd_url, response.headers, content)
defer.returnValue(None)
@defer.inlineCallbacks
def discover_commands(self):
for service_type in service_types:
service = self._gateway.get_service(service_type)
if not service:
continue
yield self._discover_commands(service)
log.debug(self.debug_commands())
@staticmethod
def _soap_function_info(action_dict):
if not action_dict.get('argumentList'):
return (
action_dict['name'],
[],
[]
)
arg_dicts = action_dict['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return (
action_dict['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
)
def _patch_command(self, action_info, service_type):
name, inputs, outputs = self._soap_function_info(action_info)
command = _SCPDCommand(self._http_client, self._gateway.base_address, self._gateway.port,
self._gateway.base_address + self._gateway.get_service(service_type).controlURL.encode(),
self._gateway.get_service(service_type).serviceType.encode(), name, inputs, outputs)
current = getattr(self, command.method)
if hasattr(current, "_return_types"):
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self, command.method, command)
self._registered_commands[command.method] = service_type
log.debug("registered %s %s", service_type, action_info['name'])
return True
def _register_command(self, action_info, service_type):
try:
return self._patch_command(action_info, service_type)
except Exception as err:
s = self._unsupported_actions.get(service_type, [])
s.append((action_info, err))
self._unsupported_actions[service_type] = s
log.error("available command for %s does not have a wrapper implemented: %s", service_type, action_info)
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}
@staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus():
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(str)
def GetExternalIPAddress():
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetConnectionTypeInfo():
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str, int)
def GetStatusInfo():
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def ForceTermination():
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def RequestConnection():
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsReceived():
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService):
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetEnabledForInternet(NewEnabledForInternet):
raise NotImplementedError()
@staticmethod
@return_types(bool)
def GetEnabledForInternet():
raise NotImplementedError()
@staticmethod
def GetMaximumActiveConnections(NewActiveConnectionIndex):
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetActiveConnections():
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()
class UPnPFallback(object):
def __init__(self):
try:
import miniupnpc
self._upnp = miniupnpc.UPnP()
self.available = True
except ImportError:
self._upnp = None
self.available = False
@defer.inlineCallbacks
def discover(self):
if not self.available:
raise NotImplementedError()
devices = yield threads.deferToThread(self._upnp.discover)
if devices:
self.device_url = yield threads.deferToThread(self._upnp.selectigd)
else:
self.device_url = None
defer.returnValue(devices > 0)
@return_types(none)
def AddPortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient,
NewEnabled, NewPortMappingDescription, NewLeaseDuration=''):
"""Returns None"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.addportmapping, NewExternalPort, NewProtocol, NewInternalClient,
NewInternalPort, NewPortMappingDescription, NewLeaseDuration)
def GetNATRSIPStatus(self):
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@return_types(none_or_str, int, str, int, str, bool, str, int)
@defer.inlineCallbacks
def GetGenericPortMappingEntry(self, NewPortMappingIndex):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
"""
if not self.available:
raise NotImplementedError()
result = yield threads.deferToThread(self._upnp.getgenericportmapping, NewPortMappingIndex)
if not result:
raise UPnPError()
ext_port, protocol, (int_host, int_port), desc, enabled, remote_host, lease = result
defer.returnValue((remote_host, ext_port, protocol, int_port, int_host, enabled, desc, lease))
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(self, NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.getspecificportmapping, NewExternalPort, NewProtocol)
def SetConnectionType(self, NewConnectionType):
"""Returns None"""
raise NotImplementedError()
@return_types(str)
def GetExternalIPAddress(self):
"""Returns (NewExternalIPAddress)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.externalipaddress)
def GetConnectionTypeInfo(self):
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@return_types(str, str, int)
def GetStatusInfo(self):
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.statusinfo)
def ForceTermination(self):
"""Returns None"""
raise NotImplementedError()
@return_types(none)
def DeletePortMapping(self, NewRemoteHost, NewExternalPort, NewProtocol):
"""Returns None"""
if not self.available:
raise NotImplementedError()
return threads.deferToThread(self._upnp.deleteportmapping, NewExternalPort, NewProtocol)
def RequestConnection(self):
"""Returns None"""
raise NotImplementedError()
def GetCommonLinkProperties(self):
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
def GetTotalBytesSent(self):
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
def GetTotalBytesReceived(self):
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
def GetTotalPacketsSent(self):
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
def GetTotalPacketsReceived(self):
"""Returns (NewTotalPacketsReceived)"""
raise NotImplementedError()
def X_GetICSStatistics(self):
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
def GetDefaultConnectionService(self):
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
def SetDefaultConnectionService(self, NewDefaultConnectionService):
"""Returns (None)"""
raise NotImplementedError()

View file

@ -1,79 +0,0 @@
import logging
import netifaces
from twisted.internet import defer
from txupnp.ssdp import SSDPFactory
from txupnp.scpd import SCPDCommandRunner
from txupnp.gateway import Gateway
from txupnp.fault import UPnPError
from txupnp.constants import UPNP_ORG_IGD
log = logging.getLogger(__name__)
class SOAPServiceManager(object):
def __init__(self, reactor, treq_get=None):
self._reactor = reactor
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip)
self._command_runners = {}
self._selected_runner = UPNP_ORG_IGD
self._treq_get = treq_get
@defer.inlineCallbacks
def discover_services(self, address=None, timeout=30, max_devices=1):
server_infos = yield self.sspd_factory.m_search(
address or self.router_ip, timeout=timeout, max_devices=max_devices
)
locations = []
for server_info in server_infos:
if 'st' in server_info and server_info['st'] not in self._command_runners:
locations.append(server_info['location'])
gateway = Gateway(**server_info)
yield gateway.discover_services()
command_runner = SCPDCommandRunner(gateway, self._reactor, self._treq_get)
yield command_runner.discover_commands()
self._command_runners[gateway.urn.decode()] = command_runner
elif 'st' not in server_info:
log.error("don't know how to handle gateway: %s", server_info)
continue
defer.returnValue(len(self._command_runners) > 0)
def set_runner(self, urn):
if urn not in self._command_runners:
raise IndexError(urn)
self._command_runners = urn
def get_runner(self):
if self._command_runners and not self._selected_runner in self._command_runners:
self._selected_runner = list(self._command_runners.keys())[0]
if not self._command_runners:
raise UPnPError("no devices found")
return self._command_runners[self._selected_runner]
def get_available_runners(self):
return self._command_runners.keys()
def debug(self, include_gateway_xml=False):
results = []
for runner in self._command_runners.values():
gateway = runner._gateway
info = gateway.debug_device(include_xml=include_gateway_xml)
commands = runner.debug_commands()
service_result = []
for service in info['services']:
service_commands = []
unavailable = []
for command, service_type in commands['available'].items():
if service['serviceType'] == service_type:
service_commands.append(command)
for command, service_type in commands['failed'].items():
if service['serviceType'] == service_type:
unavailable.append(command)
services_with_commands = dict(service)
services_with_commands['available_commands'] = service_commands
services_with_commands['unavailable_commands'] = unavailable
service_result.append(services_with_commands)
info['services'] = service_result
results.append(info)
return results

View file

@ -12,7 +12,8 @@ log = logging.getLogger(__name__)
class SSDPProtocol(DatagramProtocol): class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS, def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None): ssdp_port=SSDP_PORT, ttl=1, max_devices=None, debug_packets=False,
debug_sent=None, debug_received=None):
self._reactor = reactor self._reactor = reactor
self._sem = defer.DeferredSemaphore(1) self._sem = defer.DeferredSemaphore(1)
self.discover_callbacks = {} self.discover_callbacks = {}
@ -24,34 +25,36 @@ class SSDPProtocol(DatagramProtocol):
self._start = None self._start = None
self.max_devices = max_devices self.max_devices = max_devices
self.devices = [] self.devices = []
self.debug_packets = debug_packets
self.debug_sent = debug_sent if debug_sent is not None else []
self.debug_received = debug_received if debug_sent is not None else []
def _send_m_search(self, service=UPNP_ORG_IGD): def _send_m_search(self, service=UPNP_ORG_IGD):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1) packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode()) log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try: try:
self.transport.write(packet.encode().encode(), (self.ssdp_address, self.ssdp_port)) msg_bytes = packet.encode().encode()
if self.debug_packets:
self.debug_sent.append(msg_bytes)
self.transport.write(msg_bytes, (self.ssdp_address, self.ssdp_port))
except Exception as err: except Exception as err:
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port) log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err raise err
@staticmethod @staticmethod
def _gather(finished_deferred, max_results): def _gather(finished_deferred, max_results, results: list):
results = []
def discover_cb(packet): def discover_cb(packet):
if not finished_deferred.called: if not finished_deferred.called and packet.st in service_types:
results.append(packet.as_dict()) results.append(packet.as_dict())
if len(results) >= max_results: if len(results) >= max_results:
finished_deferred.callback(results) finished_deferred.callback(results)
return discover_cb return discover_cb
def m_search(self, address=None, timeout=1, max_devices=1): def m_search(self, address, timeout, max_devices):
address = address or self.iface
# return deferred for a pending call if we have one # return deferred for a pending call if we have one
if address in self.discover_callbacks: if address in self.discover_callbacks:
d = self.protocol.discover_callbacks[address][1] d = self.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one if not d.called: # the existing deferred has already fired, make a new one
return d return d
@ -63,7 +66,7 @@ class SSDPProtocol(DatagramProtocol):
d = defer.Deferred() d = defer.Deferred()
d.addTimeout(timeout, self._reactor) d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results) d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices) found_cb = self._gather(d, max_devices, self.devices)
self.discover_callbacks[address] = found_cb, d self.discover_callbacks[address] = found_cb, d
for st in service_types: for st in service_types:
self._send_m_search(service=st) self._send_m_search(service=st)
@ -73,11 +76,12 @@ class SSDPProtocol(DatagramProtocol):
self._start = self._reactor.seconds() self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl) self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface) self.transport.joinGroup(self.ssdp_address, interface=self.iface)
self.m_search()
def datagramReceived(self, datagram, address): def datagramReceived(self, datagram, address):
if address[0] == self.iface: if address[0] == self.iface:
return return
if self.debug_packets:
self.debug_received.append((address, datagram))
try: try:
packet = SSDPDatagram.decode(datagram) packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode()) log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
@ -90,23 +94,30 @@ class SSDPProtocol(DatagramProtocol):
return return
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location) log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
if packet.st not in map(lambda p: p['st'], self.devices): # if address[0] in self.discover_callbacks and packet.location not in map(lambda p: p['location'], self.devices):
if packet.location not in map(lambda p: p['location'], self.devices):
if address[0] not in self.discover_callbacks:
self.devices.append(packet.as_dict()) self.devices.append(packet.as_dict())
log.debug("%i device%s so far", len(self.devices), "" if len(self.devices) < 2 else "s") else:
if address[0] in self.discover_callbacks:
self._sem.run(self.discover_callbacks[address[0]][0], packet) self._sem.run(self.discover_callbacks[address[0]][0], packet)
else:
log.info("ignored packet from %s:%s (%s) %s", address[0], address[1], packet._packet_type, packet.location)
elif packet._packet_type == packet._NOTIFY: elif packet._packet_type == packet._NOTIFY:
log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts, log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts,
packet.location) packet.location)
class SSDPFactory(object): class SSDPFactory:
def __init__(self, reactor, lan_address, router_address): def __init__(self, reactor, lan_address, router_address, debug_packets=False):
self.lan_address = lan_address self.lan_address = lan_address
self.router_address = router_address self.router_address = router_address
self._reactor = reactor self._reactor = reactor
self.protocol = None self.protocol = None
self.port = None self.port = None
self.debug_packets = debug_packets
self.debug_sent = []
self.debug_received = []
self.server_infos = []
def disconnect(self): def disconnect(self):
if not self.port: if not self.port:
@ -118,13 +129,15 @@ class SSDPFactory(object):
def connect(self): def connect(self):
if not self.protocol: if not self.protocol:
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address) self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address,
debug_packets=self.debug_packets, debug_sent=self.debug_sent,
debug_received=self.debug_received)
if not self.port: if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect) self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True) self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def m_search(self, address, timeout=1, max_devices=1): def m_search(self, address, timeout, max_devices):
""" """
Perform a M-SEARCH (HTTP over UDP) and gather the results Perform a M-SEARCH (HTTP over UDP) and gather the results
@ -140,7 +153,16 @@ class SSDPFactory(object):
'usn': (str) usn 'usn': (str) usn
}, ...] }, ...]
""" """
self.connect() self.connect()
server_infos = yield self.protocol.m_search(address, timeout, max_devices) server_infos = yield self.protocol.m_search(address, timeout, max_devices)
for server_info in server_infos:
self.server_infos.append(server_info)
defer.returnValue(server_infos) defer.returnValue(server_infos)
def get_ssdp_packet_replay(self) -> dict:
return {
'lan_address': self.lan_address,
'router_address': self.router_address,
'sent': self.debug_sent,
'received': self.debug_received,
}

View file

@ -1,45 +1,29 @@
import netifaces
import logging import logging
import json
from twisted.internet import defer from twisted.internet import defer
from txupnp.fault import UPnPError from txupnp.fault import UPnPError
from txupnp.soap import SOAPServiceManager from txupnp.ssdp import SSDPFactory
from txupnp.scpd import UPnPFallback from txupnp.gateway import Gateway
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class UPnP(object): class UPnP:
def __init__(self, reactor, try_miniupnpc_fallback=True, treq_get=None): def __init__(self, reactor, try_miniupnpc_fallback=False, debug_ssdp=False, router_ip=None,
lan_ip=None, iface_name=None):
self._reactor = reactor self._reactor = reactor
if router_ip and lan_ip and iface_name:
self.router_ip, self.lan_address, self.iface_name = router_ip, lan_ip, iface_name
else:
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip, debug_packets=debug_ssdp)
self.try_miniupnpc_fallback = try_miniupnpc_fallback self.try_miniupnpc_fallback = try_miniupnpc_fallback
self.soap_manager = SOAPServiceManager(reactor, treq_get=treq_get)
self.miniupnpc_runner = None self.miniupnpc_runner = None
self.miniupnpc_igd_url = None self.miniupnpc_igd_url = None
self.gateway = None
@property def m_search(self, address, timeout=1, max_devices=1):
def lan_address(self):
return self.soap_manager.lan_address
@property
def commands(self):
try:
runner = self.soap_manager.get_runner()
required_commands = [
"GetExternalIPAddress",
"AddPortMapping",
"GetSpecificPortMappingEntry",
"GetGenericPortMappingEntry",
"DeletePortMapping"
]
if all((command in runner._registered_commands for command in required_commands)):
return runner
raise UPnPError("required commands not found")
except UPnPError as err:
if self.try_miniupnpc_fallback and self.miniupnpc_runner:
return self.miniupnpc_runner
log.warning("upnp is not available: %s", err)
def m_search(self, address, timeout=30, max_devices=2):
""" """
Perform a HTTP over UDP M-SEARCH query Perform a HTTP over UDP M-SEARCH query
@ -51,68 +35,52 @@ class UPnP(object):
'usn': <usn> 'usn': <usn>
}, ...] }, ...]
""" """
return self.soap_manager.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices) return self.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks @defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1, keep_listening=False, try_txupnp=True): def _discover(self, timeout=1, max_devices=1):
found = False server_infos = yield self.sspd_factory.m_search(
if not try_txupnp and not self.try_miniupnpc_fallback: self.router_ip, timeout=timeout, max_devices=max_devices
log.warning("nothing left to try") )
if try_txupnp: server_info = server_infos[0]
if 'st' in server_info:
gateway = Gateway(reactor=self._reactor, **server_info)
yield gateway.discover_commands()
self.gateway = gateway
defer.returnValue(True)
elif 'st' not in server_info:
log.error("don't know how to handle gateway: %s", server_info)
defer.returnValue(False)
@defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1):
try: try:
found = yield self.soap_manager.discover_services(timeout=timeout, max_devices=max_devices) found = yield self._discover(timeout=timeout, max_devices=max_devices)
except defer.TimeoutError: except defer.TimeoutError:
found = False found = False
finally: finally:
if not keep_listening: self.sspd_factory.disconnect()
self.soap_manager.sspd_factory.disconnect()
if found: if found:
try: log.debug("found upnp device")
runner = self.soap_manager.get_runner() else:
required_commands = [ log.debug("failed to find upnp device")
"GetExternalIPAddress",
"AddPortMapping",
"GetSpecificPortMappingEntry",
"GetGenericPortMappingEntry",
"DeletePortMapping"
]
found = all((command in runner._registered_commands for command in required_commands))
except UPnPError:
found = False
if not found and self.try_miniupnpc_fallback:
found = yield self.start_miniupnpc_fallback()
defer.returnValue(found) defer.returnValue(found)
@defer.inlineCallbacks def get_external_ip(self) -> str:
def start_miniupnpc_fallback(self): return self.gateway.commands.GetExternalIPAddress()
found = False
if not self.miniupnpc_runner:
log.debug("trying miniupnpc fallback")
fallback = UPnPFallback()
success = yield fallback.discover()
self.miniupnpc_igd_url = fallback.device_url
if success:
log.info("successfully started miniupnpc fallback")
self.miniupnpc_runner = fallback
found = True
if not found:
log.warning("failed to find upnp gateway using miniupnpc fallback")
defer.returnValue(found)
def get_external_ip(self): def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
return self.commands.GetExternalIPAddress() description: str) -> None:
return self.gateway.commands.AddPortMapping(
def add_port_mapping(self, external_port, protocol, internal_port, lan_address, description):
return self.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration="" NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_port_mapping_by_index(self, index): def get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
try: try:
redirect = yield self.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) redirect = yield self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
defer.returnValue(redirect) defer.returnValue(redirect)
except UPnPError: except UPnPError:
defer.returnValue(None) defer.returnValue(None)
@ -129,7 +97,7 @@ class UPnP(object):
defer.returnValue(redirects) defer.returnValue(redirects)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_specific_port_mapping(self, external_port, protocol): def get_specific_port_mapping(self, external_port: int, protocol: str) -> (int, str, bool, str, int):
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
@ -137,41 +105,23 @@ class UPnP(object):
""" """
try: try:
result = yield self.commands.GetSpecificPortMappingEntry( result = yield self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
) )
defer.returnValue(result) defer.returnValue(result)
except UPnPError: except UPnPError:
defer.returnValue(None) defer.returnValue(None)
def delete_port_mapping(self, external_port, protocol, new_remote_host=""): def delete_port_mapping(self, external_port: int, protocol: str) -> None:
""" """
:param external_port: (int) external port to listen on :param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP' :param protocol: (str) 'UDP' | 'TCP'
:return: None :return: None
""" """
return self.commands.DeletePortMapping( return self.gateway.commands.DeletePortMapping(
NewRemoteHost=new_remote_host, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
) )
def get_rsip_nat_status(self):
"""
:return: (bool) NewRSIPAvailable, (bool) NewNATEnabled
"""
return self.commands.GetNATRSIPStatus()
def get_status_info(self):
"""
:return: (str) NewConnectionStatus, (str) NewLastConnectionError, (int) NewUptime
"""
return self.commands.GetStatusInfo()
def get_connection_type_info(self):
"""
:return: (str) NewConnectionType (str), NewPossibleConnectionTypes (str)
"""
return self.commands.GetConnectionTypeInfo()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_next_mapping(self, port, protocol, description, internal_port=None): def get_next_mapping(self, port, protocol, description, internal_port=None):
if protocol not in ["UDP", "TCP"]: if protocol not in ["UDP", "TCP"]:
@ -192,15 +142,3 @@ class UPnP(object):
port, protocol, internal_port, self.lan_address, description port, protocol, internal_port, self.lan_address, description
) )
defer.returnValue(port) defer.returnValue(port)
def get_debug_info(self, include_gateway_xml=False):
def default_byte(x):
if isinstance(x, bytes):
return x.decode()
return x
return json.dumps({
'txupnp': self.soap_manager.debug(include_gateway_xml=include_gateway_xml),
'miniupnpc_igd_url': self.miniupnpc_igd_url
},
indent=2, default=default_byte
)

View file

@ -1,7 +1,6 @@
import re import re
import functools import functools
from collections import defaultdict from collections import defaultdict
from twisted.internet import defer
from xml.etree import ElementTree from xml.etree import ElementTree
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode()) BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
@ -59,13 +58,11 @@ def verify_return_types(*types):
def _verify_return_types(fn): def _verify_return_types(fn):
@functools.wraps(fn) @functools.wraps(fn)
def _inner(response): def _inner(*result):
if isinstance(response, (list, tuple)): r = fn(*tuple(t(r) for t, r in zip(types, result)))
r = tuple(t(r) for t, r in zip(types, response)) if isinstance(r, tuple) and len(r) == 1:
if len(r) == 1: return r[0]
return fn(r[0]) return r
return fn(r)
return fn(types[0](response))
return _inner return _inner
return _verify_return_types return _verify_return_types
@ -85,18 +82,3 @@ def return_types(*types):
none_or_str = lambda x: None if not x or x == 'None' else str(x) none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None none = lambda _: None
@defer.inlineCallbacks
def DeferredDict(d, consumeErrors=False):
keys = []
dl = []
response = {}
for k, v in d.items():
keys.append(k)
dl.append(v)
results = yield defer.DeferredList(dl, consumeErrors=consumeErrors)
for k, (success, result) in zip(keys, results):
if success:
response[k] = result
defer.returnValue(response)