From d8265add2d144b2f265d25acc0ad6903e48ab0c9 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Fri, 21 Jun 2019 02:15:59 -0400 Subject: [PATCH] added coin selection strategy for both prefer_confirmed and only_confirmed --- lbry/lbry/conf.py | 21 ++--- lbry/tests/unit/test_conf.py | 38 ++++----- .../client_tests/unit/test_coinselection.py | 41 +++++---- torba/torba/client/coinselection.py | 84 ++++++++++++------- 4 files changed, 102 insertions(+), 82 deletions(-) diff --git a/lbry/lbry/conf.py b/lbry/lbry/conf.py index e779e1e43..7a4ac17ba 100644 --- a/lbry/lbry/conf.py +++ b/lbry/lbry/conf.py @@ -193,19 +193,19 @@ class MaxKeyFee(Setting[dict]): ) -class OneOfString(String): - def __init__(self, valid_values: typing.List[str], *args, **kwargs): - super().__init__(*args, **kwargs) +class StringChoice(String): + def __init__(self, doc: str, valid_values: typing.List[str], default: str, *args, **kwargs): + super().__init__(doc, default, *args, **kwargs) + if not valid_values: + raise ValueError("No valid values provided") + if default not in valid_values: + raise ValueError(f"Default value must be one of: {', '.join(valid_values)}") self.valid_values = valid_values - if not self.valid_values: - raise ValueError(f"No valid values provided") - if self.default not in self.valid_values: - raise ValueError(f"Default value must be one of: " + ', '.join(self.valid_values)) def validate(self, val): super().validate(val) if val not in self.valid_values: - raise ValueError(f"Setting '{self.name}' must be one of: " + ', '.join(self.valid_values)) + raise ValueError(f"Setting '{self.name}' value must be one of: {', '.join(self.valid_values)}") class ListSetting(Setting[list]): @@ -578,8 +578,9 @@ class Config(CLIConfig): streaming_get = Toggle("Enable the /get endpoint for the streaming media server. " "Disable to prevent new streams from being added.", True) - coin_selection_strategy = OneOfString(STRATEGIES, "Strategy to use when selecting UTXOs for a transaction", - "branch_and_bound") + coin_selection_strategy = StringChoice( + "Strategy to use when selecting UTXOs for a transaction", + STRATEGIES, "standard") @property def streaming_host(self): diff --git a/lbry/tests/unit/test_conf.py b/lbry/tests/unit/test_conf.py index 1d8bc290f..dd5d62721 100644 --- a/lbry/tests/unit/test_conf.py +++ b/lbry/tests/unit/test_conf.py @@ -4,7 +4,7 @@ import types import tempfile import unittest import argparse -from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, OneOfString, NOT_SET +from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, StringChoice, NOT_SET from lbry.error import InvalidCurrencyError @@ -15,7 +15,7 @@ class TestConfig(BaseConfig): test_true_toggle = Toggle('toggle help', True) servers = Servers('servers help', [('localhost', 80)]) strings = Strings('cheese', ['string']) - one_of_string = OneOfString(["a", "b", "c"], "one of string", "a") + string_choice = StringChoice("one of string", ["a", "b", "c"], "a") class ConfigurationTests(unittest.TestCase): @@ -227,31 +227,21 @@ class ConfigurationTests(unittest.TestCase): c = Config.create_from_arguments(args) self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'}) - def test_one_of_string(self): - with self.assertRaises(ValueError): - no_vaid_values = OneOfString([], "no valid values", None) - - with self.assertRaises(ValueError): - default_none = OneOfString(["a"], "invalid default", None) - with self.assertRaises(ValueError): - invalid_default = OneOfString(["a"], "invalid default", "b") - - valid_default = OneOfString(["a"], "valid default", "a") - - self.assertEqual("hello", OneOfString(["hello"], "valid default", "hello").default) + def test_string_choice(self): + with self.assertRaisesRegex(ValueError, "No valid values provided"): + StringChoice("no valid values", [], "") + with self.assertRaisesRegex(ValueError, "Default value must be one of"): + StringChoice("invalid default", ["a"], "b") c = TestConfig() - with self.assertRaises(ValueError): - c.one_of_string = "d" + self.assertEqual("a", c.string_choice) # default + c.string_choice = "b" + self.assertEqual("b", c.string_choice) + with self.assertRaisesRegex(ValueError, "Setting 'string_choice' value must be one of"): + c.string_choice = "d" parser = argparse.ArgumentParser() TestConfig.contribute_to_argparse(parser) - - args = parser.parse_args(["--one-of-string=b"]) + args = parser.parse_args(['--string-choice', 'c']) c = TestConfig.create_from_arguments(args) - self.assertEqual("b", c.one_of_string) - - # with self.assertRaises(ValueError): - # args = parser.parse_args(["--one-of-string=arst"]) - # c = TestConfig.create_from_arguments(args) - # print("here") + self.assertEqual("c", c.string_choice) diff --git a/torba/tests/client_tests/unit/test_coinselection.py b/torba/tests/client_tests/unit/test_coinselection.py index 8a5d38c8f..51ad220a5 100644 --- a/torba/tests/client_tests/unit/test_coinselection.py +++ b/torba/tests/client_tests/unit/test_coinselection.py @@ -13,7 +13,7 @@ NULL_HASH = b'\x00'*32 def search(*args, **kwargs): - selection = CoinSelector(*args, **kwargs).branch_and_bound() + selection = CoinSelector(*args[1:], **kwargs).select(args[0], 'branch_and_bound') return [o.txo.amount for o in selection] if selection else selection @@ -37,17 +37,17 @@ class BaseSelectionTestCase(AsyncioTestCase): class TestCoinSelectionTests(BaseSelectionTestCase): def test_empty_coins(self): - self.assertEqual(CoinSelector([], 0, 0).select(), []) + self.assertEqual(CoinSelector(0, 0).select([]), []) def test_skip_binary_search_if_total_not_enough(self): fee = utxo(CENT).get_estimator(self.ledger).fee big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) - selector = CoinSelector(big_pool, 101 * CENT, 0) - self.assertEqual(selector.select(), []) + selector = CoinSelector(101 * CENT, 0) + self.assertEqual(selector.select(big_pool), []) self.assertEqual(selector.tries, 0) # Never tried. # check happy path - selector = CoinSelector(big_pool, 100 * CENT, 0) - self.assertEqual(len(selector.select()), 100) + selector = CoinSelector(100 * CENT, 0) + self.assertEqual(len(selector.select(big_pool)), 100) self.assertEqual(selector.tries, 201) def test_exact_match(self): @@ -57,8 +57,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase): utxo(CENT), utxo(CENT - fee) ) - selector = CoinSelector(utxo_pool, CENT, 0) - match = selector.select() + selector = CoinSelector(CENT, 0) + match = selector.select(utxo_pool) self.assertEqual([CENT + fee], [c.txo.amount for c in match]) self.assertTrue(selector.exact_match) @@ -68,8 +68,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase): utxo(3 * CENT), utxo(4 * CENT) ) - selector = CoinSelector(utxo_pool, CENT, 0, '\x00') - match = selector.select() + selector = CoinSelector(CENT, 0, '\x00') + match = selector.select(utxo_pool) self.assertEqual([2 * CENT], [c.txo.amount for c in match]) self.assertFalse(selector.exact_match) @@ -81,20 +81,27 @@ class TestCoinSelectionTests(BaseSelectionTestCase): utxo(5*CENT), utxo(10*CENT), ) - selector = CoinSelector(utxo_pool, 3*CENT, 0) - match = selector.select() + selector = CoinSelector(3*CENT, 0) + match = selector.select(utxo_pool) self.assertEqual([5*CENT], [c.txo.amount for c in match]) - def test_prefer_confirmed_strategy(self): + def test_confirmed_strategies(self): utxo_pool = self.estimates( utxo(11*CENT, height=5), utxo(11*CENT, height=0), utxo(11*CENT, height=-2), utxo(11*CENT, height=5), ) - selector = CoinSelector(utxo_pool, 20*CENT, 0) - match = selector.select("prefer_confirmed") + + match = CoinSelector(20*CENT, 0).select(utxo_pool, "only_confirmed") self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match]) + match = CoinSelector(25*CENT, 0).select(utxo_pool, "only_confirmed") + self.assertEqual([], [c.txo.tx_ref.height for c in match]) + + match = CoinSelector(20*CENT, 0).select(utxo_pool, "prefer_confirmed") + self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match]) + match = CoinSelector(25*CENT, 0, '\x00').select(utxo_pool, "prefer_confirmed") + self.assertEqual([5, 0, -2], [c.txo.tx_ref.height for c in match]) class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): @@ -160,8 +167,8 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): # Iteration exhaustion test utxo_pool, target = self.make_hard_case(17) - selector = CoinSelector(utxo_pool, target, 0) - self.assertEqual(selector.branch_and_bound(), []) + selector = CoinSelector(target, 0) + self.assertEqual(selector.select(utxo_pool, 'branch_and_bound'), []) self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust utxo_pool, target = self.make_hard_case(14) self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust diff --git a/torba/torba/client/coinselection.py b/torba/torba/client/coinselection.py index 61cff6909..953915b5f 100644 --- a/torba/torba/client/coinselection.py +++ b/torba/torba/client/coinselection.py @@ -7,6 +7,7 @@ MAXIMUM_TRIES = 100000 STRATEGIES = [] + def strategy(method): STRATEGIES.append(method.__name__) return method @@ -14,50 +15,67 @@ def strategy(method): class CoinSelector: - def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], - target: int, cost_of_change: int, seed: str = None) -> None: - self.txos = txos + def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None: self.target = target self.cost_of_change = cost_of_change self.exact_match = False self.tries = 0 - self.available = sum(c.effective_amount for c in self.txos) self.random = Random(seed) if seed is not None: self.random.seed(seed, version=1) - def select(self, strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: - if not self.txos: + def select( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], + strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + if not txos: return [] - if self.target > self.available: + available = sum(c.effective_amount for c in txos) + if self.target > available: return [] - if strategy_name is not None: - return getattr(self, strategy_name)() + return getattr(self, strategy_name or "standard")(txos, available) + + @strategy + def prefer_confirmed( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: return ( - self.branch_and_bound() or - self.closest_match() or - self.random_draw() + self.only_confirmed(txos, available) or + self.standard(txos, available) ) @strategy - def prefer_confirmed(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: - self.txos = [t for t in self.txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] or self.txos - self.available = sum(c.effective_amount for c in self.txos) + def only_confirmed( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _ + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] + if not confirmed: + return [] + confirmed_available = sum(c.effective_amount for c in confirmed) + if self.target > confirmed_available: + return [] + return self.standard(confirmed, confirmed_available) + + @strategy + def standard( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: return ( - self.branch_and_bound() or - self.closest_match() or - self.random_draw() + self.branch_and_bound(txos, available) or + self.closest_match(txos, available) or + self.random_draw(txos, available) ) @strategy - def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def branch_and_bound( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: # see bitcoin implementation for more info: # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp - self.txos.sort(reverse=True) + txos.sort(reverse=True) current_value = 0 - current_available_value = self.available + current_available_value = available current_selection: List[bool] = [] best_waste = self.cost_of_change best_selection: List[bool] = [] @@ -79,19 +97,19 @@ class CoinSelector: if backtrack: while current_selection and not current_selection[-1]: current_selection.pop() - current_available_value += self.txos[len(current_selection)].effective_amount + current_available_value += txos[len(current_selection)].effective_amount if not current_selection: break current_selection[-1] = False - utxo = self.txos[len(current_selection) - 1] + utxo = txos[len(current_selection) - 1] current_value -= utxo.effective_amount else: - utxo = self.txos[len(current_selection)] + utxo = txos[len(current_selection)] current_available_value -= utxo.effective_amount - previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None + previous_utxo = txos[len(current_selection) - 1] if current_selection else None if current_selection and not current_selection[-1] and previous_utxo and \ utxo.effective_amount == previous_utxo.effective_amount and \ utxo.fee == previous_utxo.fee: @@ -103,18 +121,20 @@ class CoinSelector: if best_selection: self.exact_match = True return [ - self.txos[i] for i, include in enumerate(best_selection) if include + txos[i] for i, include in enumerate(best_selection) if include ] return [] @strategy - def closest_match(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def closest_match( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _ + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: """ Pick one UTXOs that is larger than the target but with the smallest change. """ target = self.target + self.cost_of_change smallest_change = None best_match = None - for txo in self.txos: + for txo in txos: if txo.effective_amount >= target: change = txo.effective_amount - target if smallest_change is None or change < smallest_change: @@ -122,13 +142,15 @@ class CoinSelector: return [best_match] if best_match else [] @strategy - def random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: + def random_draw( + self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _ + ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: """ Accumulate UTXOs at random until there is enough to cover the target. """ target = self.target + self.cost_of_change - self.random.shuffle(self.txos, self.random.random) + self.random.shuffle(txos, self.random.random) selection = [] amount = 0 - for coin in self.txos: + for coin in txos: selection.append(coin) amount += coin.effective_amount if amount >= target: