added coin selection strategy for both prefer_confirmed and only_confirmed
This commit is contained in:
parent
01cd02e4c5
commit
d8265add2d
4 changed files with 102 additions and 82 deletions
|
@ -193,19 +193,19 @@ class MaxKeyFee(Setting[dict]):
|
|||
)
|
||||
|
||||
|
||||
class OneOfString(String):
|
||||
def __init__(self, valid_values: typing.List[str], *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
class StringChoice(String):
|
||||
def __init__(self, doc: str, valid_values: typing.List[str], default: str, *args, **kwargs):
|
||||
super().__init__(doc, default, *args, **kwargs)
|
||||
if not valid_values:
|
||||
raise ValueError("No valid values provided")
|
||||
if default not in valid_values:
|
||||
raise ValueError(f"Default value must be one of: {', '.join(valid_values)}")
|
||||
self.valid_values = valid_values
|
||||
if not self.valid_values:
|
||||
raise ValueError(f"No valid values provided")
|
||||
if self.default not in self.valid_values:
|
||||
raise ValueError(f"Default value must be one of: " + ', '.join(self.valid_values))
|
||||
|
||||
def validate(self, val):
|
||||
super().validate(val)
|
||||
if val not in self.valid_values:
|
||||
raise ValueError(f"Setting '{self.name}' must be one of: " + ', '.join(self.valid_values))
|
||||
raise ValueError(f"Setting '{self.name}' value must be one of: {', '.join(self.valid_values)}")
|
||||
|
||||
|
||||
class ListSetting(Setting[list]):
|
||||
|
@ -578,8 +578,9 @@ class Config(CLIConfig):
|
|||
streaming_get = Toggle("Enable the /get endpoint for the streaming media server. "
|
||||
"Disable to prevent new streams from being added.", True)
|
||||
|
||||
coin_selection_strategy = OneOfString(STRATEGIES, "Strategy to use when selecting UTXOs for a transaction",
|
||||
"branch_and_bound")
|
||||
coin_selection_strategy = StringChoice(
|
||||
"Strategy to use when selecting UTXOs for a transaction",
|
||||
STRATEGIES, "standard")
|
||||
|
||||
@property
|
||||
def streaming_host(self):
|
||||
|
|
|
@ -4,7 +4,7 @@ import types
|
|||
import tempfile
|
||||
import unittest
|
||||
import argparse
|
||||
from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, OneOfString, NOT_SET
|
||||
from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, StringChoice, NOT_SET
|
||||
from lbry.error import InvalidCurrencyError
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ class TestConfig(BaseConfig):
|
|||
test_true_toggle = Toggle('toggle help', True)
|
||||
servers = Servers('servers help', [('localhost', 80)])
|
||||
strings = Strings('cheese', ['string'])
|
||||
one_of_string = OneOfString(["a", "b", "c"], "one of string", "a")
|
||||
string_choice = StringChoice("one of string", ["a", "b", "c"], "a")
|
||||
|
||||
|
||||
class ConfigurationTests(unittest.TestCase):
|
||||
|
@ -227,31 +227,21 @@ class ConfigurationTests(unittest.TestCase):
|
|||
c = Config.create_from_arguments(args)
|
||||
self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'})
|
||||
|
||||
def test_one_of_string(self):
|
||||
with self.assertRaises(ValueError):
|
||||
no_vaid_values = OneOfString([], "no valid values", None)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
default_none = OneOfString(["a"], "invalid default", None)
|
||||
with self.assertRaises(ValueError):
|
||||
invalid_default = OneOfString(["a"], "invalid default", "b")
|
||||
|
||||
valid_default = OneOfString(["a"], "valid default", "a")
|
||||
|
||||
self.assertEqual("hello", OneOfString(["hello"], "valid default", "hello").default)
|
||||
def test_string_choice(self):
|
||||
with self.assertRaisesRegex(ValueError, "No valid values provided"):
|
||||
StringChoice("no valid values", [], "")
|
||||
with self.assertRaisesRegex(ValueError, "Default value must be one of"):
|
||||
StringChoice("invalid default", ["a"], "b")
|
||||
|
||||
c = TestConfig()
|
||||
with self.assertRaises(ValueError):
|
||||
c.one_of_string = "d"
|
||||
self.assertEqual("a", c.string_choice) # default
|
||||
c.string_choice = "b"
|
||||
self.assertEqual("b", c.string_choice)
|
||||
with self.assertRaisesRegex(ValueError, "Setting 'string_choice' value must be one of"):
|
||||
c.string_choice = "d"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
TestConfig.contribute_to_argparse(parser)
|
||||
|
||||
args = parser.parse_args(["--one-of-string=b"])
|
||||
args = parser.parse_args(['--string-choice', 'c'])
|
||||
c = TestConfig.create_from_arguments(args)
|
||||
self.assertEqual("b", c.one_of_string)
|
||||
|
||||
# with self.assertRaises(ValueError):
|
||||
# args = parser.parse_args(["--one-of-string=arst"])
|
||||
# c = TestConfig.create_from_arguments(args)
|
||||
# print("here")
|
||||
self.assertEqual("c", c.string_choice)
|
||||
|
|
|
@ -13,7 +13,7 @@ NULL_HASH = b'\x00'*32
|
|||
|
||||
|
||||
def search(*args, **kwargs):
|
||||
selection = CoinSelector(*args, **kwargs).branch_and_bound()
|
||||
selection = CoinSelector(*args[1:], **kwargs).select(args[0], 'branch_and_bound')
|
||||
return [o.txo.amount for o in selection] if selection else selection
|
||||
|
||||
|
||||
|
@ -37,17 +37,17 @@ class BaseSelectionTestCase(AsyncioTestCase):
|
|||
class TestCoinSelectionTests(BaseSelectionTestCase):
|
||||
|
||||
def test_empty_coins(self):
|
||||
self.assertEqual(CoinSelector([], 0, 0).select(), [])
|
||||
self.assertEqual(CoinSelector(0, 0).select([]), [])
|
||||
|
||||
def test_skip_binary_search_if_total_not_enough(self):
|
||||
fee = utxo(CENT).get_estimator(self.ledger).fee
|
||||
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
|
||||
selector = CoinSelector(big_pool, 101 * CENT, 0)
|
||||
self.assertEqual(selector.select(), [])
|
||||
selector = CoinSelector(101 * CENT, 0)
|
||||
self.assertEqual(selector.select(big_pool), [])
|
||||
self.assertEqual(selector.tries, 0) # Never tried.
|
||||
# check happy path
|
||||
selector = CoinSelector(big_pool, 100 * CENT, 0)
|
||||
self.assertEqual(len(selector.select()), 100)
|
||||
selector = CoinSelector(100 * CENT, 0)
|
||||
self.assertEqual(len(selector.select(big_pool)), 100)
|
||||
self.assertEqual(selector.tries, 201)
|
||||
|
||||
def test_exact_match(self):
|
||||
|
@ -57,8 +57,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
|
|||
utxo(CENT),
|
||||
utxo(CENT - fee)
|
||||
)
|
||||
selector = CoinSelector(utxo_pool, CENT, 0)
|
||||
match = selector.select()
|
||||
selector = CoinSelector(CENT, 0)
|
||||
match = selector.select(utxo_pool)
|
||||
self.assertEqual([CENT + fee], [c.txo.amount for c in match])
|
||||
self.assertTrue(selector.exact_match)
|
||||
|
||||
|
@ -68,8 +68,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
|
|||
utxo(3 * CENT),
|
||||
utxo(4 * CENT)
|
||||
)
|
||||
selector = CoinSelector(utxo_pool, CENT, 0, '\x00')
|
||||
match = selector.select()
|
||||
selector = CoinSelector(CENT, 0, '\x00')
|
||||
match = selector.select(utxo_pool)
|
||||
self.assertEqual([2 * CENT], [c.txo.amount for c in match])
|
||||
self.assertFalse(selector.exact_match)
|
||||
|
||||
|
@ -81,20 +81,27 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
|
|||
utxo(5*CENT),
|
||||
utxo(10*CENT),
|
||||
)
|
||||
selector = CoinSelector(utxo_pool, 3*CENT, 0)
|
||||
match = selector.select()
|
||||
selector = CoinSelector(3*CENT, 0)
|
||||
match = selector.select(utxo_pool)
|
||||
self.assertEqual([5*CENT], [c.txo.amount for c in match])
|
||||
|
||||
def test_prefer_confirmed_strategy(self):
|
||||
def test_confirmed_strategies(self):
|
||||
utxo_pool = self.estimates(
|
||||
utxo(11*CENT, height=5),
|
||||
utxo(11*CENT, height=0),
|
||||
utxo(11*CENT, height=-2),
|
||||
utxo(11*CENT, height=5),
|
||||
)
|
||||
selector = CoinSelector(utxo_pool, 20*CENT, 0)
|
||||
match = selector.select("prefer_confirmed")
|
||||
|
||||
match = CoinSelector(20*CENT, 0).select(utxo_pool, "only_confirmed")
|
||||
self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match])
|
||||
match = CoinSelector(25*CENT, 0).select(utxo_pool, "only_confirmed")
|
||||
self.assertEqual([], [c.txo.tx_ref.height for c in match])
|
||||
|
||||
match = CoinSelector(20*CENT, 0).select(utxo_pool, "prefer_confirmed")
|
||||
self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match])
|
||||
match = CoinSelector(25*CENT, 0, '\x00').select(utxo_pool, "prefer_confirmed")
|
||||
self.assertEqual([5, 0, -2], [c.txo.tx_ref.height for c in match])
|
||||
|
||||
|
||||
class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
|
||||
|
@ -160,8 +167,8 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
|
|||
|
||||
# Iteration exhaustion test
|
||||
utxo_pool, target = self.make_hard_case(17)
|
||||
selector = CoinSelector(utxo_pool, target, 0)
|
||||
self.assertEqual(selector.branch_and_bound(), [])
|
||||
selector = CoinSelector(target, 0)
|
||||
self.assertEqual(selector.select(utxo_pool, 'branch_and_bound'), [])
|
||||
self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust
|
||||
utxo_pool, target = self.make_hard_case(14)
|
||||
self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust
|
||||
|
|
|
@ -7,6 +7,7 @@ MAXIMUM_TRIES = 100000
|
|||
|
||||
STRATEGIES = []
|
||||
|
||||
|
||||
def strategy(method):
|
||||
STRATEGIES.append(method.__name__)
|
||||
return method
|
||||
|
@ -14,50 +15,67 @@ def strategy(method):
|
|||
|
||||
class CoinSelector:
|
||||
|
||||
def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
target: int, cost_of_change: int, seed: str = None) -> None:
|
||||
self.txos = txos
|
||||
def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None:
|
||||
self.target = target
|
||||
self.cost_of_change = cost_of_change
|
||||
self.exact_match = False
|
||||
self.tries = 0
|
||||
self.available = sum(c.effective_amount for c in self.txos)
|
||||
self.random = Random(seed)
|
||||
if seed is not None:
|
||||
self.random.seed(seed, version=1)
|
||||
|
||||
def select(self, strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
if not self.txos:
|
||||
def select(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
|
||||
strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
if not txos:
|
||||
return []
|
||||
if self.target > self.available:
|
||||
available = sum(c.effective_amount for c in txos)
|
||||
if self.target > available:
|
||||
return []
|
||||
if strategy_name is not None:
|
||||
return getattr(self, strategy_name)()
|
||||
return getattr(self, strategy_name or "standard")(txos, available)
|
||||
|
||||
@strategy
|
||||
def prefer_confirmed(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
return (
|
||||
self.branch_and_bound() or
|
||||
self.closest_match() or
|
||||
self.random_draw()
|
||||
self.only_confirmed(txos, available) or
|
||||
self.standard(txos, available)
|
||||
)
|
||||
|
||||
@strategy
|
||||
def prefer_confirmed(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
self.txos = [t for t in self.txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] or self.txos
|
||||
self.available = sum(c.effective_amount for c in self.txos)
|
||||
def only_confirmed(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0]
|
||||
if not confirmed:
|
||||
return []
|
||||
confirmed_available = sum(c.effective_amount for c in confirmed)
|
||||
if self.target > confirmed_available:
|
||||
return []
|
||||
return self.standard(confirmed, confirmed_available)
|
||||
|
||||
@strategy
|
||||
def standard(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
return (
|
||||
self.branch_and_bound() or
|
||||
self.closest_match() or
|
||||
self.random_draw()
|
||||
self.branch_and_bound(txos, available) or
|
||||
self.closest_match(txos, available) or
|
||||
self.random_draw(txos, available)
|
||||
)
|
||||
|
||||
@strategy
|
||||
def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def branch_and_bound(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
# see bitcoin implementation for more info:
|
||||
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
|
||||
|
||||
self.txos.sort(reverse=True)
|
||||
txos.sort(reverse=True)
|
||||
|
||||
current_value = 0
|
||||
current_available_value = self.available
|
||||
current_available_value = available
|
||||
current_selection: List[bool] = []
|
||||
best_waste = self.cost_of_change
|
||||
best_selection: List[bool] = []
|
||||
|
@ -79,19 +97,19 @@ class CoinSelector:
|
|||
if backtrack:
|
||||
while current_selection and not current_selection[-1]:
|
||||
current_selection.pop()
|
||||
current_available_value += self.txos[len(current_selection)].effective_amount
|
||||
current_available_value += txos[len(current_selection)].effective_amount
|
||||
|
||||
if not current_selection:
|
||||
break
|
||||
|
||||
current_selection[-1] = False
|
||||
utxo = self.txos[len(current_selection) - 1]
|
||||
utxo = txos[len(current_selection) - 1]
|
||||
current_value -= utxo.effective_amount
|
||||
|
||||
else:
|
||||
utxo = self.txos[len(current_selection)]
|
||||
utxo = txos[len(current_selection)]
|
||||
current_available_value -= utxo.effective_amount
|
||||
previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None
|
||||
previous_utxo = txos[len(current_selection) - 1] if current_selection else None
|
||||
if current_selection and not current_selection[-1] and previous_utxo and \
|
||||
utxo.effective_amount == previous_utxo.effective_amount and \
|
||||
utxo.fee == previous_utxo.fee:
|
||||
|
@ -103,18 +121,20 @@ class CoinSelector:
|
|||
if best_selection:
|
||||
self.exact_match = True
|
||||
return [
|
||||
self.txos[i] for i, include in enumerate(best_selection) if include
|
||||
txos[i] for i, include in enumerate(best_selection) if include
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@strategy
|
||||
def closest_match(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def closest_match(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
""" Pick one UTXOs that is larger than the target but with the smallest change. """
|
||||
target = self.target + self.cost_of_change
|
||||
smallest_change = None
|
||||
best_match = None
|
||||
for txo in self.txos:
|
||||
for txo in txos:
|
||||
if txo.effective_amount >= target:
|
||||
change = txo.effective_amount - target
|
||||
if smallest_change is None or change < smallest_change:
|
||||
|
@ -122,13 +142,15 @@ class CoinSelector:
|
|||
return [best_match] if best_match else []
|
||||
|
||||
@strategy
|
||||
def random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
def random_draw(
|
||||
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
|
||||
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
|
||||
""" Accumulate UTXOs at random until there is enough to cover the target. """
|
||||
target = self.target + self.cost_of_change
|
||||
self.random.shuffle(self.txos, self.random.random)
|
||||
self.random.shuffle(txos, self.random.random)
|
||||
selection = []
|
||||
amount = 0
|
||||
for coin in self.txos:
|
||||
for coin in txos:
|
||||
selection.append(coin)
|
||||
amount += coin.effective_amount
|
||||
if amount >= target:
|
||||
|
|
Loading…
Reference in a new issue