diff --git a/lbrynet/conf.py b/lbrynet/conf.py index d4d93552a..744bec931 100644 --- a/lbrynet/conf.py +++ b/lbrynet/conf.py @@ -4,6 +4,7 @@ import sys import typing import logging import yaml +import decimal from argparse import ArgumentParser from contextlib import contextmanager from appdirs import user_data_dir, user_config_dir @@ -61,6 +62,14 @@ class Setting(typing.Generic[T]): def __set_name__(self, owner, name): self.name = name + @property + def cli_name(self): + return f"--{self.name.replace('_', '-')}" + + @property + def no_cli_name(self): + return f"--no-{self.name.replace('_', '-')}" + def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> T: if obj is None: return self @@ -88,6 +97,14 @@ class Setting(typing.Generic[T]): def serialize(self, value): return value + def contribute_to_argparse(self, parser: ArgumentParser): + parser.add_argument( + self.cli_name, + help=self.doc, + metavar=self.metavar, + default=NOT_SET + ) + class String(Setting[str]): def validate(self, val): @@ -112,6 +129,21 @@ class Toggle(Setting[bool]): assert isinstance(val, bool), \ f"Setting '{self.name}' must be a true/false value." + def contribute_to_argparse(self, parser: ArgumentParser): + parser.add_argument( + self.cli_name, + help=self.doc, + action="store_true", + default=NOT_SET + ) + parser.add_argument( + self.no_cli_name, + help=f"Opposite of --{self.cli_name}", + dest=self.name, + action="store_false", + default=NOT_SET + ) + class Path(String): def __init__(self, doc: str, default: str = '', *args, **kwargs): @@ -127,13 +159,53 @@ class Path(String): class MaxKeyFee(Setting[dict]): def validate(self, value): - assert isinstance(value, dict), \ - f"Setting '{self.name}' must be of the format \"{'currency': 'USD', 'amount': 50.0}\"." - assert set(value) == {'currency', 'amount'}, \ - f"Setting '{self.name}' must contain a 'currency' and an 'amount' field." - currency = str(value["currency"]).upper() + assert isinstance(value, dict) and set(value) == {'currency', 'amount'}, \ + f"Setting '{self.name}' must be a dict like \"{{'amount': 50.0, 'currency': 'USD'}}\"." + if value["currency"] not in CURRENCIES: + raise InvalidCurrencyError(value["currency"]) + + @staticmethod + def _parse_list(l): + assert len(l) == 2, 'Max key fee is made up of two values: "AMOUNT CURRENCY".' + try: + amount = decimal.Decimal(l[0]) + except decimal.InvalidOperation: + raise AssertionError('First value in max key fee is a decimal: "AMOUNT CURRENCY"') + currency = str(l[1]).upper() if currency not in CURRENCIES: raise InvalidCurrencyError(currency) + return {'amount': amount, 'currency': currency} + + def deserialize(self, value): + if value is None: + return + if isinstance(value, dict): + return { + 'currency': value['currency'], + 'amount': decimal.Decimal(value['amount']), + } + if isinstance(value, str): + value = value.split() + if isinstance(value, list): + return self._parse_list(value) + raise AssertionError('Invalid max key fee.') + + def contribute_to_argparse(self, parser: ArgumentParser): + parser.add_argument( + self.cli_name, + help=self.doc, + nargs=2, + metavar=('AMOUNT', 'CURRENCY'), + default=NOT_SET + ) + parser.add_argument( + self.no_cli_name, + help=f"Disable maximum key fee check.", + dest=self.name, + const=None, + action="store_const", + default=NOT_SET + ) class Servers(Setting[list]): @@ -169,6 +241,14 @@ class Servers(Setting[list]): return [f"{host}:{port}" for host, port in value] return value + def contribute_to_argparse(self, parser: ArgumentParser): + parser.add_argument( + self.cli_name, + nargs="*", + help=self.doc, + default=NOT_SET + ) + class Strings(Setting[list]): @@ -205,7 +285,7 @@ class ArgumentAccess: def load(self, args): for setting in self.configuration.get_settings(): value = getattr(args, setting.name, NOT_SET) - if value not in (None, NOT_SET): + if value != NOT_SET: self.args[setting.name] = setting.deserialize(value) def __contains__(self, item: str): @@ -340,36 +420,9 @@ class BaseConfig: return conf @classmethod - def contribute_args(cls, parser: ArgumentParser): + def contribute_to_argparse(cls, parser: ArgumentParser): for setting in cls.get_settings(): - if isinstance(setting, Toggle): - parser.add_argument( - f"--{setting.name.replace('_', '-')}", - help=setting.doc, - action="store_true", - default=NOT_SET - ) - parser.add_argument( - f"--no-{setting.name.replace('_', '-')}", - help=f"Opposite of --{setting.name.replace('_', '-')}", - dest=setting.name, - action="store_false", - default=NOT_SET - ) - elif isinstance(setting, Servers): - parser.add_argument( - f"--{setting.name.replace('_', '-')}", - nargs="*", - help=setting.doc, - default=None - ) - else: - parser.add_argument( - f"--{setting.name.replace('_', '-')}", - help=setting.doc, - metavar=setting.metavar, - default=NOT_SET - ) + setting.contribute_to_argparse(parser) def set_arguments(self, args): self.arguments = ArgumentAccess(self, args) @@ -451,10 +504,7 @@ class Config(CLIConfig): ]) max_connections_per_stream = Integer("", 5) seek_head_blob_first = Toggle("", True) - # TODO: writing json on the cmd line is a pain, come up with a nicer - # parser for this data structure. maybe 'USD:25' max_key_fee = MaxKeyFee("", {'currency': 'USD', 'amount': 50.0}) - disable_max_key_fee = Toggle("", False) min_info_rate = Float("points/1000 infos", .02) min_valuable_hash_rate = Float("points/1000 infos", .05) min_valuable_info_rate = Float("points/1000 infos", .05) diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index fa0b51cb4..ff750fb11 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -46,7 +46,7 @@ class ConfigurationTests(unittest.TestCase): def test_arguments(self): parser = argparse.ArgumentParser() - TestConfig.contribute_args(parser) + TestConfig.contribute_to_argparse(parser) args = parser.parse_args([]) c = TestConfig.create_from_arguments(args) @@ -180,7 +180,7 @@ class ConfigurationTests(unittest.TestCase): ) self.assertEqual(c.servers, [('localhost', 5566)]) - def test_max_key_fee(self): + def test_max_key_fee_from_yaml(self): with tempfile.TemporaryDirectory() as temp_dir: config = os.path.join(temp_dir, 'settings.yml') with open(config, 'w') as fd: @@ -196,3 +196,22 @@ class ConfigurationTests(unittest.TestCase): c.max_key_fee = {'currency': 'BTC', 'amount': 1} with open(config, 'r') as fd: self.assertEqual(fd.read(), 'max_key_fee:\n amount: 1\n currency: BTC\n') + + def test_max_key_fee_from_args(self): + parser = argparse.ArgumentParser() + Config.contribute_to_argparse(parser) + + # default + args = parser.parse_args([]) + c = Config.create_from_arguments(args) + self.assertEqual(c.max_key_fee, {'amount': 50.0, 'currency': 'USD'}) + + # disabled + args = parser.parse_args(['--no-max-key-fee']) + c = Config.create_from_arguments(args) + self.assertEqual(c.max_key_fee, None) + + # set + args = parser.parse_args(['--max-key-fee', '1.0', 'BTC']) + c = Config.create_from_arguments(args) + self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'})