more mypy refactoring, improve coverage #13

Merged
jackrobison merged 1 commit from improve-coverage into master 2019-05-24 03:35:37 +02:00
11 changed files with 637 additions and 244 deletions

View file

@ -4,8 +4,7 @@ import logging
import textwrap
import typing
from collections import OrderedDict
from aioupnp.upnp import run_cli, UPnP
from aioupnp.commands import SOAPCommands
from aioupnp.upnp import run_cli, UPnP, cli_commands
log = logging.getLogger("aioupnp")
handler = logging.StreamHandler()
@ -20,36 +19,53 @@ base_usage = "\n".join(textwrap.wrap(
def get_help(command: str) -> str:
annotations = UPnP.get_annotations(command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
return base_usage + "\n".join(
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
)
annotations, doc = UPnP.get_annotations(command)
doc = doc or ""
arg_strs = []
for k, v in annotations.items():
if k not in ['return', 'igd_args', 'loop']:
t = str(v) if not hasattr(v, "__name__") else v.__name__
if t == 'bool':
arg_strs.append(f"[--{k}]")
else:
arg_strs.append(f"[--{k}=<{t}>]")
elif k == 'igd_args':
arg_strs.append(f"[--<header key>=<header value>, ...]")
params = " ".join(arg_strs)
usage = "\n".join(textwrap.wrap(
f"aioupnp [-h] [--debug_logging] {command} {params}",
100, subsequent_indent=' ', break_long_words=False)) + "\n"
return usage + textwrap.dedent(doc)
def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
argv = argv or list(sys.argv)
commands = list(SOAPCommands.SOAP_COMMANDS)
help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
" | ".join(cli_commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
))
usage = \
"\n%s\n" \
"%s\n" \
"If m-search headers are provided as keyword arguments all of the headers to be used must be provided,\n" \
"in the order they are to be used. For example:\n" \
" aioupnp --HOST=239.255.255.250:1900 --MAN=\"ssdp:discover\" --MX=1 --ST=upnp:rootdevice m_search\n\n" \
"Commands:\n" \
"%s\n\n" \
"For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str)
" aioupnp help <command>" % (base_usage, help_str)
args: typing.List[str] = [str(arg) for arg in argv[1:]]
if not args:
print(usage)
return 0
if args[0] in ['help', '-h', '--help']:
if len(args) > 1:
if args[1] in commands:
print(get_help(args[1]))
if args[1].replace("-", "_") in cli_commands:
print(get_help(args[1].replace("-", "_")))
return 0
print(usage)
return 0

View file

@ -2,43 +2,64 @@ import asyncio
import time
import typing
import logging
from typing import Tuple
from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service
log = logging.getLogger(__name__)
def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
def soap_optional_str(x: typing.Optional[typing.Union[str, int]]) -> typing.Optional[str]:
return str(x) if x is not None and str(x).lower() not in ['none', 'nil'] else None
def soap_bool(x: typing.Optional[str]) -> bool:
def soap_bool(x: typing.Optional[typing.Union[str, int]]) -> bool:
return False if not x or str(x).lower() in ['false', 'False'] else True
def recast_single_result(t: type, result: typing.Any) -> typing.Optional[typing.Union[str, int, float, bool]]:
if t is bool:
return soap_bool(result)
if t is str:
return soap_optional_str(result)
return t(result)
class GetSpecificPortMappingEntryResponse(typing.NamedTuple):
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
class GetGenericPortMappingEntryResponse(typing.NamedTuple):
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]],
result_keys: typing.List[str]) -> typing.Tuple:
if return_annotation is None or len(result_keys) == 0:
return ()
result_keys: typing.List[str]) -> typing.Optional[
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
if len(result_keys) == 1:
assert len(result_keys) == 1
single_result = result[result_keys[0]]
return (recast_single_result(return_annotation, single_result), )
annotated_args: typing.List[type] = list(return_annotation.__args__)
assert len(annotated_args) == len(result_keys)
recast_results: typing.List[typing.Optional[typing.Union[str, int, float, bool]]] = []
for type_annotation, result_key in zip(annotated_args, result_keys):
recast_results.append(recast_single_result(type_annotation, result.get(result_key, None)))
return tuple(recast_results)
if return_annotation is bool:
return soap_bool(single_result)
if return_annotation is str:
return soap_optional_str(single_result)
return int(result[result_keys[0]]) if result_keys[0] in result else None
elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]:
arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types
assert len(arg_types) == len(result_keys)
recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {}
for i, (field_name, result_key) in enumerate(zip(arg_types, result_keys)):
result_field_name = result_keys[i]
field_type = arg_types[field_name]
if field_type is bool:
recast_results[field_name] = soap_bool(result.get(result_field_name, None))
elif field_type is str:
recast_results[field_name] = soap_optional_str(result.get(result_field_name, None))
elif field_type is int:
recast_results[field_name] = int(result[result_field_name]) if result_field_name in result else None
return return_annotation(**recast_results)
return None
class SOAPCommands:
@ -88,7 +109,10 @@ class SOAPCommands:
self._base_address = base_address
self._port = port
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Tuple, typing.Optional[Exception], float]] = []
typing.Optional[typing.Union[str, int, bool,
GetSpecificPortMappingEntryResponse,
GetGenericPortMappingEntryResponse]],
typing.Optional[Exception], float]] = []
def is_registered(self, name: str) -> bool:
if name not in self.SOAP_COMMANDS:
@ -112,7 +136,8 @@ class SOAPCommands:
input_names: typing.List[str] = self._registered[service][name][0]
output_names: typing.List[str] = self._registered[service][name][1]
async def wrapper(**kwargs: typing.Any) -> typing.Tuple:
async def wrapper(**kwargs: typing.Any) -> typing.Optional[
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
assert service.controlURL is not None
assert service.serviceType is not None
@ -122,11 +147,10 @@ class SOAPCommands:
)
if err is not None:
assert isinstance(xml_bytes, bytes)
self._requests.append((name, kwargs, xml_bytes, (), err, time.time()))
self._requests.append((name, kwargs, xml_bytes, None, err, time.time()))
raise err
assert 'return' in annotations
result = recast_return(annotations['return'], response, output_names)
self._requests.append((name, kwargs, xml_bytes, result, None, time.time()))
return result
@ -161,8 +185,7 @@ class SOAPCommands:
)
return None
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
bool, str, int]:
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> GetGenericPortMappingEntryResponse:
"""
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration)
@ -171,19 +194,19 @@ class SOAPCommands:
if not self.is_registered(name):
raise NotImplementedError()
assert name in self._wrappers_kwargs
result: Tuple[str, int, str, int, str, bool, str, int] = await self._wrappers_kwargs[name](
result: GetGenericPortMappingEntryResponse = await self._wrappers_kwargs[name](
NewPortMappingIndex=NewPortMappingIndex
)
return result
async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
NewProtocol: str) -> GetSpecificPortMappingEntryResponse:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
name = "GetSpecificPortMappingEntry"
if not self.is_registered(name):
raise NotImplementedError()
assert name in self._wrappers_kwargs
result: Tuple[int, str, bool, str, int] = await self._wrappers_kwargs[name](
result: GetSpecificPortMappingEntryResponse = await self._wrappers_kwargs[name](
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol
)
return result
@ -205,8 +228,8 @@ class SOAPCommands:
if not self.is_registered(name):
raise NotImplementedError()
assert name in self._wrappers_no_args
result: Tuple[str] = await self._wrappers_no_args[name]()
return result[0]
result: str = await self._wrappers_no_args[name]()
return result
# async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
# """Returns (NewRSIPAvailable, NewNATEnabled)"""

View file

@ -3,7 +3,7 @@ import logging
import typing
import asyncio
from collections import OrderedDict
from typing import Dict, List, Union
from typing import Dict, List
from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands
@ -103,74 +103,81 @@ class Gateway:
self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
def gateway_descriptor(self) -> dict:
r = {
'server': self.server.decode(),
'urlBase': self.url_base,
'location': self.location.decode(),
"specVersion": self.spec_version,
'usn': self.usn.decode(),
'urn': self.urn.decode(),
}
return r
# def gateway_descriptor(self) -> dict:
# r = {
# 'server': self.server.decode(),
# 'urlBase': self.url_base,
# 'location': self.location.decode(),
# "specVersion": self.spec_version,
# 'usn': self.usn.decode(),
# 'urn': self.urn.decode(),
# }
# return r
@property
def manufacturer_string(self) -> str:
if not self.devices:
return "UNKNOWN GATEWAY"
manufacturer_string = "UNKNOWN GATEWAY"
if self.devices:
devices: typing.List[Device] = list(self.devices.values())
device = devices[0]
return f"{device.manufacturer} {device.modelName}"
manufacturer_string = f"{device.manufacturer} {device.modelName}"
return manufacturer_string
@property
def services(self) -> Dict[str, Service]:
if not self._device:
return {}
return {str(service.serviceType): service for service in self._services}
@property
def devices(self) -> Dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> typing.Optional[Service]:
services: Dict[str, Service] = {}
if self._services:
for service in self._services:
if service.serviceType and service.serviceType.lower() == service_type.lower():
return service
return None
if service.serviceType is not None:
services[service.serviceType] = service
return services
@property
def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Tuple],
typing.Optional[Exception], float]]:
soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Optional[typing.Tuple],
typing.Optional[Exception], float]] = []
soap_call_infos.extend([
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
name, request_args, raw_response, decoded_response, soap_error, ts
) in self.commands._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def devices(self) -> Dict[str, Device]:
devices: Dict[str, Device] = {}
if self._device:
for device in self._devices:
if device.udn is not None:
devices[device.udn] = device
return devices
def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
return {
'manufacturer_string': self.manufacturer_string,
'gateway_address': self.base_ip,
'gateway_descriptor': self.gateway_descriptor(),
'gateway_xml': self._xml_response,
'services_xml': self._service_descriptors,
'services': {service.SCPDURL: service.as_dict() for service in self._services},
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
'reply': self._ok_packet.as_dict(),
'soap_port': self.port,
'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions,
'soap_requests': self.soap_requests
}
# def get_service(self, service_type: str) -> typing.Optional[Service]:
# for service in self._services:
# if service.serviceType and service.serviceType.lower() == service_type.lower():
# return service
# return None
# @property
# def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
# typing.Optional[typing.Tuple],
# typing.Optional[Exception], float]]:
# soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
# typing.Optional[typing.Tuple],
# typing.Optional[Exception], float]] = []
# soap_call_infos.extend([
# (name, request_args, raw_response, decoded_response, soap_error, ts)
# for (
# name, request_args, raw_response, decoded_response, soap_error, ts
# ) in self.commands._requests
# ])
# soap_call_infos.sort(key=lambda x: x[5])
# return soap_call_infos
# def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
# return {
# 'manufacturer_string': self.manufacturer_string,
# 'gateway_address': self.base_ip,
# 'gateway_descriptor': self.gateway_descriptor(),
# 'gateway_xml': self._xml_response,
# 'services_xml': self._service_descriptors,
# 'services': {service.SCPDURL: service.as_dict() for service in self._services},
# 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
# 'reply': self._ok_packet.as_dict(),
# 'soap_port': self.port,
# 'registered_soap_commands': self._registered_commands,
# 'unsupported_soap_commands': self._unsupported_actions,
# 'soap_requests': self.soap_requests
# }
@classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
@ -206,7 +213,7 @@ class Gateway:
ignored.add(datagram.location)
continue
else:
log.debug('found gateway device %s', datagram.location)
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway
except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None

View file

@ -2,11 +2,11 @@ import struct
import socket
import typing
from asyncio.protocols import DatagramProtocol
from asyncio.transports import BaseTransport
from asyncio.transports import DatagramTransport
from unittest import mock
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
def _get_sock(transport: typing.Optional[DatagramTransport]) -> typing.Optional[socket.socket]:
if transport is None or not hasattr(transport, "_extra"):
return None
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
@ -18,7 +18,7 @@ class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address
self.bind_address = bind_address
self.transport: typing.Optional[BaseTransport] = None
self.transport: typing.Optional[DatagramTransport] = None
@property
def sock(self) -> typing.Optional[socket.socket]:
@ -26,14 +26,13 @@ class MulticastProtocol(DatagramProtocol):
def get_ttl(self) -> int:
sock = self.sock
if not sock:
raise ValueError("not connected")
if sock:
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
return 0
def set_ttl(self, ttl: int = 1) -> None:
sock = self.sock
if not sock:
return None
if sock:
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
)
@ -41,8 +40,7 @@ class MulticastProtocol(DatagramProtocol):
def join_group(self, multicast_address: str, bind_address: str) -> None:
sock = self.sock
if not sock:
return None
if sock:
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
@ -51,15 +49,14 @@ class MulticastProtocol(DatagramProtocol):
def leave_group(self, multicast_address: str, bind_address: str) -> None:
sock = self.sock
if not sock:
raise ValueError("not connected")
if sock:
sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
)
return None
def connection_made(self, transport: BaseTransport) -> None:
def connection_made(self, transport: DatagramTransport) -> None: # type: ignore
self.transport = transport
return None

View file

@ -28,19 +28,14 @@ class SSDPProtocol(MulticastProtocol):
self.notifications: typing.List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop)
def connection_made(self, transport) -> None:
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
super().connection_made(transport)
self.connected.set()
return None
def disconnect(self) -> None:
if self.transport:
try:
self.leave_group(self.multicast_address, self.bind_address)
except ValueError:
pass
except Exception:
log.exception("unexpected error leaving multicast group")
self.transport.close()
self.connected.clear()
while self._pending_searches:
@ -50,13 +45,12 @@ class SSDPProtocol(MulticastProtocol):
return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored:
return None
if packet.location not in self._ignored:
# TODO: fix this
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
while len(self._pending_searches):
t: typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle] = self._pending_searches.pop()
t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f = t[2]
if f not in set_futures:
@ -69,10 +63,12 @@ class SSDPProtocol(MulticastProtocol):
self._pending_searches.append(tmp.pop())
return None
def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS
if not self.transport:
raise UPnPError("SSDP transport not connected")
if not fut.done():
fut.set_exception(UPnPError("SSDP transport not connected"))
return None
log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None
@ -84,7 +80,7 @@ class SSDPProtocol(MulticastProtocol):
packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None
self._pending_searches.append(
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
)
return await asyncio.wait_for(fut, timeout)
@ -95,7 +91,7 @@ class SSDPProtocol(MulticastProtocol):
packet = SSDPDatagram.decode(data)
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data))
return None
@ -131,19 +127,13 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
)
transport = listen_result[0]
protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol)
except Exception as err:
print(err)
raise UPnPError(err)
try:
else:
protocol.join_group(protocol.multicast_address, protocol.bind_address)
protocol.set_ttl(1)
except Exception as err:
protocol.disconnect()
raise UPnPError(err)
return protocol, gateway_address, lan_address
@ -178,10 +168,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try:
await protocol.m_search(gateway_address, batch_timeout, args)
protocol.disconnect()
return args
except asyncio.TimeoutError:
continue
else:
protocol.disconnect()
return args
protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))

View file

@ -5,13 +5,14 @@ import logging
import json
import asyncio
from collections import OrderedDict
from typing import Tuple, Dict, List, Union, Optional
from typing import Tuple, Dict, List, Union, Optional, Any
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import SOAPCommands
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
log = logging.getLogger(__name__)
@ -31,11 +32,27 @@ class UPnP:
self.gateway = gateway
@classmethod
def get_annotations(cls, command: str) -> Dict[str, type]:
return getattr(SOAPCommands, command).__annotations__
def get_annotations(cls, command: str) -> Tuple[Dict[str, Any], Optional[str]]:
if command == "m_search":
return cls.m_search.__annotations__, cls.m_search.__doc__
if command == "get_external_ip":
return cls.get_external_ip.__annotations__, cls.get_external_ip.__doc__
if command == "add_port_mapping":
return cls.add_port_mapping.__annotations__, cls.add_port_mapping.__doc__
if command == "get_port_mapping_by_index":
return cls.get_port_mapping_by_index.__annotations__, cls.get_port_mapping_by_index.__doc__
if command == "get_redirects":
return cls.get_redirects.__annotations__, cls.get_redirects.__doc__
if command == "get_specific_port_mapping":
return cls.get_specific_port_mapping.__annotations__, cls.get_specific_port_mapping.__doc__
if command == "delete_port_mapping":
return cls.delete_port_mapping.__annotations__, cls.delete_port_mapping.__doc__
if command == "get_next_mapping":
return cls.get_next_mapping.__annotations__, cls.get_next_mapping.__doc__
raise AttributeError(command)
@classmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
@staticmethod
def get_lan_and_gateway(lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> Tuple[str, str]:
if not lan_address or not gateway_address:
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
@ -55,10 +72,28 @@ class UPnP:
@classmethod
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: Optional[Dict[str, Union[int, str]]] = None,
unicast: bool = True, interface_name: str = 'default',
igd_args: Optional[Dict[str, Union[str, int]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
"""
Perform a M-SEARCH for a upnp gateway.
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
:return: {
'lan_address': (str) lan address,
'gateway_address': (str) gateway address,
'm_search_kwargs': (str) equivalent igd_args ,
'discover_reply': (dict) SSDP response datagram
}
"""
if not lan_address or not gateway_address:
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -79,10 +114,25 @@ class UPnP:
}
async def get_external_ip(self) -> str:
"""
Get the external ip address from the gateway
:return: (str) external ip
"""
return await self.gateway.commands.GetExternalIPAddress()
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
description: str) -> None:
"""
Add a new port mapping
:param external_port: (int) external port to map
:param protocol: (str) UDP | TCP
:param internal_port: (int) internal port
:param lan_address: (str) internal lan address
:param description: (str) mapping description
:return: None
"""
await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address,
@ -90,11 +140,42 @@ class UPnP:
)
return None
async def get_port_mapping_by_index(self, index: int) -> Tuple[str, int, str, int, str, bool, str, int]:
async def get_port_mapping_by_index(self, index: int) -> GetGenericPortMappingEntryResponse:
"""
Get information about a port mapping by index number
:param index: (int) mapping index number
:return: NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
async def get_redirects(self) -> List[GetGenericPortMappingEntryResponse]:
"""
Get information about all mapped ports
:return: List[
NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
]
"""
redirects: List[GetGenericPortMappingEntryResponse] = []
cnt = 0
try:
redirect = await self.get_port_mapping_by_index(cnt)
@ -109,11 +190,19 @@ class UPnP:
break
return redirects
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> GetSpecificPortMappingEntryResponse:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
Get information about a port mapping by port number and protocol
:param external_port: (int) port number
:param protocol: (str) UDP | TCP
:return: NamedTuple[
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
return await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
@ -121,25 +210,31 @@ class UPnP:
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
"""
:param external_port: (int) external port to listen on
:param protocol: (str) 'UDP' | 'TCP'
Delete a port mapping
:param external_port: (int) port number of mapping
:param protocol: (str) TCP | UDP
:return: None
"""
await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
)
return None
async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int:
"""
:param port: (int) external port to redirect from
:param protocol: (str) 'UDP' | 'TCP'
:param description: (str) mapping description
:param internal_port: (int) internal port to redirect to
Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
:return: (int) <mapped port>
:param port: (int) external port
:param protocol: (str) UDP | TCP
:param description: (str) mapping description
:param internal_port: (int) internal port
:return: (int) mapped port
"""
_internal_port = int(internal_port or port)
requested_port = int(_internal_port)
port = int(port)
@ -264,30 +359,7 @@ class UPnP:
# return await self.gateway.commands.GetActiveConnections()
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
:param lan_address: the ip address of the local interface
:param gateway_address: the ip address of the gateway
:param timeout: timeout, in seconds
:param interface_name: name of the network interface, the default is aliased to 'default'
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
cli_commands = [
cli_commands = [
'm_search',
'get_external_ip',
'add_port_mapping',
@ -296,11 +368,24 @@ def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str
'get_specific_port_mapping',
'delete_port_mapping',
'get_next_mapping'
]
]
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
)
else: # automatically discover the gateway
try:

View file

@ -16,7 +16,7 @@ except ImportError:
@contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {}
@ -28,7 +28,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
def _write(data):
sent_tcp_packets.append(data)
if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
reply = tcp_replies[data]
i = 0
while i < len(reply):
loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100])
i += 100
return
else:
pass
@ -46,6 +50,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_udp_packets.append(data)
loop.call_later(udp_delay_reply, p.datagram_received, data,
(p.bind_address, 1900))
if add_potato_datagrams:
loop.call_soon(p.datagram_received, b'potato', ('?.?.?.?', 1900))
if (data, addr) in udp_replies:
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900))

View file

@ -1,23 +1,30 @@
import unittest
from unittest import mock
import socket
import struct
from asyncio import DatagramTransport
from aioupnp.protocols.multicast import MulticastProtocol
class TestMulticast(unittest.TestCase):
def test_it(self):
class none_socket:
sock = None
def test_multicast(self):
_ttl = None
mock_socket = mock.MagicMock(spec=socket.socket)
def getsockopt(*_):
return _ttl
def get(self, name, default=None):
return default
def setsockopt(a, b, ttl: bytes):
nonlocal _ttl
_ttl, = struct.unpack('b', ttl)
mock_socket.getsockopt = getsockopt
mock_socket.setsockopt = setsockopt
protocol = MulticastProtocol('1.2.3.4', '1.2.3.4')
transport = DatagramTransport(none_socket())
protocol.set_ttl(1)
with self.assertRaises(ValueError):
_ = protocol.get_ttl()
transport = DatagramTransport()
transport._extra = {'socket': mock_socket}
self.assertEqual(None, protocol.set_ttl(1))
self.assertEqual(0, protocol.get_ttl())
protocol.connection_made(transport)
protocol.set_ttl(1)
with self.assertRaises(ValueError):
_ = protocol.get_ttl()
self.assertEqual(None, protocol.set_ttl(1))
self.assertEqual(1, protocol.get_ttl())

