100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
|
# Copyright (c) 2017, Neil Booth
|
||
|
#
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# See the file "LICENCE" for information about the copyright
|
||
|
# and warranty status of this software.
|
||
|
|
||
|
'''Class for server environment configuration and defaults.'''
|
||
|
|
||
|
|
||
|
from os import environ
|
||
|
|
||
|
from torba.server.util import class_logger
|
||
|
|
||
|
|
||
|
class EnvBase(object):
|
||
|
'''Wraps environment configuration.'''
|
||
|
|
||
|
class Error(Exception):
|
||
|
pass
|
||
|
|
||
|
def __init__(self):
|
||
|
self.logger = class_logger(__name__, self.__class__.__name__)
|
||
|
self.allow_root = self.boolean('ALLOW_ROOT', False)
|
||
|
self.host = self.default('HOST', 'localhost')
|
||
|
self.rpc_host = self.default('RPC_HOST', 'localhost')
|
||
|
self.loop_policy = self.event_loop_policy()
|
||
|
|
||
|
@classmethod
|
||
|
def default(cls, envvar, default):
|
||
|
return environ.get(envvar, default)
|
||
|
|
||
|
@classmethod
|
||
|
def boolean(cls, envvar, default):
|
||
|
default = 'Yes' if default else ''
|
||
|
return bool(cls.default(envvar, default).strip())
|
||
|
|
||
|
@classmethod
|
||
|
def required(cls, envvar):
|
||
|
value = environ.get(envvar)
|
||
|
if value is None:
|
||
|
raise cls.Error('required envvar {} not set'.format(envvar))
|
||
|
return value
|
||
|
|
||
|
@classmethod
|
||
|
def integer(cls, envvar, default):
|
||
|
value = environ.get(envvar)
|
||
|
if value is None:
|
||
|
return default
|
||
|
try:
|
||
|
return int(value)
|
||
|
except Exception:
|
||
|
raise cls.Error('cannot convert envvar {} value {} to an integer'
|
||
|
.format(envvar, value))
|
||
|
|
||
|
@classmethod
|
||
|
def custom(cls, envvar, default, parse):
|
||
|
value = environ.get(envvar)
|
||
|
if value is None:
|
||
|
return default
|
||
|
try:
|
||
|
return parse(value)
|
||
|
except Exception as e:
|
||
|
raise cls.Error('cannot parse envvar {} value {}'
|
||
|
.format(envvar, value)) from e
|
||
|
|
||
|
@classmethod
|
||
|
def obsolete(cls, envvars):
|
||
|
bad = [envvar for envvar in envvars if environ.get(envvar)]
|
||
|
if bad:
|
||
|
raise cls.Error('remove obsolete environment variables {}'
|
||
|
.format(bad))
|
||
|
|
||
|
def event_loop_policy(self):
|
||
|
policy = self.default('EVENT_LOOP_POLICY', None)
|
||
|
if policy is None:
|
||
|
return None
|
||
|
if policy == 'uvloop':
|
||
|
import uvloop
|
||
|
return uvloop.EventLoopPolicy()
|
||
|
raise self.Error('unknown event loop policy "{}"'.format(policy))
|
||
|
|
||
|
def cs_host(self, *, for_rpc):
|
||
|
'''Returns the 'host' argument to pass to asyncio's create_server
|
||
|
call. The result can be a single host name string, a list of
|
||
|
host name strings, or an empty string to bind to all interfaces.
|
||
|
|
||
|
If rpc is True the host to use for the RPC server is returned.
|
||
|
Otherwise the host to use for SSL/TCP servers is returned.
|
||
|
'''
|
||
|
host = self.rpc_host if for_rpc else self.host
|
||
|
result = [part.strip() for part in host.split(',')]
|
||
|
if len(result) == 1:
|
||
|
result = result[0]
|
||
|
# An empty result indicates all interfaces, which we do not
|
||
|
# permitted for an RPC server.
|
||
|
if for_rpc and not result:
|
||
|
result = 'localhost'
|
||
|
return result
|