refactored --max-key-fee to be more ergonomic

This commit is contained in:
Lex Berezhny 2019-01-25 22:13:43 -05:00
parent e01c73a7f8
commit be0bd3bdea
2 changed files with 109 additions and 40 deletions

View file

@ -4,6 +4,7 @@ import sys
import typing import typing
import logging import logging
import yaml import yaml
import decimal
from argparse import ArgumentParser from argparse import ArgumentParser
from contextlib import contextmanager from contextlib import contextmanager
from appdirs import user_data_dir, user_config_dir from appdirs import user_data_dir, user_config_dir
@ -61,6 +62,14 @@ class Setting(typing.Generic[T]):
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
self.name = name self.name = name
@property
def cli_name(self):
return f"--{self.name.replace('_', '-')}"
@property
def no_cli_name(self):
return f"--no-{self.name.replace('_', '-')}"
def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> T: def __get__(self, obj: typing.Optional['BaseConfig'], owner) -> T:
if obj is None: if obj is None:
return self return self
@ -88,6 +97,14 @@ class Setting(typing.Generic[T]):
def serialize(self, value): def serialize(self, value):
return value return value
def contribute_to_argparse(self, parser: ArgumentParser):
parser.add_argument(
self.cli_name,
help=self.doc,
metavar=self.metavar,
default=NOT_SET
)
class String(Setting[str]): class String(Setting[str]):
def validate(self, val): def validate(self, val):
@ -112,6 +129,21 @@ class Toggle(Setting[bool]):
assert isinstance(val, bool), \ assert isinstance(val, bool), \
f"Setting '{self.name}' must be a true/false value." f"Setting '{self.name}' must be a true/false value."
def contribute_to_argparse(self, parser: ArgumentParser):
parser.add_argument(
self.cli_name,
help=self.doc,
action="store_true",
default=NOT_SET
)
parser.add_argument(
self.no_cli_name,
help=f"Opposite of --{self.cli_name}",
dest=self.name,
action="store_false",
default=NOT_SET
)
class Path(String): class Path(String):
def __init__(self, doc: str, default: str = '', *args, **kwargs): def __init__(self, doc: str, default: str = '', *args, **kwargs):
@ -127,13 +159,53 @@ class Path(String):
class MaxKeyFee(Setting[dict]): class MaxKeyFee(Setting[dict]):
def validate(self, value): def validate(self, value):
assert isinstance(value, dict), \ assert isinstance(value, dict) and set(value) == {'currency', 'amount'}, \
f"Setting '{self.name}' must be of the format \"{'currency': 'USD', 'amount': 50.0}\"." f"Setting '{self.name}' must be a dict like \"{{'amount': 50.0, 'currency': 'USD'}}\"."
assert set(value) == {'currency', 'amount'}, \ if value["currency"] not in CURRENCIES:
f"Setting '{self.name}' must contain a 'currency' and an 'amount' field." raise InvalidCurrencyError(value["currency"])
currency = str(value["currency"]).upper()
@staticmethod
def _parse_list(l):
assert len(l) == 2, 'Max key fee is made up of two values: "AMOUNT CURRENCY".'
try:
amount = decimal.Decimal(l[0])
except decimal.InvalidOperation:
raise AssertionError('First value in max key fee is a decimal: "AMOUNT CURRENCY"')
currency = str(l[1]).upper()
if currency not in CURRENCIES: if currency not in CURRENCIES:
raise InvalidCurrencyError(currency) raise InvalidCurrencyError(currency)
return {'amount': amount, 'currency': currency}
def deserialize(self, value):
if value is None:
return
if isinstance(value, dict):
return {
'currency': value['currency'],
'amount': decimal.Decimal(value['amount']),
}
if isinstance(value, str):
value = value.split()
if isinstance(value, list):
return self._parse_list(value)
raise AssertionError('Invalid max key fee.')
def contribute_to_argparse(self, parser: ArgumentParser):
parser.add_argument(
self.cli_name,
help=self.doc,
nargs=2,
metavar=('AMOUNT', 'CURRENCY'),
default=NOT_SET
)
parser.add_argument(
self.no_cli_name,
help=f"Disable maximum key fee check.",
dest=self.name,
const=None,
action="store_const",
default=NOT_SET
)
class Servers(Setting[list]): class Servers(Setting[list]):
@ -169,6 +241,14 @@ class Servers(Setting[list]):
return [f"{host}:{port}" for host, port in value] return [f"{host}:{port}" for host, port in value]
return value return value
def contribute_to_argparse(self, parser: ArgumentParser):
parser.add_argument(
self.cli_name,
nargs="*",
help=self.doc,
default=NOT_SET
)
class Strings(Setting[list]): class Strings(Setting[list]):
@ -205,7 +285,7 @@ class ArgumentAccess:
def load(self, args): def load(self, args):
for setting in self.configuration.get_settings(): for setting in self.configuration.get_settings():
value = getattr(args, setting.name, NOT_SET) value = getattr(args, setting.name, NOT_SET)
if value not in (None, NOT_SET): if value != NOT_SET:
self.args[setting.name] = setting.deserialize(value) self.args[setting.name] = setting.deserialize(value)
def __contains__(self, item: str): def __contains__(self, item: str):
@ -340,36 +420,9 @@ class BaseConfig:
return conf return conf
@classmethod @classmethod
def contribute_args(cls, parser: ArgumentParser): def contribute_to_argparse(cls, parser: ArgumentParser):
for setting in cls.get_settings(): for setting in cls.get_settings():
if isinstance(setting, Toggle): setting.contribute_to_argparse(parser)
parser.add_argument(
f"--{setting.name.replace('_', '-')}",
help=setting.doc,
action="store_true",
default=NOT_SET
)
parser.add_argument(
f"--no-{setting.name.replace('_', '-')}",
help=f"Opposite of --{setting.name.replace('_', '-')}",
dest=setting.name,
action="store_false",
default=NOT_SET
)
elif isinstance(setting, Servers):
parser.add_argument(
f"--{setting.name.replace('_', '-')}",
nargs="*",
help=setting.doc,
default=None
)
else:
parser.add_argument(
f"--{setting.name.replace('_', '-')}",
help=setting.doc,
metavar=setting.metavar,
default=NOT_SET
)
def set_arguments(self, args): def set_arguments(self, args):
self.arguments = ArgumentAccess(self, args) self.arguments = ArgumentAccess(self, args)
@ -451,10 +504,7 @@ class Config(CLIConfig):
]) ])
max_connections_per_stream = Integer("", 5) max_connections_per_stream = Integer("", 5)
seek_head_blob_first = Toggle("", True) seek_head_blob_first = Toggle("", True)
# TODO: writing json on the cmd line is a pain, come up with a nicer
# parser for this data structure. maybe 'USD:25'
max_key_fee = MaxKeyFee("", {'currency': 'USD', 'amount': 50.0}) max_key_fee = MaxKeyFee("", {'currency': 'USD', 'amount': 50.0})
disable_max_key_fee = Toggle("", False)
min_info_rate = Float("points/1000 infos", .02) min_info_rate = Float("points/1000 infos", .02)
min_valuable_hash_rate = Float("points/1000 infos", .05) min_valuable_hash_rate = Float("points/1000 infos", .05)
min_valuable_info_rate = Float("points/1000 infos", .05) min_valuable_info_rate = Float("points/1000 infos", .05)

