This commit is contained in:
Jack Robison 2018-10-25 15:28:00 -04:00
parent 2f8b10f3c7
commit ab0b9707f2
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
11 changed files with 306 additions and 136 deletions

View file

@ -24,7 +24,8 @@ def get_help(command):
)
def main():
def main(argv=None, loop=None):
argv = argv or sys.argv
commands = [n for n in dir(UPnP) if hasattr(getattr(UPnP, n, None), "_cli")]
help_str = "\n".join(textwrap.wrap(
" | ".join(commands), 100, initial_indent=' ', subsequent_indent=' ', break_long_words=False
@ -40,7 +41,7 @@ def main():
"For help with a specific command:" \
" aioupnp help <command>\n" % (base_usage, help_str)
args = sys.argv[1:]
args = argv[1:]
if args[0] in ['help', '-h', '--help']:
if len(args) > 1:
if args[1] in commands:
@ -53,6 +54,7 @@ def main():
'gateway_address': '',
'lan_address': '',
'timeout': 30,
'unicast': False
}
options = OrderedDict()
@ -88,7 +90,7 @@ def main():
UPnP.run_cli(
command.replace('-', '_'), options, options.pop('lan_address'), options.pop('gateway_address'),
options.pop('timeout'), options.pop('interface'), kwargs
options.pop('timeout'), options.pop('interface'), options.pop('unicast'), kwargs, loop
)

View file

@ -181,7 +181,7 @@ class Gateway:
not_met = [
required for required in required_commands if required not in gateway._registered_commands
]
log.warning("found gateway %s at %s, but it does not implement required soap commands: %s",
log.debug("found gateway %s at %s, but it does not implement required soap commands: %s",
gateway.manufacturer_string, gateway.location, not_met)
ignored.add(datagram.location)
continue

View file

@ -60,18 +60,20 @@ class UPnP:
@classmethod
@cli
async def m_search(cls, lan_address: str = '', gateway_address: str = '', timeout: int = 1,
igd_args: OrderedDict = None, interface_name: str = 'default') -> Dict:
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
assert gateway_address and lan_address
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err)))
igd_args: OrderedDict = None, unicast: bool = True, interface_name: str = 'default',
loop=None) -> Dict:
if not lan_address or not gateway_address:
try:
lan_address, gateway_address = cls.get_lan_and_gateway(lan_address, gateway_address, interface_name)
assert gateway_address and lan_address
except Exception as err:
raise UPnPError("failed to get lan and gateway addresses for interface \"%s\": %s" % (interface_name,
str(err)))
if not igd_args:
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout)
igd_args, datagram = await fuzzy_m_search(lan_address, gateway_address, timeout, loop, unicast=unicast)
else:
igd_args = OrderedDict(igd_args)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout)
datagram = await m_search(lan_address, gateway_address, igd_args, timeout, loop, unicast=unicast)
return {
'lan_address': lan_address,
'gateway_address': gateway_address,
@ -340,7 +342,7 @@ class UPnP:
@classmethod
def run_cli(cls, method, igd_args: OrderedDict, lan_address: str = '', gateway_address: str = '', timeout: int = 30,
interface_name: str = 'default', kwargs: dict = None) -> None:
interface_name: str = 'default', unicast: bool = True, kwargs: dict = None, loop=None) -> None:
"""
:param method: the command name
:param igd_args: ordered case sensitive M-SEARCH headers, if provided all headers to be used must be provided
@ -349,18 +351,19 @@ class UPnP:
:param timeout: timeout, in seconds
:param interface_name: name of the network interface, the default is aliased to 'default'
:param kwargs: keyword arguments for the command
:param loop: EventLoop, used for testing
"""
kwargs = kwargs or {}
igd_args = igd_args
timeout = int(timeout)
loop = asyncio.get_event_loop_policy().get_event_loop()
loop = loop or asyncio.get_event_loop_policy().get_event_loop()
fut: asyncio.Future = asyncio.Future()
async def wrapper(): # wrap the upnp setup and call of the command in a coroutine
if method == 'm_search': # if we're only m_searching don't do any device discovery
fn = lambda *_a, **_kw: cls.m_search(
lan_address, gateway_address, timeout, igd_args, interface_name
lan_address, gateway_address, timeout, igd_args, unicast, interface_name, loop
)
else: # automatically discover the gateway
try:

View file

@ -22,7 +22,7 @@ setup(
long_description_content_type='text/markdown',
url="https://github.com/lbryio/aioupnp",
license=__license__,
packages=find_packages(),
packages=find_packages(exclude=('tests',)),
entry_points={'console_scripts': console_scripts},
install_requires=[
'netifaces',

113
tests/__init__.py Normal file
View file

@ -0,0 +1,113 @@
import asyncio
import unittest
from unittest.case import _Outcome
try:
from asyncio.runners import _cancel_all_tasks
except ImportError:
# this is only available in py3.7
def _cancel_all_tasks(loop):
pass
class TestBase(unittest.TestCase):
# Implementation inspired by discussion:
# https://bugs.python.org/issue32972
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
async def doAsyncCleanups(self):
pass
def run(self, result=None):
orig_result = result
if result is None:
result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None)
if startTestRun is not None:
startTestRun()
result.startTest(self)
testMethod = getattr(self, self._testMethodName)
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
expecting_failure_method = getattr(testMethod,
"__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self,
"__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result)
try:
self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self):
self.setUp()
loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine):
loop.run_until_complete(possible_coroutine)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown())
self.tearDown()
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
result.stopTest(self)
if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None)
if stopTestRun is not None:
stopTestRun()
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
def setUp(self):
self.loop = asyncio.get_event_loop_policy().get_event_loop()

View file

