Merge pull request #285 from lbryio/more-flexible-origin-check

Allow 0.0.0.0 for api interface
This commit is contained in:
Jack Robison 2016-11-30 14:04:52 -05:00 committed by GitHub
commit fe53cc97d9
4 changed files with 97 additions and 9 deletions

View file

@ -170,7 +170,12 @@ ENVIRONMENT = Env(
# #
# TODO: writing json on the cmd line is a pain, come up with a nicer # TODO: writing json on the cmd line is a pain, come up with a nicer
# parser for this data structure. (maybe MAX_KEY_FEE=USD:25 # parser for this data structure. (maybe MAX_KEY_FEE=USD:25
max_key_fee=(json.loads, {'USD': {'amount': 25.0, 'address': ''}}) max_key_fee=(json.loads, {'USD': {'amount': 25.0, 'address': ''}}),
# Changing this value is not-advised as it could potentially
# expose the lbrynet daemon to the outside world which would
# give an attacker access to your wallet and you could lose
# all of your credits.
API_INTERFACE=(str, "localhost"),
) )
@ -205,7 +210,6 @@ class ApplicationSettings(Settings):
self.LOG_FILE_NAME = "lbrynet.log" self.LOG_FILE_NAME = "lbrynet.log"
self.LOG_POST_URL = "https://lbry.io/log-upload" self.LOG_POST_URL = "https://lbry.io/log-upload"
self.CRYPTSD_FILE_EXTENSION = ".cryptsd" self.CRYPTSD_FILE_EXTENSION = ".cryptsd"
self.API_INTERFACE = "localhost"
self.API_ADDRESS = "lbryapi" self.API_ADDRESS = "lbryapi"
self.ICON_PATH = "icons" if platform is WINDOWS else "app.icns" self.ICON_PATH = "icons" if platform is WINDOWS else "app.icns"
self.APP_NAME = "LBRY" self.APP_NAME = "LBRY"

View file

@ -1,4 +1,6 @@
import logging import logging
import urlparse
from decimal import Decimal from decimal import Decimal
from zope.interface import implements from zope.interface import implements
from twisted.web import server, resource from twisted.web import server, resource
@ -219,16 +221,37 @@ class AuthJSONRPCServer(AuthorizedBase):
request.finish() request.finish()
def _check_headers(self, request): def _check_headers(self, request):
origin = request.getHeader("Origin") return (
referer = request.getHeader("Referer") self._check_header_source(request, 'Origin') and
if origin not in [None, settings.ORIGIN]: self._check_header_source(request, 'Referer'))
log.warning("Attempted api call from %s", origin)
return False def _check_header_source(self, request, header):
if referer is not None and not referer.startswith(settings.REFERER): """Check if the source of the request is allowed based on the header value."""
log.warning("Attempted api call from %s", referer) source = request.getHeader(header)
if not self._check_source_of_request(source):
log.warning("Attempted api call from invalid %s: %s", header, source)
return False return False
return True return True
def _check_source_of_request(self, source):
if source is None:
return True
if settings.API_INTERFACE == '0.0.0.0':
return True
server, port = self.get_server_port(source)
return (
server == settings.API_INTERFACE and
port == settings.api_port)
def get_server_port(self, origin):
parsed = urlparse.urlparse(origin)
server_port = parsed.netloc.split(':')
assert len(server_port) <= 2
if len(server_port) == 2:
return server_port[0], int(server_port[1])
else:
return server_port[0], 80
def _check_function_path(self, function_path): def _check_function_path(self, function_path):
if function_path not in self.callable_methods: if function_path not in self.callable_methods:
log.warning("Unknown method: %s", function_path) log.warning("Unknown method: %s", function_path)

View file

@ -0,0 +1,61 @@
import mock
import requests
from twisted.trial import unittest
from lbrynet import conf
from lbrynet.lbrynet_daemon.auth import server
class AuthJSONRPCServerTest(unittest.TestCase):
# TODO: move to using a base class for tests
# and add useful general utilities like this
# onto it.
def setUp(self):
self.server = server.AuthJSONRPCServer(False)
def _set_setting(self, attr, value):
original = getattr(conf.settings, attr)
setattr(conf.settings, attr, value)
self.addCleanup(lambda: setattr(conf.settings, attr, original))
def test_get_server_port(self):
self.assertSequenceEqual(
('example.com', 80), self.server.get_server_port('http://example.com'))
self.assertSequenceEqual(
('example.com', 1234), self.server.get_server_port('http://example.com:1234'))
def test_foreign_origin_is_rejected(self):
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://example.com')
self.assertFalse(self.server._check_header_source(request, 'Origin'))
def test_wrong_port_is_rejected(self):
self._set_setting('api_port', 1234)
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://localhost:9999')
self.assertFalse(self.server._check_header_source(request, 'Origin'))
def test_matching_origin_is_allowed(self):
self._set_setting('API_INTERFACE', 'example.com')
self._set_setting('api_port', 1234)
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://example.com:1234')
self.assertTrue(self.server._check_header_source(request, 'Origin'))
def test_any_origin_is_allowed(self):
self._set_setting('API_INTERFACE', '0.0.0.0')
self._set_setting('api_port', 80)
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://example.com')
self.assertTrue(self.server._check_header_source(request, 'Origin'))
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://another-example.com')
self.assertTrue(self.server._check_header_source(request, 'Origin'))
def test_matching_referer_is_allowed(self):
self._set_setting('API_INTERFACE', 'the_api')
self._set_setting('api_port', 1111)
request = mock.Mock(['getHeader'])
request.getHeader = mock.Mock(return_value='http://the_api:1111?settings')
self.assertTrue(self.server._check_header_source(request, 'Referer'))
request.getHeader.assert_called_with('Referer')