diff --git a/lbry/conf.py b/lbry/conf.py index e9b293933..deb9222f2 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -625,6 +625,9 @@ class Config(CLIConfig): previous_names=['upload_log', 'upload_log', 'share_debug_info'] ) track_bandwidth = Toggle("Track bandwidth usage", True) + allowed_origin = String( + "Allowed `Origin` header value for API request (sent by browser), use * to allow " + "all hosts; default is to only allow API requests with no `Origin` value.", "") # media server streaming_server = String('Host name and port to serve streaming media over range requests', diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 756ff6bc7..7508c1b5e 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -47,6 +47,7 @@ from lbry.extras.daemon.componentmanager import ComponentManager from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder from lbry.extras.daemon import comment_client from lbry.extras.daemon.undecorated import undecorated +from lbry.extras.daemon.security import ensure_request_allowed from lbry.file_analysis import VideoFileAnalyzer from lbry.schema.claim import Claim from lbry.schema.url import URL @@ -566,6 +567,7 @@ class Daemon(metaclass=JSONRPCServerType): log.info("finished shutting down") async def handle_old_jsonrpc(self, request): + ensure_request_allowed(request, self.conf) data = await request.json() params = data.get('params', {}) include_protobuf = params.pop('include_protobuf', False) if isinstance(params, dict) else False diff --git a/lbry/extras/daemon/security.py b/lbry/extras/daemon/security.py new file mode 100644 index 000000000..c6ecde6d8 --- /dev/null +++ b/lbry/extras/daemon/security.py @@ -0,0 +1,31 @@ +import logging +from aiohttp import web + +log = logging.getLogger(__name__) + + +def ensure_request_allowed(request, conf): + if is_request_allowed(request, conf): + return + if conf.allowed_origin: + log.warning( + "API requests with Origin '%s' are not allowed, " + "configuration 'allowed_origin' limits requests to: '%s'", + request.headers.get('Origin'), conf.allowed_origin + ) + else: + log.warning( + "API requests with Origin '%s' are not allowed, " + "update configuration 'allowed_origin' to enable this origin.", + request.headers.get('Origin') + ) + raise web.HTTPForbidden() + + +def is_request_allowed(request, conf) -> bool: + origin = request.headers.get('Origin') + return ( + origin is None or + origin == conf.allowed_origin or + conf.allowed_origin == '*' + ) diff --git a/tests/unit/lbrynet_daemon/test_allowed_origin.py b/tests/unit/lbrynet_daemon/test_allowed_origin.py new file mode 100644 index 000000000..230210202 --- /dev/null +++ b/tests/unit/lbrynet_daemon/test_allowed_origin.py @@ -0,0 +1,53 @@ +import unittest + +from aiohttp.test_utils import make_mocked_request as request +from aiohttp.web import HTTPForbidden + +from lbry.testcase import AsyncioTestCase +from lbry.conf import Config +from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure + + +class TestAllowedOrigin(unittest.TestCase): + + def test_allowed_origin_default(self): + conf = Config() + # lack of Origin is always allowed + self.assertTrue(allowed(request('GET', '/'), conf)) + # deny all other Origins + self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'null'}), conf)) + self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf)) + self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)) + + def test_allowed_origin_star(self): + conf = Config(allowed_origin='*') + # everything is allowed + self.assertTrue(allowed(request('GET', '/'), conf)) + self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'null'}), conf)) + self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf)) + self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)) + + def test_allowed_origin_specified(self): + conf = Config(allowed_origin='localhost') + # no origin and only localhost are allowed + self.assertTrue(allowed(request('GET', '/'), conf)) + self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf)) + self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'null'}), conf)) + self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)) + + def test_ensure_default(self): + conf = Config() + ensure(request('GET', '/'), conf) + with self.assertLogs() as log: + with self.assertRaises(HTTPForbidden): + ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf) + self.assertIn("'localhost' are not allowed", log.output[0]) + + def test_ensure_specific(self): + conf = Config(allowed_origin='localhost') + ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf) + with self.assertLogs() as log: + with self.assertRaises(HTTPForbidden): + ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf) + self.assertIn("'hackers.com' are not allowed", log.output[0]) + self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0])