Merge pull request #2966 from lbryio/check-origin

add `allowed_origin` to config, by default no longer allow any requests which pass Origin in header
This commit is contained in:
Lex Berezhny 2020-06-03 14:39:57 -04:00 committed by GitHub
commit 3c8bec61d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 0 deletions

View file

@ -625,6 +625,9 @@ class Config(CLIConfig):
previous_names=['upload_log', 'upload_log', 'share_debug_info']
)
track_bandwidth = Toggle("Track bandwidth usage", True)
allowed_origin = String(
"Allowed `Origin` header value for API request (sent by browser), use * to allow "
"all hosts; default is to only allow API requests with no `Origin` value.", "")
# media server
streaming_server = String('Host name and port to serve streaming media over range requests',

View file

@ -47,6 +47,7 @@ from lbry.extras.daemon.componentmanager import ComponentManager
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
from lbry.extras.daemon import comment_client
from lbry.extras.daemon.undecorated import undecorated
from lbry.extras.daemon.security import ensure_request_allowed
from lbry.file_analysis import VideoFileAnalyzer
from lbry.schema.claim import Claim
from lbry.schema.url import URL
@ -566,6 +567,7 @@ class Daemon(metaclass=JSONRPCServerType):
log.info("finished shutting down")
async def handle_old_jsonrpc(self, request):
ensure_request_allowed(request, self.conf)
data = await request.json()
params = data.get('params', {})
include_protobuf = params.pop('include_protobuf', False) if isinstance(params, dict) else False

View file

@ -0,0 +1,31 @@
import logging
from aiohttp import web
log = logging.getLogger(__name__)
def ensure_request_allowed(request, conf):
if is_request_allowed(request, conf):
return
if conf.allowed_origin:
log.warning(
"API requests with Origin '%s' are not allowed, "
"configuration 'allowed_origin' limits requests to: '%s'",
request.headers.get('Origin'), conf.allowed_origin
)
else:
log.warning(
"API requests with Origin '%s' are not allowed, "
"update configuration 'allowed_origin' to enable this origin.",
request.headers.get('Origin')
)
raise web.HTTPForbidden()
def is_request_allowed(request, conf) -> bool:
origin = request.headers.get('Origin')
return (
origin is None or
origin == conf.allowed_origin or
conf.allowed_origin == '*'
)

View file

@ -0,0 +1,53 @@
import unittest
from aiohttp.test_utils import make_mocked_request as request
from aiohttp.web import HTTPForbidden
from lbry.testcase import AsyncioTestCase
from lbry.conf import Config
from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure
class TestAllowedOrigin(unittest.TestCase):
def test_allowed_origin_default(self):
conf = Config()
# lack of Origin is always allowed
self.assertTrue(allowed(request('GET', '/'), conf))
# deny all other Origins
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'null'}), conf))
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf))
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf))
def test_allowed_origin_star(self):
conf = Config(allowed_origin='*')
# everything is allowed
self.assertTrue(allowed(request('GET', '/'), conf))
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'null'}), conf))
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf))
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf))
def test_allowed_origin_specified(self):
conf = Config(allowed_origin='localhost')
# no origin and only localhost are allowed
self.assertTrue(allowed(request('GET', '/'), conf))
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf))
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'null'}), conf))
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf))
def test_ensure_default(self):
conf = Config()
ensure(request('GET', '/'), conf)
with self.assertLogs() as log:
with self.assertRaises(HTTPForbidden):
ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf)
self.assertIn("'localhost' are not allowed", log.output[0])
def test_ensure_specific(self):
conf = Config(allowed_origin='localhost')
ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf)
with self.assertLogs() as log:
with self.assertRaises(HTTPForbidden):
ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)
self.assertIn("'hackers.com' are not allowed", log.output[0])
self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0])