lbry-sdk/lbrynet/core/PTCWallet.py

323 lines
12 KiB
Python
Raw Normal View History

2015-08-20 11:27:15 -04:00
from collections import defaultdict
import logging
import os
import unqlite
2015-08-20 11:27:15 -04:00
import time
from Crypto.Hash import SHA512
from Crypto.PublicKey import RSA
from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import RequestCanceledError
from lbrynet.interfaces import IRequestCreator, IQueryHandlerFactory, IQueryHandler, ILBRYWallet
from lbrynet.pointtraderclient import pointtraderclient
from twisted.internet import defer, threads
from zope.interface import implements
from twisted.python.failure import Failure
from lbrynet.core.LBRYWallet import ReservedPoints
2015-08-20 11:27:15 -04:00
log = logging.getLogger(__name__)
2015-08-20 11:27:15 -04:00
class PTCWallet(object):
"""This class sends payments to peers and also ensures that expected payments are received.
This class is only intended to be used for testing."""
implements(ILBRYWallet)
def __init__(self, db_dir):
self.db_dir = db_dir
self.db = None
self.private_key = None
self.encoded_public_key = None
self.peer_pub_keys = {}
self.queued_payments = defaultdict(int)
self.expected_payments = defaultdict(list)
self.received_payments = defaultdict(list)
self.next_manage_call = None
self.payment_check_window = 3 * 60 # 3 minutes
self.new_payments_expected_time = time.time() - self.payment_check_window
self.known_transactions = []
self.total_reserved_points = 0.0
self.wallet_balance = 0.0
def manage(self):
"""Send payments, ensure expected payments are received"""
from twisted.internet import reactor
if time.time() < self.new_payments_expected_time + self.payment_check_window:
d1 = self._get_new_payments()
else:
d1 = defer.succeed(None)
d1.addCallback(lambda _: self._check_good_standing())
d2 = self._send_queued_points()
self.next_manage_call = reactor.callLater(60, self.manage)
dl = defer.DeferredList([d1, d2])
dl.addCallback(lambda _: self.get_balance())
def set_balance(balance):
self.wallet_balance = balance
dl.addCallback(set_balance)
return dl
def stop(self):
if self.next_manage_call is not None:
self.next_manage_call.cancel()
self.next_manage_call = None
d = self.manage()
self.next_manage_call.cancel()
self.next_manage_call = None
self.db = None
return d
def start(self):
def save_key(success, private_key):
if success is True:
self._save_private_key(private_key.exportKey())
2015-08-20 11:27:15 -04:00
return True
return False
def register_private_key(private_key):
self.private_key = private_key
self.encoded_public_key = self.private_key.publickey().exportKey()
d_r = pointtraderclient.register_new_account(private_key)
d_r.addCallback(save_key, private_key)
return d_r
def ensure_private_key_exists(encoded_private_key):
if encoded_private_key is not None:
self.private_key = RSA.importKey(encoded_private_key)
self.encoded_public_key = self.private_key.publickey().exportKey()
return True
else:
create_d = threads.deferToThread(RSA.generate, 4096)
create_d.addCallback(register_private_key)
return create_d
def start_manage():
self.manage()
return True
d = self._open_db()
d.addCallback(lambda _: self._get_wallet_private_key())
2015-08-20 11:27:15 -04:00
d.addCallback(ensure_private_key_exists)
d.addCallback(lambda _: start_manage())
return d
def get_info_exchanger(self):
return PointTraderKeyExchanger(self)
def get_wallet_info_query_handler_factory(self):
return PointTraderKeyQueryHandlerFactory(self)
def reserve_points(self, peer, amount):
"""
Ensure a certain amount of points are available to be sent as payment, before the service is rendered
@param peer: The peer to which the payment will ultimately be sent
@param amount: The amount of points to reserve
@return: A ReservedPoints object which is given to send_points once the service has been rendered
"""
if self.wallet_balance >= self.total_reserved_points + amount:
self.total_reserved_points += amount
return ReservedPoints(peer, amount)
return None
def cancel_point_reservation(self, reserved_points):
"""
Return all of the points that were reserved previously for some ReservedPoints object
@param reserved_points: ReservedPoints previously returned by reserve_points
@return: None
"""
self.total_reserved_points -= reserved_points.amount
def send_points(self, reserved_points, amount):
"""
Schedule a payment to be sent to a peer
@param reserved_points: ReservedPoints object previously returned by reserve_points
@param amount: amount of points to actually send, must be less than or equal to the
amount reserved in reserved_points
@return: Deferred which fires when the payment has been scheduled
"""
self.queued_payments[reserved_points.identifier] += amount
# make any unused points available
self.total_reserved_points -= reserved_points.amount - amount
reserved_points.identifier.update_stats('points_sent', amount)
d = defer.succeed(True)
return d
def _send_queued_points(self):
ds = []
for peer, points in self.queued_payments.items():
if peer in self.peer_pub_keys:
d = pointtraderclient.send_points(self.private_key, self.peer_pub_keys[peer], points)
self.wallet_balance -= points
self.total_reserved_points -= points
ds.append(d)
del self.queued_payments[peer]
else:
log.warning("Don't have a payment address for peer %s. Can't send %s points.",
str(peer), str(points))
2015-08-20 11:27:15 -04:00
return defer.DeferredList(ds)
def get_balance(self):
"""Return the balance of this wallet"""
d = pointtraderclient.get_balance(self.private_key)
return d
def add_expected_payment(self, peer, amount):
"""Increase the number of points expected to be paid by a peer"""
self.expected_payments[peer].append((amount, time.time()))
self.new_payments_expected_time = time.time()
peer.update_stats('expected_points', amount)
def set_public_key_for_peer(self, peer, pub_key):
self.peer_pub_keys[peer] = pub_key
def _get_new_payments(self):
def add_new_transactions(transactions):
for transaction in transactions:
if transaction[1] == self.encoded_public_key:
t_hash = SHA512.new()
t_hash.update(transaction[0])
t_hash.update(transaction[1])
t_hash.update(str(transaction[2]))
t_hash.update(transaction[3])
if t_hash.hexdigest() not in self.known_transactions:
self.known_transactions.append(t_hash.hexdigest())
self._add_received_payment(transaction[0], transaction[2])
d = pointtraderclient.get_recent_transactions(self.private_key)
d.addCallback(add_new_transactions)
return d
def _add_received_payment(self, encoded_other_public_key, amount):
self.received_payments[encoded_other_public_key].append((amount, time.time()))
def _check_good_standing(self):
for peer, expected_payments in self.expected_payments.iteritems():
expected_cutoff = time.time() - 90
min_expected_balance = sum([a[0] for a in expected_payments if a[1] < expected_cutoff])
received_balance = 0
if self.peer_pub_keys[peer] in self.received_payments:
received_balance = sum([a[0] for a in self.received_payments[self.peer_pub_keys[peer]]])
if min_expected_balance > received_balance:
log.warning("Account in bad standing: %s (pub_key: %s), expected amount = %s, received_amount = %s",
str(peer), self.peer_pub_keys[peer], str(min_expected_balance), str(received_balance))
2015-08-20 11:27:15 -04:00
def _open_db(self):
def open_db():
self.db = unqlite.UnQLite(os.path.join(self.db_dir, "ptcwallet.db"))
return threads.deferToThread(open_db)
def _save_private_key(self, private_key):
def save_key():
self.db['private_key'] = private_key
return threads.deferToThread(save_key)
def _get_wallet_private_key(self):
def get_key():
if 'private_key' in self.db:
return self.db['private_key']
2015-08-20 11:27:15 -04:00
return None
return threads.deferToThread(get_key)
2015-08-20 11:27:15 -04:00
class PointTraderKeyExchanger(object):
implements([IRequestCreator])
def __init__(self, wallet):
self.wallet = wallet
self._protocols = []
######### IRequestCreator #########
def send_next_request(self, peer, protocol):
if not protocol in self._protocols:
r = ClientRequest({'public_key': self.wallet.encoded_public_key},
'public_key')
d = protocol.add_request(r)
d.addCallback(self._handle_exchange_response, peer, r, protocol)
d.addErrback(self._request_failed, peer)
self._protocols.append(protocol)
return defer.succeed(True)
else:
return defer.succeed(False)
######### internal calls #########
def _handle_exchange_response(self, response_dict, peer, request, protocol):
assert request.response_identifier in response_dict, \
"Expected %s in dict but did not get it" % request.response_identifier
assert protocol in self._protocols, "Responding protocol is not in our list of protocols"
peer_pub_key = response_dict[request.response_identifier]
self.wallet.set_public_key_for_peer(peer, peer_pub_key)
return True
def _request_failed(self, err, peer):
if not err.check(RequestCanceledError):
log.warning("A peer failed to send a valid public key response. Error: %s, peer: %s",
err.getErrorMessage(), str(peer))
return err
2015-08-20 11:27:15 -04:00
class PointTraderKeyQueryHandlerFactory(object):
implements(IQueryHandlerFactory)
def __init__(self, wallet):
self.wallet = wallet
######### IQueryHandlerFactory #########
def build_query_handler(self):
q_h = PointTraderKeyQueryHandler(self.wallet)
return q_h
def get_primary_query_identifier(self):
return 'public_key'
def get_description(self):
return "Point Trader Address - an address for receiving payments on the point trader testing network"
class PointTraderKeyQueryHandler(object):
implements(IQueryHandler)
def __init__(self, wallet):
self.wallet = wallet
self.query_identifiers = ['public_key']
self.public_key = None
self.peer = None
######### IQueryHandler #########
def register_with_request_handler(self, request_handler, peer):
self.peer = peer
request_handler.register_query_handler(self, self.query_identifiers)
def handle_queries(self, queries):
if self.query_identifiers[0] in queries:
new_encoded_pub_key = queries[self.query_identifiers[0]]
try:
RSA.importKey(new_encoded_pub_key)
except (ValueError, TypeError, IndexError):
log.warning("Client sent an invalid public key.")
2015-08-20 11:27:15 -04:00
return defer.fail(Failure(ValueError("Client sent an invalid public key")))
self.public_key = new_encoded_pub_key
self.wallet.set_public_key_for_peer(self.peer, self.public_key)
log.debug("Received the client's public key: %s", str(self.public_key))
2015-08-20 11:27:15 -04:00
fields = {'public_key': self.wallet.encoded_public_key}
return defer.succeed(fields)
if self.public_key is None:
log.warning("Expected a public key, but did not receive one")
2015-08-20 11:27:15 -04:00
return defer.fail(Failure(ValueError("Expected but did not receive a public key")))
else:
return defer.succeed({})