forked from LBRYCommunity/lbry-sdk
allow requests that match allowed-origin
This commit is contained in:
parent
f4217d5593
commit
e2db99f7ab
2 changed files with 27 additions and 0 deletions
|
@ -267,6 +267,20 @@ class AuthJSONRPCServer(AuthorizedBase):
|
||||||
if conf.settings['api_host'] == '0.0.0.0':
|
if conf.settings['api_host'] == '0.0.0.0':
|
||||||
return True
|
return True
|
||||||
server, port = self.get_server_port(source)
|
server, port = self.get_server_port(source)
|
||||||
|
return self._check_server_port(server, port)
|
||||||
|
|
||||||
|
def _check_server_port(self, server, port):
|
||||||
|
api = (conf.settings['api_host'], conf.settings['api_port'])
|
||||||
|
return (server, port) == api or self._is_from_allowed_origin(server, port)
|
||||||
|
|
||||||
|
def _is_from_allowed_origin(self, server, port):
|
||||||
|
allowed_origin = conf.settings['allowed_origin']
|
||||||
|
if not allowed_origin:
|
||||||
|
return False
|
||||||
|
if allowed_origin == '*':
|
||||||
|
return True
|
||||||
|
allowed_server, allowed_port = self.get_server_port(allowed_origin)
|
||||||
|
return (allowed_server, allowed_port) == (server, port)
|
||||||
return (
|
return (
|
||||||
server == conf.settings['api_host'] and
|
server == conf.settings['api_host'] and
|
||||||
port == conf.settings['api_port'])
|
port == conf.settings['api_port'])
|
||||||
|
|
|
@ -51,3 +51,16 @@ class AuthJSONRPCServerTest(unittest.TestCase):
|
||||||
request.getHeader = mock.Mock(return_value='http://the_api:1111?settings')
|
request.getHeader = mock.Mock(return_value='http://the_api:1111?settings')
|
||||||
self.assertTrue(self.server._check_header_source(request, 'Referer'))
|
self.assertTrue(self.server._check_header_source(request, 'Referer'))
|
||||||
request.getHeader.assert_called_with('Referer')
|
request.getHeader.assert_called_with('Referer')
|
||||||
|
|
||||||
|
def test_request_is_allowed_when_matching_allowed_origin_setting(self):
|
||||||
|
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
|
||||||
|
request = mock.Mock(['getHeader'])
|
||||||
|
request.getHeader = mock.Mock(return_value='http://example.com:1234')
|
||||||
|
self.assertTrue(self.server._check_header_source(request, 'Origin'))
|
||||||
|
|
||||||
|
def test_request_is_rejected_when_not_matching_allowed_origin_setting(self):
|
||||||
|
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
|
||||||
|
request = mock.Mock(['getHeader'])
|
||||||
|
# note the ports don't match
|
||||||
|
request.getHeader = mock.Mock(return_value='http://example.com:1235')
|
||||||
|
self.assertFalse(self.server._check_header_source(request, 'Origin'))
|
||||||
|
|
Loading…
Add table
Reference in a new issue