fix fuzzy m search, update generate_test_data

This commit is contained in:
Jack Robison 2018-10-11 17:39:46 -04:00
parent e562fdc4bb
commit 210073af93
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
8 changed files with 282 additions and 248 deletions

View file

@ -22,8 +22,4 @@ SSDP_IP_ADDRESS = '239.255.255.250'
SSDP_PORT = 1900 SSDP_PORT = 1900
SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT) SSDP_HOST = "%s:%i" % (SSDP_IP_ADDRESS, SSDP_PORT)
SSDP_DISCOVER = "ssdp:discover" SSDP_DISCOVER = "ssdp:discover"
SSDP_ALL = "ssdp:all"
SSDP_BYEBYE = "ssdp:byebye"
SSDP_UPDATE = "ssdp:update"
SSDP_ROOT_DEVICE = "upnp:rootdevice"
line_separator = "\r\n" line_separator = "\r\n"

View file

@ -1,13 +1,15 @@
import logging import logging
import socket import socket
from collections import OrderedDict
from typing import Dict, List, Union, Type from typing import Dict, List, Union, Type
from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX from aioupnp.util import get_dict_val_case_insensitive, BASE_PORT_REGEX, BASE_ADDRESS_REGEX
from aioupnp.constants import SPEC_VERSION, UPNP_ORG_IGD, SERVICE from aioupnp.constants import SPEC_VERSION, SERVICE
from aioupnp.commands import SOAPCommands from aioupnp.commands import SOAPCommands
from aioupnp.device import Device, Service from aioupnp.device import Device, Service
from aioupnp.protocols.ssdp import fuzzy_m_search from aioupnp.protocols.ssdp import fuzzy_m_search
from aioupnp.protocols.scpd import scpd_get from aioupnp.protocols.scpd import scpd_get
from aioupnp.protocols.soap import SCPDCommand from aioupnp.protocols.soap import SOAPCommand
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.util import flatten_keys from aioupnp.util import flatten_keys
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
@ -53,43 +55,36 @@ def get_action_list(element_dict: dict) -> List: # [(<method>, [<input1>, ...],
class Gateway: class Gateway:
def __init__(self, **kwargs): def __init__(self, ok_packet: SSDPDatagram, m_search_args: OrderedDict, lan_address: str,
flattened = { gateway_address: str) -> None:
k.lower(): v for k, v in kwargs.items() self._ok_packet = ok_packet
} self._m_search_args = m_search_args
usn = flattened["usn"] self._lan_address = lan_address
server = flattened["server"] self.usn = (ok_packet.usn or '').encode()
location = flattened["location"] self.ext = (ok_packet.ext or '').encode()
st = flattened["st"] self.server = (ok_packet.server or '').encode()
self.location = (ok_packet.location or '').encode()
self.cache_control = (ok_packet.cache_control or '').encode()
self.date = (ok_packet.date or '').encode()
self.urn = (ok_packet.st or '').encode()
cache_control = flattened.get("cache_control") or flattened.get("cache-control") or "" self._xml_response = b""
date = flattened.get("date", "") self._service_descriptors: Dict = {}
ext = flattened.get("ext", "")
self.usn = usn.encode()
self.ext = ext.encode()
self.server = server.encode()
self.location = location.encode()
self.cache_control = cache_control.encode()
self.date = date.encode()
self.urn = st.encode()
self._xml_response = ""
self._service_descriptors = {}
self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0] self.base_address = BASE_ADDRESS_REGEX.findall(self.location)[0]
self.port = int(BASE_PORT_REGEX.findall(self.location)[0]) self.port = int(BASE_PORT_REGEX.findall(self.location)[0])
self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0] self.base_ip = self.base_address.lstrip(b"http://").split(b":")[0]
assert self.base_ip == gateway_address.encode()
self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1] self.path = self.location.split(b"%s:%i/" % (self.base_ip, self.port))[1]
self.spec_version = None self.spec_version = None
self.url_base = None self.url_base = None
self._device = None self._device: Union[None, Device] = None
self._devices = [] self._devices: List = []
self._services = [] self._services: List = []
self._unsupported_actions = {} self._unsupported_actions: Dict = {}
self._registered_commands = {} self._registered_commands: Dict = {}
self.commands = SOAPCommands() self.commands = SOAPCommands()
def gateway_descriptor(self) -> dict: def gateway_descriptor(self) -> dict:
@ -103,6 +98,12 @@ class Gateway:
} }
return r return r
@property
def manufacturer_string(self) -> str:
if not self._device:
raise NotImplementedError()
return "%s %s" % (self._device.manufacturer, self._device.modelName)
@property @property
def services(self) -> Dict: def services(self) -> Dict:
if not self._device: if not self._device:
@ -121,23 +122,36 @@ class Gateway:
return service return service
return None return None
def debug_commands(self): @property
def _soap_requests(self) -> Dict:
return { return {
'available': self._registered_commands, name: getattr(self.commands, name)._requests for name in self._registered_commands.keys()
'failed': self._unsupported_actions }
def debug_gateway(self) -> Dict:
return {
'gateway_address': self.base_ip,
'soap_port': self.port,
'm_search_args': self._m_search_args,
'reply': self._ok_packet.as_dict(),
'registered_soap_commands': self._registered_commands,
'unsupported_soap_commands': self._unsupported_actions,
'gateway_xml': self._xml_response,
'service_descriptors': self._service_descriptors,
'soap_requests': self._soap_requests
} }
@classmethod @classmethod
async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 1, async def discover_gateway(cls, lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket) m_search_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
gateway = cls(**datagram.as_dict()) gateway = cls(datagram, m_search_args, lan_address, gateway_address)
await gateway.discover_commands(soap_socket) await gateway.discover_commands(soap_socket)
return gateway return gateway
async def discover_commands(self, soap_socket: socket.socket = None): async def discover_commands(self, soap_socket: socket.socket = None):
response = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port) response, xml_bytes = await scpd_get(self.path.decode(), self.base_ip.decode(), self.port)
self._xml_response = xml_bytes
self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION) self.spec_version = get_dict_val_case_insensitive(response, SPEC_VERSION)
self.url_base = get_dict_val_case_insensitive(response, "urlbase") self.url_base = get_dict_val_case_insensitive(response, "urlbase")
if not self.url_base: if not self.url_base:
@ -154,7 +168,9 @@ class Gateway:
async def register_commands(self, service: Service, soap_socket: socket.socket = None): async def register_commands(self, service: Service, soap_socket: socket.socket = None):
if not service.SCPDURL: if not service.SCPDURL:
raise UPnPError("no scpd url") raise UPnPError("no scpd url")
service_dict = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port) service_dict, xml_bytes = await scpd_get(service.SCPDURL, self.base_ip.decode(), self.port)
self._service_descriptors[service.SCPDURL] = xml_bytes
if not service_dict: if not service_dict:
return return
@ -176,7 +192,7 @@ class Gateway:
if param_name == "return": if param_name == "return":
continue continue
param_types[param_name] = param_type param_types[param_name] = param_type
command = SCPDCommand( command = SOAPCommand(
self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(), self.base_ip.decode(), self.port, service.controlURL, service.serviceType.encode(),
name, param_types, return_types, inputs, outputs, soap_socket) name, param_types, return_types, inputs, outputs, soap_socket)
setattr(command, "__doc__", current.__doc__) setattr(command, "__doc__", current.__doc__)

