From dcd8a6bb0e4b7ef10cbe578ab780f80a76d1e8bd Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Thu, 19 Apr 2018 13:23:59 -0400 Subject: [PATCH] branch and bound based coin selection with random draw fallback --- .../tests/unit/wallet/test_coinselection.py | 151 +++++++++ lbrynet/wallet/coinchooser.py | 313 ------------------ lbrynet/wallet/coinselection.py | 93 ++++++ 3 files changed, 244 insertions(+), 313 deletions(-) create mode 100644 lbrynet/tests/unit/wallet/test_coinselection.py delete mode 100644 lbrynet/wallet/coinchooser.py create mode 100644 lbrynet/wallet/coinselection.py diff --git a/lbrynet/tests/unit/wallet/test_coinselection.py b/lbrynet/tests/unit/wallet/test_coinselection.py new file mode 100644 index 000000000..06d502b0a --- /dev/null +++ b/lbrynet/tests/unit/wallet/test_coinselection.py @@ -0,0 +1,151 @@ +import unittest + +from lbrynet.wallet.constants import CENT, MAXIMUM_FEE_PER_BYTE +from lbrynet.wallet.transaction import Transaction, Output +from lbrynet.wallet.coinselection import CoinSelector, MAXIMUM_TRIES +from lbrynet.wallet.manager import WalletManager +from lbrynet.wallet import set_wallet_manager + + +NULL_HASH = '\x00'*32 + + +def search(*args, **kwargs): + selection = CoinSelector(*args, **kwargs).branch_and_bound() + return [o.amount for o in selection] if selection else selection + + +def utxo(amount): + return Output.pay_pubkey_hash(Transaction(), 0, amount, NULL_HASH) + + +class TestCoinSelectionTests(unittest.TestCase): + + def setUp(self): + set_wallet_manager(WalletManager({'fee_per_byte': MAXIMUM_FEE_PER_BYTE})) + + def test_empty_coins(self): + self.assertIsNone(CoinSelector([], 0, 0).select()) + + def test_skip_binary_search_if_total_not_enough(self): + fee = utxo(CENT).spend(fake=True).fee + big_pool = [utxo(CENT+fee) for _ in range(100)] + selector = CoinSelector(big_pool, 101 * CENT, 0) + self.assertIsNone(selector.select()) + self.assertEqual(selector.tries, 0) # Never tried. + # check happy path + selector = CoinSelector(big_pool, 100 * CENT, 0) + self.assertEqual(len(selector.select()), 100) + self.assertEqual(selector.tries, 201) + + def test_exact_match(self): + fee = utxo(CENT).spend(fake=True).fee + utxo_pool = [ + utxo(CENT + fee), + utxo(CENT), + utxo(CENT - fee), + ] + selector = CoinSelector(utxo_pool, CENT, 0) + match = selector.select() + self.assertEqual([CENT + fee], [c.amount for c in match]) + self.assertTrue(selector.exact_match) + + def test_random_draw(self): + utxo_pool = [ + utxo(2 * CENT), + utxo(3 * CENT), + utxo(4 * CENT), + ] + selector = CoinSelector(utxo_pool, CENT, 0, 1) + match = selector.select() + self.assertEqual([2 * CENT], [c.amount for c in match]) + self.assertFalse(selector.exact_match) + + +class TestOfficialBitcoinCoinSelectionTests(unittest.TestCase): + + # Bitcoin implementation: + # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp + # + # Bitcoin implementation tests: + # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/test/coinselector_tests.cpp + # + # Branch and Bound coin selection white paper: + # https://murch.one/wp-content/uploads/2016/11/erhardt2016coinselection.pdf + + def setUp(self): + set_wallet_manager(WalletManager({'fee_per_byte': 0})) + + def make_hard_case(self, utxos): + target = 0 + utxo_pool = [] + for i in range(utxos): + amount = 1 << (utxos+i) + target += amount + utxo_pool.append(utxo(amount)) + utxo_pool.append(utxo(amount + (1 << (utxos-1-i)))) + return utxo_pool, target + + def test_branch_and_bound_coin_selection(self): + utxo_pool = [ + utxo(1 * CENT), + utxo(2 * CENT), + utxo(3 * CENT), + utxo(4 * CENT) + ] + + # Select 1 Cent + self.assertEqual([1 * CENT], search(utxo_pool, 1 * CENT, 0.5 * CENT)) + + # Select 2 Cent + self.assertEqual([2 * CENT], search(utxo_pool, 2 * CENT, 0.5 * CENT)) + + # Select 5 Cent + self.assertEqual([3 * CENT, 2 * CENT], search(utxo_pool, 5 * CENT, 0.5 * CENT)) + + # Select 11 Cent, not possible + self.assertIsNone(search(utxo_pool, 11 * CENT, 0.5 * CENT)) + + # Select 10 Cent + utxo_pool += [utxo(5 * CENT)] + self.assertEqual( + [4 * CENT, 3 * CENT, 2 * CENT, 1 * CENT], + search(utxo_pool, 10 * CENT, 0.5 * CENT) + ) + + # Negative effective value + # Select 10 Cent but have 1 Cent not be possible because too small + # TODO: bitcoin has [5, 3, 2] + self.assertEqual( + [4 * CENT, 3 * CENT, 2 * CENT, 1 * CENT], + search(utxo_pool, 10 * CENT, 5000) + ) + + # Select 0.25 Cent, not possible + self.assertIsNone(search(utxo_pool, 0.25 * CENT, 0.5 * CENT)) + + # Iteration exhaustion test + utxo_pool, target = self.make_hard_case(17) + selector = CoinSelector(utxo_pool, target, 0) + self.assertIsNone(selector.branch_and_bound()) + self.assertEqual(selector.tries, MAXIMUM_TRIES) # Should exhaust + utxo_pool, target = self.make_hard_case(14) + self.assertIsNotNone(search(utxo_pool, target, 0)) # Should not exhaust + + # Test same value early bailout optimization + utxo_pool = [ + utxo(7 * CENT), + utxo(7 * CENT), + utxo(7 * CENT), + utxo(7 * CENT), + utxo(2 * CENT) + ] + [utxo(5 * CENT)]*50000 + self.assertEqual( + [7 * CENT, 7 * CENT, 7 * CENT, 7 * CENT, 2 * CENT], + search(utxo_pool, 30 * CENT, 5000) + ) + + # Select 1 Cent with pool of only greater than 5 Cent + utxo_pool = [utxo(i * CENT) for i in range(5, 21)] + for _ in range(100): + self.assertIsNone(search(utxo_pool, 1 * CENT, 2 * CENT)) diff --git a/lbrynet/wallet/coinchooser.py b/lbrynet/wallet/coinchooser.py deleted file mode 100644 index 72c725fdf..000000000 --- a/lbrynet/wallet/coinchooser.py +++ /dev/null @@ -1,313 +0,0 @@ -import struct -import logging -from collections import defaultdict, namedtuple -from math import floor, log10 - -from .hashing import sha256 -from .constants import COIN, TYPE_ADDRESS -from .transaction import Transaction -from .errors import NotEnoughFunds - -log = logging.getLogger() - - -class PRNG(object): - """ - A simple deterministic PRNG. Used to deterministically shuffle a - set of coins - the same set of coins should produce the same output. - Although choosing UTXOs "randomly" we want it to be deterministic, - so if sending twice from the same UTXO set we choose the same UTXOs - to spend. This prevents attacks on users by malicious or stale - servers. - """ - - def __init__(self, seed): - self.sha = sha256(seed) - self.pool = bytearray() - - def get_bytes(self, n): - while len(self.pool) < n: - self.pool.extend(self.sha) - self.sha = sha256(self.sha) - result, self.pool = self.pool[:n], self.pool[n:] - return result - - def random(self): - # Returns random double in [0, 1) - four = self.get_bytes(4) - return struct.unpack("I", four)[0] / 4294967296.0 - - def randint(self, start, end): - # Returns random integer in [start, end) - return start + int(self.random() * (end - start)) - - def choice(self, seq): - return seq[int(self.random() * len(seq))] - - def shuffle(self, x): - for i in reversed(xrange(1, len(x))): - # pick an element in x[:i+1] with which to exchange x[i] - j = int(self.random() * (i + 1)) - x[i], x[j] = x[j], x[i] - - -Bucket = namedtuple('Bucket', ['desc', 'size', 'value', 'coins']) - - -def strip_unneeded(bkts, sufficient_funds): - '''Remove buckets that are unnecessary in achieving the spend amount''' - bkts = sorted(bkts, key=lambda bkt: bkt.value) - for i in range(len(bkts)): - if not sufficient_funds(bkts[i + 1:]): - return bkts[i:] - # Shouldn't get here - return bkts - - -class CoinChooserBase: - def keys(self, coins): - raise NotImplementedError - - def bucketize_coins(self, coins): - keys = self.keys(coins) - buckets = defaultdict(list) - for key, coin in zip(keys, coins): - buckets[key].append(coin) - - def make_Bucket(desc, coins): - size = sum(Transaction.estimated_input_size(coin) - for coin in coins) - value = sum(coin['value'] for coin in coins) - return Bucket(desc, size, value, coins) - - return map(make_Bucket, buckets.keys(), buckets.values()) - - def penalty_func(self, tx): - def penalty(candidate): - return 0 - - return penalty - - def change_amounts(self, tx, count, fee_estimator, dust_threshold): - # Break change up if bigger than max_change - output_amounts = [o[2] for o in tx.outputs()] - # Don't split change of less than 0.02 BTC - max_change = max(max(output_amounts) * 1.25, 0.02 * COIN) - - # Use N change outputs - for n in range(1, count + 1): - # How much is left if we add this many change outputs? - change_amount = max(0, tx.get_fee() - fee_estimator(n)) - if change_amount // n <= max_change: - break - - # Get a handle on the precision of the output amounts; round our - # change to look similar - def trailing_zeroes(val): - s = str(val) - return len(s) - len(s.rstrip('0')) - - zeroes = map(trailing_zeroes, output_amounts) - min_zeroes = min(zeroes) - max_zeroes = max(zeroes) - zeroes = range(max(0, min_zeroes - 1), (max_zeroes + 1) + 1) - - # Calculate change; randomize it a bit if using more than 1 output - remaining = change_amount - amounts = [] - while n > 1: - average = remaining // n - amount = self.p.randint(int(average * 0.7), int(average * 1.3)) - precision = min(self.p.choice(zeroes), int(floor(log10(amount)))) - amount = int(round(amount, -precision)) - amounts.append(amount) - remaining -= amount - n -= 1 - - # Last change output. Round down to maximum precision but lose - # no more than 100 satoshis to fees (2dp) - N = pow(10, min(2, zeroes[0])) - amount = (remaining // N) * N - amounts.append(amount) - - assert sum(amounts) <= change_amount - - return amounts - - def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold): - amounts = self.change_amounts(tx, len(change_addrs), fee_estimator, - dust_threshold) - assert min(amounts) >= 0 - assert len(change_addrs) >= len(amounts) - # If change is above dust threshold after accounting for the - # size of the change output, add it to the transaction. - dust = sum(amount for amount in amounts if amount < dust_threshold) - amounts = [amount for amount in amounts if amount >= dust_threshold] - change = [(TYPE_ADDRESS, addr, amount) - for addr, amount in zip(change_addrs, amounts)] - log.debug('change: %s', change) - if dust: - log.debug('not keeping dust %s', dust) - return change - - def make_tx(self, coins, outputs, change_addrs, fee_estimator, - dust_threshold, abandon_txid=None): - '''Select unspent coins to spend to pay outputs. If the change is - greater than dust_threshold (after adding the change output to - the transaction) it is kept, otherwise none is sent and it is - added to the transaction fee.''' - - # Deterministic randomness from coins - utxos = [c['prevout_hash'] + str(c['prevout_n']) for c in coins] - self.p = PRNG(''.join(sorted(utxos))) - - # Copy the ouputs so when adding change we don't modify "outputs" - tx = Transaction.from_io([], outputs[:]) - # Size of the transaction with no inputs and no change - base_size = tx.estimated_size() - spent_amount = tx.output_value() - - claim_coin = None - if abandon_txid is not None: - claim_coins = [coin for coin in coins if coin['is_claim']] - assert len(claim_coins) >= 1 - claim_coin = claim_coins[0] - spent_amount -= claim_coin['value'] - coins = [coin for coin in coins if not coin['is_claim']] - - def sufficient_funds(buckets): - '''Given a list of buckets, return True if it has enough - value to pay for the transaction''' - total_input = sum(bucket.value for bucket in buckets) - total_size = sum(bucket.size for bucket in buckets) + base_size - return total_input >= spent_amount + fee_estimator(total_size) - - # Collect the coins into buckets, choose a subset of the buckets - buckets = self.bucketize_coins(coins) - buckets = self.choose_buckets(buckets, sufficient_funds, - self.penalty_func(tx)) - - if claim_coin is not None: - tx.add_inputs([claim_coin]) - tx.add_inputs([coin for b in buckets for coin in b.coins]) - tx_size = base_size + sum(bucket.size for bucket in buckets) - - # This takes a count of change outputs and returns a tx fee; - # each pay-to-bitcoin-address output serializes as 34 bytes - def fee(count): - return fee_estimator(tx_size + count * 34) - - change = self.change_outputs(tx, change_addrs, fee, dust_threshold) - tx.add_outputs(change) - - log.debug("using %i inputs", len(tx.inputs())) - log.info("using buckets: %s", [bucket.desc for bucket in buckets]) - - return tx - - -class CoinChooserOldestFirst(CoinChooserBase): - '''Maximize transaction priority. Select the oldest unspent - transaction outputs in your wallet, that are sufficient to cover - the spent amount. Then, remove any unneeded inputs, starting with - the smallest in value. - ''' - - def keys(self, coins): - return [coin['prevout_hash'] + ':' + str(coin['prevout_n']) - for coin in coins] - - def choose_buckets(self, buckets, sufficient_funds, penalty_func): - '''Spend the oldest buckets first.''' - # Unconfirmed coins are young, not old - def adj_height(height): - return 99999999 if height == 0 else height - - buckets.sort(key=lambda b: max(adj_height(coin['height']) - for coin in b.coins)) - selected = [] - for bucket in buckets: - selected.append(bucket) - if sufficient_funds(selected): - return strip_unneeded(selected, sufficient_funds) - raise NotEnoughFunds() - - -class CoinChooserRandom(CoinChooserBase): - def keys(self, coins): - return [coin['prevout_hash'] + ':' + str(coin['prevout_n']) - for coin in coins] - - def bucket_candidates(self, buckets, sufficient_funds): - '''Returns a list of bucket sets.''' - candidates = set() - - # Add all singletons - for n, bucket in enumerate(buckets): - if sufficient_funds([bucket]): - candidates.add((n,)) - - # And now some random ones - attempts = min(100, (len(buckets) - 1) * 10 + 1) - permutation = range(len(buckets)) - for i in range(attempts): - # Get a random permutation of the buckets, and - # incrementally combine buckets until sufficient - self.p.shuffle(permutation) - bkts = [] - for count, index in enumerate(permutation): - bkts.append(buckets[index]) - if sufficient_funds(bkts): - candidates.add(tuple(sorted(permutation[:count + 1]))) - break - else: - raise NotEnoughFunds() - - candidates = [[buckets[n] for n in c] for c in candidates] - return [strip_unneeded(c, sufficient_funds) for c in candidates] - - def choose_buckets(self, buckets, sufficient_funds, penalty_func): - candidates = self.bucket_candidates(buckets, sufficient_funds) - penalties = [penalty_func(cand) for cand in candidates] - winner = candidates[penalties.index(min(penalties))] - log.debug("Bucket sets: %i", len(buckets)) - log.debug("Winning penalty: %s", min(penalties)) - return winner - - -class CoinChooserPrivacy(CoinChooserRandom): - '''Attempts to better preserve user privacy. First, if any coin is - spent from a user address, all coins are. Compared to spending - from other addresses to make up an amount, this reduces - information leakage about sender holdings. It also helps to - reduce blockchain UTXO bloat, and reduce future privacy loss that - would come from reusing that address' remaining UTXOs. Second, it - penalizes change that is quite different to the sent amount. - Third, it penalizes change that is too big.''' - - def keys(self, coins): - return [coin['address'] for coin in coins] - - def penalty_func(self, tx): - min_change = min(o[2] for o in tx.outputs()) * 0.75 - max_change = max(o[2] for o in tx.outputs()) * 1.33 - spent_amount = sum(o[2] for o in tx.outputs()) - - def penalty(buckets): - badness = len(buckets) - 1 - total_input = sum(bucket.value for bucket in buckets) - change = float(total_input - spent_amount) - # Penalize change not roughly in output range - if change < min_change: - badness += (min_change - change) / (min_change + 10000) - elif change > max_change: - badness += (change - max_change) / (max_change + 10000) - # Penalize large change; 5 BTC excess ~= using 1 more input - badness += change / (COIN * 5) - return badness - - return penalty - - -COIN_CHOOSERS = {'Priority': CoinChooserOldestFirst, - 'Privacy': CoinChooserPrivacy} diff --git a/lbrynet/wallet/coinselection.py b/lbrynet/wallet/coinselection.py new file mode 100644 index 000000000..3e9080731 --- /dev/null +++ b/lbrynet/wallet/coinselection.py @@ -0,0 +1,93 @@ +from __future__ import print_function +from random import Random + +MAXIMUM_TRIES = 100000 + + +class CoinSelector: + + def __init__(self, coins, target, cost_of_change, seed=None, debug=False): + self.coins = coins + self.target = target + self.cost_of_change = cost_of_change + self.exact_match = False + self.tries = 0 + self.available = sum(c.effective_amount for c in self.coins) + self.debug = debug + self.random = Random(seed) + debug and print(target) + debug and print([c.effective_amount for c in self.coins]) + + def select(self): + if self.target > self.available: + return + if not self.coins: + return + return self.branch_and_bound() or self.single_random_draw() + + def single_random_draw(self): + self.random.shuffle(self.coins) + selection = [] + amount = 0 + for coin in self.coins: + selection.append(coin) + amount += coin.effective_amount + if amount >= self.target+self.cost_of_change: + return selection + + def branch_and_bound(self): + # see bitcoin implementation for more info: + # https://github.com/bitcoin/bitcoin/blob/master/src/wallet/coinselection.cpp + + self.coins.sort(reverse=True) + + current_value = 0 + current_available_value = self.available + current_selection = [] + best_waste = self.cost_of_change + best_selection = [] + + while self.tries < MAXIMUM_TRIES: + self.tries += 1 + + backtrack = False + if current_value + current_available_value < self.target or \ + current_value > self.target + self.cost_of_change: + backtrack = True + elif current_value >= self.target: + new_waste = current_value - self.target + if new_waste <= best_waste: + best_waste = new_waste + best_selection = current_selection[:] + backtrack = True + + if backtrack: + while current_selection and not current_selection[-1]: + current_selection.pop() + current_available_value += self.coins[len(current_selection)].effective_amount + + if not current_selection: + break + + current_selection[-1] = False + utxo = self.coins[len(current_selection)-1] + current_value -= utxo.effective_amount + + else: + utxo = self.coins[len(current_selection)] + current_available_value -= utxo.effective_amount + previous_utxo = self.coins[len(current_selection)-1] if current_selection else None + if current_selection and not current_selection[-1] and \ + utxo.effective_amount == previous_utxo.effective_amount and \ + utxo.fee == previous_utxo.fee: + current_selection.append(False) + else: + current_selection.append(True) + current_value += utxo.effective_amount + self.debug and print(current_selection) + + if best_selection: + self.exact_match = True + return [ + self.coins[i] for i, include in enumerate(best_selection) if include + ]