more mypy refactoring, improve coverage #13

Merged
jackrobison merged 1 commit from improve-coverage into master 2019-05-24 03:35:37 +02:00
11 changed files with 637 additions and 244 deletions

View file

@ -4,8 +4,7 @@ import logging
import textwrap import textwrap
import typing import typing
from collections import OrderedDict from collections import OrderedDict
from aioupnp.upnp import run_cli, UPnP from aioupnp.upnp import run_cli, UPnP, cli_commands
from aioupnp.commands import SOAPCommands
log = logging.getLogger("aioupnp") log = logging.getLogger("aioupnp")
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -20,36 +19,53 @@ base_usage = "\n".join(textwrap.wrap(
def get_help(command: str) -> str: def get_help(command: str) -> str:
annotations = UPnP.get_annotations(command) annotations, doc = UPnP.get_annotations(command)
params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return']) doc = doc or ""
return base_usage + "\n".join(
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False) arg_strs = []
) for k, v in annotations.items():
if k not in ['return', 'igd_args', 'loop']:
t = str(v) if not hasattr(v, "__name__") else v.__name__
if t == 'bool':
arg_strs.append(f"[--{k}]")
else:
arg_strs.append(f"[--{k}=<{t}>]")
elif k == 'igd_args':
arg_strs.append(f"[--<header key>=<header value>, ...]")
params = " ".join(arg_strs)
usage = "\n".join(textwrap.wrap(
f"aioupnp [-h] [--debug_logging] {command} {params}",
100, subsequent_indent=' ', break_long_words=False)) + "\n"
return usage + textwrap.dedent(doc)
def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None, def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int: loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
argv = argv or list(sys.argv) argv = argv or list(sys.argv)
commands = list(SOAPCommands.SOAP_COMMANDS)
help_str = "\n".join(textwrap.wrap( help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False " | ".join(cli_commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
)) ))
usage = \ usage = \
"\n%s\n" \ "%s\n" \
"If m-search headers are provided as keyword arguments all of the headers to be used must be provided,\n" \ "If m-search headers are provided as keyword arguments all of the headers to be used must be provided,\n" \
"in the order they are to be used. For example:\n" \ "in the order they are to be used. For example:\n" \
" aioupnp --HOST=239.255.255.250:1900 --MAN=\"ssdp:discover\" --MX=1 --ST=upnp:rootdevice m_search\n\n" \ " aioupnp --HOST=239.255.255.250:1900 --MAN=\"ssdp:discover\" --MX=1 --ST=upnp:rootdevice m_search\n\n" \
"Commands:\n" \ "Commands:\n" \
"%s\n\n" \ "%s\n\n" \
"For help with a specific command:" \ "For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str) " aioupnp help <command>" % (base_usage, help_str)
args: typing.List[str] = [str(arg) for arg in argv[1:]] args: typing.List[str] = [str(arg) for arg in argv[1:]]
if not args:
print(usage)
return 0
if args[0] in ['help', '-h', '--help']: if args[0] in ['help', '-h', '--help']:
if len(args) > 1: if len(args) > 1:
if args[1] in commands: if args[1].replace("-", "_") in cli_commands:
print(get_help(args[1])) print(get_help(args[1].replace("-", "_")))
return 0 return 0
print(usage) print(usage)
return 0 return 0

View file

@ -2,43 +2,64 @@ import asyncio
import time import time
import typing import typing
import logging import logging
from typing import Tuple
from aioupnp.protocols.scpd import scpd_post from aioupnp.protocols.scpd import scpd_post
from aioupnp.device import Service from aioupnp.device import Service
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]: def soap_optional_str(x: typing.Optional[typing.Union[str, int]]) -> typing.Optional[str]:
return x if x is not None and str(x).lower() not in ['none', 'nil'] else None return str(x) if x is not None and str(x).lower() not in ['none', 'nil'] else None
def soap_bool(x: typing.Optional[str]) -> bool: def soap_bool(x: typing.Optional[typing.Union[str, int]]) -> bool:
return False if not x or str(x).lower() in ['false', 'False'] else True return False if not x or str(x).lower() in ['false', 'False'] else True
def recast_single_result(t: type, result: typing.Any) -> typing.Optional[typing.Union[str, int, float, bool]]: class GetSpecificPortMappingEntryResponse(typing.NamedTuple):
if t is bool: internal_port: int
return soap_bool(result) lan_address: str
if t is str: enabled: bool
return soap_optional_str(result) description: str
return t(result) lease_time: int
class GetGenericPortMappingEntryResponse(typing.NamedTuple):
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]], def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]],
result_keys: typing.List[str]) -> typing.Tuple: result_keys: typing.List[str]) -> typing.Optional[
if return_annotation is None or len(result_keys) == 0: typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
return ()
if len(result_keys) == 1: if len(result_keys) == 1:
assert len(result_keys) == 1
single_result = result[result_keys[0]] single_result = result[result_keys[0]]
return (recast_single_result(return_annotation, single_result), ) if return_annotation is bool:
annotated_args: typing.List[type] = list(return_annotation.__args__) return soap_bool(single_result)
assert len(annotated_args) == len(result_keys) if return_annotation is str:
recast_results: typing.List[typing.Optional[typing.Union[str, int, float, bool]]] = [] return soap_optional_str(single_result)
for type_annotation, result_key in zip(annotated_args, result_keys): return int(result[result_keys[0]]) if result_keys[0] in result else None
recast_results.append(recast_single_result(type_annotation, result.get(result_key, None))) elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]:
return tuple(recast_results) arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types
assert len(arg_types) == len(result_keys)
recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {}
for i, (field_name, result_key) in enumerate(zip(arg_types, result_keys)):
result_field_name = result_keys[i]
field_type = arg_types[field_name]
if field_type is bool:
recast_results[field_name] = soap_bool(result.get(result_field_name, None))
elif field_type is str:
recast_results[field_name] = soap_optional_str(result.get(result_field_name, None))
elif field_type is int:
recast_results[field_name] = int(result[result_field_name]) if result_field_name in result else None
return return_annotation(**recast_results)
return None
class SOAPCommands: class SOAPCommands:
@ -88,7 +109,10 @@ class SOAPCommands:
self._base_address = base_address self._base_address = base_address
self._port = port self._port = port
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
typing.Tuple, typing.Optional[Exception], float]] = [] typing.Optional[typing.Union[str, int, bool,
GetSpecificPortMappingEntryResponse,
GetGenericPortMappingEntryResponse]],
typing.Optional[Exception], float]] = []
def is_registered(self, name: str) -> bool: def is_registered(self, name: str) -> bool:
if name not in self.SOAP_COMMANDS: if name not in self.SOAP_COMMANDS:
@ -112,7 +136,8 @@ class SOAPCommands:
input_names: typing.List[str] = self._registered[service][name][0] input_names: typing.List[str] = self._registered[service][name][0]
output_names: typing.List[str] = self._registered[service][name][1] output_names: typing.List[str] = self._registered[service][name][1]
async def wrapper(**kwargs: typing.Any) -> typing.Tuple: async def wrapper(**kwargs: typing.Any) -> typing.Optional[
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
assert service.controlURL is not None assert service.controlURL is not None
assert service.serviceType is not None assert service.serviceType is not None
@ -122,11 +147,10 @@ class SOAPCommands:
) )
if err is not None: if err is not None:
assert isinstance(xml_bytes, bytes) assert isinstance(xml_bytes, bytes)
self._requests.append((name, kwargs, xml_bytes, (), err, time.time())) self._requests.append((name, kwargs, xml_bytes, None, err, time.time()))
raise err raise err
assert 'return' in annotations assert 'return' in annotations
result = recast_return(annotations['return'], response, output_names) result = recast_return(annotations['return'], response, output_names)
self._requests.append((name, kwargs, xml_bytes, result, None, time.time())) self._requests.append((name, kwargs, xml_bytes, result, None, time.time()))
return result return result
@ -161,8 +185,7 @@ class SOAPCommands:
) )
return None return None
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str, async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> GetGenericPortMappingEntryResponse:
bool, str, int]:
""" """
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled, Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
NewPortMappingDescription, NewLeaseDuration) NewPortMappingDescription, NewLeaseDuration)
@ -171,19 +194,19 @@ class SOAPCommands:
if not self.is_registered(name): if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs assert name in self._wrappers_kwargs
result: Tuple[str, int, str, int, str, bool, str, int] = await self._wrappers_kwargs[name]( result: GetGenericPortMappingEntryResponse = await self._wrappers_kwargs[name](
NewPortMappingIndex=NewPortMappingIndex NewPortMappingIndex=NewPortMappingIndex
) )
return result return result
async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int, async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
NewProtocol: str) -> Tuple[int, str, bool, str, int]: NewProtocol: str) -> GetSpecificPortMappingEntryResponse:
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)""" """Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
name = "GetSpecificPortMappingEntry" name = "GetSpecificPortMappingEntry"
if not self.is_registered(name): if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_kwargs assert name in self._wrappers_kwargs
result: Tuple[int, str, bool, str, int] = await self._wrappers_kwargs[name]( result: GetSpecificPortMappingEntryResponse = await self._wrappers_kwargs[name](
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol
) )
return result return result
@ -205,8 +228,8 @@ class SOAPCommands:
if not self.is_registered(name): if not self.is_registered(name):
raise NotImplementedError() raise NotImplementedError()
assert name in self._wrappers_no_args assert name in self._wrappers_no_args
result: Tuple[str] = await self._wrappers_no_args[name]() result: str = await self._wrappers_no_args[name]()
return result[0] return result
# async def GetNATRSIPStatus(self) -> Tuple[bool, bool]: # async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
# """Returns (NewRSIPAvailable, NewNATEnabled)""" # """Returns (NewRSIPAvailable, NewNATEnabled)"""

View file

@ -3,7 +3,7 @@ import logging
import typing import typing
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Union from typing import Dict, List
from aioupnp.util import get_dict_val_case_insensitive from aioupnp.util import get_dict_val_case_insensitive
from aioupnp.constants import SPEC_VERSION, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
@ -103,74 +103,81 @@ class Gateway:
self._registered_commands: Dict[str, str] = {} self._registered_commands: Dict[str, str] = {}
self.commands = SOAPCommands(self._loop, self.base_ip, self.port) self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
def gateway_descriptor(self) -> dict: # def gateway_descriptor(self) -> dict:
r = { # r = {
'server': self.server.decode(), # 'server': self.server.decode(),
'urlBase': self.url_base, # 'urlBase': self.url_base,
'location': self.location.decode(), # 'location': self.location.decode(),
"specVersion": self.spec_version, # "specVersion": self.spec_version,
'usn': self.usn.decode(), # 'usn': self.usn.decode(),
'urn': self.urn.decode(), # 'urn': self.urn.decode(),
} # }
return r # return r
@property @property
def manufacturer_string(self) -> str: def manufacturer_string(self) -> str:
if not self.devices: manufacturer_string = "UNKNOWN GATEWAY"
return "UNKNOWN GATEWAY" if self.devices:
devices: typing.List[Device] = list(self.devices.values()) devices: typing.List[Device] = list(self.devices.values())
device = devices[0] device = devices[0]
return f"{device.manufacturer} {device.modelName}" manufacturer_string = f"{device.manufacturer} {device.modelName}"
return manufacturer_string
@property @property
def services(self) -> Dict[str, Service]: def services(self) -> Dict[str, Service]:
if not self._device: services: Dict[str, Service] = {}
return {} if self._services:
return {str(service.serviceType): service for service in self._services}
@property
def devices(self) -> Dict:
if not self._device:
return {}
return {device.udn: device for device in self._devices}
def get_service(self, service_type: str) -> typing.Optional[Service]:
for service in self._services: for service in self._services:
if service.serviceType and service.serviceType.lower() == service_type.lower(): if service.serviceType is not None:
return service services[service.serviceType] = service
return None return services
@property @property
def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, def devices(self) -> Dict[str, Device]:
typing.Optional[typing.Tuple], devices: Dict[str, Device] = {}
typing.Optional[Exception], float]]: if self._device:
soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes, for device in self._devices:
typing.Optional[typing.Tuple], if device.udn is not None:
typing.Optional[Exception], float]] = [] devices[device.udn] = device
soap_call_infos.extend([ return devices
(name, request_args, raw_response, decoded_response, soap_error, ts)
for (
name, request_args, raw_response, decoded_response, soap_error, ts
) in self.commands._requests
])
soap_call_infos.sort(key=lambda x: x[5])
return soap_call_infos
def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]: # def get_service(self, service_type: str) -> typing.Optional[Service]:
return { # for service in self._services:
'manufacturer_string': self.manufacturer_string, # if service.serviceType and service.serviceType.lower() == service_type.lower():
'gateway_address': self.base_ip, # return service
'gateway_descriptor': self.gateway_descriptor(), # return None
'gateway_xml': self._xml_response,
'services_xml': self._service_descriptors, # @property
'services': {service.SCPDURL: service.as_dict() for service in self._services}, # def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()], # typing.Optional[typing.Tuple],
'reply': self._ok_packet.as_dict(), # typing.Optional[Exception], float]]:
'soap_port': self.port, # soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
'registered_soap_commands': self._registered_commands, # typing.Optional[typing.Tuple],
'unsupported_soap_commands': self._unsupported_actions, # typing.Optional[Exception], float]] = []
'soap_requests': self.soap_requests # soap_call_infos.extend([
} # (name, request_args, raw_response, decoded_response, soap_error, ts)
# for (
# name, request_args, raw_response, decoded_response, soap_error, ts
# ) in self.commands._requests
# ])
# soap_call_infos.sort(key=lambda x: x[5])
# return soap_call_infos
# def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
# return {
# 'manufacturer_string': self.manufacturer_string,
# 'gateway_address': self.base_ip,
# 'gateway_descriptor': self.gateway_descriptor(),
# 'gateway_xml': self._xml_response,
# 'services_xml': self._service_descriptors,
# 'services': {service.SCPDURL: service.as_dict() for service in self._services},
# 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
# 'reply': self._ok_packet.as_dict(),
# 'soap_port': self.port,
# 'registered_soap_commands': self._registered_commands,
# 'unsupported_soap_commands': self._unsupported_actions,
# 'soap_requests': self.soap_requests
# }
@classmethod @classmethod
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30, async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
@ -206,7 +213,7 @@ class Gateway:
ignored.add(datagram.location) ignored.add(datagram.location)
continue continue
else: else:
log.debug('found gateway device %s', datagram.location) log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
return gateway return gateway
except (asyncio.TimeoutError, UPnPError) as err: except (asyncio.TimeoutError, UPnPError) as err:
assert datagram.location is not None assert datagram.location is not None

