--conf option improvements (#1455)

* Conf file improvements
* Add test for loading config file
This commit is contained in:
Tomasz Kopeć 2018-10-10 19:44:51 +02:00 committed by Lex Berezhny
parent fc41af5889
commit 9e2e53147e
2 changed files with 37 additions and 25 deletions

View file

@ -49,6 +49,8 @@ DEFAULT_DHT_NODES = [
('lbrynet3.lbry.io', 4444) ('lbrynet3.lbry.io', 4444)
] ]
DEFAULT_SETTINGS_FILENAME = 'daemon_settings.yml'
settings_decoders = { settings_decoders = {
'.json': json.loads, '.json': json.loads,
'.yml': yaml.load '.yml': yaml.load
@ -483,16 +485,15 @@ class Config:
} }
def save_conf_file_settings(self): def save_conf_file_settings(self):
if conf_file: path = conf_file or self.get_valid_settings_filename() or DEFAULT_SETTINGS_FILENAME
path = conf_file
else:
path = self.get_conf_filename()
# reverse the conversions done after loading the settings from the conf # reverse the conversions done after loading the settings from the conf
# file # file
rev = self._convert_conf_file_lists_reverse(self._data[TYPE_PERSISTED]) rev = self._convert_conf_file_lists_reverse(self._data[TYPE_PERSISTED])
ext = os.path.splitext(path)[1] ext = os.path.splitext(path)[1]
encoder = settings_encoders.get(ext, False) encoder = settings_encoders.get(ext, False)
assert encoder is not False, 'Unknown settings format %s' % ext if not encoder:
raise ValueError('Unknown settings format: {}. Available formats: {}'
.format(ext, list(settings_encoders.keys())))
with open(path, 'w') as settings_file: with open(path, 'w') as settings_file:
settings_file.write(encoder(rev)) settings_file.write(encoder(rev))
@ -521,25 +522,25 @@ class Config:
settings.node_id = settings.get_node_id() settings.node_id = settings.get_node_id()
def load_conf_file_settings(self): def load_conf_file_settings(self):
if conf_file: path = conf_file or self.get_valid_settings_filename()
path = conf_file self._read_conf_file(path)
else: # initialize members depending on config file
path = self.get_conf_filename() self.initialize_post_conf_load()
def _read_conf_file(self, path):
if not path:
return
ext = os.path.splitext(path)[1] ext = os.path.splitext(path)[1]
decoder = settings_decoders.get(ext, False) decoder = settings_decoders.get(ext, False)
assert decoder is not False, 'Unknown settings format %s' % ext if not decoder:
try: raise ValueError('Unknown settings format: {}. Available formats: {}'
.format(ext, list(settings_decoders.keys())))
with open(path, 'r') as settings_file: with open(path, 'r') as settings_file:
data = settings_file.read() data = settings_file.read()
decoded = self._fix_old_conf_file_settings(decoder(data)) decoded = self._fix_old_conf_file_settings(decoder(data))
log.info('Loaded settings file: %s', path) log.info('Loaded settings file: %s', path)
self._validate_settings(decoded) self._validate_settings(decoded)
self._data[TYPE_PERSISTED].update(self._convert_conf_file_lists(decoded)) self._data[TYPE_PERSISTED].update(self._convert_conf_file_lists(decoded))
except (IOError, OSError) as err:
log.info('%s: Failed to update settings from %s', err, path)
#initialize members depending on config file
self.initialize_post_conf_load()
def _fix_old_conf_file_settings(self, settings_dict): def _fix_old_conf_file_settings(self, settings_dict):
if 'API_INTERFACE' in settings_dict: if 'API_INTERFACE' in settings_dict:
@ -590,7 +591,7 @@ class Config:
def get_db_revision_filename(self): def get_db_revision_filename(self):
return os.path.join(self.ensure_data_dir(), self['DB_REVISION_FILE_NAME']) return os.path.join(self.ensure_data_dir(), self['DB_REVISION_FILE_NAME'])
def get_conf_filename(self): def get_valid_settings_filename(self):
data_dir = self.ensure_data_dir() data_dir = self.ensure_data_dir()
yml_path = os.path.join(data_dir, 'daemon_settings.yml') yml_path = os.path.join(data_dir, 'daemon_settings.yml')
json_path = os.path.join(data_dir, 'daemon_settings.json') json_path = os.path.join(data_dir, 'daemon_settings.json')
@ -598,8 +599,6 @@ class Config:
return yml_path return yml_path
elif os.path.isfile(json_path): elif os.path.isfile(json_path):
return json_path return json_path
else:
return yml_path
def get_installation_id(self): def get_installation_id(self):
install_id_filename = os.path.join(self.ensure_data_dir(), "install_id") install_id_filename = os.path.join(self.ensure_data_dir(), "install_id")

View file

@ -107,3 +107,16 @@ class SettingsTest(unittest.TestCase):
conf_decoded_new = decoder(conf_entry_new) conf_decoded_new = decoder(conf_entry_new)
self.assertEqual(conf_decoded, conf_decoded_new) self.assertEqual(conf_decoded, conf_decoded_new)
def test_load_file(self):
settings = self.get_mock_config_instance()
# nonexistent file
conf.conf_file = 'monkey.yml'
with self.assertRaises(FileNotFoundError):
settings.load_conf_file_settings()
# invalid extensions
for filename in ('monkey.yymmll', 'monkey'):
conf.conf_file = filename
with self.assertRaises(ValueError):
settings.load_conf_file_settings()