forked from LBRYCommunity/lbry-sdk
Merge pull request #404 from lbryio/match-allow-origin
Allow requests that match allowed-origin
This commit is contained in:
commit
5bf75ef139
2 changed files with 27 additions and 0 deletions
|
@ -267,6 +267,20 @@ class AuthJSONRPCServer(AuthorizedBase):
|
|||
if conf.settings['api_host'] == '0.0.0.0':
|
||||
return True
|
||||
server, port = self.get_server_port(source)
|
||||
return self._check_server_port(server, port)
|
||||
|
||||
def _check_server_port(self, server, port):
|
||||
api = (conf.settings['api_host'], conf.settings['api_port'])
|
||||
return (server, port) == api or self._is_from_allowed_origin(server, port)
|
||||
|
||||
def _is_from_allowed_origin(self, server, port):
|
||||
allowed_origin = conf.settings['allowed_origin']
|
||||
if not allowed_origin:
|
||||
return False
|
||||
if allowed_origin == '*':
|
||||
return True
|
||||
allowed_server, allowed_port = self.get_server_port(allowed_origin)
|
||||
return (allowed_server, allowed_port) == (server, port)
|
||||
return (
|
||||
server == conf.settings['api_host'] and
|
||||
port == conf.settings['api_port'])
|
||||
|
|
|
@ -51,3 +51,16 @@ class AuthJSONRPCServerTest(unittest.TestCase):
|
|||
request.getHeader = mock.Mock(return_value='http://the_api:1111?settings')
|
||||
self.assertTrue(self.server._check_header_source(request, 'Referer'))
|
||||
request.getHeader.assert_called_with('Referer')
|
||||
|
||||
def test_request_is_allowed_when_matching_allowed_origin_setting(self):
|
||||
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
|
||||
request = mock.Mock(['getHeader'])
|
||||
request.getHeader = mock.Mock(return_value='http://example.com:1234')
|
||||
self.assertTrue(self.server._check_header_source(request, 'Origin'))
|
||||
|
||||
def test_request_is_rejected_when_not_matching_allowed_origin_setting(self):
|
||||
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
|
||||
request = mock.Mock(['getHeader'])
|
||||
# note the ports don't match
|
||||
request.getHeader = mock.Mock(return_value='http://example.com:1235')
|
||||
self.assertFalse(self.server._check_header_source(request, 'Origin'))
|
||||
|
|
Loading…
Reference in a new issue