View file

@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError
from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol
from tests import AsyncioTestCase, mock_tcp_and_udp
@ -28,6 +28,13 @@ class TestSSDP(AsyncioTestCase):
])
reply_packet = SSDPDatagram("OK", reply_args)
async def test_transport_not_connected_error(self):
try:
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
self.assertTrue(False)
except UPnPError as err:
self.assertEqual(str(err), "SSDP transport not connected")
async def test_m_search_reply_unicast(self):
replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
@ -80,3 +87,13 @@ class TestSSDP(AsyncioTestCase):
self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertEqual(args, self.successful_args)
async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self):
sent = []
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent,
add_potato_datagrams=True):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets)

View file

@ -21,6 +21,139 @@ m_search_cli_result = """{
}\n"""
m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=<str>] [--gateway_address=<str>]
[--timeout=<int>] [--unicast] [--interface_name=<str>] [--<header key>=<header value>, ...]
Perform a M-SEARCH for a upnp gateway.
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
:return: {
'lan_address': (str) lan address,
'gateway_address': (str) gateway address,
'm_search_kwargs': (str) equivalent igd_args ,
'discover_reply': (dict) SSDP response datagram
}\n
"""
expected_usage = """aioupnp [-h] [--debug_logging] [--interface=<interface>] [--gateway_address=<gateway_address>]
[--lan_address=<lan_address>] [--timeout=<timeout>] [(--<header_key>=<value>)...]
If m-search headers are provided as keyword arguments all of the headers to be used must be provided,
in the order they are to be used. For example:
aioupnp --HOST=239.255.255.250:1900 --MAN="ssdp:discover" --MX=1 --ST=upnp:rootdevice m_search
Commands:
m_search | get_external_ip | add_port_mapping | get_port_mapping_by_index | get_redirects |
get_specific_port_mapping | delete_port_mapping | get_next_mapping
For help with a specific command: aioupnp help <command>
"""
expected_get_external_ip_usage = """aioupnp [-h] [--debug_logging] get_external_ip
Get the external ip address from the gateway
:return: (str) external ip
"""
expected_add_port_mapping_usage = """aioupnp [-h] [--debug_logging] add_port_mapping [--external_port=<int>] [--protocol=<str>]
[--internal_port=<int>] [--lan_address=<str>] [--description=<str>]
Add a new port mapping
:param external_port: (int) external port to map
:param protocol: (str) UDP | TCP
:param internal_port: (int) internal port
:param lan_address: (str) internal lan address
:param description: (str) mapping description
:return: None
"""
expected_get_next_mapping_usage = """aioupnp [-h] [--debug_logging] get_next_mapping [--port=<int>] [--protocol=<str>]
[--description=<str>] [--internal_port=<typing.Union[int, NoneType]>]
Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
:param port: (int) external port
:param protocol: (str) UDP | TCP
:param description: (str) mapping description
:param internal_port: (int) internal port
:return: (int) mapped port
"""
expected_delete_port_mapping_usage = """aioupnp [-h] [--debug_logging] delete_port_mapping [--external_port=<int>] [--protocol=<str>]
Delete a port mapping
:param external_port: (int) port number of mapping
:param protocol: (str) TCP | UDP
:return: None
"""
expected_get_specific_port_mapping_usage = """aioupnp [-h] [--debug_logging] get_specific_port_mapping [--external_port=<int>] [--protocol=<str>]
Get information about a port mapping by port number and protocol
:param external_port: (int) port number
:param protocol: (str) UDP | TCP
:return: NamedTuple[
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
expected_get_redirects_usage = """aioupnp [-h] [--debug_logging] get_redirects
Get information about all mapped ports
:return: List[
NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
]
"""
expected_get_port_mapping_by_index_usage = """aioupnp [-h] [--debug_logging] get_port_mapping_by_index [--index=<int>]
Get information about a port mapping by index number
:param index: (int) mapping index number
:return: NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
class TestCLI(AsyncioTestCase):
gateway_address = "10.0.0.1"
soap_port = 49152
@ -101,3 +234,117 @@ class TestCLI(AsyncioTestCase):
self.loop
)
self.assertEqual(m_search_cli_result, actual_output.getvalue())
def test_usage(self):
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help'],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'test'],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'test', 'help'],
self.loop
)
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'test'],
self.loop
)
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, "--something=test"],
self.loop
)
self.assertEqual("no command given\n" + expected_usage, actual_output.getvalue())
def test_commands_help(self):
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'm-search'],
self.loop
)
self.assertEqual(m_search_help_msg, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get-external-ip'],
self.loop
)
self.assertEqual(expected_get_external_ip_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'add-port-mapping'],
self.loop
)
self.assertEqual(expected_add_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get-next-mapping'],
self.loop
)
self.assertEqual(expected_get_next_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'delete_port_mapping'],
self.loop
)
self.assertEqual(expected_delete_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_specific_port_mapping'],
self.loop
)
self.assertEqual(expected_get_specific_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_redirects'],
self.loop
)
self.assertEqual(expected_get_redirects_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_port_mapping_by_index'],
self.loop
)
self.assertEqual(expected_get_port_mapping_by_index_usage, actual_output.getvalue())

View file

@ -4,17 +4,7 @@ from aioupnp.upnp import UPnP
from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway
from aioupnp.serialization.ssdp import SSDPDatagram
class TestGetAnnotations(AsyncioTestCase):
def test_get_annotations(self):
expected = {
'NewRemoteHost': str, 'NewExternalPort': int, 'NewProtocol': str, 'NewInternalPort': int,
'NewInternalClient': str, 'NewEnabled': int, 'NewPortMappingDescription': str,
'NewLeaseDuration': str, 'return': None
}
self.assertDictEqual(expected, UPnP.get_annotations('AddPortMapping'))
from aioupnp.commands import GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse
class UPnPCommandTestCase(AsyncioTestCase):
@ -76,7 +66,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
await gateway.discover_commands(self.loop)
upnp = UPnP(self.client_address, self.gateway_address, gateway)
result = await upnp.get_port_mapping_by_index(0)
self.assertEqual((None, 9308, 'UDP', 9308, "11.2.3.44", True, "11.2.3.44:9308 to 9308 (UDP)", 0), result)
self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True,
"11.2.3.44:9308 to 9308 (UDP)", 0), result)
class TestGetNextPortMapping(UPnPCommandTestCase):
@ -120,6 +111,9 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
await upnp.get_specific_port_mapping(1000, 'UDP')
except UPnPError:
result = await upnp.get_specific_port_mapping(9308, 'UDP')
self.assertEqual((9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0), result)
self.assertEqual(
GetSpecificPortMappingEntryResponse(9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0),
result
)
else:
self.assertTrue(False)