Merge branch 'dht-inlinecallbacks-refactor'

This commit is contained in:
Jack Robison 2018-03-28 19:35:27 -04:00
commit 31e032bb55
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
43 changed files with 1222 additions and 1273 deletions

View file

@ -38,10 +38,8 @@ install:
- pip install . - pip install .
script: script:
- pip install cython - pip install mock pylint
- pip install mock pylint unqlite
- pylint lbrynet - pylint lbrynet
- PYTHONPATH=. trial lbrynet.tests - PYTHONPATH=. trial lbrynet.tests
- python -m unittest discover lbrynet/tests/integration -v
- rvm install ruby-2.3.1 - rvm install ruby-2.3.1
- rvm use 2.3.1 && gem install danger --version '~> 4.0' && danger - rvm use 2.3.1 && gem install danger --version '~> 4.0' && danger

View file

@ -13,24 +13,37 @@ at anytime.
* *
### Fixed ### Fixed
* * handling error from dht clients with old `ping` method
* * blobs not being re-announced if no peers successfully stored, now failed announcements are re-queued
### Deprecated ### Deprecated
* *
* *
### Changed ### Changed
* * several internal dht functions to use inlineCallbacks
* * `DHTHashAnnouncer` and `Node` manage functions to use `LoopingCall`s instead of scheduling with `callLater`.
* `store` kademlia rpc method to block on the call finishing and to return storing peer information
* refactored `DHTHashAnnouncer` to longer use locks, use a `DeferredSemaphore` to limit concurrent announcers
* decoupled `DiskBlobManager` from `DHTHashAnnouncer`
* blob hashes to announce to be controlled by`SQLiteStorage`
* kademlia protocol to not delay writes to the UDP socket
* `reactor` and `callLater`, `listenUDP`, and `resolve` functions to be configurable (to allow easier testing)
* calls to get the current time to use `reactor.seconds` (to control callLater and LoopingCall timing in tests)
* `blob_announce` to queue the blob announcement but not block on it
* blob completion to not `callLater` an immediate announce, let `SQLiteStorage` and the `DHTHashAnnouncer` handle it
* raise the default number of concurrent blob announcers to 100
* dht logging to be more verbose with errors and warnings
* added `single_announce` and `last_announced_time` columns to the `blob` table in sqlite
### Added ### Added
* * virtual kademlia network and mock udp transport for dht integration tests
* * integration tests for bootstrapping the dht
* configurable `concurrent_announcers` setting
### Removed ### Removed
* * `announce_all` argument from `blob_announce`
* * old `blob_announce_all` command
## [0.19.2] - 2018-03-28 ## [0.19.2] - 2018-03-28

View file