View file

@ -1,81 +1,84 @@
M_SEARCH_ARG_PATTERNS = [ """
# Alleged SSDP discovery documentation
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('MAN', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('ST', lambda s: s),
('Man', lambda s: '"%s"' % s),
('MX', lambda n: int(n)),
],
# swap st and man M-SEARCH * HTTP/1.1
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('MAN', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
[
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: '"%s"' % s),
('ST', lambda s: s),
('MX', lambda n: int(n)),
],
# repeat above but with no quotes on man Headers
[ HOST
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), Required. Multicast channel and port reserved for SSDP by Internet Assigned Numbers Authority (IANA). Must be
('ST', lambda s: s), 239.255.255.250:1900. If the port number (:1900) is omitted, the receiver should assume the default SSDP port
('MAN', lambda s: s), number of 1900.
('MX', lambda n: int(n)), MAN
], Required by HTTP Extension Framework. Unlike the NTS and ST headers, the value of the MAN header is enclosed in
[ double quotes; it defines the scope (namespace) of the extension. Must be "ssdp:discover".
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), MX
('ST', lambda s: s), Required. Maximum wait time in seconds. Should be between 1 and 120 inclusive. Device responses should be delayed a
('Man', lambda s: s), random duration between 0 and this many seconds to balance load for the control point when it processes responses.
('MX', lambda n: int(n)), This value may be increased if a large number of devices are expected to respond. The MX value should not be
], increased to accommodate network characteristics such as latency or propagation delay (for more details, see the
[ explanation below). Specified by UPnP vendor. Integer.
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)), ST
('ST', lambda s: s), Required. Search Target. Must be one of the following. (cf. NT header in NOTIFY with ssdp:alive above.) Single URI.
('Man', lambda s: s),
('MX', lambda n: int(n)),
],
# swap st and man ssdp:all
[ Search for all devices and services.
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('MAN', lambda s: s), upnp:rootdevice
('ST', lambda s: s), Search for root devices only.
('MX', lambda n: int(n)),
], uuid:device-UUID
[ Search for a particular device. Device UUID specified by UPnP vendor.
('HOST', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: s), urn:schemas-upnp-org:device:deviceType:v
('ST', lambda s: s), Search for any device of this type. Device type and version defined by UPnP Forum working committee.
('MX', lambda n: int(n)),
], urn:schemas-upnp-org:service:serviceType:v
[ Search for any service of this type. Service type and version defined by UPnP Forum working committee.
('Host', lambda ssdp_ip: "{}:{}".format(ssdp_ip, 1900)),
('Man', lambda s: s), urn:domain-name:device:deviceType:v
('ST', lambda s: str(s)), Search for any device of this type. Domain name, device type and version defined by UPnP vendor. Period
('MX', lambda n: int(n)), characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
],
] urn:domain-name:service:serviceType:v
Search for any service of this type. Domain name, service type and version defined by UPnP vendor. Period
characters in the domain name must be replaced with hyphens in accordance with RFC 2141.
"""
from collections import OrderedDict
from aioupnp.constants import SSDP_DISCOVER, SSDP_HOST
SEARCH_TARGETS = [
'ssdp:all'
'urn:schemas-upnp-org:device:InternetGatewayDevice:1',
'upnp:rootdevice',
'urn:schemas-wifialliance-org:device:WFADevice:1',
'urn:schemas-upnp-org:device:WANDevice:1',
]
def format_packet_args(order: list, **kwargs):
args = []
for o in order:
for k, v in kwargs.items():
if k.lower() == o.lower():
args.append((k, v))
break
return OrderedDict(args)
def packet_generator():
for st in SEARCH_TARGETS:
order = ["HOST", "MAN", "MX", "ST"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, Host=SSDP_HOST, Man=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, Host=SSDP_HOST, Man='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
order = ["HOST", "MAN", "ST", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)
order = ["HOST", "ST", "MAN", "MX"]
yield format_packet_args(order, HOST=SSDP_HOST, MAN=SSDP_DISCOVER, MX=1, ST=st)
yield format_packet_args(order, HOST=SSDP_HOST, MAN='"%s"' % SSDP_DISCOVER, MX=1, ST=st)

View file

@ -1,5 +1,6 @@
import logging import logging
import socket import socket
import typing
from xml.etree import ElementTree from xml.etree import ElementTree
import asyncio import asyncio
from asyncio.protocols import Protocol from asyncio.protocols import Protocol
@ -38,43 +39,46 @@ class SCPDHTTPClientProtocol(Protocol):
if self.method == self.GET: if self.method == self.GET:
try: try:
packet = deserialize_scpd_get_response(self.response_buff) packet = deserialize_scpd_get_response(self.response_buff)
if not packet: if packet:
return self.finished.set_result(packet)
except ElementTree.ParseError: return
pass except ElementTree.ParseError:
except UPnPError as err: pass
self.finished.set_exception(err) except UPnPError as err:
else: self.finished.set_exception(err)
self.finished.set_result(packet) elif self.method == self.POST:
elif self.method == self.POST: try:
try: packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id)
packet = deserialize_soap_post_response(self.response_buff, self.soap_method, self.soap_service_id) if packet:
if not packet:
self.finished.set_result(packet) self.finished.set_result(packet)
return return
except ElementTree.ParseError: except ElementTree.ParseError:
pass pass
except UPnPError as err: except UPnPError as err:
self.finished.set_exception(err) self.finished.set_exception(err)
else:
self.finished.set_result(packet)
async def scpd_get(control_url: str, address: str, port: int) -> dict: async def scpd_get(control_url: str, address: str, port: int) -> typing.Tuple[typing.Dict, bytes]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_scpd_get(control_url, address) packet = serialize_scpd_get(control_url, address)
transport, protocol = await loop.create_connection( transport, protocol = await loop.create_connection(
lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port lambda : SCPDHTTPClientProtocol('GET', packet, finished), address, port
) )
assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {}
try: try:
return await asyncio.wait_for(finished, 1.0) parsed = await asyncio.wait_for(finished, 1.0)
except UPnPError:
return parsed, protocol.response_buff
finally: finally:
transport.close() transport.close()
return parsed, protocol.response_buff
async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes, async def scpd_post(control_url: str, address: str, port: int, method: str, param_names: list, service_id: bytes,
close_after_send: bool, soap_socket: socket.socket = None, **kwargs): close_after_send: bool, soap_socket: socket.socket = None,
**kwargs) -> typing.Tuple[typing.Dict, bytes]:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
finished: asyncio.Future = asyncio.Future() finished: asyncio.Future = asyncio.Future()
packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs) packet = serialize_soap_post(method, param_names, service_id, address.encode(), control_url.encode(), **kwargs)
@ -84,7 +88,12 @@ async def scpd_post(control_url: str, address: str, port: int, method: str, para
close_after_send=close_after_send close_after_send=close_after_send
), address, port, sock=soap_socket ), address, port, sock=soap_socket
) )
assert isinstance(protocol, SCPDHTTPClientProtocol)
parsed: typing.Dict = {}
try: try:
return await asyncio.wait_for(finished, 1.0) parsed = await asyncio.wait_for(finished, 1.0)
except UPnPError:
return parsed, protocol.response_buff
finally: finally:
transport.close() transport.close()
return parsed, protocol.response_buff

