improve allowed_origin request handling
This commit is contained in:
parent
ee0aabda1d
commit
f3ee6603de
3 changed files with 45 additions and 5 deletions
|
@ -47,7 +47,7 @@ from lbry.extras.daemon.componentmanager import ComponentManager
|
||||||
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
|
from lbry.extras.daemon.json_response_encoder import JSONResponseEncoder
|
||||||
from lbry.extras.daemon import comment_client
|
from lbry.extras.daemon import comment_client
|
||||||
from lbry.extras.daemon.undecorated import undecorated
|
from lbry.extras.daemon.undecorated import undecorated
|
||||||
from lbry.extras.daemon.security import is_request_allowed
|
from lbry.extras.daemon.security import ensure_request_allowed
|
||||||
from lbry.file_analysis import VideoFileAnalyzer
|
from lbry.file_analysis import VideoFileAnalyzer
|
||||||
from lbry.schema.claim import Claim
|
from lbry.schema.claim import Claim
|
||||||
from lbry.schema.url import URL
|
from lbry.schema.url import URL
|
||||||
|
@ -567,9 +567,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
||||||
log.info("finished shutting down")
|
log.info("finished shutting down")
|
||||||
|
|
||||||
async def handle_old_jsonrpc(self, request):
|
async def handle_old_jsonrpc(self, request):
|
||||||
if is_request_allowed(request, self.conf):
|
ensure_request_allowed(request, self.conf)
|
||||||
log.warning("API request from origin '%s' is not allowed", request.headers.get('Origin'))
|
|
||||||
raise web.HTTPForbidden()
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
params = data.get('params', {})
|
params = data.get('params', {})
|
||||||
include_protobuf = params.pop('include_protobuf', False) if isinstance(params, dict) else False
|
include_protobuf = params.pop('include_protobuf', False) if isinstance(params, dict) else False
|
||||||
|
|
|
@ -1,3 +1,27 @@
|
||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_request_allowed(request, conf):
|
||||||
|
if is_request_allowed(request, conf):
|
||||||
|
return
|
||||||
|
if conf.allowed_origin:
|
||||||
|
log.warning(
|
||||||
|
"API requests with Origin '%s' are not allowed, "
|
||||||
|
"configuration 'allowed_origin' limits requests to: '%s'",
|
||||||
|
request.headers.get('Origin'), conf.allowed_origin
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"API requests with Origin '%s' are not allowed, "
|
||||||
|
"update configuration 'allowed_origin' to enable this origin.",
|
||||||
|
request.headers.get('Origin')
|
||||||
|
)
|
||||||
|
raise web.HTTPForbidden()
|
||||||
|
|
||||||
|
|
||||||
def is_request_allowed(request, conf) -> bool:
|
def is_request_allowed(request, conf) -> bool:
|
||||||
origin = request.headers.get('Origin', 'null')
|
origin = request.headers.get('Origin', 'null')
|
||||||
if origin == 'null' or conf.allowed_origin in ('*', origin):
|
if origin == 'null' or conf.allowed_origin in ('*', origin):
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from aiohttp.test_utils import make_mocked_request as request
|
from aiohttp.test_utils import make_mocked_request as request
|
||||||
|
from aiohttp.web import HTTPForbidden
|
||||||
|
|
||||||
from lbry.testcase import AsyncioTestCase
|
from lbry.testcase import AsyncioTestCase
|
||||||
from lbry.conf import Config
|
from lbry.conf import Config
|
||||||
from lbry.extras.daemon.security import is_request_allowed as allowed
|
from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure
|
||||||
|
|
||||||
|
|
||||||
class TestAllowedOrigin(unittest.TestCase):
|
class TestAllowedOrigin(unittest.TestCase):
|
||||||
|
@ -34,3 +35,20 @@ class TestAllowedOrigin(unittest.TestCase):
|
||||||
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'null'}), conf))
|
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'null'}), conf))
|
||||||
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf))
|
self.assertTrue(allowed(request('GET', '/', headers={'Origin': 'localhost'}), conf))
|
||||||
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf))
|
self.assertFalse(allowed(request('GET', '/', headers={'Origin': 'hackers.com'}), conf))
|
||||||
|
|
||||||
|
def test_ensure_default(self):
|
||||||
|
conf = Config()
|
||||||
|
ensure(request('GET', '/'), conf)
|
||||||
|
with self.assertLogs() as log:
|
||||||
|
with self.assertRaises(HTTPForbidden):
|
||||||
|
ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf)
|
||||||
|
self.assertIn("'localhost' are not allowed", log.output[0])
|
||||||
|
|
||||||
|
def test_ensure_specific(self):
|
||||||
|
conf = Config(allowed_origin='localhost')
|
||||||
|
ensure(request('GET', '/', headers={'Origin': 'localhost'}), conf)
|
||||||
|
with self.assertLogs() as log:
|
||||||
|
with self.assertRaises(HTTPForbidden):
|
||||||
|
ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)
|
||||||
|
self.assertIn("'hackers.com' are not allowed", log.output[0])
|
||||||
|
self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0])
|
||||||
|
|
Loading…
Add table
Reference in a new issue