@ -40,6 +40,8 @@ ANDROID = 4
KB = 2 ** 10 KB = 2 ** 10
MB = 2 ** 20 MB = 2 ** 20
DEFAULT_CONCURRENT_ANNOUNCERS = 100
DEFAULT_DHT_NODES = [ DEFAULT_DHT_NODES = [
('lbrynet1.lbry.io', 4444), ('lbrynet1.lbry.io', 4444),
('lbrynet2.lbry.io', 4444), ('lbrynet2.lbry.io', 4444),
@ -263,6 +265,7 @@ ADJUSTABLE_SETTINGS = {
'download_timeout': (int, 180), 'download_timeout': (int, 180),
'is_generous_host': (bool, True), 'is_generous_host': (bool, True),
'announce_head_blobs_only': (bool, True), 'announce_head_blobs_only': (bool, True),
'concurrent_announcers': (int, DEFAULT_CONCURRENT_ANNOUNCERS),
'known_dht_nodes': (list, DEFAULT_DHT_NODES, server_list), 'known_dht_nodes': (list, DEFAULT_DHT_NODES, server_list),
'lbryum_wallet_dir': (str, default_lbryum_dir), 'lbryum_wallet_dir': (str, default_lbryum_dir),
'max_connections_per_stream': (int, 5), 'max_connections_per_stream': (int, 5),

View file

@ -1,27 +1,23 @@
import logging import logging
import os import os
from sqlite3 import IntegrityError from sqlite3 import IntegrityError
from twisted.internet import threads, defer, reactor, task from twisted.internet import threads, defer, task
from lbrynet import conf from lbrynet import conf
from lbrynet.blob.blob_file import BlobFile from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.creator import BlobFileCreator from lbrynet.blob.creator import BlobFileCreator
from lbrynet.core.server.DHTHashAnnouncer import DHTHashSupplier
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class DiskBlobManager(DHTHashSupplier): class DiskBlobManager(object):
def __init__(self, hash_announcer, blob_dir, storage): def __init__(self, blob_dir, storage):
""" """
This class stores blobs on the hard disk, This class stores blobs on the hard disk
blob_dir - directory where blobs are stored blob_dir - directory where blobs are stored
db_dir - directory where sqlite database of blob information is stored storage - SQLiteStorage object
""" """
DHTHashSupplier.__init__(self, hash_announcer)
self.storage = storage self.storage = storage
self.announce_head_blobs_only = conf.settings['announce_head_blobs_only']
self.blob_dir = blob_dir self.blob_dir = blob_dir
self.blob_creator_type = BlobFileCreator self.blob_creator_type = BlobFileCreator
# TODO: consider using an LRU for blobs as there could potentially # TODO: consider using an LRU for blobs as there could potentially
@ -30,7 +26,7 @@ class DiskBlobManager(DHTHashSupplier):
self.blob_hashes_to_delete = {} # {blob_hash: being_deleted (True/False)} self.blob_hashes_to_delete = {} # {blob_hash: being_deleted (True/False)}
self.check_should_announce_lc = None self.check_should_announce_lc = None
if conf.settings['run_reflector_server']: if conf.settings['run_reflector_server']: # TODO: move this looping call to SQLiteStorage
self.check_should_announce_lc = task.LoopingCall(self.storage.verify_will_announce_all_head_and_sd_blobs) self.check_should_announce_lc = task.LoopingCall(self.storage.verify_will_announce_all_head_and_sd_blobs)
def setup(self): def setup(self):
@ -62,40 +58,21 @@ class DiskBlobManager(DHTHashSupplier):
self.blobs[blob_hash] = blob self.blobs[blob_hash] = blob
return defer.succeed(blob) return defer.succeed(blob)
def immediate_announce(self, blob_hashes):
if self.hash_announcer:
return self.hash_announcer.immediate_announce(blob_hashes)
raise Exception("Hash announcer not set")
@defer.inlineCallbacks @defer.inlineCallbacks
def blob_completed(self, blob, next_announce_time=None, should_announce=True): def blob_completed(self, blob, should_announce=False, next_announce_time=None):
if next_announce_time is None:
next_announce_time = self.get_next_announce_time()
yield self.storage.add_completed_blob( yield self.storage.add_completed_blob(
blob.blob_hash, blob.length, next_announce_time, should_announce blob.blob_hash, blob.length, next_announce_time, should_announce
) )
# we announce all blobs immediately, if announce_head_blob_only is False
# otherwise, announce only if marked as should_announce
if not self.announce_head_blobs_only or should_announce:
reactor.callLater(0, self.immediate_announce, [blob.blob_hash])
def completed_blobs(self, blobhashes_to_check): def completed_blobs(self, blobhashes_to_check):
return self._completed_blobs(blobhashes_to_check) return self._completed_blobs(blobhashes_to_check)
def hashes_to_announce(self):
return self.storage.get_blobs_to_announce(self.hash_announcer)
def count_should_announce_blobs(self): def count_should_announce_blobs(self):
return self.storage.count_should_announce_blobs() return self.storage.count_should_announce_blobs()
def set_should_announce(self, blob_hash, should_announce): def set_should_announce(self, blob_hash, should_announce):
if blob_hash in self.blobs: now = self.storage.clock.seconds()
blob = self.blobs[blob_hash] return self.storage.set_should_announce(blob_hash, now, should_announce)
if blob.get_is_verified():
return self.storage.set_should_announce(
blob_hash, self.get_next_announce_time(), should_announce
)
return defer.succeed(False)
def get_should_announce(self, blob_hash): def get_should_announce(self, blob_hash):
return self.storage.should_announce(blob_hash) return self.storage.should_announce(blob_hash)
@ -110,13 +87,7 @@ class DiskBlobManager(DHTHashSupplier):
raise Exception("Blob has a length of 0") raise Exception("Blob has a length of 0")
new_blob = BlobFile(self.blob_dir, blob_creator.blob_hash, blob_creator.length) new_blob = BlobFile(self.blob_dir, blob_creator.blob_hash, blob_creator.length)
self.blobs[blob_creator.blob_hash] = new_blob self.blobs[blob_creator.blob_hash] = new_blob
next_announce_time = self.get_next_announce_time() return self.blob_completed(new_blob, should_announce)
return self.blob_completed(new_blob, next_announce_time, should_announce)
def immediate_announce_all_blobs(self):
d = self._get_all_verified_blob_hashes()
d.addCallback(self.immediate_announce)
return d
def get_all_verified_blobs(self): def get_all_verified_blobs(self):
d = self._get_all_verified_blob_hashes() d = self._get_all_verified_blob_hashes()

View file

@ -1,18 +0,0 @@
class DummyHashAnnouncer(object):
def __init__(self, *args):
pass
def run_manage_loop(self):
pass
def stop(self):
pass
def add_supplier(self, *args):
pass
def hash_queue_size(self):
return 0
def immediate_announce(self, *args):
pass

View file

@ -1,330 +0,0 @@
from collections import defaultdict
import logging
import os
import unqlite
import time
from Crypto.Hash import SHA512
from Crypto.PublicKey import RSA
from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import RequestCanceledError
from lbrynet.interfaces import IRequestCreator, IQueryHandlerFactory, IQueryHandler, IWallet
from lbrynet.pointtraderclient import pointtraderclient
from twisted.internet import defer, threads
from zope.interface import implements
from twisted.python.failure import Failure
from lbrynet.core.Wallet import ReservedPoints
log = logging.getLogger(__name__)
class PTCWallet(object):
"""This class sends payments to peers and also ensures that expected payments are received.
This class is only intended to be used for testing."""
implements(IWallet)
def __init__(self, db_dir):
self.db_dir = db_dir
self.db = None
self.private_key = None
self.encoded_public_key = None
self.peer_pub_keys = {}
self.queued_payments = defaultdict(int)
self.expected_payments = defaultdict(list)
self.received_payments = defaultdict(list)
self.next_manage_call = None
self.payment_check_window = 3 * 60 # 3 minutes
self.new_payments_expected_time = time.time() - self.payment_check_window
self.known_transactions = []
self.total_reserved_points = 0.0
self.wallet_balance = 0.0
def manage(self):
"""Send payments, ensure expected payments are received"""
from twisted.internet import reactor
if time.time() < self.new_payments_expected_time + self.payment_check_window:
d1 = self._get_new_payments()
else:
d1 = defer.succeed(None)
d1.addCallback(lambda _: self._check_good_standing())
d2 = self._send_queued_points()
self.next_manage_call = reactor.callLater(60, self.manage)
dl = defer.DeferredList([d1, d2])
dl.addCallback(lambda _: self.get_balance())
def set_balance(balance):
self.wallet_balance = balance
dl.addCallback(set_balance)
return dl
def stop(self):
if self.next_manage_call is not None:
self.next_manage_call.cancel()
self.next_manage_call = None
d = self.manage()
self.next_manage_call.cancel()
self.next_manage_call = None
self.db = None
return d
def start(self):
def save_key(success, private_key):
if success is True:
self._save_private_key(private_key.exportKey())
return True
return False
def register_private_key(private_key):
self.private_key = private_key
self.encoded_public_key = self.private_key.publickey().exportKey()
d_r = pointtraderclient.register_new_account(private_key)
d_r.addCallback(save_key, private_key)
return d_r
def ensure_private_key_exists(encoded_private_key):
if encoded_private_key is not None:
self.private_key = RSA.importKey(encoded_private_key)
self.encoded_public_key = self.private_key.publickey().exportKey()
return True
else:
create_d = threads.deferToThread(RSA.generate, 4096)
create_d.addCallback(register_private_key)
return create_d
def start_manage():
self.manage()
return True
d = self._open_db()
d.addCallback(lambda _: self._get_wallet_private_key())
d.addCallback(ensure_private_key_exists)
d.addCallback(lambda _: start_manage())
return d
def get_info_exchanger(self):
return PointTraderKeyExchanger(self)
def get_wallet_info_query_handler_factory(self):
return PointTraderKeyQueryHandlerFactory(self)
def reserve_points(self, peer, amount):
"""Ensure a certain amount of points are available to be sent as
payment, before the service is rendered
@param peer: The peer to which the payment will ultimately be sent
@param amount: The amount of points to reserve
@return: A ReservedPoints object which is given to send_points
once the service has been rendered
"""
if self.wallet_balance >= self.total_reserved_points + amount:
self.total_reserved_points += amount
return ReservedPoints(peer, amount)
return None
def cancel_point_reservation(self, reserved_points):
"""
Return all of the points that were reserved previously for some ReservedPoints object
@param reserved_points: ReservedPoints previously returned by reserve_points
@return: None
"""
self.total_reserved_points -= reserved_points.amount
def send_points(self, reserved_points, amount):
"""
Schedule a payment to be sent to a peer
@param reserved_points: ReservedPoints object previously returned by reserve_points
@param amount: amount of points to actually send, must be less than or equal to the
amount reserved in reserved_points
@return: Deferred which fires when the payment has been scheduled
"""
self.queued_payments[reserved_points.identifier] += amount
# make any unused points available
self.total_reserved_points -= reserved_points.amount - amount
reserved_points.identifier.update_stats('points_sent', amount)
d = defer.succeed(True)
return d
def _send_queued_points(self):
ds = []
for peer, points in self.queued_payments.items():
if peer in self.peer_pub_keys:
d = pointtraderclient.send_points(
self.private_key, self.peer_pub_keys[peer], points)
self.wallet_balance -= points
self.total_reserved_points -= points
ds.append(d)
del self.queued_payments[peer]
else:
log.warning("Don't have a payment address for peer %s. Can't send %s points.",
str(peer), str(points))
return defer.DeferredList(ds)
def get_balance(self):
"""Return the balance of this wallet"""
d = pointtraderclient.get_balance(self.private_key)
return d
def add_expected_payment(self, peer, amount):
"""Increase the number of points expected to be paid by a peer"""
self.expected_payments[peer].append((amount, time.time()))
self.new_payments_expected_time = time.time()
peer.update_stats('expected_points', amount)
def set_public_key_for_peer(self, peer, pub_key):
self.peer_pub_keys[peer] = pub_key
def _get_new_payments(self):
def add_new_transactions(transactions):
for transaction in transactions:
if transaction[1] == self.encoded_public_key:
t_hash = SHA512.new()
t_hash.update(transaction[0])
t_hash.update(transaction[1])
t_hash.update(str(transaction[2]))
t_hash.update(transaction[3])
if t_hash.hexdigest() not in self.known_transactions:
self.known_transactions.append(t_hash.hexdigest())
self._add_received_payment(transaction[0], transaction[2])
d = pointtraderclient.get_recent_transactions(self.private_key)
d.addCallback(add_new_transactions)
return d
def _add_received_payment(self, encoded_other_public_key, amount):
self.received_payments[encoded_other_public_key].append((amount, time.time()))
def _check_good_standing(self):
for peer, expected_payments in self.expected_payments.iteritems():
expected_cutoff = time.time() - 90
min_expected_balance = sum([a[0] for a in expected_payments if a[1] < expected_cutoff])
received_balance = 0
if self.peer_pub_keys[peer] in self.received_payments:
received_balance = sum([
a[0] for a in self.received_payments[self.peer_pub_keys[peer]]])
if min_expected_balance > received_balance:
log.warning(
"Account in bad standing: %s (pub_key: %s), expected amount = %s, "
"received_amount = %s",
peer, self.peer_pub_keys[peer], min_expected_balance, received_balance)
def _open_db(self):
def open_db():
self.db = unqlite.UnQLite(os.path.join(self.db_dir, "ptcwallet.db"))
return threads.deferToThread(open_db)
def _save_private_key(self, private_key):
def save_key():
self.db['private_key'] = private_key
return threads.deferToThread(save_key)
def _get_wallet_private_key(self):
def get_key():
if 'private_key' in self.db:
return self.db['private_key']
return None
return threads.deferToThread(get_key)
class PointTraderKeyExchanger(object):
implements([IRequestCreator])
def __init__(self, wallet):
self.wallet = wallet
self._protocols = []
######### IRequestCreator #########
def send_next_request(self, peer, protocol):
if not protocol in self._protocols:
r = ClientRequest({'public_key': self.wallet.encoded_public_key},
'public_key')
d = protocol.add_request(r)
d.addCallback(self._handle_exchange_response, peer, r, protocol)
d.addErrback(self._request_failed, peer)
self._protocols.append(protocol)
return defer.succeed(True)
else:
return defer.succeed(False)
######### internal calls #########
def _handle_exchange_response(self, response_dict, peer, request, protocol):
assert request.response_identifier in response_dict, \
"Expected %s in dict but did not get it" % request.response_identifier
assert protocol in self._protocols, "Responding protocol is not in our list of protocols"
peer_pub_key = response_dict[request.response_identifier]
self.wallet.set_public_key_for_peer(peer, peer_pub_key)
return True
def _request_failed(self, err, peer):
if not err.check(RequestCanceledError):
log.warning("A peer failed to send a valid public key response. Error: %s, peer: %s",
err.getErrorMessage(), str(peer))
return err
class PointTraderKeyQueryHandlerFactory(object):
implements(IQueryHandlerFactory)
def __init__(self, wallet):
self.wallet = wallet
######### IQueryHandlerFactory #########
def build_query_handler(self):
q_h = PointTraderKeyQueryHandler(self.wallet)
return q_h
def get_primary_query_identifier(self):
return 'public_key'
def get_description(self):
return ("Point Trader Address - an address for receiving payments on the "
"point trader testing network")
class PointTraderKeyQueryHandler(object):
implements(IQueryHandler)
def __init__(self, wallet):
self.wallet = wallet
self.query_identifiers = ['public_key']
self.public_key = None
self.peer = None
######### IQueryHandler #########
def register_with_request_handler(self, request_handler, peer):
self.peer = peer
request_handler.register_query_handler(self, self.query_identifiers)
def handle_queries(self, queries):
if self.query_identifiers[0] in queries:
new_encoded_pub_key = queries[self.query_identifiers[0]]
try:
RSA.importKey(new_encoded_pub_key)
except (ValueError, TypeError, IndexError):
log.warning("Client sent an invalid public key.")
return defer.fail(Failure(ValueError("Client sent an invalid public key")))
self.public_key = new_encoded_pub_key
self.wallet.set_public_key_for_peer(self.peer, self.public_key)
log.debug("Received the client's public key: %s", str(self.public_key))
fields = {'public_key': self.wallet.encoded_public_key}
return defer.succeed(fields)
if self.public_key is None:
log.warning("Expected a public key, but did not receive one")
return defer.fail(Failure(ValueError("Expected but did not receive a public key")))
else:
return defer.succeed({})

View file

@ -1,19 +0,0 @@
from twisted.internet import defer
class DummyPeerFinder(object):
"""This class finds peers which have announced to the DHT that they have certain blobs"""
def __init__(self):
pass
def run_manage_loop(self):
pass
def stop(self):
pass
def find_peers_for_blob(self, blob_hash):
return defer.succeed([])
def get_most_popular_hashes(self, num_to_return):
return []

View file

@ -1,17 +1,13 @@
import logging import logging
import miniupnpc import miniupnpc
from twisted.internet import threads, defer
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.dht import node from lbrynet.dht import node, hashannouncer
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet.core.PeerManager import PeerManager
from lbrynet.core.RateLimiter import RateLimiter from lbrynet.core.RateLimiter import RateLimiter
from lbrynet.core.client.DHTPeerFinder import DHTPeerFinder
from lbrynet.core.HashAnnouncer import DummyHashAnnouncer
from lbrynet.core.server.DHTHashAnnouncer import DHTHashAnnouncer
from lbrynet.core.utils import generate_id from lbrynet.core.utils import generate_id
from lbrynet.core.PaymentRateManager import BasePaymentRateManager, NegotiatedPaymentRateManager from lbrynet.core.PaymentRateManager import BasePaymentRateManager, NegotiatedPaymentRateManager
from lbrynet.core.BlobAvailability import BlobAvailabilityTracker from lbrynet.core.BlobAvailability import BlobAvailabilityTracker
from twisted.internet import threads, defer
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -101,43 +97,32 @@ class Session(object):
""" """
self.db_dir = db_dir self.db_dir = db_dir
self.node_id = node_id self.node_id = node_id
self.peer_manager = peer_manager self.peer_manager = peer_manager
self.peer_finder = peer_finder
self.hash_announcer = hash_announcer
self.dht_node_port = dht_node_port self.dht_node_port = dht_node_port
self.known_dht_nodes = known_dht_nodes self.known_dht_nodes = known_dht_nodes
if self.known_dht_nodes is None: if self.known_dht_nodes is None:
self.known_dht_nodes = [] self.known_dht_nodes = []
self.peer_finder = peer_finder
self.hash_announcer = hash_announcer
self.blob_dir = blob_dir self.blob_dir = blob_dir
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.blob_tracker = None self.blob_tracker = None
self.blob_tracker_class = blob_tracker_class or BlobAvailabilityTracker self.blob_tracker_class = blob_tracker_class or BlobAvailabilityTracker
self.peer_port = peer_port self.peer_port = peer_port
self.use_upnp = use_upnp self.use_upnp = use_upnp
self.rate_limiter = rate_limiter self.rate_limiter = rate_limiter
self.external_ip = external_ip self.external_ip = external_ip
self.upnp_redirects = [] self.upnp_redirects = []
self.wallet = wallet self.wallet = wallet
self.dht_node_class = dht_node_class self.dht_node_class = dht_node_class
self.dht_node = None self.dht_node = None
self.base_payment_rate_manager = BasePaymentRateManager(blob_data_payment_rate) self.base_payment_rate_manager = BasePaymentRateManager(blob_data_payment_rate)
self.payment_rate_manager = None self.payment_rate_manager = None
self.payment_rate_manager_class = payment_rate_manager_class or NegotiatedPaymentRateManager self.payment_rate_manager_class = payment_rate_manager_class or NegotiatedPaymentRateManager
self.is_generous = is_generous self.is_generous = is_generous
self.storage = storage or SQLiteStorage(self.db_dir) self.storage = storage or SQLiteStorage(self.db_dir)
self._join_dht_deferred = None
def setup(self): def setup(self):
"""Create the blob directory and database if necessary, start all desired services""" """Create the blob directory and database if necessary, start all desired services"""
@ -147,25 +132,12 @@ class Session(object):
if self.node_id is None: if self.node_id is None:
self.node_id = generate_id() self.node_id = generate_id()
if self.wallet is None:
from lbrynet.core.PTCWallet import PTCWallet
self.wallet = PTCWallet(self.db_dir)
if self.peer_manager is None:
self.peer_manager = PeerManager()
if self.use_upnp is True: if self.use_upnp is True:
d = self._try_upnp() d = self._try_upnp()
else: else:
d = defer.succeed(True) d = defer.succeed(True)
d.addCallback(lambda _: self.storage.setup())
if self.peer_finder is None:
d.addCallback(lambda _: self._setup_dht()) d.addCallback(lambda _: self._setup_dht())
else:
if self.hash_announcer is None and self.peer_port is not None:
log.warning("The server has no way to advertise its available blobs.")
self.hash_announcer = DummyHashAnnouncer()
d.addCallback(lambda _: self._setup_other_components()) d.addCallback(lambda _: self._setup_other_components())
return d return d
@ -173,16 +145,14 @@ class Session(object):
"""Stop all services""" """Stop all services"""
log.info('Stopping session.') log.info('Stopping session.')
ds = [] ds = []
if self.hash_announcer:
self.hash_announcer.stop()
if self.blob_tracker is not None: if self.blob_tracker is not None:
ds.append(defer.maybeDeferred(self.blob_tracker.stop)) ds.append(defer.maybeDeferred(self.blob_tracker.stop))
if self.dht_node is not None: if self.dht_node is not None:
ds.append(defer.maybeDeferred(self.dht_node.stop)) ds.append(defer.maybeDeferred(self.dht_node.stop))
if self.rate_limiter is not None: if self.rate_limiter is not None:
ds.append(defer.maybeDeferred(self.rate_limiter.stop)) ds.append(defer.maybeDeferred(self.rate_limiter.stop))
if self.peer_finder is not None:
ds.append(defer.maybeDeferred(self.peer_finder.stop))
if self.hash_announcer is not None:
ds.append(defer.maybeDeferred(self.hash_announcer.stop))
if self.wallet is not None: if self.wallet is not None:
ds.append(defer.maybeDeferred(self.wallet.stop)) ds.append(defer.maybeDeferred(self.wallet.stop))
if self.blob_manager is not None: if self.blob_manager is not None:
@ -251,59 +221,22 @@ class Session(object):
d.addErrback(upnp_failed) d.addErrback(upnp_failed)
return d return d
# the callback, if any, will be invoked once the joining procedure def _setup_dht(self): # does not block startup, the dht will re-attempt if necessary
# has terminated
def join_dht(self, cb=None):
from twisted.internet import reactor
def join_resolved_addresses(result):
addresses = []
for success, value in result:
if success is True:
addresses.append(value)
return addresses
@defer.inlineCallbacks
def join_network(knownNodes):
log.debug("join DHT using known nodes: " + str(knownNodes))
result = yield self.dht_node.joinNetwork(knownNodes)
defer.returnValue(result)
ds = []
for host, port in self.known_dht_nodes:
d = reactor.resolve(host)
d.addCallback(lambda h: (h, port)) # match host to port
ds.append(d)
dl = defer.DeferredList(ds)
dl.addCallback(join_resolved_addresses)
dl.addCallback(join_network)
if cb:
dl.addCallback(cb)
return dl
def _setup_dht(self):
log.info("Starting DHT")
def start_dht(join_network_result):
self.hash_announcer.run_manage_loop()
return True
self.dht_node = self.dht_node_class( self.dht_node = self.dht_node_class(
udpPort=self.dht_node_port,
node_id=self.node_id, node_id=self.node_id,
udpPort=self.dht_node_port,
externalIP=self.external_ip, externalIP=self.external_ip,
peerPort=self.peer_port peerPort=self.peer_port,
peer_manager=self.peer_manager,
peer_finder=self.peer_finder,
) )
self.peer_finder = DHTPeerFinder(self.dht_node, self.peer_manager) if not self.hash_announcer:
if self.hash_announcer is None: self.hash_announcer = hashannouncer.DHTHashAnnouncer(self.dht_node, self.storage)
self.hash_announcer = DHTHashAnnouncer(self.dht_node, self.peer_port) self.peer_manager = self.dht_node.peer_manager
self.peer_finder = self.dht_node.peer_finder
self.dht_node.startNetwork() self._join_dht_deferred = self.dht_node.joinNetwork(self.known_dht_nodes)
self._join_dht_deferred.addCallback(lambda _: log.info("Joined the dht"))
# pass start_dht() as callback to start the remaining components after joining the DHT self._join_dht_deferred.addCallback(lambda _: self.hash_announcer.start())
return self.join_dht(start_dht)
def _setup_other_components(self): def _setup_other_components(self):
log.debug("Setting up the rest of the components") log.debug("Setting up the rest of the components")
@ -316,13 +249,11 @@ class Session(object):
raise Exception( raise Exception(
"TempBlobManager is no longer supported, specify BlobManager or db_dir") "TempBlobManager is no longer supported, specify BlobManager or db_dir")
else: else:
self.blob_manager = DiskBlobManager( self.blob_manager = DiskBlobManager(self.blob_dir, self.storage)
self.hash_announcer, self.blob_dir, self.storage
)
if self.blob_tracker is None: if self.blob_tracker is None:
self.blob_tracker = self.blob_tracker_class( self.blob_tracker = self.blob_tracker_class(
self.blob_manager, self.peer_finder, self.dht_node self.blob_manager, self.dht_node.peer_finder, self.dht_node
) )
if self.payment_rate_manager is None: if self.payment_rate_manager is None:
self.payment_rate_manager = self.payment_rate_manager_class( self.payment_rate_manager = self.payment_rate_manager_class(
@ -330,8 +261,7 @@ class Session(object):
) )
self.rate_limiter.start() self.rate_limiter.start()
d = self.storage.setup() d = self.blob_manager.setup()
d.addCallback(lambda _: self.blob_manager.setup())
d.addCallback(lambda _: self.wallet.start()) d.addCallback(lambda _: self.wallet.start())
d.addCallback(lambda _: self.blob_tracker.start()) d.addCallback(lambda _: self.blob_tracker.start())
return d return d

View file

@ -6,13 +6,13 @@ from twisted.internet import defer, threads, reactor
from lbrynet.blob import BlobFile from lbrynet.blob import BlobFile
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.core.HashAnnouncer import DummyHashAnnouncer
from lbrynet.core.RateLimiter import DummyRateLimiter from lbrynet.core.RateLimiter import DummyRateLimiter
from lbrynet.core.PaymentRateManager import OnlyFreePaymentsManager from lbrynet.core.PaymentRateManager import OnlyFreePaymentsManager
from lbrynet.core.PeerFinder import DummyPeerFinder
from lbrynet.core.client.BlobRequester import BlobRequester from lbrynet.core.client.BlobRequester import BlobRequester
from lbrynet.core.client.StandaloneBlobDownloader import StandaloneBlobDownloader from lbrynet.core.client.StandaloneBlobDownloader import StandaloneBlobDownloader
from lbrynet.core.client.ConnectionManager import ConnectionManager from lbrynet.core.client.ConnectionManager import ConnectionManager
from lbrynet.dht.peerfinder import DummyPeerFinder
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -60,7 +60,6 @@ class SingleBlobDownloadManager(object):
class SinglePeerDownloader(object): class SinglePeerDownloader(object):
def __init__(self): def __init__(self):
self._payment_rate_manager = OnlyFreePaymentsManager() self._payment_rate_manager = OnlyFreePaymentsManager()
self._announcer = DummyHashAnnouncer()
self._rate_limiter = DummyRateLimiter() self._rate_limiter = DummyRateLimiter()
self._wallet = None self._wallet = None
self._blob_manager = None self._blob_manager = None
@ -97,7 +96,7 @@ class SinglePeerDownloader(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def download_temp_blob_from_peer(self, peer, timeout, blob_hash): def download_temp_blob_from_peer(self, peer, timeout, blob_hash):
tmp_dir = yield threads.deferToThread(tempfile.mkdtemp) tmp_dir = yield threads.deferToThread(tempfile.mkdtemp)
tmp_blob_manager = DiskBlobManager(self._announcer, tmp_dir, tmp_dir) tmp_blob_manager = DiskBlobManager(tmp_dir, tmp_dir)
try: try:
result = yield self.download_blob_from_peer(peer, timeout, blob_hash, tmp_blob_manager) result = yield self.download_blob_from_peer(peer, timeout, blob_hash, tmp_blob_manager)
finally: finally:

View file

@ -623,6 +623,9 @@ class Wallet(object):
d = self._get_transaction(txid) d = self._get_transaction(txid)
return d return d
def wait_for_tx_in_wallet(self, txid):
return self._wait_for_tx_in_wallet(txid)
def get_balance(self): def get_balance(self):
return self.wallet_balance - self.total_reserved_points - sum(self.queued_payments.values()) return self.wallet_balance - self.total_reserved_points - sum(self.queued_payments.values())
@ -681,6 +684,9 @@ class Wallet(object):
def _get_transaction(self, txid): def _get_transaction(self, txid):
return defer.fail(NotImplementedError()) return defer.fail(NotImplementedError())
def _wait_for_tx_in_wallet(self, txid):
return defer.fail(NotImplementedError())
def _update_balance(self): def _update_balance(self):
return defer.fail(NotImplementedError()) return defer.fail(NotImplementedError())
@ -1067,6 +1073,9 @@ class LBRYumWallet(Wallet):
def _get_transaction(self, txid): def _get_transaction(self, txid):
return self._run_cmd_as_defer_to_thread("gettransaction", txid) return self._run_cmd_as_defer_to_thread("gettransaction", txid)
def _wait_for_tx_in_wallet(self, txid):
return self._run_cmd_as_defer_to_thread("waitfortxinwallet", txid)
def get_name_claims(self): def get_name_claims(self):
return self._run_cmd_as_defer_succeed('getnameclaims') return self._run_cmd_as_defer_succeed('getnameclaims')

View file

@ -0,0 +1,63 @@
class CallLaterManager(object):
_callLater = None
_pendingCallLaters = []
@classmethod
def _cancel(cls, call_later):
"""
:param call_later: DelayedCall
:return: (callable) canceller function
"""
def cancel(reason=None):
"""
:param reason: reason for cancellation, this is returned after cancelling the DelayedCall
:return: reason
"""
if call_later.active():
call_later.cancel()
cls._pendingCallLaters.remove(call_later)
return reason
return cancel
@classmethod
def stop(cls):
"""
Cancel any callLaters that are still running
"""
from twisted.internet import defer
while cls._pendingCallLaters:
canceller = cls._cancel(cls._pendingCallLaters[0])
try:
canceller()
except (defer.CancelledError, defer.AlreadyCalledError):
pass
@classmethod
def call_later(cls, when, what, *args, **kwargs):
"""
Schedule a call later and get a canceller callback function
:param when: (float) delay in seconds
:param what: (callable)
:param args: (*tuple) args to be passed to the callable
:param kwargs: (**dict) kwargs to be passed to the callable
:return: (tuple) twisted.internet.base.DelayedCall object, canceller function
"""
call_later = cls._callLater(when, what, *args, **kwargs)
canceller = cls._cancel(call_later)
cls._pendingCallLaters.append(call_later)
return call_later, canceller
@classmethod
def setup(cls, callLater):
"""
Setup the callLater function to use, supports the real reactor as well as task.Clock
:param callLater: (IReactorTime.callLater)
"""
cls._callLater = callLater

View file

@ -1,129 +0,0 @@
import binascii
import collections
import logging
import time
from twisted.internet import defer
from lbrynet.core import utils
log = logging.getLogger(__name__)
class DHTHashAnnouncer(object):
ANNOUNCE_CHECK_INTERVAL = 60
CONCURRENT_ANNOUNCERS = 5
"""This class announces to the DHT that this peer has certain blobs"""
def __init__(self, dht_node, peer_port):
self.dht_node = dht_node
self.peer_port = peer_port
self.suppliers = []
self.next_manage_call = None
self.hash_queue = collections.deque()
self._concurrent_announcers = 0
def run_manage_loop(self):
if self.peer_port is not None:
self._announce_available_hashes()
self.next_manage_call = utils.call_later(self.ANNOUNCE_CHECK_INTERVAL, self.run_manage_loop)
def stop(self):
log.info("Stopping DHT hash announcer.")
if self.next_manage_call is not None:
self.next_manage_call.cancel()
self.next_manage_call = None
def add_supplier(self, supplier):
self.suppliers.append(supplier)
def immediate_announce(self, blob_hashes):
if self.peer_port is not None:
return self._announce_hashes(blob_hashes, immediate=True)
else:
return defer.succeed(False)
def hash_queue_size(self):
return len(self.hash_queue)
def _announce_available_hashes(self):
log.debug('Announcing available hashes')
ds = []
for supplier in self.suppliers:
d = supplier.hashes_to_announce()
d.addCallback(self._announce_hashes)
ds.append(d)
dl = defer.DeferredList(ds)
return dl
def _announce_hashes(self, hashes, immediate=False):
if not hashes:
return
log.debug('Announcing %s hashes', len(hashes))
# TODO: add a timeit decorator
start = time.time()
ds = []
for h in hashes:
announce_deferred = defer.Deferred()
ds.append(announce_deferred)
if immediate:
self.hash_queue.appendleft((h, announce_deferred))
else:
self.hash_queue.append((h, announce_deferred))
log.debug('There are now %s hashes remaining to be announced', self.hash_queue_size())
def announce():
if len(self.hash_queue):
h, announce_deferred = self.hash_queue.popleft()
log.debug('Announcing blob %s to dht', h)
d = self.dht_node.announceHaveBlob(binascii.unhexlify(h))
d.chainDeferred(announce_deferred)
d.addBoth(lambda _: utils.call_later(0, announce))
else:
self._concurrent_announcers -= 1
for i in range(self._concurrent_announcers, self.CONCURRENT_ANNOUNCERS):
self._concurrent_announcers += 1
announce()
d = defer.DeferredList(ds)
d.addCallback(lambda _: log.debug('Took %s seconds to announce %s hashes',
time.time() - start, len(hashes)))
return d
class DHTHashSupplier(object):
# 1 hour is the min time hash will be reannounced
MIN_HASH_REANNOUNCE_TIME = 60*60
# conservative assumption of the time it takes to announce
# a single hash
SINGLE_HASH_ANNOUNCE_DURATION = 5
"""Classes derived from this class give hashes to a hash announcer"""
def __init__(self, announcer):
if announcer is not None:
announcer.add_supplier(self)
self.hash_announcer = announcer
def hashes_to_announce(self):
pass
def get_next_announce_time(self, num_hashes_to_announce=1):
"""
Hash reannounce time is set to current time + MIN_HASH_REANNOUNCE_TIME,
unless we are announcing a lot of hashes at once which could cause the
the announce queue to pile up. To prevent pile up, reannounce
only after a conservative estimate of when it will finish
to announce all the hashes.
Args:
num_hashes_to_announce: number of hashes that will be added to the queue
Returns:
timestamp for next announce time
"""
queue_size = self.hash_announcer.hash_queue_size()+num_hashes_to_announce
reannounce = max(self.MIN_HASH_REANNOUNCE_TIME,
queue_size*self.SINGLE_HASH_ANNOUNCE_DURATION)
return time.time() + reannounce

View file

@ -148,6 +148,17 @@ def json_dumps_pretty(obj, **kwargs):
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs) return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)
class DeferredLockContextManager(object):
def __init__(self, lock):
self._lock = lock
def __enter__(self):
yield self._lock.aquire()
def __exit__(self, exc_type, exc_val, exc_tb):
yield self._lock.release()
@defer.inlineCallbacks @defer.inlineCallbacks
def DeferredDict(d, consumeErrors=False): def DeferredDict(d, consumeErrors=False):
keys = [] keys = []

View file

@ -25,7 +25,7 @@ from lbryschema.decode import smart_decode
from lbrynet.core.system_info import get_lbrynet_version from lbrynet.core.system_info import get_lbrynet_version
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet import conf from lbrynet import conf
from lbrynet.conf import LBRYCRD_WALLET, LBRYUM_WALLET, PTC_WALLET from lbrynet.conf import LBRYCRD_WALLET, LBRYUM_WALLET
from lbrynet.reflector import reupload from lbrynet.reflector import reupload
from lbrynet.reflector import ServerFactory as reflector_server_factory from lbrynet.reflector import ServerFactory as reflector_server_factory
from lbrynet.core.log_support import configure_loggly_handler from lbrynet.core.log_support import configure_loggly_handler
@ -198,7 +198,7 @@ class Daemon(AuthJSONRPCServer):
self.connected_to_internet = True self.connected_to_internet = True
self.connection_status_code = None self.connection_status_code = None
self.platform = None self.platform = None
self.current_db_revision = 6 self.current_db_revision = 7
self.db_revision_file = conf.settings.get_db_revision_filename() self.db_revision_file = conf.settings.get_db_revision_filename()
self.session = None self.session = None
self._session_id = conf.settings.get_session_id() self._session_id = conf.settings.get_session_id()
@ -313,7 +313,7 @@ class Daemon(AuthJSONRPCServer):
self.session.peer_manager) self.session.peer_manager)
try: try:
log.info("Daemon bound to port: %d", self.peer_port) log.info("Peer protocol listening on TCP %d", self.peer_port)
self.lbry_server_port = reactor.listenTCP(self.peer_port, server_factory) self.lbry_server_port = reactor.listenTCP(self.peer_port, server_factory)
except error.CannotListenError as e: except error.CannotListenError as e:
import traceback import traceback
@ -547,10 +547,6 @@ class Daemon(AuthJSONRPCServer):
config['lbryum_path'] = conf.settings['lbryum_wallet_dir'] config['lbryum_path'] = conf.settings['lbryum_wallet_dir']
wallet = LBRYumWallet(self.storage, config) wallet = LBRYumWallet(self.storage, config)
return defer.succeed(wallet) return defer.succeed(wallet)
elif self.wallet_type == PTC_WALLET:
log.info("Using PTC wallet")
from lbrynet.core.PTCWallet import PTCWallet
return defer.succeed(PTCWallet(self.db_dir))
else: else:
raise ValueError('Wallet Type {} is not valid'.format(self.wallet_type)) raise ValueError('Wallet Type {} is not valid'.format(self.wallet_type))
@ -997,7 +993,7 @@ class Daemon(AuthJSONRPCServer):
############################################################################ ############################################################################
@defer.inlineCallbacks @defer.inlineCallbacks
def jsonrpc_status(self, session_status=False, dht_status=False): def jsonrpc_status(self, session_status=False):
""" """
Get daemon status Get daemon status
@ -1006,7 +1002,6 @@ class Daemon(AuthJSONRPCServer):
Options: Options:
--session_status : (bool) include session status in results --session_status : (bool) include session status in results
--dht_status : (bool) include dht network and peer status
Returns: Returns:
(dict) lbrynet-daemon status (dict) lbrynet-daemon status
@ -1037,18 +1032,6 @@ class Daemon(AuthJSONRPCServer):
'announce_queue_size': number of blobs currently queued to be announced 'announce_queue_size': number of blobs currently queued to be announced
'should_announce_blobs': number of blobs that should be announced 'should_announce_blobs': number of blobs that should be announced
} }
If given the dht status option:
'dht_status': {
'kbps_received': current kbps receiving,
'kbps_sent': current kdps being sent,
'total_bytes_sent': total bytes sent,
'total_bytes_received': total bytes received,
'queries_received': number of queries received per second,
'queries_sent': number of queries sent per second,
'recent_contacts': count of recently contacted peers,
'unique_contacts': count of unique peers
},
} }
""" """
@ -1095,8 +1078,6 @@ class Daemon(AuthJSONRPCServer):
'announce_queue_size': announce_queue_size, 'announce_queue_size': announce_queue_size,
'should_announce_blobs': should_announce_blobs, 'should_announce_blobs': should_announce_blobs,
} }
if dht_status:
response['dht_status'] = self.session.dht_node.get_bandwidth_stats()
defer.returnValue(response) defer.returnValue(response)
def jsonrpc_version(self): def jsonrpc_version(self):
@ -2919,17 +2900,15 @@ class Daemon(AuthJSONRPCServer):
return d return d
@defer.inlineCallbacks @defer.inlineCallbacks
def jsonrpc_blob_announce(self, blob_hash=None, stream_hash=None, sd_hash=None, announce_all=None): def jsonrpc_blob_announce(self, blob_hash=None, stream_hash=None, sd_hash=None):
""" """
Announce blobs to the DHT Announce blobs to the DHT
Usage: Usage:
blob_announce [<blob_hash> | --blob_hash=<blob_hash>] blob_announce [<blob_hash> | --blob_hash=<blob_hash>]
[<stream_hash> | --stream_hash=<stream_hash>] | [<sd_hash> | --sd_hash=<sd_hash>] [<stream_hash> | --stream_hash=<stream_hash>] | [<sd_hash> | --sd_hash=<sd_hash>]
[--announce_all]
Options: Options:
--announce_all : (bool) announce all the blobs possessed by user
--blob_hash=<blob_hash> : (str) announce a blob, specified by blob_hash --blob_hash=<blob_hash> : (str) announce a blob, specified by blob_hash
--stream_hash=<stream_hash> : (str) announce all blobs associated with --stream_hash=<stream_hash> : (str) announce all blobs associated with
stream_hash stream_hash
@ -2940,9 +2919,6 @@ class Daemon(AuthJSONRPCServer):
(bool) true if successful (bool) true if successful
""" """
if announce_all:
yield self.session.blob_manager.immediate_announce_all_blobs()
else:
blob_hashes = [] blob_hashes = []
if blob_hash: if blob_hash:
blob_hashes.append(blob_hash) blob_hashes.append(blob_hash)
@ -2952,29 +2928,13 @@ class Daemon(AuthJSONRPCServer):
if sd_hash: if sd_hash:
stream_hash = yield self.storage.get_stream_hash_for_sd_hash(sd_hash) stream_hash = yield self.storage.get_stream_hash_for_sd_hash(sd_hash)
blobs = yield self.storage.get_blobs_for_stream(stream_hash, only_completed=True) blobs = yield self.storage.get_blobs_for_stream(stream_hash, only_completed=True)
blob_hashes.extend([blob.blob_hash for blob in blobs if blob.blob_hash is not None]) blob_hashes.extend(blob.blob_hash for blob in blobs if blob.blob_hash is not None)
else: else:
raise Exception('single argument must be specified') raise Exception('single argument must be specified')
yield self.session.blob_manager.immediate_announce(blob_hashes) yield self.storage.should_single_announce_blobs(blob_hashes, immediate=True)
response = yield self._render_response(True) response = yield self._render_response(True)
defer.returnValue(response) defer.returnValue(response)
@AuthJSONRPCServer.deprecated("blob_announce")
def jsonrpc_blob_announce_all(self):
"""
Announce all blobs to the DHT
Usage:
blob_announce_all
Options:
None
Returns:
(str) Success/fail message
"""
return self.jsonrpc_blob_announce(announce_all=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def jsonrpc_file_reflect(self, **kwargs): def jsonrpc_file_reflect(self, **kwargs):
""" """

View file

@ -39,6 +39,7 @@ class DaemonServer(object):
try: try:
self.server_port = reactor.listenTCP( self.server_port = reactor.listenTCP(
conf.settings['api_port'], lbrynet_server, interface=conf.settings['api_host']) conf.settings['api_port'], lbrynet_server, interface=conf.settings['api_host'])
log.info("lbrynet API listening on TCP %s:%i", conf.settings['api_host'], conf.settings['api_port'])
except error.CannotListenError: except error.CannotListenError:
log.info('Daemon already running, exiting app') log.info('Daemon already running, exiting app')
raise raise

View file

@ -6,22 +6,20 @@ def migrate_db(db_dir, start, end):
while current < end: while current < end:
if current == 1: if current == 1:
from lbrynet.database.migrator.migrate1to2 import do_migration from lbrynet.database.migrator.migrate1to2 import do_migration
do_migration(db_dir)
elif current == 2: elif current == 2:
from lbrynet.database.migrator.migrate2to3 import do_migration from lbrynet.database.migrator.migrate2to3 import do_migration
do_migration(db_dir)
elif current == 3: elif current == 3:
from lbrynet.database.migrator.migrate3to4 import do_migration from lbrynet.database.migrator.migrate3to4 import do_migration
do_migration(db_dir)
elif current == 4: elif current == 4:
from lbrynet.database.migrator.migrate4to5 import do_migration from lbrynet.database.migrator.migrate4to5 import do_migration
do_migration(db_dir)
elif current == 5: elif current == 5:
from lbrynet.database.migrator.migrate5to6 import do_migration from lbrynet.database.migrator.migrate5to6 import do_migration
do_migration(db_dir) elif current == 6:
from lbrynet.database.migrator.migrate6to7 import do_migration
else: else:
raise Exception("DB migration of version {} to {} is not available".format(current, raise Exception("DB migration of version {} to {} is not available".format(current,
current+1)) current+1))
do_migration(db_dir)
current += 1 current += 1
return None return None

View file

@ -0,0 +1,13 @@
import sqlite3
import os
def do_migration(db_dir):
db_path = os.path.join(db_dir, "lbrynet.sqlite")
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.executescript("alter table blob add last_announced_time integer;")
cursor.executescript("alter table blob add single_announce integer;")
cursor.execute("update blob set next_announce_time=0")
connection.commit()
connection.close()

View file

@ -1,6 +1,5 @@
import logging import logging
import os import os
import time
import sqlite3 import sqlite3
import traceback import traceback
from decimal import Decimal from decimal import Decimal
@ -11,6 +10,7 @@ from lbryschema.claim import ClaimDict
from lbryschema.decode import smart_decode from lbryschema.decode import smart_decode
from lbrynet import conf from lbrynet import conf
from lbrynet.cryptstream.CryptBlob import CryptBlobInfo from lbrynet.cryptstream.CryptBlob import CryptBlobInfo
from lbrynet.dht.constants import dataExpireTimeout
from lbryum.constants import COIN from lbryum.constants import COIN
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -49,26 +49,6 @@ def open_file_for_writing(download_directory, suggested_file_name):
return threads.deferToThread(_open_file_for_writing, download_directory, suggested_file_name) return threads.deferToThread(_open_file_for_writing, download_directory, suggested_file_name)
def get_next_announce_time(hash_announcer, num_hashes_to_announce=1, min_reannounce_time=60*60,
single_announce_duration=5):
"""
Hash reannounce time is set to current time + MIN_HASH_REANNOUNCE_TIME,
unless we are announcing a lot of hashes at once which could cause the
the announce queue to pile up. To prevent pile up, reannounce
only after a conservative estimate of when it will finish
to announce all the hashes.
Args:
num_hashes_to_announce: number of hashes that will be added to the queue
Returns:
timestamp for next announce time
"""
queue_size = hash_announcer.hash_queue_size() + num_hashes_to_announce
reannounce = max(min_reannounce_time,
queue_size * single_announce_duration)
return time.time() + reannounce
def rerun_if_locked(f): def rerun_if_locked(f):
max_attempts = 3 max_attempts = 3
@ -124,7 +104,9 @@ class SQLiteStorage(object):
blob_length integer not null, blob_length integer not null,
next_announce_time integer not null, next_announce_time integer not null,
should_announce integer not null default 0, should_announce integer not null default 0,
status text not null status text not null,
last_announced_time integer,
single_announce integer
); );
create table if not exists stream ( create table if not exists stream (
@ -185,6 +167,7 @@ class SQLiteStorage(object):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = SqliteConnection(self._db_path) self.db = SqliteConnection(self._db_path)
self.db.set_reactor(reactor) self.db.set_reactor(reactor)
self.clock = reactor
# used to refresh the claim attributes on a ManagedEncryptedFileDownloader when a # used to refresh the claim attributes on a ManagedEncryptedFileDownloader when a
# change to the associated content claim occurs. these are added by the file manager # change to the associated content claim occurs. these are added by the file manager
@ -229,6 +212,7 @@ class SQLiteStorage(object):
) )
def set_should_announce(self, blob_hash, next_announce_time, should_announce): def set_should_announce(self, blob_hash, next_announce_time, should_announce):
next_announce_time = next_announce_time or 0
should_announce = 1 if should_announce else 0 should_announce = 1 if should_announce else 0
return self.db.runOperation( return self.db.runOperation(
"update blob set next_announce_time=?, should_announce=? where blob_hash=?", "update blob set next_announce_time=?, should_announce=? where blob_hash=?",
@ -250,8 +234,8 @@ class SQLiteStorage(object):
status = yield self.get_blob_status(blob_hash) status = yield self.get_blob_status(blob_hash)
if status is None: if status is None:
status = "pending" status = "pending"
yield self.db.runOperation("insert into blob values (?, ?, ?, ?, ?)", yield self.db.runOperation("insert into blob values (?, ?, ?, ?, ?, ?, ?)",
(blob_hash, length, 0, 0, status)) (blob_hash, length, 0, 0, status, 0, 0))
defer.returnValue(status) defer.returnValue(status)
def should_announce(self, blob_hash): def should_announce(self, blob_hash):
@ -269,13 +253,35 @@ class SQLiteStorage(object):
"select blob_hash from blob where should_announce=1 and status='finished'" "select blob_hash from blob where should_announce=1 and status='finished'"
) )
def get_blobs_to_announce(self, hash_announcer): def update_last_announced_blob(self, blob_hash, last_announced):
return self.db.runOperation(
"update blob set next_announce_time=?, last_announced_time=?, single_announce=0 where blob_hash=?",
(int(last_announced + (dataExpireTimeout / 2)), int(last_announced), blob_hash)
)
def should_single_announce_blobs(self, blob_hashes, immediate=False):
def set_single_announce(transaction):
now = self.clock.seconds()
for blob_hash in blob_hashes:
if immediate:
transaction.execute(
"update blob set single_announce=1, next_announce_time=? "
"where blob_hash=? and status='finished'", (int(now), blob_hash)
)
else:
transaction.execute(
"update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash, )
)
return self.db.runInteraction(set_single_announce)
def get_blobs_to_announce(self):
def get_and_update(transaction): def get_and_update(transaction):
timestamp = time.time() timestamp = self.clock.seconds()
if conf.settings['announce_head_blobs_only']: if conf.settings['announce_head_blobs_only']:
r = transaction.execute( r = transaction.execute(
"select blob_hash from blob " "select blob_hash from blob "
"where blob_hash is not null and should_announce=1 and next_announce_time<? and status='finished'", "where blob_hash is not null and "
"(should_announce=1 or single_announce=1) and next_announce_time<? and status='finished'",
(timestamp,) (timestamp,)
) )
else: else:
@ -283,16 +289,8 @@ class SQLiteStorage(object):
"select blob_hash from blob where blob_hash is not null " "select blob_hash from blob where blob_hash is not null "
"and next_announce_time<? and status='finished'", (timestamp,) "and next_announce_time<? and status='finished'", (timestamp,)
) )
blobs = [b[0] for b in r.fetchall()]
blobs = [b for b, in r.fetchall()]
next_announce_time = get_next_announce_time(hash_announcer, len(blobs))
transaction.execute(
"update blob set next_announce_time=? where next_announce_time<?", (next_announce_time, timestamp)
)
log.debug("Got %s blobs to announce, next announce time is in %s seconds", len(blobs),
next_announce_time-time.time())
return blobs return blobs
return self.db.runInteraction(get_and_update) return self.db.runInteraction(get_and_update)
def delete_blobs_from_db(self, blob_hashes): def delete_blobs_from_db(self, blob_hashes):

