more mypy+coverage refactoring
This commit is contained in:
parent
e2ad340868
commit
5356b78e77
11 changed files with 637 additions and 244 deletions
|
@ -4,8 +4,7 @@ import logging
|
|||
import textwrap
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from aioupnp.upnp import run_cli, UPnP
|
||||
from aioupnp.commands import SOAPCommands
|
||||
from aioupnp.upnp import run_cli, UPnP, cli_commands
|
||||
|
||||
log = logging.getLogger("aioupnp")
|
||||
handler = logging.StreamHandler()
|
||||
|
@ -20,36 +19,53 @@ base_usage = "\n".join(textwrap.wrap(
|
|||
|
||||
|
||||
def get_help(command: str) -> str:
|
||||
annotations = UPnP.get_annotations(command)
|
||||
params = command + " " + " ".join(["[--%s=<%s>]" % (k, str(v)) for k, v in annotations.items() if k != 'return'])
|
||||
return base_usage + "\n".join(
|
||||
textwrap.wrap(params, 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False)
|
||||
)
|
||||
annotations, doc = UPnP.get_annotations(command)
|
||||
doc = doc or ""
|
||||
|
||||
arg_strs = []
|
||||
for k, v in annotations.items():
|
||||
if k not in ['return', 'igd_args', 'loop']:
|
||||
t = str(v) if not hasattr(v, "__name__") else v.__name__
|
||||
if t == 'bool':
|
||||
arg_strs.append(f"[--{k}]")
|
||||
else:
|
||||
arg_strs.append(f"[--{k}=<{t}>]")
|
||||
elif k == 'igd_args':
|
||||
arg_strs.append(f"[--<header key>=<header value>, ...]")
|
||||
|
||||
params = " ".join(arg_strs)
|
||||
usage = "\n".join(textwrap.wrap(
|
||||
f"aioupnp [-h] [--debug_logging] {command} {params}",
|
||||
100, subsequent_indent=' ', break_long_words=False)) + "\n"
|
||||
|
||||
return usage + textwrap.dedent(doc)
|
||||
|
||||
|
||||
def main(argv: typing.Optional[typing.List[typing.Optional[str]]] = None,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None) -> int:
|
||||
argv = argv or list(sys.argv)
|
||||
commands = list(SOAPCommands.SOAP_COMMANDS)
|
||||
help_str = "\n".join(textwrap.wrap(
|
||||
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
|
||||
" | ".join(cli_commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
|
||||
))
|
||||
|
||||
usage = \
|
||||
"\n%s\n" \
|
||||
"%s\n" \
|
||||
"If m-search headers are provided as keyword arguments all of the headers to be used must be provided,\n" \
|
||||
"in the order they are to be used. For example:\n" \
|
||||
" aioupnp --HOST=239.255.255.250:1900 --MAN=\"ssdp:discover\" --MX=1 --ST=upnp:rootdevice m_search\n\n" \
|
||||
"Commands:\n" \
|
||||
"%s\n\n" \
|
||||
"For help with a specific command:" \
|
||||
" aioupnp help <command>\n" % (base_usage, help_str)
|
||||
" aioupnp help <command>" % (base_usage, help_str)
|
||||
|
||||
args: typing.List[str] = [str(arg) for arg in argv[1:]]
|
||||
if not args:
|
||||
print(usage)
|
||||
return 0
|
||||
if args[0] in ['help', '-h', '--help']:
|
||||
if len(args) > 1:
|
||||
if args[1] in commands:
|
||||
print(get_help(args[1]))
|
||||
if args[1].replace("-", "_") in cli_commands:
|
||||
print(get_help(args[1].replace("-", "_")))
|
||||
return 0
|
||||
print(usage)
|
||||
return 0
|
||||
|
|
|
@ -2,43 +2,64 @@ import asyncio
|
|||
import time
|
||||
import typing
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from aioupnp.protocols.scpd import scpd_post
|
||||
from aioupnp.device import Service
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def soap_optional_str(x: typing.Optional[str]) -> typing.Optional[str]:
|
||||
return x if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||
def soap_optional_str(x: typing.Optional[typing.Union[str, int]]) -> typing.Optional[str]:
|
||||
return str(x) if x is not None and str(x).lower() not in ['none', 'nil'] else None
|
||||
|
||||
|
||||
def soap_bool(x: typing.Optional[str]) -> bool:
|
||||
def soap_bool(x: typing.Optional[typing.Union[str, int]]) -> bool:
|
||||
return False if not x or str(x).lower() in ['false', 'False'] else True
|
||||
|
||||
|
||||
def recast_single_result(t: type, result: typing.Any) -> typing.Optional[typing.Union[str, int, float, bool]]:
|
||||
if t is bool:
|
||||
return soap_bool(result)
|
||||
if t is str:
|
||||
return soap_optional_str(result)
|
||||
return t(result)
|
||||
class GetSpecificPortMappingEntryResponse(typing.NamedTuple):
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
|
||||
|
||||
class GetGenericPortMappingEntryResponse(typing.NamedTuple):
|
||||
gateway_address: str
|
||||
external_port: int
|
||||
protocol: str
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
|
||||
|
||||
def recast_return(return_annotation, result: typing.Dict[str, typing.Union[int, str]],
|
||||
result_keys: typing.List[str]) -> typing.Tuple:
|
||||
if return_annotation is None or len(result_keys) == 0:
|
||||
return ()
|
||||
result_keys: typing.List[str]) -> typing.Optional[
|
||||
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
|
||||
if len(result_keys) == 1:
|
||||
assert len(result_keys) == 1
|
||||
single_result = result[result_keys[0]]
|
||||
return (recast_single_result(return_annotation, single_result), )
|
||||
annotated_args: typing.List[type] = list(return_annotation.__args__)
|
||||
assert len(annotated_args) == len(result_keys)
|
||||
recast_results: typing.List[typing.Optional[typing.Union[str, int, float, bool]]] = []
|
||||
for type_annotation, result_key in zip(annotated_args, result_keys):
|
||||
recast_results.append(recast_single_result(type_annotation, result.get(result_key, None)))
|
||||
return tuple(recast_results)
|
||||
if return_annotation is bool:
|
||||
return soap_bool(single_result)
|
||||
if return_annotation is str:
|
||||
return soap_optional_str(single_result)
|
||||
return int(result[result_keys[0]]) if result_keys[0] in result else None
|
||||
elif return_annotation in [GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse]:
|
||||
arg_types: typing.Dict[str, typing.Type[typing.Any]] = return_annotation._field_types
|
||||
assert len(arg_types) == len(result_keys)
|
||||
recast_results: typing.Dict[str, typing.Optional[typing.Union[str, int, bool]]] = {}
|
||||
for i, (field_name, result_key) in enumerate(zip(arg_types, result_keys)):
|
||||
result_field_name = result_keys[i]
|
||||
field_type = arg_types[field_name]
|
||||
if field_type is bool:
|
||||
recast_results[field_name] = soap_bool(result.get(result_field_name, None))
|
||||
elif field_type is str:
|
||||
recast_results[field_name] = soap_optional_str(result.get(result_field_name, None))
|
||||
elif field_type is int:
|
||||
recast_results[field_name] = int(result[result_field_name]) if result_field_name in result else None
|
||||
return return_annotation(**recast_results)
|
||||
return None
|
||||
|
||||
|
||||
class SOAPCommands:
|
||||
|
@ -88,7 +109,10 @@ class SOAPCommands:
|
|||
self._base_address = base_address
|
||||
self._port = port
|
||||
self._requests: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Tuple, typing.Optional[Exception], float]] = []
|
||||
typing.Optional[typing.Union[str, int, bool,
|
||||
GetSpecificPortMappingEntryResponse,
|
||||
GetGenericPortMappingEntryResponse]],
|
||||
typing.Optional[Exception], float]] = []
|
||||
|
||||
def is_registered(self, name: str) -> bool:
|
||||
if name not in self.SOAP_COMMANDS:
|
||||
|
@ -112,7 +136,8 @@ class SOAPCommands:
|
|||
input_names: typing.List[str] = self._registered[service][name][0]
|
||||
output_names: typing.List[str] = self._registered[service][name][1]
|
||||
|
||||
async def wrapper(**kwargs: typing.Any) -> typing.Tuple:
|
||||
async def wrapper(**kwargs: typing.Any) -> typing.Optional[
|
||||
typing.Union[str, int, bool, GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse]]:
|
||||
|
||||
assert service.controlURL is not None
|
||||
assert service.serviceType is not None
|
||||
|
@ -122,11 +147,10 @@ class SOAPCommands:
|
|||
)
|
||||
if err is not None:
|
||||
assert isinstance(xml_bytes, bytes)
|
||||
self._requests.append((name, kwargs, xml_bytes, (), err, time.time()))
|
||||
self._requests.append((name, kwargs, xml_bytes, None, err, time.time()))
|
||||
raise err
|
||||
assert 'return' in annotations
|
||||
result = recast_return(annotations['return'], response, output_names)
|
||||
|
||||
self._requests.append((name, kwargs, xml_bytes, result, None, time.time()))
|
||||
return result
|
||||
|
||||
|
@ -161,8 +185,7 @@ class SOAPCommands:
|
|||
)
|
||||
return None
|
||||
|
||||
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> Tuple[str, int, str, int, str,
|
||||
bool, str, int]:
|
||||
async def GetGenericPortMappingEntry(self, NewPortMappingIndex: int) -> GetGenericPortMappingEntryResponse:
|
||||
"""
|
||||
Returns (NewRemoteHost, NewExternalPort, NewProtocol, NewInternalPort, NewInternalClient, NewEnabled,
|
||||
NewPortMappingDescription, NewLeaseDuration)
|
||||
|
@ -171,19 +194,19 @@ class SOAPCommands:
|
|||
if not self.is_registered(name):
|
||||
raise NotImplementedError()
|
||||
assert name in self._wrappers_kwargs
|
||||
result: Tuple[str, int, str, int, str, bool, str, int] = await self._wrappers_kwargs[name](
|
||||
result: GetGenericPortMappingEntryResponse = await self._wrappers_kwargs[name](
|
||||
NewPortMappingIndex=NewPortMappingIndex
|
||||
)
|
||||
return result
|
||||
|
||||
async def GetSpecificPortMappingEntry(self, NewRemoteHost: str, NewExternalPort: int,
|
||||
NewProtocol: str) -> Tuple[int, str, bool, str, int]:
|
||||
NewProtocol: str) -> GetSpecificPortMappingEntryResponse:
|
||||
"""Returns (NewInternalPort, NewInternalClient, NewEnabled, NewPortMappingDescription, NewLeaseDuration)"""
|
||||
name = "GetSpecificPortMappingEntry"
|
||||
if not self.is_registered(name):
|
||||
raise NotImplementedError()
|
||||
assert name in self._wrappers_kwargs
|
||||
result: Tuple[int, str, bool, str, int] = await self._wrappers_kwargs[name](
|
||||
result: GetSpecificPortMappingEntryResponse = await self._wrappers_kwargs[name](
|
||||
NewRemoteHost=NewRemoteHost, NewExternalPort=NewExternalPort, NewProtocol=NewProtocol
|
||||
)
|
||||
return result
|
||||
|
@ -205,8 +228,8 @@ class SOAPCommands:
|
|||
if not self.is_registered(name):
|
||||
raise NotImplementedError()
|
||||
assert name in self._wrappers_no_args
|
||||
result: Tuple[str] = await self._wrappers_no_args[name]()
|
||||
return result[0]
|
||||
result: str = await self._wrappers_no_args[name]()
|
||||
return result
|
||||
|
||||
# async def GetNATRSIPStatus(self) -> Tuple[bool, bool]:
|
||||
# """Returns (NewRSIPAvailable, NewNATEnabled)"""
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import typing
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List
|
||||
from aioupnp.util import get_dict_val_case_insensitive
|
||||
from aioupnp.constants import SPEC_VERSION, SERVICE
|
||||
from aioupnp.commands import SOAPCommands
|
||||
|
@ -103,74 +103,81 @@ class Gateway:
|
|||
self._registered_commands: Dict[str, str] = {}
|
||||
self.commands = SOAPCommands(self._loop, self.base_ip, self.port)
|
||||
|
||||
def gateway_descriptor(self) -> dict:
|
||||
r = {
|
||||
'server': self.server.decode(),
|
||||
'urlBase': self.url_base,
|
||||
'location': self.location.decode(),
|
||||
"specVersion": self.spec_version,
|
||||
'usn': self.usn.decode(),
|
||||
'urn': self.urn.decode(),
|
||||
}
|
||||
return r
|
||||
# def gateway_descriptor(self) -> dict:
|
||||
# r = {
|
||||
# 'server': self.server.decode(),
|
||||
# 'urlBase': self.url_base,
|
||||
# 'location': self.location.decode(),
|
||||
# "specVersion": self.spec_version,
|
||||
# 'usn': self.usn.decode(),
|
||||
# 'urn': self.urn.decode(),
|
||||
# }
|
||||
# return r
|
||||
|
||||
@property
|
||||
def manufacturer_string(self) -> str:
|
||||
if not self.devices:
|
||||
return "UNKNOWN GATEWAY"
|
||||
devices: typing.List[Device] = list(self.devices.values())
|
||||
device = devices[0]
|
||||
return f"{device.manufacturer} {device.modelName}"
|
||||
manufacturer_string = "UNKNOWN GATEWAY"
|
||||
if self.devices:
|
||||
devices: typing.List[Device] = list(self.devices.values())
|
||||
device = devices[0]
|
||||
manufacturer_string = f"{device.manufacturer} {device.modelName}"
|
||||
return manufacturer_string
|
||||
|
||||
@property
|
||||
def services(self) -> Dict[str, Service]:
|
||||
if not self._device:
|
||||
return {}
|
||||
return {str(service.serviceType): service for service in self._services}
|
||||
services: Dict[str, Service] = {}
|
||||
if self._services:
|
||||
for service in self._services:
|
||||
if service.serviceType is not None:
|
||||
services[service.serviceType] = service
|
||||
return services
|
||||
|
||||
@property
|
||||
def devices(self) -> Dict:
|
||||
if not self._device:
|
||||
return {}
|
||||
return {device.udn: device for device in self._devices}
|
||||
def devices(self) -> Dict[str, Device]:
|
||||
devices: Dict[str, Device] = {}
|
||||
if self._device:
|
||||
for device in self._devices:
|
||||
if device.udn is not None:
|
||||
devices[device.udn] = device
|
||||
return devices
|
||||
|
||||
def get_service(self, service_type: str) -> typing.Optional[Service]:
|
||||
for service in self._services:
|
||||
if service.serviceType and service.serviceType.lower() == service_type.lower():
|
||||
return service
|
||||
return None
|
||||
# def get_service(self, service_type: str) -> typing.Optional[Service]:
|
||||
# for service in self._services:
|
||||
# if service.serviceType and service.serviceType.lower() == service_type.lower():
|
||||
# return service
|
||||
# return None
|
||||
|
||||
@property
|
||||
def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[typing.Tuple],
|
||||
typing.Optional[Exception], float]]:
|
||||
soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
typing.Optional[typing.Tuple],
|
||||
typing.Optional[Exception], float]] = []
|
||||
soap_call_infos.extend([
|
||||
(name, request_args, raw_response, decoded_response, soap_error, ts)
|
||||
for (
|
||||
name, request_args, raw_response, decoded_response, soap_error, ts
|
||||
) in self.commands._requests
|
||||
])
|
||||
soap_call_infos.sort(key=lambda x: x[5])
|
||||
return soap_call_infos
|
||||
# @property
|
||||
# def soap_requests(self) -> typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
# typing.Optional[typing.Tuple],
|
||||
# typing.Optional[Exception], float]]:
|
||||
# soap_call_infos: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any], bytes,
|
||||
# typing.Optional[typing.Tuple],
|
||||
# typing.Optional[Exception], float]] = []
|
||||
# soap_call_infos.extend([
|
||||
# (name, request_args, raw_response, decoded_response, soap_error, ts)
|
||||
# for (
|
||||
# name, request_args, raw_response, decoded_response, soap_error, ts
|
||||
# ) in self.commands._requests
|
||||
# ])
|
||||
# soap_call_infos.sort(key=lambda x: x[5])
|
||||
# return soap_call_infos
|
||||
|
||||
def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
|
||||
return {
|
||||
'manufacturer_string': self.manufacturer_string,
|
||||
'gateway_address': self.base_ip,
|
||||
'gateway_descriptor': self.gateway_descriptor(),
|
||||
'gateway_xml': self._xml_response,
|
||||
'services_xml': self._service_descriptors,
|
||||
'services': {service.SCPDURL: service.as_dict() for service in self._services},
|
||||
'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
|
||||
'reply': self._ok_packet.as_dict(),
|
||||
'soap_port': self.port,
|
||||
'registered_soap_commands': self._registered_commands,
|
||||
'unsupported_soap_commands': self._unsupported_actions,
|
||||
'soap_requests': self.soap_requests
|
||||
}
|
||||
# def debug_gateway(self) -> Dict[str, Union[str, bytes, int, Dict, List]]:
|
||||
# return {
|
||||
# 'manufacturer_string': self.manufacturer_string,
|
||||
# 'gateway_address': self.base_ip,
|
||||
# 'gateway_descriptor': self.gateway_descriptor(),
|
||||
# 'gateway_xml': self._xml_response,
|
||||
# 'services_xml': self._service_descriptors,
|
||||
# 'services': {service.SCPDURL: service.as_dict() for service in self._services},
|
||||
# 'm_search_args': [(k, v) for (k, v) in self._m_search_args.items()],
|
||||
# 'reply': self._ok_packet.as_dict(),
|
||||
# 'soap_port': self.port,
|
||||
# 'registered_soap_commands': self._registered_commands,
|
||||
# 'unsupported_soap_commands': self._unsupported_actions,
|
||||
# 'soap_requests': self.soap_requests
|
||||
# }
|
||||
|
||||
@classmethod
|
||||
async def _discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
|
||||
|
@ -206,7 +213,7 @@ class Gateway:
|
|||
ignored.add(datagram.location)
|
||||
continue
|
||||
else:
|
||||
log.debug('found gateway device %s', datagram.location)
|
||||
log.debug('found gateway %s at %s', gateway.manufacturer_string or "device", datagram.location)
|
||||
return gateway
|
||||
except (asyncio.TimeoutError, UPnPError) as err:
|
||||
assert datagram.location is not None
|
||||
|
|
|
@ -2,11 +2,11 @@ import struct
|
|||
import socket
|
||||
import typing
|
||||
from asyncio.protocols import DatagramProtocol
|
||||
from asyncio.transports import BaseTransport
|
||||
from asyncio.transports import DatagramTransport
|
||||
from unittest import mock
|
||||
|
||||
|
||||
def _get_sock(transport: typing.Optional[BaseTransport]) -> typing.Optional[socket.socket]:
|
||||
def _get_sock(transport: typing.Optional[DatagramTransport]) -> typing.Optional[socket.socket]:
|
||||
if transport is None or not hasattr(transport, "_extra"):
|
||||
return None
|
||||
sock: typing.Optional[socket.socket] = transport.get_extra_info('socket', None)
|
||||
|
@ -18,7 +18,7 @@ class MulticastProtocol(DatagramProtocol):
|
|||
def __init__(self, multicast_address: str, bind_address: str) -> None:
|
||||
self.multicast_address = multicast_address
|
||||
self.bind_address = bind_address
|
||||
self.transport: typing.Optional[BaseTransport] = None
|
||||
self.transport: typing.Optional[DatagramTransport] = None
|
||||
|
||||
@property
|
||||
def sock(self) -> typing.Optional[socket.socket]:
|
||||
|
@ -26,40 +26,37 @@ class MulticastProtocol(DatagramProtocol):
|
|||
|
||||
def get_ttl(self) -> int:
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
raise ValueError("not connected")
|
||||
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
||||
if sock:
|
||||
return sock.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
|
||||
return 0
|
||||
|
||||
def set_ttl(self, ttl: int = 1) -> None:
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
return None
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
|
||||
)
|
||||
if sock:
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, struct.pack('b', ttl)
|
||||
)
|
||||
return None
|
||||
|
||||
def join_group(self, multicast_address: str, bind_address: str) -> None:
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
return None
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
if sock:
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
return None
|
||||
|
||||
def leave_group(self, multicast_address: str, bind_address: str) -> None:
|
||||
sock = self.sock
|
||||
if not sock:
|
||||
raise ValueError("not connected")
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
if sock:
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
|
||||
socket.inet_aton(multicast_address) + socket.inet_aton(bind_address)
|
||||
)
|
||||
return None
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
def connection_made(self, transport: DatagramTransport) -> None: # type: ignore
|
||||
self.transport = transport
|
||||
return None
|
||||
|
||||
|
|
|
@ -28,19 +28,14 @@ class SSDPProtocol(MulticastProtocol):
|
|||
self.notifications: typing.List[SSDPDatagram] = []
|
||||
self.connected = asyncio.Event(loop=self.loop)
|
||||
|
||||
def connection_made(self, transport) -> None:
|
||||
# assert isinstance(transport, asyncio.DatagramTransport), str(type(transport))
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
|
||||
super().connection_made(transport)
|
||||
self.connected.set()
|
||||
return None
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.transport:
|
||||
try:
|
||||
self.leave_group(self.multicast_address, self.bind_address)
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("unexpected error leaving multicast group")
|
||||
self.leave_group(self.multicast_address, self.bind_address)
|
||||
self.transport.close()
|
||||
self.connected.clear()
|
||||
while self._pending_searches:
|
||||
|
@ -50,29 +45,30 @@ class SSDPProtocol(MulticastProtocol):
|
|||
return None
|
||||
|
||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||
if packet.location in self._ignored:
|
||||
return None
|
||||
# TODO: fix this
|
||||
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
|
||||
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
|
||||
while len(self._pending_searches):
|
||||
t: typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle] = self._pending_searches.pop()
|
||||
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
|
||||
f = t[2]
|
||||
if f not in set_futures:
|
||||
set_futures.append(f)
|
||||
if not f.done():
|
||||
f.set_result(packet)
|
||||
elif t[2] not in set_futures:
|
||||
tmp.append(t)
|
||||
while tmp:
|
||||
self._pending_searches.append(tmp.pop())
|
||||
if packet.location not in self._ignored:
|
||||
# TODO: fix this
|
||||
tmp: typing.List[typing.Tuple[str, str, 'asyncio.Future[SSDPDatagram]', asyncio.Handle]] = []
|
||||
set_futures: typing.List['asyncio.Future[SSDPDatagram]'] = []
|
||||
while len(self._pending_searches):
|
||||
t = self._pending_searches.pop()
|
||||
if (address == t[0]) and (t[1] in [packet.st, "upnp:rootdevice"]):
|
||||
f = t[2]
|
||||
if f not in set_futures:
|
||||
set_futures.append(f)
|
||||
if not f.done():
|
||||
f.set_result(packet)
|
||||
elif t[2] not in set_futures:
|
||||
tmp.append(t)
|
||||
while tmp:
|
||||
self._pending_searches.append(tmp.pop())
|
||||
return None
|
||||
|
||||
def _send_m_search(self, address: str, packet: SSDPDatagram) -> None:
|
||||
def _send_m_search(self, address: str, packet: SSDPDatagram, fut: 'asyncio.Future[SSDPDatagram]') -> None:
|
||||
dest = address if self._unicast else SSDP_IP_ADDRESS
|
||||
if not self.transport:
|
||||
raise UPnPError("SSDP transport not connected")
|
||||
if not fut.done():
|
||||
fut.set_exception(UPnPError("SSDP transport not connected"))
|
||||
return None
|
||||
log.debug("send m search to %s: %s", dest, packet.st)
|
||||
self.transport.sendto(packet.encode().encode(), (dest, SSDP_PORT))
|
||||
return None
|
||||
|
@ -84,7 +80,7 @@ class SSDPProtocol(MulticastProtocol):
|
|||
packet = SSDPDatagram("M-SEARCH", datagram)
|
||||
assert packet.st is not None
|
||||
self._pending_searches.append(
|
||||
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet))
|
||||
(address, packet.st, fut, self.loop.call_soon(self._send_m_search, address, packet, fut))
|
||||
)
|
||||
return await asyncio.wait_for(fut, timeout)
|
||||
|
||||
|
@ -95,8 +91,8 @@ class SSDPProtocol(MulticastProtocol):
|
|||
packet = SSDPDatagram.decode(data)
|
||||
log.debug("decoded packet from %s:%i: %s", addr[0], addr[1], packet)
|
||||
except UPnPError as err:
|
||||
log.error("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
|
||||
binascii.hexlify(data))
|
||||
log.warning("failed to decode SSDP packet from %s:%i (%s): %s", addr[0], addr[1], err,
|
||||
binascii.hexlify(data))
|
||||
return None
|
||||
|
||||
if packet._packet_type == packet._OK:
|
||||
|
@ -131,19 +127,13 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: typing.Optio
|
|||
listen_result: typing.Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored, unicast), sock=sock
|
||||
)
|
||||
transport = listen_result[0]
|
||||
protocol = listen_result[1]
|
||||
assert isinstance(protocol, SSDPProtocol)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise UPnPError(err)
|
||||
try:
|
||||
else:
|
||||
protocol.join_group(protocol.multicast_address, protocol.bind_address)
|
||||
protocol.set_ttl(1)
|
||||
except Exception as err:
|
||||
protocol.disconnect()
|
||||
raise UPnPError(err)
|
||||
|
||||
return protocol, gateway_address, lan_address
|
||||
|
||||
|
||||
|
@ -178,10 +168,11 @@ async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int =
|
|||
log.debug("sending batch of %i M-SEARCH attempts", batch_size)
|
||||
try:
|
||||
await protocol.m_search(gateway_address, batch_timeout, args)
|
||||
protocol.disconnect()
|
||||
return args
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
else:
|
||||
protocol.disconnect()
|
||||
return args
|
||||
protocol.disconnect()
|
||||
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
|
||||
|
||||
|
|
181
aioupnp/upnp.py
181
aioupnp/upnp.py
|
@ -5,13 +5,14 @@ import logging
|
|||
import json
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Dict, List, Union, Optional
|
||||
from typing import Tuple, Dict, List, Union, Optional, Any
|
||||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.gateway import Gateway
|
||||
from aioupnp.interfaces import get_gateway_and_lan_addresses
|
||||
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.commands import SOAPCommands
|
||||
from aioupnp.commands import GetGenericPortMappingEntryResponse, GetSpecificPortMappingEntryResponse
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,11 +32,27 @@ class UPnP:
|
|||
self.gateway = gateway
|
||||
|
||||
@classmethod
|
||||
def get_annotations(cls, command: str) -> Dict[str, type]:
|
||||
return getattr(SOAPCommands, command).__annotations__
|
||||
def get_annotations(cls, command: str) -> Tuple[Dict[str, Any], Optional[str]]:
|
||||
if command == "m_search":
|
||||
return cls.m_search.__annotations__, cls.m_search.__doc__
|
||||
if command == "get_external_ip":
|
||||
return cls.get_external_ip.__annotations__, cls.get_external_ip.__doc__
|
||||
if command == "add_port_mapping":
|
||||
return cls.add_port_mapping.__annotations__, cls.add_port_mapping.__doc__
|
||||
if command == "get_port_mapping_by_index":
|
||||
return cls.get_port_mapping_by_index.__annotations__, cls.get_port_mapping_by_index.__doc__
|
||||
if command == "get_redirects":
|
||||
return cls.get_redirects.__annotations__, cls.get_redirects.__doc__
|
||||
if command == "get_specific_port_mapping":
|
||||
return cls.get_specific_port_mapping.__annotations__, cls.get_specific_port_mapping.__doc__
|
||||
if command == "delete_port_mapping":
|
||||
return cls.delete_port_mapping.__annotations__, cls.delete_port_mapping.__doc__
|
||||
if command == "get_next_mapping":
|
||||
return cls.get_next_mapping.__annotations__, cls.get_next_mapping.__doc__
|
||||
raise AttributeError(command)
|
||||
|
||||
@classmethod
|
||||
def get_lan_and_gateway(cls, lan_address: str = '', gateway_address: str = '',
|
||||
@staticmethod
|
||||
def get_lan_and_gateway(lan_address: str = '', gateway_address: str = '',
|
||||
interface_name: str = 'default') -> Tuple[str, str]:
|
||||
if not lan_address or not gateway_address:
|
||||
gateway_addr, lan_addr = get_gateway_and_lan_addresses(interface_name)
|
||||
|
@ -55,10 +72,28 @@ class UPnP:
|
|||
|
||||
@classmethod
|
||||
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
|
||||
igd_args: Optional[Dict[str, Union[int, str]]] = None,
|
||||
unicast: bool = True, interface_name: str = 'default',
|
||||
igd_args: Optional[Dict[str, Union[str, int]]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
) -> Dict[str, Union[str, Dict[str, Union[int, str]]]]:
|
||||
) -> Dict[str, Union[str, Dict[str, Union[str, int]]]]:
|
||||
"""
|
||||
Perform a M-SEARCH for a upnp gateway.
|
||||
|
||||
:param lan_address: (str) the local interface ipv4 address
|
||||
:param gateway_address: (str) the gateway ipv4 address
|
||||
:param timeout: (int) m search timeout
|
||||
:param unicast: (bool) use unicast
|
||||
:param interface_name: (str) name of the network interface
|
||||
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
|
||||
|
||||
:return: {
|
||||
'lan_address': (str) lan address,
|
||||
'gateway_address': (str) gateway address,
|
||||
'm_search_kwargs': (str) equivalent igd_args ,
|
||||
'discover_reply': (dict) SSDP response datagram
|
||||
}
|
||||
"""
|
||||
|
||||
if not lan_address or not gateway_address:
|
||||
try:
|
||||
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
|
||||
|
@ -79,10 +114,25 @@ class UPnP:
|
|||
}
|
||||
|
||||
async def get_external_ip(self) -> str:
|
||||
"""
|
||||
Get the external ip address from the gateway
|
||||
|
||||
:return: (str) external ip
|
||||
"""
|
||||
return await self.gateway.commands.GetExternalIPAddress()
|
||||
|
||||
async def add_port_mapping(self, external_port: int, protocol: str, internal_port: int, lan_address: str,
|
||||
description: str) -> None:
|
||||
"""
|
||||
Add a new port mapping
|
||||
|
||||
:param external_port: (int) external port to map
|
||||
:param protocol: (str) UDP | TCP
|
||||
:param internal_port: (int) internal port
|
||||
:param lan_address: (str) internal lan address
|
||||
:param description: (str) mapping description
|
||||
:return: None
|
||||
"""
|
||||
await self.gateway.commands.AddPortMapping(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol,
|
||||
NewInternalPort=internal_port, NewInternalClient=lan_address,
|
||||
|
@ -90,11 +140,42 @@ class UPnP:
|
|||
)
|
||||
return None
|
||||
|
||||
async def get_port_mapping_by_index(self, index: int) -> Tuple[str, int, str, int, str, bool, str, int]:
|
||||
async def get_port_mapping_by_index(self, index: int) -> GetGenericPortMappingEntryResponse:
|
||||
"""
|
||||
Get information about a port mapping by index number
|
||||
|
||||
:param index: (int) mapping index number
|
||||
:return: NamedTuple[
|
||||
gateway_address: str
|
||||
external_port: int
|
||||
protocol: str
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
"""
|
||||
return await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
|
||||
|
||||
async def get_redirects(self) -> List[Tuple[str, int, str, int, str, bool, str, int]]:
|
||||
redirects: List[Tuple[str, int, str, int, str, bool, str, int]] = []
|
||||
async def get_redirects(self) -> List[GetGenericPortMappingEntryResponse]:
|
||||
"""
|
||||
Get information about all mapped ports
|
||||
|
||||
:return: List[
|
||||
NamedTuple[
|
||||
gateway_address: str
|
||||
external_port: int
|
||||
protocol: str
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
]
|
||||
"""
|
||||
redirects: List[GetGenericPortMappingEntryResponse] = []
|
||||
cnt = 0
|
||||
try:
|
||||
redirect = await self.get_port_mapping_by_index(cnt)
|
||||
|
@ -109,11 +190,19 @@ class UPnP:
|
|||
break
|
||||
return redirects
|
||||
|
||||
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> Tuple[int, str, bool, str, int]:
|
||||
async def get_specific_port_mapping(self, external_port: int, protocol: str) -> GetSpecificPortMappingEntryResponse:
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
:return: (int) <internal port>, (str) <lan ip>, (bool) <enabled>, (str) <description>, (int) <lease time>
|
||||
Get information about a port mapping by port number and protocol
|
||||
|
||||
:param external_port: (int) port number
|
||||
:param protocol: (str) UDP | TCP
|
||||
:return: NamedTuple[
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
"""
|
||||
return await self.gateway.commands.GetSpecificPortMappingEntry(
|
||||
NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
|
||||
|
@ -121,25 +210,31 @@ class UPnP:
|
|||
|
||||
async def delete_port_mapping(self, external_port: int, protocol: str) -> None:
|
||||
"""
|
||||
:param external_port: (int) external port to listen on
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
Delete a port mapping
|
||||
|
||||
:param external_port: (int) port number of mapping
|
||||
:param protocol: (str) TCP | UDP
|
||||
:return: None
|
||||
"""
|
||||
await self.gateway.commands.DeletePortMapping(
|
||||
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def get_next_mapping(self, port: int, protocol: str, description: str,
|
||||
internal_port: Optional[int] = None) -> int:
|
||||
"""
|
||||
:param port: (int) external port to redirect from
|
||||
:param protocol: (str) 'UDP' | 'TCP'
|
||||
:param description: (str) mapping description
|
||||
:param internal_port: (int) internal port to redirect to
|
||||
Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
|
||||
|
||||
:return: (int) <mapped port>
|
||||
:param port: (int) external port
|
||||
:param protocol: (str) UDP | TCP
|
||||
:param description: (str) mapping description
|
||||
:param internal_port: (int) internal port
|
||||
|
||||
:return: (int) mapped port
|
||||
"""
|
||||
|
||||
_internal_port = int(internal_port or port)
|
||||
requested_port = int(_internal_port)
|
||||
port = int(port)
|
||||
|
@ -264,22 +359,23 @@ class UPnP:
|
|||
# return await self.gateway.commands.GetActiveConnections()
|
||||
|
||||
|
||||
def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||
unicast: bool = True, kwargs: Optional[Dict] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
"""
|
||||
:param method: the command name
|
||||
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
|
||||
:param lan_address: the ip address of the local interface
|
||||
:param gateway_address: the ip address of the gateway
|
||||
:param timeout: timeout, in seconds
|
||||
:param interface_name: name of the network interface, the default is aliased to 'default'
|
||||
:param kwargs: keyword arguments for the command
|
||||
:param loop: EventLoop, used for testing
|
||||
"""
|
||||
cli_commands = [
|
||||
'm_search',
|
||||
'get_external_ip',
|
||||
'add_port_mapping',
|
||||
'get_port_mapping_by_index',
|
||||
'get_redirects',
|
||||
'get_specific_port_mapping',
|
||||
'delete_port_mapping',
|
||||
'get_next_mapping'
|
||||
]
|
||||
|
||||
|
||||
def run_cli(method: str, igd_args: Dict[str, Union[bool, str, int]], lan_address: str = '',
|
||||
gateway_address: str = '', timeout: int = 30, interface_name: str = 'default',
|
||||
unicast: bool = True, kwargs: Optional[Dict[str, str]] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
|
||||
kwargs = kwargs or {}
|
||||
igd_args = igd_args
|
||||
timeout = int(timeout)
|
||||
|
@ -287,20 +383,9 @@ def run_cli(method, igd_args: Dict[str, Union[bool, str, int]], lan_address: str
|
|||
fut: 'asyncio.Future' = asyncio.Future(loop=loop)
|
||||
|
||||
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
|
||||
cli_commands = [
|
||||
'm_search',
|
||||
'get_external_ip',
|
||||
'add_port_mapping',
|
||||
'get_port_mapping_by_index',
|
||||
'get_redirects',
|
||||
'get_specific_port_mapping',
|
||||
'delete_port_mapping',
|
||||
'get_next_mapping'
|
||||
]
|
||||
|
||||
if method == 'm_search': # if we're only m_searching don't do any device discovery
|
||||
fn = lambda *_a, **_kw: UPnP.m_search(
|
||||
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
|
||||
lan_address, gateway_address, timeout, unicast, interface_name, igd_args, loop
|
||||
)
|
||||
else: # automatically discover the gateway
|
||||
try:
|
||||
|
|
|
@ -16,7 +16,7 @@ except ImportError:
|
|||
|
||||
@contextlib.contextmanager
|
||||
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None):
|
||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
|
||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||
udp_replies = udp_replies or {}
|
||||
|
||||
|
@ -28,7 +28,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
|
|||
def _write(data):
|
||||
sent_tcp_packets.append(data)
|
||||
if data in tcp_replies:
|
||||
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
|
||||
reply = tcp_replies[data]
|
||||
i = 0
|
||||
while i < len(reply):
|
||||
loop.call_later(tcp_delay_reply, p.data_received, reply[i:i+100])
|
||||
i += 100
|
||||
return
|
||||
else:
|
||||
pass
|
||||
|
@ -46,6 +50,11 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
|
|||
def sendto(p: asyncio.DatagramProtocol):
|
||||
def _sendto(data, addr):
|
||||
sent_udp_packets.append(data)
|
||||
loop.call_later(udp_delay_reply, p.datagram_received, data,
|
||||
(p.bind_address, 1900))
|
||||
if add_potato_datagrams:
|
||||
loop.call_soon(p.datagram_received, b'potato', ('?.?.?.?', 1900))
|
||||
|
||||
if (data, addr) in udp_replies:
|
||||
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
|
||||
(udp_expected_addr, 1900))
|
||||
|
|
|
@ -1,23 +1,30 @@
|
|||
import unittest
|
||||
|
||||
from unittest import mock
|
||||
import socket
|
||||
import struct
|
||||
from asyncio import DatagramTransport
|
||||
from aioupnp.protocols.multicast import MulticastProtocol
|
||||
|
||||
|
||||
class TestMulticast(unittest.TestCase):
|
||||
def test_it(self):
|
||||
class none_socket:
|
||||
sock = None
|
||||
def test_multicast(self):
|
||||
_ttl = None
|
||||
mock_socket = mock.MagicMock(spec=socket.socket)
|
||||
def getsockopt(*_):
|
||||
return _ttl
|
||||
|
||||
def get(self, name, default=None):
|
||||
return default
|
||||
def setsockopt(a, b, ttl: bytes):
|
||||
nonlocal _ttl
|
||||
_ttl, = struct.unpack('b', ttl)
|
||||
|
||||
mock_socket.getsockopt = getsockopt
|
||||
mock_socket.setsockopt = setsockopt
|
||||
|
||||
protocol = MulticastProtocol('1.2.3.4', '1.2.3.4')
|
||||
transport = DatagramTransport(none_socket())
|
||||
protocol.set_ttl(1)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = protocol.get_ttl()
|
||||
transport = DatagramTransport()
|
||||
transport._extra = {'socket': mock_socket}
|
||||
self.assertEqual(None, protocol.set_ttl(1))
|
||||
self.assertEqual(0, protocol.get_ttl())
|
||||
protocol.connection_made(transport)
|
||||
protocol.set_ttl(1)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = protocol.get_ttl()
|
||||
self.assertEqual(None, protocol.set_ttl(1))
|
||||
self.assertEqual(1, protocol.get_ttl())
|
||||
|
|
|
@ -3,7 +3,7 @@ from aioupnp.fault import UPnPError
|
|||
from aioupnp.protocols.m_search_patterns import packet_generator
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
from aioupnp.constants import SSDP_IP_ADDRESS
|
||||
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
|
||||
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search, SSDPProtocol
|
||||
from tests import AsyncioTestCase, mock_tcp_and_udp
|
||||
|
||||
|
||||
|
@ -28,6 +28,13 @@ class TestSSDP(AsyncioTestCase):
|
|||
])
|
||||
reply_packet = SSDPDatagram("OK", reply_args)
|
||||
|
||||
async def test_transport_not_connected_error(self):
|
||||
try:
|
||||
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
|
||||
self.assertTrue(False)
|
||||
except UPnPError as err:
|
||||
self.assertEqual(str(err), "SSDP transport not connected")
|
||||
|
||||
async def test_m_search_reply_unicast(self):
|
||||
replies = {
|
||||
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
|
||||
|
@ -80,3 +87,13 @@ class TestSSDP(AsyncioTestCase):
|
|||
|
||||
self.assertEqual(reply.encode(), self.reply_packet.encode())
|
||||
self.assertEqual(args, self.successful_args)
|
||||
|
||||
async def test_packets_sent_fuzzy_m_search_ignore_invalid_datagram_replies(self):
|
||||
sent = []
|
||||
|
||||
with self.assertRaises(UPnPError):
|
||||
with mock_tcp_and_udp(self.loop, udp_expected_addr="10.0.0.1", sent_udp_packets=sent,
|
||||
add_potato_datagrams=True):
|
||||
await fuzzy_m_search("10.0.0.2", "10.0.0.1", 1, self.loop)
|
||||
|
||||
self.assertListEqual(sent, self.byte_packets)
|
|
@ -21,6 +21,139 @@ m_search_cli_result = """{
|
|||
}\n"""
|
||||
|
||||
|
||||
m_search_help_msg = """aioupnp [-h] [--debug_logging] m_search [--lan_address=<str>] [--gateway_address=<str>]
|
||||
[--timeout=<int>] [--unicast] [--interface_name=<str>] [--<header key>=<header value>, ...]
|
||||
|
||||
Perform a M-SEARCH for a upnp gateway.
|
||||
|
||||
:param lan_address: (str) the local interface ipv4 address
|
||||
:param gateway_address: (str) the gateway ipv4 address
|
||||
:param timeout: (int) m search timeout
|
||||
:param unicast: (bool) use unicast
|
||||
:param interface_name: (str) name of the network interface
|
||||
:param igd_args: (dict) case sensitive M-SEARCH headers. if used all headers to be used must be provided.
|
||||
|
||||
:return: {
|
||||
'lan_address': (str) lan address,
|
||||
'gateway_address': (str) gateway address,
|
||||
'm_search_kwargs': (str) equivalent igd_args ,
|
||||
'discover_reply': (dict) SSDP response datagram
|
||||
}\n
|
||||
"""
|
||||
|
||||
expected_usage = """aioupnp [-h] [--debug_logging] [--interface=<interface>] [--gateway_address=<gateway_address>]
|
||||
[--lan_address=<lan_address>] [--timeout=<timeout>] [(--<header_key>=<value>)...]
|
||||
|
||||
If m-search headers are provided as keyword arguments all of the headers to be used must be provided,
|
||||
in the order they are to be used. For example:
|
||||
aioupnp --HOST=239.255.255.250:1900 --MAN="ssdp:discover" --MX=1 --ST=upnp:rootdevice m_search
|
||||
|
||||
Commands:
|
||||
m_search | get_external_ip | add_port_mapping | get_port_mapping_by_index | get_redirects |
|
||||
get_specific_port_mapping | delete_port_mapping | get_next_mapping
|
||||
|
||||
For help with a specific command: aioupnp help <command>
|
||||
"""
|
||||
|
||||
expected_get_external_ip_usage = """aioupnp [-h] [--debug_logging] get_external_ip
|
||||
|
||||
Get the external ip address from the gateway
|
||||
|
||||
:return: (str) external ip
|
||||
|
||||
"""
|
||||
|
||||
expected_add_port_mapping_usage = """aioupnp [-h] [--debug_logging] add_port_mapping [--external_port=<int>] [--protocol=<str>]
|
||||
[--internal_port=<int>] [--lan_address=<str>] [--description=<str>]
|
||||
|
||||
Add a new port mapping
|
||||
|
||||
:param external_port: (int) external port to map
|
||||
:param protocol: (str) UDP | TCP
|
||||
:param internal_port: (int) internal port
|
||||
:param lan_address: (str) internal lan address
|
||||
:param description: (str) mapping description
|
||||
:return: None
|
||||
|
||||
"""
|
||||
|
||||
expected_get_next_mapping_usage = """aioupnp [-h] [--debug_logging] get_next_mapping [--port=<int>] [--protocol=<str>]
|
||||
[--description=<str>] [--internal_port=<typing.Union[int, NoneType]>]
|
||||
|
||||
Get a new port mapping. If the requested port is not available, increment until the next free port is mapped
|
||||
|
||||
:param port: (int) external port
|
||||
:param protocol: (str) UDP | TCP
|
||||
:param description: (str) mapping description
|
||||
:param internal_port: (int) internal port
|
||||
|
||||
:return: (int) mapped port
|
||||
|
||||
"""
|
||||
|
||||
|
||||
expected_delete_port_mapping_usage = """aioupnp [-h] [--debug_logging] delete_port_mapping [--external_port=<int>] [--protocol=<str>]
|
||||
|
||||
Delete a port mapping
|
||||
|
||||
:param external_port: (int) port number of mapping
|
||||
:param protocol: (str) TCP | UDP
|
||||
:return: None
|
||||
|
||||
"""
|
||||
|
||||
expected_get_specific_port_mapping_usage = """aioupnp [-h] [--debug_logging] get_specific_port_mapping [--external_port=<int>] [--protocol=<str>]
|
||||
|
||||
Get information about a port mapping by port number and protocol
|
||||
|
||||
:param external_port: (int) port number
|
||||
:param protocol: (str) UDP | TCP
|
||||
:return: NamedTuple[
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
|
||||
"""
|
||||
expected_get_redirects_usage = """aioupnp [-h] [--debug_logging] get_redirects
|
||||
|
||||
Get information about all mapped ports
|
||||
|
||||
:return: List[
|
||||
NamedTuple[
|
||||
gateway_address: str
|
||||
external_port: int
|
||||
protocol: str
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
]
|
||||
|
||||
"""
|
||||
expected_get_port_mapping_by_index_usage = """aioupnp [-h] [--debug_logging] get_port_mapping_by_index [--index=<int>]
|
||||
|
||||
Get information about a port mapping by index number
|
||||
|
||||
:param index: (int) mapping index number
|
||||
:return: NamedTuple[
|
||||
gateway_address: str
|
||||
external_port: int
|
||||
protocol: str
|
||||
internal_port: int
|
||||
lan_address: str
|
||||
enabled: bool
|
||||
description: str
|
||||
lease_time: int
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class TestCLI(AsyncioTestCase):
|
||||
gateway_address = "10.0.0.1"
|
||||
soap_port = 49152
|
||||
|
@ -101,3 +234,117 @@ class TestCLI(AsyncioTestCase):
|
|||
self.loop
|
||||
)
|
||||
self.assertEqual(m_search_cli_result, actual_output.getvalue())
|
||||
|
||||
def test_usage(self):
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'test'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'test', 'help'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'test'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual("aioupnp encountered an error: \"test\" is not a recognized command\n", actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, "--something=test"],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual("no command given\n" + expected_usage, actual_output.getvalue())
|
||||
|
||||
def test_commands_help(self):
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'm-search'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(m_search_help_msg, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'get-external-ip'],
|
||||
self.loop
|
||||
)
|
||||
|
||||
self.assertEqual(expected_get_external_ip_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'add-port-mapping'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_add_port_mapping_usage, actual_output.getvalue())
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'get-next-mapping'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_get_next_mapping_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'delete_port_mapping'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_delete_port_mapping_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'get_specific_port_mapping'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_get_specific_port_mapping_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'get_redirects'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_get_redirects_usage, actual_output.getvalue())
|
||||
|
||||
actual_output = StringIO()
|
||||
with contextlib.redirect_stdout(actual_output):
|
||||
main(
|
||||
[None, 'help', 'get_port_mapping_by_index'],
|
||||
self.loop
|
||||
)
|
||||
self.assertEqual(expected_get_port_mapping_by_index_usage, actual_output.getvalue())
|
||||
|
|
|
@ -4,17 +4,7 @@ from aioupnp.upnp import UPnP
|
|||
from aioupnp.fault import UPnPError
|
||||
from aioupnp.gateway import Gateway
|
||||
from aioupnp.serialization.ssdp import SSDPDatagram
|
||||
|
||||
|
||||
class TestGetAnnotations(AsyncioTestCase):
|
||||
def test_get_annotations(self):
|
||||
expected = {
|
||||
'NewRemoteHost': str, 'NewExternalPort': int, 'NewProtocol': str, 'NewInternalPort': int,
|
||||
'NewInternalClient': str, 'NewEnabled': int, 'NewPortMappingDescription': str,
|
||||
'NewLeaseDuration': str, 'return': None
|
||||
}
|
||||
|
||||
self.assertDictEqual(expected, UPnP.get_annotations('AddPortMapping'))
|
||||
from aioupnp.commands import GetSpecificPortMappingEntryResponse, GetGenericPortMappingEntryResponse
|
||||
|
||||
|
||||
class UPnPCommandTestCase(AsyncioTestCase):
|
||||
|
@ -76,7 +66,8 @@ class TestGetGenericPortMappingEntry(UPnPCommandTestCase):
|
|||
await gateway.discover_commands(self.loop)
|
||||
upnp = UPnP(self.client_address, self.gateway_address, gateway)
|
||||
result = await upnp.get_port_mapping_by_index(0)
|
||||
self.assertEqual((None, 9308, 'UDP', 9308, "11.2.3.44", True, "11.2.3.44:9308 to 9308 (UDP)", 0), result)
|
||||
self.assertEqual(GetGenericPortMappingEntryResponse(None, 9308, 'UDP', 9308, "11.2.3.44", True,
|
||||
"11.2.3.44:9308 to 9308 (UDP)", 0), result)
|
||||
|
||||
|
||||
class TestGetNextPortMapping(UPnPCommandTestCase):
|
||||
|
@ -120,6 +111,9 @@ class TestGetSpecificPortMapping(UPnPCommandTestCase):
|
|||
await upnp.get_specific_port_mapping(1000, 'UDP')
|
||||
except UPnPError:
|
||||
result = await upnp.get_specific_port_mapping(9308, 'UDP')
|
||||
self.assertEqual((9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0), result)
|
||||
self.assertEqual(
|
||||
GetSpecificPortMappingEntryResponse(9308, '11.2.3.55', True, '11.2.3.55:9308 to 9308 (UDP)', 0),
|
||||
result
|
||||
)
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
|
Loading…
Reference in a new issue