View file

@ -2,11 +2,11 @@ import struct
import socket import socket
import typing import typing
from asyncio.protocols import DatagramProtocol from asyncio.protocols import DatagramProtocol
from asyncio.transports import BaseTransport from asyncio.transports import DatagramTransport
from unittest import mock from unittest import mock
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]: def _get_sock(transport: typing.Optional[DatagramTransport]) -> typing.Optional[socket.socket]:
if transport is None or not hasattr(transport, "_extra"): if transport is None or not hasattr(transport, "_extra"):
return None return None
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None) sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
@ -18,7 +18,7 @@ class MulticastProtocol(DatagramProtocol):
def __init__(self, multicast_address: str, bind_address: str) -> None: def __init__(self, multicast_address: str, bind_address: str) -> None:
self.multicast_address = multicast_address self.multicast_address = multicast_address
self.bind_address = bind_address self.bind_address = bind_address
self.transport: typing.Optional[BaseTransport] = None self.transport: typing.Optional[DatagramTransport] = None
@property @property
def sock(self) -> typing.Optional[socket.socket]: def sock(self) -> typing.Optional[socket.socket]:
@ -26,14 +26,13 @@ class MulticastProtocol(DatagramProtocol):
def get_ttl(self) -> int: def get_ttl(self) -> int:
sock = self.sock sock = self.sock
if not sock: if sock:
raise ValueError("not connected")
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL) return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
return 0
def set_ttl(self, ttl: int = 1) -> None: def set_ttl(self, ttl: int = 1) -> None:
sock = self.sock sock = self.sock
if not sock: if sock:
return None
sock.setsockopt( sock.setsockopt(
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl) socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
) )
@ -41,8 +40,7 @@ class MulticastProtocol(DatagramProtocol):
def join_group(self, multicast_address: str, bind_address: str) -> None: def join_group(self, multicast_address: str, bind_address: str) -> None:
sock = self.sock sock = self.sock
if not sock: if sock:
return None
sock.setsockopt( sock.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
@ -51,15 +49,14 @@ class MulticastProtocol(DatagramProtocol):
def leave_group(self, multicast_address: str, bind_address: str) -> None: def leave_group(self, multicast_address: str, bind_address: str) -> None:
sock = self.sock sock = self.sock
if not sock: if sock:
raise ValueError("not connected")
sock.setsockopt( sock.setsockopt(
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address) socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
) )
return None return None
def connection_made(self, transport: BaseTransport) -> None: def connection_made(self, transport: DatagramTransport) -> None: # type: ignore
self.transport = transport self.transport = transport
return None return None

View file

@ -28,19 +28,14 @@ class SSDPProtocol(MulticastProtocol):
self.notifications: typing.List[SSDPDatagram] = [] self.notifications: typing.List[SSDPDatagram] = []
self.connected = asyncio.Event(loop=self.loop) self.connected = asyncio.Event(loop=self.loop)
def connection_made(self, transport) -> None: def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
super().connection_made(transport) super().connection_made(transport)
self.connected.set() self.connected.set()
return None
def disconnect(self) -> None: def disconnect(self) -> None:
if self.transport: if self.transport:
try:
self.leave_group(self.multicast_address, self.bind_address) self.leave_group(self.multicast_address, self.bind_address)
except ValueError:
pass
except Exception:
log.exception("unexpected error leaving multicast group")
self.transport.close() self.transport.close()
self.connected.clear() self.connected.clear()
while self._pending_searches: while self._pending_searches:
@ -50,13 +45,12 @@ class SSDPProtocol(MulticastProtocol):
return None return None
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
if packet.location in self._ignored: if packet.location not in self._ignored:
return None
# TODO: fix this # TODO: fix this
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = [] tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = [] set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
while len(self._pending_searches): while len(self._pending_searches):
t: typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle] = self._pending_searches.pop() t = self._pending_searches.pop()
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]): if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
f = t[2] f = t[2]
if f not in set_futures: if f not in set_futures:
@ -69,10 +63,12 @@ class SSDPProtocol(MulticastProtocol):
self._pending_searches.append(tmp.pop()) self._pending_searches.append(tmp.pop())
return None return None
def _send_m_search(self, address: str, packet: SSDPDatagram) -> None: def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
dest = address if self._unicast else SSDP_IP_ADDRESS dest = address if self._unicast else SSDP_IP_ADDRESS
if not self.transport: if not self.transport:
raise UPnPError("SSDP transport not connected") if not fut.done():
fut.set_exception(UPnPError("SSDP transport not connected"))
return None
log.debug("send m search to %s: %s", dest, packet.st) log.debug("send m search to %s: %s", dest, packet.st)
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT)) self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
return None return None
@ -84,7 +80,7 @@ class SSDPProtocol(MulticastProtocol):
packet = SSDPDatagram("M-SEARCH", datagram) packet = SSDPDatagram("M-SEARCH", datagram)
assert packet.st is not None assert packet.st is not None
self._pending_searches.append( self._pending_searches.append(
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet)) (address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
) )
return await asyncio.wait_for(fut, timeout) return await asyncio.wait_for(fut, timeout)
@ -95,7 +91,7 @@ class SSDPProtocol(MulticastProtocol):
packet = SSDPDatagram.decode(data) packet = SSDPDatagram.decode(data)
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet) log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
except UPnPError as err: except UPnPError as err:
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err, log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
binascii.hexlify(data)) binascii.hexlify(data))
return None return None
@ -131,19 +127,13 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint( listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
) )
transport = listen_result[0]
protocol = listen_result[1] protocol = listen_result[1]
assert isinstance(protocol, SSDPProtocol) assert isinstance(protocol, SSDPProtocol)
except Exception as err: except Exception as err:
print(err)
raise UPnPError(err) raise UPnPError(err)
try: else:
protocol.join_group(protocol.multicast_address, protocol.bind_address) protocol.join_group(protocol.multicast_address, protocol.bind_address)
protocol.set_ttl(1) protocol.set_ttl(1)
except Exception as err:
protocol.disconnect()
raise UPnPError(err)
return protocol, gateway_address, lan_address return protocol, gateway_address, lan_address
@ -178,10 +168,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
log.debug("sending batch of %i M-SEARCH attempts", batch_size) log.debug("sending batch of %i M-SEARCH attempts", batch_size)
try: try:
await protocol.m_search(gateway_address, batch_timeout, args) await protocol.m_search(gateway_address, batch_timeout, args)
protocol.disconnect()
return args
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
else:
protocol.disconnect()
return args
protocol.disconnect() protocol.disconnect()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))