View file

@ -46,7 +46,7 @@ class ConfigurationTests(unittest.TestCase):
def test_arguments(self): def test_arguments(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
TestConfig.contribute_args(parser) TestConfig.contribute_to_argparse(parser)
args = parser.parse_args([]) args = parser.parse_args([])
c = TestConfig.create_from_arguments(args) c = TestConfig.create_from_arguments(args)
@ -180,7 +180,7 @@ class ConfigurationTests(unittest.TestCase):
) )
self.assertEqual(c.servers, [('localhost', 5566)]) self.assertEqual(c.servers, [('localhost', 5566)])
def test_max_key_fee(self): def test_max_key_fee_from_yaml(self):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
config = os.path.join(temp_dir, 'settings.yml') config = os.path.join(temp_dir, 'settings.yml')
with open(config, 'w') as fd: with open(config, 'w') as fd:
@ -196,3 +196,22 @@ class ConfigurationTests(unittest.TestCase):
c.max_key_fee = {'currency': 'BTC', 'amount': 1} c.max_key_fee = {'currency': 'BTC', 'amount': 1}
with open(config, 'r') as fd: with open(config, 'r') as fd:
self.assertEqual(fd.read(), 'max_key_fee:\n amount: 1\n currency: BTC\n') self.assertEqual(fd.read(), 'max_key_fee:\n amount: 1\n currency: BTC\n')
def test_max_key_fee_from_args(self):
parser = argparse.ArgumentParser()
Config.contribute_to_argparse(parser)
# default
args = parser.parse_args([])
c = Config.create_from_arguments(args)
self.assertEqual(c.max_key_fee, {'amount': 50.0, 'currency': 'USD'})
# disabled
args = parser.parse_args(['--no-max-key-fee'])
c = Config.create_from_arguments(args)
self.assertEqual(c.max_key_fee, None)
# set
args = parser.parse_args(['--max-key-fee', '1.0', 'BTC'])
c = Config.create_from_arguments(args)
self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'})