View file

@ -1,22 +0,0 @@
import time
class Delay(object):
maxToSendDelay = 10 ** -3 # 0.05
minToSendDelay = 10 ** -5 # 0.01
def __init__(self, start=0):
self._next = start
# TODO: explain why this logic is like it is. And add tests that
# show that it actually does what it needs to do.
def __call__(self):
ts = time.time()
delay = 0
if ts >= self._next:
delay = self.minToSendDelay
self._next = ts + self.minToSendDelay
else:
delay = (self._next - ts) + self.maxToSendDelay
self._next += self.maxToSendDelay
return delay

22
lbrynet/dht/distance.py Normal file
View file

@ -0,0 +1,22 @@
class Distance(object):
"""Calculate the XOR result between two string variables.
Frequently we re-use one of the points so as an optimization
we pre-calculate the long value of that point.
"""
def __init__(self, key):
self.key = key
self.val_key_one = long(key.encode('hex'), 16)
def __call__(self, key_two):
val_key_two = long(key_two.encode('hex'), 16)
return self.val_key_one ^ val_key_two
def is_closer(self, a, b):
"""Returns true is `a` is closer to `key` than `b` is"""
return self(a) < self(b)
def to_contact(self, contact):
"""A convenience function for calculating the distance to a contact"""
return self(contact.id)

View file

@ -0,0 +1,76 @@
import binascii
import logging
from twisted.internet import defer, task
from lbrynet.core import utils
from lbrynet import conf
log = logging.getLogger(__name__)
class DHTHashAnnouncer(object):
def __init__(self, dht_node, storage, concurrent_announcers=None):
self.dht_node = dht_node
self.storage = storage
self.clock = dht_node.clock
self.peer_port = dht_node.peerPort
self.hash_queue = []
self.concurrent_announcers = concurrent_announcers or conf.settings['concurrent_announcers']
self._manage_lc = task.LoopingCall(self.manage)
self._manage_lc.clock = self.clock
def start(self):
self._manage_lc.start(30)
def stop(self):
if self._manage_lc.running:
self._manage_lc.stop()
@defer.inlineCallbacks
def do_store(self, blob_hash):
storing_node_ids = yield self.dht_node.announceHaveBlob(binascii.unhexlify(blob_hash))
now = self.clock.seconds()
if storing_node_ids:
result = (now, storing_node_ids)
yield self.storage.update_last_announced_blob(blob_hash, now)
log.debug("Stored %s to %i peers", blob_hash[:16], len(storing_node_ids))
else:
result = (None, [])
self.hash_queue.remove(blob_hash)
defer.returnValue(result)
def _show_announce_progress(self, size, start):
queue_size = len(self.hash_queue)
average_blobs_per_second = float(size - queue_size) / (self.clock.seconds() - start)
log.info("Announced %i/%i blobs, %f blobs per second", size - queue_size, size, average_blobs_per_second)
@defer.inlineCallbacks
def immediate_announce(self, blob_hashes):
self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue)
log.info("Announcing %i blobs", len(self.hash_queue))
start = self.clock.seconds()
progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start)
progress_lc.start(60, now=False)
s = defer.DeferredSemaphore(self.concurrent_announcers)
results = yield utils.DeferredDict({blob_hash: s.run(self.do_store, blob_hash) for blob_hash in blob_hashes})
now = self.clock.seconds()
progress_lc.stop()
announced_to = [blob_hash for blob_hash in results if results[blob_hash][0]]
if len(announced_to) != len(results):
log.debug("Failed to announce %i blobs", len(results) - len(announced_to))
if announced_to:
log.info('Took %s seconds to announce %i of %i attempted hashes (%f hashes per second)',
now - start, len(announced_to), len(blob_hashes),
int(float(len(blob_hashes)) / float(now - start)))
defer.returnValue(results)
@defer.inlineCallbacks
def manage(self):
need_reannouncement = yield self.storage.get_blobs_to_announce()
if need_reannouncement:
yield self.immediate_announce(need_reannouncement)
else:
log.debug("Nothing to announce")

View file

@ -1,24 +1,22 @@
from collections import Counter from collections import Counter
import datetime import datetime
from twisted.internet import task
class HashWatcher(object): class HashWatcher(object):
def __init__(self): def __init__(self, clock=None):
if not clock:
from twisted.internet import reactor as clock
self.ttl = 600 self.ttl = 600
self.hashes = [] self.hashes = []
self.next_tick = None self.lc = task.LoopingCall(self._remove_old_hashes)
self.lc.clock = clock
def tick(self): def start(self):
return self.lc.start(10)
from twisted.internet import reactor
self._remove_old_hashes()
self.next_tick = reactor.callLater(10, self.tick)
def stop(self): def stop(self):
if self.next_tick is not None: return self.lc.stop()
self.next_tick.cancel()
self.next_tick = None
def add_requested_hash(self, hashsum, contact): def add_requested_hash(self, hashsum, contact):
from_ip = contact.compact_ip from_ip = contact.compact_ip

View file