@ -75,3 +75,63 @@ def mock_tcp_endpoint_factory(loop, replies=None, delay_reply=0.0, sent_packets=
mock_socket.return_value = mock_sock
loop.create_connection = create_connection
yield
@contextlib.contextmanager
def mock_tcp_and_udp(loop, udp_expected_addr, udp_replies=None, udp_delay_reply=0.0, sent_udp_packets=None,
tcp_replies=None, tcp_delay_reply=0.0, tcp_sent_packets=None):
sent_udp_packets = sent_udp_packets if sent_udp_packets is not None else []
udp_replies = udp_replies or {}
tcp_sent_packets = tcp_sent_packets if tcp_sent_packets is not None else []
tcp_replies = tcp_replies or {}
async def create_connection(protocol_factory, host=None, port=None):
def write(p: asyncio.Protocol):
def _write(data):
tcp_sent_packets.append(data)
if data in tcp_replies:
loop.call_later(tcp_delay_reply, p.data_received, tcp_replies[data])
return _write
protocol = protocol_factory()
transport = asyncio.Transport(extra={'socket': mock.Mock(spec=socket.socket)})
transport.close = lambda: None
transport.write = write(protocol)
protocol.connection_made(transport)
return transport, protocol
async def create_datagram_endpoint(proto_lam, sock=None):
def sendto(p: asyncio.DatagramProtocol):
def _sendto(data, addr):
sent_udp_packets.append(data)
if (data, addr) in udp_replies:
loop.call_later(udp_delay_reply, p.datagram_received, udp_replies[(data, addr)],
(udp_expected_addr, 1900))
return _sendto
protocol = proto_lam()
transport = asyncio.DatagramTransport(extra={'socket': mock_sock})
transport.close = lambda: mock_sock.close()
mock_sock.sendto = sendto(protocol)
transport.sendto = mock_sock.sendto
protocol.connection_made(transport)
return transport, protocol
with mock.patch('socket.socket') as mock_socket:
mock_sock = mock.Mock(spec=socket.socket)
mock_sock.setsockopt = lambda *_: None
mock_sock.bind = lambda *_: None
mock_sock.setblocking = lambda *_: None
mock_sock.getsockname = lambda: "0.0.0.0"
mock_sock.getpeername = lambda: ""
mock_sock.close = lambda: None
mock_sock.type = socket.SOCK_DGRAM
mock_sock.fileno = lambda: 7
mock_socket.return_value = mock_sock
loop.create_datagram_endpoint = create_datagram_endpoint
loop.create_connection = create_connection
yield

View file

@ -1,113 +0,0 @@
import asyncio
import unittest
from unittest.case import _Outcome
try:
from asyncio.runners import _cancel_all_tasks
except ImportError:
# this is only available in py3.7
def _cancel_all_tasks(loop):
pass
class TestBase(unittest.TestCase):
# Implementation inspired by discussion:
# https://bugs.python.org/issue32972
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
async def doAsyncCleanups(self):
pass
def run(self, result=None):
orig_result = result
if result is None:
result = self.defaultTestResult()
startTestRun = getattr(result, 'startTestRun', None)
if startTestRun is not None:
startTestRun()
result.startTest(self)
testMethod = getattr(self, self._testMethodName)
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False)):
# If the class or method was skipped.
try:
skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
or getattr(testMethod, '__unittest_skip_why__', ''))
self._addSkip(result, self, skip_why)
finally:
result.stopTest(self)
return
expecting_failure_method = getattr(testMethod,
"__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self,
"__unittest_expecting_failure__", False)
expecting_failure = expecting_failure_class or expecting_failure_method
outcome = _Outcome(result)
try:
self._outcome = outcome
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(True)
with outcome.testPartExecutor(self):
self.setUp()
loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
possible_coroutine = testMethod()
if asyncio.iscoroutine(possible_coroutine):
loop.run_until_complete(possible_coroutine)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
loop.run_until_complete(self.asyncTearDown())
self.tearDown()
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
self._feedErrorsToResult(result, outcome.errors)
if outcome.success:
if expecting_failure:
if outcome.expectedFailure:
self._addExpectedFailure(result, outcome.expectedFailure)
else:
self._addUnexpectedSuccess(result)
else:
result.addSuccess(self)
return result
finally:
result.stopTest(self)
if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None)
if stopTestRun is not None:
stopTestRun()
# explicitly break reference cycles:
# outcome.errors -> frame -> outcome -> outcome.errors
# outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
outcome.errors.clear()
outcome.expectedFailure = None
# clear the outcome, no more needed
self._outcome = None
def setUp(self):
self.loop = asyncio.get_event_loop_policy().get_event_loop()

View file

@ -1,7 +1,7 @@
from aioupnp.fault import UPnPError
from aioupnp.protocols.scpd import scpd_post, scpd_get
from . import TestBase
from .mocks import mock_tcp_endpoint_factory
from tests import TestBase
from tests.mocks import mock_tcp_endpoint_factory
class TestSCPDGet(TestBase):

View file

@ -4,8 +4,8 @@ from aioupnp.protocols.m_search_patterns import packet_generator
from aioupnp.serialization.ssdp import SSDPDatagram
from aioupnp.constants import SSDP_IP_ADDRESS
from aioupnp.protocols.ssdp import fuzzy_m_search, m_search
from . import TestBase
from .mocks import mock_datagram_endpoint_factory
from tests import TestBase
from tests.mocks import mock_datagram_endpoint_factory
class TestSSDP(TestBase):

104
tests/test_cli.py Normal file

File diff suppressed because one or more lines are too long

View file

@ -1,11 +1,12 @@
from aioupnp.fault import UPnPError
from aioupnp.protocols.scpd import scpd_post, scpd_get
from . import TestBase
from .mocks import mock_tcp_endpoint_factory
from tests import TestBase
from tests.mocks import mock_tcp_endpoint_factory
from collections import OrderedDict
from aioupnp.gateway import Gateway
from aioupnp.serialization.ssdp import SSDPDatagram
class TestDiscoverCommands(TestBase):
gateway_address = "10.0.0.1"
soap_port = 49152