convert to asyncio

This commit is contained in:
Jack Robison 2018-10-07 22:30:13 -04:00
parent 2bd4b5e3e6
commit 02a00e7ef4
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
31 changed files with 1500 additions and 1354 deletions

View file

@ -6,10 +6,10 @@ python: "3.7"
before_install:
- pip install pylint coverage
- pip install -e .
# - pylint txupnp
# - pylint aioupnp
script:
- HOME=/tmp coverage run --source=txupnp -m twisted.trial tests
- HOME=/tmp coverage run --source=aioupnp -m unittest -v
after_success:
- bash <(curl -s https://codecov.io/bash)

View file

@ -1,28 +1,27 @@
[![codecov](https://codecov.io/gh/lbryio/txupnp/branch/master/graph/badge.svg)](https://codecov.io/gh/lbryio/txupnp)
# UPnP for Twisted
# UPnP for asyncio
`txupnp` is a python 3 library to interact with UPnP gateways using `twisted`
`aioupnp` is a python 3 library and command line tool to interact with UPnP gateways using asyncio. `aioupnp` requires the `netifaces` module.
## Installation
```
pip install --upgrade txupnp
pip install --upgrade aioupnp
```
## Usage
```
usage: txupnp-cli [-h] [--debug_logging] [--include_igd_xml] command
usage: txupnp [-h] [--debug_logging=<debug_logging>] [--interface=<interface>]
[--gateway_address=<gateway_address>]
[--lan_address=<lan_address>] [--timeout=<timeout>]
[--service=<service>]
command [--<arg name>=<arg>]...
positional arguments:
command debug_device | list_mappings | get_external_ip |
add_mapping | delete_mapping
commands: add_port_mapping | delete_port_mapping | get_external_ip | get_next_mapping | get_port_mapping_by_index | get_redirects | get_soap_commands | get_specific_port_mapping | m_search
optional arguments:
-h, --help show this help message and exit
--debug_logging
--include_igd_xml
for help with a specific command: txupnp help <command>
```

View file

@ -1,5 +1,5 @@
__version__ = "0.0.1a11"
__name__ = "txupnp"
__name__ = "aioupnp"
__author__ = "Jack Robison"
__maintainer__ = "Jack Robison"
__license__ = "MIT"

90
aioupnp/__main__.py Normal file
View file

@ -0,0 +1,90 @@
import logging
import sys
from aioupnp.upnp import UPnP
from aioupnp.constants import UPNP_ORG_IGD
log = logging.getLogger("aioupnp")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)
log.setLevel(logging.WARNING)
def get_help(command):
fn = getattr(UPnP, command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, k) for k in fn.__annotations__ if k != 'return'])
return \
"usage: aioupnp [--debug_logging=<debug_logging>] [--interface=<interface>]\n" \
" [--gateway_address=<gateway_address>]\n" \
" [--lan_address=<lan_address>] [--timeout=<timeout>]\n" \
" [--service=<service>]\n" \
" %s\n" % params
def main():
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
help_str = " | ".join(commands)
usage = \
"usage: aioupnp [-h] [--debug_logging=<debug_logging>] [--interface=<interface>]\n" \
" [--gateway_address=<gateway_address>]\n" \
" [--lan_address=<lan_address>] [--timeout=<timeout>]\n" \
" [--service=<service>]\n" \
" command [--<arg name>=<arg>]...\n" \
"\n" \
"commands: %s\n\nfor help with a specific command: aioupnp help <command>" % help_str
args = sys.argv[1:]
if args[0] in ['help', '-h', '--help']:
if len(args) > 1:
if args[1] in commands:
sys.exit(get_help(args[1]))
sys.exit(print(usage))
defaults = {
'debug_logging': False,
'interface': 'default',
'gateway_address': '',
'lan_address': '',
'timeout': 1,
'service': UPNP_ORG_IGD,
'return_as_json': True
}
options = {}
command = None
for arg in args:
if arg.startswith("--"):
k, v = arg.split("=")
k = k.lstrip('--')
options[k] = v
else:
command = arg
break
if not command:
print("no command given")
sys.exit(print(usage))
kwargs = {}
for arg in args[len(options)+1:]:
if arg.startswith("--"):
k, v = arg.split("=")
k = k.lstrip('--')
kwargs[k] = v
else:
break
for k, v in defaults.items():
if k not in options:
options[k] = v
if options.pop('debug_logging'):
log.setLevel(logging.DEBUG)
UPnP.run_cli(
command.replace('-', '_'), options.pop('lan_address'), options.pop('gateway_address'),
options.pop('timeout'), options.pop('service'), options.pop('interface'),
kwargs
)
if __name__ == "__main__":
main()

View file