@ -11,19 +11,23 @@ import hashlib
import operator import operator
import struct import struct
import time import time
import logging
from twisted.internet import defer, error, task
from twisted.internet import defer, error, reactor, threads, task from lbrynet.core.utils import generate_id
from lbrynet.core.call_later_manager import CallLaterManager
from lbrynet.core.PeerManager import PeerManager
import constants import constants
import routingtable import routingtable
import datastore import datastore
import protocol import protocol
from error import TimeoutError
from peerfinder import DHTPeerFinder
from contact import Contact from contact import Contact
from hashwatcher import HashWatcher from hashwatcher import HashWatcher
import logging from distance import Distance
from lbrynet.core.utils import generate_id
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -51,7 +55,9 @@ class Node(object):
def __init__(self, node_id=None, udpPort=4000, dataStore=None, def __init__(self, node_id=None, udpPort=4000, dataStore=None,
routingTableClass=None, networkProtocol=None, routingTableClass=None, networkProtocol=None,
externalIP=None, peerPort=None): externalIP=None, peerPort=None, listenUDP=None,
callLater=None, resolve=None, clock=None, peer_finder=None,
peer_manager=None):
""" """
@param dataStore: The data store to use. This must be class inheriting @param dataStore: The data store to use. This must be class inheriting
from the C{DataStore} interface (or providing the from the C{DataStore} interface (or providing the
@ -76,6 +82,18 @@ class Node(object):
@param externalIP: the IP at which this node can be contacted @param externalIP: the IP at which this node can be contacted
@param peerPort: the port at which this node announces it has a blob for @param peerPort: the port at which this node announces it has a blob for
""" """
if not listenUDP or not resolve or not callLater or not clock:
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
self.clock = clock
CallLaterManager.setup(callLater)
self.reactor_resolve = resolve
self.reactor_listenUDP = listenUDP
self.reactor_callLater = CallLaterManager.call_later
self.node_id = node_id or self._generateID() self.node_id = node_id or self._generateID()
self.port = udpPort self.port = udpPort
self._listeningPort = None # object implementing Twisted self._listeningPort = None # object implementing Twisted
@ -84,14 +102,17 @@ class Node(object):
# information from the DHT as soon as the node is part of the # information from the DHT as soon as the node is part of the
# network (add callbacks to this deferred if scheduling such # network (add callbacks to this deferred if scheduling such
# operations before the node has finished joining the network) # operations before the node has finished joining the network)
self._joinDeferred = None self._joinDeferred = defer.Deferred(None)
self.next_refresh_call = None
self.change_token_lc = task.LoopingCall(self.change_token) self.change_token_lc = task.LoopingCall(self.change_token)
self.change_token_lc.clock = self.clock
self.refresh_node_lc = task.LoopingCall(self._refreshNode)
self.refresh_node_lc.clock = self.clock
# Create k-buckets (for storing contacts) # Create k-buckets (for storing contacts)
if routingTableClass is None: if routingTableClass is None:
self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id) self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id, self.clock.seconds)
else: else:
self._routingTable = routingTableClass(self.node_id) self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
# Initialize this node's network access mechanisms # Initialize this node's network access mechanisms
if networkProtocol is None: if networkProtocol is None:
@ -115,75 +136,95 @@ class Node(object):
self._routingTable.addContact(contact) self._routingTable.addContact(contact)
self.externalIP = externalIP self.externalIP = externalIP
self.peerPort = peerPort self.peerPort = peerPort
self.hash_watcher = HashWatcher() self.hash_watcher = HashWatcher(self.clock)
self.peer_manager = peer_manager or PeerManager()
self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager)
def __del__(self): def __del__(self):
log.warning("unclean shutdown of the dht node")
if self._listeningPort is not None: if self._listeningPort is not None:
self._listeningPort.stopListening() self._listeningPort.stopListening()
@defer.inlineCallbacks
def stop(self): def stop(self):
# cancel callLaters: # stop LoopingCalls:
if self.next_refresh_call is not None: if self.refresh_node_lc.running:
self.next_refresh_call.cancel() yield self.refresh_node_lc.stop()
self.next_refresh_call = None
if self.change_token_lc.running: if self.change_token_lc.running:
self.change_token_lc.stop() yield self.change_token_lc.stop()
if self._listeningPort is not None: if self._listeningPort is not None:
self._listeningPort.stopListening() yield self._listeningPort.stopListening()
self.hash_watcher.stop() if self.hash_watcher.lc.running:
yield self.hash_watcher.stop()
def startNetwork(self): def start_listening(self):
""" Causes the Node to start all the underlying components needed for the DHT if not self._listeningPort:
to work. This should be called before any other DHT operations.
"""
log.info("Starting DHT underlying components")
# Prepare the underlying Kademlia protocol
if self.port is not None:
try: try:
self._listeningPort = reactor.listenUDP(self.port, self._protocol) self._listeningPort = self.reactor_listenUDP(self.port, self._protocol)
except error.CannotListenError as e: except error.CannotListenError as e:
import traceback import traceback
log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc()) log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc())
raise ValueError("%s lbrynet may already be running." % str(e)) raise ValueError("%s lbrynet may already be running." % str(e))
else:
log.warning("Already bound to port %d", self._listeningPort.port)
# Start the token looping call def bootstrap_join(self, known_node_addresses, finished_d):
self.change_token_lc.start(constants.tokenSecretChangeInterval) """
# #TODO: Refresh all k-buckets further away than this node's closest neighbour Attempt to join the dht, retry every 30 seconds if unsuccessful
# Start refreshing k-buckets periodically, if necessary :param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes
self.next_refresh_call = reactor.callLater(constants.checkRefreshInterval, :param finished_d: (defer.Deferred) called when join succeeds
self._refreshNode) """
self.hash_watcher.tick() @defer.inlineCallbacks
def _resolve_seeds():
bootstrap_contacts = []
for node_address, port in known_node_addresses:
host = yield self.reactor_resolve(node_address)
# Create temporary contact information for the list of addresses of known nodes
contact = Contact(self._generateID(), host, port, self._protocol)
bootstrap_contacts.append(contact)
if not bootstrap_contacts:
if not self.hasContacts():
log.warning("No known contacts!")
else:
log.info("found contacts")
bootstrap_contacts = self.contacts
defer.returnValue(bootstrap_contacts)
def _rerun(closest_nodes):
if not closest_nodes:
log.info("Failed to join the dht, re-attempting in 30 seconds")
self.reactor_callLater(30, self.bootstrap_join, known_node_addresses, finished_d)
elif not finished_d.called:
finished_d.callback(closest_nodes)
log.info("Attempting to join the DHT network")
d = _resolve_seeds()
# Initiate the Kademlia joining sequence - perform a search for this node's own ID
d.addCallback(lambda contacts: self._iterativeFind(self.node_id, contacts))
d.addCallback(_rerun)
@defer.inlineCallbacks @defer.inlineCallbacks
def joinNetwork(self, knownNodeAddresses=None): def joinNetwork(self, known_node_addresses=None):
""" Causes the Node to attempt to join the DHT network by contacting the """ Causes the Node to attempt to join the DHT network by contacting the
known DHT nodes. This can be called multiple times if the previous attempt known DHT nodes. This can be called multiple times if the previous attempt
has failed or if the Node has lost all the contacts. has failed or if the Node has lost all the contacts.
@param knownNodeAddresses: A sequence of tuples containing IP address @param known_node_addresses: A sequence of tuples containing IP address
information for existing nodes on the information for existing nodes on the
Kademlia network, in the format: Kademlia network, in the format:
C{(<ip address>, (udp port>)} C{(<ip address>, (udp port>)}
@type knownNodeAddresses: tuple @type known_node_addresses: list
""" """
log.info("Attempting to join the DHT network")
# IGNORE:E1101 self.start_listening()
# Create temporary contact information for the list of addresses of known nodes # #TODO: Refresh all k-buckets further away than this node's closest neighbour
if knownNodeAddresses != None: # Start refreshing k-buckets periodically, if necessary
bootstrapContacts = [] self.bootstrap_join(known_node_addresses or [], self._joinDeferred)
for address, port in knownNodeAddresses: yield self._joinDeferred
contact = Contact(self._generateID(), address, port, self._protocol) self.hash_watcher.start()
bootstrapContacts.append(contact) self.change_token_lc.start(constants.tokenSecretChangeInterval)
else: self.refresh_node_lc.start(constants.checkRefreshInterval)
bootstrapContacts = None
# Initiate the Kademlia joining sequence - perform a search for this node's own ID
self._joinDeferred = self._iterativeFind(self.node_id, bootstrapContacts)
result = yield self._joinDeferred
defer.returnValue(result)
@property @property
def contacts(self): def contacts(self):
@ -193,41 +234,19 @@ class Node(object):
yield contact yield contact
return list(_inner()) return list(_inner())
def printContacts(self, *args):
print '\n\nNODE CONTACTS\n==============='
for i in range(len(self._routingTable._buckets)):
print "bucket %i" % i
for contact in self._routingTable._buckets[i]._contacts:
print " %s:%i" % (contact.address, contact.port)
print '=================================='
def hasContacts(self): def hasContacts(self):
for bucket in self._routingTable._buckets: for bucket in self._routingTable._buckets:
if bucket._contacts: if bucket._contacts:
return True return True
return False return False
def getApproximateTotalDHTNodes(self):
# get the deepest bucket and the number of contacts in that bucket and multiply it
# by the number of equivalently deep buckets in the whole DHT to get a really bad
# estimate!
bucket = self._routingTable._buckets[self._routingTable._kbucketIndex(self.node_id)]
num_in_bucket = len(bucket._contacts)
factor = (2 ** constants.key_bits) / (bucket.rangeMax - bucket.rangeMin)
return num_in_bucket * factor
def getApproximateTotalHashes(self):
# Divide the number of hashes we know about by k to get a really, really, really
# bad estimate of the average number of hashes per node, then multiply by the
# approximate number of nodes to get a horrendous estimate of the total number
# of hashes in the DHT
num_in_data_store = len(self._dataStore._dict)
if num_in_data_store == 0:
return 0
return num_in_data_store * self.getApproximateTotalDHTNodes() / 8
def announceHaveBlob(self, key): def announceHaveBlob(self, key):
return self.iterativeAnnounceHaveBlob(key, {'port': self.peerPort, 'lbryid': self.node_id}) return self.iterativeAnnounceHaveBlob(
key, {
'port': self.peerPort,
'lbryid': self.node_id,
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def getPeersForBlob(self, blob_hash): def getPeersForBlob(self, blob_hash):
@ -245,69 +264,57 @@ class Node(object):
def get_most_popular_hashes(self, num_to_return): def get_most_popular_hashes(self, num_to_return):
return self.hash_watcher.most_popular_hashes(num_to_return) return self.hash_watcher.most_popular_hashes(num_to_return)
def get_bandwidth_stats(self): @defer.inlineCallbacks
return self._protocol.bandwidth_stats
def iterativeAnnounceHaveBlob(self, blob_hash, value): def iterativeAnnounceHaveBlob(self, blob_hash, value):
known_nodes = {} known_nodes = {}
contacts = yield self.iterativeFindNode(blob_hash)
def log_error(err, n): # store locally if we're the closest node and there are less than k contacts to try storing to
if err.check(protocol.TimeoutError): if self.externalIP is not None and contacts and len(contacts) < constants.k:
log.debug(
"Timeout while storing blob_hash %s at %s",
binascii.hexlify(blob_hash), n)
else:
log.error(
"Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), n, err.getErrorMessage())
def log_success(res):
log.debug("Response to store request: %s", str(res))
return res
def announce_to_peer(responseTuple):
""" @type responseMsg: kademlia.msgtypes.ResponseMessage """
# The "raw response" tuple contains the response message,
# and the originating address info
responseMsg = responseTuple[0]
originAddress = responseTuple[1] # tuple: (ip adress, udp port)
# Make sure the responding node is valid, and abort the operation if it isn't
if not responseMsg.nodeID in known_nodes:
return responseMsg.nodeID
n = known_nodes[responseMsg.nodeID]
result = responseMsg.response
if 'token' in result:
value['token'] = result['token']
d = n.store(blob_hash, value, self.node_id, 0)
d.addCallback(log_success)
d.addErrback(log_error, n)
else:
d = defer.succeed(False)
return d
def requestPeers(contacts):
if self.externalIP is not None and len(contacts) >= constants.k:
is_closer = Distance(blob_hash).is_closer(self.node_id, contacts[-1].id) is_closer = Distance(blob_hash).is_closer(self.node_id, contacts[-1].id)
if is_closer: if is_closer:
contacts.pop() contacts.pop()
self.store(blob_hash, value, self_store=True, originalPublisherID=self.node_id) yield self.store(blob_hash, value, originalPublisherID=self.node_id,
self_store=True)
elif self.externalIP is not None: elif self.externalIP is not None:
self.store(blob_hash, value, self_store=True, originalPublisherID=self.node_id) pass
ds = [] else:
for contact in contacts: raise Exception("Cannot determine external IP: %s" % self.externalIP)
known_nodes[contact.id] = contact
rpcMethod = getattr(contact, "findValue")
df = rpcMethod(blob_hash, rawResponse=True)
df.addCallback(announce_to_peer)
df.addErrback(log_error, contact)
ds.append(df)
return defer.DeferredList(ds)
d = self.iterativeFindNode(blob_hash) contacted = []
d.addCallbacks(requestPeers)
return d @defer.inlineCallbacks
def announce_to_contact(contact):
known_nodes[contact.id] = contact
try:
responseMsg, originAddress = yield contact.findValue(blob_hash, rawResponse=True)
if responseMsg.nodeID != contact.id:
raise Exception("node id mismatch")
value['token'] = responseMsg.response['token']
res = yield contact.store(blob_hash, value)
if res != "OK":
raise ValueError(res)
contacted.append(contact)
log.debug("Stored %s to %s (%s)", blob_hash.encode('hex'), contact.id.encode('hex'), originAddress[0])
except protocol.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s",
blob_hash.encode('hex')[:16], contact.id.encode('hex'))
except ValueError as err:
log.error("Unexpected response: %s" % err.message)
except Exception as err:
log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err)
dl = []
for c in contacts:
dl.append(announce_to_contact(c))
yield defer.DeferredList(dl)
log.debug("Stored %s to %i of %i attempted peers", blob_hash.encode('hex')[:16],
len(contacted), len(contacts))
contacted_node_ids = [c.id.encode('hex') for c in contacted]
defer.returnValue(contacted_node_ids)
def change_token(self): def change_token(self):
self.old_token_secret = self.token_secret self.old_token_secret = self.token_secret
@ -365,14 +372,13 @@ class Node(object):
to the specified key to the specified key
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
# Prepare a callback for this operation
outerDf = defer.Deferred()
def checkResult(result): # Execute the search
if isinstance(result, dict): iterative_find_result = yield self._iterativeFind(key, rpc='findValue')
if isinstance(iterative_find_result, dict):
# We have found the value; now see who was the closest contact without it... # We have found the value; now see who was the closest contact without it...
# ...and store the key/value pair # ...and store the key/value pair
outerDf.callback(result) defer.returnValue(iterative_find_result)
else: else:
# The value wasn't found, but a list of contacts was returned # The value wasn't found, but a list of contacts was returned
# Now, see if we have the value (it might seem wasteful to search on the network # Now, see if we have the value (it might seem wasteful to search on the network
@ -380,18 +386,12 @@ class Node(object):
# network # network
if self._dataStore.hasPeersForBlob(key): if self._dataStore.hasPeersForBlob(key):
# Ok, we have the value locally, so use that # Ok, we have the value locally, so use that
peers = self._dataStore.getPeersForBlob(key)
# Send this value to the closest node without it # Send this value to the closest node without it
outerDf.callback({key: peers}) peers = self._dataStore.getPeersForBlob(key)
defer.returnValue({key: peers})
else: else:
# Ok, value does not exist in DHT at all # Ok, value does not exist in DHT at all
outerDf.callback(result) defer.returnValue(iterative_find_result)
# Execute the search
iterative_find_result = yield self._iterativeFind(key, rpc='findValue')
checkResult(iterative_find_result)
result = yield outerDf
defer.returnValue(result)
def addContact(self, contact): def addContact(self, contact):
""" Add/update the given contact; simple wrapper for the same method """ Add/update the given contact; simple wrapper for the same method
@ -486,18 +486,20 @@ class Node(object):
else: else:
raise TypeError, 'No contact info available' raise TypeError, 'No contact info available'
if ((self_store is False) and if not self_store:
('token' not in value or not self.verify_token(value['token'], compact_ip))): if 'token' not in value:
raise ValueError('Invalid or missing token') raise ValueError("Missing token")
if not self.verify_token(value['token'], compact_ip):
raise ValueError("Invalid token")
if 'port' in value: if 'port' in value:
port = int(value['port']) port = int(value['port'])
if 0 <= port <= 65536: if 0 <= port <= 65536:
compact_port = str(struct.pack('>H', port)) compact_port = str(struct.pack('>H', port))
else: else:
raise TypeError, 'Invalid port' raise TypeError('Invalid port')
else: else:
raise TypeError, 'No port available' raise TypeError('No port available')
if 'lbryid' in value: if 'lbryid' in value:
if len(value['lbryid']) != constants.key_bits / 8: if len(value['lbryid']) != constants.key_bits / 8:
@ -506,7 +508,7 @@ class Node(object):
else: else:
compact_address = compact_ip + compact_port + value['lbryid'] compact_address = compact_ip + compact_port + value['lbryid']
else: else:
raise TypeError, 'No lbryid given' raise TypeError('No lbryid given')
now = int(time.time()) now = int(time.time())
originallyPublished = now # - age originallyPublished = now # - age
@ -628,39 +630,22 @@ class Node(object):
result = yield outerDf result = yield outerDf
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _refreshNode(self): def _refreshNode(self):
""" Periodically called to perform k-bucket refreshes and data """ Periodically called to perform k-bucket refreshes and data
replication/republishing as necessary """ replication/republishing as necessary """
df = self._refreshRoutingTable() yield self._refreshRoutingTable()
df.addCallback(self._removeExpiredPeers) self._dataStore.removeExpiredPeers()
df.addCallback(self._scheduleNextNodeRefresh) defer.returnValue(None)
@defer.inlineCallbacks
def _refreshRoutingTable(self): def _refreshRoutingTable(self):
nodeIDs = self._routingTable.getRefreshList(0, False) nodeIDs = self._routingTable.getRefreshList(0, False)
outerDf = defer.Deferred() while nodeIDs:
def searchForNextNodeID(dfResult=None):
if len(nodeIDs) > 0:
searchID = nodeIDs.pop() searchID = nodeIDs.pop()
df = self.iterativeFindNode(searchID) yield self.iterativeFindNode(searchID)
df.addCallback(searchForNextNodeID) defer.returnValue(None)
else:
# If this is reached, we have finished refreshing the routing table
outerDf.callback(None)
# Start the refreshing cycle
searchForNextNodeID()
return outerDf
def _scheduleNextNodeRefresh(self, *args):
self.next_refresh_call = reactor.callLater(constants.checkRefreshInterval,
self._refreshNode)
# args put here because _refreshRoutingTable does outerDF.callback(None)
def _removeExpiredPeers(self, *args):
df = threads.deferToThread(self._dataStore.removeExpiredPeers)
return df
# This was originally a set of nested methods in _iterativeFind # This was originally a set of nested methods in _iterativeFind
@ -770,7 +755,7 @@ class _IterativeFindHelper(object):
def removeFromShortlist(self, failure, deadContactID): def removeFromShortlist(self, failure, deadContactID):
""" @type failure: twisted.python.failure.Failure """ """ @type failure: twisted.python.failure.Failure """
failure.trap(protocol.TimeoutError) failure.trap(TimeoutError, defer.CancelledError, TypeError)
if len(deadContactID) != constants.key_bits / 8: if len(deadContactID) != constants.key_bits / 8:
raise ValueError("invalid lbry id") raise ValueError("invalid lbry id")
if deadContactID in self.shortlist: if deadContactID in self.shortlist:
@ -827,7 +812,7 @@ class _IterativeFindHelper(object):
if self._should_lookup_active_calls(): if self._should_lookup_active_calls():
# Schedule the next iteration if there are any active # Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism) # calls (Kademlia uses loose parallelism)
call = reactor.callLater(constants.iterativeLookupDelay, self.searchIteration) call, _ = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
self.pending_iteration_calls.append(call) self.pending_iteration_calls.append(call)
# Check for a quick contact response that made an update to the shortList # Check for a quick contact response that made an update to the shortList
elif prevShortlistLength < len(self.shortlist): elif prevShortlistLength < len(self.shortlist):
@ -867,30 +852,6 @@ class _IterativeFindHelper(object):
) )
class Distance(object):
"""Calculate the XOR result between two string variables.
Frequently we re-use one of the points so as an optimization
we pre-calculate the long value of that point.
"""
def __init__(self, key):
self.key = key
self.val_key_one = long(key.encode('hex'), 16)
def __call__(self, key_two):
val_key_two = long(key_two.encode('hex'), 16)
return self.val_key_one ^ val_key_two
def is_closer(self, a, b):
"""Returns true is `a` is closer to `key` than `b` is"""
return self(a) < self(b)
def to_contact(self, contact):
"""A convenience function for calculating the distance to a contact"""
return self(contact.id)
class ExpensiveSort(object): class ExpensiveSort(object):
"""Sort a list in place. """Sort a list in place.

View file

@ -2,7 +2,7 @@ import binascii
import logging import logging
from zope.interface import implements from zope.interface import implements
from twisted.internet import defer, reactor from twisted.internet import defer
from lbrynet.interfaces import IPeerFinder from lbrynet.interfaces import IPeerFinder
from lbrynet.core.utils import short_hash from lbrynet.core.utils import short_hash
@ -10,7 +10,23 @@ from lbrynet.core.utils import short_hash
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class DHTPeerFinder(object): class DummyPeerFinder(object):
"""This class finds peers which have announced to the DHT that they have certain blobs"""
def run_manage_loop(self):
pass
def stop(self):
pass
def find_peers_for_blob(self, blob_hash):
return defer.succeed([])
def get_most_popular_hashes(self, num_to_return):
return []
class DHTPeerFinder(DummyPeerFinder):
"""This class finds peers which have announced to the DHT that they have certain blobs""" """This class finds peers which have announced to the DHT that they have certain blobs"""
implements(IPeerFinder) implements(IPeerFinder)
@ -47,7 +63,7 @@ class DHTPeerFinder(object):
finished_deferred = self.dht_node.getPeersForBlob(bin_hash) finished_deferred = self.dht_node.getPeersForBlob(bin_hash)
if timeout is not None: if timeout is not None:
reactor.callLater(timeout, _trigger_timeout) self.dht_node.reactor_callLater(timeout, _trigger_timeout)
try: try:
peer_list = yield finished_deferred peer_list = yield finished_deferred

View file

@ -1,9 +1,9 @@
import logging import logging
import time
import socket import socket
import errno import errno
from twisted.internet import protocol, defer, error, reactor, task from twisted.internet import protocol, defer
from lbrynet.core.call_later_manager import CallLaterManager
import constants import constants
import encoding import encoding
@ -11,7 +11,6 @@ import msgtypes
import msgformat import msgformat
from contact import Contact from contact import Contact
from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError
from delay import Delay
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -28,104 +27,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._sentMessages = {} self._sentMessages = {}
self._partialMessages = {} self._partialMessages = {}
self._partialMessagesProgress = {} self._partialMessagesProgress = {}
self._delay = Delay()
# keep track of outstanding writes so that they
# can be cancelled on shutdown
self._call_later_list = {}
# keep track of bandwidth usage by peer
self._history_rx = {}
self._history_tx = {}
self._bytes_rx = {}
self._bytes_tx = {}
self._unique_contacts = []
self._queries_rx_per_second = 0
self._queries_tx_per_second = 0
self._kbps_tx = 0
self._kbps_rx = 0
self._recent_contact_count = 0
self._total_bytes_tx = 0
self._total_bytes_rx = 0
self._bandwidth_stats_update_lc = task.LoopingCall(self._update_bandwidth_stats)
def _update_bandwidth_stats(self):
recent_rx_history = {}
now = time.time()
for address, history in self._history_rx.iteritems():
recent_rx_history[address] = [(s, t) for (s, t) in history if now - t < 1.0]
qps_rx = sum(len(v) for (k, v) in recent_rx_history.iteritems())
bps_rx = sum(sum([x[0] for x in v]) for (k, v) in recent_rx_history.iteritems())
kbps_rx = round(float(bps_rx) / 1024.0, 2)
recent_tx_history = {}
now = time.time()
for address, history in self._history_tx.iteritems():
recent_tx_history[address] = [(s, t) for (s, t) in history if now - t < 1.0]
qps_tx = sum(len(v) for (k, v) in recent_tx_history.iteritems())
bps_tx = sum(sum([x[0] for x in v]) for (k, v) in recent_tx_history.iteritems())
kbps_tx = round(float(bps_tx) / 1024.0, 2)
recent_contacts = []
for k, v in recent_rx_history.iteritems():
if v:
recent_contacts.append(k)
for k, v in recent_tx_history.iteritems():
if v and k not in recent_contacts:
recent_contacts.append(k)
self._queries_rx_per_second = qps_rx
self._queries_tx_per_second = qps_tx
self._kbps_tx = kbps_tx
self._kbps_rx = kbps_rx
self._recent_contact_count = len(recent_contacts)
self._total_bytes_tx = sum(v for (k, v) in self._bytes_tx.iteritems())
self._total_bytes_rx = sum(v for (k, v) in self._bytes_rx.iteritems())
@property
def unique_contacts(self):
return self._unique_contacts
@property
def queries_rx_per_second(self):
return self._queries_rx_per_second
@property
def queries_tx_per_second(self):
return self._queries_tx_per_second
@property
def kbps_tx(self):
return self._kbps_tx
@property
def kbps_rx(self):
return self._kbps_rx
@property
def recent_contact_count(self):
return self._recent_contact_count
@property
def total_bytes_tx(self):
return self._total_bytes_tx
@property
def total_bytes_rx(self):
return self._total_bytes_rx
@property
def bandwidth_stats(self):
response = {
"kbps_received": self.kbps_rx,
"kbps_sent": self.kbps_tx,
"total_bytes_sent": self.total_bytes_tx,
"total_bytes_received": self.total_bytes_rx,
"queries_received": self.queries_rx_per_second,
"queries_sent": self.queries_tx_per_second,
"recent_contacts": self.recent_contact_count,
"unique_contacts": len(self.unique_contacts)
}
return response
def sendRPC(self, contact, method, args, rawResponse=False): def sendRPC(self, contact, method, args, rawResponse=False):
""" Sends an RPC to the specified contact """ Sends an RPC to the specified contact
@ -168,16 +69,15 @@ class KademliaProtocol(protocol.DatagramProtocol):
df._rpcRawResponse = True df._rpcRawResponse = True
# Set the RPC timeout timer # Set the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, msg.id) timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data # Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port)) self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args) self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args)
df.addErrback(cancelTimeout)
return df return df
def startProtocol(self): def startProtocol(self):
log.info("DHT listening on UDP %i", self._node.port) log.info("DHT listening on UDP %i", self._node.port)
if not self._bandwidth_stats_update_lc.running:
self._bandwidth_stats_update_lc.start(1)
def datagramReceived(self, datagram, address): def datagramReceived(self, datagram, address):
""" Handles and parses incoming RPC messages (and responses) """ Handles and parses incoming RPC messages (and responses)
@ -206,8 +106,9 @@ class KademliaProtocol(protocol.DatagramProtocol):
try: try:
msgPrimitive = self._encoder.decode(datagram) msgPrimitive = self._encoder.decode(datagram)
message = self._translator.fromPrimitive(msgPrimitive) message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError): except (encoding.DecodeError, ValueError) as err:
# We received some rubbish here # We received some rubbish here
log.exception("Error decoding datagram from %s:%i - %s", address[0], address[1], err)
return return
except (IndexError, KeyError): except (IndexError, KeyError):
log.warning("Couldn't decode dht datagram from %s", address) log.warning("Couldn't decode dht datagram from %s", address)
@ -215,18 +116,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
remoteContact = Contact(message.nodeID, address[0], address[1], self) remoteContact = Contact(message.nodeID, address[0], address[1], self)
now = time.time()
contact_history = self._history_rx.get(address, [])
if len(contact_history) > 1000:
contact_history = [x for x in contact_history if now - x[1] < 1.0]
contact_history.append((len(datagram), time.time()))
self._history_rx[address] = contact_history
bytes_rx = self._bytes_rx.get(address, 0)
bytes_rx += len(datagram)
self._bytes_rx[address] = bytes_rx
if address not in self.unique_contacts:
self._unique_contacts.append(address)
# Refresh the remote node's details in the local node's k-buckets # Refresh the remote node's details in the local node's k-buckets
self._node.addContact(remoteContact) self._node.addContact(remoteContact)
if isinstance(message, msgtypes.RequestMessage): if isinstance(message, msgtypes.RequestMessage):
@ -253,7 +142,17 @@ class KademliaProtocol(protocol.DatagramProtocol):
else: else:
exception_type = UnknownRemoteException exception_type = UnknownRemoteException
remoteException = exception_type(message.response) remoteException = exception_type(message.response)
log.error("Remote exception (%s): %s", address, remoteException) # this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
df.errback(remoteException) df.errback(remoteException)
else: else:
# We got a result from the RPC # We got a result from the RPC
@ -281,18 +180,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
class (see C{kademlia.msgformat} and C{kademlia.encoding}). class (see C{kademlia.msgformat} and C{kademlia.encoding}).
""" """
now = time.time()
contact_history = self._history_tx.get(address, [])
if len(contact_history) > 1000:
contact_history = [x for x in contact_history if now - x[1] < 1.0]
contact_history.append((len(data), time.time()))
self._history_tx[address] = contact_history
bytes_tx = self._bytes_tx.get(address, 0)
bytes_tx += len(data)
self._bytes_tx[address] = bytes_tx
if address not in self.unique_contacts:
self._unique_contacts.append(address)
if len(data) > self.msgSizeLimit: if len(data) > self.msgSizeLimit:
# We have to spread the data over multiple UDP datagrams, # We have to spread the data over multiple UDP datagrams,
# and provide sequencing information # and provide sequencing information
@ -319,13 +206,9 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _scheduleSendNext(self, txData, address): def _scheduleSendNext(self, txData, address):
"""Schedule the sending of the next UDP packet """ """Schedule the sending of the next UDP packet """
delay = self._delay() delayed_call, _ = self._node.reactor_callLater(0, self._write, txData, address)
key = object()
delayed_call = reactor.callLater(delay, self._write_and_remove, key, txData, address)
self._call_later_list[key] = delayed_call
def _write_and_remove(self, key, txData, address): def _write(self, txData, address):
del self._call_later_list[key]
if self.transport: if self.transport:
try: try:
self.transport.write(txData, address) self.transport.write(txData, address)
@ -333,13 +216,15 @@ class KademliaProtocol(protocol.DatagramProtocol):
if err.errno == errno.EWOULDBLOCK: if err.errno == errno.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these # i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesnt seem to affect anything -grin # on Linux and it doesnt seem to affect anything -grin
log.debug("Can't send data to dht: EWOULDBLOCK") log.warning("Can't send data to dht: EWOULDBLOCK")
elif err.errno == errno.ENETUNREACH: elif err.errno == errno.ENETUNREACH:
# this should probably try to retransmit when the network connection is back # this should probably try to retransmit when the network connection is back
log.error("Network is unreachable") log.error("Network is unreachable")
else: else:
log.error("DHT socket error: %s (%i)", err.message, err.errno) log.error("DHT socket error: %s (%i)", err.message, err.errno)
raise err raise err
else:
log.warning("transport not connected!")
def _sendResponse(self, contact, rpcID, response): def _sendResponse(self, contact, rpcID, response):
""" Send a RPC response to the specified contact """ Send a RPC response to the specified contact
@ -418,7 +303,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# See if any progress has been made; if not, kill the message # See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID): if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer # Reset the RPC timeout timer
timeoutCall = reactor.callLater(constants.rpcTimeout, self._msgTimeout, messageID) timeoutCall, _ = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args)
else: else:
# No progress has been made # No progress has been made
@ -443,19 +328,5 @@ class KademliaProtocol(protocol.DatagramProtocol):
Will only be called once, after all ports are disconnected. Will only be called once, after all ports are disconnected.
""" """
log.info('Stopping DHT') log.info('Stopping DHT')
CallLaterManager.stop()
if self._bandwidth_stats_update_lc.running:
self._bandwidth_stats_update_lc.stop()
for delayed_call in self._call_later_list.values():
try:
delayed_call.cancel()
except (error.AlreadyCalled, error.AlreadyCancelled):
log.debug('Attempted to cancel a DelayedCall that was not active')
except Exception:
log.exception('Failed to cancel a DelayedCall')
# not sure why this is needed, but taking this out sometimes causes
# exceptions.AttributeError: 'Port' object has no attribute 'socket'
# to happen on shutdown
# reactor.iterate()
log.info('DHT stopped') log.info('DHT stopped')