View file

@ -5,13 +5,14 @@ import logging
import json import json
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, Dict, List, Union, Optional from typing import Tuple, Dict, List, Union, Optional, Any
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.interfaces import get_gateway_and_lan_addresses from aioupnp.interfaces import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import SOAPCommands from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -31,11 +32,27 @@ class UPnP:
self.gateway = gateway self.gateway = gateway
@classmethod @classmethod
def get_annotations(cls, command: str) -> Dict[str, type]: def get_annotations(cls, command: str) -> Tuple[Dict[str, Any], Optional[str]]:
return getattr(SOAPCommands, command).__annotations__ if command == "m_search":
return cls.m_search.__annotations__, cls.m_search.__doc__
if command == "get_external_ip":
return cls.get_external_ip.__annotations__, cls.get_external_ip.__doc__
if command == "add_port_mapping":
return cls.add_port_mapping.__annotations__, cls.add_port_mapping.__doc__
if command == "get_port_mapping_by_index":
return cls.get_port_mapping_by_index.__annotations__, cls.get_port_mapping_by_index.__doc__
if command == "get_redirects":
return cls.get_redirects.__annotations__, cls.get_redirects.__doc__
if command == "get_specific_port_mapping":
return cls.get_specific_port_mapping.__annotations__, cls.get_specific_port_mapping.__doc__
if command == "delete_port_mapping":
return cls.delete_port_mapping.__annotations__, cls.delete_port_mapping.__doc__
if command == "get_next_mapping":
return cls.get_next_mapping.__annotations__, cls.get_next_mapping.__doc__
raise AttributeError(command)
@classmethod @staticmethod
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '', def get_lan_and_gateway(lan_address: str = '', gateway_address: str = '',
interface_name: str = 'default') -> Tuple[str, str]: interface_name: str = 'default') -> Tuple[str, str]:
if not lan_address or not gateway_address: if not lan_address or not gateway_address:
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name) gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
@ -55,10 +72,28 @@ class UPnP:
@classmethod @classmethod
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: Optional[Dict[str, Union[int, str]]] = None,
unicast: bool = True, interface_name: str = 'default', unicast: bool = True, interface_name: str = 'default',
igd_args: Optional[Dict[str, Union[str, int]]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None loop: Optional[asyncio.AbstractEventLoop] = None
) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]: ) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
"""
Perform a M-SEARCH for a upnp gateway.
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
:return: {
'lan_address': (str) lan address,
'gateway_address': (str) gateway address,
'm_search_kwargs': (str) equivalent igd_args ,
'discover_reply': (dict) SSDP response datagram
}
"""
if not lan_address or not gateway_address: if not lan_address or not gateway_address:
try: try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name) lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
@ -79,10 +114,25 @@ class UPnP:
} }
async def get_external_ip(self) -> str: async def get_external_ip(self) -> str:
"""
Get the external ip address from the gateway
:return: (str) external ip
"""
return await self.gateway.commands.GetExternalIPAddress() return await self.gateway.commands.GetExternalIPAddress()
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str, async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
description: str) -> None: description: str) -> None:
"""
Add a new port mapping
:param external_port: (int) external port to map
:param protocol: (str) UDP | TCP
:param internal_port: (int) internal port
:param lan_address: (str) internal lan address
:param description: (str) mapping description
:return: None
"""
await self.gateway.commands.AddPortMapping( await self.gateway.commands.AddPortMapping(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
@ -90,11 +140,42 @@ class UPnP:
) )
return None return None
async def get_port_mapping_by_index(self, index: int) -> Tuple[str, int, str, int, str, bool, str, int]: async def get_port_mapping_by_index(self, index: int) -> GetGenericPortMappingEntryResponse:
"""
Get information about a port mapping by index number
:param index: (int) mapping index number
:return: NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]: async def get_redirects(self) -> List[GetGenericPortMappingEntryResponse]:
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = [] """
Get information about all mapped ports
:return: List[
NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
]
"""
redirects: List[GetGenericPortMappingEntryResponse] = []
cnt = 0 cnt = 0
try: try:
redirect = await self.get_port_mapping_by_index(cnt) redirect = await self.get_port_mapping_by_index(cnt)
@ -109,11 +190,19 @@ class UPnP:
break break
return redirects return redirects
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]: async def get_specific_port_mapping(self, external_port: int, protocol: str) -> GetSpecificPortMappingEntryResponse:
""" """
:param external_port: (int) external port to listen on Get information about a port mapping by port number and protocol
:param protocol: (str) 'UDP' | 'TCP'
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time> :param external_port: (int) port number
:param protocol: (str) UDP | TCP
:return: NamedTuple[
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
""" """
return await self.gateway.commands.GetSpecificPortMappingEntry( return await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
@ -121,25 +210,31 @@ class UPnP:
async def delete_port_mapping(self, external_port: int, protocol: str) -> None: async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
""" """
:param external_port: (int) external port to listen on Delete a port mapping
:param protocol: (str) 'UDP' | 'TCP'
:param external_port: (int) port number of mapping
:param protocol: (str) TCP | UDP
:return: None :return: None
""" """
await self.gateway.commands.DeletePortMapping( await self.gateway.commands.DeletePortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
) )
return None return None
async def get_next_mapping(self, port: int, protocol: str, description: str, async def get_next_mapping(self, port: int, protocol: str, description: str,
internal_port: Optional[int] = None) -> int: internal_port: Optional[int] = None) -> int:
""" """
:param port: (int) external port to redirect from Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
:param protocol: (str) 'UDP' | 'TCP'
:param description: (str) mapping description
:param internal_port: (int) internal port to redirect to
:return: (int) <mapped port> :param port: (int) external port
:param protocol: (str) UDP | TCP
:param description: (str) mapping description
:param internal_port: (int) internal port
:return: (int) mapped port
""" """
_internal_port = int(internal_port or port) _internal_port = int(internal_port or port)
requested_port = int(_internal_port) requested_port = int(_internal_port)
port = int(port) port = int(port)
@ -264,30 +359,7 @@ class UPnP:
# return await self.gateway.commands.GetActiveConnections() # return await self.gateway.commands.GetActiveConnections()
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '', cli_commands = [
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
:param lan_address: the ip address of the local interface
:param gateway_address: the ip address of the gateway
:param timeout: timeout, in seconds
:param interface_name: name of the network interface, the default is aliased to 'default'
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
cli_commands = [
'm_search', 'm_search',
'get_external_ip', 'get_external_ip',
'add_port_mapping', 'add_port_mapping',
@ -296,11 +368,24 @@ def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str
'get_specific_port_mapping', 'get_specific_port_mapping',
'delete_port_mapping', 'delete_port_mapping',
'get_next_mapping' 'get_next_mapping'
] ]
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = loop or asyncio.get_event_loop()
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: UPnP.m_search( fn = lambda *_a, **_kw: UPnP.m_search(
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
) )
else: # automatically discover the gateway else: # automatically discover the gateway
try: try:

View file

@ -16,7 +16,7 @@ except ImportError:
@contextlib.contextmanager @contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None, def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None): tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else [] sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {} udp_replies = udp_replies or {}
@ -28,7 +28,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
def _write(data): def _write(data):
sent_tcp_packets.append(data) sent_tcp_packets.append(data)
if data in tcp_replies: if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data]) reply = tcp_replies[data]
i = 0
while i < len(reply):
loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100])
i += 100
return return
else: else:
pass pass
@ -46,6 +50,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
def sendto(p: asyncio.DatagramProtocol): def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr): def _sendto(data, addr):
sent_udp_packets.append(data) sent_udp_packets.append(data)
loop.call_later(udp_delay_reply, p.datagram_received, data,
(p.bind_address, 1900))
if add_potato_datagrams:
loop.call_soon(p.datagram_received, b'potato', ('?.?.?.?', 1900))
if (data, addr) in udp_replies: if (data, addr) in udp_replies:
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)], loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900)) (udp_expected_addr, 1900))

View file

@ -1,23 +1,30 @@
import unittest import unittest
from unittest import mock
import socket
import struct
from asyncio import DatagramTransport from asyncio import DatagramTransport
from aioupnp.protocols.multicast import MulticastProtocol from aioupnp.protocols.multicast import MulticastProtocol
class TestMulticast(unittest.TestCase): class TestMulticast(unittest.TestCase):
def test_it(self): def test_multicast(self):
class none_socket: _ttl = None
sock = None mock_socket = mock.MagicMock(spec=socket.socket)
def getsockopt(*_):
return _ttl
def get(self, name, default=None): def setsockopt(a, b, ttl: bytes):
return default nonlocal _ttl
_ttl, = struct.unpack('b', ttl)
mock_socket.getsockopt = getsockopt
mock_socket.setsockopt = setsockopt
protocol = MulticastProtocol('1.2.3.4', '1.2.3.4') protocol = MulticastProtocol('1.2.3.4', '1.2.3.4')
transport = DatagramTransport(none_socket()) transport = DatagramTransport()
protocol.set_ttl(1) transport._extra = {'socket': mock_socket}
with self.assertRaises(ValueError): self.assertEqual(None, protocol.set_ttl(1))
_ = protocol.get_ttl() self.assertEqual(0, protocol.get_ttl())
protocol.connection_made(transport) protocol.connection_made(transport)
protocol.set_ttl(1) self.assertEqual(None, protocol.set_ttl(1))
with self.assertRaises(ValueError): self.assertEqual(1, protocol.get_ttl())
_ = protocol.get_ttl()

View file

@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError
from aioupnp.protocols.m_search_patterns import packet_generator from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol
from tests import AsyncioTestCase, mock_tcp_and_udp from tests import AsyncioTestCase, mock_tcp_and_udp
@ -28,6 +28,13 @@ class TestSSDP(AsyncioTestCase):
]) ])
reply_packet = SSDPDatagram("OK", reply_args) reply_packet = SSDPDatagram("OK", reply_args)
async def test_transport_not_connected_error(self):
try:
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
self.assertTrue(False)
except UPnPError as err:
self.assertEqual(str(err), "SSDP transport not connected")
async def test_m_search_reply_unicast(self): async def test_m_search_reply_unicast(self):
replies = { replies = {
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode() (self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
@ -80,3 +87,13 @@ class TestSSDP(AsyncioTestCase):
self.assertEqual(reply.encode(), self.reply_packet.encode()) self.assertEqual(reply.encode(), self.reply_packet.encode())
self.assertEqual(args, self.successful_args) self.assertEqual(args, self.successful_args)
async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self):
sent = []
with self.assertRaises(UPnPError):
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent,
add_potato_datagrams=True):
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
self.assertListEqual(sent, self.byte_packets)

View file

@ -21,6 +21,139 @@ m_search_cli_result = """{
}\n""" }\n"""
m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=<str>] [--gateway_address=<str>]
[--timeout=<int>] [--unicast] [--interface_name=<str>] [--<header key>=<header value>, ...]
Perform a M-SEARCH for a upnp gateway.
:param lan_address: (str) the local interface ipv4 address
:param gateway_address: (str) the gateway ipv4 address
:param timeout: (int) m search timeout
:param unicast: (bool) use unicast
:param interface_name: (str) name of the network interface
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
:return: {
'lan_address': (str) lan address,
'gateway_address': (str) gateway address,
'm_search_kwargs': (str) equivalent igd_args ,
'discover_reply': (dict) SSDP response datagram
}\n
"""
expected_usage = """aioupnp [-h] [--debug_logging] [--interface=<interface>] [--gateway_address=<gateway_address>]
[--lan_address=<lan_address>] [--timeout=<timeout>] [(--<header_key>=<value>)...]
If m-search headers are provided as keyword arguments all of the headers to be used must be provided,
in the order they are to be used. For example:
aioupnp --HOST=239.255.255.250:1900 --MAN="ssdp:discover" --MX=1 --ST=upnp:rootdevice m_search
Commands:
m_search | get_external_ip | add_port_mapping | get_port_mapping_by_index | get_redirects |
get_specific_port_mapping | delete_port_mapping | get_next_mapping
For help with a specific command: aioupnp help <command>
"""
expected_get_external_ip_usage = """aioupnp [-h] [--debug_logging] get_external_ip
Get the external ip address from the gateway
:return: (str) external ip
"""
expected_add_port_mapping_usage = """aioupnp [-h] [--debug_logging] add_port_mapping [--external_port=<int>] [--protocol=<str>]
[--internal_port=<int>] [--lan_address=<str>] [--description=<str>]
Add a new port mapping
:param external_port: (int) external port to map
:param protocol: (str) UDP | TCP
:param internal_port: (int) internal port
:param lan_address: (str) internal lan address
:param description: (str) mapping description
:return: None
"""
expected_get_next_mapping_usage = """aioupnp [-h] [--debug_logging] get_next_mapping [--port=<int>] [--protocol=<str>]
[--description=<str>] [--internal_port=<typing.Union[int, NoneType]>]
Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
:param port: (int) external port
:param protocol: (str) UDP | TCP
:param description: (str) mapping description
:param internal_port: (int) internal port
:return: (int) mapped port
"""
expected_delete_port_mapping_usage = """aioupnp [-h] [--debug_logging] delete_port_mapping [--external_port=<int>] [--protocol=<str>]
Delete a port mapping
:param external_port: (int) port number of mapping
:param protocol: (str) TCP | UDP
:return: None
"""
expected_get_specific_port_mapping_usage = """aioupnp [-h] [--debug_logging] get_specific_port_mapping [--external_port=<int>] [--protocol=<str>]
Get information about a port mapping by port number and protocol
:param external_port: (int) port number
:param protocol: (str) UDP | TCP
:return: NamedTuple[
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
expected_get_redirects_usage = """aioupnp [-h] [--debug_logging] get_redirects
Get information about all mapped ports
:return: List[
NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
]
"""
expected_get_port_mapping_by_index_usage = """aioupnp [-h] [--debug_logging] get_port_mapping_by_index [--index=<int>]
Get information about a port mapping by index number
:param index: (int) mapping index number
:return: NamedTuple[
gateway_address: str
external_port: int
protocol: str
internal_port: int
lan_address: str
enabled: bool
description: str
lease_time: int
]
"""
class TestCLI(AsyncioTestCase): class TestCLI(AsyncioTestCase):
gateway_address = "10.0.0.1" gateway_address = "10.0.0.1"
soap_port = 49152 soap_port = 49152
@ -101,3 +234,117 @@ class TestCLI(AsyncioTestCase):
self.loop self.loop
) )
self.assertEqual(m_search_cli_result, actual_output.getvalue()) self.assertEqual(m_search_cli_result, actual_output.getvalue())
def test_usage(self):
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help'],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'test'],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'test', 'help'],
self.loop
)
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'test'],
self.loop
)
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None],
self.loop
)
self.assertEqual(expected_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, "--something=test"],
self.loop
)
self.assertEqual("no command given\n" + expected_usage, actual_output.getvalue())
def test_commands_help(self):
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'm-search'],
self.loop
)
self.assertEqual(m_search_help_msg, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get-external-ip'],
self.loop
)
self.assertEqual(expected_get_external_ip_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'add-port-mapping'],
self.loop
)
self.assertEqual(expected_add_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get-next-mapping'],
self.loop
)
self.assertEqual(expected_get_next_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'delete_port_mapping'],
self.loop
)
self.assertEqual(expected_delete_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_specific_port_mapping'],
self.loop
)
self.assertEqual(expected_get_specific_port_mapping_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_redirects'],
self.loop
)
self.assertEqual(expected_get_redirects_usage, actual_output.getvalue())
actual_output = StringIO()
with contextlib.redirect_stdout(actual_output):
main(
[None, 'help', 'get_port_mapping_by_index'],
self.loop
)
self.assertEqual(expected_get_port_mapping_by_index_usage, actual_output.getvalue())

View file

@ -4,17 +4,7 @@ from aioupnp.upnp import UPnP
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.commands import GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse
class TestGetAnnotations(AsyncioTestCase):
def test_get_annotations(self):
expected = {
'NewRemoteHost': str, 'NewExternalPort': int, 'NewProtocol': str, 'NewInternalPort': int,
'NewInternalClient': str, 'NewEnabled': int, 'NewPortMappingDescription': str,
'NewLeaseDuration': str, 'return': None
}
self.assertDictEqual(expected, UPnP.get_annotations('AddPortMapping'))
class UPnPCommandTestCase(AsyncioTestCase): class UPnPCommandTestCase(AsyncioTestCase):
@ -76,7 +66,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
await gateway.discover_commands(self.loop) await gateway.discover_commands(self.loop)
upnp = UPnP(self.client_address, self.gateway_address, gateway) upnp = UPnP(self.client_address, self.gateway_address, gateway)
result = await upnp.get_port_mapping_by_index(0) result = await upnp.get_port_mapping_by_index(0)
self.assertEqual((None, 9308, 'UDP', 9308, "11.2.3.44", True, "11.2.3.44:9308 to 9308 (UDP)", 0), result) self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True,
"11.2.3.44:9308 to 9308 (UDP)", 0), result)
class TestGetNextPortMapping(UPnPCommandTestCase): class TestGetNextPortMapping(UPnPCommandTestCase):
@ -120,6 +111,9 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
await upnp.get_specific_port_mapping(1000, 'UDP') await upnp.get_specific_port_mapping(1000, 'UDP')
except UPnPError: except UPnPError:
result = await upnp.get_specific_port_mapping(9308, 'UDP') result = await upnp.get_specific_port_mapping(9308, 'UDP')
self.assertEqual((9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0), result) self.assertEqual(
GetSpecificPortMappingEntryResponse(9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0),
result
)
else: else:
self.assertTrue(False) self.assertTrue(False)