add tests + coverage
-test_socket_setup_error -ssdp test_deadbeef_response
This commit is contained in:
parent
1b25c009ca
commit
8e07d7f390
5 changed files with 41 additions and 19 deletions
|
@ -188,7 +188,7 @@ class Gateway:
|
|||
igd_args: typing.Dict[str, typing.Union[int, str]],
|
||||
timeout: int = 30,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set())
|
||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop)
|
||||
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
|
||||
if not gateway:
|
||||
raise UPnPError("no gateway found for given args")
|
||||
|
@ -199,7 +199,7 @@ class Gateway:
|
|||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||
ignored: typing.Set[str] = set()
|
||||
ssdp_proto = await multi_m_search(
|
||||
lan_address, gateway_address, timeout, loop, ignored
|
||||
lan_address, gateway_address, timeout, loop
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
|
|
|
@ -24,12 +24,11 @@ class PendingSearch(typing.NamedTuple):
|
|||
|
||||
|
||||
class SSDPProtocol(MulticastProtocol):
|
||||
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
|
||||
def __init__(self, multicast_address: str, lan_address: str,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(multicast_address, lan_address)
|
||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self.transport: Optional[DatagramTransport] = None
|
||||
self._ignored: Set[str] = ignored or set() # ignored locations
|
||||
self._pending_searches: List[PendingSearch] = []
|
||||
self.notifications: List[SSDPDatagram] = []
|
||||
self.connected = asyncio.Event(loop=self.loop)
|
||||
|
@ -52,9 +51,6 @@ class SSDPProtocol(MulticastProtocol):
|
|||
return None
|
||||
|
||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||
if packet.location in self._ignored:
|
||||
return
|
||||
|
||||
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
|
||||
replied: List[PendingSearch] = []
|
||||
|
||||
|
@ -137,13 +133,13 @@ class SSDPProtocol(MulticastProtocol):
|
|||
# return
|
||||
|
||||
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]:
|
||||
async def listen_ssdp(lan_address: str, gateway_address: str,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
try:
|
||||
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
|
||||
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock
|
||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
|
||||
)
|
||||
protocol = listen_result[1]
|
||||
assert isinstance(protocol, SSDPProtocol)
|
||||
|
@ -156,10 +152,10 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
|
|||
|
||||
|
||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
|
||||
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None) -> SSDPDatagram:
|
||||
timeout: int = 1,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram:
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored
|
||||
lan_address, gateway_address, loop
|
||||
)
|
||||
try:
|
||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||
|
@ -170,11 +166,10 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
|
|||
|
||||
|
||||
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
ignored: Set[str] = None) -> SSDPProtocol:
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||
lan_address, gateway_address, loop, ignored
|
||||
lan_address, gateway_address, loop
|
||||
)
|
||||
fut = asyncio.ensure_future(protocol.send_m_searches(
|
||||
address=gateway_address, datagrams=list(packet_generator())
|
||||
|
|
|
@ -164,7 +164,12 @@ class SSDPDatagram:
|
|||
|
||||
@classmethod
|
||||
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
|
||||
packet = cls._from_string(datagram.decode())
|
||||
try:
|
||||
packet = cls._from_string(datagram.decode())
|
||||
except UnicodeDecodeError:
|
||||
raise UPnPError(
|
||||
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
|
||||
)
|
||||
if packet is None:
|
||||
raise UPnPError(
|
||||
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
|
||||
|
|
|
@ -16,7 +16,8 @@ except ImportError:
|
|||
|
||||
@contextlib.contextmanager
|
||||
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
|
||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
|
||||
raise_oserror_on_bind=False):
|
||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||
udp_replies = udp_replies or {}
|
||||
|
||||
|
@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
|
|||
with mock.patch('socket.socket') as mock_socket:
|
||||
mock_sock = mock.Mock(spec=socket.socket)
|
||||
mock_sock.setsockopt = lambda *_: None
|
||||
mock_sock.bind = lambda *_: None
|
||||
|
||||
def bind(*_):
|
||||
if raise_oserror_on_bind:
|
||||
raise OSError()
|
||||
return
|
||||
|
||||
mock_sock.bind = bind
|
||||
mock_sock.setblocking = lambda *_: None
|
||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||
mock_sock.getpeername = lambda: ""
|
||||
|
|
|
@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase):
|
|||
])
|
||||
reply_packet = SSDPDatagram("OK", reply_args)
|
||||
|
||||
async def test_socket_setup_error(self):
|
||||
with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True):
|
||||
with self.assertRaises(UPnPError):
|
||||
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
|
||||
|
||||
async def test_transport_not_connected_error(self):
|
||||
try:
|
||||
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
|
||||
|
@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase):
|
|||
except UPnPError as err:
|
||||
self.assertEqual(str(err), "SSDP transport not connected")
|
||||
|
||||
async def test_deadbeef_response(self):
|
||||
replies = {
|
||||
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef'
|
||||
}
|
||||
sent = []
|
||||
|
||||
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
|
||||
with self.assertRaises(UPnPError):
|
||||
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
|
||||
|
||||
async def test_m_search_reply_unicast(self):
|
||||
replies = {
|
||||
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
|
||||
|
|
Loading…
Reference in a new issue