add tests + coverage
-test_socket_setup_error -ssdp test_deadbeef_response
This commit is contained in:
parent
1b25c009ca
commit
8e07d7f390
5 changed files with 41 additions and 19 deletions
|
@ -188,7 +188,7 @@ class Gateway:
|
||||||
igd_args: typing.Dict[str, typing.Union[int, str]],
|
igd_args: typing.Dict[str, typing.Union[int, str]],
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||||
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, set())
|
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop)
|
||||||
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
|
gateway = await cls._try_gateway_from_ssdp(datagram, lan_address, gateway_address, loop)
|
||||||
if not gateway:
|
if not gateway:
|
||||||
raise UPnPError("no gateway found for given args")
|
raise UPnPError("no gateway found for given args")
|
||||||
|
@ -199,7 +199,7 @@ class Gateway:
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Gateway':
|
||||||
ignored: typing.Set[str] = set()
|
ignored: typing.Set[str] = set()
|
||||||
ssdp_proto = await multi_m_search(
|
ssdp_proto = await multi_m_search(
|
||||||
lan_address, gateway_address, timeout, loop, ignored
|
lan_address, gateway_address, timeout, loop
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -24,12 +24,11 @@ class PendingSearch(typing.NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
class SSDPProtocol(MulticastProtocol):
|
class SSDPProtocol(MulticastProtocol):
|
||||||
def __init__(self, multicast_address: str, lan_address: str, ignored: Optional[Set[str]] = None,
|
def __init__(self, multicast_address: str, lan_address: str,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(multicast_address, lan_address)
|
super().__init__(multicast_address, lan_address)
|
||||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||||
self.transport: Optional[DatagramTransport] = None
|
self.transport: Optional[DatagramTransport] = None
|
||||||
self._ignored: Set[str] = ignored or set() # ignored locations
|
|
||||||
self._pending_searches: List[PendingSearch] = []
|
self._pending_searches: List[PendingSearch] = []
|
||||||
self.notifications: List[SSDPDatagram] = []
|
self.notifications: List[SSDPDatagram] = []
|
||||||
self.connected = asyncio.Event(loop=self.loop)
|
self.connected = asyncio.Event(loop=self.loop)
|
||||||
|
@ -52,9 +51,6 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
def _callback_m_search_ok(self, address: str, packet: SSDPDatagram) -> None:
|
||||||
if packet.location in self._ignored:
|
|
||||||
return
|
|
||||||
|
|
||||||
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
|
futures: Set['asyncio.Future[SSDPDatagram]'] = set()
|
||||||
replied: List[PendingSearch] = []
|
replied: List[PendingSearch] = []
|
||||||
|
|
||||||
|
@ -137,13 +133,13 @@ class SSDPProtocol(MulticastProtocol):
|
||||||
# return
|
# return
|
||||||
|
|
||||||
|
|
||||||
async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asyncio.AbstractEventLoop] = None,
|
async def listen_ssdp(lan_address: str, gateway_address: str,
|
||||||
ignored: Optional[Set[str]] = None) -> Tuple[SSDPProtocol, str, str]:
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> Tuple[SSDPProtocol, str, str]:
|
||||||
loop = loop or asyncio.get_event_loop()
|
loop = loop or asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
|
sock: socket.socket = SSDPProtocol.create_multicast_socket(lan_address)
|
||||||
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
listen_result: Tuple[asyncio.BaseTransport, asyncio.BaseProtocol] = await loop.create_datagram_endpoint(
|
||||||
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address, ignored), sock=sock
|
lambda: SSDPProtocol(SSDP_IP_ADDRESS, lan_address), sock=sock
|
||||||
)
|
)
|
||||||
protocol = listen_result[1]
|
protocol = listen_result[1]
|
||||||
assert isinstance(protocol, SSDPProtocol)
|
assert isinstance(protocol, SSDPProtocol)
|
||||||
|
@ -156,10 +152,10 @@ async def listen_ssdp(lan_address: str, gateway_address: str, loop: Optional[asy
|
||||||
|
|
||||||
|
|
||||||
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
|
async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[str, typing.Union[int, str]],
|
||||||
timeout: int = 1, loop: Optional[asyncio.AbstractEventLoop] = None,
|
timeout: int = 1,
|
||||||
ignored: Set[str] = None) -> SSDPDatagram:
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPDatagram:
|
||||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||||
lan_address, gateway_address, loop, ignored
|
lan_address, gateway_address, loop
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
return await protocol.m_search(address=gateway_address, timeout=timeout, datagrams=[datagram_args])
|
||||||
|
@ -170,11 +166,10 @@ async def m_search(lan_address: str, gateway_address: str, datagram_args: Dict[s
|
||||||
|
|
||||||
|
|
||||||
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
|
async def multi_m_search(lan_address: str, gateway_address: str, timeout: int = 3,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> SSDPProtocol:
|
||||||
ignored: Set[str] = None) -> SSDPProtocol:
|
|
||||||
loop = loop or asyncio.get_event_loop()
|
loop = loop or asyncio.get_event_loop()
|
||||||
protocol, gateway_address, lan_address = await listen_ssdp(
|
protocol, gateway_address, lan_address = await listen_ssdp(
|
||||||
lan_address, gateway_address, loop, ignored
|
lan_address, gateway_address, loop
|
||||||
)
|
)
|
||||||
fut = asyncio.ensure_future(protocol.send_m_searches(
|
fut = asyncio.ensure_future(protocol.send_m_searches(
|
||||||
address=gateway_address, datagrams=list(packet_generator())
|
address=gateway_address, datagrams=list(packet_generator())
|
||||||
|
|
|
@ -164,7 +164,12 @@ class SSDPDatagram:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
|
def decode(cls, datagram: bytes) -> 'SSDPDatagram':
|
||||||
|
try:
|
||||||
packet = cls._from_string(datagram.decode())
|
packet = cls._from_string(datagram.decode())
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise UPnPError(
|
||||||
|
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
|
||||||
|
)
|
||||||
if packet is None:
|
if packet is None:
|
||||||
raise UPnPError(
|
raise UPnPError(
|
||||||
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
|
f"failed to decode datagram: {binascii.hexlify(datagram).decode()}"
|
||||||
|
|
|
@ -16,7 +16,8 @@ except ImportError:
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
|
||||||
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False):
|
tcp_replies=None, tcp_delay_reply=0.0, sent_tcp_packets=None, add_potato_datagrams=False,
|
||||||
|
raise_oserror_on_bind=False):
|
||||||
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
|
||||||
udp_replies = udp_replies or {}
|
udp_replies = udp_replies or {}
|
||||||
|
|
||||||
|
@ -72,7 +73,13 @@ def mock_tcp_and_udp(loop, udp_expected_addr=None, udp_replies=None, udp_delay_r
|
||||||
with mock.patch('socket.socket') as mock_socket:
|
with mock.patch('socket.socket') as mock_socket:
|
||||||
mock_sock = mock.Mock(spec=socket.socket)
|
mock_sock = mock.Mock(spec=socket.socket)
|
||||||
mock_sock.setsockopt = lambda *_: None
|
mock_sock.setsockopt = lambda *_: None
|
||||||
mock_sock.bind = lambda *_: None
|
|
||||||
|
def bind(*_):
|
||||||
|
if raise_oserror_on_bind:
|
||||||
|
raise OSError()
|
||||||
|
return
|
||||||
|
|
||||||
|
mock_sock.bind = bind
|
||||||
mock_sock.setblocking = lambda *_: None
|
mock_sock.setblocking = lambda *_: None
|
||||||
mock_sock.getsockname = lambda: "0.0.0.0"
|
mock_sock.getsockname = lambda: "0.0.0.0"
|
||||||
mock_sock.getpeername = lambda: ""
|
mock_sock.getpeername = lambda: ""
|
||||||
|
|
|
@ -28,6 +28,11 @@ class TestSSDP(AsyncioTestCase):
|
||||||
])
|
])
|
||||||
reply_packet = SSDPDatagram("OK", reply_args)
|
reply_packet = SSDPDatagram("OK", reply_args)
|
||||||
|
|
||||||
|
async def test_socket_setup_error(self):
|
||||||
|
with mock_tcp_and_udp(self.loop, raise_oserror_on_bind=True):
|
||||||
|
with self.assertRaises(UPnPError):
|
||||||
|
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
|
||||||
|
|
||||||
async def test_transport_not_connected_error(self):
|
async def test_transport_not_connected_error(self):
|
||||||
try:
|
try:
|
||||||
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
|
await SSDPProtocol('', '').m_search('1.2.3.4', 2, [self.query_packet.as_dict()])
|
||||||
|
@ -35,6 +40,16 @@ class TestSSDP(AsyncioTestCase):
|
||||||
except UPnPError as err:
|
except UPnPError as err:
|
||||||
self.assertEqual(str(err), "SSDP transport not connected")
|
self.assertEqual(str(err), "SSDP transport not connected")
|
||||||
|
|
||||||
|
async def test_deadbeef_response(self):
|
||||||
|
replies = {
|
||||||
|
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): b'\xde\xad\xbe\xef'
|
||||||
|
}
|
||||||
|
sent = []
|
||||||
|
|
||||||
|
with mock_tcp_and_udp(self.loop, udp_replies=replies, udp_expected_addr="10.0.0.1", sent_udp_packets=sent):
|
||||||
|
with self.assertRaises(UPnPError):
|
||||||
|
await m_search("10.0.0.2", "10.0.0.1", self.successful_args, timeout=1, loop=self.loop)
|
||||||
|
|
||||||
async def test_m_search_reply_unicast(self):
|
async def test_m_search_reply_unicast(self):
|
||||||
replies = {
|
replies = {
|
||||||
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
|
(self.query_packet.encode().encode(), ("10.0.0.1", 1900)): self.reply_packet.encode().encode()
|
||||||
|
|
Loading…
Reference in a new issue