Merge pull request #409 from lbryio/fix_wallet_race_condition

Fix wallet balance interfaces
This commit is contained in:
Umpei Kay Kurokawa 2017-01-26 15:12:54 -05:00 committed by GitHub
commit 893fe8823e
4 changed files with 95 additions and 23 deletions

View file

@ -19,7 +19,8 @@ from lbryum.commands import known_commands, Commands
from lbrynet.core.sqlite_helpers import rerun_if_locked from lbrynet.core.sqlite_helpers import rerun_if_locked
from lbrynet.interfaces import IRequestCreator, IQueryHandlerFactory, IQueryHandler, IWallet from lbrynet.interfaces import IRequestCreator, IQueryHandlerFactory, IQueryHandler, IWallet
from lbrynet.core.client.ClientRequest import ClientRequest from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import UnknownNameError, InvalidStreamInfoError, RequestCanceledError from lbrynet.core.Error import (UnknownNameError, InvalidStreamInfoError, RequestCanceledError,
InsufficientFundsError)
from lbrynet.db_migrator.migrate1to2 import UNSET_NOUT from lbrynet.db_migrator.migrate1to2 import UNSET_NOUT
from lbrynet.metadata.Metadata import Metadata from lbrynet.metadata.Metadata import Metadata
@ -285,19 +286,12 @@ class Wallet(object):
else: else:
d = defer.succeed(True) d = defer.succeed(True)
d.addCallback(lambda _: self.get_balance())
def set_wallet_balance(balance):
if self.wallet_balance != balance:
log.debug("Got a new balance: %s", str(balance))
self.wallet_balance = balance
def log_error(err): def log_error(err):
if isinstance(err, AttributeError): if isinstance(err, AttributeError):
log.warning("Failed to get an updated balance") log.warning("Failed to get an updated balance")
log.warning("Last balance update: %s", str(self.wallet_balance)) log.warning("Last balance update: %s", str(self.wallet_balance))
d.addCallbacks(set_wallet_balance, log_error) d.addCallbacks(lambda _: self.update_balance(), log_error)
return d return d
d.addCallback(lambda should_run: do_manage() if should_run else None) d.addCallback(lambda should_run: do_manage() if should_run else None)
@ -323,6 +317,15 @@ class Wallet(object):
d.addBoth(set_manage_not_running) d.addBoth(set_manage_not_running)
return d return d
@defer.inlineCallbacks
def update_balance(self):
""" obtain balance from lbryum wallet and set self.wallet_balance
"""
balance = yield self._update_balance()
if self.wallet_balance != balance:
log.debug("Got a new balance: %s", balance)
self.wallet_balance = balance
def get_info_exchanger(self): def get_info_exchanger(self):
return LBRYcrdAddressRequester(self) return LBRYcrdAddressRequester(self)
@ -341,7 +344,7 @@ class Wallet(object):
once the service has been rendered once the service has been rendered
""" """
rounded_amount = Decimal(str(round(amount, 8))) rounded_amount = Decimal(str(round(amount, 8)))
if self.wallet_balance >= self.total_reserved_points + rounded_amount: if self.get_balance() >= rounded_amount:
self.total_reserved_points += rounded_amount self.total_reserved_points += rounded_amount
return ReservedPoints(identifier, rounded_amount) return ReservedPoints(identifier, rounded_amount)
return None return None
@ -432,7 +435,6 @@ class Wallet(object):
log.debug("Should be sending %s points to %s", str(points), str(address)) log.debug("Should be sending %s points to %s", str(points), str(address))
payments_to_send[address] = points payments_to_send[address] = points
self.total_reserved_points -= points self.total_reserved_points -= points
self.wallet_balance -= points
else: else:
log.info("Skipping dust") log.info("Skipping dust")
@ -443,6 +445,7 @@ class Wallet(object):
d = self._do_send_many(payments_to_send) d = self._do_send_many(payments_to_send)
d.addCallback(lambda txid: log.debug("Sent transaction %s", txid)) d.addCallback(lambda txid: log.debug("Sent transaction %s", txid))
return d return d
log.debug("There were no payments to send") log.debug("There were no payments to send")
return defer.succeed(True) return defer.succeed(True)
@ -628,6 +631,7 @@ class Wallet(object):
""" """
def claim_name(self, name, bid, m): def claim_name(self, name, bid, m):
def _save_metadata(claim_out, metadata): def _save_metadata(claim_out, metadata):
if not claim_out['success']: if not claim_out['success']:
msg = 'Claim to name {} failed: {}'.format(name, claim_out['reason']) msg = 'Claim to name {} failed: {}'.format(name, claim_out['reason'])
@ -643,9 +647,13 @@ class Wallet(object):
def _claim_or_update(claim, metadata, _bid): def _claim_or_update(claim, metadata, _bid):
if not claim: if not claim:
log.debug("No own claim yet, making a new one") log.debug("No own claim yet, making a new one")
if self.get_balance() < _bid:
raise InsufficientFundsError()
return self._send_name_claim(name, metadata, _bid) return self._send_name_claim(name, metadata, _bid)
else: else:
log.debug("Updating over own claim") log.debug("Updating over own claim")
if self.get_balance() < _bid - claim['amount']:
raise InsufficientFundsError()
d = self.update_metadata(metadata, claim['value']) d = self.update_metadata(metadata, claim['value'])
claim_outpoint = ClaimOutpoint(claim['txid'], claim['nOut']) claim_outpoint = ClaimOutpoint(claim['txid'], claim['nOut'])
d.addCallback( d.addCallback(
@ -682,6 +690,9 @@ class Wallet(object):
claim_out = self._process_claim_out(claim_out) claim_out = self._process_claim_out(claim_out)
return defer.succeed(claim_out) return defer.succeed(claim_out)
if self.get_balance() < amount:
raise InsufficientFundsError()
d = self._support_claim(name, claim_id, amount) d = self._support_claim(name, claim_id, amount)
d.addCallback(lambda claim_out: _parse_support_claim_out(claim_out)) d.addCallback(lambda claim_out: _parse_support_claim_out(claim_out))
return d return d
@ -718,8 +729,8 @@ class Wallet(object):
d.addCallback(lambda name_txid: _get_status_of_claim(name_txid, sd_hash)) d.addCallback(lambda name_txid: _get_status_of_claim(name_txid, sd_hash))
return d return d
def get_available_balance(self): def get_balance(self):
return float(self.wallet_balance - self.total_reserved_points) return self.wallet_balance - self.total_reserved_points - sum(self.queued_payments.values())
def _get_status_of_claim(self, claim_outpoint, name, sd_hash): def _get_status_of_claim(self, claim_outpoint, name, sd_hash):
d = self.get_claims_from_tx(claim_outpoint['txid']) d = self.get_claims_from_tx(claim_outpoint['txid'])
@ -804,7 +815,7 @@ class Wallet(object):
# ======== Must be overridden ======== # # ======== Must be overridden ======== #
def get_balance(self): def _update_balance(self):
return defer.fail(NotImplementedError()) return defer.fail(NotImplementedError())
def get_new_address(self): def get_new_address(self):
@ -1038,7 +1049,7 @@ https://github.com/lbryio/lbry/issues/437 to reduce your wallet size")
func = getattr(cmd_runner, cmd.name) func = getattr(cmd_runner, cmd.name)
return threads.deferToThread(func, *args) return threads.deferToThread(func, *args)
def get_balance(self): def _update_balance(self):
accounts = None accounts = None
exclude_claimtrietx = True exclude_claimtrietx = True
d = self._run_cmd_as_defer_succeed('getbalance', accounts, exclude_claimtrietx) d = self._run_cmd_as_defer_succeed('getbalance', accounts, exclude_claimtrietx)

View file

@ -288,7 +288,7 @@ class Daemon(AuthJSONRPCServer):
def _announce_startup(): def _announce_startup():
def _wait_for_credits(): def _wait_for_credits():
if float(self.session.wallet.wallet_balance) == 0.0: if float(self.session.wallet.get_balance()) == 0.0:
self.startup_status = STARTUP_STAGES[6] self.startup_status = STARTUP_STAGES[6]
return reactor.callLater(1, _wait_for_credits) return reactor.callLater(1, _wait_for_credits)
else: else:
@ -332,7 +332,7 @@ class Daemon(AuthJSONRPCServer):
yield self._setup_lbry_file_manager() yield self._setup_lbry_file_manager()
yield self._setup_query_handlers() yield self._setup_query_handlers()
yield self._setup_server() yield self._setup_server()
log.info("Starting balance: " + str(self.session.wallet.wallet_balance)) log.info("Starting balance: " + str(self.session.wallet.get_balance()))
yield _announce_startup() yield _announce_startup()
def _get_platform(self): def _get_platform(self):
@ -1339,7 +1339,7 @@ class Daemon(AuthJSONRPCServer):
Returns: Returns:
balance, float balance, float
""" """
return self._render_response(float(self.session.wallet.wallet_balance)) return self._render_response(float(self.session.wallet.get_balance()))
def jsonrpc_stop(self): def jsonrpc_stop(self):
""" """

View file

@ -118,9 +118,9 @@ class GetStream(object):
self.fee = FeeValidator(self.stream_info['fee']) self.fee = FeeValidator(self.stream_info['fee'])
max_key_fee = self._convert_max_fee() max_key_fee = self._convert_max_fee()
converted_fee = self.exchange_rate_manager.to_lbc(self.fee).amount converted_fee = self.exchange_rate_manager.to_lbc(self.fee).amount
if converted_fee > self.wallet.wallet_balance: if converted_fee > self.wallet.get_balance():
msg = "Insufficient funds to download lbry://{}. Need {:0.2f}, have {:0.2f}".format( msg = "Insufficient funds to download lbry://{}. Need {:0.2f}, have {:0.2f}".format(
self.resolved_name, converted_fee, self.wallet.wallet_balance) self.resolved_name, converted_fee, self.wallet.get_balance())
raise InsufficientFundsError(msg) raise InsufficientFundsError(msg)
if converted_fee > max_key_fee: if converted_fee > max_key_fee:
msg = "Key fee {:0.2f} above limit of {:0.2f} didn't download lbry://{}".format( msg = "Key fee {:0.2f} above limit of {:0.2f} didn't download lbry://{}".format(

View file

@ -1,7 +1,10 @@
from decimal import Decimal
from collections import defaultdict
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import threads, defer from twisted.internet import threads, defer
from lbrynet.core.Wallet import Wallet
from lbrynet.core.Error import InsufficientFundsError
from lbrynet.core.Wallet import Wallet, ReservedPoints
test_metadata = { test_metadata = {
'license': 'NASA', 'license': 'NASA',
@ -21,7 +24,9 @@ test_metadata = {
class MocLbryumWallet(Wallet): class MocLbryumWallet(Wallet):
def __init__(self): def __init__(self):
pass self.wallet_balance = Decimal(10.0)
self.total_reserved_points = Decimal(0.0)
self.queued_payments = defaultdict(Decimal)
def get_name_claims(self): def get_name_claims(self):
return threads.deferToThread(lambda: []) return threads.deferToThread(lambda: [])
@ -128,3 +133,59 @@ class WalletTest(unittest.TestCase):
d = wallet.abandon_claim("0578c161ad8d36a7580c557d7444f967ea7f988e194c20d0e3c42c3cabf110dd", 1) d = wallet.abandon_claim("0578c161ad8d36a7580c557d7444f967ea7f988e194c20d0e3c42c3cabf110dd", 1)
d.addCallback(lambda claim_out: check_out(claim_out)) d.addCallback(lambda claim_out: check_out(claim_out))
return d return d
def test_point_reservation_and_balance(self):
# check that point reservations and cancellation changes the balance
# properly
def update_balance():
return defer.succeed(5)
wallet = MocLbryumWallet()
wallet._update_balance = update_balance
d = wallet.update_balance()
# test point reservation
d.addCallback(lambda _: self.assertEqual(5, wallet.get_balance()))
d.addCallback(lambda _: wallet.reserve_points('testid',2))
d.addCallback(lambda _: self.assertEqual(3, wallet.get_balance()))
d.addCallback(lambda _: self.assertEqual(2, wallet.total_reserved_points))
# test reserved points cancellation
d.addCallback(lambda _: wallet.cancel_point_reservation(ReservedPoints('testid',2)))
d.addCallback(lambda _: self.assertEqual(5, wallet.get_balance()))
d.addCallback(lambda _: self.assertEqual(0, wallet.total_reserved_points))
# test point sending
d.addCallback(lambda _: wallet.reserve_points('testid',2))
d.addCallback(lambda reserve_points: wallet.send_points_to_address(reserve_points,1))
d.addCallback(lambda _: self.assertEqual(3, wallet.get_balance()))
# test failed point reservation
d.addCallback(lambda _: wallet.reserve_points('testid',4))
d.addCallback(lambda out: self.assertEqual(None,out))
return d
def test_point_reservation_and_claim(self):
# check that claims take into consideration point reservations
def update_balance():
return defer.succeed(5)
wallet = MocLbryumWallet()
wallet._update_balance = update_balance
d = wallet.update_balance()
d.addCallback(lambda _: self.assertEqual(5, wallet.get_balance()))
d.addCallback(lambda _: wallet.reserve_points('testid',2))
d.addCallback(lambda _: wallet.claim_name('test', 4, test_metadata))
self.assertFailure(d,InsufficientFundsError)
return d
def test_point_reservation_and_support(self):
# check that supports take into consideration point reservations
def update_balance():
return defer.succeed(5)
wallet = MocLbryumWallet()
wallet._update_balance = update_balance
d = wallet.update_balance()
d.addCallback(lambda _: self.assertEqual(5, wallet.get_balance()))
d.addCallback(lambda _: wallet.reserve_points('testid',2))
d.addCallback(lambda _: wallet.support_claim('test', "f43dc06256a69988bdbea09a58c80493ba15dcfa", 4))
self.assertFailure(d,InsufficientFundsError)
return d