View file

@ -5,7 +5,6 @@
# The docstrings in this module contain epytext markup; API documentation # The docstrings in this module contain epytext markup; API documentation
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
import time
import random import random
from zope.interface import implements from zope.interface import implements
import constants import constants
@ -20,7 +19,7 @@ log = logging.getLogger(__name__)
class TreeRoutingTable(object): class TreeRoutingTable(object):
""" This class implements a routing table used by a Node class. """ This class implements a routing table used by a Node class.
The Kademlia routing table is a binary tree whose leaves are k-buckets, The Kademlia routing table is a binary tree whFose leaves are k-buckets,
where each k-bucket contains nodes with some common prefix of their IDs. where each k-bucket contains nodes with some common prefix of their IDs.
This prefix is the k-bucket's position in the binary tree; it therefore This prefix is the k-bucket's position in the binary tree; it therefore
covers some range of ID values, and together all of the k-buckets cover covers some range of ID values, and together all of the k-buckets cover
@ -34,7 +33,7 @@ class TreeRoutingTable(object):
""" """
implements(IRoutingTable) implements(IRoutingTable)
def __init__(self, parentNodeID): def __init__(self, parentNodeID, getTime=None):
""" """
@param parentNodeID: The n-bit node ID of the node to which this @param parentNodeID: The n-bit node ID of the node to which this
routing table belongs routing table belongs
@ -43,6 +42,9 @@ class TreeRoutingTable(object):
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space # Create the initial (single) k-bucket covering the range of the entire n-bit ID space
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)] self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)]
self._parentNodeID = parentNodeID self._parentNodeID = parentNodeID
if not getTime:
from time import time as getTime
self._getTime = getTime
def addContact(self, contact): def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already """ Add the given contact to the correct k-bucket; if it already
@ -194,7 +196,7 @@ class TreeRoutingTable(object):
bucketIndex = startIndex bucketIndex = startIndex
refreshIDs = [] refreshIDs = []
for bucket in self._buckets[startIndex:]: for bucket in self._buckets[startIndex:]:
if force or (int(time.time()) - bucket.lastAccessed >= constants.refreshTimeout): if force or (int(self._getTime()) - bucket.lastAccessed >= constants.refreshTimeout):
searchID = self._randomIDInBucketRange(bucketIndex) searchID = self._randomIDInBucketRange(bucketIndex)
refreshIDs.append(searchID) refreshIDs.append(searchID)
bucketIndex += 1 bucketIndex += 1
@ -221,7 +223,7 @@ class TreeRoutingTable(object):
@type key: str @type key: str
""" """
bucketIndex = self._kbucketIndex(key) bucketIndex = self._kbucketIndex(key)
self._buckets[bucketIndex].lastAccessed = int(time.time()) self._buckets[bucketIndex].lastAccessed = int(self._getTime())
def _kbucketIndex(self, key): def _kbucketIndex(self, key):
""" Calculate the index of the k-bucket which is responsible for the """ Calculate the index of the k-bucket which is responsible for the
@ -289,8 +291,8 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
of the 13-page version of the Kademlia paper. of the 13-page version of the Kademlia paper.
""" """
def __init__(self, parentNodeID): def __init__(self, parentNodeID, getTime=None):
TreeRoutingTable.__init__(self, parentNodeID) TreeRoutingTable.__init__(self, parentNodeID, getTime)
# Cache containing nodes eligible to replace stale k-bucket entries # Cache containing nodes eligible to replace stale k-bucket entries
self._replacementCache = {} self._replacementCache = {}
@ -301,6 +303,7 @@ class OptimizedTreeRoutingTable(TreeRoutingTable):
@param contact: The contact to add to this node's k-buckets @param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact @type contact: kademlia.contact.Contact
""" """
if contact.id == self._parentNodeID: if contact.id == self._parentNodeID:
return return

View file

@ -0,0 +1,274 @@
import time
import logging
from twisted.trial import unittest
from twisted.internet import defer, threads, task
from lbrynet.dht.node import Node
from lbrynet.tests import mocks
from lbrynet.core.utils import generate_id
log = logging.getLogger("lbrynet.tests.util")
# log.addHandler(logging.StreamHandler())
# log.setLevel(logging.DEBUG)
class TestKademliaBase(unittest.TestCase):
timeout = 300.0 # timeout for each test
network_size = 0 # plus lbrynet1, lbrynet2, and lbrynet3 seed nodes
node_ids = None
seed_dns = mocks.MOCK_DHT_SEED_DNS
def _add_next_node(self):
node_id, node_ip = self.mock_node_generator.next()
node = Node(node_id=node_id.decode('hex'), udpPort=4444, peerPort=3333, externalIP=node_ip,
resolve=mocks.resolve, listenUDP=mocks.listenUDP, callLater=self.clock.callLater, clock=self.clock)
self.nodes.append(node)
return node
@defer.inlineCallbacks
def add_node(self):
node = self._add_next_node()
yield node.joinNetwork(
[
("lbrynet1.lbry.io", self._seeds[0].port),
("lbrynet2.lbry.io", self._seeds[1].port),
("lbrynet3.lbry.io", self._seeds[2].port),
]
)
defer.returnValue(node)
def get_node(self, node_id):
for node in self.nodes:
if node.node_id == node_id:
return node
raise KeyError(node_id)
@defer.inlineCallbacks
def pop_node(self):
node = self.nodes.pop()
yield node.stop()
def pump_clock(self, n, step=0.01):
"""
:param n: seconds to run the reactor for
:param step: reactor tick rate (in seconds)
"""
for _ in range(n * 100):
self.clock.advance(step)
def run_reactor(self, seconds, *deferreds):
dl = [threads.deferToThread(self.pump_clock, seconds)]
for d in deferreds:
dl.append(d)
return defer.DeferredList(dl)
@defer.inlineCallbacks
def setUp(self):
self.nodes = []
self._seeds = []
self.clock = task.Clock()
self.mock_node_generator = mocks.mock_node_generator(mock_node_ids=self.node_ids)
join_dl = []
for seed_dns in self.seed_dns:
other_seeds = list(self.seed_dns.keys())
other_seeds.remove(seed_dns)
self._add_next_node()
seed = self.nodes.pop()
self._seeds.append(seed)
join_dl.append(
seed.joinNetwork([(other_seed_dns, 4444) for other_seed_dns in other_seeds])
)
if self.network_size:
for _ in range(self.network_size):
join_dl.append(self.add_node())
yield self.run_reactor(1, *tuple(join_dl))
self.verify_all_nodes_are_routable()
@defer.inlineCallbacks
def tearDown(self):
dl = []
while self.nodes:
dl.append(self.pop_node()) # stop all of the nodes
while self._seeds:
dl.append(self._seeds.pop().stop()) # and the seeds
yield defer.DeferredList(dl)
def verify_all_nodes_are_routable(self):
routable = set()
node_addresses = {node.externalIP for node in self.nodes}
node_addresses = node_addresses.union({node.externalIP for node in self._seeds})
for node in self._seeds:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
for node in self.nodes:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
self.assertSetEqual(routable, node_addresses)
@defer.inlineCallbacks
def verify_all_nodes_are_pingable(self):
ping_replies = {}
ping_dl = []
contacted = set()
def _ping_cb(result, node, replies):
replies[node] = result
for node in self._seeds:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
for node in self.nodes:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
self.run_reactor(2, *ping_dl)
yield threads.deferToThread(time.sleep, 0.1)
node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds})
self.assertSetEqual(node_addresses, contacted)
self.assertDictEqual(ping_replies, {node: "pong" for node in contacted})
class TestKademliaBootstrap(TestKademliaBase):
"""
Test initializing the network / connecting the seed nodes
"""
def test_bootstrap_network(self): # simulates the real network, which has three seeds
self.assertEqual(len(self._seeds[0].contacts), 2)
self.assertEqual(len(self._seeds[1].contacts), 2)
self.assertEqual(len(self._seeds[2].contacts), 2)
self.assertSetEqual(
{self._seeds[0].contacts[0].address, self._seeds[0].contacts[1].address},
{self._seeds[1].externalIP, self._seeds[2].externalIP}
)
self.assertSetEqual(
{self._seeds[1].contacts[0].address, self._seeds[1].contacts[1].address},
{self._seeds[0].externalIP, self._seeds[2].externalIP}
)
self.assertSetEqual(
{self._seeds[2].contacts[0].address, self._seeds[2].contacts[1].address},
{self._seeds[0].externalIP, self._seeds[1].externalIP}
)
def test_all_nodes_are_pingable(self):
return self.verify_all_nodes_are_pingable()
class TestKademliaBootstrapSixteenSeeds(TestKademliaBase):
node_ids = [
'000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000',
'111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111',
'222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222',
'333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333',
'444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444',
'555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555',
'666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666',
'777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777',
'888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888',
'999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999',
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc',
'dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd',
'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
]
@defer.inlineCallbacks
def setUp(self):
self.seed_dns.update(
{
"lbrynet4.lbry.io": "10.42.42.4",
"lbrynet5.lbry.io": "10.42.42.5",
"lbrynet6.lbry.io": "10.42.42.6",
"lbrynet7.lbry.io": "10.42.42.7",
"lbrynet8.lbry.io": "10.42.42.8",
"lbrynet9.lbry.io": "10.42.42.9",
"lbrynet10.lbry.io": "10.42.42.10",
"lbrynet11.lbry.io": "10.42.42.11",
"lbrynet12.lbry.io": "10.42.42.12",
"lbrynet13.lbry.io": "10.42.42.13",
"lbrynet14.lbry.io": "10.42.42.14",
"lbrynet15.lbry.io": "10.42.42.15",
"lbrynet16.lbry.io": "10.42.42.16",
}
)
yield TestKademliaBase.setUp(self)
@defer.inlineCallbacks
def tearDown(self):
yield TestKademliaBase.tearDown(self)
def test_bootstrap_network(self):
pass
def _test_all_nodes_are_pingable(self):
return self.verify_all_nodes_are_pingable()
class Test250NodeNetwork(TestKademliaBase):
network_size = 250
def test_setup_network_and_verify_connectivity(self):
pass
def update_network(self):
import random
dl = []
announced_blobs = []
for node in self.nodes: # random events
if random.randint(0, 10000) < 75 and announced_blobs: # get peers for a blob
log.info('find blob')
blob_hash = random.choice(announced_blobs)
dl.append(node.getPeersForBlob(blob_hash))
if random.randint(0, 10000) < 25: # announce a blob
log.info('announce blob')
blob_hash = generate_id()
announced_blobs.append((blob_hash, node.node_id))
dl.append(node.announceHaveBlob(blob_hash))
random.shuffle(self.nodes)
# kill nodes
while random.randint(0, 100) > 95:
dl.append(self.pop_node())
log.info('pop node')
# add nodes
while random.randint(0, 100) > 95:
dl.append(self.add_node())
log.info('add node')
return tuple(dl), announced_blobs
@defer.inlineCallbacks
def _test_simulate_network(self):
total_blobs = []
for i in range(100):
d, blobs = self.update_network()
total_blobs.extend(blobs)
self.run_reactor(1, *d)
yield threads.deferToThread(time.sleep, 0.1)
routable = set()
node_addresses = {node.externalIP for node in self.nodes}
for node in self.nodes:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
log.warning("difference: %i", len(node_addresses.difference(routable)))
log.info("blobs %i", len(total_blobs))
log.info("step %i, %i nodes", i, len(self.nodes))
self.pump_clock(100)

View file

@ -4,8 +4,7 @@ from twisted.trial import unittest
from lbrynet import conf from lbrynet import conf
from lbrynet.core.StreamDescriptor import get_sd_info from lbrynet.core.StreamDescriptor import get_sd_info
from lbrynet import reflector from lbrynet import reflector
from lbrynet.core import BlobManager from lbrynet.core import BlobManager, PeerManager
from lbrynet.core import PeerManager
from lbrynet.core import Session from lbrynet.core import Session
from lbrynet.core import StreamDescriptor from lbrynet.core import StreamDescriptor
from lbrynet.lbry_file.client import EncryptedFileOptions from lbrynet.lbry_file.client import EncryptedFileOptions
@ -15,6 +14,8 @@ from lbrynet.file_manager import EncryptedFileManager
from lbrynet.tests import mocks from lbrynet.tests import mocks
from lbrynet.tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir from lbrynet.tests.util import mk_db_and_blob_dir, rm_db_and_blob_dir
from lbrynet.tests.mocks import Node
class TestReflector(unittest.TestCase): class TestReflector(unittest.TestCase):
def setUp(self): def setUp(self):
@ -28,7 +29,6 @@ class TestReflector(unittest.TestCase):
wallet = mocks.Wallet() wallet = mocks.Wallet()
peer_manager = PeerManager.PeerManager() peer_manager = PeerManager.PeerManager()
peer_finder = mocks.PeerFinder(5553, peer_manager, 2) peer_finder = mocks.PeerFinder(5553, peer_manager, 2)
hash_announcer = mocks.Announcer()
sd_identifier = StreamDescriptor.StreamDescriptorIdentifier() sd_identifier = StreamDescriptor.StreamDescriptorIdentifier()
self.expected_blobs = [ self.expected_blobs = [
@ -55,13 +55,14 @@ class TestReflector(unittest.TestCase):
db_dir=self.db_dir, db_dir=self.db_dir,
node_id="abcd", node_id="abcd",
peer_finder=peer_finder, peer_finder=peer_finder,
hash_announcer=hash_announcer,
blob_dir=self.blob_dir, blob_dir=self.blob_dir,
peer_port=5553, peer_port=5553,
use_upnp=False, use_upnp=False,
wallet=wallet, wallet=wallet,
blob_tracker_class=mocks.BlobAvailabilityTracker, blob_tracker_class=mocks.BlobAvailabilityTracker,
external_ip="127.0.0.1" external_ip="127.0.0.1",
dht_node_class=Node,
hash_announcer=mocks.Announcer()
) )
self.lbry_file_manager = EncryptedFileManager.EncryptedFileManager(self.session, self.lbry_file_manager = EncryptedFileManager.EncryptedFileManager(self.session,
@ -74,17 +75,17 @@ class TestReflector(unittest.TestCase):
db_dir=self.server_db_dir, db_dir=self.server_db_dir,
node_id="abcd", node_id="abcd",
peer_finder=peer_finder, peer_finder=peer_finder,
hash_announcer=hash_announcer,
blob_dir=self.server_blob_dir, blob_dir=self.server_blob_dir,
peer_port=5553, peer_port=5553,
use_upnp=False, use_upnp=False,
wallet=wallet, wallet=wallet,
blob_tracker_class=mocks.BlobAvailabilityTracker, blob_tracker_class=mocks.BlobAvailabilityTracker,
external_ip="127.0.0.1" external_ip="127.0.0.1",
dht_node_class=Node,
hash_announcer=mocks.Announcer()
) )
self.server_blob_manager = BlobManager.DiskBlobManager(hash_announcer, self.server_blob_manager = BlobManager.DiskBlobManager(self.server_blob_dir,
self.server_blob_dir,
self.server_session.storage) self.server_session.storage)
self.server_lbry_file_manager = EncryptedFileManager.EncryptedFileManager( self.server_lbry_file_manager = EncryptedFileManager.EncryptedFileManager(
@ -364,6 +365,7 @@ class TestReflector(unittest.TestCase):
d.addCallback(lambda _: verify_stream_on_reflector()) d.addCallback(lambda _: verify_stream_on_reflector())
return d return d
def iv_generator(): def iv_generator():
iv = 0 iv = 0
while True: while True:

View file

@ -67,7 +67,7 @@ class TestStreamify(TestCase):
blob_dir=self.blob_dir, peer_port=5553, blob_dir=self.blob_dir, peer_port=5553,
use_upnp=False, rate_limiter=rate_limiter, wallet=wallet, use_upnp=False, rate_limiter=rate_limiter, wallet=wallet,
blob_tracker_class=DummyBlobAvailabilityTracker, blob_tracker_class=DummyBlobAvailabilityTracker,
is_generous=self.is_generous, external_ip="127.0.0.1" is_generous=self.is_generous, external_ip="127.0.0.1", dht_node_class=mocks.Node
) )
self.lbry_file_manager = EncryptedFileManager(self.session, sd_identifier) self.lbry_file_manager = EncryptedFileManager(self.session, sd_identifier)
@ -112,7 +112,7 @@ class TestStreamify(TestCase):
self.session = Session( self.session = Session(
conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=self.db_dir, node_id="abcd", conf.ADJUSTABLE_SETTINGS['data_rate'][1], db_dir=self.db_dir, node_id="abcd",
peer_finder=peer_finder, hash_announcer=hash_announcer, peer_finder=peer_finder, hash_announcer=hash_announcer,
blob_dir=self.blob_dir, peer_port=5553, blob_dir=self.blob_dir, peer_port=5553, dht_node_class=mocks.Node,
use_upnp=False, rate_limiter=rate_limiter, wallet=wallet, use_upnp=False, rate_limiter=rate_limiter, wallet=wallet,
blob_tracker_class=DummyBlobAvailabilityTracker, external_ip="127.0.0.1" blob_tracker_class=DummyBlobAvailabilityTracker, external_ip="127.0.0.1"
) )

View file

@ -1,12 +1,17 @@
import struct
import io import io
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from twisted.internet import defer from twisted.internet import defer, error
from twisted.python.failure import Failure
from lbrynet.core import PTCWallet from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import RequestCanceledError
from lbrynet.core import BlobAvailability from lbrynet.core import BlobAvailability
from lbrynet.core.utils import generate_id
from lbrynet.daemon import ExchangeRateManager as ERM from lbrynet.daemon import ExchangeRateManager as ERM
from lbrynet import conf from lbrynet import conf
from util import debug_kademlia_packet
KB = 2**10 KB = 2**10
@ -18,15 +23,18 @@ class FakeLBRYFile(object):
self.stream_hash = stream_hash self.stream_hash = stream_hash
self.file_name = 'fake_lbry_file' self.file_name = 'fake_lbry_file'
class Node(object): class Node(object):
def __init__(self, *args, **kwargs): def __init__(self, peer_finder=None, peer_manager=None, **kwargs):
pass self.peer_finder = peer_finder
self.peer_manager = peer_manager
self.peerPort = 3333
def joinNetwork(self, *args): def joinNetwork(self, *args):
pass return defer.succeed(True)
def stop(self): def stop(self):
pass return defer.succeed(None)
class FakeNetwork(object): class FakeNetwork(object):
@ -69,6 +77,82 @@ class ExchangeRateManager(ERM.ExchangeRateManager):
feed.market, rates[feed.market]['spot'], rates[feed.market]['ts']) feed.market, rates[feed.market]['spot'], rates[feed.market]['ts'])
class PointTraderKeyExchanger(object):
def __init__(self, wallet):
self.wallet = wallet
self._protocols = []
def send_next_request(self, peer, protocol):
if not protocol in self._protocols:
r = ClientRequest({'public_key': self.wallet.encoded_public_key},
'public_key')
d = protocol.add_request(r)
d.addCallback(self._handle_exchange_response, peer, r, protocol)
d.addErrback(self._request_failed, peer)
self._protocols.append(protocol)
return defer.succeed(True)
else:
return defer.succeed(False)
def _handle_exchange_response(self, response_dict, peer, request, protocol):
assert request.response_identifier in response_dict, \
"Expected %s in dict but did not get it" % request.response_identifier
assert protocol in self._protocols, "Responding protocol is not in our list of protocols"
peer_pub_key = response_dict[request.response_identifier]
self.wallet.set_public_key_for_peer(peer, peer_pub_key)
return True
def _request_failed(self, err, peer):
if not err.check(RequestCanceledError):
return err
class PointTraderKeyQueryHandlerFactory(object):
def __init__(self, wallet):
self.wallet = wallet
def build_query_handler(self):
q_h = PointTraderKeyQueryHandler(self.wallet)
return q_h
def get_primary_query_identifier(self):
return 'public_key'
def get_description(self):
return ("Point Trader Address - an address for receiving payments on the "
"point trader testing network")
class PointTraderKeyQueryHandler(object):
def __init__(self, wallet):
self.wallet = wallet
self.query_identifiers = ['public_key']
self.public_key = None
self.peer = None
def register_with_request_handler(self, request_handler, peer):
self.peer = peer
request_handler.register_query_handler(self, self.query_identifiers)
def handle_queries(self, queries):
if self.query_identifiers[0] in queries:
new_encoded_pub_key = queries[self.query_identifiers[0]]
try:
RSA.importKey(new_encoded_pub_key)
except (ValueError, TypeError, IndexError):
return defer.fail(Failure(ValueError("Client sent an invalid public key")))
self.public_key = new_encoded_pub_key
self.wallet.set_public_key_for_peer(self.peer, self.public_key)
fields = {'public_key': self.wallet.encoded_public_key}
return defer.succeed(fields)
if self.public_key is None:
return defer.fail(Failure(ValueError("Expected but did not receive a public key")))
else:
return defer.succeed({})
class Wallet(object): class Wallet(object):
def __init__(self): def __init__(self):
@ -94,10 +178,10 @@ class Wallet(object):
return defer.succeed(True) return defer.succeed(True)
def get_info_exchanger(self): def get_info_exchanger(self):
return PTCWallet.PointTraderKeyExchanger(self) return PointTraderKeyExchanger(self)
def get_wallet_info_query_handler_factory(self): def get_wallet_info_query_handler_factory(self):
return PTCWallet.PointTraderKeyQueryHandlerFactory(self) return PointTraderKeyQueryHandlerFactory(self)
def reserve_points(self, *args): def reserve_points(self, *args):
return True return True
@ -164,6 +248,9 @@ class Announcer(object):
def stop(self): def stop(self):
pass pass
def get_next_announce_time(self):
return 0
class GenFile(io.RawIOBase): class GenFile(io.RawIOBase):
def __init__(self, size, pattern): def __init__(self, size, pattern):
@ -313,4 +400,87 @@ def mock_conf_settings(obj, settings={}):
obj.addCleanup(_reset_settings) obj.addCleanup(_reset_settings)
MOCK_DHT_NODES = [
"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
"DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF",
]
MOCK_DHT_SEED_DNS = { # these map to mock nodes 0, 1, and 2
"lbrynet1.lbry.io": "10.42.42.1",
"lbrynet2.lbry.io": "10.42.42.2",
"lbrynet3.lbry.io": "10.42.42.3",
}
def resolve(name, timeout=(1, 3, 11, 45)):
if name not in MOCK_DHT_SEED_DNS:
return defer.fail(error.DNSLookupError(name))
return defer.succeed(MOCK_DHT_SEED_DNS[name])
class MockUDPTransport(object):
def __init__(self, address, port, max_packet_size, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._node = protocol._node
def write(self, data, address):
dest = MockNetwork.peers[address][0]
debug_kademlia_packet(data, (self.address, self.port), address, self._node)
dest.datagramReceived(data, (self.address, self.port))
class MockUDPPort(object):
def __init__(self, protocol):
self.protocol = protocol
def startListening(self, reason=None):
return self.protocol.startProtocol()
def stopListening(self, reason=None):
return self.protocol.stopProtocol()
class MockNetwork(object):
peers = {} # (interface, port): (protocol, max_packet_size)
@classmethod
def add_peer(cls, port, protocol, interface, maxPacketSize):
interface = protocol._node.externalIP
protocol.transport = MockUDPTransport(interface, port, maxPacketSize, protocol)
cls.peers[(interface, port)] = (protocol, maxPacketSize)
def listenUDP(port, protocol, interface='', maxPacketSize=8192):
MockNetwork.add_peer(port, protocol, interface, maxPacketSize)
return MockUDPPort(protocol)
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES):
if mock_node_ids is None:
mock_node_ids = MOCK_DHT_NODES
for num, node_ip in enumerate(address_generator()):
if count and num >= count:
break
if num >= len(mock_node_ids):
node_id = generate_id().encode('hex')
else:
node_id = mock_node_ids[num]
yield (node_id, node_ip)

View file

@ -1,55 +1,82 @@
import tempfile
import shutil
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task from twisted.internet import defer, reactor, threads
from lbrynet.core import utils
from lbrynet.tests.util import random_lbry_hash from lbrynet.tests.util import random_lbry_hash
from lbrynet.dht.hashannouncer import DHTHashAnnouncer
from lbrynet.core.call_later_manager import CallLaterManager
from lbrynet.database.storage import SQLiteStorage
class MocDHTNode(object): class MocDHTNode(object):
def __init__(self): def __init__(self, announce_will_fail=False):
# if announce_will_fail is True,
# announceHaveBlob will return empty dict
self.call_later_manager = CallLaterManager
self.call_later_manager.setup(reactor.callLater)
self.blobs_announced = 0 self.blobs_announced = 0
self.announce_will_fail = announce_will_fail
def announceHaveBlob(self, blob): def announceHaveBlob(self, blob):
self.blobs_announced += 1 if self.announce_will_fail:
return defer.succeed(True) return_val = {}
class MocSupplier(object):
def __init__(self, blobs_to_announce):
self.blobs_to_announce = blobs_to_announce
self.announced = False
def hashes_to_announce(self):
if not self.announced:
self.announced = True
return defer.succeed(self.blobs_to_announce)
else: else:
return defer.succeed([]) return_val = {blob: ["ab"*48]}
self.blobs_announced += 1
d = defer.Deferred()
self.call_later_manager.call_later(1, d.callback, return_val)
return d
class DHTHashAnnouncerTest(unittest.TestCase): class DHTHashAnnouncerTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self): def setUp(self):
from lbrynet.conf import initialize_settings
initialize_settings()
self.num_blobs = 10 self.num_blobs = 10
self.blobs_to_announce = [] self.blobs_to_announce = []
for i in range(0, self.num_blobs): for i in range(0, self.num_blobs):
self.blobs_to_announce.append(random_lbry_hash()) self.blobs_to_announce.append(random_lbry_hash())
self.clock = task.Clock()
self.dht_node = MocDHTNode() self.dht_node = MocDHTNode()
utils.call_later = self.clock.callLater self.dht_node.peerPort = 3333
from lbrynet.core.server.DHTHashAnnouncer import DHTHashAnnouncer self.dht_node.clock = reactor
self.announcer = DHTHashAnnouncer(self.dht_node, peer_port=3333) self.db_dir = tempfile.mkdtemp()
self.supplier = MocSupplier(self.blobs_to_announce) self.storage = SQLiteStorage(self.db_dir)
self.announcer.add_supplier(self.supplier) yield self.storage.setup()
self.announcer = DHTHashAnnouncer(self.dht_node, self.storage, 10)
for blob_hash in self.blobs_to_announce:
yield self.storage.add_completed_blob(blob_hash, 100, 0, 1)
@defer.inlineCallbacks
def tearDown(self):
self.dht_node.call_later_manager.stop()
yield self.storage.stop()
yield threads.deferToThread(shutil.rmtree, self.db_dir)
@defer.inlineCallbacks
def test_announce_fail(self):
# test what happens when node.announceHaveBlob() returns empty dict
self.dht_node.announce_will_fail = True
d = yield self.announcer.manage()
yield d
@defer.inlineCallbacks
def test_basic(self): def test_basic(self):
self.announcer._announce_available_hashes() d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(self.announcer.hash_queue_size(), self.announcer.CONCURRENT_ANNOUNCERS) self.assertEqual(len(self.announcer.hash_queue), self.num_blobs)
self.clock.advance(1) yield d
self.assertEqual(self.dht_node.blobs_announced, self.num_blobs) self.assertEqual(self.dht_node.blobs_announced, self.num_blobs)
self.assertEqual(self.announcer.hash_queue_size(), 0) self.assertEqual(len(self.announcer.hash_queue), 0)
@defer.inlineCallbacks
def test_immediate_announce(self): def test_immediate_announce(self):
# Test that immediate announce puts a hash at the front of the queue # Test that immediate announce puts a hash at the front of the queue
self.announcer._announce_available_hashes() d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(len(self.announcer.hash_queue), self.num_blobs)
blob_hash = random_lbry_hash() blob_hash = random_lbry_hash()
self.announcer.immediate_announce([blob_hash]) self.announcer.immediate_announce([blob_hash])
self.assertEqual(self.announcer.hash_queue_size(), self.announcer.CONCURRENT_ANNOUNCERS+1) self.assertEqual(len(self.announcer.hash_queue), self.num_blobs+1)
self.assertEqual(blob_hash, self.announcer.hash_queue[0][0]) self.assertEqual(blob_hash, self.announcer.hash_queue[-1])
yield d

View file

@ -8,7 +8,6 @@ from twisted.internet import defer, threads
from lbrynet.tests.util import random_lbry_hash from lbrynet.tests.util import random_lbry_hash
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.core.HashAnnouncer import DummyHashAnnouncer
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet.core.Peer import Peer from lbrynet.core.Peer import Peer
from lbrynet import conf from lbrynet import conf
@ -21,8 +20,7 @@ class BlobManagerTest(unittest.TestCase):
conf.initialize_settings() conf.initialize_settings()
self.blob_dir = tempfile.mkdtemp() self.blob_dir = tempfile.mkdtemp()
self.db_dir = tempfile.mkdtemp() self.db_dir = tempfile.mkdtemp()
hash_announcer = DummyHashAnnouncer() self.bm = DiskBlobManager(self.blob_dir, SQLiteStorage(self.db_dir))
self.bm = DiskBlobManager(hash_announcer, self.blob_dir, SQLiteStorage(self.db_dir))
self.peer = Peer('somehost', 22) self.peer = Peer('somehost', 22)
yield self.bm.storage.setup() yield self.bm.storage.setup()

View file

@ -195,7 +195,7 @@ class StreamStorageTests(StorageTest):
should_announce_count = yield self.storage.count_should_announce_blobs() should_announce_count = yield self.storage.count_should_announce_blobs()
self.assertEqual(should_announce_count, 2) self.assertEqual(should_announce_count, 2)
should_announce_hashes = yield self.storage.get_blobs_to_announce(FakeAnnouncer()) should_announce_hashes = yield self.storage.get_blobs_to_announce()
self.assertSetEqual(set(should_announce_hashes), {sd_hash, blob1}) self.assertSetEqual(set(should_announce_hashes), {sd_hash, blob1})
stream_hashes = yield self.storage.get_all_streams() stream_hashes = yield self.storage.get_all_streams()

View file

@ -198,7 +198,7 @@ class NodeLookupTest(unittest.TestCase):
h = hashlib.sha384() h = hashlib.sha384()
h.update('node1') h.update('node1')
node_id = str(h.digest()) node_id = str(h.digest())
self.node = lbrynet.dht.node.Node(node_id, 4000, None, None, self._protocol) self.node = lbrynet.dht.node.Node(node_id=node_id, udpPort=4000, networkProtocol=self._protocol)
self.updPort = 81173 self.updPort = 81173
self.contactsAmount = 80 self.contactsAmount = 80
# Reinitialise the routing table # Reinitialise the routing table

View file

@ -1,32 +1,55 @@
import time import time
import unittest import unittest
import twisted.internet.selectreactor from twisted.internet.task import Clock
from twisted.internet import defer
import lbrynet.dht.protocol import lbrynet.dht.protocol
import lbrynet.dht.contact import lbrynet.dht.contact
import lbrynet.dht.constants import lbrynet.dht.constants
import lbrynet.dht.msgtypes import lbrynet.dht.msgtypes
from lbrynet.dht.error import TimeoutError from lbrynet.dht.error import TimeoutError
from lbrynet.dht.node import Node, rpcmethod from lbrynet.dht.node import Node, rpcmethod
from lbrynet.tests.mocks import listenUDP, resolve
from lbrynet.core.call_later_manager import CallLaterManager
import logging
log = logging.getLogger()
class KademliaProtocolTest(unittest.TestCase): class KademliaProtocolTest(unittest.TestCase):
""" Test case for the Protocol class """ """ Test case for the Protocol class """
def setUp(self): udpPort = 9182
del lbrynet.dht.protocol.reactor
lbrynet.dht.protocol.reactor = twisted.internet.selectreactor.SelectReactor()
self.node = Node(node_id='1' * 48, udpPort=9182, externalIP="127.0.0.1")
self.protocol = lbrynet.dht.protocol.KademliaProtocol(self.node)
def setUp(self):
self._reactor = Clock()
CallLaterManager.setup(self._reactor.callLater)
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
def tearDown(self):
CallLaterManager.stop()
del self._reactor
@defer.inlineCallbacks
def testReactor(self): def testReactor(self):
""" Tests if the reactor can start/stop the protocol correctly """ """ Tests if the reactor can start/stop the protocol correctly """
lbrynet.dht.protocol.reactor.listenUDP(0, self.protocol)
lbrynet.dht.protocol.reactor.callLater(0, lbrynet.dht.protocol.reactor.stop) d = defer.Deferred()
lbrynet.dht.protocol.reactor.run() self._reactor.callLater(1, d.callback, True)
self._reactor.advance(1)
result = yield d
self.assertTrue(result)
def testRPCTimeout(self): def testRPCTimeout(self):
""" Tests if a RPC message sent to a dead remote node times out correctly """ """ Tests if a RPC message sent to a dead remote node times out correctly """
dead_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
dead_node.start_listening()
dead_node.stop()
self._reactor.pump([1 for _ in range(10)])
dead_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(dead_contact)
@rpcmethod @rpcmethod
def fake_ping(*args, **kwargs): def fake_ping(*args, **kwargs):
@ -38,19 +61,18 @@ class KademliaProtocolTest(unittest.TestCase):
real_attempts = lbrynet.dht.constants.rpcAttempts real_attempts = lbrynet.dht.constants.rpcAttempts
lbrynet.dht.constants.rpcAttempts = 1 lbrynet.dht.constants.rpcAttempts = 1
lbrynet.dht.constants.rpcTimeout = 1 lbrynet.dht.constants.rpcTimeout = 1
self.node.ping = fake_ping self.node.ping = fake_ping
deadContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
self.node.addContact(deadContact)
# Make sure the contact was added # Make sure the contact was added
self.failIf(deadContact not in self.node.contacts, self.failIf(dead_contact not in self.node.contacts,
'Contact not added to fake node (error in test code)') 'Contact not added to fake node (error in test code)')
lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol) self.node.start_listening()
# Run the PING RPC (which should raise a timeout error) # Run the PING RPC (which should raise a timeout error)
df = self.protocol.sendRPC(deadContact, 'ping', {}) df = self.node._protocol.sendRPC(dead_contact, 'ping', {})
def check_timeout(err): def check_timeout(err):
self.assertEqual(type(err), TimeoutError) self.assertEqual(err.type, TimeoutError)
df.addErrback(check_timeout) df.addErrback(check_timeout)
@ -61,20 +83,24 @@ class KademliaProtocolTest(unittest.TestCase):
# See if the contact was removed due to the timeout # See if the contact was removed due to the timeout
def check_removed_contact(): def check_removed_contact():
self.failIf(deadContact in self.node.contacts, self.failIf(dead_contact in self.node.contacts,
'Contact was not removed after RPC timeout; check exception types.') 'Contact was not removed after RPC timeout; check exception types.')
df.addCallback(lambda _: reset_values()) df.addCallback(lambda _: reset_values())
# Stop the reactor if a result arrives (timeout or not) # Stop the reactor if a result arrives (timeout or not)
df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop())
df.addCallback(lambda _: check_removed_contact()) df.addCallback(lambda _: check_removed_contact())
lbrynet.dht.protocol.reactor.run() self._reactor.pump([1 for _ in range(20)])
def testRPCRequest(self): def testRPCRequest(self):
""" Tests if a valid RPC request is executed and responded to correctly """ """ Tests if a valid RPC request is executed and responded to correctly """
remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remoteContact) self.node.addContact(remoteContact)
self.error = None self.error = None
def handleError(f): def handleError(f):
@ -87,16 +113,18 @@ class KademliaProtocolTest(unittest.TestCase):
% (expectedResult, result) % (expectedResult, result)
# Publish the "local" node on the network # Publish the "local" node on the network
lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol) self.node.start_listening()
# Simulate the RPC # Simulate the RPC
df = remoteContact.ping() df = remoteContact.ping()
df.addCallback(handleResult) df.addCallback(handleResult)
df.addErrback(handleError) df.addErrback(handleError)
df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop())
lbrynet.dht.protocol.reactor.run() for _ in range(10):
self._reactor.advance(1)
self.failIf(self.error, self.error) self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage # The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.protocol._sentMessages), 0, self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, ' 'The protocol is still waiting for a RPC result, '
'but the transaction is already done!') 'but the transaction is already done!')
@ -105,8 +133,12 @@ class KademliaProtocolTest(unittest.TestCase):
Verifies that a RPC request for an existing but unpublished Verifies that a RPC request for an existing but unpublished
method is denied, and that the associated (remote) exception gets method is denied, and that the associated (remote) exception gets
raised locally """ raised locally """
remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol) remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
self.node.addContact(remoteContact) resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remote_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remote_contact)
self.error = None self.error = None
def handleError(f): def handleError(f):
@ -123,24 +155,26 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'The remote method executed successfully, returning: "%s"; ' \ self.error = 'The remote method executed successfully, returning: "%s"; ' \
'this RPC should not have been allowed.' % result 'this RPC should not have been allowed.' % result
# Publish the "local" node on the network self.node.start_listening()
lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol) self._reactor.pump([1 for _ in range(10)])
# Simulate the RPC # Simulate the RPC
df = remoteContact.not_a_rpc_function() df = remote_contact.not_a_rpc_function()
df.addCallback(handleResult) df.addCallback(handleResult)
df.addErrback(handleError) df.addErrback(handleError)
df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop()) self._reactor.pump([1 for _ in range(10)])
lbrynet.dht.protocol.reactor.run()
self.failIf(self.error, self.error) self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage # The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.protocol._sentMessages), 0, self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, ' 'The protocol is still waiting for a RPC result, '
'but the transaction is already done!') 'but the transaction is already done!')
def testRPCRequestArgs(self): def testRPCRequestArgs(self):
""" Tests if an RPC requiring arguments is executed correctly """ """ Tests if an RPC requiring arguments is executed correctly """
remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol) remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
self.node.addContact(remoteContact) resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remote_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remote_contact)
self.error = None self.error = None
def handleError(f): def handleError(f):
@ -153,15 +187,14 @@ class KademliaProtocolTest(unittest.TestCase):
(expectedResult, result) (expectedResult, result)
# Publish the "local" node on the network # Publish the "local" node on the network
lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol) self.node.start_listening()
# Simulate the RPC # Simulate the RPC
df = remoteContact.ping() df = remote_contact.ping()
df.addCallback(handleResult) df.addCallback(handleResult)
df.addErrback(handleError) df.addErrback(handleError)
df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop()) self._reactor.pump([1 for _ in range(10)])
lbrynet.dht.protocol.reactor.run()
self.failIf(self.error, self.error) self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage # The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.protocol._sentMessages), 0, self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, ' 'The protocol is still waiting for a RPC result, '
'but the transaction is already done!') 'but the transaction is already done!')

View file

@ -1,12 +1,12 @@
import hashlib import hashlib
import unittest import unittest
#from lbrynet.dht import contact, routingtable, constants
import lbrynet.dht.constants import lbrynet.dht.constants
import lbrynet.dht.routingtable import lbrynet.dht.routingtable
import lbrynet.dht.contact import lbrynet.dht.contact
import lbrynet.dht.node import lbrynet.dht.node
import lbrynet.dht.distance
class FakeRPCProtocol(object): class FakeRPCProtocol(object):
""" Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """ """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
@ -42,15 +42,15 @@ class TreeRoutingTableTest(unittest.TestCase):
basicTestList = [('123456789', '123456789', 0L), ('12345', '98765', 34527773184L)] basicTestList = [('123456789', '123456789', 0L), ('12345', '98765', 34527773184L)]
for test in basicTestList: for test in basicTestList:
result = lbrynet.dht.node.Distance(test[0])(test[1]) result = lbrynet.dht.distance.Distance(test[0])(test[1])
self.failIf(result != test[2], 'Result of _distance() should be %s but %s returned' % self.failIf(result != test[2], 'Result of _distance() should be %s but %s returned' %
(test[2], result)) (test[2], result))
baseIp = '146.64.19.111' baseIp = '146.64.19.111'
ipTestList = ['146.64.29.222', '192.68.19.333'] ipTestList = ['146.64.29.222', '192.68.19.333']
distanceOne = lbrynet.dht.node.Distance(baseIp)(ipTestList[0]) distanceOne = lbrynet.dht.distance.Distance(baseIp)(ipTestList[0])
distanceTwo = lbrynet.dht.node.Distance(baseIp)(ipTestList[1]) distanceTwo = lbrynet.dht.distance.Distance(baseIp)(ipTestList[1])
self.failIf(distanceOne > distanceTwo, '%s should be closer to the base ip %s than %s' % self.failIf(distanceOne > distanceTwo, '%s should be closer to the base ip %s than %s' %
(ipTestList[0], baseIp, ipTestList[1])) (ipTestList[0], baseIp, ipTestList[1]))
@ -184,10 +184,6 @@ class TreeRoutingTableTest(unittest.TestCase):
'New contact should have been discarded (since RPC is faked in this test)') 'New contact should have been discarded (since RPC is faked in this test)')
class KeyErrorFixedTest(unittest.TestCase): class KeyErrorFixedTest(unittest.TestCase):
""" Basic tests case for boolean operators on the Contact class """ """ Basic tests case for boolean operators on the Contact class """