View file

@ -6,7 +6,7 @@ from aioupnp.protocols.scpd import scpd_post
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class SCPDCommand: class SOAPCommand:
def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str, def __init__(self, gateway_address: str, service_port: int, control_url: str, service_id: bytes, method: str,
param_types: dict, return_types: dict, param_order: list, return_order: list, param_types: dict, return_types: dict, param_order: list, return_order: list,
soap_socket: socket.socket = None) -> None: soap_socket: socket.socket = None) -> None:
@ -20,18 +20,21 @@ class SCPDCommand:
self.return_types = return_types self.return_types = return_types
self.return_order = return_order self.return_order = return_order
self.soap_socket = soap_socket self.soap_socket = soap_socket
self._requests: typing.List = []
async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]: async def __call__(self, **kwargs) -> typing.Union[None, typing.Dict, typing.List, typing.Tuple]:
if set(kwargs.keys()) != set(self.param_types.keys()): if set(kwargs.keys()) != set(self.param_types.keys()):
raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys())) raise Exception("argument mismatch: %s vs %s" % (kwargs.keys(), self.param_types.keys()))
close_after_send = not self.return_types or self.return_types == [None] close_after_send = not self.return_types or self.return_types == [None]
response = await scpd_post( soap_kwargs = {n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()}
response, xml_bytes = await scpd_post(
self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id, self.control_url, self.gateway_address, self.service_port, self.method, self.param_order, self.service_id,
close_after_send, self.soap_socket, **{n: self.param_types[n](kwargs[n]) for n in self.param_types.keys()} close_after_send, self.soap_socket, **soap_kwargs
) )
result = tuple([self.return_types[n](response.get(n)) for n in self.return_order]) self._requests.append((soap_kwargs, xml_bytes))
if not result: if not response:
return None return None
result = tuple([self.return_types[n](response.get(n)) for n in self.return_order])
if len(result) == 1: if len(result) == 1:
return result[0] return result[0]
return result return result

