allow requests that match allowed-origin

This commit is contained in:
Job Evers-Meltzer 2017-01-12 09:49:57 -06:00 committed by Jack Robison
parent f4217d5593
commit e2db99f7ab
2 changed files with 27 additions and 0 deletions

View file

@ -267,6 +267,20 @@ class AuthJSONRPCServer(AuthorizedBase):
if conf.settings['api_host'] == '0.0.0.0': if conf.settings['api_host'] == '0.0.0.0':
return True return True
server, port = self.get_server_port(source) server, port = self.get_server_port(source)
return self._check_server_port(server, port)
def _check_server_port(self, server, port):
api = (conf.settings['api_host'], conf.settings['api_port'])
return (server, port) == api or self._is_from_allowed_origin(server, port)
def _is_from_allowed_origin(self, server, port):
allowed_origin = conf.settings['allowed_origin']
if not allowed_origin:
return False
if allowed_origin == '*':
return True
allowed_server, allowed_port = self.get_server_port(allowed_origin)
return (allowed_server, allowed_port) == (server, port)
return ( return (
server == conf.settings['api_host'] and server == conf.settings['api_host'] and
port == conf.settings['api_port']) port == conf.settings['api_port'])

View file

@ -51,3 +51,16 @@ class AuthJSONRPCServerTest(unittest.TestCase):
request.getHeader = mock.Mock(return_value='http://the_api:1111?settings') request.getHeader = mock.Mock(return_value='http://the_api:1111?settings')
self.assertTrue(self.server._check_header_source(request, 'Referer')) self.assertTrue(self.server._check_header_source(request, 'Referer'))
request.getHeader.assert_called_with('Referer') request.getHeader.assert_called_with('Referer')
def test_request_is_allowed_when_matching_allowed_origin_setting(self):
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://example.com:1234')
self.assertTrue(self.server._check_header_source(request, 'Origin'))
def test_request_is_rejected_when_not_matching_allowed_origin_setting(self):
mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'})
request = mock.Mock(['getHeader'])
# note the ports don't match
request.getHeader = mock.Mock(return_value='http://example.com:1235')
self.assertFalse(self.server._check_header_source(request, 'Origin'))