View file

@ -8,7 +8,6 @@ from lbrynet.database.storage import SQLiteStorage
from lbrynet.core.StreamDescriptor import get_sd_info, BlobStreamDescriptorReader from lbrynet.core.StreamDescriptor import get_sd_info, BlobStreamDescriptorReader
from lbrynet.core import BlobManager from lbrynet.core import BlobManager
from lbrynet.core import Session from lbrynet.core import Session
from lbrynet.core.server import DHTHashAnnouncer
from lbrynet.file_manager import EncryptedFileCreator from lbrynet.file_manager import EncryptedFileCreator
from lbrynet.file_manager import EncryptedFileManager from lbrynet.file_manager import EncryptedFileManager
from lbrynet.tests import mocks from lbrynet.tests import mocks
@ -32,10 +31,7 @@ class CreateEncryptedFileTest(unittest.TestCase):
self.session = mock.Mock(spec=Session.Session)(None, None) self.session = mock.Mock(spec=Session.Session)(None, None)
self.session.payment_rate_manager.min_blob_data_payment_rate = 0 self.session.payment_rate_manager.min_blob_data_payment_rate = 0
self.blob_manager = BlobManager.DiskBlobManager(self.tmp_blob_dir, SQLiteStorage(self.tmp_db_dir))
hash_announcer = DHTHashAnnouncer.DHTHashAnnouncer(None, None)
self.blob_manager = BlobManager.DiskBlobManager(
hash_announcer, self.tmp_blob_dir, SQLiteStorage(self.tmp_db_dir))
self.session.blob_manager = self.blob_manager self.session.blob_manager = self.blob_manager
self.session.storage = self.session.blob_manager.storage self.session.storage = self.session.blob_manager.storage
self.file_manager = EncryptedFileManager.EncryptedFileManager(self.session, object()) self.file_manager = EncryptedFileManager.EncryptedFileManager(self.session, object())
@ -74,6 +70,7 @@ class CreateEncryptedFileTest(unittest.TestCase):
# this comes from the database, the blobs returned are sorted # this comes from the database, the blobs returned are sorted
sd_info = yield get_sd_info(self.session.storage, lbry_file.stream_hash, include_blobs=True) sd_info = yield get_sd_info(self.session.storage, lbry_file.stream_hash, include_blobs=True)
self.assertDictEqual(sd_info, sd_file_info) self.assertDictEqual(sd_info, sd_file_info)
self.assertListEqual(sd_info['blobs'], sd_file_info['blobs'])
self.assertEqual(sd_info['stream_hash'], expected_stream_hash) self.assertEqual(sd_info['stream_hash'], expected_stream_hash)
self.assertEqual(len(sd_info['blobs']), 3) self.assertEqual(len(sd_info['blobs']), 3)
self.assertNotEqual(sd_info['blobs'][0]['length'], 0) self.assertNotEqual(sd_info['blobs'][0]['length'], 0)