View file

@ -9,10 +9,9 @@ from asyncio.futures import Future
from asyncio.transports import DatagramTransport from asyncio.transports import DatagramTransport
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.serialization.ssdp import SSDPDatagram from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import UPNP_ORG_IGD, WIFI_ALLIANCE_ORG_IGD from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT
from aioupnp.constants import SSDP_IP_ADDRESS, SSDP_PORT, SSDP_DISCOVER, SSDP_ROOT_DEVICE, SSDP_ALL
from aioupnp.protocols.multicast import MulticastProtocol from aioupnp.protocols.multicast import MulticastProtocol
from aioupnp.protocols.m_search_patterns import M_SEARCH_ARG_PATTERNS from aioupnp.protocols.m_search_patterns import packet_generator
ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$") ADDRESS_REGEX = re.compile("^http:\/\/(\d+\.\d+\.\d+\.\d+)\:(\d*)(\/[\w|\/|\:|\-|\.]*)$")
@ -23,21 +22,45 @@ class SSDPProtocol(MulticastProtocol):
def __init__(self, multicast_address: str, lan_address: str) -> None: def __init__(self, multicast_address: str, lan_address: str) -> None:
super().__init__(multicast_address, lan_address) super().__init__(multicast_address, lan_address)
self.lan_address = lan_address self.lan_address = lan_address
self.discover_callbacks: Dict = {} self._pending_searches: List[Tuple[str, str, Future, asyncio.Handle]] = []
self.notifications: List = [] self.notifications: List = []
self.replies: List = []
def m_search(self, address: str, timeout: int, datagram_args: OrderedDict) -> Future: def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram_args) tmp: list = []
f: Future = Future() set_futures: list = []
futs = self.discover_callbacks.get((address, packet.st), []) while self._pending_searches:
futs.append(f) t: tuple = self._pending_searches.pop()
self.discover_callbacks[(address, packet.st)] = futs a, s = t[0], t[1]
log.debug("send m search to %s: %s", address, packet) if (address == a) and (s in [packet.st, "upnp:rootdevice"]):
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT)) f: Future = t[2]
h: asyncio.Handle = t[3]
h.cancel()
if f not in set_futures:
set_futures.append(f)
if not f.done():
f.set_result(packet)
elif t[2] not in set_futures:
tmp.append(t)
while tmp:
self._pending_searches.append(tmp.pop())
r: Future = asyncio.ensure_future(asyncio.wait_for(f, timeout)) def send_many_m_searches(self, address: str, packets: List[SSDPDatagram]):
return r for packet in packets:
log.debug("send m search to %s: %s", address, packet.st)
self.transport.sendto(packet.encode().encode(), (address, SSDP_PORT))
async def m_search(self, address: str, timeout: float, datagrams: List[OrderedDict]) -> SSDPDatagram:
fut: Future = Future()
packets: List[SSDPDatagram] = []
for datagram in datagrams:
packet = SSDPDatagram(SSDPDatagram._M_SEARCH, datagram)
assert packet.st is not None
h = asyncio.get_running_loop().call_later(timeout, fut.cancel)
self._pending_searches.append((address, packet.st, fut, h))
packets.append(packet)
self.send_many_m_searches(address, packets),
return await fut
def datagram_received(self, data, addr) -> None: def datagram_received(self, data, addr) -> None:
if addr[0] == self.lan_address: if addr[0] == self.lan_address:
@ -51,15 +74,8 @@ class SSDPProtocol(MulticastProtocol):
return return
if packet._packet_type == packet._OK: if packet._packet_type == packet._OK:
if (addr[0], packet.st) in self.discover_callbacks: self._callback_m_search_ok(addr[0], packet)
log.debug("%s:%i replied to our m-search", addr[0], addr[1]) return
if packet.st not in map(lambda p: p['st'], self.replies):
self.replies.append(packet)
for ok_fut in self.discover_callbacks[(addr[0], packet.st)]:
ok_fut.set_result(packet)
del self.discover_callbacks[(addr[0], packet.st)]
return
# elif packet._packet_type == packet._NOTIFY: # elif packet._packet_type == packet._NOTIFY:
# log.debug("%s:%i sent us a notification: %s", packet) # log.debug("%s:%i sent us a notification: %s", packet)
# if packet.nt == SSDP_ROOT_DEVICE: # if packet.nt == SSDP_ROOT_DEVICE:
@ -109,46 +125,42 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Ordere
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket
) )
try: try:
return await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=datagram_args) return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
except asyncio.TimeoutError: except (asyncio.TimeoutError, asyncio.CancelledError):
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
finally: finally:
transport.close() transport.close()
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 1, async def _fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> SSDPDatagram: ssdp_socket: socket.socket = None) -> List[OrderedDict]:
transport, protocol, gateway_address, lan_address = await listen_ssdp( transport, protocol, gateway_address, lan_address = await listen_ssdp(
lan_address, gateway_address, ssdp_socket lan_address, gateway_address, ssdp_socket
) )
datagram_kwargs: list = [] packet_args = list(packet_generator())
services = [UPNP_ORG_IGD, SSDP_ALL, WIFI_ALLIANCE_ORG_IGD] batch_size = 2
mans = [SSDP_DISCOVER, SSDP_ROOT_DEVICE] b = 0
mx = 1 batch_timeout = float(timeout) / float(len(packet_args))
while packet_args:
for service in services: args = packet_args[:batch_size]
for man in mans: packet_args = packet_args[batch_size:]
for arg_pattern in M_SEARCH_ARG_PATTERNS: log.debug("sending batch of %i M-SEARCH attempts", batch_size)
dgram_kwargs: OrderedDict = OrderedDict()
for k, l in arg_pattern:
if k.lower() == 'host':
dgram_kwargs[k] = l(SSDP_IP_ADDRESS)
elif k.lower() == 'st':
dgram_kwargs[k] = l(service)
elif k.lower() == 'man':
dgram_kwargs[k] = l(man)
elif k.lower() == 'mx':
dgram_kwargs[k] = l(mx)
datagram_kwargs.append(dgram_kwargs)
for i, args in enumerate(datagram_kwargs):
try: try:
result = await protocol.m_search(address=gateway_address, timeout=timeout, datagram_args=args) await protocol.m_search(gateway_address, batch_timeout, args)
transport.close() return args
return result except (asyncio.TimeoutError, asyncio.CancelledError):
except asyncio.TimeoutError: b += 1
pass continue
except Exception as err:
log.error(err)
transport.close()
raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT)) raise UPnPError("M-SEARCH for {}:{} timed out".format(gateway_address, SSDP_PORT))
async def fuzzy_m_search(lan_address: str, gateway_address: str, timeout: int = 30,
ssdp_socket: socket.socket = None) -> Tuple[OrderedDict, SSDPDatagram]:
args_to_try = await _fuzzy_m_search(lan_address, gateway_address, timeout, ssdp_socket)
for args in args_to_try:
try:
packet = await m_search(lan_address, gateway_address, args, 3)
return args, packet
except UPnPError:
continue
raise UPnPError("failed to discover gateway")

