From f3ee6603de9b4365b30b331694c903b972b4c4a9 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 3 Jun 2020 13:55:20 -0400 Subject: [PATCH] improve allowed_origin request handling --- lbry/extras/daemon/daemon.py | 6 ++--- lbry/extras/daemon/security.py | 24 +++++++++++++++++++ .../lbrynet_daemon/test_allowed_origin.py | 20 +++++++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 2c121ff8c..7508c1b5e 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -47,7 +47,7 @@ from lbry.extras.daemon.componentmanager import ComponentManager from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder from lbry.extras.daemon import comment_client from lbry.extras.daemon.undecorated import undecorated -from lbry.extras.daemon.security import is_request_allowed +from lbry.extras.daemon.security import ensure_request_allowed from lbry.file_analysis import VideoFileAnalyzer from lbry.schema.claim import Claim from lbry.schema.url import URL @@ -567,9 +567,7 @@ class Daemon(metaclass=JSONRPCServerType): log.info("finished shutting down") async def handle_old_jsonrpc(self, request): - if is_request_allowed(request, self.conf): - log.warning("API request from origin '%s' is not allowed", request.headers.get('Origin')) - raise web.HTTPForbidden() + ensure_request_allowed(request, self.conf) data = await request.json() params = data.get('params', {}) include_protobuf = params.pop('include_protobuf', False) if isinstance(params, dict) else False diff --git a/lbry/extras/daemon/security.py b/lbry/extras/daemon/security.py index 18a8326ec..80fe0ea93 100644 --- a/lbry/extras/daemon/security.py +++ b/lbry/extras/daemon/security.py @@ -1,3 +1,27 @@ +import logging +from aiohttp import web + +log = logging.getLogger(__name__) + + +def ensure_request_allowed(request, conf): + if is_request_allowed(request, conf): + return + if conf.allowed_origin: + log.warning( + "API requests with Origin '%s' are not allowed, " + "configuration 'allowed_origin' limits requests to: '%s'", + request.headers.get('Origin'), conf.allowed_origin + ) + else: + log.warning( + "API requests with Origin '%s' are not allowed, " + "update configuration 'allowed_origin' to enable this origin.", + request.headers.get('Origin') + ) + raise web.HTTPForbidden() + + def is_request_allowed(request, conf) -> bool: origin = request.headers.get('Origin', 'null') if origin == 'null' or conf.allowed_origin in ('*', origin): diff --git a/tests/unit/lbrynet_daemon/test_allowed_origin.py b/tests/unit/lbrynet_daemon/test_allowed_origin.py index 7d679a3dc..531863ce4 100644 --- a/tests/unit/lbrynet_daemon/test_allowed_origin.py +++ b/tests/unit/lbrynet_daemon/test_allowed_origin.py @@ -1,10 +1,11 @@ import unittest from aiohttp.test_utils import make_mocked_request as request +from aiohttp.web import HTTPForbidden from lbry.testcase import AsyncioTestCase from lbry.conf import Config -from lbry.extras.daemon.security import is_request_allowed as allowed +from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure class TestAllowedOrigin(unittest.TestCase): @@ -34,3 +35,20 @@ class TestAllowedOrigin(unittest.TestCase): self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'null'}), conf)) self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf)) self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)) + + def test_ensure_default(self): + conf = Config() + ensure(request('GET', '/'), conf) + with self.assertLogs() as log: + with self.assertRaises(HTTPForbidden): + ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf) + self.assertIn("'localhost' are not allowed", log.output[0]) + + def test_ensure_specific(self): + conf = Config(allowed_origin='localhost') + ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf) + with self.assertLogs() as log: + with self.assertRaises(HTTPForbidden): + ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf) + self.assertIn("'hackers.com' are not allowed", log.output[0]) + self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0])