added coin selection strategy for both prefer_confirmed and only_confirmed

This commit is contained in:
Lex Berezhny 2019-06-21 02:15:59 -04:00
parent 01cd02e4c5
commit d8265add2d
4 changed files with 102 additions and 82 deletions

View file

@ -193,19 +193,19 @@ class MaxKeyFee(Setting[dict]):
) )
class OneOfString(String): class StringChoice(String):
def __init__(self, valid_values: typing.List[str], *args, **kwargs): def __init__(self, doc: str, valid_values: typing.List[str], default: str, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(doc, default, *args, **kwargs)
if not valid_values:
raise ValueError("No valid values provided")
if default not in valid_values:
raise ValueError(f"Default value must be one of: {', '.join(valid_values)}")
self.valid_values = valid_values self.valid_values = valid_values
if not self.valid_values:
raise ValueError(f"No valid values provided")
if self.default not in self.valid_values:
raise ValueError(f"Default value must be one of: " + ', '.join(self.valid_values))
def validate(self, val): def validate(self, val):
super().validate(val) super().validate(val)
if val not in self.valid_values: if val not in self.valid_values:
raise ValueError(f"Setting '{self.name}' must be one of: " + ', '.join(self.valid_values)) raise ValueError(f"Setting '{self.name}' value must be one of: {', '.join(self.valid_values)}")
class ListSetting(Setting[list]): class ListSetting(Setting[list]):
@ -578,8 +578,9 @@ class Config(CLIConfig):
streaming_get = Toggle("Enable the /get endpoint for the streaming media server. " streaming_get = Toggle("Enable the /get endpoint for the streaming media server. "
"Disable to prevent new streams from being added.", True) "Disable to prevent new streams from being added.", True)
coin_selection_strategy = OneOfString(STRATEGIES, "Strategy to use when selecting UTXOs for a transaction", coin_selection_strategy = StringChoice(
"branch_and_bound") "Strategy to use when selecting UTXOs for a transaction",
STRATEGIES, "standard")
@property @property
def streaming_host(self): def streaming_host(self):

View file

@ -4,7 +4,7 @@ import types
import tempfile import tempfile
import unittest import unittest
import argparse import argparse
from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, OneOfString, NOT_SET from lbry.conf import Config, BaseConfig, String, Integer, Toggle, Servers, Strings, StringChoice, NOT_SET
from lbry.error import InvalidCurrencyError from lbry.error import InvalidCurrencyError
@ -15,7 +15,7 @@ class TestConfig(BaseConfig):
test_true_toggle = Toggle('toggle help', True) test_true_toggle = Toggle('toggle help', True)
servers = Servers('servers help', [('localhost', 80)]) servers = Servers('servers help', [('localhost', 80)])
strings = Strings('cheese', ['string']) strings = Strings('cheese', ['string'])
one_of_string = OneOfString(["a", "b", "c"], "one of string", "a") string_choice = StringChoice("one of string", ["a", "b", "c"], "a")
class ConfigurationTests(unittest.TestCase): class ConfigurationTests(unittest.TestCase):
@ -227,31 +227,21 @@ class ConfigurationTests(unittest.TestCase):
c = Config.create_from_arguments(args) c = Config.create_from_arguments(args)
self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'}) self.assertEqual(c.max_key_fee, {'amount': 1.0, 'currency': 'BTC'})
def test_one_of_string(self): def test_string_choice(self):
with self.assertRaises(ValueError): with self.assertRaisesRegex(ValueError, "No valid values provided"):
no_vaid_values = OneOfString([], "no valid values", None) StringChoice("no valid values", [], "")
with self.assertRaisesRegex(ValueError, "Default value must be one of"):
with self.assertRaises(ValueError): StringChoice("invalid default", ["a"], "b")
default_none = OneOfString(["a"], "invalid default", None)
with self.assertRaises(ValueError):
invalid_default = OneOfString(["a"], "invalid default", "b")
valid_default = OneOfString(["a"], "valid default", "a")
self.assertEqual("hello", OneOfString(["hello"], "valid default", "hello").default)
c = TestConfig() c = TestConfig()
with self.assertRaises(ValueError): self.assertEqual("a", c.string_choice) # default
c.one_of_string = "d" c.string_choice = "b"
self.assertEqual("b", c.string_choice)
with self.assertRaisesRegex(ValueError, "Setting 'string_choice' value must be one of"):
c.string_choice = "d"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
TestConfig.contribute_to_argparse(parser) TestConfig.contribute_to_argparse(parser)
args = parser.parse_args(['--string-choice', 'c'])
args = parser.parse_args(["--one-of-string=b"])
c = TestConfig.create_from_arguments(args) c = TestConfig.create_from_arguments(args)
self.assertEqual("b", c.one_of_string) self.assertEqual("c", c.string_choice)
# with self.assertRaises(ValueError):
# args = parser.parse_args(["--one-of-string=arst"])
# c = TestConfig.create_from_arguments(args)
# print("here")

View file

@ -13,7 +13,7 @@ NULL_HASH = b'\x00'*32
def search(*args, **kwargs): def search(*args, **kwargs):
selection = CoinSelector(*args, **kwargs).branch_and_bound() selection = CoinSelector(*args[1:], **kwargs).select(args[0], 'branch_and_bound')
return [o.txo.amount for o in selection] if selection else selection return [o.txo.amount for o in selection] if selection else selection
@ -37,17 +37,17 @@ class BaseSelectionTestCase(AsyncioTestCase):
class TestCoinSelectionTests(BaseSelectionTestCase): class TestCoinSelectionTests(BaseSelectionTestCase):
def test_empty_coins(self): def test_empty_coins(self):
self.assertEqual(CoinSelector([], 0, 0).select(), []) self.assertEqual(CoinSelector(0, 0).select([]), [])
def test_skip_binary_search_if_total_not_enough(self): def test_skip_binary_search_if_total_not_enough(self):
fee = utxo(CENT).get_estimator(self.ledger).fee fee = utxo(CENT).get_estimator(self.ledger).fee
big_pool = self.estimates(utxo(CENT+fee) for _ in range(100)) big_pool = self.estimates(utxo(CENT+fee) for _ in range(100))
selector = CoinSelector(big_pool, 101 * CENT, 0) selector = CoinSelector(101 * CENT, 0)
self.assertEqual(selector.select(), []) self.assertEqual(selector.select(big_pool), [])
self.assertEqual(selector.tries, 0) # Never tried. self.assertEqual(selector.tries, 0) # Never tried.
# check happy path # check happy path
selector = CoinSelector(big_pool, 100 * CENT, 0) selector = CoinSelector(100 * CENT, 0)
self.assertEqual(len(selector.select()), 100) self.assertEqual(len(selector.select(big_pool)), 100)
self.assertEqual(selector.tries, 201) self.assertEqual(selector.tries, 201)
def test_exact_match(self): def test_exact_match(self):
@ -57,8 +57,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
utxo(CENT), utxo(CENT),
utxo(CENT - fee) utxo(CENT - fee)
) )
selector = CoinSelector(utxo_pool, CENT, 0) selector = CoinSelector(CENT, 0)
match = selector.select() match = selector.select(utxo_pool)
self.assertEqual([CENT + fee], [c.txo.amount for c in match]) self.assertEqual([CENT + fee], [c.txo.amount for c in match])
self.assertTrue(selector.exact_match) self.assertTrue(selector.exact_match)
@ -68,8 +68,8 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
utxo(3 * CENT), utxo(3 * CENT),
utxo(4 * CENT) utxo(4 * CENT)
) )
selector = CoinSelector(utxo_pool, CENT, 0, '\x00') selector = CoinSelector(CENT, 0, '\x00')
match = selector.select() match = selector.select(utxo_pool)
self.assertEqual([2 * CENT], [c.txo.amount for c in match]) self.assertEqual([2 * CENT], [c.txo.amount for c in match])
self.assertFalse(selector.exact_match) self.assertFalse(selector.exact_match)
@ -81,20 +81,27 @@ class TestCoinSelectionTests(BaseSelectionTestCase):
utxo(5*CENT), utxo(5*CENT),
utxo(10*CENT), utxo(10*CENT),
) )
selector = CoinSelector(utxo_pool, 3*CENT, 0) selector = CoinSelector(3*CENT, 0)
match = selector.select() match = selector.select(utxo_pool)
self.assertEqual([5*CENT], [c.txo.amount for c in match]) self.assertEqual([5*CENT], [c.txo.amount for c in match])
def test_prefer_confirmed_strategy(self): def test_confirmed_strategies(self):
utxo_pool = self.estimates( utxo_pool = self.estimates(
utxo(11*CENT, height=5), utxo(11*CENT, height=5),
utxo(11*CENT, height=0), utxo(11*CENT, height=0),
utxo(11*CENT, height=-2), utxo(11*CENT, height=-2),
utxo(11*CENT, height=5), utxo(11*CENT, height=5),
) )
selector = CoinSelector(utxo_pool, 20*CENT, 0)
match = selector.select("prefer_confirmed") match = CoinSelector(20*CENT, 0).select(utxo_pool, "only_confirmed")
self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match]) self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match])
match = CoinSelector(25*CENT, 0).select(utxo_pool, "only_confirmed")
self.assertEqual([], [c.txo.tx_ref.height for c in match])
match = CoinSelector(20*CENT, 0).select(utxo_pool, "prefer_confirmed")
self.assertEqual([5, 5], [c.txo.tx_ref.height for c in match])
match = CoinSelector(25*CENT, 0, '\x00').select(utxo_pool, "prefer_confirmed")
self.assertEqual([5, 0, -2], [c.txo.tx_ref.height for c in match])
class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase): class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
@ -160,8 +167,8 @@ class TestOfficialBitcoinCoinSelectionTests(BaseSelectionTestCase):
# Iteration exhaustion test # Iteration exhaustion test
utxo_pool, target = self.make_hard_case(17) utxo_pool, target = self.make_hard_case(17)
selector = CoinSelector(utxo_pool, target, 0) selector = CoinSelector(target, 0)
self.assertEqual(selector.branch_and_bound(), []) self.assertEqual(selector.select(utxo_pool, 'branch_and_bound'), [])
self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust
utxo_pool, target = self.make_hard_case(14) utxo_pool, target = self.make_hard_case(14)
self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust

View file

@ -7,6 +7,7 @@ MAXIMUM_TRIES = 100000
STRATEGIES = [] STRATEGIES = []
def strategy(method): def strategy(method):
STRATEGIES.append(method.__name__) STRATEGIES.append(method.__name__)
return method return method
@ -14,50 +15,67 @@ def strategy(method):
class CoinSelector: class CoinSelector:
def __init__(self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], def __init__(self, target: int, cost_of_change: int, seed: str = None) -> None:
target: int, cost_of_change: int, seed: str = None) -> None:
self.txos = txos
self.target = target self.target = target
self.cost_of_change = cost_of_change self.cost_of_change = cost_of_change
self.exact_match = False self.exact_match = False
self.tries = 0 self.tries = 0
self.available = sum(c.effective_amount for c in self.txos)
self.random = Random(seed) self.random = Random(seed)
if seed is not None: if seed is not None:
self.random.seed(seed, version=1) self.random.seed(seed, version=1)
def select(self, strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: def select(
if not self.txos: self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator],
strategy_name: str = None) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
if not txos:
return [] return []
if self.target > self.available: available = sum(c.effective_amount for c in txos)
if self.target > available:
return [] return []
if strategy_name is not None: return getattr(self, strategy_name or "standard")(txos, available)
return getattr(self, strategy_name)()
@strategy
def prefer_confirmed(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
return ( return (
self.branch_and_bound() or self.only_confirmed(txos, available) or
self.closest_match() or self.standard(txos, available)
self.random_draw()
) )
@strategy @strategy
def prefer_confirmed(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: def only_confirmed(
self.txos = [t for t in self.txos if t.txo.tx_ref and t.txo.tx_ref.height > 0] or self.txos self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
self.available = sum(c.effective_amount for c in self.txos) ) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
confirmed = [t for t in txos if t.txo.tx_ref and t.txo.tx_ref.height > 0]
if not confirmed:
return []
confirmed_available = sum(c.effective_amount for c in confirmed)
if self.target > confirmed_available:
return []
return self.standard(confirmed, confirmed_available)
@strategy
def standard(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
return ( return (
self.branch_and_bound() or self.branch_and_bound(txos, available) or
self.closest_match() or self.closest_match(txos, available) or
self.random_draw() self.random_draw(txos, available)
) )
@strategy @strategy
def branch_and_bound(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: def branch_and_bound(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], available: int
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
# see bitcoin implementation for more info: # see bitcoin implementation for more info:
# https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp
self.txos.sort(reverse=True) txos.sort(reverse=True)
current_value = 0 current_value = 0
current_available_value = self.available current_available_value = available
current_selection: List[bool] = [] current_selection: List[bool] = []
best_waste = self.cost_of_change best_waste = self.cost_of_change
best_selection: List[bool] = [] best_selection: List[bool] = []
@ -79,19 +97,19 @@ class CoinSelector:
if backtrack: if backtrack:
while current_selection and not current_selection[-1]: while current_selection and not current_selection[-1]:
current_selection.pop() current_selection.pop()
current_available_value += self.txos[len(current_selection)].effective_amount current_available_value += txos[len(current_selection)].effective_amount
if not current_selection: if not current_selection:
break break
current_selection[-1] = False current_selection[-1] = False
utxo = self.txos[len(current_selection) - 1] utxo = txos[len(current_selection) - 1]
current_value -= utxo.effective_amount current_value -= utxo.effective_amount
else: else:
utxo = self.txos[len(current_selection)] utxo = txos[len(current_selection)]
current_available_value -= utxo.effective_amount current_available_value -= utxo.effective_amount
previous_utxo = self.txos[len(current_selection) - 1] if current_selection else None previous_utxo = txos[len(current_selection) - 1] if current_selection else None
if current_selection and not current_selection[-1] and previous_utxo and \ if current_selection and not current_selection[-1] and previous_utxo and \
utxo.effective_amount == previous_utxo.effective_amount and \ utxo.effective_amount == previous_utxo.effective_amount and \
utxo.fee == previous_utxo.fee: utxo.fee == previous_utxo.fee:
@ -103,18 +121,20 @@ class CoinSelector:
if best_selection: if best_selection:
self.exact_match = True self.exact_match = True
return [ return [
self.txos[i] for i, include in enumerate(best_selection) if include txos[i] for i, include in enumerate(best_selection) if include
] ]
return [] return []
@strategy @strategy
def closest_match(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: def closest_match(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
""" Pick one UTXOs that is larger than the target but with the smallest change. """ """ Pick one UTXOs that is larger than the target but with the smallest change. """
target = self.target + self.cost_of_change target = self.target + self.cost_of_change
smallest_change = None smallest_change = None
best_match = None best_match = None
for txo in self.txos: for txo in txos:
if txo.effective_amount >= target: if txo.effective_amount >= target:
change = txo.effective_amount - target change = txo.effective_amount - target
if smallest_change is None or change < smallest_change: if smallest_change is None or change < smallest_change:
@ -122,13 +142,15 @@ class CoinSelector:
return [best_match] if best_match else [] return [best_match] if best_match else []
@strategy @strategy
def random_draw(self) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]: def random_draw(
self, txos: List[basetransaction.BaseOutputEffectiveAmountEstimator], _
) -> List[basetransaction.BaseOutputEffectiveAmountEstimator]:
""" Accumulate UTXOs at random until there is enough to cover the target. """ """ Accumulate UTXOs at random until there is enough to cover the target. """
target = self.target + self.cost_of_change target = self.target + self.cost_of_change
self.random.shuffle(self.txos, self.random.random) self.random.shuffle(txos, self.random.random)
selection = [] selection = []
amount = 0 amount = 0
for coin in self.txos: for coin in txos:
selection.append(coin) selection.append(coin)
amount += coin.effective_amount amount += coin.effective_amount
if amount >= target: if amount >= target: