diff --git a/lbrynet/lbrynet_daemon/auth/server.py b/lbrynet/lbrynet_daemon/auth/server.py index 551013fb5..f601f7421 100644 --- a/lbrynet/lbrynet_daemon/auth/server.py +++ b/lbrynet/lbrynet_daemon/auth/server.py @@ -267,6 +267,20 @@ class AuthJSONRPCServer(AuthorizedBase): if conf.settings['api_host'] == '0.0.0.0': return True server, port = self.get_server_port(source) + return self._check_server_port(server, port) + + def _check_server_port(self, server, port): + api = (conf.settings['api_host'], conf.settings['api_port']) + return (server, port) == api or self._is_from_allowed_origin(server, port) + + def _is_from_allowed_origin(self, server, port): + allowed_origin = conf.settings['allowed_origin'] + if not allowed_origin: + return False + if allowed_origin == '*': + return True + allowed_server, allowed_port = self.get_server_port(allowed_origin) + return (allowed_server, allowed_port) == (server, port) return ( server == conf.settings['api_host'] and port == conf.settings['api_port']) diff --git a/tests/unit/lbrynet_daemon/auth/test_server.py b/tests/unit/lbrynet_daemon/auth/test_server.py index a293b7473..1e833d245 100644 --- a/tests/unit/lbrynet_daemon/auth/test_server.py +++ b/tests/unit/lbrynet_daemon/auth/test_server.py @@ -51,3 +51,16 @@ class AuthJSONRPCServerTest(unittest.TestCase): request.getHeader = mock.Mock(return_value='http://the_api:1111?settings') self.assertTrue(self.server._check_header_source(request, 'Referer')) request.getHeader.assert_called_with('Referer') + + def test_request_is_allowed_when_matching_allowed_origin_setting(self): + mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'}) + request = mock.Mock(['getHeader']) + request.getHeader = mock.Mock(return_value='http://example.com:1234') + self.assertTrue(self.server._check_header_source(request, 'Origin')) + + def test_request_is_rejected_when_not_matching_allowed_origin_setting(self): + mock_conf_settings(self, {'allowed_origin': 'http://example.com:1234'}) + request = mock.Mock(['getHeader']) + # note the ports don't match + request.getHeader = mock.Mock(return_value='http://example.com:1235') + self.assertFalse(self.server._check_header_source(request, 'Origin'))