@ -1,27 +1,25 @@
from txupnp.util import return_types, none_or_str, none
def none_or_str(x):
return None if not x or x == 'None' else str(x)
class SCPDCommands: # TODO use type annotations
class SCPDCommands:
def debug_commands(self) -> dict:
raise NotImplementedError()
@staticmethod
@return_types(none)
def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
async def AddPortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int,
NewInternalClient: str, NewEnabled: bool, NewPortMappingDescription: str,
NewLeaseDuration: str = '') -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(bool, bool)
def GetNATRSIPStatus() -> (bool, bool):
async def GetNATRSIPStatus() -> (bool, bool):
"""Returns (NewRSIPAvailable, NewNATEnabled)"""
raise NotImplementedError()
@staticmethod
@return_types(none_or_str, int, str, int, str, bool, str, int)
def GetGenericPortMappingEntry(NewPortMappingIndex) -> (none_or_str, int, str, int, str, bool, str, int):
async def GetGenericPortMappingEntry(NewPortMappingIndex: int) -> (none_or_str, int, str, int, str, bool, str, int):
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
@ -29,70 +27,62 @@ class SCPDCommands: # TODO use type annotations
raise NotImplementedError()
@staticmethod
@return_types(int, str, bool, str, int)
def GetSpecificPortMappingEntry(NewRemoteHost, NewExternalPort, NewProtocol) -> (int, str, bool, str, int):
async def GetSpecificPortMappingEntry(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> (int, str, bool, str, int):
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetConnectionType(NewConnectionType) -> None:
async def SetConnectionType(NewConnectionType: str) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(str)
def GetExternalIPAddress() -> str:
async def GetExternalIPAddress() -> str:
"""Returns (NewExternalIPAddress)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetConnectionTypeInfo() -> (str, str):
async def GetConnectionTypeInfo() -> (str, str):
"""Returns (NewConnectionType, NewPossibleConnectionTypes)"""
raise NotImplementedError()
@staticmethod
@return_types(str, str, int)
def GetStatusInfo() -> (str, str, int):
async def GetStatusInfo() -> (str, str, int):
"""Returns (NewConnectionStatus, NewLastConnectionError, NewUptime)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def ForceTermination() -> None:
async def ForceTermination() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def DeletePortMapping(NewRemoteHost, NewExternalPort, NewProtocol) -> None:
async def DeletePortMapping(NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def RequestConnection() -> None:
async def RequestConnection() -> None:
"""Returns None"""
raise NotImplementedError()
@staticmethod
def GetCommonLinkProperties():
async def GetCommonLinkProperties():
"""Returns (NewWANAccessType, NewLayer1UpstreamMaxBitRate, NewLayer1DownstreamMaxBitRate, NewPhysicalLinkStatus)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesSent():
async def GetTotalBytesSent():
"""Returns (NewTotalBytesSent)"""
raise NotImplementedError()
@staticmethod
def GetTotalBytesReceived():
async def GetTotalBytesReceived():
"""Returns (NewTotalBytesReceived)"""
raise NotImplementedError()
@staticmethod
def GetTotalPacketsSent():
async def GetTotalPacketsSent():
"""Returns (NewTotalPacketsSent)"""
raise NotImplementedError()
@ -102,36 +92,33 @@ class SCPDCommands: # TODO use type annotations
raise NotImplementedError()
@staticmethod
def X_GetICSStatistics():
async def X_GetICSStatistics() -> (int, int, int, int, str, str):
"""Returns (TotalBytesSent, TotalBytesReceived, TotalPacketsSent, TotalPacketsReceived, Layer1DownstreamMaxBitRate, Uptime)"""
raise NotImplementedError()
@staticmethod
def GetDefaultConnectionService():
async def GetDefaultConnectionService():
"""Returns (NewDefaultConnectionService)"""
raise NotImplementedError()
@staticmethod
def SetDefaultConnectionService(NewDefaultConnectionService) -> None:
async def SetDefaultConnectionService(NewDefaultConnectionService: str) -> None:
"""Returns (None)"""
raise NotImplementedError()
@staticmethod
@return_types(none)
def SetEnabledForInternet(NewEnabledForInternet) -> None:
async def SetEnabledForInternet(NewEnabledForInternet: bool) -> None:
raise NotImplementedError()
@staticmethod
@return_types(bool)
def GetEnabledForInternet() -> bool:
async def GetEnabledForInternet() -> bool:
raise NotImplementedError()
@staticmethod
def GetMaximumActiveConnections(NewActiveConnectionIndex):
async def GetMaximumActiveConnections(NewActiveConnectionIndex: int):
raise NotImplementedError()
@staticmethod
@return_types(str, str)
def GetActiveConnections() -> (str, str):
async def GetActiveConnections() -> (str, str):
"""Returns (NewActiveConnDeviceContainer, NewActiveConnectionServiceID"""
raise NotImplementedError()

View file

@ -18,11 +18,6 @@ WAN_SCHEMA = 'urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
LAYER_SCHEMA = 'urn:schemas-upnp-org:service:Layer3Forwarding:1'
IP_SCHEMA = 'urn:schemas-upnp-org:service:WANIPConnection:1'
service_types = [
UPNP_ORG_IGD,
# WIFI_ALLIANCE_ORG_IGD,
]
SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)

95
aioupnp/device.py Normal file
View file

@ -0,0 +1,95 @@
import logging
log = logging.getLogger(__name__)
class CaseInsensitive:
def __init__(self, **kwargs):
not_evaluated = {}
for k, v in kwargs.items():
if k.startswith("_"):
not_evaluated[k] = v
continue
try:
getattr(self, k)
setattr(self, k, v)
except AttributeError as err:
not_evaluated[k] = v
if not_evaluated:
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
def _get_attr_name(self, case_insensitive: str) -> str:
for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower():
return k
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
for k, v in self.__class__.__dict__.items():
if k.lower() == item.lower():
if k not in self.__dict__:
self.__dict__[k] = v
return v
raise AttributeError(item)
def __setattr__(self, item, value):
if item in self.__dict__:
self.__dict__[item] = value
return
to_update = None
for k, v in self.__dict__.items():
if k.lower() == item.lower():
to_update = k
break
self.__dict__[to_update or item] = value
def as_dict(self) -> dict:
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
}
class Service(CaseInsensitive):
serviceType = None
serviceId = None
controlURL = None
eventSubURL = None
SCPDURL = None
class Device(CaseInsensitive):
serviceList = None
deviceList = None
deviceType = None
friendlyName = None
manufacturer = None
manufacturerURL = None
modelDescription = None
modelName = None
modelNumber = None
modelURL = None
serialNumber = None
udn = None
upc = None
presentationURL = None
iconList = None
def __init__(self, devices, services, **kwargs):
super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"]
if isinstance(new_services, dict):
new_services = [new_services]
services.extend([Service(**service) for service in new_services])
if self.deviceList:
for kw in self.deviceList.values():
if isinstance(kw, dict):
d = Device(devices, services, **kw)
devices.append(d)
elif isinstance(kw, list):
for _inner_kw in kw:
d = Device(devices, services, **_inner_kw)
devices.append(d)
else:
log.warning("failed to parse device:\n%s", kw)

View file

@ -1,5 +1,5 @@
from txupnp.util import flatten_keys
from txupnp.constants import FAULT, CONTROL
from aioupnp.util import flatten_keys
from aioupnp.constants import FAULT, CONTROL
class UPnPError(Exception):

181
aioupnp/gateway.py Normal file
View file

@ -0,0 +1,181 @@
import logging
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE
from aioupnp.commands import SCPDCommands
from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import m_search
from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SCPDCommand
from aioupnp.util import flatten_keys
log = logging.getLogger(__name__)
def get_action_list(element_dict: dict) -> list: # [(<method>, [<input1>, ...], [<output1, ...]), ...]
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "actionList" in service_info:
action_list = service_info["actionList"]
else:
return []
if not len(action_list): # it could be an empty string
return []
result = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg
arg_dicts = [arg_dicts]
result.append((
action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
class Gateway:
def __init__(self, **kwargs):
flattened = {
k.lower(): v for k, v in kwargs.items()
}
usn = flattened["usn"]
server = flattened["server"]
location = flattened["location"]
st = flattened["st"]
cache_control = flattened.get("cache_control") or flattened.get("cache-control") or ""
date = flattened.get("date", "")
ext = flattened.get("ext", "")
self.usn = usn.encode()
self.ext = ext.encode()
self.server = server.encode()
self.location = location.encode()
self.cache_control = cache_control.encode()
self.date = date.encode()
self.urn = st.encode()
self._xml_response = ""
self._service_descriptors = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None
self.url_base = None
self._device = None
self._devices = []
self._services = []
self._unsupported_actions = {}
self._registered_commands = {}
self.commands = SCPDCommands()
def gateway_descriptor(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r
@property
def services(self) -> dict:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
@property
def devices(self) -> dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Service:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}
@classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD):
datagram = await m_search(lan_address, gateway_address, timeout, service)
gateway = cls(**datagram.as_dict())
await gateway.discover_commands()
return gateway
async def discover_commands(self):
response = await scpd_get("/" + self.path.decode(), self.base_ip.decode(), self.port)
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
)
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
await self.register_commands(self.services[service_type])
async def register_commands(self, service: Service):
service_dict = await scpd_get(("" if service.SCPDURL.startswith("/") else "/") + service.SCPDURL,
self.base_ip.decode(), self.port)
if not service_dict:
return
action_list = get_action_list(service_dict)
for name, inputs, outputs in action_list:
try:
current = getattr(self.commands, name)
annotations = current.__annotations__
return_types = annotations.get('return', None)
if return_types:
if not isinstance(return_types, tuple):
return_types = (return_types, )
return_types = {r: t for r, t in zip(outputs, return_types)}
param_types = {}
for param_name, param_type in annotations.items():
if param_name == "return":
continue
param_types[param_name] = param_type
command = SCPDCommand(
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
name, param_types, return_types, inputs, outputs)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)

View file

View file

@ -0,0 +1,46 @@
import struct
import socket
from asyncio.protocols import DatagramProtocol
from asyncio.transports import DatagramTransport
class MulticastProtocol(DatagramProtocol):
@property
def socket(self) -> socket.socket:
return self.transport.get_extra_info('socket')
def get_ttl(self) -> int:
return self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
def set_ttl(self, ttl: int = 1) -> None:
self.socket.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
)
def join_group(self, addr: str, interface: str) -> None:
addr = socket.inet_aton(addr)
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, addr + interface)
def leave_group(self, addr: str, interface: str) -> None:
addr = socket.inet_aton(addr)
interface = socket.inet_aton(interface)
self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, addr + interface)
def connection_made(self, transport: DatagramTransport) -> None:
self.transport = transport
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@classmethod
def create_socket(cls, bind_address: str, multicast_address: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((bind_address, 0))
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
sock.setblocking(False)
return sock

88
aioupnp/protocols/scpd.py Normal file
View file

@ -0,0 +1,88 @@
import logging
from xml.etree import ElementTree
import asyncio
from asyncio.protocols import Protocol
from aioupnp.fault import UPnPError
from aioupnp.serialization.scpd import deserialize_scpd_get_response
from aioupnp.serialization.scpd import serialize_scpd_get
from aioupnp.serialization.soap import serialize_soap_post, deserialize_soap_post_response
log = logging.getLogger(__name__)
class SCPDHTTPClientProtocol(Protocol):
POST = 'POST'
GET = 'GET'
def __init__(self, method: str, message: bytes, finished: asyncio.Future, soap_method: str=None,
soap_service_id: str=None, close_after_send: bool = False):
self.method = method
assert soap_service_id is not None and soap_method is not None if method == 'POST' else True, \
'soap args not provided'
self.message = message
self.response_buff = b""
self.finished = finished
self.soap_method = soap_method
self.soap_service_id = soap_service_id
self.close_after_send = close_after_send
def connection_made(self, transport):
transport.write(self.message)
if self.close_after_send:
self.finished.set_result(None)
def data_received(self, data):
self.response_buff += data
if self.method == self.GET:
try:
packet = deserialize_scpd_get_response(self.response_buff)
if not packet:
return
except ElementTree.ParseError:
pass
except UPnPError as err:
self.finished.set_exception(err)
else:
self.finished.set_result(packet)
elif self.method == self.POST:
try:
packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id)
if not packet:
return
except ElementTree.ParseError:
pass
except UPnPError as err:
self.finished.set_exception(err)
else:
self.finished.set_result(packet)
async def scpd_get(control_url: str, address: str, port: int) -> dict:
loop = asyncio.get_running_loop()
finished = asyncio.Future()
packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
)
try:
return await asyncio.wait_for(finished, 1.0)
finally:
transport.close()
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, **kwargs):
loop = asyncio.get_running_loop()
finished = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol(
'POST', packet, finished, soap_method=method, soap_service_id=service_id.decode(),
close_after_send=close_after_send
), address, port
)
try:
return await asyncio.wait_for(finished, 1.0)
finally:
transport.close()

32
aioupnp/protocols/soap.py Normal file
View file

@ -0,0 +1,32 @@
import logging
from aioupnp.protocols.scpd import scpd_post
log = logging.getLogger(__name__)
class SCPDCommand:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list):
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_types = param_types
self.param_order = param_order
self.return_types = return_types
self.return_order = return_order
async def __call__(self, **kwargs):
if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None]
response = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
close_after_send, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
)
extracted_response = tuple([None if self.return_types[n] is None else self.return_types[n](response[n])
for n in self.return_order]) or (None, )
if len(extracted_response) == 1:
return extracted_response[0]
return extracted_response

104
aioupnp/protocols/ssdp.py Normal file
View file

@ -0,0 +1,104 @@
import re
import binascii
import asyncio
import logging
from typing import DefaultDict
from asyncio.coroutines import coroutine
from asyncio.futures import Future
from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import UPNP_ORG_IGD, SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE
from aioupnp.protocols.multicast import MulticastProtocol
ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$")
log = logging.getLogger(__name__)
class SSDPProtocol(MulticastProtocol):
def __init__(self, lan_address):
super().__init__()
self.lan_address = lan_address
self.discover_callbacks: DefaultDict[coroutine] = {}
self.transport: DatagramTransport
self.notifications = []
self.replies = []
def connection_made(self, transport: DatagramTransport):
super().connection_made(transport)
self.set_ttl(1)
async def m_search(self, address, timeout: int = 1, service=UPNP_ORG_IGD) -> SSDPDatagram:
if (address, service) in self.discover_callbacks:
return self.discover_callbacks[(address, service)]
packet = SSDPDatagram(
SSDPDatagram._M_SEARCH, host="{}:{}".format(SSDP_IP_ADDRESS, SSDP_PORT), st=service, man=SSDP_DISCOVER,
mx=1
)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
f = Future()
self.discover_callbacks[(address, service)] = f
return await asyncio.wait_for(f, timeout)
def datagram_received(self, data, addr) -> None:
if addr[0] == self.lan_address:
return
try:
packet = SSDPDatagram.decode(data)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), addr[0], addr[1], packet.encode())
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", addr[0], addr[1], err,
binascii.hexlify(data))
return
if packet._packet_type == packet._OK:
log.debug("%s:%i sent us an OK", addr[0], addr[1])
if (addr[0], packet.st) in self.discover_callbacks:
if packet.st not in map(lambda p: p['st'], self.replies):
self.replies.append(packet)
f: Future = self.discover_callbacks.pop((addr[0], packet.st))
f.set_result(packet)
elif packet._packet_type == packet._NOTIFY:
if packet.nt == SSDP_ROOT_DEVICE:
address, port, path = ADDRESS_REGEX.findall(packet.location)[0]
key = None
for (addr, service) in self.discover_callbacks:
if addr == address:
key = (addr, service)
break
if key:
log.debug("got a notification with the requested m-search info")
f: Future = self.discover_callbacks.pop(key)
f.set_result(SSDPDatagram(
SSDPDatagram._OK, cache_control='', location=packet.location, server=packet.server,
st=UPNP_ORG_IGD, usn=packet.usn
))
self.notifications.append(packet.as_dict())
async def listen_ssdp(lan_address: str, gateway_address: str) -> (DatagramTransport, SSDPProtocol, str, str):
loop = asyncio.get_running_loop()
try:
sock = SSDPProtocol.create_socket(lan_address, SSDP_IP_ADDRESS)
transport, protocol = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(lan_address), sock=sock
)
except Exception:
log.exception("failed to create multicast socket %s:%i", lan_address, SSDP_PORT)
raise
return transport, protocol, gateway_address, lan_address
async def m_search(lan_address: str, gateway_address: str, timeout: int = 1,
service: str = UPNP_ORG_IGD) -> SSDPDatagram:
transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address
)
try:
return await protocol.m_search(gateway_address, timeout=timeout, service=service)
except asyncio.TimeoutError:
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally:
transport.close()

View file

View file

@ -0,0 +1,49 @@
import re
from xml.etree import ElementTree
from aioupnp.constants import XML_VERSION, DEVICE, ROOT
from aioupnp.util import etree_to_dict, flatten_keys
CONTENT_PATTERN = re.compile(
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
)
XML_ROOT_SANITY_PATTERN = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
)
def serialize_scpd_get(path: str, address: str) -> bytes:
if "http://" in address:
host = address.split("http://")[1]
else:
host = address
if ":" in host:
host = host.split(":")[0]
if not path.startswith("/"):
path = "/" + path
return (
(
'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n'
'Host: %s\r\n'
'Connection: Close\r\n'
'\r\n'
) % (path, host)
).encode()
def deserialize_scpd_get_response(content: bytes) -> dict:
if XML_VERSION.encode() in content:
parsed = CONTENT_PATTERN.findall(content)
content = b'' if not parsed else parsed[0][0]
xml_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
schema_key = DEVICE
root = ROOT
for k in xml_dict.keys():
m = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
return flatten_keys(xml_dict, "{%s}" % schema_key)[root]

View file

@ -0,0 +1,63 @@
import re
from xml.etree import ElementTree
from aioupnp.util import etree_to_dict, flatten_keys
from aioupnp.fault import handle_fault, UPnPError
from aioupnp.constants import XML_VERSION, ENVELOPE, BODY
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
)
def serialize_soap_post(method: str, param_names: list, service_id: bytes, gateway_address: bytes,
control_url: bytes, **kwargs) -> bytes:
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, method, service_id.decode(),
args, method))
if "http://" in gateway_address.decode():
host = gateway_address.decode().split("http://")[1]
else:
host = gateway_address.decode()
return (
(
'POST %s HTTP/1.1\r\n'
'Host: %s\r\n'
'User-Agent: python3/aioupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n'
'Connection: Close\r\n'
'Cache-Control: no-cache\r\n'
'Pragma: no-cache\r\n'
'%s'
'\r\n'
) % (
control_url.decode(), # could be just / even if it shouldn't be
host,
len(soap_body),
service_id.decode(), # maybe no quotes
method,
soap_body
)
).encode()
def deserialize_soap_post_response(response: bytes, method: str, service_id: str) -> dict:
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(response)
content = b'' if not parsed else parsed[0][0]
content_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
return body[response_key]

View file

@ -1,12 +1,12 @@
import re
import logging
import binascii
from txupnp.fault import UPnPError
from txupnp.constants import line_separator
from aioupnp.fault import UPnPError
from aioupnp.constants import line_separator
log = logging.getLogger(__name__)
_ssdp_datagram_patterns = {
ssdp_datagram_patterns = {
'host': (re.compile("^(?i)(host):(.*)$"), str),
'st': (re.compile("^(?i)(st):(.*)$"), str),
'man': (re.compile("^(?i)(man):|(\"(.*)\")$"), str),
@ -19,7 +19,7 @@ _ssdp_datagram_patterns = {
'server': (re.compile("^(?i)(server):(.*)$"), str),
}
_vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
vendor_pattern = re.compile("^([\w|\d]*)\.([\w|\d]*\.com):([ \"|\w|\d\:]*)$")
class SSDPDatagram(object):
@ -39,9 +39,9 @@ class SSDPDatagram(object):
_OK: "m-search response"
}
_vendor_field_pattern = _vendor_pattern
_vendor_field_pattern = vendor_pattern
_patterns = _ssdp_datagram_patterns
_patterns = ssdp_datagram_patterns
_required_fields = {
_M_SEARCH: [
@ -95,6 +95,9 @@ class SSDPDatagram(object):
if not k.startswith("_") and hasattr(self, k.lower()) and getattr(self, k.lower()) is None:
setattr(self, k.lower(), v)
def __repr__(self):
return ("SSDPDatagram(packet_type=%s, " % self._packet_type) + ", ".join("%s=%s" % (n, v) for n, v in self.as_dict().items()) + ")"
def __getitem__(self, item):
for i in self._required_fields[self._packet_type]:
if i.lower() == item.lower():
@ -166,7 +169,7 @@ class SSDPDatagram(object):
lines = [l for l in datagram.split(line_separator) if l]
if lines[0] == cls._start_lines[cls._M_SEARCH]:
return cls._from_request(lines[1:])
if lines[0] == cls._start_lines[cls._NOTIFY]:
if lines[0] in [cls._start_lines[cls._NOTIFY], cls._start_lines[cls._NOTIFY] + " "]:
return cls._from_notify(lines[1:])
if lines[0] == cls._start_lines[cls._OK]:
return cls._from_response(lines[1:])

258
aioupnp/upnp.py Normal file
View file

@ -0,0 +1,258 @@
import os
import logging
import json
import asyncio
import functools
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.constants import UPNP_ORG_IGD
from aioupnp.util import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search
log = logging.getLogger(__name__)
def cli(format_result=None):
def _cli(fn):
@functools.wraps(fn)
async def _inner(*args, **kwargs):
result = await fn(*args, **kwargs)
if not format_result or not result or not isinstance(result, (list, dict, tuple)):
return result
self = args[0]
return {k: v for k, v in zip(getattr(self.gateway.commands, format_result).return_order, result)}
f = _inner
f._cli = True
return f
return _cli
def _encode(x):
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
class UPnP:
def __init__(self, lan_address: str, gateway_address: str, gateway: Gateway):
self.lan_address = lan_address
self.gateway_address = gateway_address
self.gateway = gateway
@classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> (str, str):
if not lan_address or not gateway_address:
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
lan_address = lan_address or lan_addr
gateway_address = gateway_address or gateway_addr
return lan_address, gateway_address
@classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default'):
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses: %s" % str(err))
gateway = await Gateway.discover_gateway(lan_address, gateway_address, timeout, service)
return cls(lan_address, gateway_address, gateway)
@classmethod
@cli()
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
service: str = UPNP_ORG_IGD, interface_name: str = 'default') -> dict:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
datagram = await m_search(lan_address, gateway_address, timeout, service)
return {
'lan_address': lan_address,
'gateway_address': gateway_address,
'discover_reply': datagram.as_dict()
}
@cli()
async def get_external_ip(self) -> str:
return await self.gateway.commands.GetExternalIPAddress()
@cli("AddPortMapping")
async def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
description: str) -> None:
return await self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
)
@cli("GetGenericPortMappingEntry")
async def get_port_mapping_by_index(self, index: int) -> dict:
return await self._get_port_mapping_by_index(index)
async def _get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect
except UPnPError:
return
@cli()
async def get_redirects(self) -> list:
redirects = []
cnt = 0
redirect = await self.get_port_mapping_by_index(cnt)
while redirect:
redirects.append(redirect)
cnt += 1
redirect = await self.get_port_mapping_by_index(cnt)
return redirects
@cli("GetSpecificPortMappingEntry")
async def get_specific_port_mapping(self, external_port: int, protocol: str):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
"""
try:
return await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
except UPnPError:
return
@cli()
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: None
"""
return await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
@cli("AddPortMapping")
async def get_next_mapping(self, port: int, protocol: str, description: str, internal_port: int=None) -> int:
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = internal_port or port
redirect_tups = []
cnt = 0
port = int(port)
internal_port = int(internal_port)
redirect = await self._get_port_mapping_by_index(cnt)
while redirect:
redirect_tups.append(redirect)
cnt += 1
redirect = await self._get_port_mapping_by_index(cnt)
redirects = {
"%i:%s" % (ext_port, proto): (int_host, int_port, desc)
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, lease) in redirect_tups
}
while ("%i:%s" % (port, protocol)) in redirects:
int_host, int_port, _ = redirects["%i:%s" % (port, protocol)]
if int_host == self.lan_address and int_port == internal_port:
break
port += 1
await self.add_port_mapping( # set one up
port, protocol, internal_port, self.lan_address, description
)
return port
@cli()
async def get_soap_commands(self) -> dict:
return {
'supported': list(self.gateway._registered_commands.keys()),
'unsupported': self.gateway._unsupported_actions
}
@cli()
async def generate_test_data(self):
external_ip = await self.get_external_ip()
redirects = await self.get_redirects()
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
delete = await self.delete_port_mapping(ext_port, "UDP")
after_delete = await self.get_specific_port_mapping(ext_port, "UDP")
commands_test_case = (
("get_external_ip", (), "1.2.3.4"),
("get_redirects", (), redirects),
("get_next_mapping", (4567, "UDP", "aioupnp test mapping"), ext_port),
("delete_port_mapping", (ext_port, "UDP"), delete),
("get_specific_port_mapping", (ext_port, "UDP"), after_delete),
)
gateway = self.gateway
device = list(gateway.devices.values())[0]
assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName))
commands = gateway.debug_commands()
with open(device_path, "w") as f:
f.write(json.dumps({
"router_address": self.gateway_address,
"client_address": self.lan_address,
"port": gateway.port,
"gateway_dict": gateway.gateway_descriptor(),
'expected_devices': [
{
'cache_control': 'max-age=1800',
'location': gateway.location,
'server': gateway.server,
'st': gateway.urn,
'usn': gateway.usn
}
],
'commands': commands,
# 'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
# 'scpd': gateway.requester.dump_packets(),
'soap': commands_test_case
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
return "Generated test data! -> %s" % device_path
@classmethod
def run_cli(cls, method, lan_address: str = '', gateway_address: str = '', timeout: int = 60,
service: str = UPNP_ORG_IGD, interface_name: str = 'default',
kwargs: dict = None):
kwargs = kwargs or {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
fut = asyncio.Future()
async def wrapper():
if method == 'm_search':
fn = lambda *_a, **_kw: cls.m_search(lan_address, gateway_address, timeout, service, interface_name)
else:
u = await cls.discover(
lan_address, gateway_address, timeout, service, interface_name
)
if hasattr(u, method) and hasattr(getattr(u, method), "_cli"):
fn = getattr(u, method)
else:
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
return
try:
result = await fn(**kwargs)
fut.set_result(result)
except UPnPError as err:
fut.set_exception(err)
except Exception as err:
log.exception("uncaught error")
fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
asyncio.run(wrapper())
try:
result = fut.result()
except UPnPError as err:
print("error: %s" % str(err))
return
if isinstance(result, (list, tuple, dict)):
print(json.dumps(result, indent=2, default=_encode))
else:
print(result)

View file

@ -1,7 +1,9 @@
import re
import functools
import socket
from collections import defaultdict
from xml.etree import ElementTree
import netifaces
BASE_ADDRESS_REGEX = re.compile("^(http:\/\/\d*\.\d*\.\d*\.\d*:\d*)\/.*$".encode())
BASE_PORT_REGEX = re.compile("^http:\/\/\d*\.\d*\.\d*\.\d*:(\d*)\/.*$".encode())
@ -50,35 +52,37 @@ def get_dict_val_case_insensitive(d, k):
raise KeyError("overlapping keys")
return d[match[0]]
def verify_return_types(*types):
"""
Attempt to recast results to expected result types
"""
def _verify_return_types(fn):
@functools.wraps(fn)
def _inner(*result):
r = fn(*tuple(t(r) for t, r in zip(types, result)))
if isinstance(r, tuple) and len(r) == 1:
return r[0]
return r
return _inner
return _verify_return_types
# import struct
# import fcntl
# def get_ip_address(ifname):
# SIOCGIFADDR = 0x8915
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# return socket.inet_ntoa(fcntl.ioctl(
# s.fileno(),
# SIOCGIFADDR,
# struct.pack(b'256s', ifname[:15].encode())
# )[20:24])
def return_types(*types):
"""
Decorator to set the expected return types of a SOAP function call
"""
def return_types_wrapper(fn):
fn._return_types = types
return fn
return return_types_wrapper
def get_interfaces():
r = {
interface_name: (router_address, netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'])
for router_address, interface_name, _ in netifaces.gateways()[socket.AF_INET]
}
for interface_name in netifaces.interfaces():
if interface_name in ['lo', 'localhost'] or interface_name in r:
continue
addresses = netifaces.ifaddresses(interface_name)
if netifaces.AF_INET in addresses:
address = addresses[netifaces.AF_INET][0]['addr']
gateway_guess = ".".join(address.split(".")[:-1] + ["1"])
r[interface_name] = (gateway_guess, address)
r['default'] = r[netifaces.gateways()['default'][netifaces.AF_INET][1]]
return r
none_or_str = lambda x: None if not x or x == 'None' else str(x)
none = lambda _: None
def get_gateway_and_lan_addresses(interface_name: str) -> (str, str):
for iface_name, (gateway, lan) in get_interfaces().items():
if interface_name == iface_name:
return gateway, lan
return None, None

View file

@ -1,12 +1,12 @@
import os
from setuptools import setup, find_packages
from txupnp import __version__, __name__, __email__, __author__, __license__
from aioupnp import __version__, __name__, __email__, __author__, __license__
console_scripts = [
'txupnp-cli = txupnp.cli:main',
'aioupnp = aioupnp.__main__:main',
]
package_name = "txupnp"
package_name = "aioupnp"
base_dir = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(base_dir, 'README.md'), 'rb') as f:
long_description = f.read().decode('utf-8')
@ -16,15 +16,14 @@ setup(
version=__version__,
author=__author__,
author_email=__email__,
description="UPnP for twisted",
keywords="upnp twisted",
description="UPnP for asyncio",
keywords="upnp asyncio",
long_description=long_description,
url="https://github.com/lbryio/txupnp",
url="https://github.com/lbryio/aioupnp",
license=__license__,
packages=find_packages(exclude=['tests']),
entry_points={'console_scripts': console_scripts},
install_requires=[
'twisted[tls]',
'netifaces',
],
)

View file

@ -1,6 +1,6 @@
import logging
log = logging.getLogger("txupnp")
log = logging.getLogger("aioupnp")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)

File diff suppressed because one or more lines are too long

View file

@ -1,117 +1,117 @@
from twisted.internet import reactor, defer
from twisted.trial import unittest
from txupnp.constants import SSDP_PORT, SSDP_IP_ADDRESS
from txupnp.upnp import UPnP
from txupnp.mocks import MockReactor, MockSSDPServiceGatewayProtocol, get_device_test_case
class TestDevice(unittest.TestCase):
manufacturer, model = "Cisco", "CGA4131COM"
device = get_device_test_case(manufacturer, model)
router_address = device.device_dict['router_address']
client_address = device.device_dict['client_address']
expected_devices = device.device_dict['expected_devices']
packets_rx = device.device_dict['ssdp']['received']
packets_tx = device.device_dict['ssdp']['sent']
expected_available_commands = device.device_dict['commands']['available']
scdp_packets = device.device_dict['scpd']
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
def tearDown(self):
self.upnp.sspd_factory.disconnect()
self.server_port.stopListening()
class TestSSDP(TestDevice):
@defer.inlineCallbacks
def test_discover_device(self):
result = yield self.upnp.m_search(self.router_address, timeout=1)
self.assertEqual(len(self.expected_devices), len(result))
self.assertEqual(len(result), 1)
self.assertDictEqual(self.expected_devices[0], result[0])
class TestSCPD(TestDevice):
@defer.inlineCallbacks
def setUp(self):
fake_reactor = MockReactor(self.client_address, self.scdp_packets)
reactor.listenMulticast = fake_reactor.listenMulticast
reactor.connectTCP = fake_reactor.connectTCP
self.reactor = reactor
server_protocol = MockSSDPServiceGatewayProtocol(
self.client_address, self.router_address, self.packets_rx, self.packets_tx
)
self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
self.upnp = UPnP(
self.reactor, debug_ssdp=True, router_ip=self.router_address,
lan_ip=self.client_address, iface_name='mock'
)
yield self.upnp.discover()
def test_parse_available_commands(self):
self.assertDictEqual(self.expected_available_commands, self.upnp.gateway.debug_commands()['available'])
def test_parse_gateway(self):
self.assertDictEqual(self.device.device_dict['gateway_dict'], self.upnp.gateway.as_dict())
@defer.inlineCallbacks
def test_commands(self):
method, args, expected = self.device.device_dict['soap'][0]
command1 = getattr(self.upnp, method)
result = yield command1(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][1]
command2 = getattr(self.upnp, method)
result = yield command2(*tuple(args))
result = [[i for i in r] for r in result]
self.assertListEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][2]
command3 = getattr(self.upnp, method)
result = yield command3(*tuple(args))
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][3]
command4 = getattr(self.upnp, method)
result = yield command4(*tuple(args))
result = [r for r in result]
self.assertEqual(result, expected)
method, args, expected = self.device.device_dict['soap'][4]
command5 = getattr(self.upnp, method)
result = yield command5(*tuple(args))
self.assertEqual(result, expected)
class TestDDWRTSSDP(TestSSDP):
manufacturer, model = "DD-WRT", "router"
class TestDDWRTSCPD(TestSCPD):
manufacturer, model = "DD-WRT", "router"
class TestMiniUPnPMiniUPnPd(TestSSDP):
manufacturer, model = "MiniUPnP", "MiniUPnPd"
class TestMiniUPnPMiniUPnPdSCPD(TestSCPD):
manufacturer, model = "MiniUPnP", "MiniUPnPd"
# from twisted.internet import reactor, defer
# from twisted.trial import unittest
# from aioupnp.constants import SSDP_PORT, SSDP_IP_ADDRESS
# from aioupnp.upnp import UPnP
# from aioupnp.mocks import MockReactor, MockSSDPServiceGatewayProtocol, get_device_test_case
#
#
# class TestDevice(unittest.TestCase):
# manufacturer, model = "Cisco", "CGA4131COM"
#
# device = get_device_test_case(manufacturer, model)
# router_address = device.device_dict['router_address']
# client_address = device.device_dict['client_address']
# expected_devices = device.device_dict['expected_devices']
# packets_rx = device.device_dict['ssdp']['received']
# packets_tx = device.device_dict['ssdp']['sent']
# expected_available_commands = device.device_dict['commands']['available']
# scdp_packets = device.device_dict['scpd']
#
# def setUp(self):
# fake_reactor = MockReactor(self.client_address, self.scdp_packets)
# reactor.listenMulticast = fake_reactor.listenMulticast
# self.reactor = reactor
# server_protocol = MockSSDPServiceGatewayProtocol(
# self.client_address, self.router_address, self.packets_rx, self.packets_tx
# )
# self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
# self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
#
# self.upnp = UPnP(
# self.reactor, debug_ssdp=True, router_ip=self.router_address,
# lan_ip=self.client_address, iface_name='mock'
# )
#
# def tearDown(self):
# self.upnp.sspd_factory.disconnect()
# self.server_port.stopListening()
#
#
# class TestSSDP(TestDevice):
# @defer.inlineCallbacks
# def test_discover_device(self):
# result = yield self.upnp.m_search(self.router_address, timeout=1)
# self.assertEqual(len(self.expected_devices), len(result))
# self.assertEqual(len(result), 1)
# self.assertDictEqual(self.expected_devices[0], result[0])
#
#
# class TestSCPD(TestDevice):
# @defer.inlineCallbacks
# def setUp(self):
# fake_reactor = MockReactor(self.client_address, self.scdp_packets)
# reactor.listenMulticast = fake_reactor.listenMulticast
# reactor.connectTCP = fake_reactor.connectTCP
# self.reactor = reactor
# server_protocol = MockSSDPServiceGatewayProtocol(
# self.client_address, self.router_address, self.packets_rx, self.packets_tx
# )
# self.server_port = self.reactor.listenMulticast(SSDP_PORT, server_protocol, interface=self.router_address)
# self.server_port.transport.joinGroup(SSDP_IP_ADDRESS, interface=self.router_address)
#
# self.upnp = UPnP(
# self.reactor, debug_ssdp=True, router_ip=self.router_address,
# lan_ip=self.client_address, iface_name='mock'
# )
# yield self.upnp.discover()
#
# def test_parse_available_commands(self):
# self.assertDictEqual(self.expected_available_commands, self.upnp.gateway.debug_commands()['available'])
#
# def test_parse_gateway(self):
# self.assertDictEqual(self.device.device_dict['gateway_dict'], self.upnp.gateway.as_dict())
#
# @defer.inlineCallbacks
# def test_commands(self):
# method, args, expected = self.device.device_dict['soap'][0]
# command1 = getattr(self.upnp, method)
# result = yield command1(*tuple(args))
# self.assertEqual(result, expected)
#
# method, args, expected = self.device.device_dict['soap'][1]
# command2 = getattr(self.upnp, method)
# result = yield command2(*tuple(args))
# result = [[i for i in r] for r in result]
# self.assertListEqual(result, expected)
#
# method, args, expected = self.device.device_dict['soap'][2]
# command3 = getattr(self.upnp, method)
# result = yield command3(*tuple(args))
# self.assertEqual(result, expected)
#
# method, args, expected = self.device.device_dict['soap'][3]
# command4 = getattr(self.upnp, method)
# result = yield command4(*tuple(args))
# result = [r for r in result]
# self.assertEqual(result, expected)
#
# method, args, expected = self.device.device_dict['soap'][4]
# command5 = getattr(self.upnp, method)
# result = yield command5(*tuple(args))
# self.assertEqual(result, expected)
#
#
# class TestDDWRTSSDP(TestSSDP):
# manufacturer, model = "DD-WRT", "router"
#
#
# class TestDDWRTSCPD(TestSCPD):
# manufacturer, model = "DD-WRT", "router"
#
#
# class TestMiniUPnPMiniUPnPd(TestSSDP):
# manufacturer, model = "MiniUPnP", "MiniUPnPd"
#
#
# class TestMiniUPnPMiniUPnPdSCPD(TestSCPD):
# manufacturer, model = "MiniUPnP", "MiniUPnPd"

View file

@ -1,6 +1,6 @@
from twisted.trial import unittest
from txupnp.ssdp_datagram import SSDPDatagram
from txupnp.fault import UPnPError
import unittest
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.fault import UPnPError
class TestParseMSearchRequest(unittest.TestCase):
@ -99,3 +99,27 @@ class TestFailToParseMSearchResponseNoLocation(TestFailToParseMSearchResponseNoS
'st: urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1',
'USN: uuid:00000000-0000-0000-0000-000000000000::urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1'
]).encode()
class TestParseNotify(unittest.TestCase):
datagram = \
b'NOTIFY * HTTP/1.1 \r\n' \
b'Host: 239.255.255.250:1900\r\n' \
b'Cache-Control: max-age=180\r\n' \
b'Location: http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068\r\n' \
b'NT: upnp:rootdevice\r\n' \
b'NTS: ssdp:alive\r\n' \
b'SERVER: LINUX/2.4 UPnP/1.0 BRCM400/1.0\r\n' \
b'USN: uuid:000c-29ea-247500c00068::upnp:rootdevice\r\n' \
b'\r\n'
def test_parse_notify(self):
packet = SSDPDatagram.decode(self.datagram)
self.assertTrue(packet._packet_type, packet._NOTIFY)
self.assertEqual(packet.host, '239.255.255.250:1900')
self.assertEqual(packet.cache_control, 'max-age=180')
self.assertEqual(packet.location, 'http://192.168.1.1:5431/dyndev/uuid:000c-29ea-247500c00068')
self.assertEqual(packet.nt, 'upnp:rootdevice')
self.assertEqual(packet.nts, 'ssdp:alive')
self.assertEqual(packet.server, 'LINUX/2.4 UPnP/1.0 BRCM400/1.0')
self.assertEqual(packet.usn, 'uuid:000c-29ea-247500c00068::upnp:rootdevice')

View file

@ -1,159 +0,0 @@
import os
import json
import argparse
import logging
from twisted.internet import reactor, defer
from txupnp.upnp import UPnP
log = logging.getLogger("txupnp")
@defer.inlineCallbacks
def get_external_ip(u, *_):
ip = yield u.get_external_ip()
print(ip)
@defer.inlineCallbacks
def list_mappings(u, *_):
redirects = yield u.get_redirects()
ext_ip = yield u.get_external_ip()
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, lease) in redirects:
print("{}:{}/{} --> {}:{} ({}) (expires: {}) - {} ".format(
ext_host or ext_ip, ext_port, proto, int_host, int_port, "enabled" if enabled else "disabled",
"never" if not lease else lease, desc)
)
@defer.inlineCallbacks
def add_mapping(u, *_):
port = 51413
protocol = "UDP"
description = "txupnp test mapping"
ext_port = yield u.get_next_mapping(port, protocol, description)
if ext_port:
print("external port: %i to local %i/%s" % (ext_port, port, protocol))
@defer.inlineCallbacks
def delete_mapping(u, *_):
port = 4567
protocol = "UDP"
yield u.delete_port_mapping(port, protocol)
mapping = yield u.get_specific_port_mapping(port, protocol)
if mapping:
print("failed to remove mapping")
else:
print("removed mapping")
def _encode(x):
if isinstance(x, bytes):
return x.decode()
elif isinstance(x, Exception):
return str(x)
return x
@defer.inlineCallbacks
def generate_test_data(u, *_):
external_ip = yield u.get_external_ip()
redirects = yield u.get_redirects()
ext_port = yield u.get_next_mapping(4567, "UDP", "txupnp test mapping")
delete = yield u.delete_port_mapping(ext_port, "UDP")
after_delete = yield u.get_specific_port_mapping(ext_port, "UDP")
commands_test_case = (
("get_external_ip", (), "1.2.3.4"),
("get_redirects", (), redirects),
("get_next_mapping", (4567, "UDP", "txupnp test mapping"), ext_port),
("delete_port_mapping", (ext_port, "UDP"), delete),
("get_specific_port_mapping", (ext_port, "UDP"), after_delete),
)
gateway = u.gateway
device = list(gateway.devices.values())[0]
assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName))
commands = gateway.debug_commands()
with open(device_path, "w") as f:
f.write(json.dumps({
"router_address": u.router_ip,
"client_address": u.lan_address,
"port": gateway.port,
"gateway_dict": gateway.as_dict(),
'expected_devices': [
{
'cache_control': 'max-age=1800',
'location': gateway.location,
'server': gateway.server,
'st': gateway.urn,
'usn': gateway.usn
}
],
'commands': commands,
'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
'scpd': gateway.requester.dump_packets(),
'soap': commands_test_case
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
print("Generated test data! -> %s" % device_path)
cli_commands = {
"get_external_ip": get_external_ip,
"list_mappings": list_mappings,
"add_mapping": add_mapping,
"delete_mapping": delete_mapping,
"generate_test_data": generate_test_data,
}
@defer.inlineCallbacks
def run_command(found, u, command, debug_xml):
if not found:
print("failed to find gateway")
reactor.callLater(0, reactor.stop)
return
if command not in cli_commands:
print("unrecognized command: valid commands: %s" % list(cli_commands.keys()))
else:
yield cli_commands[command](u, debug_xml)
def main():
import logging
log = logging.getLogger("txupnp")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)-15s-%(filename)s:%(lineno)s->%(message)s'))
log.addHandler(handler)
log.setLevel(logging.INFO)
parser = argparse.ArgumentParser(description="upnp command line utility")
parser.add_argument(dest="command", type=str, help="debug_gateway | list_mappings | get_external_ip | add_mapping | delete_mapping")
parser.add_argument("--debug_logging", dest="debug_logging", default=False, action="store_true")
parser.add_argument("--include_igd_xml", dest="include_igd_xml", default=False, action="store_true")
args = parser.parse_args()
if args.debug_logging:
# from twisted.python import log as tx_log
# observer = tx_log.PythonLoggingObserver(loggerName="txupnp")
# observer.start()
log.setLevel(logging.DEBUG)
command = args.command
command = command.replace("-", "_")
if command not in cli_commands:
print("unrecognized command: %s is not in %s" % (command, cli_commands.keys()))
return
def show(err):
print("error: {}".format(err))
u = UPnP(reactor, debug_ssdp=(command == "generate_test_data"))
d = u.discover()
d.addCallback(run_command, u, command, args.include_igd_xml)
d.addErrback(show)
d.addBoth(lambda _: reactor.callLater(0, reactor.stop))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -1,217 +0,0 @@
import logging
from twisted.internet import defer
from txupnp.scpd import SCPDCommand, SCPDRequester
from txupnp.util import get_dict_val_case_insensitive, verify_return_types, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from txupnp.constants import SPEC_VERSION
from txupnp.commands import SCPDCommands
log = logging.getLogger(__name__)
class CaseInsensitive:
def __init__(self, **kwargs):
not_evaluated = {}
for k, v in kwargs.items():
if k.startswith("_"):
not_evaluated[k] = v
continue
try:
getattr(self, k)
setattr(self, k, v)
except AttributeError as err:
not_evaluated[k] = v
if not_evaluated:
log.debug("%s did not apply kwargs: %s", self.__class__.__name__, not_evaluated)
def _get_attr_name(self, case_insensitive: str) -> str:
for k, v in self.__dict__.items():
if k.lower() == case_insensitive.lower():
return k
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
for k, v in self.__class__.__dict__.items():
if k.lower() == item.lower():
if k not in self.__dict__:
self.__dict__[k] = v
return v
raise AttributeError(item)
def __setattr__(self, item, value):
if item in self.__dict__:
self.__dict__[item] = value
return
to_update = None
for k, v in self.__dict__.items():
if k.lower() == item.lower():
to_update = k
break
self.__dict__[to_update or item] = value
def as_dict(self) -> dict:
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)
}
class Service(CaseInsensitive):
serviceType = None
serviceId = None
controlURL = None
eventSubURL = None
SCPDURL = None
class Device(CaseInsensitive):
serviceList = None
deviceList = None
deviceType = None
friendlyName = None
manufacturer = None
manufacturerURL = None
modelDescription = None
modelName = None
modelNumber = None
modelURL = None
serialNumber = None
udn = None
upc = None
presentationURL = None
iconList = None
def __init__(self, devices, services, **kwargs):
super(Device, self).__init__(**kwargs)
if self.serviceList and "service" in self.serviceList:
new_services = self.serviceList["service"]
if isinstance(new_services, dict):
new_services = [new_services]
services.extend([Service(**service) for service in new_services])
if self.deviceList:
for kw in self.deviceList.values():
if isinstance(kw, dict):
d = Device(devices, services, **kw)
devices.append(d)
elif isinstance(kw, list):
for _inner_kw in kw:
d = Device(devices, services, **_inner_kw)
devices.append(d)
else:
log.warning("failed to parse device:\n%s", kw)
class Gateway:
def __init__(self, reactor, **kwargs):
flattened = {
k.lower(): v for k, v in kwargs.items()
}
usn = flattened["usn"]
server = flattened["server"]
location = flattened["location"]
st = flattened["st"]
cache_control = flattened.get("cache_control") or flattened.get("cache-control") or ""
date = flattened.get("date", "")
ext = flattened.get("ext", "")
self.usn = usn.encode()
self.ext = ext.encode()
self.server = server.encode()
self.location = location.encode()
self.cache_control = cache_control.encode()
self.date = date.encode()
self.urn = st.encode()
self._xml_response = ""
self._service_descriptors = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.spec_version = None
self.url_base = None
self._device = None
self._devices = []
self._services = []
self._reactor = reactor
self._unsupported_actions = {}
self._registered_commands = {}
self.commands = SCPDCommands()
self.requester = SCPDRequester(self._reactor)
def as_dict(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r
@defer.inlineCallbacks
def discover_commands(self):
response = yield self.requester.scpd_get(self.location.decode().split(self.base_address.decode())[1], self.base_address.decode(), self.port)
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base:
self.url_base = self.base_address.decode()
if response:
self._device = Device(
self._devices, self._services, **get_dict_val_case_insensitive(response, "device")
)
else:
self._device = Device(self._devices, self._services)
for service_type in self.services.keys():
service = self.services[service_type]
yield self.register_commands(service)
@defer.inlineCallbacks
def register_commands(self, service: Service):
try:
action_list = yield self.requester.scpd_get_supported_actions(service, self.base_address.decode(), self.port)
except Exception as err:
log.exception("failed to register service %s: %s", service.serviceType, str(err))
return
for name, inputs, outputs in action_list:
try:
command = SCPDCommand(self.requester, self.base_address, self.port,
service.controlURL.encode(),
service.serviceType.encode(), name, inputs, outputs)
current = getattr(self.commands, command.method)
if hasattr(current, "_return_types"):
command._process_result = verify_return_types(*current._return_types)(command._process_result)
setattr(command, "__doc__", current.__doc__)
setattr(self.commands, command.method, command)
self._registered_commands[command.method] = service.serviceType
log.debug("registered %s::%s", service.serviceType, command.method)
except AttributeError:
s = self._unsupported_actions.get(service.serviceType, [])
s.append(name)
self._unsupported_actions[service.serviceType] = s
log.debug("available command for %s does not have a wrapper implemented: %s %s %s",
service.serviceType, name, inputs, outputs)
@property
def services(self) -> dict:
if not self._device:
return {}
return {service.serviceType: service for service in self._services}
@property
def devices(self) -> dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> Service:
for service in self._services:
if service.serviceType.lower() == service_type.lower():
return service
def debug_commands(self):
return {
'available': self._registered_commands,
'failed': self._unsupported_actions
}

View file

@ -1,177 +0,0 @@
import os
import json
import logging
from twisted.internet import task, defer
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import DatagramProtocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import _FakePort
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger()
class MockResponse:
def __init__(self, content):
self._content = content
self.headers = {}
def content(self):
return defer.succeed(self._content)
class MockDevice:
def __init__(self, manufacturer, model):
self.manufacturer = manufacturer
self.model = model
device_path = os.path.join(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"), "{} {}".format(manufacturer, model)
)
assert os.path.isfile(device_path)
with open(device_path, "r") as f:
self.device_dict = json.loads(f.read())
def __repr__(self):
return "MockDevice(manufacturer={}, model={})".format(self.manufacturer, self.model)
def get_mock_devices():
return [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "devices"))
if ".py" not in path and "pycache" not in path
]
def get_device_test_case(manufacturer: str, model: str) -> MockDevice:
r = [
MockDevice(path.split(" ")[0], path.split(" ")[1])
for path in os.listdir(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "devices"))
if ".py" not in path and "pycache" not in path and path.split(" ") == [manufacturer, model]
]
return r[0]
class MockMulticastTransport:
def __init__(self, address, port, max_packet_size, network, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._network = network
self._protocol = protocol
def write(self, data, address):
if address[0] in self._network.group:
destinations = self._network.group[address[0]]
else:
destinations = address[0]
for address, dest in self._network.peers.items():
if address[0] in destinations and dest.address != self.address:
dest._protocol.datagramReceived(data, (self.address, self.port))
def setTTL(self, ttl):
pass
def joinGroup(self, address, interface=None):
group = self._network.group.get(address, [])
group.append(interface)
self._network.group[address] = group
def leaveGroup(self, address, interface=None):
group = self._network.group.get(address, [])
if interface in group:
group.remove(interface)
self._network.group[address] = group
class MockTCPTransport(_FakePort):
def __init__(self, address, port, callback, mock_requests):
super().__init__((address, port))
self._callback = callback
self._mock_requests = mock_requests
def write(self, data):
if data.startswith(b"POST"):
for url, packets in self._mock_requests['POST'].items():
for request_response in packets:
if data.decode() == request_response['request']:
self._callback(request_response['response'].encode())
return
elif data.startswith(b"GET"):
for url, packets in self._mock_requests['GET'].items():
if data.decode() == packets['request']:
self._callback(packets['response'].encode())
return
class MockMulticastPort(_FakePort):
def __init__(self, protocol, remover, address, transport):
super().__init__((address, 1900))
self.protocol = protocol
self._remover = remover
self.transport = transport
def startListening(self, reason=None):
self.protocol.transport = self.transport
return self.protocol.startProtocol()
def stopListening(self, reason=None):
result = self.protocol.stopProtocol()
self._remover()
return result
class MockNetwork:
def __init__(self):
self.peers = {}
self.group = {}
def add_peer(self, port, protocol, interface, maxPacketSize):
transport = MockMulticastTransport(interface, port, maxPacketSize, self, protocol)
self.peers[(interface, port)] = transport
def remove_peer():
if self.peers.get((interface, port)):
del self.peers[(interface, port)]
return transport, remove_peer
class MockReactor(task.Clock):
def __init__(self, client_addr, mock_scpd_requests):
super().__init__()
self.client_addr = client_addr
self._mock_scpd_requests = mock_scpd_requests
self.network = MockNetwork()
def listenMulticast(self, port, protocol, interface=None, maxPacketSize=8192, listenMultiple=True):
interface = interface or self.client_addr
transport, remover = self.network.add_peer(port, protocol, interface, maxPacketSize)
port = MockMulticastPort(protocol, remover, interface, transport)
port.startListening()
return port
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
protocol = factory.buildProtocol(host)
def _write_and_close(data):
protocol.dataReceived(data)
protocol.connectionLost(Failure(ConnectionDone()))
protocol.transport = MockTCPTransport(host, port, _write_and_close, self._mock_scpd_requests)
protocol.connectionMade()
class MockSSDPServiceGatewayProtocol(DatagramProtocol):
def __init__(self, client_addr: int, iface: str, packets_rx: list, packets_tx: list):
self.client_addr = client_addr
self.iface = iface
self.packets_tx = [SSDPDatagram.decode(packet.encode()) for packet in packets_tx] # sent by client
self.packets_rx = [((addr, port), SSDPDatagram.decode(packet.encode())) for (addr, port), packet in packets_rx] # rx by client
def datagramReceived(self, datagram, address):
packet = SSDPDatagram.decode(datagram)
if packet.st in map(lambda p: p[1].st, self.packets_rx): # this contains one of the service types the server replied to
reply = list(filter(lambda p: p[1].st == packet.st, self.packets_rx))[0][1]
self.transport.write(reply.encode().encode(), (self.client_addr, 1900))
else:
pass

View file

@ -1,264 +0,0 @@
import re
import logging
from xml.etree import ElementTree
from twisted.internet.protocol import Protocol, ClientFactory
from twisted.internet import defer, error
from txupnp.constants import XML_VERSION, DEVICE, ROOT, SERVICE, ENVELOPE, BODY
from txupnp.util import etree_to_dict, flatten_keys
from txupnp.fault import handle_fault, UPnPError
log = logging.getLogger(__name__)
CONTENT_PATTERN = re.compile(
"(\<\?xml version=\"1\.0\"\?\>(\s*.)*|\>)".encode()
)
CONTENT_NO_XML_VERSION_PATTERN = re.compile(
"(\<s\:Envelope xmlns\:s=\"http\:\/\/schemas\.xmlsoap\.org\/soap\/envelope\/\"(\s*.)*\>)".encode()
)
XML_ROOT_SANITY_PATTERN = re.compile(
"(?i)(\{|(urn:schemas-[\w|\d]*-(com|org|net))[:|-](device|service)[:|-]([\w|\d|\:|\-|\_]*)|\}([\w|\d|\:|\-|\_]*))"
)
def parse_service_description(content: bytes):
if not content:
return []
element_dict = etree_to_dict(ElementTree.fromstring(content.decode()))
service_info = flatten_keys(element_dict, "{%s}" % SERVICE)
if "scpd" not in service_info:
return []
action_list = service_info["scpd"]["actionList"]
if not len(action_list): # it could be an empty string
return []
result = []
if isinstance(action_list["action"], dict):
arg_dicts = action_list["action"]['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
return [[
action_list["action"]['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
]]
for action in action_list["action"]:
if not action.get('argumentList'):
result.append((action['name'], [], []))
else:
arg_dicts = action['argumentList']['argument']
if not isinstance(arg_dicts, list): # when there is one arg, ew
arg_dicts = [arg_dicts]
result.append((
action['name'],
[i['name'] for i in arg_dicts if i['direction'] == 'in'],
[i['name'] for i in arg_dicts if i['direction'] == 'out']
))
return result
class SCPDHTTPClientProtocol(Protocol):
def connectionMade(self):
self.response_buff = b""
log.debug("Sending HTTP:\n%s", self.factory.packet.decode())
self.factory.reactor.callLater(0, self.transport.write, self.factory.packet)
def dataReceived(self, data):
self.response_buff += data
def connectionLost(self, reason):
if reason.trap(error.ConnectionDone):
log.debug("Received HTTP:\n%s", self.response_buff.decode())
if XML_VERSION.encode() in self.response_buff:
parsed = CONTENT_PATTERN.findall(self.response_buff)
result = b'' if not parsed else parsed[0][0]
self.factory.finished_deferred.callback(result)
else:
parsed = CONTENT_NO_XML_VERSION_PATTERN.findall(self.response_buff)
result = b'' if not parsed else XML_VERSION.encode() + b'\r\n' + parsed[0][0]
self.factory.finished_deferred.callback(result)
class SCPDHTTPClientFactory(ClientFactory):
protocol = SCPDHTTPClientProtocol
def __init__(self, reactor, packet):
self.reactor = reactor
self.finished_deferred = defer.Deferred()
self.packet = packet
def buildProtocol(self, addr):
p = self.protocol()
p.factory = self
return p
@classmethod
def post(cls, reactor, command, **kwargs):
args = "".join("<%s>%s</%s>" % (n, kwargs.get(n), n) for n in command.param_names)
soap_body = ('\r\n%s\r\n<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" '
's:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>'
'<u:%s xmlns:u="%s">%s</u:%s></s:Body></s:Envelope>' % (
XML_VERSION, command.method, command.service_id.decode(),
args, command.method))
if "http://" in command.gateway_address.decode():
host = command.gateway_address.decode().split("http://")[1]
else:
host = command.gateway_address.decode()
data = (
(
'POST %s HTTP/1.1\r\n'
'Host: %s\r\n'
'User-Agent: python3/txupnp, UPnP/1.0, MiniUPnPc/1.9\r\n'
'Content-Length: %i\r\n'
'Content-Type: text/xml\r\n'
'SOAPAction: \"%s#%s\"\r\n'
'Connection: Close\r\n'
'Cache-Control: no-cache\r\n'
'Pragma: no-cache\r\n'
'%s'
'\r\n'
) % (
command.control_url.decode(), # could be just / even if it shouldn't be
host,
len(soap_body),
command.service_id.decode(), # maybe no quotes
command.method,
soap_body
)
).encode()
return cls(reactor, data)
@classmethod
def get(cls, reactor, control_url: str, address: str):
if "http://" in address:
host = address.split("http://")[1]
else:
host = address
if ":" in host:
host = host.split(":")[0]
if not control_url.startswith("/"):
control_url = "/" + control_url
data = (
(
'GET %s HTTP/1.1\r\n'
'Accept-Encoding: gzip\r\n'
'Host: %s\r\n'
'\r\n'
) % (control_url, host)
).encode()
return cls(reactor, data)
class SCPDRequester:
client_factory = SCPDHTTPClientFactory
def __init__(self, reactor):
self._reactor = reactor
self._get_requests = {}
self._post_requests = {}
def _save_get(self, request: bytes, response: bytes, destination: str) -> None:
self._get_requests[destination.lstrip("/")] = {
'request': request,
'response': response
}
def _save_post(self, request: bytes, response: bytes, destination: str) -> None:
p = self._post_requests.get(destination.lstrip("/"), [])
p.append({
'request': request,
'response': response,
})
self._post_requests[destination.lstrip("/")] = p
@defer.inlineCallbacks
def _scpd_get_soap_xml(self, control_url: str, address: str, service_port: int) -> bytes:
factory = self.client_factory.get(self._reactor, control_url, address)
url = address.split("http://")[1].split(":")[0]
self._reactor.connectTCP(url, service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_get(factory.packet, xml_response_bytes, control_url)
return xml_response_bytes
@defer.inlineCallbacks
def scpd_post_soap(self, command, **kwargs) -> tuple:
factory = self.client_factory.post(self._reactor, command, **kwargs)
url = command.gateway_address.split(b"http://")[1].split(b":")[0]
self._reactor.connectTCP(url.decode(), command.service_port, factory)
xml_response_bytes = yield factory.finished_deferred
self._save_post(
factory.packet, xml_response_bytes, command.gateway_address.decode() + command.control_url.decode()
)
content_dict = etree_to_dict(ElementTree.fromstring(xml_response_bytes.decode()))
envelope = content_dict[ENVELOPE]
response_body = flatten_keys(envelope[BODY], "{%s}" % command.service_id)
body = handle_fault(response_body) # raises UPnPError if there is a fault
response_key = None
for key in body:
if command.method in key:
response_key = key
break
if not response_key:
raise UPnPError("unknown response fields for %s")
response = body[response_key]
extracted_response = tuple([response[n] for n in command.returns])
return extracted_response
@defer.inlineCallbacks
def scpd_get_supported_actions(self, service, address: str, port: int) -> list:
xml_bytes = yield self._scpd_get_soap_xml(service.SCPDURL, address, port)
return parse_service_description(xml_bytes)
@defer.inlineCallbacks
def scpd_get(self, control_url: str, service_address: str, service_port: int) -> dict:
xml_bytes = yield self._scpd_get_soap_xml(control_url, service_address, service_port)
xml_dict = etree_to_dict(ElementTree.fromstring(xml_bytes.decode()))
schema_key = DEVICE
root = ROOT
for k in xml_dict.keys():
m = XML_ROOT_SANITY_PATTERN.findall(k)
if len(m) == 3 and m[1][0] and m[2][5]:
schema_key = m[1][0]
root = m[2][5]
break
flattened_xml = flatten_keys(xml_dict, "{%s}" % schema_key)[root]
return flattened_xml
def dump_packets(self) -> dict:
return {
'GET': self._get_requests,
'POST': self._post_requests
}
class SCPDCommand:
def __init__(self, scpd_requester: SCPDRequester, gateway_address, service_port, control_url, service_id, method,
param_names,
returns):
self.scpd_requester = scpd_requester
self.gateway_address = gateway_address
self.service_port = service_port
self.control_url = control_url
self.service_id = service_id
self.method = method
self.param_names = param_names
self.returns = returns
@staticmethod
def _process_result(*results):
"""
this method gets decorated automatically with a function that maps result types to the types
defined in the @return_types decorator
"""
return results
@defer.inlineCallbacks
def __call__(self, **kwargs):
if set(kwargs.keys()) != set(self.param_names):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_names))
response = yield self.scpd_requester.scpd_post_soap(self, **kwargs)
try:
result = self._process_result(*response)
except Exception as err:
log.error("error formatting response (%s):\n%s", err, response)
raise err
defer.returnValue(result)

View file

@ -1,168 +0,0 @@
import logging
import binascii
from twisted.internet import defer
from twisted.internet.protocol import DatagramProtocol
from txupnp.constants import UPNP_ORG_IGD, SSDP_DISCOVER, SSDP_IP_ADDRESS, SSDP_PORT, service_types
from txupnp.constants import SSDP_HOST
from txupnp.fault import UPnPError
from txupnp.ssdp_datagram import SSDPDatagram
log = logging.getLogger(__name__)
class SSDPProtocol(DatagramProtocol):
def __init__(self, reactor, iface, router, ssdp_address=SSDP_IP_ADDRESS,
ssdp_port=SSDP_PORT, ttl=1, max_devices=None, debug_packets=False,
debug_sent=None, debug_received=None):
self._reactor = reactor
self._sem = defer.DeferredSemaphore(1)
self.discover_callbacks = {}
self.iface = iface
self.router = router
self.ssdp_address = ssdp_address
self.ssdp_port = ssdp_port
self.ttl = ttl
self._start = None
self.max_devices = max_devices
self.devices = []
self.debug_packets = debug_packets
self.debug_sent = debug_sent if debug_sent is not None else []
self.debug_received = debug_received if debug_sent is not None else []
def _send_m_search(self, service=UPNP_ORG_IGD):
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, host=SSDP_HOST, st=service, man=SSDP_DISCOVER, mx=1)
log.debug("sending packet to %s:\n%s", SSDP_HOST, packet.encode())
try:
msg_bytes = packet.encode().encode()
if self.debug_packets:
self.debug_sent.append(msg_bytes)
self.transport.write(msg_bytes, (self.ssdp_address, self.ssdp_port))
except Exception as err:
log.exception("failed to write %s to %s:%i", packet.encode(), self.ssdp_address, self.ssdp_port)
raise err
@staticmethod
def _gather(finished_deferred, max_results, results: list):
def discover_cb(packet):
if not finished_deferred.called and packet.st in service_types:
results.append(packet.as_dict())
if len(results) >= max_results:
finished_deferred.callback(results)
return discover_cb
def m_search(self, address, timeout, max_devices):
# return deferred for a pending call if we have one
if address in self.discover_callbacks:
d = self.discover_callbacks[address][1]
if not d.called: # the existing deferred has already fired, make a new one
return d
def _trap_timeout_and_return_results(err):
if err.check(defer.TimeoutError):
return self.devices
raise err
d = defer.Deferred()
d.addTimeout(timeout, self._reactor)
d.addErrback(_trap_timeout_and_return_results)
found_cb = self._gather(d, max_devices, self.devices)
self.discover_callbacks[address] = found_cb, d
for st in service_types:
self._send_m_search(service=st)
return d
def startProtocol(self):
self._start = self._reactor.seconds()
self.transport.setTTL(self.ttl)
self.transport.joinGroup(self.ssdp_address, interface=self.iface)
def datagramReceived(self, datagram, address):
if address[0] == self.iface:
return
if self.debug_packets:
self.debug_received.append((address, datagram))
try:
packet = SSDPDatagram.decode(datagram)
log.debug("decoded %s from %s:%i:\n%s", packet.get_friendly_name(), address[0], address[1], packet.encode())
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i: %s\npacket: %s", address[0], address[1], err,
binascii.hexlify(datagram))
return
except Exception:
log.exception("failed to decode: %s", binascii.hexlify(datagram))
return
if packet._packet_type == packet._OK:
log.debug("%s:%i replied to our m-search with new xml url: %s", address[0], address[1], packet.location)
# if address[0] in self.discover_callbacks and packet.location not in map(lambda p: p['location'], self.devices):
if packet.location not in map(lambda p: p['location'], self.devices):
if address[0] not in self.discover_callbacks:
self.devices.append(packet.as_dict())
else:
self._sem.run(self.discover_callbacks[address[0]][0], packet)
else:
log.info("ignored packet from %s:%s (%s) %s", address[0], address[1], packet._packet_type, packet.location)
elif packet._packet_type == packet._NOTIFY:
log.debug("%s:%i sent us a notification (type: %s), url: %s", address[0], address[1], packet.nts,
packet.location)
class SSDPFactory:
def __init__(self, reactor, lan_address, router_address, debug_packets=False):
self.lan_address = lan_address
self.router_address = router_address
self._reactor = reactor
self.protocol = None
self.port = None
self.debug_packets = debug_packets
self.debug_sent = []
self.debug_received = []
self.server_infos = []
def disconnect(self):
if not self.port:
return
self.protocol.transport.leaveGroup(SSDP_IP_ADDRESS, interface=self.lan_address)
self.port.stopListening()
self.port = None
self.protocol = None
def connect(self):
if not self.protocol:
self.protocol = SSDPProtocol(self._reactor, self.lan_address, self.router_address,
debug_packets=self.debug_packets, debug_sent=self.debug_sent,
debug_received=self.debug_received)
if not self.port:
self._reactor.addSystemEventTrigger("before", "shutdown", self.disconnect)
self.port = self._reactor.listenMulticast(self.protocol.ssdp_port, self.protocol, listenMultiple=True)
@defer.inlineCallbacks
def m_search(self, address, timeout, max_devices):
"""
Perform a M-SEARCH (HTTP over UDP) and gather the results
:param address: (str) address to listen for responses from
:param timeout: (int) timeout for the query
:param max_devices: (int) block until timeout or at least this many devices are found
:param service_types: (list) M-SEARCH "ST" arguments to try, if None use the defaults
:return: (list) [ (dict) {
'server: (str) gateway os and version
'location': (str) upnp gateway url,
'cache-control': (str) max age,
'date': (int) server time,
'usn': (str) usn
}, ...]
"""
self.connect()
server_infos = yield self.protocol.m_search(address, timeout, max_devices)
for server_info in server_infos:
self.server_infos.append(server_info)
defer.returnValue(server_infos)
def get_ssdp_packet_replay(self) -> dict:
return {
'lan_address': self.lan_address,
'router_address': self.router_address,
'sent': self.debug_sent,
'received': self.debug_received,
}

View file

@ -1,146 +0,0 @@
import netifaces
import logging
from twisted.internet import defer
from txupnp.fault import UPnPError
from txupnp.ssdp import SSDPFactory
from txupnp.gateway import Gateway
log = logging.getLogger(__name__)
class UPnP:
def __init__(self, reactor, try_miniupnpc_fallback=False, debug_ssdp=False, router_ip=None,
lan_ip=None, iface_name=None):
self._reactor = reactor
if router_ip and lan_ip and iface_name:
self.router_ip, self.lan_address, self.iface_name = router_ip, lan_ip, iface_name
else:
self.router_ip, self.iface_name = netifaces.gateways()['default'][netifaces.AF_INET]
self.lan_address = netifaces.ifaddresses(self.iface_name)[netifaces.AF_INET][0]['addr']
self.sspd_factory = SSDPFactory(self._reactor, self.lan_address, self.router_ip, debug_packets=debug_ssdp)
self.try_miniupnpc_fallback = try_miniupnpc_fallback
self.miniupnpc_runner = None
self.miniupnpc_igd_url = None
self.gateway = None
def m_search(self, address, timeout=1, max_devices=1):
"""
Perform a HTTP over UDP M-SEARCH query
returns (list) [{
'server: <gateway os and version string>
'location': <upnp gateway url>,
'cache-control': <max age>,
'date': <server time>,
'usn': <usn>
}, ...]
"""
return self.sspd_factory.m_search(address, timeout=timeout, max_devices=max_devices)
@defer.inlineCallbacks
def _discover(self, timeout=1, max_devices=1):
server_infos = yield self.sspd_factory.m_search(
self.router_ip, timeout=timeout, max_devices=max_devices
)
if not server_infos:
return False
server_info = server_infos[0]
if 'st' in server_info:
gateway = Gateway(reactor=self._reactor, **server_info)
yield gateway.discover_commands()
self.gateway = gateway
return True
elif 'st' not in server_info:
log.error("don't know how to handle gateway: %s", server_info)
return False
@defer.inlineCallbacks
def discover(self, timeout=1, max_devices=1):
try:
found = yield self._discover(timeout=timeout, max_devices=max_devices)
except defer.TimeoutError:
found = False
finally:
self.sspd_factory.disconnect()
if found:
log.debug("found upnp device")
else:
log.debug("failed to find upnp device")
return found
def get_external_ip(self) -> str:
return self.gateway.commands.GetExternalIPAddress()
def add_port_mapping(self, external_port: int, protocol: str, internal_port, lan_address: str,
description: str) -> None:
return self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration=""
)
@defer.inlineCallbacks
def get_port_mapping_by_index(self, index: int) -> (str, int, str, int, str, bool, str, int):
try:
redirect = yield self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
defer.returnValue(redirect)
except UPnPError:
defer.returnValue(None)
@defer.inlineCallbacks
def get_redirects(self):
redirects = []
cnt = 0
redirect = yield self.get_port_mapping_by_index(cnt)
while redirect:
redirects.append(redirect)
cnt += 1
redirect = yield self.get_port_mapping_by_index(cnt)
defer.returnValue(redirects)
@defer.inlineCallbacks
def get_specific_port_mapping(self, external_port: int, protocol: str) -> (int, str, bool, str, int):
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
"""
try:
result = yield self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol
)
defer.returnValue(result)
except UPnPError:
defer.returnValue(None)
def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: None
"""
return self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
@defer.inlineCallbacks
def get_next_mapping(self, port, protocol, description, internal_port=None):
if protocol not in ["UDP", "TCP"]:
raise UPnPError("unsupported protocol: {}".format(protocol))
internal_port = internal_port or port
redirect_tups = yield self.get_redirects()
redirects = {
"%i:%s" % (ext_port, proto): (int_host, int_port, desc)
for (ext_host, ext_port, proto, int_port, int_host, enabled, desc, lease) in redirect_tups
}
while ("%i:%s" % (port, protocol)) in redirects:
int_host, int_port, _ = redirects["%i:%s" % (port, protocol)]
if int_host == self.lan_address and int_port == internal_port:
break
port += 1
yield self.add_port_mapping( # set one up
port, protocol, internal_port, self.lan_address, description
)
defer.returnValue(port)