View file

@ -9,7 +9,6 @@ from aioupnp.constants import line_separator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_template = "^(?i)(%s):[ ]*(.*)$" _template = "^(?i)(%s):[ ]*(.*)$"
@ -54,8 +53,8 @@ class SSDPDatagram(object):
_M_SEARCH: [ _M_SEARCH: [
'host', 'host',
'man', 'man',
'st',
'mx', 'mx',
'st',
], ],
_NOTIFY: [ _NOTIFY: [
'host', 'host',
@ -85,9 +84,9 @@ class SSDPDatagram(object):
k.lower().replace("-", "_") for k in kwargs.keys() k.lower().replace("-", "_") for k in kwargs.keys()
] ]
self.host = None self.host = None
self.st = None
self.man = None self.man = None
self.mx = None self.mx = None
self.st = None
self.nt = None self.nt = None
self.nts = None self.nts = None
self.usn = None self.usn = None

View file

@ -8,7 +8,8 @@ from typing import Tuple, Dict, List, Union
from aioupnp.fault import UPnPError from aioupnp.fault import UPnPError
from aioupnp.gateway import Gateway from aioupnp.gateway import Gateway
from aioupnp.util import get_gateway_and_lan_addresses from aioupnp.util import get_gateway_and_lan_addresses
from aioupnp.protocols.ssdp import m_search, fuzzy_m_search from aioupnp.protocols.ssdp import m_search
from aioupnp.protocols.soap import SOAPCommand
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -42,7 +43,7 @@ class UPnP:
return lan_address, gateway_address return lan_address, gateway_address
@classmethod @classmethod
async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1, async def discover(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
interface_name: str = 'default', interface_name: str = 'default',
ssdp_socket: socket.socket = None, soap_socket: socket.socket = None): ssdp_socket: socket.socket = None, soap_socket: socket.socket = None):
try: try:
@ -77,7 +78,7 @@ class UPnP:
await self.gateway.commands.AddPortMapping( await self.gateway.commands.AddPortMapping(
NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol, NewRemoteHost="", NewExternalPort=external_port, NewProtocol=protocol,
NewInternalPort=internal_port, NewInternalClient=lan_address, NewInternalPort=internal_port, NewInternalClient=lan_address,
NewEnabled=1, NewPortMappingDescription=description, NewLeaseDuration="" NewEnabled=True, NewPortMappingDescription=description, NewLeaseDuration=""
) )
return return
@ -85,13 +86,14 @@ class UPnP:
async def get_port_mapping_by_index(self, index: int) -> Dict: async def get_port_mapping_by_index(self, index: int) -> Dict:
result = await self._get_port_mapping_by_index(index) result = await self._get_port_mapping_by_index(index)
if result: if result:
return { if isinstance(self.gateway.commands.GetGenericPortMappingEntry, SOAPCommand):
k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result) return {
} k: v for k, v in zip(self.gateway.commands.GetGenericPortMappingEntry.return_order, result)
}
return {} return {}
async def _get_port_mapping_by_index(self, index: int) -> Union[Tuple[str, int, str, int, str, bool, str, int], async def _get_port_mapping_by_index(self, index: int) -> Union[None,
None]: Tuple[Union[None, str], int, str, int, str, bool, str, int]]:
try: try:
redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index) redirect = await self.gateway.commands.GetGenericPortMappingEntry(NewPortMappingIndex=index)
return redirect return redirect
@ -119,9 +121,11 @@ class UPnP:
try: try:
result = await self.gateway.commands.GetSpecificPortMappingEntry( result = await self.gateway.commands.GetSpecificPortMappingEntry(
NewRemoteHost=None, NewExternalPort=external_port, NewProtocol=protocol NewRemoteHost='', NewExternalPort=external_port, NewProtocol=protocol
) )
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)} if isinstance(self.gateway.commands.GetSpecificPortMappingEntry, SOAPCommand):
return {k: v for k, v in zip(self.gateway.commands.GetSpecificPortMappingEntry.return_order, result)}
return {}
except UPnPError: except UPnPError:
return {} return {}
@ -175,49 +179,38 @@ class UPnP:
@cli @cli
async def generate_test_data(self): async def generate_test_data(self):
external_ip = await self.get_external_ip() print("found gateway via M-SEARCH")
redirects = await self.get_redirects() try:
ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping") external_ip = await self.get_external_ip()
delete = await self.delete_port_mapping(ext_port, "UDP") print("got external ip: %s" % external_ip)
after_delete = await self.get_specific_port_mapping(ext_port, "UDP") except UPnPError:
print("failed to get the external ip")
try:
redirects = await self.get_redirects()
print("got redirects:\n%s" % redirects)
except UPnPError:
print("failed to get redirects")
commands_test_case = ( try:
("get_external_ip", (), "1.2.3.4"), ext_port = await self.get_next_mapping(4567, "UDP", "aioupnp test mapping")
("get_redirects", (), redirects), print("set up external mapping to port %i" % ext_port)
("get_next_mapping", (4567, "UDP", "aioupnp test mapping"), ext_port), await self.delete_port_mapping(ext_port, "UDP")
("delete_port_mapping", (ext_port, "UDP"), delete), print("deleted mapping")
("get_specific_port_mapping", (ext_port, "UDP"), after_delete), except UPnPError:
) print("failed to add and remove a mapping")
gateway = self.gateway device = list(self.gateway.devices.values())[0]
device = list(gateway.devices.values())[0]
assert device.manufacturer and device.modelName assert device.manufacturer and device.modelName
device_path = os.path.join(os.getcwd(), "%s %s" % (device.manufacturer, device.modelName)) device_path = os.path.join(os.getcwd(), self.gateway.manufacturer_string)
commands = gateway.debug_commands()
with open(device_path, "w") as f: with open(device_path, "w") as f:
f.write(json.dumps({ f.write(json.dumps({
"router_address": self.gateway_address, "gateway": self.gateway.debug_gateway(),
"client_address": self.lan_address, "client_address": self.lan_address,
"port": gateway.port, }, default=_encode, indent=2))
"gateway_dict": gateway.gateway_descriptor(),
'expected_devices': [
{
'cache_control': 'max-age=1800',
'location': gateway.location,
'server': gateway.server,
'st': gateway.urn,
'usn': gateway.usn
}
],
'commands': commands,
# 'ssdp': u.sspd_factory.get_ssdp_packet_replay(),
# 'scpd': gateway.requester.dump_packets(),
'soap': commands_test_case
}, default=_encode, indent=2).replace(external_ip, "1.2.3.4"))
return "Generated test data! -> %s" % device_path return "Generated test data! -> %s" % device_path
@classmethod @classmethod
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 60, def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
interface_name: str = 'default', kwargs: dict = None) -> None: interface_name: str = 'default', kwargs: dict = None) -> None:
kwargs = kwargs or {} kwargs = kwargs or {}
igd_args = igd_args igd_args = igd_args
@ -257,6 +250,9 @@ class UPnP:
log.exception("uncaught error") log.exception("uncaught error")
fut.set_exception(UPnPError("uncaught error: %s" % str(err))) fut.set_exception(UPnPError("uncaught error: %s" % str(err)))
if not hasattr(UPnP, method) or not hasattr(getattr(UPnP, method), "_cli"):
fut.set_exception(UPnPError("\"%s\" is not a recognized command" % method))
wrapper = lambda : None
asyncio.run(wrapper()) asyncio.run(wrapper())
try: try:
result = fut.result() result = fut.result()