View file

@ -5,24 +5,36 @@ import os
import tempfile import tempfile
import shutil import shutil
import mock import mock
import logging
from lbrynet.dht.encoding import Bencode
from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
_encode = Bencode()
_datagram_formatter = DefaultFormat()
DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1) DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1)
DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple()) DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple())
log = logging.getLogger("lbrynet.tests.util")
def mk_db_and_blob_dir(): def mk_db_and_blob_dir():
db_dir = tempfile.mkdtemp() db_dir = tempfile.mkdtemp()
blob_dir = tempfile.mkdtemp() blob_dir = tempfile.mkdtemp()
return db_dir, blob_dir return db_dir, blob_dir
def rm_db_and_blob_dir(db_dir, blob_dir): def rm_db_and_blob_dir(db_dir, blob_dir):
shutil.rmtree(db_dir, ignore_errors=True) shutil.rmtree(db_dir, ignore_errors=True)
shutil.rmtree(blob_dir, ignore_errors=True) shutil.rmtree(blob_dir, ignore_errors=True)
def random_lbry_hash(): def random_lbry_hash():
return binascii.b2a_hex(os.urandom(48)) return binascii.b2a_hex(os.urandom(48))
def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP): def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP):
iso_time = time.mktime(timestamp.timetuple()) iso_time = time.mktime(timestamp.timetuple())
patcher = mock.patch('time.time') patcher = mock.patch('time.time')
@ -37,5 +49,28 @@ def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP):
patcher.start().return_value = timestamp patcher.start().return_value = timestamp
test_case.addCleanup(patcher.stop) test_case.addCleanup(patcher.stop)
def is_android(): def is_android():
return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way
def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG:
return
try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds())
elif isinstance(packet, ResponseMessage):
if isinstance(packet.response, (str, unicode)):
log.debug("response %s <-- %s %s (node time %s)", destination[0], source[0], packet.response,
node.clock.seconds())
else:
log.debug("response %s <-- %s %i contacts (node time %s)", destination[0], source[0],
len(packet.response), node.clock.seconds())
elif isinstance(packet, ErrorMessage):
log.error("error %s <-- %s %s (node time %s)", destination[0], source[0], packet.exceptionType,
node.clock.seconds())
except DecodeError:
log.exception("decode error %s --> %s (node time %s)", source[0], destination[0], node.clock.seconds())

0
requirements_testing.txt Normal file
View file

View file

@ -73,11 +73,33 @@ def connect(port=None):
yield reactor.stop() yield reactor.stop()
def getApproximateTotalDHTNodes(node):
from lbrynet.dht import constants
# get the deepest bucket and the number of contacts in that bucket and multiply it
# by the number of equivalently deep buckets in the whole DHT to get a really bad
# estimate!
bucket = node._routingTable._buckets[node._routingTable._kbucketIndex(node.node_id)]
num_in_bucket = len(bucket._contacts)
factor = (2 ** constants.key_bits) / (bucket.rangeMax - bucket.rangeMin)
return num_in_bucket * factor
def getApproximateTotalHashes(node):
# Divide the number of hashes we know about by k to get a really, really, really
# bad estimate of the average number of hashes per node, then multiply by the
# approximate number of nodes to get a horrendous estimate of the total number
# of hashes in the DHT
num_in_data_store = len(node._dataStore._dict)
if num_in_data_store == 0:
return 0
return num_in_data_store * getApproximateTotalDHTNodes(node) / 8
@defer.inlineCallbacks @defer.inlineCallbacks
def find(node): def find(node):
try: try:
log.info("Approximate number of nodes in DHT: %s", str(node.getApproximateTotalDHTNodes())) log.info("Approximate number of nodes in DHT: %s", str(getApproximateTotalDHTNodes(node)))
log.info("Approximate number of blobs in DHT: %s", str(node.getApproximateTotalHashes())) log.info("Approximate number of blobs in DHT: %s", str(getApproximateTotalHashes(node)))
h = "578f5e82da7db97bfe0677826d452cc0c65406a8e986c9caa126af4ecdbf4913daad2f7f5d1fb0ffec17d0bf8f187f5a" h = "578f5e82da7db97bfe0677826d452cc0c65406a8e986c9caa126af4ecdbf4913daad2f7f5d1fb0ffec17d0bf8f187f5a"
peersFake = yield node.getPeersForBlob(h.decode("hex")) peersFake = yield node.getPeersForBlob(h.decode("hex"))

View file

@ -14,7 +14,7 @@ from lbrynet.core import log_support, Wallet, Peer
from lbrynet.core.SinglePeerDownloader import SinglePeerDownloader from lbrynet.core.SinglePeerDownloader import SinglePeerDownloader
from lbrynet.core.StreamDescriptor import BlobStreamDescriptorReader from lbrynet.core.StreamDescriptor import BlobStreamDescriptorReader
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.core.HashAnnouncer import DummyHashAnnouncer from lbrynet.dht.hashannouncer import DummyHashAnnouncer
log = logging.getLogger() log = logging.getLogger()

View file

@ -5,12 +5,11 @@ import sys
from twisted.internet import defer from twisted.internet import defer
from twisted.internet import reactor from twisted.internet import reactor
from twisted.protocols import basic
from twisted.web.client import FileBodyProducer from twisted.web.client import FileBodyProducer
from lbrynet import conf from lbrynet import conf
from lbrynet.core import log_support from lbrynet.core import log_support
from lbrynet.core.HashAnnouncer import DummyHashAnnouncer from lbrynet.dht.hashannouncer import DummyHashAnnouncer
from lbrynet.core.BlobManager import DiskBlobManager from lbrynet.core.BlobManager import DiskBlobManager
from lbrynet.cryptstream.CryptStreamCreator import CryptStreamCreator from lbrynet.cryptstream.CryptStreamCreator import CryptStreamCreator

View file

@ -17,7 +17,7 @@ from twisted.protocols import basic
from lbrynet import conf from lbrynet import conf
from lbrynet.core import BlobManager from lbrynet.core import BlobManager
from lbrynet.core import HashAnnouncer from lbrynet.dht import hashannouncer
from lbrynet.core import log_support from lbrynet.core import log_support
from lbrynet.cryptstream import CryptStreamCreator from lbrynet.cryptstream import CryptStreamCreator
@ -52,7 +52,7 @@ def reseed_file(input_file, sd_blob):
sd_blob = SdBlob.new_instance(sd_blob) sd_blob = SdBlob.new_instance(sd_blob)
db_dir = conf.settings['data_dir'] db_dir = conf.settings['data_dir']
blobfile_dir = os.path.join(db_dir, "blobfiles") blobfile_dir = os.path.join(db_dir, "blobfiles")
announcer = HashAnnouncer.DummyHashAnnouncer() announcer = hashannouncer.DummyHashAnnouncer()
blob_manager = BlobManager.DiskBlobManager(announcer, blobfile_dir, db_dir) blob_manager = BlobManager.DiskBlobManager(announcer, blobfile_dir, db_dir)
yield blob_manager.setup() yield blob_manager.setup()
creator = CryptStreamCreator.CryptStreamCreator( creator = CryptStreamCreator.CryptStreamCreator(