This commit is contained in:
jleute 2018-06-09 00:41:38 +02:00
commit e381e4c063
73 changed files with 3604 additions and 2969 deletions

View file

@ -23,7 +23,7 @@ clone_folder: c:\projects\lbry
test_script: test_script:
- cd C:\projects\lbry\ - cd C:\projects\lbry\
- pip install cython - pip install cython
- pip install mock pylint unqlite - pip install mock pylint unqlite Faker
- pip install . - pip install .
- pylint lbrynet - pylint lbrynet
# disable tests for now so that appveyor can build the app # disable tests for now so that appveyor can build the app

View file

@ -35,6 +35,7 @@ before_install:
install: install:
- pip install -U pip==9.0.3 - pip install -U pip==9.0.3
- pip install -r requirements.txt - pip install -r requirements.txt
- pip install -r requirements_testing.txt
- pip install . - pip install .
script: script:

View file

@ -13,6 +13,7 @@ at anytime.
* *
### Fixed ### Fixed
* fix prm/brm typo
* handling error from dht clients with old `ping` method * handling error from dht clients with old `ping` method
* blobs not being re-announced if no peers successfully stored, now failed announcements are re-queued * blobs not being re-announced if no peers successfully stored, now failed announcements are re-queued
* issue where an `AuthAPIClient` (used by `lbrynet-cli`) would fail to update its session secret and keep making new auth sessions, with every other request failing * issue where an `AuthAPIClient` (used by `lbrynet-cli`) would fail to update its session secret and keep making new auth sessions, with every other request failing
@ -26,18 +27,21 @@ at anytime.
* *
### Changed ### Changed
* refactor `add_completed_blobs` on storage.py, simplifying into less queries
* check headers file integrity on startup, removing/truncating the file to force re-download when necessary * check headers file integrity on startup, removing/truncating the file to force re-download when necessary
* support partial headers file download from S3 * support partial headers file download from S3
* changed txrequests for treq * changed txrequests for treq
* changed cryptography version to 2.2.2 * changed cryptography version to 2.2.2
* removed pycrypto dependency, replacing all calls to cryptography * removed pycrypto dependency, replacing all calls to cryptography
* full verification of streams only during migration instead of every startup
* database batching functions for starting up the file manager
* several internal dht functions to use inlineCallbacks * several internal dht functions to use inlineCallbacks
* `DHTHashAnnouncer` and `Node` manage functions to use `LoopingCall`s instead of scheduling with `callLater`. * `DHTHashAnnouncer` and `Node` manage functions to use `LoopingCall`s instead of scheduling with `callLater`.
* `store` kademlia rpc method to block on the call finishing and to return storing peer information * `store` kademlia rpc method to block on the call finishing and to return storing peer information
* refactored `DHTHashAnnouncer` to longer use locks, use a `DeferredSemaphore` to limit concurrent announcers * refactored `DHTHashAnnouncer` to no longer use locks, use a `DeferredSemaphore` to limit concurrent announcers
* decoupled `DiskBlobManager` from `DHTHashAnnouncer` * decoupled `DiskBlobManager` from `DHTHashAnnouncer`
* blob hashes to announce to be controlled by`SQLiteStorage` * blob hashes to announce to be controlled by`SQLiteStorage`
* kademlia protocol to not delay writes to the UDP socket * kademlia protocol to minimally delay writes to the UDP socket
* `reactor` and `callLater`, `listenUDP`, and `resolve` functions to be configurable (to allow easier testing) * `reactor` and `callLater`, `listenUDP`, and `resolve` functions to be configurable (to allow easier testing)
* calls to get the current time to use `reactor.seconds` (to control callLater and LoopingCall timing in tests) * calls to get the current time to use `reactor.seconds` (to control callLater and LoopingCall timing in tests)
* `blob_announce` to queue the blob announcement but not block on it * `blob_announce` to queue the blob announcement but not block on it
@ -53,18 +57,35 @@ at anytime.
* download blockchain headers from s3 before starting the wallet when the local height is more than `s3_headers_depth` (a config setting) blocks behind * download blockchain headers from s3 before starting the wallet when the local height is more than `s3_headers_depth` (a config setting) blocks behind
* track successful reflector uploads in sqlite to minimize how many streams are attempted by auto re-reflect * track successful reflector uploads in sqlite to minimize how many streams are attempted by auto re-reflect
* increase the default `auto_re_reflect_interval` to a day * increase the default `auto_re_reflect_interval` to a day
* predictable result sorting for `claim_list` and `claim_list_mine`
* changed the bucket splitting condition in the dht routing table to be more aggressive
* ping dht nodes who have stored to us periodically to determine whether we should include them as an active peer for the hash when we are queried. Nodes that are known to be not reachable by the node storing the record are no longer returned as peers by the storing node.
* temporarily disabled data price negotiation, treat all data as free
* changed dht bootstrap join process to better populate the routing table initially
* cache dht node tokens used during announcement to minimize the number of requests that are needed
* implement BEP0005 dht rules to classify nodes as good, bad, or unknown and for when to add them to the routing table (http://www.bittorrent.org/beps/bep_0005.html)
* refactored internal dht contact class to track failure counts/times, the time the contact last replied to us, and the time the node last requested something fom us
* refactored dht iterativeFind
* sort dht contacts returned by `findCloseNodes` in the routing table
* disabled Cryptonator price feed
### Added ### Added
* virtual kademlia network and mock udp transport for dht integration tests * virtual kademlia network and mock udp transport for dht integration tests
* integration tests for bootstrapping the dht * functional tests for bootstrapping the dht, announcing and expiring hashes, finding and pinging nodes, protocol version 0/1 backwards/forwards compatibility, and rejoining the network
* configurable `concurrent_announcers` and `s3_headers_depth` settings * configurable `concurrent_announcers` and `s3_headers_depth` settings
* `peer_ping` command * `peer_ping` command
* `--sort` option in `file_list`
* linux distro and desktop name added to analytics
* certifi module for Twisted SSL verification on Windows
* protocol version to dht requests and to the response from `findValue`
* added `port` field to contacts returned by `routing_table_get`
### Removed ### Removed
* `announce_all` argument from `blob_announce` * `announce_all` argument from `blob_announce`
* old `blob_announce_all` command * old `blob_announce_all` command
* `AuthJSONRPCServer.auth_required` decorator * `AuthJSONRPCServer.auth_required` decorator
* unused `--wallet` argument to `lbrynet-daemon`, which used to be to support `PTCWallet`. * unused `--wallet` argument to `lbrynet-daemon`, which used to be to support `PTCWallet`.
* `OptimizedTreeRoutingTable` class used by the dht node for the time being
## [0.19.3] - 2018-05-04 ## [0.19.3] - 2018-05-04
### Changed ### Changed

View file

@ -1,6 +1,6 @@
import logging import logging
__version__ = "0.20.0rc9" __version__ = "0.20.0rc13"
version = tuple(__version__.split('.')) version = tuple(__version__.split('.'))
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())

View file

@ -185,7 +185,7 @@ class Manager(object):
@staticmethod @staticmethod
def _make_context(platform, wallet): def _make_context(platform, wallet):
return { context = {
'app': { 'app': {
'name': 'lbrynet', 'name': 'lbrynet',
'version': platform['lbrynet_version'], 'version': platform['lbrynet_version'],
@ -206,6 +206,10 @@ class Manager(object):
'version': '1.0.0' 'version': '1.0.0'
}, },
} }
if 'desktop' in platform and 'distro' in platform:
context['os']['desktop'] = platform['desktop']
context['os']['distro'] = platform['distro']
return context
@staticmethod @staticmethod
def _if_deferred(maybe_deferred, callback, *args, **kwargs): def _if_deferred(maybe_deferred, callback, *args, **kwargs):

View file

@ -40,7 +40,7 @@ ANDROID = 4
KB = 2 ** 10 KB = 2 ** 10
MB = 2 ** 20 MB = 2 ** 20
DEFAULT_CONCURRENT_ANNOUNCERS = 100 DEFAULT_CONCURRENT_ANNOUNCERS = 10
DEFAULT_DHT_NODES = [ DEFAULT_DHT_NODES = [
('lbrynet1.lbry.io', 4444), ('lbrynet1.lbry.io', 4444),

View file

@ -23,18 +23,14 @@ class BlobAvailabilityTracker(object):
self._blob_manager = blob_manager self._blob_manager = blob_manager
self._peer_finder = peer_finder self._peer_finder = peer_finder
self._dht_node = dht_node self._dht_node = dht_node
self._check_popular = LoopingCall(self._update_most_popular)
self._check_mine = LoopingCall(self._update_mine) self._check_mine = LoopingCall(self._update_mine)
def start(self): def start(self):
log.info("Starting blob availability tracker.") log.info("Starting blob availability tracker.")
self._check_popular.start(600)
self._check_mine.start(600) self._check_mine.start(600)
def stop(self): def stop(self):
log.info("Stopping blob availability tracker.") log.info("Stopping blob availability tracker.")
if self._check_popular.running:
self._check_popular.stop()
if self._check_mine.running: if self._check_mine.running:
self._check_mine.stop() self._check_mine.stop()
@ -68,17 +64,6 @@ class BlobAvailabilityTracker(object):
d.addCallback(lambda peers: _save_peer_info(blob, peers)) d.addCallback(lambda peers: _save_peer_info(blob, peers))
return d return d
def _get_most_popular(self):
dl = []
for (hash, _) in self._dht_node.get_most_popular_hashes(10):
encoded = hash.encode('hex')
dl.append(self._update_peers_for_blob(encoded))
return defer.DeferredList(dl)
def _update_most_popular(self):
d = self._get_most_popular()
d.addCallback(lambda _: self._set_mean_peers())
def _update_mine(self): def _update_mine(self):
def _get_peers(blobs): def _get_peers(blobs):
dl = [] dl = []

View file

@ -93,6 +93,7 @@ class OnlyFreePaymentsManager(object):
self.base = BasePaymentRateManager(0.0, 0.0) self.base = BasePaymentRateManager(0.0, 0.0)
self.points_paid = 0.0 self.points_paid = 0.0
self.min_blob_data_payment_rate = 0.0
self.generous = True self.generous = True
self.strategy = OnlyFreeStrategy() self.strategy = OnlyFreeStrategy()

View file

@ -6,8 +6,7 @@ from lbrynet.dht import node, hashannouncer
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet.core.RateLimiter import RateLimiter from lbrynet.core.RateLimiter import RateLimiter
from lbrynet.core.utils import generate_id from lbrynet.core.utils import generate_id
from lbrynet.core.PaymentRateManager import BasePaymentRateManager, NegotiatedPaymentRateManager from lbrynet.core.PaymentRateManager import BasePaymentRateManager, OnlyFreePaymentsManager
from lbrynet.core.BlobAvailability import BlobAvailabilityTracker
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -33,14 +32,11 @@ class Session(object):
peers can connect to this peer. peers can connect to this peer.
""" """
def __init__(self, blob_data_payment_rate, db_dir=None, def __init__(self, blob_data_payment_rate, db_dir=None, node_id=None, peer_manager=None, dht_node_port=None,
node_id=None, peer_manager=None, dht_node_port=None, known_dht_nodes=None, peer_finder=None, hash_announcer=None, blob_dir=None, blob_manager=None,
known_dht_nodes=None, peer_finder=None, peer_port=None, use_upnp=True, rate_limiter=None, wallet=None, dht_node_class=node.Node,
hash_announcer=None, blob_dir=None, blob_tracker_class=None, payment_rate_manager_class=None, is_generous=True, external_ip=None,
blob_manager=None, peer_port=None, use_upnp=True, storage=None):
rate_limiter=None, wallet=None,
dht_node_class=node.Node, blob_tracker_class=None,
payment_rate_manager_class=None, is_generous=True, external_ip=None, storage=None):
"""@param blob_data_payment_rate: The default payment rate for blob data """@param blob_data_payment_rate: The default payment rate for blob data
@param db_dir: The directory in which levelDB files should be stored @param db_dir: The directory in which levelDB files should be stored
@ -107,8 +103,8 @@ class Session(object):
self.known_dht_nodes = [] self.known_dht_nodes = []
self.blob_dir = blob_dir self.blob_dir = blob_dir
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.blob_tracker = None # self.blob_tracker = None
self.blob_tracker_class = blob_tracker_class or BlobAvailabilityTracker # self.blob_tracker_class = blob_tracker_class or BlobAvailabilityTracker
self.peer_port = peer_port self.peer_port = peer_port
self.use_upnp = use_upnp self.use_upnp = use_upnp
self.rate_limiter = rate_limiter self.rate_limiter = rate_limiter
@ -118,11 +114,10 @@ class Session(object):
self.dht_node_class = dht_node_class self.dht_node_class = dht_node_class
self.dht_node = None self.dht_node = None
self.base_payment_rate_manager = BasePaymentRateManager(blob_data_payment_rate) self.base_payment_rate_manager = BasePaymentRateManager(blob_data_payment_rate)
self.payment_rate_manager = None self.payment_rate_manager = OnlyFreePaymentsManager()
self.payment_rate_manager_class = payment_rate_manager_class or NegotiatedPaymentRateManager # self.payment_rate_manager_class = payment_rate_manager_class or NegotiatedPaymentRateManager
self.is_generous = is_generous # self.is_generous = is_generous
self.storage = storage or SQLiteStorage(self.db_dir) self.storage = storage or SQLiteStorage(self.db_dir)
self._join_dht_deferred = None
def setup(self): def setup(self):
"""Create the blob directory and database if necessary, start all desired services""" """Create the blob directory and database if necessary, start all desired services"""
@ -147,8 +142,8 @@ class Session(object):
ds = [] ds = []
if self.hash_announcer: if self.hash_announcer:
self.hash_announcer.stop() self.hash_announcer.stop()
if self.blob_tracker is not None: # if self.blob_tracker is not None:
ds.append(defer.maybeDeferred(self.blob_tracker.stop)) # ds.append(defer.maybeDeferred(self.blob_tracker.stop))
if self.dht_node is not None: if self.dht_node is not None:
ds.append(defer.maybeDeferred(self.dht_node.stop)) ds.append(defer.maybeDeferred(self.dht_node.stop))
if self.rate_limiter is not None: if self.rate_limiter is not None:
@ -171,24 +166,24 @@ class Session(object):
if not mapping: if not mapping:
return port return port
if upnp.lanaddr == mapping[0]: if upnp.lanaddr == mapping[0]:
return mapping return mapping[1]
return get_free_port(upnp, port + 1, protocol) return get_free_port(upnp, port + 1, protocol)
def get_port_mapping(upnp, internal_port, protocol, description): def get_port_mapping(upnp, port, protocol, description):
# try to map to the requested port, if there is already a mapping use the next external # try to map to the requested port, if there is already a mapping use the next external
# port available # port available
if protocol not in ['UDP', 'TCP']: if protocol not in ['UDP', 'TCP']:
raise Exception("invalid protocol") raise Exception("invalid protocol")
external_port = get_free_port(upnp, internal_port, protocol) port = get_free_port(upnp, port, protocol)
if isinstance(external_port, tuple): if isinstance(port, tuple):
log.info("Found existing UPnP redirect %s:%i (%s) to %s:%i, using it", log.info("Found existing UPnP redirect %s:%i (%s) to %s:%i, using it",
self.external_ip, external_port[1], protocol, upnp.lanaddr, internal_port) self.external_ip, port, protocol, upnp.lanaddr, port)
return external_port[1], protocol return port
upnp.addportmapping(external_port, protocol, upnp.lanaddr, internal_port, upnp.addportmapping(port, protocol, upnp.lanaddr, port,
description, '') description, '')
log.info("Set UPnP redirect %s:%i (%s) to %s:%i", self.external_ip, external_port, log.info("Set UPnP redirect %s:%i (%s) to %s:%i", self.external_ip, port,
protocol, upnp.lanaddr, internal_port) protocol, upnp.lanaddr, port)
return external_port, protocol return port
def threaded_try_upnp(): def threaded_try_upnp():
if self.use_upnp is False: if self.use_upnp is False:
@ -203,13 +198,11 @@ class Session(object):
# best not to rely on this external ip, the router can be behind layers of NATs # best not to rely on this external ip, the router can be behind layers of NATs
self.external_ip = external_ip self.external_ip = external_ip
if self.peer_port: if self.peer_port:
self.upnp_redirects.append( self.peer_port = get_port_mapping(u, self.peer_port, 'TCP', 'LBRY peer port')
get_port_mapping(u, self.peer_port, 'TCP', 'LBRY peer port') self.upnp_redirects.append((self.peer_port, 'TCP'))
)
if self.dht_node_port: if self.dht_node_port:
self.upnp_redirects.append( self.dht_node_port = get_port_mapping(u, self.dht_node_port, 'UDP', 'LBRY DHT port')
get_port_mapping(u, self.dht_node_port, 'UDP', 'LBRY DHT port') self.upnp_redirects.append((self.dht_node_port, 'UDP'))
)
return True return True
return False return False
@ -234,9 +227,9 @@ class Session(object):
self.hash_announcer = hashannouncer.DHTHashAnnouncer(self.dht_node, self.storage) self.hash_announcer = hashannouncer.DHTHashAnnouncer(self.dht_node, self.storage)
self.peer_manager = self.dht_node.peer_manager self.peer_manager = self.dht_node.peer_manager
self.peer_finder = self.dht_node.peer_finder self.peer_finder = self.dht_node.peer_finder
self._join_dht_deferred = self.dht_node.joinNetwork(self.known_dht_nodes) d = self.dht_node.start(self.known_dht_nodes)
self._join_dht_deferred.addCallback(lambda _: log.info("Joined the dht")) d.addCallback(lambda _: log.info("Joined the dht"))
self._join_dht_deferred.addCallback(lambda _: self.hash_announcer.start()) d.addCallback(lambda _: self.hash_announcer.start())
def _setup_other_components(self): def _setup_other_components(self):
log.debug("Setting up the rest of the components") log.debug("Setting up the rest of the components")
@ -251,19 +244,19 @@ class Session(object):
else: else:
self.blob_manager = DiskBlobManager(self.blob_dir, self.storage) self.blob_manager = DiskBlobManager(self.blob_dir, self.storage)
if self.blob_tracker is None: # if self.blob_tracker is None:
self.blob_tracker = self.blob_tracker_class( # self.blob_tracker = self.blob_tracker_class(
self.blob_manager, self.dht_node.peer_finder, self.dht_node # self.blob_manager, self.dht_node.peer_finder, self.dht_node
) # )
if self.payment_rate_manager is None: # if self.payment_rate_manager is None:
self.payment_rate_manager = self.payment_rate_manager_class( # self.payment_rate_manager = self.payment_rate_manager_class(
self.base_payment_rate_manager, self.blob_tracker, self.is_generous # self.base_payment_rate_manager, self.blob_tracker, self.is_generous
) # )
self.rate_limiter.start() self.rate_limiter.start()
d = self.blob_manager.setup() d = self.blob_manager.setup()
d.addCallback(lambda _: self.wallet.start()) d.addCallback(lambda _: self.wallet.start())
d.addCallback(lambda _: self.blob_tracker.start()) # d.addCallback(lambda _: self.blob_tracker.start())
return d return d
def _unset_upnp(self): def _unset_upnp(self):

View file

@ -9,22 +9,25 @@ QUEUE_SIZE_THRESHOLD = 100
class CallLaterManager(object): class CallLaterManager(object):
_callLater = None def __init__(self, callLater):
_pendingCallLaters = [] """
_delay = MIN_DELAY :param callLater: (IReactorTime.callLater)
"""
@classmethod self._callLater = callLater
def get_min_delay(cls): self._pendingCallLaters = []
cls._pendingCallLaters = [cl for cl in cls._pendingCallLaters if cl.active()] self._delay = MIN_DELAY
queue_size = len(cls._pendingCallLaters)
def get_min_delay(self):
self._pendingCallLaters = [cl for cl in self._pendingCallLaters if cl.active()]
queue_size = len(self._pendingCallLaters)
if queue_size > QUEUE_SIZE_THRESHOLD: if queue_size > QUEUE_SIZE_THRESHOLD:
cls._delay = min((cls._delay + DELAY_INCREMENT), MAX_DELAY) self._delay = min((self._delay + DELAY_INCREMENT), MAX_DELAY)
else: else:
cls._delay = max((cls._delay - 2.0 * DELAY_INCREMENT), MIN_DELAY) self._delay = max((self._delay - 2.0 * DELAY_INCREMENT), MIN_DELAY)
return cls._delay return self._delay
@classmethod def _cancel(self, call_later):
def _cancel(cls, call_later):
""" """
:param call_later: DelayedCall :param call_later: DelayedCall
:return: (callable) canceller function :return: (callable) canceller function
@ -38,26 +41,25 @@ class CallLaterManager(object):
if call_later.active(): if call_later.active():
call_later.cancel() call_later.cancel()
cls._pendingCallLaters.remove(call_later) if call_later in self._pendingCallLaters:
self._pendingCallLaters.remove(call_later)
return reason return reason
return cancel return cancel
@classmethod def stop(self):
def stop(cls):
""" """
Cancel any callLaters that are still running Cancel any callLaters that are still running
""" """
from twisted.internet import defer from twisted.internet import defer
while cls._pendingCallLaters: while self._pendingCallLaters:
canceller = cls._cancel(cls._pendingCallLaters[0]) canceller = self._cancel(self._pendingCallLaters[0])
try: try:
canceller() canceller()
except (defer.CancelledError, defer.AlreadyCalledError): except (defer.CancelledError, defer.AlreadyCalledError, ValueError):
pass pass
@classmethod def call_later(self, when, what, *args, **kwargs):
def call_later(cls, when, what, *args, **kwargs):
""" """
Schedule a call later and get a canceller callback function Schedule a call later and get a canceller callback function
@ -69,21 +71,11 @@ class CallLaterManager(object):
:return: (tuple) twisted.internet.base.DelayedCall object, canceller function :return: (tuple) twisted.internet.base.DelayedCall object, canceller function
""" """
call_later = cls._callLater(when, what, *args, **kwargs) call_later = self._callLater(when, what, *args, **kwargs)
canceller = cls._cancel(call_later) canceller = self._cancel(call_later)
cls._pendingCallLaters.append(call_later) self._pendingCallLaters.append(call_later)
return call_later, canceller return call_later, canceller
@classmethod def call_soon(self, what, *args, **kwargs):
def call_soon(cls, what, *args, **kwargs): delay = self.get_min_delay()
delay = cls.get_min_delay() return self.call_later(delay, what, *args, **kwargs)
return cls.call_later(delay, what, *args, **kwargs)
@classmethod
def setup(cls, callLater):
"""
Setup the callLater function to use, supports the real reactor as well as task.Clock
:param callLater: (IReactorTime.callLater)
"""
cls._callLater = callLater

View file

@ -36,6 +36,10 @@ def get_platform(get_ip=True):
"lbryschema_version": lbryschema_version, "lbryschema_version": lbryschema_version,
"build": build_type.BUILD, # CI server sets this during build step "build": build_type.BUILD, # CI server sets this during build step
} }
if p["os_system"] == "Linux":
import distro
p["distro"] = distro.info()
p["desktop"] = os.environ.get('XDG_CURRENT_DESKTOP', 'Unknown')
# TODO: remove this from get_platform and add a get_external_ip function using treq # TODO: remove this from get_platform and add a get_external_ip function using treq
if get_ip: if get_ip:

View file

@ -1,21 +1,23 @@
import base64 import base64
import datetime import datetime
import logging
import random import random
import socket import socket
import string import string
import json import json
import traceback
import functools
import logging
import pkg_resources import pkg_resources
from twisted.python.failure import Failure
from twisted.internet import defer from twisted.internet import defer
from lbryschema.claim import ClaimDict from lbryschema.claim import ClaimDict
from lbrynet.core.cryptoutils import get_lbry_hash_obj from lbrynet.core.cryptoutils import get_lbry_hash_obj
log = logging.getLogger(__name__)
# digest_size is in bytes, and blob hashes are hex encoded # digest_size is in bytes, and blob hashes are hex encoded
blobhash_length = get_lbry_hash_obj().digest_size * 2 blobhash_length = get_lbry_hash_obj().digest_size * 2
log = logging.getLogger(__name__)
# defining these time functions here allows for easier overriding in testing # defining these time functions here allows for easier overriding in testing
def now(): def now():
@ -172,3 +174,70 @@ def DeferredDict(d, consumeErrors=False):
if success: if success:
response[k] = result response[k] = result
defer.returnValue(response) defer.returnValue(response)
class DeferredProfiler(object):
def __init__(self):
self.profile_results = {}
def add_result(self, fn, start_time, finished_time, stack, success):
self.profile_results[fn].append((start_time, finished_time, stack, success))
def show_profile_results(self, fn):
profile_results = list(self.profile_results[fn])
call_counts = {
caller: [(start, finished, finished - start, success)
for (start, finished, _caller, success) in profile_results
if _caller == caller]
for caller in set(result[2] for result in profile_results)
}
log.info("called %s %i times from %i sources\n", fn.__name__, len(profile_results), len(call_counts))
for caller in sorted(list(call_counts.keys()), key=lambda c: len(call_counts[c]), reverse=True):
call_info = call_counts[caller]
times = [r[2] for r in call_info]
own_time = sum(times)
times.sort()
longest = 0 if not times else times[-1]
shortest = 0 if not times else times[0]
log.info(
"%i successes and %i failures\nlongest %f, shortest %f, avg %f\ncaller:\n%s",
len([r for r in call_info if r[3]]),
len([r for r in call_info if not r[3]]),
longest, shortest, own_time / float(len(call_info)), caller
)
def profiled_deferred(self, reactor=None):
if not reactor:
from twisted.internet import reactor
def _cb(result, fn, start, caller_info):
if isinstance(result, (Failure, Exception)):
error = result
result = None
else:
error = None
self.add_result(fn, start, reactor.seconds(), caller_info, error is None)
if error is None:
return result
raise error
def _profiled_deferred(fn):
reactor.addSystemEventTrigger("after", "shutdown", self.show_profile_results, fn)
self.profile_results[fn] = []
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
caller_info = "".join(traceback.format_list(traceback.extract_stack()[-3:-1]))
start = reactor.seconds()
d = defer.maybeDeferred(fn, *args, **kwargs)
d.addBoth(_cb, fn, start, caller_info)
return d
return _wrapper
return _profiled_deferred
_profiler = DeferredProfiler()
profile_deferred = _profiler.profiled_deferred

View file

@ -97,6 +97,9 @@ CONNECTION_MESSAGES = {
SHORT_ID_LEN = 20 SHORT_ID_LEN = 20
MAX_UPDATE_FEE_ESTIMATE = 0.3 MAX_UPDATE_FEE_ESTIMATE = 0.3
DIRECTION_ASCENDING = 'asc'
DIRECTION_DESCENDING = 'desc'
DIRECTIONS = DIRECTION_ASCENDING, DIRECTION_DESCENDING
class IterableContainer(object): class IterableContainer(object):
def __iter__(self): def __iter__(self):
@ -163,6 +166,11 @@ class AlwaysSend(object):
return d return d
def sort_claim_results(claims):
claims.sort(key=lambda d: (d['height'], d['name'], d['claim_id'], d['txid'], d['nout']))
return claims
class Daemon(AuthJSONRPCServer): class Daemon(AuthJSONRPCServer):
""" """
LBRYnet daemon, a jsonrpc interface to lbry functions LBRYnet daemon, a jsonrpc interface to lbry functions
@ -199,7 +207,7 @@ class Daemon(AuthJSONRPCServer):
self.connected_to_internet = True self.connected_to_internet = True
self.connection_status_code = None self.connection_status_code = None
self.platform = None self.platform = None
self.current_db_revision = 8 self.current_db_revision = 9
self.db_revision_file = conf.settings.get_db_revision_filename() self.db_revision_file = conf.settings.get_db_revision_filename()
self.session = None self.session = None
self._session_id = conf.settings.get_session_id() self._session_id = conf.settings.get_session_id()
@ -939,6 +947,28 @@ class Daemon(AuthJSONRPCServer):
log.debug("Collected %i lbry files", len(lbry_files)) log.debug("Collected %i lbry files", len(lbry_files))
defer.returnValue(lbry_files) defer.returnValue(lbry_files)
def _sort_lbry_files(self, lbry_files, sort_by):
for field, direction in sort_by:
is_reverse = direction == DIRECTION_DESCENDING
key_getter = create_key_getter(field) if field else None
lbry_files = sorted(lbry_files, key=key_getter, reverse=is_reverse)
return lbry_files
def _parse_lbry_files_sort(self, sort):
"""
Given a sort string like 'file_name, desc' or 'points_paid',
parse the string into a tuple of (field, direction).
Direction defaults to ascending.
"""
pieces = [p.strip() for p in sort.split(',')]
field = pieces.pop(0)
direction = DIRECTION_ASCENDING
if pieces and pieces[0] in DIRECTIONS:
direction = pieces[0]
return field, direction
def _get_single_peer_downloader(self): def _get_single_peer_downloader(self):
downloader = SinglePeerDownloader() downloader = SinglePeerDownloader()
downloader.setup(self.session.wallet) downloader.setup(self.session.wallet)
@ -1358,7 +1388,7 @@ class Daemon(AuthJSONRPCServer):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def jsonrpc_file_list(self, **kwargs): def jsonrpc_file_list(self, sort=None, **kwargs):
""" """
List files limited by optional filters List files limited by optional filters
@ -1366,7 +1396,7 @@ class Daemon(AuthJSONRPCServer):
file_list [--sd_hash=<sd_hash>] [--file_name=<file_name>] [--stream_hash=<stream_hash>] file_list [--sd_hash=<sd_hash>] [--file_name=<file_name>] [--stream_hash=<stream_hash>]
[--rowid=<rowid>] [--claim_id=<claim_id>] [--outpoint=<outpoint>] [--txid=<txid>] [--nout=<nout>] [--rowid=<rowid>] [--claim_id=<claim_id>] [--outpoint=<outpoint>] [--txid=<txid>] [--nout=<nout>]
[--channel_claim_id=<channel_claim_id>] [--channel_name=<channel_name>] [--channel_claim_id=<channel_claim_id>] [--channel_name=<channel_name>]
[--claim_name=<claim_name>] [--full_status] [--claim_name=<claim_name>] [--full_status] [--sort=<sort_method>...]
Options: Options:
--sd_hash=<sd_hash> : (str) get file with matching sd hash --sd_hash=<sd_hash> : (str) get file with matching sd hash
@ -1383,6 +1413,9 @@ class Daemon(AuthJSONRPCServer):
--claim_name=<claim_name> : (str) get file with matching claim name --claim_name=<claim_name> : (str) get file with matching claim name
--full_status : (bool) full status, populate the --full_status : (bool) full status, populate the
'message' and 'size' fields 'message' and 'size' fields
--sort=<sort_method> : (str) sort by any property, like 'file_name'
or 'metadata.author'; to specify direction
append ',asc' or ',desc'
Returns: Returns:
(list) List of files (list) List of files
@ -1419,6 +1452,9 @@ class Daemon(AuthJSONRPCServer):
""" """
result = yield self._get_lbry_files(return_json=True, **kwargs) result = yield self._get_lbry_files(return_json=True, **kwargs)
if sort:
sort_by = [self._parse_lbry_files_sort(s) for s in sort]
result = self._sort_lbry_files(result, sort_by)
response = yield self._render_response(result) response = yield self._render_response(result)
defer.returnValue(response) defer.returnValue(response)
@ -2330,7 +2366,8 @@ class Daemon(AuthJSONRPCServer):
} }
""" """
claims = yield self.session.wallet.get_claims_for_name(name) claims = yield self.session.wallet.get_claims_for_name(name) # type: dict
sort_claim_results(claims['claims'])
defer.returnValue(claims) defer.returnValue(claims)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -2866,24 +2903,22 @@ class Daemon(AuthJSONRPCServer):
if not utils.is_valid_blobhash(blob_hash): if not utils.is_valid_blobhash(blob_hash):
raise Exception("invalid blob hash") raise Exception("invalid blob hash")
finished_deferred = self.session.dht_node.getPeersForBlob(binascii.unhexlify(blob_hash), True) finished_deferred = self.session.dht_node.iterativeFindValue(binascii.unhexlify(blob_hash))
def _trigger_timeout(): def trap_timeout(err):
if not finished_deferred.called: err.trap(defer.TimeoutError)
log.debug("Peer search for %s timed out", blob_hash) return []
finished_deferred.cancel()
timeout = timeout or conf.settings['peer_search_timeout']
self.session.dht_node.reactor_callLater(timeout, _trigger_timeout)
finished_deferred.addTimeout(timeout or conf.settings['peer_search_timeout'], self.session.dht_node.clock)
finished_deferred.addErrback(trap_timeout)
peers = yield finished_deferred peers = yield finished_deferred
results = [ results = [
{ {
"node_id": node_id.encode('hex'),
"host": host, "host": host,
"port": port, "port": port
"node_id": node_id
} }
for host, port, node_id in peers for node_id, host, port in peers
] ]
defer.returnValue(results) defer.returnValue(results)
@ -3102,6 +3137,7 @@ class Daemon(AuthJSONRPCServer):
<bucket index>: [ <bucket index>: [
{ {
"address": (str) peer address, "address": (str) peer address,
"port": (int) peer udp port
"node_id": (str) peer node id, "node_id": (str) peer node id,
"blobs": (list) blob hashes announced by peer "blobs": (list) blob hashes announced by peer
} }
@ -3114,18 +3150,13 @@ class Daemon(AuthJSONRPCServer):
""" """
result = {} result = {}
data_store = deepcopy(self.session.dht_node._dataStore._dict) data_store = self.session.dht_node._dataStore._dict
datastore_len = len(data_store) datastore_len = len(data_store)
hosts = {} hosts = {}
if datastore_len: if datastore_len:
for k, v in data_store.iteritems(): for k, v in data_store.iteritems():
for value, lastPublished, originallyPublished, originalPublisherID in v: for contact, value, lastPublished, originallyPublished, originalPublisherID in v:
try:
contact = self.session.dht_node._routingTable.getContact(
originalPublisherID)
except ValueError:
continue
if contact in hosts: if contact in hosts:
blobs = hosts[contact] blobs = hosts[contact]
else: else:
@ -3147,6 +3178,7 @@ class Daemon(AuthJSONRPCServer):
blobs = [] blobs = []
host = { host = {
"address": contact.address, "address": contact.address,
"port": contact.port,
"node_id": contact.id.encode("hex"), "node_id": contact.id.encode("hex"),
"blobs": blobs, "blobs": blobs,
} }
@ -3392,3 +3424,16 @@ def get_blob_payment_rate_manager(session, payment_rate_manager=None):
payment_rate_manager = rate_managers[payment_rate_manager] payment_rate_manager = rate_managers[payment_rate_manager]
log.info("Downloading blob with rate manager: %s", payment_rate_manager) log.info("Downloading blob with rate manager: %s", payment_rate_manager)
return payment_rate_manager or session.payment_rate_manager return payment_rate_manager or session.payment_rate_manager
def create_key_getter(field):
search_path = field.split('.')
def key_getter(value):
for key in search_path:
try:
value = value[key]
except KeyError as e:
errmsg = 'Failed to get "{}", key "{}" was not found.'
raise Exception(errmsg.format(field, e.message))
return value
return key_getter

View file

@ -1,3 +1,12 @@
import os
import sys
# Set SSL_CERT_FILE env variable for Twisted SSL verification on Windows
# This needs to happen before anything else
if 'win' in sys.platform:
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
from lbrynet.core import log_support from lbrynet.core import log_support
import argparse import argparse

View file

@ -206,7 +206,12 @@ class CryptonatorFeed(MarketFeed):
class ExchangeRateManager(object): class ExchangeRateManager(object):
def __init__(self): def __init__(self):
self.market_feeds = [ self.market_feeds = [
LBRYioBTCFeed(), LBRYioFeed(), BittrexFeed(), CryptonatorBTCFeed(), CryptonatorFeed()] LBRYioBTCFeed(),
LBRYioFeed(),
BittrexFeed(),
# CryptonatorBTCFeed(),
# CryptonatorFeed()
]
def start(self): def start(self):
log.info("Starting exchange rate manager") log.info("Starting exchange rate manager")

View file

@ -13,7 +13,7 @@ from txjsonrpc import jsonrpclib
from traceback import format_exc from traceback import format_exc
from lbrynet import conf from lbrynet import conf
from lbrynet.core.Error import InvalidAuthenticationToken, InvalidHeaderError from lbrynet.core.Error import InvalidAuthenticationToken
from lbrynet.core import utils from lbrynet.core import utils
from lbrynet.daemon.auth.util import APIKey, get_auth_message from lbrynet.daemon.auth.util import APIKey, get_auth_message
from lbrynet.daemon.auth.client import LBRY_SECRET from lbrynet.daemon.auth.client import LBRY_SECRET
@ -231,9 +231,9 @@ class AuthJSONRPCServer(AuthorizedBase):
def _render(self, request): def _render(self, request):
time_in = utils.now() time_in = utils.now()
if not self._check_headers(request): # if not self._check_headers(request):
self._render_error(Failure(InvalidHeaderError()), request, None) # self._render_error(Failure(InvalidHeaderError()), request, None)
return server.NOT_DONE_YET # return server.NOT_DONE_YET
session = request.getSession() session = request.getSession()
session_id = session.uid session_id = session.uid
finished_deferred = request.notifyFinish() finished_deferred = request.notifyFinish()

View file

@ -1,5 +1,7 @@
import logging import logging
log = logging.getLogger(__name__)
def migrate_db(db_dir, start, end): def migrate_db(db_dir, start, end):
current = start current = start
@ -18,11 +20,14 @@ def migrate_db(db_dir, start, end):
from lbrynet.database.migrator.migrate6to7 import do_migration from lbrynet.database.migrator.migrate6to7 import do_migration
elif current == 7: elif current == 7:
from lbrynet.database.migrator.migrate7to8 import do_migration from lbrynet.database.migrator.migrate7to8 import do_migration
elif current == 8:
from lbrynet.database.migrator.migrate8to9 import do_migration
else: else:
raise Exception("DB migration of version {} to {} is not available".format(current, raise Exception("DB migration of version {} to {} is not available".format(current,
current+1)) current+1))
do_migration(db_dir) do_migration(db_dir)
current += 1 current += 1
log.info("successfully migrated the database from revision %i to %i", current - 1, current)
return None return None

View file

@ -4,12 +4,73 @@ import json
import logging import logging
from lbryschema.decode import smart_decode from lbryschema.decode import smart_decode
from lbrynet import conf from lbrynet import conf
from lbrynet.database.storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
download_directory = conf.settings['download_directory'] download_directory = conf.settings['download_directory']
CREATE_TABLES_QUERY = """
pragma foreign_keys=on;
pragma journal_mode=WAL;
create table if not exists blob (
blob_hash char(96) primary key not null,
blob_length integer not null,
next_announce_time integer not null,
should_announce integer not null default 0,
status text not null
);
create table if not exists stream (
stream_hash char(96) not null primary key,
sd_hash char(96) not null references blob,
stream_key text not null,
stream_name text not null,
suggested_filename text not null
);
create table if not exists stream_blob (
stream_hash char(96) not null references stream,
blob_hash char(96) references blob,
position integer not null,
iv char(32) not null,
primary key (stream_hash, blob_hash)
);
create table if not exists claim (
claim_outpoint text not null primary key,
claim_id char(40) not null,
claim_name text not null,
amount integer not null,
height integer not null,
serialized_metadata blob not null,
channel_claim_id text,
address text not null,
claim_sequence integer not null
);
create table if not exists file (
stream_hash text primary key not null references stream,
file_name text not null,
download_directory text not null,
blob_data_rate real not null,
status text not null
);
create table if not exists content_claim (
stream_hash text unique not null references file,
claim_outpoint text not null references claim,
primary key (stream_hash, claim_outpoint)
);
create table if not exists support (
support_outpoint text not null primary key,
claim_id text not null,
amount integer not null,
address text not null
);
"""
def run_operation(db): def run_operation(db):
def _decorate(fn): def _decorate(fn):
@ -148,7 +209,7 @@ def do_migration(db_dir):
@run_operation(connection) @run_operation(connection)
def _make_db(new_db): def _make_db(new_db):
# create the new tables # create the new tables
new_db.executescript(SQLiteStorage.CREATE_TABLES_QUERY) new_db.executescript(CREATE_TABLES_QUERY)
# first migrate the blobs # first migrate the blobs
blobs = blobs_db_cursor.execute("select * from blobs").fetchall() blobs = blobs_db_cursor.execute("select * from blobs").fetchall()
@ -245,13 +306,20 @@ def do_migration(db_dir):
continue continue
log.info("migrated %i content claims", new_db.execute("select count(*) from content_claim").fetchone()[0]) log.info("migrated %i content claims", new_db.execute("select count(*) from content_claim").fetchone()[0])
try:
_make_db() # pylint: disable=no-value-for-parameter
except sqlite3.OperationalError as err:
if err.message == "table blob has 7 columns but 5 values were supplied":
log.warning("detected a failed previous migration to revision 6, repairing it")
connection.close()
os.remove(new_db_path)
return do_migration(db_dir)
raise err
_make_db() # pylint: disable=no-value-for-parameter
connection.close() connection.close()
blobs_db.close() blobs_db.close()
lbryfile_db.close() lbryfile_db.close()
metadata_db.close() metadata_db.close()
log.info("successfully migrated the database")
# os.remove(os.path.join(db_dir, "blockchainname.db")) # os.remove(os.path.join(db_dir, "blockchainname.db"))
# os.remove(os.path.join(db_dir, 'lbryfile_info.db')) # os.remove(os.path.join(db_dir, 'lbryfile_info.db'))
# os.remove(os.path.join(db_dir, 'blobs.db')) # os.remove(os.path.join(db_dir, 'blobs.db'))

View file

@ -0,0 +1,54 @@
import sqlite3
import logging
import os
from lbrynet.core.Error import InvalidStreamDescriptorError
from lbrynet.core.StreamDescriptor import EncryptedFileStreamType, format_sd_info, format_blobs, validate_descriptor
from lbrynet.cryptstream.CryptBlob import CryptBlobInfo
log = logging.getLogger(__name__)
def do_migration(db_dir):
db_path = os.path.join(db_dir, "lbrynet.sqlite")
blob_dir = os.path.join(db_dir, "blobfiles")
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
query = "select stream_name, stream_key, suggested_filename, sd_hash, stream_hash from stream"
streams = cursor.execute(query).fetchall()
blobs = cursor.execute("select s.stream_hash, s.position, s.iv, b.blob_hash, b.blob_length from stream_blob s "
"left outer join blob b ON b.blob_hash=s.blob_hash order by s.position").fetchall()
blobs_by_stream = {}
for stream_hash, position, iv, blob_hash, blob_length in blobs:
blobs_by_stream.setdefault(stream_hash, []).append(CryptBlobInfo(blob_hash, position, blob_length or 0, iv))
for stream_name, stream_key, suggested_filename, sd_hash, stream_hash in streams:
sd_info = format_sd_info(
EncryptedFileStreamType, stream_name, stream_key,
suggested_filename, stream_hash, format_blobs(blobs_by_stream[stream_hash])
)
try:
validate_descriptor(sd_info)
except InvalidStreamDescriptorError as err:
log.warning("Stream for descriptor %s is invalid (%s), cleaning it up",
sd_hash, err.message)
blob_hashes = [blob.blob_hash for blob in blobs_by_stream[stream_hash]]
delete_stream(cursor, stream_hash, sd_hash, blob_hashes, blob_dir)
connection.commit()
connection.close()
def delete_stream(transaction, stream_hash, sd_hash, blob_hashes, blob_dir):
transaction.execute("delete from content_claim where stream_hash=? ", (stream_hash,))
transaction.execute("delete from file where stream_hash=? ", (stream_hash, ))
transaction.execute("delete from stream_blob where stream_hash=?", (stream_hash, ))
transaction.execute("delete from stream where stream_hash=? ", (stream_hash, ))
transaction.execute("delete from blob where blob_hash=?", (sd_hash, ))
for blob_hash in blob_hashes:
transaction.execute("delete from blob where blob_hash=?", (blob_hash, ))
file_path = os.path.join(blob_dir, blob_hash)
if os.path.isfile(file_path):
os.unlink(file_path)

View file

@ -208,27 +208,15 @@ class SQLiteStorage(object):
# # # # # # # # # blob functions # # # # # # # # # # # # # # # # # # blob functions # # # # # # # # #
@defer.inlineCallbacks def add_completed_blob(self, blob_hash, length, next_announce_time, should_announce, status="finished"):
def add_completed_blob(self, blob_hash, length, next_announce_time, should_announce):
log.debug("Adding a completed blob. blob_hash=%s, length=%i", blob_hash, length) log.debug("Adding a completed blob. blob_hash=%s, length=%i", blob_hash, length)
yield self.add_known_blob(blob_hash, length) values = (blob_hash, length, next_announce_time or 0, int(bool(should_announce)), status, 0, 0)
yield self.set_blob_status(blob_hash, "finished") return self.db.runOperation("insert or replace into blob values (?, ?, ?, ?, ?, ?, ?)", values)
yield self.set_should_announce(blob_hash, next_announce_time, should_announce)
yield self.db.runOperation(
"update blob set blob_length=? where blob_hash=?", (length, blob_hash)
)
def set_should_announce(self, blob_hash, next_announce_time, should_announce): def set_should_announce(self, blob_hash, next_announce_time, should_announce):
next_announce_time = next_announce_time or 0
should_announce = 1 if should_announce else 0
return self.db.runOperation( return self.db.runOperation(
"update blob set next_announce_time=?, should_announce=? where blob_hash=?", "update blob set next_announce_time=?, should_announce=? where blob_hash=?",
(next_announce_time, should_announce, blob_hash) (next_announce_time or 0, int(bool(should_announce)), blob_hash)
)
def set_blob_status(self, blob_hash, status):
return self.db.runOperation(
"update blob set status=? where blob_hash=?", (status, blob_hash)
) )
def get_blob_status(self, blob_hash): def get_blob_status(self, blob_hash):
@ -552,7 +540,7 @@ class SQLiteStorage(object):
) )
return self.db.runInteraction(_save_support) return self.db.runInteraction(_save_support)
def get_supports(self, claim_id): def get_supports(self, *claim_ids):
def _format_support(outpoint, supported_id, amount, address): def _format_support(outpoint, supported_id, amount, address):
return { return {
"txid": outpoint.split(":")[0], "txid": outpoint.split(":")[0],
@ -563,10 +551,15 @@ class SQLiteStorage(object):
} }
def _get_supports(transaction): def _get_supports(transaction):
if len(claim_ids) == 1:
bind = "=?"
else:
bind = "in ({})".format(','.join('?' for _ in range(len(claim_ids))))
return [ return [
_format_support(*support_info) _format_support(*support_info)
for support_info in transaction.execute( for support_info in transaction.execute(
"select * from support where claim_id=?", (claim_id, ) "select * from support where claim_id {}".format(bind),
tuple(claim_ids)
).fetchall() ).fetchall()
] ]
@ -683,51 +676,82 @@ class SQLiteStorage(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_content_claim(self, stream_hash, include_supports=True): def get_content_claim(self, stream_hash, include_supports=True):
def _get_content_claim(transaction): def _get_claim_from_stream_hash(transaction):
claim_id = transaction.execute( claim_info = transaction.execute(
"select claim.claim_outpoint from content_claim " "select c.*, "
"inner join claim on claim.claim_outpoint=content_claim.claim_outpoint and content_claim.stream_hash=? " "case when c.channel_claim_id is not null then "
"order by claim.rowid desc", (stream_hash, ) "(select claim_name from claim where claim_id==c.channel_claim_id) "
"else null end as channel_name from content_claim "
"inner join claim c on c.claim_outpoint=content_claim.claim_outpoint "
"and content_claim.stream_hash=? order by c.rowid desc", (stream_hash,)
).fetchone() ).fetchone()
if not claim_id: if not claim_info:
return None return None
return claim_id[0] channel_name = claim_info[-1]
result = _format_claim_response(*claim_info[:-1])
if channel_name:
result['channel_name'] = channel_name
return result
content_claim_outpoint = yield self.db.runInteraction(_get_content_claim) result = yield self.db.runInteraction(_get_claim_from_stream_hash)
result = None if result and include_supports:
if content_claim_outpoint: supports = yield self.get_supports(result['claim_id'])
result = yield self.get_claim(content_claim_outpoint, include_supports) result['supports'] = supports
result['effective_amount'] = float(
sum([support['amount'] for support in supports]) + result['amount']
)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_claim(self, claim_outpoint, include_supports=True): def get_claims_from_stream_hashes(self, stream_hashes, include_supports=True):
def _claim_response(outpoint, claim_id, name, amount, height, serialized, channel_id, address, claim_sequence): def _batch_get_claim(transaction):
r = { results = {}
"name": name, bind = "({})".format(','.join('?' for _ in range(len(stream_hashes))))
"claim_id": claim_id, claim_infos = transaction.execute(
"address": address, "select content_claim.stream_hash, c.*, "
"claim_sequence": claim_sequence, "case when c.channel_claim_id is not null then "
"value": ClaimDict.deserialize(serialized.decode('hex')).claim_dict, "(select claim_name from claim where claim_id==c.channel_claim_id) "
"height": height, "else null end as channel_name from content_claim "
"amount": float(Decimal(amount) / Decimal(COIN)), "inner join claim c on c.claim_outpoint=content_claim.claim_outpoint "
"nout": int(outpoint.split(":")[1]), "and content_claim.stream_hash in {} order by c.rowid desc".format(bind),
"txid": outpoint.split(":")[0], tuple(stream_hashes)
"channel_claim_id": channel_id, ).fetchall()
"channel_name": None for claim_info in claim_infos:
} channel_name = claim_info[-1]
return r stream_hash = claim_info[0]
result = _format_claim_response(*claim_info[1:-1])
if channel_name:
result['channel_name'] = channel_name
results[stream_hash] = result
return results
claims = yield self.db.runInteraction(_batch_get_claim)
if include_supports:
all_supports = {}
for support in (yield self.get_supports(*[claim['claim_id'] for claim in claims.values()])):
all_supports.setdefault(support['claim_id'], []).append(support)
for stream_hash in claims.keys():
claim = claims[stream_hash]
supports = all_supports.get(claim['claim_id'], [])
claim['supports'] = supports
claim['effective_amount'] = float(
sum([support['amount'] for support in supports]) + claim['amount']
)
claims[stream_hash] = claim
defer.returnValue(claims)
@defer.inlineCallbacks
def get_claim(self, claim_outpoint, include_supports=True):
def _get_claim(transaction): def _get_claim(transaction):
claim_info = transaction.execute( claim_info = transaction.execute("select c.*, "
"select * from claim where claim_outpoint=?", (claim_outpoint, ) "case when c.channel_claim_id is not null then "
).fetchone() "(select claim_name from claim where claim_id==c.channel_claim_id) "
result = _claim_response(*claim_info) "else null end as channel_name from claim c where claim_outpoint = ?",
if result['channel_claim_id']: (claim_outpoint,)).fetchone()
channel_name_result = transaction.execute( channel_name = claim_info[-1]
"select claim_name from claim where claim_id=?", (result['channel_claim_id'], ) result = _format_claim_response(*claim_info[:-1])
).fetchone() if channel_name:
if channel_name_result: result['channel_name'] = channel_name
result['channel_name'] = channel_name_result[0]
return result return result
result = yield self.db.runInteraction(_get_claim) result = yield self.db.runInteraction(_get_claim)
@ -793,3 +817,21 @@ class SQLiteStorage(object):
"where r.timestamp is null or r.timestamp < ?", "where r.timestamp is null or r.timestamp < ?",
self.clock.seconds() - conf.settings['auto_re_reflect_interval'] self.clock.seconds() - conf.settings['auto_re_reflect_interval']
) )
# Helper functions
def _format_claim_response(outpoint, claim_id, name, amount, height, serialized, channel_id, address, claim_sequence):
r = {
"name": name,
"claim_id": claim_id,
"address": address,
"claim_sequence": claim_sequence,
"value": ClaimDict.deserialize(serialized.decode('hex')).claim_dict,
"height": height,
"amount": float(Decimal(amount) / Decimal(COIN)),
"nout": int(outpoint.split(":")[1]),
"txid": outpoint.split(":")[0],
"channel_claim_id": channel_id,
"channel_name": None
}
return r

View file

@ -43,21 +43,17 @@ dataExpireTimeout = 86400 # 24 hours
tokenSecretChangeInterval = 300 # 5 minutes tokenSecretChangeInterval = 300 # 5 minutes
peer_request_timeout = 10
######## IMPLEMENTATION-SPECIFIC CONSTANTS ########### ######## IMPLEMENTATION-SPECIFIC CONSTANTS ###########
#: The interval in which the node should check its whether any buckets need refreshing, #: The interval for the node to check whether any buckets need refreshing
#: or whether any data needs to be republished (in seconds)
checkRefreshInterval = refreshTimeout / 5 checkRefreshInterval = refreshTimeout / 5
#: Max size of a single UDP datagram, in bytes. If a message is larger than this, it will #: Max size of a single UDP datagram, in bytes. If a message is larger than this, it will
#: be spread across several UDP packets. #: be spread across several UDP packets.
udpDatagramMaxSize = 8192 # 8 KB udpDatagramMaxSize = 8192 # 8 KB
from lbrynet.core.cryptoutils import get_lbry_hash_obj key_bits = 384
h = get_lbry_hash_obj()
key_bits = h.digest_size * 8 # 384 bits
rpc_id_length = 20 rpc_id_length = 20
protocolVersion = 1

View file

@ -1,19 +1,101 @@
class Contact(object): import ipaddress
from lbrynet.dht import constants
def is_valid_ipv4(address):
try:
ip = ipaddress.ip_address(address.decode()) # this needs to be unicode, thus the decode()
return ip.version == 4
except ipaddress.AddressValueError:
return False
class _Contact(object):
""" Encapsulation for remote contact """ Encapsulation for remote contact
This class contains information on a single remote contact, and also This class contains information on a single remote contact, and also
provides a direct RPC API to the remote node which it represents provides a direct RPC API to the remote node which it represents
""" """
def __init__(self, id, ipAddress, udpPort, networkProtocol, firstComm=0): def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firstComm):
self.id = id if id is not None:
if not len(id) == constants.key_bits / 8:
raise ValueError("invalid node id: %s" % id.encode('hex'))
if not 0 <= udpPort <= 65536:
raise ValueError("invalid port")
if not is_valid_ipv4(ipAddress):
raise ValueError("invalid ip address")
self._contactManager = contactManager
self._id = id
self.address = ipAddress self.address = ipAddress
self.port = udpPort self.port = udpPort
self._networkProtocol = networkProtocol self._networkProtocol = networkProtocol
self.commTime = firstComm self.commTime = firstComm
self.getTime = self._contactManager._get_time
self.lastReplied = None
self.lastRequested = None
self.protocolVersion = 0
self._token = (None, 0) # token, timestamp
def update_token(self, token):
self._token = token, self.getTime()
@property
def token(self):
# expire the token 1 minute early to be safe
return self._token[0] if self._token[1] + 240 > self.getTime() else None
@property
def lastInteracted(self):
return max(self.lastRequested or 0, self.lastReplied or 0, self.lastFailed or 0)
@property
def id(self):
return self._id
def log_id(self, short=True):
if not self.id:
return "not initialized"
id_hex = self.id.encode('hex')
return id_hex if not short else id_hex[:8]
@property
def failedRPCs(self):
return len(self.failures)
@property
def lastFailed(self):
return self._contactManager._rpc_failures.get((self.address, self.port), [None])[-1]
@property
def failures(self):
return self._contactManager._rpc_failures.get((self.address, self.port), [])
@property
def contact_is_good(self):
"""
:return: False if contact is bad, None if contact is unknown, or True if contact is good
"""
failures = self.failures
now = self.getTime()
delay = constants.checkRefreshInterval
if failures:
if self.lastReplied and len(failures) >= 2 and self.lastReplied < failures[-2]:
return False
elif self.lastReplied and len(failures) >= 2 and self.lastReplied > failures[-2]:
pass # handled below
elif len(failures) >= 2:
return False
if self.lastReplied and self.lastReplied > now - delay:
return True
if self.lastReplied and self.lastRequested and self.lastRequested > now - delay:
return True
return None
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Contact): if isinstance(other, _Contact):
return self.id == other.id return self.id == other.id
elif isinstance(other, str): elif isinstance(other, str):
return self.id == other return self.id == other
@ -21,7 +103,7 @@ class Contact(object):
return False return False
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, Contact): if isinstance(other, _Contact):
return self.id != other.id return self.id != other.id
elif isinstance(other, str): elif isinstance(other, str):
return self.id != other return self.id != other
@ -33,6 +115,24 @@ class Contact(object):
lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray()) lambda buff, x: buff + bytearray([int(x)]), self.address.split('.'), bytearray())
return str(compact_ip) return str(compact_ip)
def set_id(self, id):
if not self._id:
self._id = id
def update_last_replied(self):
self.lastReplied = int(self.getTime())
def update_last_requested(self):
self.lastRequested = int(self.getTime())
def update_last_failed(self):
failures = self._contactManager._rpc_failures.get((self.address, self.port), [])
failures.append(self.getTime())
self._contactManager._rpc_failures[(self.address, self.port)] = failures
def update_protocol_version(self, version):
self.protocolVersion = version
def __str__(self): def __str__(self):
return '<%s.%s object; IP address: %s, UDP port: %d>' % ( return '<%s.%s object; IP address: %s, UDP port: %d>' % (
self.__module__, self.__class__.__name__, self.address, self.port) self.__module__, self.__class__.__name__, self.address, self.port)
@ -52,7 +152,38 @@ class Contact(object):
host Node's C{_protocol} object). host Node's C{_protocol} object).
""" """
if name not in ['ping', 'findValue', 'findNode', 'store']:
raise AttributeError("unknown command: %s" % name)
def _sendRPC(*args, **kwargs): def _sendRPC(*args, **kwargs):
return self._networkProtocol.sendRPC(self, name, args, **kwargs) return self._networkProtocol.sendRPC(self, name, args)
return _sendRPC return _sendRPC
class ContactManager(object):
def __init__(self, get_time=None):
if not get_time:
from twisted.internet import reactor
get_time = reactor.seconds
self._get_time = get_time
self._contacts = {}
self._rpc_failures = {}
def get_contact(self, id, address, port):
for contact in self._contacts.itervalues():
if contact.id == id and contact.address == address and contact.port == port:
return contact
def make_contact(self, id, ipAddress, udpPort, networkProtocol, firstComm=0):
ipAddress = str(ipAddress)
contact = self.get_contact(id, ipAddress, udpPort)
if contact:
return contact
contact = _Contact(self, id, ipAddress, udpPort, networkProtocol, firstComm or self._get_time())
self._contacts[(id, ipAddress, udpPort)] = contact
return contact
def is_ignored(self, origin_tuple):
failed_rpc_count = len(self._rpc_failures.get(origin_tuple, []))
return failed_rpc_count > constants.rpcAttempts

View file

@ -1,5 +1,4 @@
import UserDict import UserDict
import time
import constants import constants
from interface import IDataStore from interface import IDataStore
from zope.interface import implements from zope.interface import implements
@ -9,44 +8,61 @@ class DictDataStore(UserDict.DictMixin):
""" A datastore using an in-memory Python dictionary """ """ A datastore using an in-memory Python dictionary """
implements(IDataStore) implements(IDataStore)
def __init__(self): def __init__(self, getTime=None):
# Dictionary format: # Dictionary format:
# { <key>: (<value>, <lastPublished>, <originallyPublished> <originalPublisherID>) } # { <key>: (<contact>, <value>, <lastPublished>, <originallyPublished> <originalPublisherID>) }
self._dict = {} self._dict = {}
if not getTime:
from twisted.internet import reactor
getTime = reactor.seconds
self._getTime = getTime
def keys(self): def keys(self):
""" Return a list of the keys in this data store """ """ Return a list of the keys in this data store """
return self._dict.keys() return self._dict.keys()
def filter_bad_and_expired_peers(self, key):
"""
Returns only non-expired and unknown/good peers
"""
return filter(
lambda peer:
self._getTime() - peer[3] < constants.dataExpireTimeout and peer[0].contact_is_good is not False,
self._dict[key]
)
def filter_expired_peers(self, key):
"""
Returns only non-expired peers
"""
return filter(lambda peer: self._getTime() - peer[3] < constants.dataExpireTimeout, self._dict[key])
def removeExpiredPeers(self): def removeExpiredPeers(self):
now = int(time.time())
def notExpired(peer):
if (now - peer[2]) > constants.dataExpireTimeout:
return False
return True
for key in self._dict.keys(): for key in self._dict.keys():
unexpired_peers = filter(notExpired, self._dict[key]) unexpired_peers = self.filter_expired_peers(key)
self._dict[key] = unexpired_peers if not unexpired_peers:
del self._dict[key]
else:
self._dict[key] = unexpired_peers
def hasPeersForBlob(self, key): def hasPeersForBlob(self, key):
if key in self._dict and len(self._dict[key]) > 0: return True if key in self._dict and len(self.filter_bad_and_expired_peers(key)) else False
return True
return False
def addPeerToBlob(self, key, value, lastPublished, originallyPublished, originalPublisherID): def addPeerToBlob(self, contact, key, compact_address, lastPublished, originallyPublished, originalPublisherID):
if key in self._dict: if key in self._dict:
self._dict[key].append((value, lastPublished, originallyPublished, originalPublisherID)) if compact_address not in map(lambda store_tuple: store_tuple[1], self._dict[key]):
self._dict[key].append(
(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)
)
else: else:
self._dict[key] = [(value, lastPublished, originallyPublished, originalPublisherID)] self._dict[key] = [(contact, compact_address, lastPublished, originallyPublished, originalPublisherID)]
def getPeersForBlob(self, key): def getPeersForBlob(self, key):
if key in self._dict: return [] if key not in self._dict else [val[1] for val in self.filter_bad_and_expired_peers(key)]
return [val[0] for val in self._dict[key]]
def removePeer(self, value): def getStoringContacts(self):
contacts = set()
for key in self._dict: for key in self._dict:
self._dict[key] = [val for val in self._dict[key] if val[0] != value] for values in self._dict[key]:
if not self._dict[key]: contacts.add(values[0])
del self._dict[key] return list(contacts)

View file

@ -1,3 +1,6 @@
from lbrynet.dht import constants
class Distance(object): class Distance(object):
"""Calculate the XOR result between two string variables. """Calculate the XOR result between two string variables.
@ -6,6 +9,8 @@ class Distance(object):
""" """
def __init__(self, key): def __init__(self, key):
if len(key) != constants.key_bits / 8:
raise ValueError("invalid key length: %i" % len(key))
self.key = key self.key = key
self.val_key_one = long(key.encode('hex'), 16) self.val_key_one = long(key.encode('hex'), 16)

View file

@ -33,6 +33,13 @@ class TimeoutError(Exception):
def __init__(self, remote_contact_id): def __init__(self, remote_contact_id):
# remote_contact_id is a binary blob so we need to convert it # remote_contact_id is a binary blob so we need to convert it
# into something more readable # into something more readable
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id)) if remote_contact_id:
msg = 'Timeout connecting to {}'.format(binascii.hexlify(remote_contact_id))
else:
msg = 'Timeout connecting to uninitialized node'
Exception.__init__(self, msg) Exception.__init__(self, msg)
self.remote_contact_id = remote_contact_id self.remote_contact_id = remote_contact_id
class TransportNotConnected(Exception):
pass

View file

@ -18,6 +18,7 @@ class DHTHashAnnouncer(object):
self.concurrent_announcers = concurrent_announcers or conf.settings['concurrent_announcers'] self.concurrent_announcers = concurrent_announcers or conf.settings['concurrent_announcers']
self._manage_lc = task.LoopingCall(self.manage) self._manage_lc = task.LoopingCall(self.manage)
self._manage_lc.clock = self.clock self._manage_lc.clock = self.clock
self.sem = defer.DeferredSemaphore(self.concurrent_announcers)
def start(self): def start(self):
self._manage_lc.start(30) self._manage_lc.start(30)
@ -50,13 +51,14 @@ class DHTHashAnnouncer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def immediate_announce(self, blob_hashes): def immediate_announce(self, blob_hashes):
self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue) self.hash_queue.extend(b for b in blob_hashes if b not in self.hash_queue)
log.info("Announcing %i blobs", len(self.hash_queue)) log.info("Announcing %i blobs", len(self.hash_queue))
start = self.clock.seconds() start = self.clock.seconds()
progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start) progress_lc = task.LoopingCall(self._show_announce_progress, len(self.hash_queue), start)
progress_lc.clock = self.clock
progress_lc.start(60, now=False) progress_lc.start(60, now=False)
s = defer.DeferredSemaphore(self.concurrent_announcers) results = yield utils.DeferredDict(
results = yield utils.DeferredDict({blob_hash: s.run(self.do_store, blob_hash) for blob_hash in blob_hashes}) {blob_hash: self.sem.run(self.do_store, blob_hash) for blob_hash in blob_hashes}
)
now = self.clock.seconds() now = self.clock.seconds()
progress_lc.stop() progress_lc.stop()

View file

@ -1,33 +0,0 @@
from collections import Counter
import datetime
from twisted.internet import task
class HashWatcher(object):
def __init__(self, clock=None):
if not clock:
from twisted.internet import reactor as clock
self.ttl = 600
self.hashes = []
self.lc = task.LoopingCall(self._remove_old_hashes)
self.lc.clock = clock
def start(self):
return self.lc.start(10)
def stop(self):
return self.lc.stop()
def add_requested_hash(self, hashsum, contact):
from_ip = contact.compact_ip
matching_hashes = [h for h in self.hashes if h[0] == hashsum and h[2] == from_ip]
if len(matching_hashes) == 0:
self.hashes.append((hashsum, datetime.datetime.now(), from_ip))
def most_popular_hashes(self, num_to_return=10):
hash_counter = Counter([h[0] for h in self.hashes])
return hash_counter.most_common(num_to_return)
def _remove_old_hashes(self):
remove_time = datetime.datetime.now() - datetime.timedelta(minutes=10)
self.hashes = [h for h in self.hashes if h[1] < remove_time]

View file

@ -0,0 +1,209 @@
import logging
from twisted.internet import defer
from distance import Distance
from error import TimeoutError
import constants
log = logging.getLogger(__name__)
def get_contact(contact_list, node_id, address, port):
for contact in contact_list:
if contact.id == node_id and contact.address == address and contact.port == port:
return contact
raise IndexError(node_id)
class _IterativeFind(object):
# TODO: use polymorphism to search for a value or node
# instead of using a find_value flag
def __init__(self, node, shortlist, key, rpc):
self.node = node
self.finished_deferred = defer.Deferred()
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# The closest known and active node yet found
self.closest_node = None if not shortlist else shortlist[0]
self.prev_closest_node = None
# Shortlist of contact objects (the k closest known contacts to the key from the routing table)
self.shortlist = shortlist
# The search key
self.key = str(key)
# The rpc method name (findValue or findNode)
self.rpc = rpc
# List of active queries; len() indicates number of active probes
self.active_probes = []
# List of contact (address, port) tuples that have already been queried, includes contacts that didn't reply
self.already_contacted = []
# A list of found and known-to-be-active remote nodes (Contact objects)
self.active_contacts = []
# Ensure only one searchIteration call is running at a time
self._search_iteration_semaphore = defer.DeferredSemaphore(1)
self._iteration_count = 0
self.find_value_result = {}
self.pending_iteration_calls = []
@property
def is_find_node_request(self):
return self.rpc == "findNode"
@property
def is_find_value_request(self):
return self.rpc == "findValue"
def is_closer(self, contact):
if not self.closest_node:
return True
return self.distance.is_closer(contact.id, self.closest_node.id)
def getContactTriples(self, result):
if self.is_find_value_request:
contact_triples = result['contacts']
else:
contact_triples = result
for contact_tup in contact_triples:
if not isinstance(contact_tup, (list, tuple)) or len(contact_tup) != 3:
raise ValueError("invalid contact triple")
return contact_triples
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
contact_list.sort(key=lambda c: self.distance(c.id))
@defer.inlineCallbacks
def extendShortlist(self, contact, result):
# The "raw response" tuple contains the response message and the originating address info
originAddress = (contact.address, contact.port)
if self.finished_deferred.called:
defer.returnValue(contact.id)
if self.node.contact_manager.is_ignored(originAddress):
raise ValueError("contact is ignored")
if contact.id == self.node.node_id:
defer.returnValue(contact.id)
if contact not in self.active_contacts:
self.active_contacts.append(contact)
if contact not in self.shortlist:
self.shortlist.append(contact)
# Now grow extend the (unverified) shortlist with the returned contacts
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
if self.is_find_value_request and self.key in result:
# We have found the value
self.find_value_result[self.key] = result[self.key]
self.finished_deferred.callback(self.find_value_result)
else:
if self.is_find_value_request:
# We are looking for a value, and the remote node didn't have it
# - mark it as the closest "empty" node, if it is
# TODO: store to this peer after finding the value as per the kademlia spec
if 'closestNodeNoValue' in self.find_value_result:
if self.is_closer(contact):
self.find_value_result['closestNodeNoValue'] = contact
else:
self.find_value_result['closestNodeNoValue'] = contact
contactTriples = self.getContactTriples(result)
for contactTriple in contactTriples:
if (contactTriple[1], contactTriple[2]) in ((c.address, c.port) for c in self.already_contacted):
continue
elif self.node.contact_manager.is_ignored((contactTriple[1], contactTriple[2])):
raise ValueError("contact is ignored")
else:
found_contact = self.node.contact_manager.make_contact(contactTriple[0], contactTriple[1],
contactTriple[2], self.node._protocol)
if found_contact not in self.shortlist:
self.shortlist.append(found_contact)
if not self.finished_deferred.called and self.should_stop():
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
defer.returnValue(contact.id)
@defer.inlineCallbacks
def probeContact(self, contact):
fn = getattr(contact, self.rpc)
try:
response = yield fn(self.key)
result = yield self.extendShortlist(contact, response)
defer.returnValue(result)
except (TimeoutError, defer.CancelledError, ValueError, IndexError):
defer.returnValue(contact.id)
def should_stop(self):
if self.prev_closest_node and self.closest_node and self.distance.is_closer(self.prev_closest_node.id,
self.closest_node.id):
# we're getting further away
return True
if len(self.active_contacts) >= constants.k:
# we have enough results
return True
return False
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
def _searchIteration(self):
# Sort the discovered active nodes from closest to furthest
if len(self.active_contacts):
self.sortByDistance(self.active_contacts)
self.prev_closest_node = self.closest_node
self.closest_node = self.active_contacts[0]
# Sort the current shortList before contacting other nodes
self.sortByDistance(self.shortlist)
probes = []
already_contacted_addresses = {(c.address, c.port) for c in self.already_contacted}
to_remove = []
for contact in self.shortlist:
if (contact.address, contact.port) not in already_contacted_addresses:
self.already_contacted.append(contact)
to_remove.append(contact)
probe = self.probeContact(contact)
probes.append(probe)
self.active_probes.append(probe)
if len(probes) == constants.alpha:
break
for contact in to_remove: # these contacts will be re-added to the shortlist when they reply successfully
self.shortlist.remove(contact)
# run the probes
if probes:
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
self.searchIteration()
d = defer.DeferredList(probes, consumeErrors=True)
def _remove_probes(results):
for probe in probes:
self.active_probes.remove(probe)
return results
d.addCallback(_remove_probes)
elif not self.finished_deferred.called and not self.active_probes or self.should_stop():
# If no probes were sent, there will not be any improvement, so we're done
self.sortByDistance(self.active_contacts)
self.finished_deferred.callback(self.active_contacts[:min(constants.k, len(self.active_contacts))])
elif not self.finished_deferred.called:
# Force the next iteration
self.searchIteration()
def searchIteration(self, delay=constants.iterativeLookupDelay):
def _cancel_pending_iterations(result):
while self.pending_iteration_calls:
canceller = self.pending_iteration_calls.pop()
canceller()
return result
self.finished_deferred.addBoth(_cancel_pending_iterations)
self._iteration_count += 1
call, cancel = self.node.reactor_callLater(delay, self._search_iteration_semaphore.run, self._searchIteration)
self.pending_iteration_calls.append(cancel)
def iterativeFind(node, shortlist, key, rpc):
helper = _IterativeFind(node, shortlist, key, rpc)
helper.searchIteration(0)
return helper.finished_deferred

View file

@ -1,12 +1,16 @@
import logging
import constants import constants
from distance import Distance
from error import BucketFull from error import BucketFull
log = logging.getLogger(__name__)
class KBucket(object): class KBucket(object):
""" Description - later """ Description - later
""" """
def __init__(self, rangeMin, rangeMax): def __init__(self, rangeMin, rangeMax, node_id):
""" """
@param rangeMin: The lower boundary for the range in the n-bit ID @param rangeMin: The lower boundary for the range in the n-bit ID
space covered by this k-bucket space covered by this k-bucket
@ -17,6 +21,7 @@ class KBucket(object):
self.rangeMin = rangeMin self.rangeMin = rangeMin
self.rangeMax = rangeMax self.rangeMax = rangeMax
self._contacts = list() self._contacts = list()
self._node_id = node_id
def addContact(self, contact): def addContact(self, contact):
""" Add contact to _contact list in the right order. This will move the """ Add contact to _contact list in the right order. This will move the
@ -27,7 +32,7 @@ class KBucket(object):
already already
@param contact: The contact to add @param contact: The contact to add
@type contact: kademlia.contact.Contact @type contact: dht.contact._Contact
""" """
if contact in self._contacts: if contact in self._contacts:
# Move the existing contact to the end of the list # Move the existing contact to the end of the list
@ -41,23 +46,35 @@ class KBucket(object):
raise BucketFull("No space in bucket to insert contact") raise BucketFull("No space in bucket to insert contact")
def getContact(self, contactID): def getContact(self, contactID):
""" Get the contact specified node ID""" """Get the contact specified node ID
index = self._contacts.index(contactID)
return self._contacts[index]
def getContacts(self, count=-1, excludeContact=None): @raise IndexError: raised if the contact is not in the bucket
@param contactID: the node id of the contact to retrieve
@type contactID: str
@rtype: dht.contact._Contact
"""
for contact in self._contacts:
if contact.id == contactID:
return contact
raise IndexError(contactID)
def getContacts(self, count=-1, excludeContact=None, sort_distance_to=None):
""" Returns a list containing up to the first count number of contacts """ Returns a list containing up to the first count number of contacts
@param count: The amount of contacts to return (if 0 or less, return @param count: The amount of contacts to return (if 0 or less, return
all contacts) all contacts)
@type count: int @type count: int
@param excludeContact: A contact to exclude; if this contact is in @param excludeContact: A node id to exclude; if this contact is in
the list of returned values, it will be the list of returned values, it will be
discarded before returning. If a C{str} is discarded before returning. If a C{str} is
passed as this argument, it must be the passed as this argument, it must be the
contact's ID. contact's ID.
@type excludeContact: kademlia.contact.Contact or str @type excludeContact: str
@param sort_distance_to: Sort distance to the id, defaulting to the parent node id. If False don't
sort the contacts
@raise IndexError: If the number of requested contacts is too large @raise IndexError: If the number of requested contacts is too large
@ -65,40 +82,41 @@ class KBucket(object):
If no contacts are present an empty is returned If no contacts are present an empty is returned
@rtype: list @rtype: list
""" """
contacts = [contact for contact in self._contacts if contact.id != excludeContact]
# Return all contacts in bucket # Return all contacts in bucket
if count <= 0: if count <= 0:
count = len(self._contacts) count = len(contacts)
# Get current contact number # Get current contact number
currentLen = len(self._contacts) currentLen = len(contacts)
# If count greater than k - return only k contacts # If count greater than k - return only k contacts
if count > constants.k: if count > constants.k:
count = constants.k count = constants.k
# Check if count value in range and,
# if count number of contacts are available
if not currentLen: if not currentLen:
contactList = list() return contacts
# length of list less than requested amount if sort_distance_to is False:
elif currentLen < count: pass
contactList = self._contacts[0:currentLen]
# enough contacts in list
else: else:
contactList = self._contacts[0:count] sort_distance_to = sort_distance_to or self._node_id
contacts.sort(key=lambda c: Distance(sort_distance_to)(c.id))
if excludeContact in contactList: return contacts[:min(currentLen, count)]
contactList.remove(excludeContact)
return contactList def getBadOrUnknownContacts(self):
contacts = self.getContacts(sort_distance_to=False)
results = [contact for contact in contacts if contact.contact_is_good is False]
results.extend(contact for contact in contacts if contact.contact_is_good is None)
return results
def removeContact(self, contact): def removeContact(self, contact):
""" Remove given contact from list """ Remove the contact from the bucket
@param contact: The contact to remove, or a string containing the @param contact: The contact to remove
contact's node ID @type contact: dht.contact._Contact
@type contact: kademlia.contact.Contact or str
@raise ValueError: The specified contact is not in this bucket @raise ValueError: The specified contact is not in this bucket
""" """
@ -123,3 +141,6 @@ class KBucket(object):
def __len__(self): def __len__(self):
return len(self._contacts) return len(self._contacts)
def __contains__(self, item):
return item in self._contacts

View file

@ -8,30 +8,33 @@
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
import binascii import binascii
import hashlib import hashlib
import operator
import struct import struct
import time
import logging import logging
from twisted.internet import defer, error, task from twisted.internet import defer, error, task
from lbrynet.core.utils import generate_id from lbrynet.core.utils import generate_id, DeferredDict
from lbrynet.core.call_later_manager import CallLaterManager from lbrynet.core.call_later_manager import CallLaterManager
from lbrynet.core.PeerManager import PeerManager from lbrynet.core.PeerManager import PeerManager
from error import TimeoutError
import constants import constants
import routingtable import routingtable
import datastore import datastore
import protocol import protocol
from error import TimeoutError
from peerfinder import DHTPeerFinder from peerfinder import DHTPeerFinder
from contact import Contact from contact import ContactManager
from hashwatcher import HashWatcher from iterativefind import iterativeFind
from distance import Distance
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def expand_peer(compact_peer_info):
host = ".".join([str(ord(d)) for d in compact_peer_info[:4]])
port, = struct.unpack('>H', compact_peer_info[4:6])
peer_node_id = compact_peer_info[6:]
return (peer_node_id, host, port)
def rpcmethod(func): def rpcmethod(func):
""" Decorator to expose Node methods as remote procedure calls """ Decorator to expose Node methods as remote procedure calls
@ -42,7 +45,46 @@ def rpcmethod(func):
return func return func
class Node(object): class MockKademliaHelper(object):
def __init__(self, clock=None, callLater=None, resolve=None, listenUDP=None):
if not listenUDP or not resolve or not callLater or not clock:
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
self.clock = clock
self.contact_manager = ContactManager(self.clock.seconds)
self.reactor_listenUDP = listenUDP
self.reactor_resolve = resolve
self.call_later_manager = CallLaterManager(callLater)
self.reactor_callLater = self.call_later_manager.call_later
self.reactor_callSoon = self.call_later_manager.call_soon
self._listeningPort = None # object implementing Twisted
# IListeningPort This will contain a deferred created when
# joining the network, to enable publishing/retrieving
# information from the DHT as soon as the node is part of the
# network (add callbacks to this deferred if scheduling such
# operations before the node has finished joining the network)
def get_looping_call(self, fn, *args, **kwargs):
lc = task.LoopingCall(fn, *args, **kwargs)
lc.clock = self.clock
return lc
def safe_stop_looping_call(self, lc):
if lc and lc.running:
return lc.stop()
return defer.succeed(None)
def safe_start_looping_call(self, lc, t):
if lc and not lc.running:
lc.start(t)
class Node(MockKademliaHelper):
""" Local node in the Kademlia network """ Local node in the Kademlia network
This class represents a single local node in a Kademlia network; in other This class represents a single local node in a Kademlia network; in other
@ -55,7 +97,7 @@ class Node(object):
def __init__(self, node_id=None, udpPort=4000, dataStore=None, def __init__(self, node_id=None, udpPort=4000, dataStore=None,
routingTableClass=None, networkProtocol=None, routingTableClass=None, networkProtocol=None,
externalIP=None, peerPort=None, listenUDP=None, externalIP=None, peerPort=3333, listenUDP=None,
callLater=None, resolve=None, clock=None, peer_finder=None, callLater=None, resolve=None, clock=None, peer_finder=None,
peer_manager=None): peer_manager=None):
""" """
@ -83,35 +125,16 @@ class Node(object):
@param peerPort: the port at which this node announces it has a blob for @param peerPort: the port at which this node announces it has a blob for
""" """
if not listenUDP or not resolve or not callLater or not clock: MockKademliaHelper.__init__(self, clock, callLater, resolve, listenUDP)
from twisted.internet import reactor
listenUDP = listenUDP or reactor.listenUDP
resolve = resolve or reactor.resolve
callLater = callLater or reactor.callLater
clock = clock or reactor
self.clock = clock
CallLaterManager.setup(callLater)
self.reactor_resolve = resolve
self.reactor_listenUDP = listenUDP
self.reactor_callLater = CallLaterManager.call_later
self.reactor_callSoon = CallLaterManager.call_soon
self.node_id = node_id or self._generateID() self.node_id = node_id or self._generateID()
self.port = udpPort self.port = udpPort
self._listeningPort = None # object implementing Twisted self._change_token_lc = self.get_looping_call(self.change_token)
# IListeningPort This will contain a deferred created when self._refresh_node_lc = self.get_looping_call(self._refreshNode)
# joining the network, to enable publishing/retrieving self._refresh_contacts_lc = self.get_looping_call(self._refreshContacts)
# information from the DHT as soon as the node is part of the
# network (add callbacks to this deferred if scheduling such
# operations before the node has finished joining the network)
self._joinDeferred = defer.Deferred(None)
self.change_token_lc = task.LoopingCall(self.change_token)
self.change_token_lc.clock = self.clock
self.refresh_node_lc = task.LoopingCall(self._refreshNode)
self.refresh_node_lc.clock = self.clock
# Create k-buckets (for storing contacts) # Create k-buckets (for storing contacts)
if routingTableClass is None: if routingTableClass is None:
self._routingTable = routingtable.OptimizedTreeRoutingTable(self.node_id, self.clock.seconds) self._routingTable = routingtable.TreeRoutingTable(self.node_id, self.clock.seconds)
else: else:
self._routingTable = routingTableClass(self.node_id, self.clock.seconds) self._routingTable = routingTableClass(self.node_id, self.clock.seconds)
@ -123,41 +146,27 @@ class Node(object):
# Initialize the data storage mechanism used by this node # Initialize the data storage mechanism used by this node
self.token_secret = self._generateID() self.token_secret = self._generateID()
self.old_token_secret = None self.old_token_secret = None
if dataStore is None:
self._dataStore = datastore.DictDataStore()
else:
self._dataStore = dataStore
# Try to restore the node's state...
if 'nodeState' in self._dataStore:
state = self._dataStore['nodeState']
self.node_id = state['id']
for contactTriple in state['closestNodes']:
contact = Contact(
contactTriple[0], contactTriple[1], contactTriple[2], self._protocol)
self._routingTable.addContact(contact)
self.externalIP = externalIP self.externalIP = externalIP
self.peerPort = peerPort self.peerPort = peerPort
self.hash_watcher = HashWatcher(self.clock) self._dataStore = dataStore or datastore.DictDataStore(self.clock.seconds)
self.peer_manager = peer_manager or PeerManager() self.peer_manager = peer_manager or PeerManager()
self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager) self.peer_finder = peer_finder or DHTPeerFinder(self, self.peer_manager)
self._join_deferred = None
def __del__(self): def __del__(self):
log.warning("unclean shutdown of the dht node") log.warning("unclean shutdown of the dht node")
if self._listeningPort is not None: if hasattr(self, "_listeningPort") and self._listeningPort is not None:
self._listeningPort.stopListening() self._listeningPort.stopListening()
@defer.inlineCallbacks @defer.inlineCallbacks
def stop(self): def stop(self):
# stop LoopingCalls: # stop LoopingCalls:
if self.refresh_node_lc.running: yield self.safe_stop_looping_call(self._refresh_node_lc)
yield self.refresh_node_lc.stop() yield self.safe_stop_looping_call(self._change_token_lc)
if self.change_token_lc.running: yield self.safe_stop_looping_call(self._refresh_contacts_lc)
yield self.change_token_lc.stop()
if self._listeningPort is not None: if self._listeningPort is not None:
yield self._listeningPort.stopListening() yield self._listeningPort.stopListening()
if self.hash_watcher.lc.running: self._listeningPort = None
yield self.hash_watcher.stop()
def start_listening(self): def start_listening(self):
if not self._listeningPort: if not self._listeningPort:
@ -168,45 +177,93 @@ class Node(object):
log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc()) log.error("Couldn't bind to port %d. %s", self.port, traceback.format_exc())
raise ValueError("%s lbrynet may already be running." % str(e)) raise ValueError("%s lbrynet may already be running." % str(e))
else: else:
log.warning("Already bound to port %d", self._listeningPort.port) log.warning("Already bound to port %s", self._listeningPort)
def bootstrap_join(self, known_node_addresses, finished_d): @defer.inlineCallbacks
def joinNetwork(self, known_node_addresses=(('jack.lbry.tech', 4455), )):
""" """
Attempt to join the dht, retry every 30 seconds if unsuccessful Attempt to join the dht, retry every 30 seconds if unsuccessful
:param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes :param known_node_addresses: [(str, int)] list of hostnames and ports for known dht seed nodes
:param finished_d: (defer.Deferred) called when join succeeds
""" """
self._join_deferred = defer.Deferred()
known_node_resolution = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _resolve_seeds(): def _resolve_seeds():
result = {}
for host, port in known_node_addresses:
node_address = yield self.reactor_resolve(host)
result[(host, port)] = node_address
defer.returnValue(result)
if not known_node_resolution:
known_node_resolution = yield _resolve_seeds()
# we are one of the seed nodes, don't add ourselves
if (self.externalIP, self.port) in known_node_resolution.itervalues():
del known_node_resolution[(self.externalIP, self.port)]
known_node_addresses.remove((self.externalIP, self.port))
def _ping_contacts(contacts):
d = DeferredDict({contact: contact.ping() for contact in contacts}, consumeErrors=True)
d.addErrback(lambda err: err.trap(TimeoutError))
return d
@defer.inlineCallbacks
def _initialize_routing():
bootstrap_contacts = [] bootstrap_contacts = []
for node_address, port in known_node_addresses: contact_addresses = {(c.address, c.port): c for c in self.contacts}
host = yield self.reactor_resolve(node_address) for (host, port), ip_address in known_node_resolution.iteritems():
# Create temporary contact information for the list of addresses of known nodes if (host, port) not in contact_addresses:
contact = Contact(self._generateID(), host, port, self._protocol) # Create temporary contact information for the list of addresses of known nodes
bootstrap_contacts.append(contact) # The contact node id will be set with the responding node id when we initialize it to None
if not bootstrap_contacts: contact = self.contact_manager.make_contact(None, ip_address, port, self._protocol)
if not self.hasContacts(): bootstrap_contacts.append(contact)
log.warning("No known contacts!")
else: else:
log.info("found contacts") for contact in self.contacts:
bootstrap_contacts = self.contacts if contact.address == ip_address and contact.port == port:
defer.returnValue(bootstrap_contacts) if not contact.id:
bootstrap_contacts.append(contact)
break
if not bootstrap_contacts:
log.warning("no bootstrap contacts to ping")
ping_result = yield _ping_contacts(bootstrap_contacts)
shortlist = ping_result.keys()
if not shortlist:
log.warning("failed to ping %i bootstrap contacts", len(bootstrap_contacts))
defer.returnValue(None)
else:
# find the closest peers to us
closest = yield self._iterativeFind(self.node_id, shortlist if not self.contacts else None)
yield _ping_contacts(closest)
# # query random hashes in our bucket key ranges to fill or split them
# random_ids_in_range = self._routingTable.getRefreshList()
# while random_ids_in_range:
# yield self.iterativeFindNode(random_ids_in_range.pop())
defer.returnValue(None)
def _rerun(closest_nodes): @defer.inlineCallbacks
if not closest_nodes: def _iterative_join(joined_d=None, last_buckets_with_contacts=None):
log.info("Failed to join the dht, re-attempting in 30 seconds") log.info("Attempting to join the DHT network, %i contacts known so far", len(self.contacts))
self.reactor_callLater(30, self.bootstrap_join, known_node_addresses, finished_d) joined_d = joined_d or defer.Deferred()
elif not finished_d.called: yield _initialize_routing()
finished_d.callback(closest_nodes) buckets_with_contacts = self.bucketsWithContacts()
if last_buckets_with_contacts and last_buckets_with_contacts == buckets_with_contacts:
if not joined_d.called:
joined_d.callback(True)
elif buckets_with_contacts < 4:
self.reactor_callLater(0, _iterative_join, joined_d, buckets_with_contacts)
elif not joined_d.called:
joined_d.callback(None)
yield joined_d
if not self._join_deferred.called:
self._join_deferred.callback(True)
defer.returnValue(None)
log.info("Attempting to join the DHT network") yield _iterative_join()
d = _resolve_seeds()
# Initiate the Kademlia joining sequence - perform a search for this node's own ID
d.addCallback(lambda contacts: self._iterativeFind(self.node_id, contacts))
d.addCallback(_rerun)
@defer.inlineCallbacks @defer.inlineCallbacks
def joinNetwork(self, known_node_addresses=None): def start(self, known_node_addresses=None):
""" Causes the Node to attempt to join the DHT network by contacting the """ Causes the Node to attempt to join the DHT network by contacting the
known DHT nodes. This can be called multiple times if the previous attempt known DHT nodes. This can be called multiple times if the previous attempt
has failed or if the Node has lost all the contacts. has failed or if the Node has lost all the contacts.
@ -219,13 +276,14 @@ class Node(object):
""" """
self.start_listening() self.start_listening()
# #TODO: Refresh all k-buckets further away than this node's closest neighbour yield self._protocol._listening
# TODO: Refresh all k-buckets further away than this node's closest neighbour
yield self.joinNetwork(known_node_addresses or [])
self.safe_start_looping_call(self._change_token_lc, constants.tokenSecretChangeInterval)
# Start refreshing k-buckets periodically, if necessary # Start refreshing k-buckets periodically, if necessary
self.bootstrap_join(known_node_addresses or [], self._joinDeferred) self.safe_start_looping_call(self._refresh_node_lc, constants.checkRefreshInterval)
yield self._joinDeferred self.safe_start_looping_call(self._refresh_contacts_lc, 60)
self.hash_watcher.start()
self.change_token_lc.start(constants.tokenSecretChangeInterval)
self.refresh_node_lc.start(constants.checkRefreshInterval)
@property @property
def contacts(self): def contacts(self):
@ -241,85 +299,44 @@ class Node(object):
return True return True
return False return False
def announceHaveBlob(self, key): def bucketsWithContacts(self):
return self.iterativeAnnounceHaveBlob( return self._routingTable.bucketsWithContacts()
key, {
'port': self.peerPort,
'lbryid': self.node_id,
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def getPeersForBlob(self, blob_hash, include_node_ids=False): def storeToContact(self, blob_hash, contact):
result = yield self.iterativeFindValue(blob_hash) try:
expanded_peers = [] token = contact.token
if result: if not token:
if blob_hash in result: find_value_response = yield contact.findValue(blob_hash)
for peer in result[blob_hash]: token = find_value_response['token']
host = ".".join([str(ord(d)) for d in peer[:4]]) contact.update_token(token)
port, = struct.unpack('>H', peer[4:6]) res = yield contact.store(blob_hash, token, self.peerPort, self.node_id, 0)
if not include_node_ids: if res != "OK":
if (host, port) not in expanded_peers: raise ValueError(res)
expanded_peers.append((host, port)) defer.returnValue(True)
else: log.debug("Stored %s to %s (%s)", binascii.hexlify(blob_hash), contact.log_id(), contact.address)
peer_node_id = peer[6:].encode('hex') except protocol.TimeoutError:
if (host, port, peer_node_id) not in expanded_peers: log.debug("Timeout while storing blob_hash %s at %s",
expanded_peers.append((host, port, peer_node_id)) binascii.hexlify(blob_hash), contact.log_id())
defer.returnValue(expanded_peers) except ValueError as err:
log.error("Unexpected response: %s" % err.message)
def get_most_popular_hashes(self, num_to_return): except Exception as err:
return self.hash_watcher.most_popular_hashes(num_to_return) log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def iterativeAnnounceHaveBlob(self, blob_hash, value): def announceHaveBlob(self, blob_hash):
known_nodes = {}
contacts = yield self.iterativeFindNode(blob_hash) contacts = yield self.iterativeFindNode(blob_hash)
# store locally if we're the closest node and there are less than k contacts to try storing to
if self.externalIP is not None and contacts and len(contacts) < constants.k: if not self.externalIP:
is_closer = Distance(blob_hash).is_closer(self.node_id, contacts[-1].id)
if is_closer:
contacts.pop()
yield self.store(blob_hash, value, originalPublisherID=self.node_id,
self_store=True)
elif self.externalIP is not None:
pass
else:
raise Exception("Cannot determine external IP: %s" % self.externalIP) raise Exception("Cannot determine external IP: %s" % self.externalIP)
stored_to = yield DeferredDict({contact: self.storeToContact(blob_hash, contact) for contact in contacts})
contacted = [] contacted_node_ids = map(
lambda contact: contact.id.encode('hex'), filter(lambda contact: stored_to[contact], stored_to.keys())
@defer.inlineCallbacks )
def announce_to_contact(contact): log.debug("Stored %s to %i of %i attempted peers", binascii.hexlify(blob_hash),
known_nodes[contact.id] = contact len(contacted_node_ids), len(contacts))
try:
responseMsg, originAddress = yield contact.findValue(blob_hash, rawResponse=True)
if responseMsg.nodeID != contact.id:
raise Exception("node id mismatch")
value['token'] = responseMsg.response['token']
res = yield contact.store(blob_hash, value)
if res != "OK":
raise ValueError(res)
contacted.append(contact)
log.debug("Stored %s to %s (%s)", blob_hash.encode('hex'), contact.id.encode('hex'), originAddress[0])
except protocol.TimeoutError:
log.debug("Timeout while storing blob_hash %s at %s",
blob_hash.encode('hex')[:16], contact.id.encode('hex'))
except ValueError as err:
log.error("Unexpected response: %s" % err.message)
except Exception as err:
log.error("Unexpected error while storing blob_hash %s at %s: %s",
binascii.hexlify(blob_hash), contact, err)
dl = []
for c in contacts:
dl.append(announce_to_contact(c))
yield defer.DeferredList(dl)
log.debug("Stored %s to %i of %i attempted peers", blob_hash.encode('hex')[:16],
len(contacted), len(contacts))
contacted_node_ids = [c.id.encode('hex') for c in contacted]
defer.returnValue(contacted_node_ids) defer.returnValue(contacted_node_ids)
def change_token(self): def change_token(self):
@ -379,12 +396,15 @@ class Node(object):
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
if len(key) != constants.key_bits / 8:
raise ValueError("invalid key length!")
# Execute the search # Execute the search
iterative_find_result = yield self._iterativeFind(key, rpc='findValue') find_result = yield self._iterativeFind(key, rpc='findValue')
if isinstance(iterative_find_result, dict): if isinstance(find_result, dict):
# We have found the value; now see who was the closest contact without it... # We have found the value; now see who was the closest contact without it...
# ...and store the key/value pair # ...and store the key/value pair
defer.returnValue(iterative_find_result) pass
else: else:
# The value wasn't found, but a list of contacts was returned # The value wasn't found, but a list of contacts was returned
# Now, see if we have the value (it might seem wasteful to search on the network # Now, see if we have the value (it might seem wasteful to search on the network
@ -394,10 +414,26 @@ class Node(object):
# Ok, we have the value locally, so use that # Ok, we have the value locally, so use that
# Send this value to the closest node without it # Send this value to the closest node without it
peers = self._dataStore.getPeersForBlob(key) peers = self._dataStore.getPeersForBlob(key)
defer.returnValue({key: peers}) find_result = {key: peers}
else: else:
# Ok, value does not exist in DHT at all pass
defer.returnValue(iterative_find_result)
expanded_peers = []
if find_result:
if key in find_result:
for peer in find_result[key]:
expanded = expand_peer(peer)
if expanded not in expanded_peers:
expanded_peers.append(expanded)
# TODO: get this working
# if 'closestNodeNoValue' in find_result:
# closest_node_without_value = find_result['closestNodeNoValue']
# try:
# response, address = yield closest_node_without_value.findValue(key, rawResponse=True)
# yield closest_node_without_value.store(key, response.response['token'], self.peerPort)
# except TimeoutError:
# pass
defer.returnValue(expanded_peers)
def addContact(self, contact): def addContact(self, contact):
""" Add/update the given contact; simple wrapper for the same method """ Add/update the given contact; simple wrapper for the same method
@ -406,17 +442,17 @@ class Node(object):
@param contact: The contact to add to this node's k-buckets @param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact @type contact: kademlia.contact.Contact
""" """
self._routingTable.addContact(contact) return self._routingTable.addContact(contact)
def removeContact(self, contactID): def removeContact(self, contact):
""" Remove the contact with the specified node ID from this node's """ Remove the contact with the specified node ID from this node's
table of known nodes. This is a simple wrapper for the same method table of known nodes. This is a simple wrapper for the same method
in this object's RoutingTable object in this object's RoutingTable object
@param contactID: The node ID of the contact to remove @param contact: The Contact object to remove
@type contactID: str @type contact: _Contact
""" """
self._routingTable.removeContact(contactID) self._routingTable.removeContact(contact)
def findContact(self, contactID): def findContact(self, contactID):
""" Find a entangled.kademlia.contact.Contact object for the specified """ Find a entangled.kademlia.contact.Contact object for the specified
@ -433,10 +469,11 @@ class Node(object):
contact = self._routingTable.getContact(contactID) contact = self._routingTable.getContact(contactID)
df = defer.Deferred() df = defer.Deferred()
df.callback(contact) df.callback(contact)
except ValueError: except (ValueError, IndexError):
def parseResults(nodes): def parseResults(nodes):
node_ids = [c.id for c in nodes]
if contactID in nodes: if contactID in nodes:
contact = nodes[nodes.index(contactID)] contact = nodes[node_ids.index(contactID)]
return contact return contact
else: else:
return None return None
@ -454,16 +491,21 @@ class Node(object):
return 'pong' return 'pong'
@rpcmethod @rpcmethod
def store(self, key, value, originalPublisherID=None, self_store=False, **kwargs): def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
""" Store the received data in this node's local hash table """ Store the received data in this node's local datastore
@param key: The hashtable key of the data @param blob_hash: The hash of the data
@type key: str @type blob_hash: str
@param value: The actual data (the value associated with C{key})
@type value: str @param token: The token we previously returned when this contact sent us a findValue
@param originalPublisherID: The node ID of the node that is the @type token: str
B{original} publisher of the data
@param port: The TCP port the contact is listening on for requests for this blob (the peerPort)
@type port: int
@param originalPublisherID: The node ID of the node that is the publisher of the data
@type originalPublisherID: str @type originalPublisherID: str
@param age: The relative age of the data (time in seconds since it was @param age: The relative age of the data (time in seconds since it was
originally published). Note that the original publish time originally published). Note that the original publish time
isn't actually given, to compensate for clock skew between isn't actually given, to compensate for clock skew between
@ -471,59 +513,26 @@ class Node(object):
@type age: int @type age: int
@rtype: str @rtype: str
@todo: Since the data (value) may be large, passing it around as a buffer
(which is the case currently) might not be a good idea... will have
to fix this (perhaps use a stream from the Protocol class?)
""" """
# Get the sender's ID (if any)
if originalPublisherID is None: if originalPublisherID is None:
if '_rpcNodeID' in kwargs: originalPublisherID = rpc_contact.id
originalPublisherID = kwargs['_rpcNodeID'] compact_ip = rpc_contact.compact_ip()
else: if not self.verify_token(token, compact_ip):
raise TypeError, 'No NodeID given. Therefore we can\'t store this node' raise ValueError("Invalid token")
if 0 <= port <= 65536:
if self_store is True and self.externalIP: compact_port = str(struct.pack('>H', port))
contact = Contact(self.node_id, self.externalIP, self.port, None, None)
compact_ip = contact.compact_ip()
elif '_rpcNodeContact' in kwargs:
contact = kwargs['_rpcNodeContact']
compact_ip = contact.compact_ip()
else: else:
raise TypeError, 'No contact info available' raise TypeError('Invalid port')
compact_address = compact_ip + compact_port + rpc_contact.id
if not self_store: now = int(self.clock.seconds())
if 'token' not in value: originallyPublished = now - age
raise ValueError("Missing token") self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished,
if not self.verify_token(value['token'], compact_ip):
raise ValueError("Invalid token")
if 'port' in value:
port = int(value['port'])
if 0 <= port <= 65536:
compact_port = str(struct.pack('>H', port))
else:
raise TypeError('Invalid port')
else:
raise TypeError('No port available')
if 'lbryid' in value:
if len(value['lbryid']) != constants.key_bits / 8:
raise ValueError('Invalid lbryid (%i bytes): %s' % (len(value['lbryid']),
value['lbryid'].encode('hex')))
else:
compact_address = compact_ip + compact_port + value['lbryid']
else:
raise TypeError('No lbryid given')
now = int(time.time())
originallyPublished = now # - age
self._dataStore.addPeerToBlob(key, compact_address, now, originallyPublished,
originalPublisherID) originalPublisherID)
return 'OK' return 'OK'
@rpcmethod @rpcmethod
def findNode(self, key, **kwargs): def findNode(self, rpc_contact, key):
""" Finds a number of known nodes closest to the node/value with the """ Finds a number of known nodes closest to the node/value with the
specified key. specified key.
@ -536,20 +545,17 @@ class Node(object):
node is returning all of the contacts that it knows of. node is returning all of the contacts that it knows of.
@rtype: list @rtype: list
""" """
if len(key) != constants.key_bits / 8:
raise ValueError("invalid contact id length: %i" % len(key))
# Get the sender's ID (if any) contacts = self._routingTable.findCloseNodes(key, constants.k, rpc_contact.id)
if '_rpcNodeID' in kwargs:
rpc_sender_id = kwargs['_rpcNodeID']
else:
rpc_sender_id = None
contacts = self._routingTable.findCloseNodes(key, constants.k, rpc_sender_id)
contact_triples = [] contact_triples = []
for contact in contacts: for contact in contacts:
contact_triples.append((contact.id, contact.address, contact.port)) contact_triples.append((contact.id, contact.address, contact.port))
return contact_triples return contact_triples
@rpcmethod @rpcmethod
def findValue(self, key, **kwargs): def findValue(self, rpc_contact, key):
""" Return the value associated with the specified key if present in """ Return the value associated with the specified key if present in
this node's data, otherwise execute FIND_NODE for the key this node's data, otherwise execute FIND_NODE for the key
@ -561,17 +567,21 @@ class Node(object):
@rtype: dict or list @rtype: dict or list
""" """
if len(key) != constants.key_bits / 8:
raise ValueError("invalid blob hash length: %i" % len(key))
response = {
'token': self.make_token(rpc_contact.compact_ip()),
}
if self._protocol._protocolVersion:
response['protocolVersion'] = self._protocol._protocolVersion
if self._dataStore.hasPeersForBlob(key): if self._dataStore.hasPeersForBlob(key):
rval = {key: self._dataStore.getPeersForBlob(key)} response[key] = self._dataStore.getPeersForBlob(key)
else: else:
contact_triples = self.findNode(key, **kwargs) response['contacts'] = self.findNode(rpc_contact, key)
rval = {'contacts': contact_triples} return response
if '_rpcNodeContact' in kwargs:
contact = kwargs['_rpcNodeContact']
compact_ip = contact.compact_ip()
rval['token'] = self.make_token(compact_ip)
self.hash_watcher.add_requested_hash(key, contact)
return rval
def _generateID(self): def _generateID(self):
""" Generates an n-bit pseudo-random identifier """ Generates an n-bit pseudo-random identifier
@ -581,6 +591,8 @@ class Node(object):
""" """
return generate_id() return generate_id()
# from lbrynet.core.utils import profile_deferred
# @profile_deferred()
@defer.inlineCallbacks @defer.inlineCallbacks
def _iterativeFind(self, key, startupShortlist=None, rpc='findNode'): def _iterativeFind(self, key, startupShortlist=None, rpc='findNode'):
""" The basic Kademlia iterative lookup operation (for nodes/values) """ The basic Kademlia iterative lookup operation (for nodes/values)
@ -610,13 +622,15 @@ class Node(object):
return a list of the k closest nodes to the specified key return a list of the k closest nodes to the specified key
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
findValue = rpc != 'findNode'
if len(key) != constants.key_bits / 8:
raise ValueError("invalid key length: %i" % len(key))
if startupShortlist is None: if startupShortlist is None:
shortlist = self._routingTable.findCloseNodes(key, constants.k) shortlist = self._routingTable.findCloseNodes(key, constants.k)
if key != self.node_id: # if key != self.node_id:
# Update the "last accessed" timestamp for the appropriate k-bucket # # Update the "last accessed" timestamp for the appropriate k-bucket
self._routingTable.touchKBucket(key) # self._routingTable.touchKBucket(key)
if len(shortlist) == 0: if len(shortlist) == 0:
log.warning("This node doesnt know any other nodes") log.warning("This node doesnt know any other nodes")
# This node doesn't know of any other nodes # This node doesn't know of any other nodes
@ -625,26 +639,32 @@ class Node(object):
result = yield fakeDf result = yield fakeDf
defer.returnValue(result) defer.returnValue(result)
else: else:
# This is used during the bootstrap process; node ID's are most probably fake # This is used during the bootstrap process
shortlist = startupShortlist shortlist = startupShortlist
outerDf = defer.Deferred() result = yield iterativeFind(self, shortlist, key, rpc)
helper = _IterativeFindHelper(self, outerDf, shortlist, key, findValue, rpc)
# Start the iterations
helper.searchIteration()
result = yield outerDf
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _refreshNode(self): def _refreshNode(self):
""" Periodically called to perform k-bucket refreshes and data """ Periodically called to perform k-bucket refreshes and data
replication/republishing as necessary """ replication/republishing as necessary """
yield self._refreshRoutingTable() yield self._refreshRoutingTable()
self._dataStore.removeExpiredPeers() self._dataStore.removeExpiredPeers()
yield self._refreshStoringPeers()
defer.returnValue(None) defer.returnValue(None)
def _refreshContacts(self):
return defer.DeferredList(
[self._protocol._ping_queue.enqueue_maybe_ping(contact, delay=0) for contact in self.contacts]
)
def _refreshStoringPeers(self):
storing_contacts = self._dataStore.getStoringContacts()
return defer.DeferredList(
[self._protocol._ping_queue.enqueue_maybe_ping(contact, delay=0) for contact in storing_contacts]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _refreshRoutingTable(self): def _refreshRoutingTable(self):
nodeIDs = self._routingTable.getRefreshList(0, False) nodeIDs = self._routingTable.getRefreshList(0, False)
@ -652,242 +672,3 @@ class Node(object):
searchID = nodeIDs.pop() searchID = nodeIDs.pop()
yield self.iterativeFindNode(searchID) yield self.iterativeFindNode(searchID)
defer.returnValue(None) defer.returnValue(None)
# This was originally a set of nested methods in _iterativeFind
# but they have been moved into this helper class in-order to
# have better scoping and readability
class _IterativeFindHelper(object):
# TODO: use polymorphism to search for a value or node
# instead of using a find_value flag
def __init__(self, node, outer_d, shortlist, key, find_value, rpc):
self.node = node
self.outer_d = outer_d
self.shortlist = shortlist
self.key = key
self.find_value = find_value
self.rpc = rpc
# all distance operations in this class only care about the distance
# to self.key, so this makes it easier to calculate those
self.distance = Distance(key)
# List of active queries; len() indicates number of active probes
#
# n.b: using lists for these variables, because Python doesn't
# allow binding a new value to a name in an enclosing
# (non-global) scope
self.active_probes = []
# List of contact IDs that have already been queried
self.already_contacted = []
# Probes that were active during the previous iteration
# A list of found and known-to-be-active remote nodes
self.active_contacts = []
# This should only contain one entry; the next scheduled iteration call
self.pending_iteration_calls = []
self.prev_closest_node = [None]
self.find_value_result = {}
self.slow_node_count = [0]
def extendShortlist(self, responseTuple):
""" @type responseMsg: kademlia.msgtypes.ResponseMessage """
# The "raw response" tuple contains the response message,
# and the originating address info
responseMsg = responseTuple[0]
originAddress = responseTuple[1] # tuple: (ip adress, udp port)
# Make sure the responding node is valid, and abort the operation if it isn't
if responseMsg.nodeID in self.active_contacts or responseMsg.nodeID == self.node.node_id:
return responseMsg.nodeID
# Mark this node as active
aContact = self._getActiveContact(responseMsg, originAddress)
self.active_contacts.append(aContact)
# This makes sure "bootstrap"-nodes with "fake" IDs don't get queried twice
if responseMsg.nodeID not in self.already_contacted:
self.already_contacted.append(responseMsg.nodeID)
# Now grow extend the (unverified) shortlist with the returned contacts
result = responseMsg.response
# TODO: some validation on the result (for guarding against attacks)
# If we are looking for a value, first see if this result is the value
# we are looking for before treating it as a list of contact triples
if self.find_value is True and self.key in result and not 'contacts' in result:
# We have found the value
self.find_value_result[self.key] = result[self.key]
else:
if self.find_value is True:
self._setClosestNodeValue(responseMsg, aContact)
self._keepSearching(result)
return responseMsg.nodeID
def _getActiveContact(self, responseMsg, originAddress):
if responseMsg.nodeID in self.shortlist:
# Get the contact information from the shortlist...
return self.shortlist[self.shortlist.index(responseMsg.nodeID)]
else:
# If it's not in the shortlist; we probably used a fake ID to reach it
# - reconstruct the contact, using the real node ID this time
return Contact(
responseMsg.nodeID, originAddress[0], originAddress[1], self.node._protocol)
def _keepSearching(self, result):
contactTriples = self._getContactTriples(result)
for contactTriple in contactTriples:
self._addIfValid(contactTriple)
def _getContactTriples(self, result):
if self.find_value is True:
return result['contacts']
else:
return result
def _setClosestNodeValue(self, responseMsg, aContact):
# We are looking for a value, and the remote node didn't have it
# - mark it as the closest "empty" node, if it is
if 'closestNodeNoValue' in self.find_value_result:
if self._is_closer(responseMsg):
self.find_value_result['closestNodeNoValue'] = aContact
else:
self.find_value_result['closestNodeNoValue'] = aContact
def _is_closer(self, responseMsg):
return self.distance.is_closer(responseMsg.nodeID, self.active_contacts[0].id)
def _addIfValid(self, contactTriple):
if isinstance(contactTriple, (list, tuple)) and len(contactTriple) == 3:
testContact = Contact(
contactTriple[0], contactTriple[1], contactTriple[2], self.node._protocol)
if testContact not in self.shortlist:
self.shortlist.append(testContact)
def removeFromShortlist(self, failure, deadContactID):
""" @type failure: twisted.python.failure.Failure """
failure.trap(TimeoutError, defer.CancelledError, TypeError)
if len(deadContactID) != constants.key_bits / 8:
raise ValueError("invalid lbry id")
if deadContactID in self.shortlist:
self.shortlist.remove(deadContactID)
return deadContactID
def cancelActiveProbe(self, contactID):
self.active_probes.pop()
if len(self.active_probes) <= constants.alpha / 2 and len(self.pending_iteration_calls):
# Force the iteration
self.pending_iteration_calls[0].cancel()
del self.pending_iteration_calls[0]
self.searchIteration()
def sortByDistance(self, contact_list):
"""Sort the list of contacts in order by distance from key"""
ExpensiveSort(contact_list, self.distance.to_contact).sort()
# Send parallel, asynchronous FIND_NODE RPCs to the shortlist of contacts
def searchIteration(self):
self.slow_node_count[0] = len(self.active_probes)
# Sort the discovered active nodes from closest to furthest
self.sortByDistance(self.active_contacts)
# This makes sure a returning probe doesn't force calling this function by mistake
while len(self.pending_iteration_calls):
del self.pending_iteration_calls[0]
# See if should continue the search
if self.key in self.find_value_result:
self.outer_d.callback(self.find_value_result)
return
elif len(self.active_contacts) and self.find_value is False:
if self._is_all_done():
# TODO: Re-send the FIND_NODEs to all of the k closest nodes not already queried
#
# Ok, we're done; either we have accumulated k active
# contacts or no improvement in closestNode has been
# noted
self.outer_d.callback(self.active_contacts)
return
# The search continues...
if len(self.active_contacts):
self.prev_closest_node[0] = self.active_contacts[0]
contactedNow = 0
self.sortByDistance(self.shortlist)
# Store the current shortList length before contacting other nodes
prevShortlistLength = len(self.shortlist)
for contact in self.shortlist:
if contact.id not in self.already_contacted:
self._probeContact(contact)
contactedNow += 1
if contactedNow == constants.alpha:
break
if self._should_lookup_active_calls():
# Schedule the next iteration if there are any active
# calls (Kademlia uses loose parallelism)
call, _ = self.node.reactor_callLater(constants.iterativeLookupDelay, self.searchIteration)
self.pending_iteration_calls.append(call)
# Check for a quick contact response that made an update to the shortList
elif prevShortlistLength < len(self.shortlist):
# Ensure that the closest contacts are taken from the updated shortList
self.searchIteration()
else:
# If no probes were sent, there will not be any improvement, so we're done
self.outer_d.callback(self.active_contacts)
def _probeContact(self, contact):
self.active_probes.append(contact.id)
rpcMethod = getattr(contact, self.rpc)
df = rpcMethod(self.key, rawResponse=True)
df.addCallback(self.extendShortlist)
df.addErrback(self.removeFromShortlist, contact.id)
df.addCallback(self.cancelActiveProbe)
df.addErrback(lambda _: log.exception('Failed to contact %s', contact))
self.already_contacted.append(contact.id)
def _should_lookup_active_calls(self):
return (
len(self.active_probes) > self.slow_node_count[0] or
(
len(self.shortlist) < constants.k and
len(self.active_contacts) < len(self.shortlist) and
len(self.active_probes) > 0
)
)
def _is_all_done(self):
return (
len(self.active_contacts) >= constants.k or
(
self.active_contacts[0] == self.prev_closest_node[0] and
len(self.active_probes) == self.slow_node_count[0]
)
)
class ExpensiveSort(object):
"""Sort a list in place.
The result of `key(item)` is cached for each item in the `to_sort`
list as an optimization. This can be useful when `key` is
expensive.
Attributes:
to_sort: a list of items to sort
key: callable, like `key` in normal python sort
attr: the attribute name used to cache the value on each item.
"""
def __init__(self, to_sort, key, attr='__value'):
self.to_sort = to_sort
self.key = key
self.attr = attr
def sort(self):
self._cacheValues()
self._sortByValue()
self._removeValue()
def _cacheValues(self):
for item in self.to_sort:
setattr(item, self.attr, self.key(item))
def _sortByValue(self):
self.to_sort.sort(key=operator.attrgetter(self.attr))
def _removeValue(self):
for item in self.to_sort:
delattr(item, self.attr)

View file

@ -4,7 +4,6 @@ import logging
from zope.interface import implements from zope.interface import implements
from twisted.internet import defer from twisted.internet import defer
from lbrynet.interfaces import IPeerFinder from lbrynet.interfaces import IPeerFinder
from lbrynet.core.utils import short_hash
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -13,18 +12,9 @@ log = logging.getLogger(__name__)
class DummyPeerFinder(object): class DummyPeerFinder(object):
"""This class finds peers which have announced to the DHT that they have certain blobs""" """This class finds peers which have announced to the DHT that they have certain blobs"""
def run_manage_loop(self): def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
pass
def stop(self):
pass
def find_peers_for_blob(self, blob_hash):
return defer.succeed([]) return defer.succeed([])
def get_most_popular_hashes(self, num_to_return):
return []
class DHTPeerFinder(DummyPeerFinder): class DHTPeerFinder(DummyPeerFinder):
"""This class finds peers which have announced to the DHT that they have certain blobs""" """This class finds peers which have announced to the DHT that they have certain blobs"""
@ -39,11 +29,8 @@ class DHTPeerFinder(DummyPeerFinder):
self.peer_manager = peer_manager self.peer_manager = peer_manager
self.peers = [] self.peers = []
def stop(self):
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=False): def find_peers_for_blob(self, blob_hash, timeout=None, filter_self=True):
""" """
Find peers for blob in the DHT Find peers for blob in the DHT
blob_hash (str): blob hash to look for blob_hash (str): blob hash to look for
@ -54,32 +41,19 @@ class DHTPeerFinder(DummyPeerFinder):
Returns: Returns:
list of peers for the blob list of peers for the blob
""" """
def _trigger_timeout():
if not finished_deferred.called:
log.debug("Peer search for %s timed out", short_hash(blob_hash))
finished_deferred.cancel()
bin_hash = binascii.unhexlify(blob_hash) bin_hash = binascii.unhexlify(blob_hash)
finished_deferred = self.dht_node.getPeersForBlob(bin_hash) finished_deferred = self.dht_node.iterativeFindValue(bin_hash)
if timeout:
if timeout is not None: finished_deferred.addTimeout(timeout, self.dht_node.clock)
self.dht_node.reactor_callLater(timeout, _trigger_timeout)
try: try:
peer_list = yield finished_deferred peer_list = yield finished_deferred
except defer.CancelledError: except defer.TimeoutError:
peer_list = [] peer_list = []
peers = set(peer_list) peers = set(peer_list)
good_peers = [] results = []
for host, port in peers: for node_id, host, port in peers:
if filter_self and (host, port) == (self.dht_node.externalIP, self.dht_node.peerPort): if filter_self and (host, port) == (self.dht_node.externalIP, self.dht_node.peerPort):
continue continue
peer = self.peer_manager.get_peer(host, port) results.append(self.peer_manager.get_peer(host, port))
if peer.is_available() is True: defer.returnValue(results)
good_peers.append(peer)
defer.returnValue(good_peers)
def get_most_popular_hashes(self, num_to_return):
return self.dht_node.get_most_popular_hashes(num_to_return)

View file

@ -1,20 +1,93 @@
import logging import logging
import socket import socket
import errno import errno
from collections import deque
from twisted.internet import protocol, defer from twisted.internet import protocol, defer
from lbrynet.core.call_later_manager import CallLaterManager from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError, TransportNotConnected
import constants import constants
import encoding import encoding
import msgtypes import msgtypes
import msgformat import msgformat
from contact import Contact
from error import BUILTIN_EXCEPTIONS, UnknownRemoteException, TimeoutError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class PingQueue(object):
"""
Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the
routing table after having been given enough time for a pinhole to expire.
"""
def __init__(self, node):
self._node = node
self._get_time = self._node.clock.seconds
self._queue = deque()
self._enqueued_contacts = {}
self._semaphore = defer.DeferredSemaphore(1)
self._ping_semaphore = defer.DeferredSemaphore(constants.alpha)
self._process_lc = node.get_looping_call(self._semaphore.run, self._process)
def _add_contact(self, contact, delay=None):
if contact in self._enqueued_contacts:
return defer.succeed(None)
delay = delay or constants.checkRefreshInterval
self._enqueued_contacts[contact] = self._get_time() + delay
self._queue.append(contact)
return defer.succeed(None)
@defer.inlineCallbacks
def _process(self):
if not len(self._queue):
defer.returnValue(None)
contact = self._queue.popleft()
now = self._get_time()
# if the oldest contact in the queue isn't old enough to be pinged, add it back to the queue and return
if now < self._enqueued_contacts[contact]:
self._queue.appendleft(contact)
defer.returnValue(None)
pinged = []
checked = []
while now > self._enqueued_contacts[contact]:
checked.append(contact)
if not contact.contact_is_good:
pinged.append(contact)
if not len(self._queue):
break
contact = self._queue.popleft()
if not now > self._enqueued_contacts[contact]:
checked.append(contact)
@defer.inlineCallbacks
def _ping(contact):
try:
yield contact.ping()
except TimeoutError:
pass
except Exception as err:
log.warning("unexpected error: %s", err)
yield defer.DeferredList([_ping(contact) for contact in pinged])
for contact in checked:
if contact in self._enqueued_contacts:
del self._enqueued_contacts[contact]
defer.returnValue(None)
def start(self):
return self._node.safe_start_looping_call(self._process_lc, 60)
def stop(self):
return self._node.safe_stop_looping_call(self._process_lc)
def enqueue_maybe_ping(self, contact, delay=None):
return self._semaphore.run(self._add_contact, contact, delay)
class KademliaProtocol(protocol.DatagramProtocol): class KademliaProtocol(protocol.DatagramProtocol):
""" Implements all low-level network-related functions of a Kademlia node """ """ Implements all low-level network-related functions of a Kademlia node """
@ -27,9 +100,41 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._sentMessages = {} self._sentMessages = {}
self._partialMessages = {} self._partialMessages = {}
self._partialMessagesProgress = {} self._partialMessagesProgress = {}
self._listening = defer.Deferred(None)
self._ping_queue = PingQueue(self._node)
self._protocolVersion = constants.protocolVersion
def sendRPC(self, contact, method, args, rawResponse=False): def _migrate_incoming_rpc_args(self, contact, method, *args):
""" Sends an RPC to the specified contact if method == 'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict):
blob_hash = args[0]
token = args[1].pop('token', None)
port = args[1].pop('port', -1)
originalPublisherID = args[1].pop('lbryid', None)
age = 0
return (blob_hash, token, port, originalPublisherID, age), {}
return args, {}
def _migrate_outgoing_rpc_args(self, contact, method, *args):
"""
This will reformat protocol version 0 arguments for the store function and will add the
protocol version keyword argument to calls to contacts who will accept it
"""
if contact.protocolVersion == 0:
if method == 'store':
blob_hash, token, port, originalPublisherID, age = args
args = (blob_hash, {'token': token, 'port': port, 'lbryid': originalPublisherID}, originalPublisherID,
False)
return args
return args
if args and isinstance(args[-1], dict):
args[-1]['protocolVersion'] = self._protocolVersion
return args
return args + ({'protocolVersion': self._protocolVersion},)
def sendRPC(self, contact, method, args):
"""
Sends an RPC to the specified contact
@param contact: The contact (remote node) to send the RPC to @param contact: The contact (remote node) to send the RPC to
@type contact: kademlia.contacts.Contact @type contact: kademlia.contacts.Contact
@ -38,14 +143,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
@param args: A list of (non-keyword) arguments to pass to the remote @param args: A list of (non-keyword) arguments to pass to the remote
method, in the correct order method, in the correct order
@type args: tuple @type args: tuple
@param rawResponse: If this is set to C{True}, the caller of this RPC
will receive a tuple containing the actual response
message object and the originating address tuple as
a result; in other words, it will not be
interpreted by this class. Unless something special
needs to be done with the metadata associated with
the message, this should remain C{False}.
@type rawResponse: bool
@return: This immediately returns a deferred object, which will return @return: This immediately returns a deferred object, which will return
the result of the RPC call, or raise the relevant exception the result of the RPC call, or raise the relevant exception
@ -55,29 +152,57 @@ class KademliaProtocol(protocol.DatagramProtocol):
C{ErrorMessage}). C{ErrorMessage}).
@rtype: twisted.internet.defer.Deferred @rtype: twisted.internet.defer.Deferred
""" """
msg = msgtypes.RequestMessage(self._node.node_id, method, args) msg = msgtypes.RequestMessage(self._node.node_id, method, self._migrate_outgoing_rpc_args(contact, method,
*args))
msgPrimitive = self._translator.toPrimitive(msg) msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = self._encoder.encode(msgPrimitive)
if args: if args:
log.debug("DHT SEND CALL %s(%s)", method, args[0].encode('hex')) log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
args[0].encode('hex'), contact.address, contact.port)
else: else:
log.debug("DHT SEND CALL %s", method) log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port)
df = defer.Deferred() df = defer.Deferred()
if rawResponse:
df._rpcRawResponse = True def _remove_contact(failure): # remove the contact from the routing table and track the failure
try:
self._node.removeContact(contact)
except (ValueError, IndexError):
pass
contact.update_last_failed()
return failure
def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied()
if method == 'findValue':
if 'protocolVersion' not in result:
contact.update_protocol_version(0)
else:
contact.update_protocol_version(result.pop('protocolVersion'))
d = self._node.addContact(contact)
d.addCallback(lambda _: result)
return d
df.addCallbacks(_update_contact, _remove_contact)
# Set the RPC timeout timer # Set the RPC timeout timer
timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id) timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, msg.id)
# Transmit the data # Transmit the data
self._send(encodedMsg, msg.id, (contact.address, contact.port)) self._send(encodedMsg, msg.id, (contact.address, contact.port))
self._sentMessages[msg.id] = (contact.id, df, timeoutCall, method, args) self._sentMessages[msg.id] = (contact, df, timeoutCall, cancelTimeout, method, args)
df.addErrback(cancelTimeout) df.addErrback(cancelTimeout)
return df return df
def startProtocol(self): def startProtocol(self):
log.info("DHT listening on UDP %i", self._node.port) log.info("DHT listening on UDP %s:%i", self._node.externalIP, self._node.port)
if self._listening.called:
self._listening = defer.Deferred()
self._listening.callback(True)
return self._ping_queue.start()
def datagramReceived(self, datagram, address): def datagramReceived(self, datagram, address):
""" Handles and parses incoming RPC messages (and responses) """ Handles and parses incoming RPC messages (and responses)
@ -115,49 +240,77 @@ class KademliaProtocol(protocol.DatagramProtocol):
log.warning("Couldn't decode dht datagram from %s", address) log.warning("Couldn't decode dht datagram from %s", address)
return return
remoteContact = Contact(message.nodeID, address[0], address[1], self)
# Refresh the remote node's details in the local node's k-buckets
self._node.addContact(remoteContact)
if isinstance(message, msgtypes.RequestMessage): if isinstance(message, msgtypes.RequestMessage):
# This is an RPC method request # This is an RPC method request
self._handleRPC(remoteContact, message.id, message.request, message.args) remoteContact = self._node.contact_manager.make_contact(message.nodeID, address[0], address[1], self)
remoteContact.update_last_requested()
# only add a requesting contact to the routing table if it has replied to one of our requests
if remoteContact.contact_is_good is True:
df = self._node.addContact(remoteContact)
else:
df = defer.succeed(None)
df.addCallback(lambda _: self._handleRPC(remoteContact, message.id, message.request, message.args))
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful
if remoteContact.contact_is_good is None and remoteContact.lastReplied is None:
df.addCallback(lambda _: self._ping_queue.enqueue_maybe_ping(remoteContact))
elif isinstance(message, msgtypes.ErrorMessage):
# The RPC request raised a remote exception; raise it locally
if message.exceptionType in BUILTIN_EXCEPTIONS:
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType]
else:
exception_type = UnknownRemoteException
remoteException = exception_type(message.response)
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
if message.id in self._sentMessages:
# Cancel timeout timer for this RPC
remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCanceller()
del self._sentMessages[message.id]
# reject replies coming from a different address than what we sent our request to
if (remoteContact.address, remoteContact.port) != address:
log.warning("Sent request to node %s at %s:%i, got reply from %s:%i",
remoteContact.log_id(), remoteContact.address,
remoteContact.port, address[0], address[1])
df.errback(TimeoutError(remoteContact.id))
return
# this error is returned by nodes that can be contacted but have an old
# and broken version of the ping command, if they return it the node can
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
df.errback(remoteException)
elif isinstance(message, msgtypes.ResponseMessage): elif isinstance(message, msgtypes.ResponseMessage):
# Find the message that triggered this response # Find the message that triggered this response
if message.id in self._sentMessages: if message.id in self._sentMessages:
# Cancel timeout timer for this RPC # Cancel timeout timer for this RPC
df, timeoutCall = self._sentMessages[message.id][1:3] remoteContact, df, timeoutCall, timeoutCanceller, method = self._sentMessages[message.id][0:5]
timeoutCall.cancel() timeoutCanceller()
del self._sentMessages[message.id] del self._sentMessages[message.id]
log.debug("%s:%i RECV response to %s from %s:%i", self._node.externalIP, self._node.port,
method, remoteContact.address, remoteContact.port)
if hasattr(df, '_rpcRawResponse'): # When joining the network we made Contact objects for the seed nodes with node ids set to None
# The RPC requested that the raw response message # Thus, the sent_to_id will also be None, and the contact objects need the ids to be manually set.
# and originating address be returned; do not # These replies have be distinguished from those where the node id in the datagram does not match
# interpret it # the node id of the node we sent a message to (these messages are treated as an error)
df.callback((message, address)) if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
elif isinstance(message, msgtypes.ErrorMessage): log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
# The RPC request raised a remote exception; raise it locally remoteContact.log_id(False), message.nodeID.encode('hex'))
if message.exceptionType in BUILTIN_EXCEPTIONS: df.errback(TimeoutError(remoteContact.id))
exception_type = BUILTIN_EXCEPTIONS[message.exceptionType] return
else: elif not remoteContact.id:
exception_type = UnknownRemoteException remoteContact.set_id(message.nodeID)
remoteException = exception_type(message.response)
# this error is returned by nodes that can be contacted but have an old # We got a result from the RPC
# and broken version of the ping command, if they return it the node can df.callback(message.response)
# be contacted, so we'll treat it as a successful ping
old_ping_error = "ping() got an unexpected keyword argument '_rpcNodeContact'"
if isinstance(remoteException, TypeError) and \
remoteException.message == old_ping_error:
log.debug("old pong error")
df.callback('pong')
else:
log.error("DHT RECV REMOTE EXCEPTION FROM %s:%i: %s", address[0],
address[1], remoteException)
df.errback(remoteException)
else:
# We got a result from the RPC
df.callback(message.response)
else: else:
# If the original message isn't found, it must have timed out # If the original message isn't found, it must have timed out
# TODO: we should probably do something with this... # TODO: we should probably do something with this...
@ -222,10 +375,11 @@ class KademliaProtocol(protocol.DatagramProtocol):
# this should probably try to retransmit when the network connection is back # this should probably try to retransmit when the network connection is back
log.error("Network is unreachable") log.error("Network is unreachable")
else: else:
log.error("DHT socket error: %s (%i)", err.message, err.errno) log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(txData), address[0], address[1], err.message, err.errno)
raise err raise err
else: else:
log.warning("transport not connected!") raise TransportNotConnected()
def _sendResponse(self, contact, rpcID, response): def _sendResponse(self, contact, rpcID, response):
""" Send a RPC response to the specified contact """ Send a RPC response to the specified contact
@ -259,28 +413,34 @@ class KademliaProtocol(protocol.DatagramProtocol):
# Execute the RPC # Execute the RPC
func = getattr(self._node, method, None) func = getattr(self._node, method, None)
if callable(func) and hasattr(func, 'rpcmethod'): if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain # Call the exposed Node method and return the result to the deferred callback chain
if args: # if args:
log.debug("DHT RECV CALL %s(%s) %s:%i", method, args[0].encode('hex'), # log.debug("%s:%i RECV CALL %s(%s) %s:%i", self._node.externalIP, self._node.port, method,
# args[0].encode('hex'), senderContact.address, senderContact.port)
# else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port) senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and 'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop('protocolVersion')))
a, kw = tuple(args[:-1]), args[-1]
else: else:
log.debug("DHT RECV CALL %s %s:%i", method, senderContact.address, senderContact.update_protocol_version(0)
senderContact.port) a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
try: try:
if method != 'ping': if method != 'ping':
kwargs = {'_rpcNodeID': senderContact.id, '_rpcNodeContact': senderContact} result = func(senderContact, *a)
result = func(*args, **kwargs)
else: else:
result = func() result = func()
except Exception, e: except Exception, e:
log.exception("error handling request for %s: %s", senderContact.address, method) log.exception("error handling request for %s:%i %s", senderContact.address, senderContact.port, method)
df.errback(e) df.errback(e)
else: else:
df.callback(result) df.callback(result)
else: else:
# No such exposed method # No such exposed method
df.errback(AttributeError('Invalid method: %s' % method)) df.errback(AttributeError('Invalid method: %s' % method))
return df
def _msgTimeout(self, messageID): def _msgTimeout(self, messageID):
""" Called when an RPC request message times out """ """ Called when an RPC request message times out """
@ -289,30 +449,30 @@ class KademliaProtocol(protocol.DatagramProtocol):
# This should never be reached # This should never be reached
log.error("deferred timed out, but is not present in sent messages list!") log.error("deferred timed out, but is not present in sent messages list!")
return return
remoteContactID, df, timeout_call, method, args = self._sentMessages[messageID] remoteContact, df, timeout_call, timeout_canceller, method, args = self._sentMessages[messageID]
if self._partialMessages.has_key(messageID): if self._partialMessages.has_key(messageID):
# We are still receiving this message # We are still receiving this message
self._msgTimeoutInProgress(messageID, remoteContactID, df, method, args) self._msgTimeoutInProgress(messageID, timeout_canceller, remoteContact, df, method, args)
return return
del self._sentMessages[messageID] del self._sentMessages[messageID]
# The message's destination node is now considered to be dead; # The message's destination node is now considered to be dead;
# raise an (asynchronous) TimeoutError exception and update the host node # raise an (asynchronous) TimeoutError exception and update the host node
self._node.removeContact(remoteContactID) df.errback(TimeoutError(remoteContact.id))
df.errback(TimeoutError(remoteContactID))
def _msgTimeoutInProgress(self, messageID, remoteContactID, df, method, args): def _msgTimeoutInProgress(self, messageID, timeoutCanceller, remoteContact, df, method, args):
# See if any progress has been made; if not, kill the message # See if any progress has been made; if not, kill the message
if self._hasProgressBeenMade(messageID): if self._hasProgressBeenMade(messageID):
# Reset the RPC timeout timer # Reset the RPC timeout timer
timeoutCall, _ = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID) timeoutCanceller()
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) timeoutCall, cancelTimeout = self._node.reactor_callLater(constants.rpcTimeout, self._msgTimeout, messageID)
self._sentMessages[messageID] = (remoteContact, df, timeoutCall, cancelTimeout, method, args)
else: else:
# No progress has been made # No progress has been made
if messageID in self._partialMessagesProgress: if messageID in self._partialMessagesProgress:
del self._partialMessagesProgress[messageID] del self._partialMessagesProgress[messageID]
if messageID in self._partialMessages: if messageID in self._partialMessages:
del self._partialMessages[messageID] del self._partialMessages[messageID]
df.errback(TimeoutError(remoteContactID)) df.errback(TimeoutError(remoteContact.id))
def _hasProgressBeenMade(self, messageID): def _hasProgressBeenMade(self, messageID):
return ( return (
@ -329,5 +489,6 @@ class KademliaProtocol(protocol.DatagramProtocol):
Will only be called once, after all ports are disconnected. Will only be called once, after all ports are disconnected.
""" """
log.info('Stopping DHT') log.info('Stopping DHT')
CallLaterManager.stop() self._ping_queue.stop()
self._node.call_later_manager.stop()
log.info('DHT stopped') log.info('DHT stopped')

View file

@ -7,9 +7,11 @@
import random import random
from zope.interface import implements from zope.interface import implements
from twisted.internet import defer
import constants import constants
import kbucket import kbucket
import protocol from error import TimeoutError
from distance import Distance
from interface import IRoutingTable from interface import IRoutingTable
import logging import logging
@ -40,72 +42,106 @@ class TreeRoutingTable(object):
@type parentNodeID: str @type parentNodeID: str
""" """
# Create the initial (single) k-bucket covering the range of the entire n-bit ID space # Create the initial (single) k-bucket covering the range of the entire n-bit ID space
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits)]
self._parentNodeID = parentNodeID self._parentNodeID = parentNodeID
self._buckets = [kbucket.KBucket(rangeMin=0, rangeMax=2 ** constants.key_bits, node_id=self._parentNodeID)]
if not getTime: if not getTime:
from time import time as getTime from twisted.internet import reactor
getTime = reactor.seconds
self._getTime = getTime self._getTime = getTime
def get_contacts(self):
contacts = []
for i in range(len(self._buckets)):
for contact in self._buckets[i]._contacts:
contacts.append(contact)
return contacts
def _shouldSplit(self, bucketIndex, toAdd):
# https://stackoverflow.com/questions/32129978/highly-unbalanced-kademlia-routing-table/32187456#32187456
if self._buckets[bucketIndex].keyInRange(self._parentNodeID):
return True
contacts = self.get_contacts()
distance = Distance(self._parentNodeID)
contacts.sort(key=lambda c: distance(c.id))
kth_contact = contacts[-1] if len(contacts) < constants.k else contacts[constants.k-1]
return distance(toAdd) < distance(kth_contact.id)
def addContact(self, contact): def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already """ Add the given contact to the correct k-bucket; if it already
exists, its status will be updated exists, its status will be updated
@param contact: The contact to add to this node's k-buckets @param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact @type contact: kademlia.contact.Contact
"""
if contact.id == self._parentNodeID:
return
@rtype: defer.Deferred
"""
if contact.id == self._parentNodeID:
return defer.succeed(None)
bucketIndex = self._kbucketIndex(contact.id) bucketIndex = self._kbucketIndex(contact.id)
try: try:
self._buckets[bucketIndex].addContact(contact) self._buckets[bucketIndex].addContact(contact)
except kbucket.BucketFull: except kbucket.BucketFull:
# The bucket is full; see if it can be split (by checking # The bucket is full; see if it can be split (by checking if its range includes the host node's id)
# if its range includes the host node's id) if self._shouldSplit(bucketIndex, contact.id):
if self._buckets[bucketIndex].keyInRange(self._parentNodeID):
self._splitBucket(bucketIndex) self._splitBucket(bucketIndex)
# Retry the insertion attempt # Retry the insertion attempt
self.addContact(contact) return self.addContact(contact)
else: else:
# We can't split the k-bucket # We can't split the k-bucket
# NOTE: #
# In section 2.4 of the 13-page version of the # The 13 page kademlia paper specifies that the least recently contacted node in the bucket
# Kademlia paper, it is specified that in this case, # shall be pinged. If it fails to reply it is replaced with the new contact. If the ping is successful
# the new contact should simply be dropped. However, # the new contact is ignored and not added to the bucket (sections 2.2 and 2.4).
# in section 2.2, it states that the head contact in #
# the k-bucket (i.e. the least-recently seen node) # A reasonable extension to this is BEP 0005, which extends the above:
# should be pinged - if it does not reply, it should #
# be dropped, and the new contact added to the tail of # Not all nodes that we learn about are equal. Some are "good" and some are not.
# the k-bucket. This implementation follows section # Many nodes using the DHT are able to send queries and receive responses,
# 2.2 regarding this point. # but are not able to respond to queries from other nodes. It is important that
# each node's routing table must contain only known good nodes. A good node is
# a node has responded to one of our queries within the last 15 minutes. A node
# is also good if it has ever responded to one of our queries and has sent us a
# query within the last 15 minutes. After 15 minutes of inactivity, a node becomes
# questionable. Nodes become bad when they fail to respond to multiple queries
# in a row. Nodes that we know are good are given priority over nodes with unknown status.
#
# When there are bad or questionable nodes in the bucket, the least recent is selected for
# potential replacement (BEP 0005). When all nodes in the bucket are fresh, the head (least recent)
# contact is selected as described in section 2.2 of the kademlia paper. In both cases the new contact
# is ignored if the pinged node replies.
def replaceContact(failure, deadContactID): def replaceContact(failure, deadContact):
""" Callback for the deferred PING RPC to see if the head """
node in the k-bucket is still responding Callback for the deferred PING RPC to see if the node to be replaced in the k-bucket is still
responding
@type failure: twisted.python.failure.Failure @type failure: twisted.python.failure.Failure
""" """
failure.trap(protocol.TimeoutError) failure.trap(TimeoutError)
if len(deadContactID) != constants.key_bits / 8: log.debug("Replacing dead contact in bucket %i: %s:%i (%s) with %s:%i (%s)", bucketIndex,
raise ValueError("invalid contact id") deadContact.address, deadContact.port, deadContact.log_id(), contact.address,
log.debug("Replacing dead contact: %s", deadContactID.encode('hex')) contact.port, contact.log_id())
try: try:
# Remove the old contact... self._buckets[bucketIndex].removeContact(deadContact)
self._buckets[bucketIndex].removeContact(deadContactID)
except ValueError: except ValueError:
# The contact has already been removed (probably due to a timeout) # The contact has already been removed (probably due to a timeout)
pass pass
# ...and add the new one at the tail of the bucket return self.addContact(contact)
self.addContact(contact)
# Ping the least-recently seen contact in this k-bucket not_good_contacts = self._buckets[bucketIndex].getBadOrUnknownContacts()
head_contact = self._buckets[bucketIndex]._contacts[0] if not_good_contacts:
df = head_contact.ping() to_replace = not_good_contacts[0]
# If there's an error (i.e. timeout), remove the head else:
# contact, and append the new one to_replace = self._buckets[bucketIndex]._contacts[0]
df.addErrback(replaceContact, head_contact.id) df = to_replace.ping()
df.addErrback(replaceContact, to_replace)
return df
else:
self.touchKBucketByIndex(bucketIndex)
return defer.succeed(None)
def findCloseNodes(self, key, count, _rpcNodeID=None): def findCloseNodes(self, key, count, sender_node_id=None):
""" Finds a number of known nodes closest to the node/value with the """ Finds a number of known nodes closest to the node/value with the
specified key. specified key.
@ -113,10 +149,10 @@ class TreeRoutingTable(object):
@type key: str @type key: str
@param count: the amount of contacts to return @param count: the amount of contacts to return
@type count: int @type count: int
@param _rpcNodeID: Used during RPC, this is be the sender's Node ID @param sender_node_id: Used during RPC, this is be the sender's Node ID
Whatever ID is passed in the paramater will get Whatever ID is passed in the paramater will get
excluded from the list of returned contacts. excluded from the list of returned contacts.
@type _rpcNodeID: str @type sender_node_id: str
@return: A list of node contacts (C{kademlia.contact.Contact instances}) @return: A list of node contacts (C{kademlia.contact.Contact instances})
closest to the specified key. closest to the specified key.
@ -128,7 +164,8 @@ class TreeRoutingTable(object):
bucketIndex = self._kbucketIndex(key) bucketIndex = self._kbucketIndex(key)
if bucketIndex < len(self._buckets): if bucketIndex < len(self._buckets):
closestNodes = self._buckets[bucketIndex].getContacts(count, _rpcNodeID) # sort these
closestNodes = self._buckets[bucketIndex].getContacts(count, sender_node_id, sort_distance_to=key)
else: else:
closestNodes = [] closestNodes = []
# This method must return k contacts (even if we have the node # This method must return k contacts (even if we have the node
@ -141,21 +178,27 @@ class TreeRoutingTable(object):
def get_remain(closest): def get_remain(closest):
return min(count, constants.k) - len(closest) return min(count, constants.k) - len(closest)
# Fill up the node list to k nodes, starting with the closest neighbouring nodes known distance = Distance(key)
while len(closestNodes) < min(count, constants.k) and (canGoLower or canGoHigher): while len(closestNodes) < min(count, constants.k) and (canGoLower or canGoHigher):
# TODO: this may need to be optimized iteration_contacts = []
# TODO: add "key" kwarg to getContacts() to sort contacts returned by xor distance # get contacts from lower and/or higher buckets without sorting them
# to the key
if canGoLower and len(closestNodes) < min(count, constants.k): if canGoLower and len(closestNodes) < min(count, constants.k):
closestNodes.extend( lower_bucket = self._buckets[bucketIndex - i]
self._buckets[bucketIndex - i].getContacts(get_remain(closestNodes), contacts = lower_bucket.getContacts(get_remain(closestNodes), sender_node_id, sort_distance_to=False)
_rpcNodeID)) iteration_contacts.extend(contacts)
canGoLower = bucketIndex - (i + 1) >= 0 canGoLower = bucketIndex - (i + 1) >= 0
if canGoHigher and len(closestNodes) < min(count, constants.k): if canGoHigher and len(closestNodes) < min(count, constants.k):
closestNodes.extend(self._buckets[bucketIndex + i].getContacts( higher_bucket = self._buckets[bucketIndex + i]
get_remain(closestNodes), _rpcNodeID)) contacts = higher_bucket.getContacts(get_remain(closestNodes), sender_node_id, sort_distance_to=False)
iteration_contacts.extend(contacts)
canGoHigher = bucketIndex + (i + 1) < len(self._buckets) canGoHigher = bucketIndex + (i + 1) < len(self._buckets)
i += 1 i += 1
# sort the combined contacts and add as many as possible/needed to the combined contact list
iteration_contacts.sort(key=lambda c: distance(c.id), reverse=True)
while len(iteration_contacts) and len(closestNodes) < min(count, constants.k):
closestNodes.append(iteration_contacts.pop())
return closestNodes return closestNodes
def getContact(self, contactID): def getContact(self, contactID):
@ -195,23 +238,24 @@ class TreeRoutingTable(object):
""" """
bucketIndex = startIndex bucketIndex = startIndex
refreshIDs = [] refreshIDs = []
now = int(self._getTime())
for bucket in self._buckets[startIndex:]: for bucket in self._buckets[startIndex:]:
if force or (int(self._getTime()) - bucket.lastAccessed >= constants.refreshTimeout): if force or now - bucket.lastAccessed >= constants.refreshTimeout:
searchID = self._randomIDInBucketRange(bucketIndex) searchID = self._randomIDInBucketRange(bucketIndex)
refreshIDs.append(searchID) refreshIDs.append(searchID)
bucketIndex += 1 bucketIndex += 1
return refreshIDs return refreshIDs
def removeContact(self, contactID): def removeContact(self, contact):
""" Remove the contact with the specified node ID from the routing
table
@param contactID: The node ID of the contact to remove
@type contactID: str
""" """
bucketIndex = self._kbucketIndex(contactID) Remove the contact from the routing table
@param contact: The contact to remove
@type contact: dht.contact._Contact
"""
bucketIndex = self._kbucketIndex(contact.id)
try: try:
self._buckets[bucketIndex].removeContact(contactID) self._buckets[bucketIndex].removeContact(contact)
except ValueError: except ValueError:
return return
@ -222,7 +266,9 @@ class TreeRoutingTable(object):
@param key: A key in the range of the target k-bucket @param key: A key in the range of the target k-bucket
@type key: str @type key: str
""" """
bucketIndex = self._kbucketIndex(key) self.touchKBucketByIndex(self._kbucketIndex(key))
def touchKBucketByIndex(self, bucketIndex):
self._buckets[bucketIndex].lastAccessed = int(self._getTime()) self._buckets[bucketIndex].lastAccessed = int(self._getTime())
def _kbucketIndex(self, key): def _kbucketIndex(self, key):
@ -272,7 +318,7 @@ class TreeRoutingTable(object):
oldBucket = self._buckets[oldBucketIndex] oldBucket = self._buckets[oldBucketIndex]
splitPoint = oldBucket.rangeMax - (oldBucket.rangeMax - oldBucket.rangeMin) / 2 splitPoint = oldBucket.rangeMax - (oldBucket.rangeMax - oldBucket.rangeMin) / 2
# Create a new k-bucket to cover the range split off from the old bucket # Create a new k-bucket to cover the range split off from the old bucket
newBucket = kbucket.KBucket(splitPoint, oldBucket.rangeMax) newBucket = kbucket.KBucket(splitPoint, oldBucket.rangeMax, self._parentNodeID)
oldBucket.rangeMax = splitPoint oldBucket.rangeMax = splitPoint
# Now, add the new bucket into the routing table tree # Now, add the new bucket into the routing table tree
self._buckets.insert(oldBucketIndex + 1, newBucket) self._buckets.insert(oldBucketIndex + 1, newBucket)
@ -284,76 +330,16 @@ class TreeRoutingTable(object):
for contact in newBucket._contacts: for contact in newBucket._contacts:
oldBucket.removeContact(contact) oldBucket.removeContact(contact)
def contactInRoutingTable(self, address_tuple):
for bucket in self._buckets:
for contact in bucket.getContacts(sort_distance_to=False):
if address_tuple[0] == contact.address and address_tuple[1] == contact.port:
return True
return False
class OptimizedTreeRoutingTable(TreeRoutingTable): def bucketsWithContacts(self):
""" A version of the "tree"-type routing table specified by Kademlia, count = 0
along with contact accounting optimizations specified in section 4.1 of for bucket in self._buckets:
of the 13-page version of the Kademlia paper. if len(bucket):
""" count += 1
return count
def __init__(self, parentNodeID, getTime=None):
TreeRoutingTable.__init__(self, parentNodeID, getTime)
# Cache containing nodes eligible to replace stale k-bucket entries
self._replacementCache = {}
def addContact(self, contact):
""" Add the given contact to the correct k-bucket; if it already
exists, its status will be updated
@param contact: The contact to add to this node's k-buckets
@type contact: kademlia.contact.Contact
"""
if contact.id == self._parentNodeID:
return
# Initialize/reset the "successively failed RPC" counter
contact.failedRPCs = 0
bucketIndex = self._kbucketIndex(contact.id)
try:
self._buckets[bucketIndex].addContact(contact)
except kbucket.BucketFull:
# The bucket is full; see if it can be split (by checking
# if its range includes the host node's id)
if self._buckets[bucketIndex].keyInRange(self._parentNodeID):
self._splitBucket(bucketIndex)
# Retry the insertion attempt
self.addContact(contact)
else:
# We can't split the k-bucket
# NOTE: This implementation follows section 4.1 of the 13 page version
# of the Kademlia paper (optimized contact accounting without PINGs
# - results in much less network traffic, at the expense of some memory)
# Put the new contact in our replacement cache for the
# corresponding k-bucket (or update it's position if
# it exists already)
if bucketIndex not in self._replacementCache:
self._replacementCache[bucketIndex] = []
if contact in self._replacementCache[bucketIndex]:
self._replacementCache[bucketIndex].remove(contact)
elif len(self._replacementCache[bucketIndex]) >= constants.replacementCacheSize:
self._replacementCache[bucketIndex].pop(0)
self._replacementCache[bucketIndex].append(contact)
def removeContact(self, contactID):
""" Remove the contact with the specified node ID from the routing
table
@param contactID: The node ID of the contact to remove
@type contactID: str
"""
bucketIndex = self._kbucketIndex(contactID)
try:
contact = self._buckets[bucketIndex].getContact(contactID)
except ValueError:
return
contact.failedRPCs += 1
if contact.failedRPCs >= constants.rpcAttempts:
self._buckets[bucketIndex].removeContact(contactID)
# Replace this stale contact with one from our replacement cache, if we have any
if bucketIndex in self._replacementCache:
if len(self._replacementCache[bucketIndex]) > 0:
self._buckets[bucketIndex].addContact(
self._replacementCache[bucketIndex].pop())

View file

@ -56,18 +56,21 @@ class ManagedEncryptedFileDownloader(EncryptedFileSaver):
self.channel_name = None self.channel_name = None
self.metadata = None self.metadata = None
def set_claim_info(self, claim_info):
self.claim_id = claim_info['claim_id']
self.txid = claim_info['txid']
self.nout = claim_info['nout']
self.channel_claim_id = claim_info['channel_claim_id']
self.outpoint = "%s:%i" % (self.txid, self.nout)
self.claim_name = claim_info['name']
self.channel_name = claim_info['channel_name']
self.metadata = claim_info['value']['stream']['metadata']
@defer.inlineCallbacks @defer.inlineCallbacks
def get_claim_info(self, include_supports=True): def get_claim_info(self, include_supports=True):
claim_info = yield self.storage.get_content_claim(self.stream_hash, include_supports) claim_info = yield self.storage.get_content_claim(self.stream_hash, include_supports)
if claim_info: if claim_info:
self.claim_id = claim_info['claim_id'] self.set_claim_info(claim_info)
self.txid = claim_info['txid']
self.nout = claim_info['nout']
self.channel_claim_id = claim_info['channel_claim_id']
self.outpoint = "%s:%i" % (self.txid, self.nout)
self.claim_name = claim_info['name']
self.channel_name = claim_info['channel_name']
self.metadata = claim_info['value']['stream']['metadata']
defer.returnValue(claim_info) defer.returnValue(claim_info)

View file

@ -6,12 +6,11 @@ import logging
from twisted.internet import defer, task, reactor from twisted.internet import defer, task, reactor
from twisted.python.failure import Failure from twisted.python.failure import Failure
from lbrynet.core.Error import InvalidStreamDescriptorError
from lbrynet.reflector.reupload import reflect_file from lbrynet.reflector.reupload import reflect_file
from lbrynet.core.PaymentRateManager import NegotiatedPaymentRateManager # from lbrynet.core.PaymentRateManager import NegotiatedPaymentRateManager
from lbrynet.file_manager.EncryptedFileDownloader import ManagedEncryptedFileDownloader from lbrynet.file_manager.EncryptedFileDownloader import ManagedEncryptedFileDownloader
from lbrynet.file_manager.EncryptedFileDownloader import ManagedEncryptedFileDownloaderFactory from lbrynet.file_manager.EncryptedFileDownloader import ManagedEncryptedFileDownloaderFactory
from lbrynet.core.StreamDescriptor import EncryptedFileStreamType, get_sd_info, validate_descriptor from lbrynet.core.StreamDescriptor import EncryptedFileStreamType, get_sd_info
from lbrynet.cryptstream.client.CryptStreamDownloader import AlreadyStoppedError from lbrynet.cryptstream.client.CryptStreamDownloader import AlreadyStoppedError
from lbrynet.cryptstream.client.CryptStreamDownloader import CurrentlyStoppingError from lbrynet.cryptstream.client.CryptStreamDownloader import CurrentlyStoppingError
from lbrynet.core.utils import safe_start_looping_call, safe_stop_looping_call from lbrynet.core.utils import safe_start_looping_call, safe_stop_looping_call
@ -96,47 +95,34 @@ class EncryptedFileManager(object):
suggested_file_name=suggested_file_name suggested_file_name=suggested_file_name
) )
@defer.inlineCallbacks def _start_lbry_file(self, file_info, payment_rate_manager, claim_info):
def _start_lbry_file(self, file_info, payment_rate_manager):
lbry_file = self._get_lbry_file( lbry_file = self._get_lbry_file(
file_info['row_id'], file_info['stream_hash'], payment_rate_manager, file_info['sd_hash'], file_info['row_id'], file_info['stream_hash'], payment_rate_manager, file_info['sd_hash'],
file_info['key'], file_info['stream_name'], file_info['file_name'], file_info['download_directory'], file_info['key'], file_info['stream_name'], file_info['file_name'], file_info['download_directory'],
file_info['suggested_file_name'] file_info['suggested_file_name']
) )
yield lbry_file.get_claim_info() if claim_info:
lbry_file.set_claim_info(claim_info)
try: try:
# verify the stream is valid (we might have downloaded an invalid stream # restore will raise an Exception if status is unknown
# in the past when the validation check didn't work) lbry_file.restore(file_info['status'])
stream_info = yield get_sd_info(self.storage, file_info['stream_hash'], include_blobs=True) self.storage.content_claim_callbacks[lbry_file.stream_hash] = lbry_file.get_claim_info
validate_descriptor(stream_info) self.lbry_files.append(lbry_file)
except InvalidStreamDescriptorError as err: if len(self.lbry_files) % 500 == 0:
log.warning("Stream for descriptor %s is invalid (%s), cleaning it up", log.info("Started %i files", len(self.lbry_files))
lbry_file.sd_hash, err.message) except Exception:
yield lbry_file.delete_data() log.warning("Failed to start %i", file_info.get('rowid'))
yield self.session.storage.delete_stream(lbry_file.stream_hash)
else:
try:
# restore will raise an Exception if status is unknown
lbry_file.restore(file_info['status'])
self.storage.content_claim_callbacks[lbry_file.stream_hash] = lbry_file.get_claim_info
self.lbry_files.append(lbry_file)
if len(self.lbry_files) % 500 == 0:
log.info("Started %i files", len(self.lbry_files))
except Exception:
log.warning("Failed to start %i", file_info.get('rowid'))
@defer.inlineCallbacks @defer.inlineCallbacks
def _start_lbry_files(self): def _start_lbry_files(self):
files = yield self.session.storage.get_all_lbry_files() files = yield self.session.storage.get_all_lbry_files()
b_prm = self.session.base_payment_rate_manager claim_infos = yield self.session.storage.get_claims_from_stream_hashes([file['stream_hash'] for file in files])
payment_rate_manager = NegotiatedPaymentRateManager(b_prm, self.session.blob_tracker) prm = self.session.payment_rate_manager
log.info("Starting %i files", len(files)) log.info("Starting %i files", len(files))
dl = []
for file_info in files: for file_info in files:
dl.append(self._start_lbry_file(file_info, payment_rate_manager)) claim_info = claim_infos.get(file_info['stream_hash'])
self._start_lbry_file(file_info, prm, claim_info)
yield defer.DeferredList(dl)
log.info("Started %i lbry files", len(self.lbry_files)) log.info("Started %i lbry files", len(self.lbry_files))
if self.auto_re_reflect is True: if self.auto_re_reflect is True:

View file

View file

@ -0,0 +1,175 @@
import logging
from twisted.trial import unittest
from twisted.internet import defer, task
from lbrynet.dht import constants
from lbrynet.dht.node import Node
from mock_transport import resolve, listenUDP, MOCK_DHT_SEED_DNS, mock_node_generator
log = logging.getLogger(__name__)
class TestKademliaBase(unittest.TestCase):
timeout = 300.0 # timeout for each test
network_size = 16 # including seed nodes
node_ids = None
seed_dns = MOCK_DHT_SEED_DNS
def _add_next_node(self):
node_id, node_ip = self.mock_node_generator.next()
node = Node(node_id=node_id.decode('hex'), udpPort=4444, peerPort=3333, externalIP=node_ip,
resolve=resolve, listenUDP=listenUDP, callLater=self.clock.callLater, clock=self.clock)
self.nodes.append(node)
return node
@defer.inlineCallbacks
def add_node(self):
node = self._add_next_node()
yield node.start([(seed_name, 4444) for seed_name in sorted(self.seed_dns.keys())])
defer.returnValue(node)
def get_node(self, node_id):
for node in self.nodes:
if node.node_id == node_id:
return node
raise KeyError(node_id)
@defer.inlineCallbacks
def pop_node(self):
node = self.nodes.pop()
yield node.stop()
def pump_clock(self, n, step=0.1, tick_callback=None):
"""
:param n: seconds to run the reactor for
:param step: reactor tick rate (in seconds)
"""
for _ in range(int(n * (1.0 / float(step)))):
self.clock.advance(step)
if tick_callback and callable(tick_callback):
tick_callback(self.clock.seconds())
def run_reactor(self, seconds, deferreds, tick_callback=None):
d = defer.DeferredList(deferreds)
self.pump_clock(seconds, tick_callback=tick_callback)
return d
def get_contacts(self):
contacts = {}
for seed in self._seeds:
contacts[seed] = seed.contacts
for node in self._seeds:
contacts[node] = node.contacts
return contacts
def get_routable_addresses(self):
known = set()
for n in self._seeds:
known.update([(c.id, c.address, c.port) for c in n.contacts])
for n in self.nodes:
known.update([(c.id, c.address, c.port) for c in n.contacts])
addresses = {triple[1] for triple in known}
return addresses
def get_online_addresses(self):
online = set()
for n in self._seeds:
online.add(n.externalIP)
for n in self.nodes:
online.add(n.externalIP)
return online
def show_info(self):
known = set()
for n in self._seeds:
known.update([(c.id, c.address, c.port) for c in n.contacts])
for n in self.nodes:
known.update([(c.id, c.address, c.port) for c in n.contacts])
log.info("Routable: %i/%i", len(known), len(self.nodes) + len(self._seeds))
for n in self._seeds:
log.info("seed %s has %i contacts in %i buckets", n.externalIP, len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
for n in self.nodes:
log.info("node %s has %i contacts in %i buckets", n.externalIP, len(n.contacts),
len([b for b in n._routingTable._buckets if b.getContacts()]))
@defer.inlineCallbacks
def setUp(self):
self.nodes = []
self._seeds = []
self.clock = task.Clock()
self.mock_node_generator = mock_node_generator(mock_node_ids=self.node_ids)
seed_dl = []
seeds = sorted(list(self.seed_dns.keys()))
known_addresses = [(seed_name, 4444) for seed_name in seeds]
for seed_dns in seeds:
self._add_next_node()
seed = self.nodes.pop()
self._seeds.append(seed)
seed_dl.append(
seed.start(known_addresses)
)
yield self.run_reactor(constants.checkRefreshInterval+1, seed_dl)
while len(self.nodes + self._seeds) < self.network_size:
network_dl = []
for i in range(min(10, self.network_size - len(self._seeds) - len(self.nodes))):
network_dl.append(self.add_node())
yield self.run_reactor(constants.checkRefreshInterval*2+1, network_dl)
self.assertEqual(len(self.nodes + self._seeds), self.network_size)
self.pump_clock(3600)
self.verify_all_nodes_are_routable()
self.verify_all_nodes_are_pingable()
@defer.inlineCallbacks
def tearDown(self):
dl = []
while self.nodes:
dl.append(self.pop_node()) # stop all of the nodes
while self._seeds:
dl.append(self._seeds.pop().stop()) # and the seeds
yield defer.DeferredList(dl)
def verify_all_nodes_are_routable(self):
routable = set()
node_addresses = {node.externalIP for node in self.nodes}
node_addresses = node_addresses.union({node.externalIP for node in self._seeds})
for node in self._seeds:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
for node in self.nodes:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
self.assertSetEqual(routable, node_addresses)
@defer.inlineCallbacks
def verify_all_nodes_are_pingable(self):
ping_replies = {}
ping_dl = []
contacted = set()
def _ping_cb(result, node, replies):
replies[node] = result
for node in self._seeds:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
for node in self.nodes:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
yield self.run_reactor(2, ping_dl)
node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds})
self.assertSetEqual(node_addresses, contacted)
expected = {node: "pong" for node in contacted}
self.assertDictEqual(ping_replies, expected)

View file

@ -0,0 +1,151 @@
import struct
import hashlib
import logging
from twisted.internet import defer, error
from lbrynet.dht.encoding import Bencode
from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
_encode = Bencode()
_datagram_formatter = DefaultFormat()
log = logging.getLogger()
MOCK_DHT_NODES = [
"cc8db9d0dd9b65b103594b5f992adf09f18b310958fa451d61ce8d06f3ee97a91461777c2b7dea1a89d02d2f23eb0e4f",
"83a3a398eead3f162fbbe1afb3d63482bb5b6d3cdd8f9b0825c1dfa58dffd3f6f6026d6e64d6d4ae4c3dfe2262e734ba",
"b6928ff25778a7bbb5d258d3b3a06e26db1654f3d2efce8c26681d43f7237cdf2e359a4d309c4473d5d89ec99fb4f573",
]
MOCK_DHT_SEED_DNS = { # these map to mock nodes 0, 1, and 2
"lbrynet1.lbry.io": "10.42.42.1",
"lbrynet2.lbry.io": "10.42.42.2",
"lbrynet3.lbry.io": "10.42.42.3",
"lbrynet4.lbry.io": "10.42.42.4",
"lbrynet5.lbry.io": "10.42.42.5",
"lbrynet6.lbry.io": "10.42.42.6",
"lbrynet7.lbry.io": "10.42.42.7",
"lbrynet8.lbry.io": "10.42.42.8",
"lbrynet9.lbry.io": "10.42.42.9",
"lbrynet10.lbry.io": "10.42.42.10",
"lbrynet11.lbry.io": "10.42.42.11",
"lbrynet12.lbry.io": "10.42.42.12",
"lbrynet13.lbry.io": "10.42.42.13",
"lbrynet14.lbry.io": "10.42.42.14",
"lbrynet15.lbry.io": "10.42.42.15",
"lbrynet16.lbry.io": "10.42.42.16",
}
def resolve(name, timeout=(1, 3, 11, 45)):
if name not in MOCK_DHT_SEED_DNS:
return defer.fail(error.DNSLookupError(name))
return defer.succeed(MOCK_DHT_SEED_DNS[name])
class MockUDPTransport(object):
def __init__(self, address, port, max_packet_size, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._node = protocol._node
def write(self, data, address):
if address in MockNetwork.peers:
dest = MockNetwork.peers[address][0]
debug_kademlia_packet(data, (self.address, self.port), address, self._node)
dest.datagramReceived(data, (self.address, self.port))
else: # the node is sending to an address that doesnt currently exist, act like it never arrived
pass
class MockUDPPort(object):
def __init__(self, protocol, remover):
self.protocol = protocol
self._remover = remover
def startListening(self, reason=None):
return self.protocol.startProtocol()
def stopListening(self, reason=None):
result = self.protocol.stopProtocol()
self._remover()
return result
class MockNetwork(object):
peers = {} # (interface, port): (protocol, max_packet_size)
@classmethod
def add_peer(cls, port, protocol, interface, maxPacketSize):
interface = protocol._node.externalIP
protocol.transport = MockUDPTransport(interface, port, maxPacketSize, protocol)
cls.peers[(interface, port)] = (protocol, maxPacketSize)
def remove_peer():
del protocol.transport
if (interface, port) in cls.peers:
del cls.peers[(interface, port)]
return remove_peer
def listenUDP(port, protocol, interface='', maxPacketSize=8192):
remover = MockNetwork.add_peer(port, protocol, interface, maxPacketSize)
port = MockUDPPort(protocol, remover)
port.startListening()
return port
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES):
if mock_node_ids is None:
mock_node_ids = MOCK_DHT_NODES
mock_node_ids = list(mock_node_ids)
for num, node_ip in enumerate(address_generator()):
if count and num >= count:
break
if num >= len(mock_node_ids):
h = hashlib.sha384()
h.update("node %i" % num)
node_id = h.hexdigest()
else:
node_id = mock_node_ids[num]
yield (node_id, node_ip)
def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG:
return
try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds())
elif isinstance(packet, ResponseMessage):
if isinstance(packet.response, (str, unicode)):
log.debug("response %s <-- %s %s (node time %s)", destination[0], source[0], packet.response,
node.clock.seconds())
else:
log.debug("response %s <-- %s %i contacts (node time %s)", destination[0], source[0],
len(packet.response), node.clock.seconds())
elif isinstance(packet, ErrorMessage):
log.error("error %s <-- %s %s (node time %s)", destination[0], source[0], packet.exceptionType,
node.clock.seconds())
except DecodeError:
log.exception("decode error %s --> %s (node time %s)", source[0], destination[0], node.clock.seconds())

View file

@ -0,0 +1,34 @@
from twisted.trial import unittest
from dht_test_environment import TestKademliaBase
class TestKademliaBootstrap(TestKademliaBase):
"""
Test initializing the network / connecting the seed nodes
"""
def test_bootstrap_seed_nodes(self):
pass
@unittest.SkipTest
class TestKademliaBootstrap40Nodes(TestKademliaBase):
network_size = 40
def test_bootstrap_network(self):
pass
class TestKademliaBootstrap80Nodes(TestKademliaBase):
network_size = 80
def test_bootstrap_network(self):
pass
@unittest.SkipTest
class TestKademliaBootstrap120Nodes(TestKademliaBase):
network_size = 120
def test_bootstrap_network(self):
pass

View file

@ -0,0 +1,40 @@
import logging
from twisted.internet import defer
from lbrynet.dht import constants
from dht_test_environment import TestKademliaBase
log = logging.getLogger()
class TestPeerExpiration(TestKademliaBase):
network_size = 40
@defer.inlineCallbacks
def test_expire_stale_peers(self):
removed_addresses = set()
removed_nodes = []
# stop 5 nodes
for _ in range(5):
n = self.nodes[0]
removed_nodes.append(n)
removed_addresses.add(n.externalIP)
self.nodes.remove(n)
yield self.run_reactor(1, [n.stop()])
offline_addresses = self.get_routable_addresses().difference(self.get_online_addresses())
self.assertSetEqual(offline_addresses, removed_addresses)
get_nodes_with_stale_contacts = lambda: filter(lambda node: any(contact.address in offline_addresses
for contact in node.contacts),
self.nodes + self._seeds)
self.assertRaises(AssertionError, self.verify_all_nodes_are_routable)
self.assertTrue(len(get_nodes_with_stale_contacts()) > 1)
# run the network long enough for two failures to happen
self.pump_clock(constants.checkRefreshInterval * 3)
self.assertEquals(len(get_nodes_with_stale_contacts()), 0)
self.verify_all_nodes_are_routable()
self.verify_all_nodes_are_pingable()

View file

@ -0,0 +1,38 @@
import logging
from twisted.internet import defer
from lbrynet.dht import constants
from dht_test_environment import TestKademliaBase
log = logging.getLogger()
class TestReJoin(TestKademliaBase):
network_size = 40
@defer.inlineCallbacks
def setUp(self):
yield super(TestReJoin, self).setUp()
self.removed_node = self.nodes[20]
self.nodes.remove(self.removed_node)
yield self.run_reactor(1, [self.removed_node.stop()])
self.pump_clock(constants.checkRefreshInterval * 2)
self.verify_all_nodes_are_routable()
self.verify_all_nodes_are_pingable()
@defer.inlineCallbacks
def test_re_join(self):
self.nodes.append(self.removed_node)
yield self.run_reactor(
31, [self.removed_node.start([(seed_name, 4444) for seed_name in sorted(self.seed_dns.keys())])]
)
self.pump_clock(constants.checkRefreshInterval*2)
self.verify_all_nodes_are_routable()
self.verify_all_nodes_are_pingable()
def test_re_join_with_new_ip(self):
self.removed_node.externalIP = "10.43.43.43"
return self.test_re_join()
def test_re_join_with_new_node_id(self):
self.removed_node.node_id = self.removed_node._generateID()
return self.test_re_join()

View file

@ -0,0 +1,269 @@
import time
from twisted.trial import unittest
import logging
from twisted.internet.task import Clock
from twisted.internet import defer
import lbrynet.dht.protocol
import lbrynet.dht.contact
from lbrynet.dht.error import TimeoutError
from lbrynet.dht.node import Node, rpcmethod
from mock_transport import listenUDP, resolve
log = logging.getLogger()
class KademliaProtocolTest(unittest.TestCase):
""" Test case for the Protocol class """
udpPort = 9182
def setUp(self):
self._reactor = Clock()
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_contact = self.node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', 9182,
self.remote_node._protocol)
self.node.start_listening()
self.remote_node.start_listening()
@defer.inlineCallbacks
def tearDown(self):
yield self.node.stop()
yield self.remote_node.stop()
del self._reactor
@defer.inlineCallbacks
def testReactor(self):
""" Tests if the reactor can start/stop the protocol correctly """
d = defer.Deferred()
self._reactor.callLater(1, d.callback, True)
self._reactor.advance(1)
result = yield d
self.assertTrue(result)
@defer.inlineCallbacks
def testRPCTimeout(self):
""" Tests if a RPC message sent to a dead remote node times out correctly """
yield self.remote_node.stop()
self._reactor.pump([1 for _ in range(10)])
self.node.addContact(self.remote_contact)
@rpcmethod
def fake_ping(*args, **kwargs):
time.sleep(lbrynet.dht.constants.rpcTimeout + 1)
return 'pong'
real_ping = self.node.ping
real_timeout = lbrynet.dht.constants.rpcTimeout
real_attempts = lbrynet.dht.constants.rpcAttempts
lbrynet.dht.constants.rpcAttempts = 1
lbrynet.dht.constants.rpcTimeout = 1
self.node.ping = fake_ping
# Make sure the contact was added
self.failIf(self.remote_contact not in self.node.contacts,
'Contact not added to fake node (error in test code)')
self.node.start_listening()
# Run the PING RPC (which should raise a timeout error)
df = self.remote_contact.ping()
def check_timeout(err):
self.assertEqual(err.type, TimeoutError)
df.addErrback(check_timeout)
def reset_values():
self.node.ping = real_ping
lbrynet.dht.constants.rpcTimeout = real_timeout
lbrynet.dht.constants.rpcAttempts = real_attempts
# See if the contact was removed due to the timeout
def check_removed_contact():
self.failIf(self.remote_contact in self.node.contacts,
'Contact was not removed after RPC timeout; check exception types.')
df.addCallback(lambda _: reset_values())
# Stop the reactor if a result arrives (timeout or not)
df.addCallback(lambda _: check_removed_contact())
self._reactor.pump([1 for _ in range(20)])
@defer.inlineCallbacks
def testRPCRequest(self):
""" Tests if a valid RPC request is executed and responded to correctly """
yield self.node.addContact(self.remote_contact)
self.error = None
def handleError(f):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \
% (expectedResult, result)
# Simulate the RPC
df = self.remote_contact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
self._reactor.advance(2)
yield df
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')
def testRPCAccess(self):
""" Tests invalid RPC requests
Verifies that a RPC request for an existing but unpublished
method is denied, and that the associated (remote) exception gets
raised locally """
self.assertRaises(AttributeError, getattr, self.remote_contact, "not_a_rpc_function")
def testRPCRequestArgs(self):
""" Tests if an RPC requiring arguments is executed correctly """
self.node.addContact(self.remote_contact)
self.error = None
def handleError(f):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' % \
(expectedResult, result)
# Publish the "local" node on the network
self.node.start_listening()
# Simulate the RPC
df = self.remote_contact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
self._reactor.pump([1 for _ in range(10)])
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')
@defer.inlineCallbacks
def testDetectProtocolVersion(self):
original_findvalue = self.remote_node.findValue
fake_blob = str("AB" * 48).decode('hex')
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
result.pop('protocolVersion')
return result
self.remote_node.findValue = findValue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
self.remote_node.findValue = original_findvalue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 1)
self.assertTrue('protocolVersion' not in find_value_response)
self.remote_node.findValue = findValue
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
@defer.inlineCallbacks
def testStoreToPre_0_20_0_Node(self):
def _dont_migrate(contact, method, *args):
return args, {}
self.remote_node._protocol._migrate_incoming_rpc_args = _dont_migrate
original_findvalue = self.remote_node.findValue
original_store = self.remote_node.store
@rpcmethod
def findValue(contact, key):
result = original_findvalue(contact, key)
if 'protocolVersion' in result:
result.pop('protocolVersion')
return result
@rpcmethod
def store(contact, key, value, originalPublisherID=None, self_store=False, **kwargs):
self.assertTrue(len(key) == 48)
self.assertSetEqual(set(value.keys()), {'token', 'lbryid', 'port'})
self.assertFalse(self_store)
self.assertDictEqual(kwargs, {})
return original_store( # pylint: disable=too-many-function-args
contact, key, value['token'], value['port'], originalPublisherID, 0
)
self.remote_node.findValue = findValue
self.remote_node.store = store
fake_blob = str("AB" * 48).decode('hex')
d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
d = self.remote_contact.store(fake_blob, token, 3333, self.node.node_id, 0)
self._reactor.advance(3)
response = yield d
self.assertEquals(response, "OK")
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.remote_node._dataStore.hasPeersForBlob(fake_blob))
self.assertEquals(len(self.remote_node._dataStore.getStoringContacts()), 1)
@defer.inlineCallbacks
def testStoreFromPre_0_20_0_Node(self):
def _dont_migrate(contact, method, *args):
return args
self.remote_node._protocol._migrate_outgoing_rpc_args = _dont_migrate
us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', self.udpPort,
self.remote_node._protocol)
fake_blob = str("AB" * 48).decode('hex')
d = us_from_them.findValue(fake_blob)
self._reactor.advance(3)
find_value_response = yield d
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response)
token = find_value_response['token']
us_from_them.update_protocol_version(0)
d = self.remote_node._protocol.sendRPC(
us_from_them, "store", (fake_blob, {'lbryid': self.remote_node.node_id, 'token': token, 'port': 3333})
)
self._reactor.advance(3)
response = yield d
self.assertEquals(response, "OK")
self.assertEquals(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.node._dataStore.hasPeersForBlob(fake_blob))
self.assertEquals(len(self.node._dataStore.getStoringContacts()), 1)
self.assertIs(self.node._dataStore.getStoringContacts()[0], self.remote_contact)

View file

@ -0,0 +1,27 @@
from lbrynet.dht import constants
from lbrynet.dht.distance import Distance
from dht_test_environment import TestKademliaBase
import logging
log = logging.getLogger()
class TestFindNode(TestKademliaBase):
"""
This tests the local routing table lookup for a node, every node should return the sorted k contacts closest
to the querying node (even if the key being looked up is known)
"""
network_size = 35
def test_find_node(self):
last_node_id = self.nodes[-1].node_id.encode('hex')
to_last_node = Distance(last_node_id.decode('hex'))
for n in self.nodes:
find_close_nodes_result = n._routingTable.findCloseNodes(last_node_id.decode('hex'), constants.k)
self.assertTrue(len(find_close_nodes_result) == constants.k)
found_ids = [c.id.encode('hex') for c in find_close_nodes_result]
self.assertListEqual(found_ids, sorted(found_ids, key=lambda x: to_last_node(x.decode('hex'))))
if last_node_id in [c.id.encode('hex') for c in n.contacts]:
self.assertTrue(found_ids[0] == last_node_id)
else:
self.assertTrue(last_node_id not in found_ids)

View file

@ -0,0 +1,145 @@
import struct
from twisted.internet import defer
from lbrynet.dht import constants
from lbrynet.core.utils import generate_id
from dht_test_environment import TestKademliaBase
import logging
log = logging.getLogger()
class TestStoreExpiration(TestKademliaBase):
network_size = 40
@defer.inlineCallbacks
def test_store_and_expire(self):
blob_hash = generate_id(1)
announcing_node = self.nodes[20]
# announce the blob
announce_d = announcing_node.announceHaveBlob(blob_hash)
self.pump_clock(5)
storing_node_ids = yield announce_d
all_nodes = set(self.nodes).union(set(self._seeds))
# verify the nodes we think stored it did actually store it
storing_nodes = [node for node in all_nodes if node.node_id.encode('hex') in storing_node_ids]
self.assertEquals(len(storing_nodes), len(storing_node_ids))
self.assertEquals(len(storing_nodes), constants.k)
for node in storing_nodes:
self.assertTrue(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(map(lambda contact: (contact.id, contact.address, contact.port),
node._dataStore.getStoringContacts()), [(announcing_node.node_id,
announcing_node.externalIP,
announcing_node.port)])
self.assertEquals(len(datastore_result), 1)
expanded_peers = []
for peer in datastore_result:
host = ".".join([str(ord(d)) for d in peer[:4]])
port, = struct.unpack('>H', peer[4:6])
peer_node_id = peer[6:]
if (host, port, peer_node_id) not in expanded_peers:
expanded_peers.append((peer_node_id, host, port))
self.assertEquals(expanded_peers[0],
(announcing_node.node_id, announcing_node.externalIP, announcing_node.peerPort))
# verify the announced blob expires in the storing nodes datastores
self.clock.advance(constants.dataExpireTimeout) # skip the clock directly ahead
for node in storing_nodes:
self.assertFalse(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 0)
self.assertTrue(blob_hash in node._dataStore._dict) # the looping call shouldn't have removed it yet
self.assertEquals(len(node._dataStore.getStoringContacts()), 1)
self.pump_clock(constants.checkRefreshInterval + 1) # tick the clock forward (so the nodes refresh)
for node in storing_nodes:
self.assertFalse(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 0)
self.assertEquals(len(node._dataStore.getStoringContacts()), 0)
self.assertTrue(blob_hash not in node._dataStore._dict) # the looping call should have fired
@defer.inlineCallbacks
def test_storing_node_went_stale_then_came_back(self):
blob_hash = generate_id(1)
announcing_node = self.nodes[20]
# announce the blob
announce_d = announcing_node.announceHaveBlob(blob_hash)
self.pump_clock(5)
storing_node_ids = yield announce_d
all_nodes = set(self.nodes).union(set(self._seeds))
# verify the nodes we think stored it did actually store it
storing_nodes = [node for node in all_nodes if node.node_id.encode('hex') in storing_node_ids]
self.assertEquals(len(storing_nodes), len(storing_node_ids))
self.assertEquals(len(storing_nodes), constants.k)
for node in storing_nodes:
self.assertTrue(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(map(lambda contact: (contact.id, contact.address, contact.port),
node._dataStore.getStoringContacts()), [(announcing_node.node_id,
announcing_node.externalIP,
announcing_node.port)])
self.assertEquals(len(datastore_result), 1)
expanded_peers = []
for peer in datastore_result:
host = ".".join([str(ord(d)) for d in peer[:4]])
port, = struct.unpack('>H', peer[4:6])
peer_node_id = peer[6:]
if (host, port, peer_node_id) not in expanded_peers:
expanded_peers.append((peer_node_id, host, port))
self.assertEquals(expanded_peers[0],
(announcing_node.node_id, announcing_node.externalIP, announcing_node.peerPort))
self.pump_clock(constants.checkRefreshInterval*2)
# stop the node
self.nodes.remove(announcing_node)
yield self.run_reactor(31, [announcing_node.stop()])
# run the network for an hour, which should expire the removed node and turn the announced value stale
self.pump_clock(constants.checkRefreshInterval * 4, constants.checkRefreshInterval/2)
self.verify_all_nodes_are_routable()
# make sure the contact isn't returned as a peer for the blob, but that we still have the entry in the
# datastore in case the node comes back
for node in storing_nodes:
self.assertFalse(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 0)
self.assertEquals(len(node._dataStore.getStoringContacts()), 1)
self.assertTrue(blob_hash in node._dataStore._dict)
# # bring the announcing node back online
self.nodes.append(announcing_node)
yield self.run_reactor(
31, [announcing_node.start([(seed_name, 4444) for seed_name in sorted(self.seed_dns.keys())])]
)
self.pump_clock(constants.checkRefreshInterval * 2)
self.verify_all_nodes_are_routable()
# now the announcing node should once again be returned as a peer for the blob
for node in storing_nodes:
self.assertTrue(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 1)
self.assertEquals(len(node._dataStore.getStoringContacts()), 1)
self.assertTrue(blob_hash in node._dataStore._dict)
# verify the announced blob expires in the storing nodes datastores
self.clock.advance(constants.dataExpireTimeout) # skip the clock directly ahead
for node in storing_nodes:
self.assertFalse(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 0)
self.assertTrue(blob_hash in node._dataStore._dict) # the looping call shouldn't have removed it yet
self.assertEquals(len(node._dataStore.getStoringContacts()), 1)
self.pump_clock(constants.checkRefreshInterval + 1) # tick the clock forward (so the nodes refresh)
for node in storing_nodes:
self.assertFalse(node._dataStore.hasPeersForBlob(blob_hash))
datastore_result = node._dataStore.getPeersForBlob(blob_hash)
self.assertEquals(len(datastore_result), 0)
self.assertEquals(len(node._dataStore.getStoringContacts()), 0)
self.assertTrue(blob_hash not in node._dataStore._dict) # the looping call should have fired

View file

@ -1,274 +0,0 @@
import time
import logging
from twisted.trial import unittest
from twisted.internet import defer, threads, task
from lbrynet.dht.node import Node
from lbrynet.tests import mocks
from lbrynet.core.utils import generate_id
log = logging.getLogger("lbrynet.tests.util")
# log.addHandler(logging.StreamHandler())
# log.setLevel(logging.DEBUG)
class TestKademliaBase(unittest.TestCase):
timeout = 300.0 # timeout for each test
network_size = 0 # plus lbrynet1, lbrynet2, and lbrynet3 seed nodes
node_ids = None
seed_dns = mocks.MOCK_DHT_SEED_DNS
def _add_next_node(self):
node_id, node_ip = self.mock_node_generator.next()
node = Node(node_id=node_id.decode('hex'), udpPort=4444, peerPort=3333, externalIP=node_ip,
resolve=mocks.resolve, listenUDP=mocks.listenUDP, callLater=self.clock.callLater, clock=self.clock)
self.nodes.append(node)
return node
@defer.inlineCallbacks
def add_node(self):
node = self._add_next_node()
yield node.joinNetwork(
[
("lbrynet1.lbry.io", self._seeds[0].port),
("lbrynet2.lbry.io", self._seeds[1].port),
("lbrynet3.lbry.io", self._seeds[2].port),
]
)
defer.returnValue(node)
def get_node(self, node_id):
for node in self.nodes:
if node.node_id == node_id:
return node
raise KeyError(node_id)
@defer.inlineCallbacks
def pop_node(self):
node = self.nodes.pop()
yield node.stop()
def pump_clock(self, n, step=0.01):
"""
:param n: seconds to run the reactor for
:param step: reactor tick rate (in seconds)
"""
for _ in range(n * 100):
self.clock.advance(step)
def run_reactor(self, seconds, *deferreds):
dl = [threads.deferToThread(self.pump_clock, seconds)]
for d in deferreds:
dl.append(d)
return defer.DeferredList(dl)
@defer.inlineCallbacks
def setUp(self):
self.nodes = []
self._seeds = []
self.clock = task.Clock()
self.mock_node_generator = mocks.mock_node_generator(mock_node_ids=self.node_ids)
join_dl = []
for seed_dns in self.seed_dns:
other_seeds = list(self.seed_dns.keys())
other_seeds.remove(seed_dns)
self._add_next_node()
seed = self.nodes.pop()
self._seeds.append(seed)
join_dl.append(
seed.joinNetwork([(other_seed_dns, 4444) for other_seed_dns in other_seeds])
)
if self.network_size:
for _ in range(self.network_size):
join_dl.append(self.add_node())
yield self.run_reactor(1, *tuple(join_dl))
self.verify_all_nodes_are_routable()
@defer.inlineCallbacks
def tearDown(self):
dl = []
while self.nodes:
dl.append(self.pop_node()) # stop all of the nodes
while self._seeds:
dl.append(self._seeds.pop().stop()) # and the seeds
yield defer.DeferredList(dl)
def verify_all_nodes_are_routable(self):
routable = set()
node_addresses = {node.externalIP for node in self.nodes}
node_addresses = node_addresses.union({node.externalIP for node in self._seeds})
for node in self._seeds:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
for node in self.nodes:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
self.assertSetEqual(routable, node_addresses)
@defer.inlineCallbacks
def verify_all_nodes_are_pingable(self):
ping_replies = {}
ping_dl = []
contacted = set()
def _ping_cb(result, node, replies):
replies[node] = result
for node in self._seeds:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
for node in self.nodes:
contact_addresses = set()
for contact in node.contacts:
contact_addresses.add(contact.address)
d = contact.ping()
d.addCallback(_ping_cb, contact.address, ping_replies)
contacted.add(contact.address)
ping_dl.append(d)
self.run_reactor(2, *ping_dl)
yield threads.deferToThread(time.sleep, 0.1)
node_addresses = {node.externalIP for node in self.nodes}.union({seed.externalIP for seed in self._seeds})
self.assertSetEqual(node_addresses, contacted)
self.assertDictEqual(ping_replies, {node: "pong" for node in contacted})
class TestKademliaBootstrap(TestKademliaBase):
"""
Test initializing the network / connecting the seed nodes
"""
def test_bootstrap_network(self): # simulates the real network, which has three seeds
self.assertEqual(len(self._seeds[0].contacts), 2)
self.assertEqual(len(self._seeds[1].contacts), 2)
self.assertEqual(len(self._seeds[2].contacts), 2)
self.assertSetEqual(
{self._seeds[0].contacts[0].address, self._seeds[0].contacts[1].address},
{self._seeds[1].externalIP, self._seeds[2].externalIP}
)
self.assertSetEqual(
{self._seeds[1].contacts[0].address, self._seeds[1].contacts[1].address},
{self._seeds[0].externalIP, self._seeds[2].externalIP}
)
self.assertSetEqual(
{self._seeds[2].contacts[0].address, self._seeds[2].contacts[1].address},
{self._seeds[0].externalIP, self._seeds[1].externalIP}
)
def test_all_nodes_are_pingable(self):
return self.verify_all_nodes_are_pingable()
class TestKademliaBootstrapSixteenSeeds(TestKademliaBase):
node_ids = [
'000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000',
'111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111',
'222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222',
'333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333',
'444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444',
'555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555',
'666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666666',
'777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777777',
'888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888',
'999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999',
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc',
'dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd',
'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
]
@defer.inlineCallbacks
def setUp(self):
self.seed_dns.update(
{
"lbrynet4.lbry.io": "10.42.42.4",
"lbrynet5.lbry.io": "10.42.42.5",
"lbrynet6.lbry.io": "10.42.42.6",
"lbrynet7.lbry.io": "10.42.42.7",
"lbrynet8.lbry.io": "10.42.42.8",
"lbrynet9.lbry.io": "10.42.42.9",
"lbrynet10.lbry.io": "10.42.42.10",
"lbrynet11.lbry.io": "10.42.42.11",
"lbrynet12.lbry.io": "10.42.42.12",
"lbrynet13.lbry.io": "10.42.42.13",
"lbrynet14.lbry.io": "10.42.42.14",
"lbrynet15.lbry.io": "10.42.42.15",
"lbrynet16.lbry.io": "10.42.42.16",
}
)
yield TestKademliaBase.setUp(self)
@defer.inlineCallbacks
def tearDown(self):
yield TestKademliaBase.tearDown(self)
def test_bootstrap_network(self):
pass
def _test_all_nodes_are_pingable(self):
return self.verify_all_nodes_are_pingable()
class Test250NodeNetwork(TestKademliaBase):
network_size = 250
def test_setup_network_and_verify_connectivity(self):
pass
def update_network(self):
import random
dl = []
announced_blobs = []
for node in self.nodes: # random events
if random.randint(0, 10000) < 75 and announced_blobs: # get peers for a blob
log.info('find blob')
blob_hash = random.choice(announced_blobs)
dl.append(node.getPeersForBlob(blob_hash))
if random.randint(0, 10000) < 25: # announce a blob
log.info('announce blob')
blob_hash = generate_id()
announced_blobs.append((blob_hash, node.node_id))
dl.append(node.announceHaveBlob(blob_hash))
random.shuffle(self.nodes)
# kill nodes
while random.randint(0, 100) > 95:
dl.append(self.pop_node())
log.info('pop node')
# add nodes
while random.randint(0, 100) > 95:
dl.append(self.add_node())
log.info('add node')
return tuple(dl), announced_blobs
@defer.inlineCallbacks
def _test_simulate_network(self):
total_blobs = []
for i in range(100):
d, blobs = self.update_network()
total_blobs.extend(blobs)
self.run_reactor(1, *d)
yield threads.deferToThread(time.sleep, 0.1)
routable = set()
node_addresses = {node.externalIP for node in self.nodes}
for node in self.nodes:
contact_addresses = {contact.address for contact in node.contacts}
routable.update(contact_addresses)
log.warning("difference: %i", len(node_addresses.difference(routable)))
log.info("blobs %i", len(total_blobs))
log.info("step %i, %i nodes", i, len(self.nodes))
self.pump_clock(100)

View file

@ -1,21 +1,18 @@
import base64 import base64
import struct
import io import io
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from twisted.internet import defer, error from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from lbrynet.core.client.ClientRequest import ClientRequest from lbrynet.core.client.ClientRequest import ClientRequest
from lbrynet.core.Error import RequestCanceledError from lbrynet.core.Error import RequestCanceledError
from lbrynet.core import BlobAvailability from lbrynet.core import BlobAvailability
from lbrynet.core.utils import generate_id
from lbrynet.dht.node import Node as RealNode from lbrynet.dht.node import Node as RealNode
from lbrynet.daemon import ExchangeRateManager as ERM from lbrynet.daemon import ExchangeRateManager as ERM
from lbrynet import conf from lbrynet import conf
from util import debug_kademlia_packet
KB = 2**10 KB = 2**10
PUBLIC_EXPONENT = 65537 # http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html PUBLIC_EXPONENT = 65537 # http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html
@ -41,6 +38,9 @@ class Node(RealNode):
def stop(self): def stop(self):
return defer.succeed(None) return defer.succeed(None)
def start(self, known_node_addresses=None):
return self.joinNetwork(known_node_addresses)
class FakeNetwork(object): class FakeNetwork(object):
@staticmethod @staticmethod
@ -188,9 +188,15 @@ class Wallet(object):
def get_info_exchanger(self): def get_info_exchanger(self):
return PointTraderKeyExchanger(self) return PointTraderKeyExchanger(self)
def update_peer_address(self, peer, address):
pass
def get_wallet_info_query_handler_factory(self): def get_wallet_info_query_handler_factory(self):
return PointTraderKeyQueryHandlerFactory(self) return PointTraderKeyQueryHandlerFactory(self)
def get_unused_address_for_peer(self, peer):
return defer.succeed("bDtL6qriyimxz71DSYjojTBsm6cpM1bqmj")
def reserve_points(self, *args): def reserve_points(self, *args):
return True return True
@ -244,24 +250,15 @@ class Announcer(object):
def hash_queue_size(self): def hash_queue_size(self):
return 0 return 0
def add_supplier(self, supplier):
pass
def immediate_announce(self, *args): def immediate_announce(self, *args):
pass pass
def run_manage_loop(self):
pass
def start(self): def start(self):
pass pass
def stop(self): def stop(self):
pass pass
def get_next_announce_time(self):
return 0
class GenFile(io.RawIOBase): class GenFile(io.RawIOBase):
def __init__(self, size, pattern): def __init__(self, size, pattern):
@ -410,89 +407,3 @@ def mock_conf_settings(obj, settings={}):
conf.settings = original_settings conf.settings = original_settings
obj.addCleanup(_reset_settings) obj.addCleanup(_reset_settings)
MOCK_DHT_NODES = [
"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
"DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF",
]
MOCK_DHT_SEED_DNS = { # these map to mock nodes 0, 1, and 2
"lbrynet1.lbry.io": "10.42.42.1",
"lbrynet2.lbry.io": "10.42.42.2",
"lbrynet3.lbry.io": "10.42.42.3",
}
def resolve(name, timeout=(1, 3, 11, 45)):
if name not in MOCK_DHT_SEED_DNS:
return defer.fail(error.DNSLookupError(name))
return defer.succeed(MOCK_DHT_SEED_DNS[name])
class MockUDPTransport(object):
def __init__(self, address, port, max_packet_size, protocol):
self.address = address
self.port = port
self.max_packet_size = max_packet_size
self._node = protocol._node
def write(self, data, address):
dest = MockNetwork.peers[address][0]
debug_kademlia_packet(data, (self.address, self.port), address, self._node)
dest.datagramReceived(data, (self.address, self.port))
class MockUDPPort(object):
def __init__(self, protocol):
self.protocol = protocol
def startListening(self, reason=None):
return self.protocol.startProtocol()
def stopListening(self, reason=None):
return self.protocol.stopProtocol()
class MockNetwork(object):
peers = {} # (interface, port): (protocol, max_packet_size)
@classmethod
def add_peer(cls, port, protocol, interface, maxPacketSize):
interface = protocol._node.externalIP
protocol.transport = MockUDPTransport(interface, port, maxPacketSize, protocol)
cls.peers[(interface, port)] = (protocol, maxPacketSize)
def listenUDP(port, protocol, interface='', maxPacketSize=8192):
MockNetwork.add_peer(port, protocol, interface, maxPacketSize)
return MockUDPPort(protocol)
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
def mock_node_generator(count=None, mock_node_ids=MOCK_DHT_NODES):
if mock_node_ids is None:
mock_node_ids = MOCK_DHT_NODES
for num, node_ip in enumerate(address_generator()):
if count and num >= count:
break
if num >= len(mock_node_ids):
node_id = generate_id().encode('hex')
else:
node_id = mock_node_ids[num]
yield (node_id, node_ip)

View file

@ -1,82 +0,0 @@
import tempfile
import shutil
from twisted.trial import unittest
from twisted.internet import defer, reactor, threads
from lbrynet.tests.util import random_lbry_hash
from lbrynet.dht.hashannouncer import DHTHashAnnouncer
from lbrynet.core.call_later_manager import CallLaterManager
from lbrynet.database.storage import SQLiteStorage
class MocDHTNode(object):
def __init__(self, announce_will_fail=False):
# if announce_will_fail is True,
# announceHaveBlob will return empty dict
self.call_later_manager = CallLaterManager
self.call_later_manager.setup(reactor.callLater)
self.blobs_announced = 0
self.announce_will_fail = announce_will_fail
def announceHaveBlob(self, blob):
if self.announce_will_fail:
return_val = {}
else:
return_val = {blob: ["ab"*48]}
self.blobs_announced += 1
d = defer.Deferred()
self.call_later_manager.call_later(1, d.callback, return_val)
return d
class DHTHashAnnouncerTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
from lbrynet.conf import initialize_settings
initialize_settings(False)
self.num_blobs = 10
self.blobs_to_announce = []
for i in range(0, self.num_blobs):
self.blobs_to_announce.append(random_lbry_hash())
self.dht_node = MocDHTNode()
self.dht_node.peerPort = 3333
self.dht_node.clock = reactor
self.db_dir = tempfile.mkdtemp()
self.storage = SQLiteStorage(self.db_dir)
yield self.storage.setup()
self.announcer = DHTHashAnnouncer(self.dht_node, self.storage, 10)
for blob_hash in self.blobs_to_announce:
yield self.storage.add_completed_blob(blob_hash, 100, 0, 1)
@defer.inlineCallbacks
def tearDown(self):
self.dht_node.call_later_manager.stop()
yield self.storage.stop()
yield threads.deferToThread(shutil.rmtree, self.db_dir)
@defer.inlineCallbacks
def test_announce_fail(self):
# test what happens when node.announceHaveBlob() returns empty dict
self.dht_node.announce_will_fail = True
d = yield self.announcer.manage()
yield d
@defer.inlineCallbacks
def test_basic(self):
d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(len(self.announcer.hash_queue), self.num_blobs)
yield d
self.assertEqual(self.dht_node.blobs_announced, self.num_blobs)
self.assertEqual(len(self.announcer.hash_queue), 0)
@defer.inlineCallbacks
def test_immediate_announce(self):
# Test that immediate announce puts a hash at the front of the queue
d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(len(self.announcer.hash_queue), self.num_blobs)
blob_hash = random_lbry_hash()
self.announcer.immediate_announce([blob_hash])
self.assertEqual(len(self.announcer.hash_queue), self.num_blobs+1)
self.assertEqual(blob_hash, self.announcer.hash_queue[-1])
yield d

View file

@ -98,8 +98,7 @@ class StorageTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def store_fake_blob(self, blob_hash, blob_length=100, next_announce=0, should_announce=0): def store_fake_blob(self, blob_hash, blob_length=100, next_announce=0, should_announce=0):
yield self.storage.add_completed_blob(blob_hash, blob_length, next_announce, yield self.storage.add_completed_blob(blob_hash, blob_length, next_announce,
should_announce) should_announce, "finished")
yield self.storage.set_blob_status(blob_hash, "finished")
@defer.inlineCallbacks @defer.inlineCallbacks
def store_fake_stream_blob(self, stream_hash, blob_hash, blob_num, length=100, iv="DEADBEEF"): def store_fake_stream_blob(self, stream_hash, blob_hash, blob_num, length=100, iv="DEADBEEF"):
@ -163,6 +162,25 @@ class BlobStorageTests(StorageTest):
self.assertEqual(blob_hashes, []) self.assertEqual(blob_hashes, [])
class SupportsStorageTests(StorageTest):
@defer.inlineCallbacks
def test_supports_storage(self):
claim_ids = [random_lbry_hash() for _ in range(10)]
random_supports = [{"txid": random_lbry_hash(), "nout":i, "address": "addr{}".format(i), "amount": i}
for i in range(20)]
expected_supports = {}
for idx, claim_id in enumerate(claim_ids):
yield self.storage.save_supports(claim_id, random_supports[idx*2:idx*2+2])
for random_support in random_supports[idx*2:idx*2+2]:
random_support['claim_id'] = claim_id
expected_supports.setdefault(claim_id, []).append(random_support)
supports = yield self.storage.get_supports(claim_ids[0])
self.assertEqual(supports, expected_supports[claim_ids[0]])
all_supports = yield self.storage.get_supports(*claim_ids)
for support in all_supports:
self.assertIn(support, expected_supports[support['claim_id']])
class StreamStorageTests(StorageTest): class StreamStorageTests(StorageTest):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_store_stream(self, stream_hash=None): def test_store_stream(self, stream_hash=None):

View file

@ -1,16 +1,32 @@
import unittest from twisted.internet import task
from twisted.trial import unittest
from lbrynet.dht import contact from lbrynet.core.utils import generate_id
from lbrynet.dht.contact import ContactManager
from lbrynet.dht import constants
class ContactOperatorsTest(unittest.TestCase): class ContactOperatorsTest(unittest.TestCase):
""" Basic tests case for boolean operators on the Contact class """ """ Basic tests case for boolean operators on the Contact class """
def setUp(self): def setUp(self):
self.firstContact = contact.Contact('firstContactID', '127.0.0.1', 1000, None, 1) self.contact_manager = ContactManager()
self.secondContact = contact.Contact('2ndContactID', '192.168.0.1', 1000, None, 32) self.node_ids = [generate_id(), generate_id(), generate_id()]
self.secondContactCopy = contact.Contact('2ndContactID', '192.168.0.1', 1000, None, 32) self.firstContact = self.contact_manager.make_contact(self.node_ids[1], '127.0.0.1', 1000, None, 1)
self.firstContactDifferentValues = contact.Contact( self.secondContact = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
'firstContactID', '192.168.1.20', 1000, None, 50) self.secondContactCopy = self.contact_manager.make_contact(self.node_ids[0], '192.168.0.1', 1000, None, 32)
self.firstContactDifferentValues = self.contact_manager.make_contact(self.node_ids[1], '192.168.1.20',
1000, None, 50)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20',
100000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], '192.168.1.20.1',
1000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, self.node_ids[1], 'this is not an ip',
1000, None)
self.assertRaises(ValueError, self.contact_manager.make_contact, "this is not a node id", '192.168.1.20.1',
1000, None)
def testNoDuplicateContactObjects(self):
self.assertTrue(self.secondContact is self.secondContactCopy)
self.assertTrue(self.firstContact is not self.firstContactDifferentValues)
def testBoolean(self): def testBoolean(self):
""" Test "equals" and "not equals" comparisons """ """ Test "equals" and "not equals" comparisons """
@ -24,15 +40,6 @@ class ContactOperatorsTest(unittest.TestCase):
self.secondContact, self.secondContactCopy, self.secondContact, self.secondContactCopy,
'Different copies of the same Contact instance should be equal') 'Different copies of the same Contact instance should be equal')
def testStringComparisons(self):
""" Test comparisons of Contact objects with str types """
self.failUnlessEqual(
'firstContactID', self.firstContact,
'The node ID string must be equal to the contact object')
self.failIfEqual(
'some random string', self.firstContact,
"The tested string should not be equal to the contact object (not equal to it's ID)")
def testIllogicalComparisons(self): def testIllogicalComparisons(self):
""" Test comparisons with non-Contact and non-str types """ """ Test comparisons with non-Contact and non-str types """
msg = '"{}" operator: Contact object should not be equal to {} type' msg = '"{}" operator: Contact object should not be equal to {} type'
@ -47,3 +54,119 @@ class ContactOperatorsTest(unittest.TestCase):
def testCompactIP(self): def testCompactIP(self):
self.assertEqual(self.firstContact.compact_ip(), '\x7f\x00\x00\x01') self.assertEqual(self.firstContact.compact_ip(), '\x7f\x00\x00\x01')
self.assertEqual(self.secondContact.compact_ip(), '\xc0\xa8\x00\x01') self.assertEqual(self.secondContact.compact_ip(), '\xc0\xa8\x00\x01')
class TestContactLastReplied(unittest.TestCase):
def setUp(self):
self.clock = task.Clock()
self.contact_manager = ContactManager(self.clock.seconds)
self.contact = self.contact_manager.make_contact(generate_id(), "127.0.0.1", 4444, None)
self.clock.advance(3600)
self.assertTrue(self.contact.contact_is_good is None)
def test_stale_replied_to_us(self):
self.contact.update_last_replied()
self.assertTrue(self.contact.contact_is_good is True)
def test_stale_requested_from_us(self):
self.contact.update_last_requested()
self.assertTrue(self.contact.contact_is_good is None)
def test_stale_then_fail(self):
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is None)
self.clock.advance(1)
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is False)
def test_good_turned_stale(self):
self.contact.update_last_replied()
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(constants.checkRefreshInterval - 1)
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is None)
def test_good_then_fail(self):
self.contact.update_last_replied()
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(59)
self.assertTrue(self.contact.contact_is_good is True)
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is False)
for _ in range(7200):
self.clock.advance(60)
self.assertTrue(self.contact.contact_is_good is False)
def test_good_then_fail_then_good(self):
# it replies
self.contact.update_last_replied()
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
# it fails twice in a row
self.contact.update_last_failed()
self.clock.advance(1)
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is False)
self.clock.advance(1)
# it replies
self.contact.update_last_replied()
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is True)
# it goes stale
self.clock.advance(constants.checkRefreshInterval - 2)
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is None)
class TestContactLastRequested(unittest.TestCase):
def setUp(self):
self.clock = task.Clock()
self.contact_manager = ContactManager(self.clock.seconds)
self.contact = self.contact_manager.make_contact(generate_id(), "127.0.0.1", 4444, None)
self.clock.advance(1)
self.contact.update_last_replied()
self.clock.advance(3600)
self.assertTrue(self.contact.contact_is_good is None)
def test_previous_replied_then_requested(self):
# it requests
self.contact.update_last_requested()
self.assertTrue(self.contact.contact_is_good is True)
# it goes stale
self.clock.advance(constants.checkRefreshInterval - 1)
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is None)
def test_previous_replied_then_requested_then_failed(self):
# it requests
self.contact.update_last_requested()
self.assertTrue(self.contact.contact_is_good is True)
self.clock.advance(1)
# it fails twice in a row
self.contact.update_last_failed()
self.clock.advance(1)
self.contact.update_last_failed()
self.assertTrue(self.contact.contact_is_good is False)
self.clock.advance(1)
# it requests
self.contact.update_last_requested()
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is False)
# it goes stale
self.clock.advance((constants.refreshTimeout / 4) - 2)
self.assertTrue(self.contact.contact_is_good is False)
self.clock.advance(1)
self.assertTrue(self.contact.contact_is_good is False)

View file

@ -1,143 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
import unittest
import time
import lbrynet.dht.datastore
import lbrynet.dht.constants
import hashlib
class DictDataStoreTest(unittest.TestCase):
""" Basic tests case for the reference DataStore API and implementation """
def setUp(self):
self.ds = lbrynet.dht.datastore.DictDataStore()
h = hashlib.sha1()
h.update('g')
hashKey = h.digest()
h2 = hashlib.sha1()
h2.update('dried')
hashKey2 = h2.digest()
h3 = hashlib.sha1()
h3.update('Boozoo Bajou - 09 - S.I.P.mp3')
hashKey3 = h3.digest()
#self.cases = (('a', 'hello there\nthis is a test'),
# (hashKey3, '1 2 3 4 5 6 7 8 9 0'))
self.cases = ((hashKey, 'test1test1test1test1test1t'),
(hashKey, 'test2'),
(hashKey, 'test3test3test3test3test3test3test3test3'),
(hashKey2, 'test4'),
(hashKey3, 'test5'),
(hashKey3, 'test6'))
def testReadWrite(self):
# Test write ability
for key, value in self.cases:
try:
now = int(time.time())
self.ds.addPeerToBlob(key, value, now, now, 'node1')
except Exception:
import traceback
self.fail('Failed writing the following data: key: "%s" '
'data: "%s"\n The error was: %s:' %
(key, value, traceback.format_exc(5)))
# Verify writing (test query ability)
for key, value in self.cases:
try:
self.failUnless(self.ds.hasPeersForBlob(key),
'Key "%s" not found in DataStore! DataStore key dump: %s' %
(key, self.ds.keys()))
except Exception:
import traceback
self.fail(
'Failed verifying that the following key exists: "%s"\n The error was: %s:' %
(key, traceback.format_exc(5)))
# Read back the data
for key, value in self.cases:
self.failUnless(value in self.ds.getPeersForBlob(key),
'DataStore returned invalid data! Expected "%s", got "%s"' %
(value, self.ds.getPeersForBlob(key)))
def testNonExistentKeys(self):
for key, value in self.cases:
self.failIf(key in self.ds.keys(), 'DataStore reports it has non-existent key: "%s"' %
key)
def testExpires(self):
now = int(time.time())
h1 = hashlib.sha1()
h1.update('test1')
key1 = h1.digest()
h2 = hashlib.sha1()
h2.update('test2')
key2 = h2.digest()
td = lbrynet.dht.constants.dataExpireTimeout - 100
td2 = td + td
self.ds.addPeerToBlob(h1, 'val1', now - td, now - td, '1')
self.ds.addPeerToBlob(h1, 'val2', now - td2, now - td2, '2')
self.ds.addPeerToBlob(h2, 'val3', now - td2, now - td2, '3')
self.ds.addPeerToBlob(h2, 'val4', now, now, '4')
self.ds.removeExpiredPeers()
self.failUnless(
'val1' in self.ds.getPeersForBlob(h1),
'DataStore deleted an unexpired value! Value %s, publish time %s, current time %s' %
('val1', str(now - td), str(now)))
self.failIf(
'val2' in self.ds.getPeersForBlob(h1),
'DataStore failed to delete an expired value! '
'Value %s, publish time %s, current time %s' %
('val2', str(now - td2), str(now)))
self.failIf(
'val3' in self.ds.getPeersForBlob(h2),
'DataStore failed to delete an expired value! '
'Value %s, publish time %s, current time %s' %
('val3', str(now - td2), str(now)))
self.failUnless(
'val4' in self.ds.getPeersForBlob(h2),
'DataStore deleted an unexpired value! Value %s, publish time %s, current time %s' %
('val4', str(now), str(now)))
# # First write with fake values
# for key, value in self.cases:
# except Exception:
#
# # write this stuff a second time, with the real values
# for key, value in self.cases:
# except Exception:
#
# # Read back the data
# for key, value in self.cases:
# # First some values
# for key, value in self.cases:
# except Exception:
#
#
# # Delete an item from the data
# # First some values with metadata
# for key, value in self.cases:
# except Exception:
#
# # Read back the meta-data
# for key, value in self.cases:
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(DictDataStoreTest))
return suite
if __name__ == '__main__':
# If this module is executed from the commandline, run all its tests
unittest.TextTestRunner().run(suite())

View file

@ -4,10 +4,10 @@
# the GNU Lesser General Public License Version 3, or any later version. # the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive # See the COPYING file included in this archive
import unittest from twisted.trial import unittest
import lbrynet.dht.encoding import lbrynet.dht.encoding
class BencodeTest(unittest.TestCase): class BencodeTest(unittest.TestCase):
""" Basic tests case for the Bencode implementation """ """ Basic tests case for the Bencode implementation """
def setUp(self): def setUp(self):
@ -16,7 +16,7 @@ class BencodeTest(unittest.TestCase):
self.cases = ((42, 'i42e'), self.cases = ((42, 'i42e'),
('spam', '4:spam'), ('spam', '4:spam'),
(['spam', 42], 'l4:spami42ee'), (['spam', 42], 'l4:spami42ee'),
({'foo':42, 'bar':'spam'}, 'd3:bar4:spam3:fooi42ee'), ({'foo': 42, 'bar': 'spam'}, 'd3:bar4:spam3:fooi42ee'),
# ...and now the "real life" tests # ...and now the "real life" tests
([['abc', '127.0.0.1', 1919], ['def', '127.0.0.1', 1921]], ([['abc', '127.0.0.1', 1919], ['def', '127.0.0.1', 1921]],
'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee')) 'll3:abc9:127.0.0.1i1919eel3:def9:127.0.0.1i1921eee'))
@ -45,12 +45,3 @@ class BencodeTest(unittest.TestCase):
for encodedValue in self.badDecoderCases: for encodedValue in self.badDecoderCases:
self.failUnlessRaises( self.failUnlessRaises(
lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue) lbrynet.dht.encoding.DecodeError, self.encoding.decode, encodedValue)
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(BencodeTest))
return suite
if __name__ == '__main__':
# If this module is executed from the commandline, run all its tests
unittest.TextTestRunner().run(suite())

View file

@ -0,0 +1,59 @@
from twisted.trial import unittest
from twisted.internet import defer, task
from lbrynet import conf
from lbrynet.core import utils
from lbrynet.dht.hashannouncer import DHTHashAnnouncer
from lbrynet.tests.util import random_lbry_hash
class MocDHTNode(object):
def __init__(self):
self.blobs_announced = 0
self.clock = task.Clock()
self.peerPort = 3333
def announceHaveBlob(self, blob):
self.blobs_announced += 1
d = defer.Deferred()
self.clock.callLater(1, d.callback, ['fake'])
return d
class MocStorage(object):
def __init__(self, blobs_to_announce):
self.blobs_to_announce = blobs_to_announce
self.announced = False
def get_blobs_to_announce(self):
if not self.announced:
self.announced = True
return defer.succeed(self.blobs_to_announce)
else:
return defer.succeed([])
def update_last_announced_blob(self, blob_hash, now):
return defer.succeed(None)
class DHTHashAnnouncerTest(unittest.TestCase):
def setUp(self):
conf.initialize_settings(False)
self.num_blobs = 10
self.blobs_to_announce = []
for i in range(0, self.num_blobs):
self.blobs_to_announce.append(random_lbry_hash())
self.dht_node = MocDHTNode()
self.clock = self.dht_node.clock
utils.call_later = self.clock.callLater
self.storage = MocStorage(self.blobs_to_announce)
self.announcer = DHTHashAnnouncer(self.dht_node, self.storage)
@defer.inlineCallbacks
def test_immediate_announce(self):
announce_d = self.announcer.immediate_announce(self.blobs_to_announce)
self.assertEqual(self.announcer.hash_queue_size(), self.num_blobs)
self.clock.advance(1)
yield announce_d
self.assertEqual(self.dht_node.blobs_announced, self.num_blobs)
self.assertEqual(self.announcer.hash_queue_size(), 0)

View file

@ -4,23 +4,41 @@
# the GNU Lesser General Public License Version 3, or any later version. # the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive # See the COPYING file included in this archive
import unittest from twisted.trial import unittest
import struct
from lbrynet.core.utils import generate_id
from lbrynet.dht import kbucket from lbrynet.dht import kbucket
import lbrynet.dht.contact as contact from lbrynet.dht.contact import ContactManager
from lbrynet.dht import constants from lbrynet.dht import constants
def address_generator(address=(10, 42, 42, 1)):
def increment(addr):
value = struct.unpack("I", "".join([chr(x) for x in list(addr)[::-1]]))[0] + 1
new_addr = []
for i in range(4):
new_addr.append(value % 256)
value >>= 8
return tuple(new_addr[::-1])
while True:
yield "{}.{}.{}.{}".format(*address)
address = increment(address)
class KBucketTest(unittest.TestCase): class KBucketTest(unittest.TestCase):
""" Test case for the KBucket class """ """ Test case for the KBucket class """
def setUp(self): def setUp(self):
self.kbucket = kbucket.KBucket(0, 2**160) self.address_generator = address_generator()
self.contact_manager = ContactManager()
self.kbucket = kbucket.KBucket(0, 2**constants.key_bits, generate_id())
def testAddContact(self): def testAddContact(self):
""" Tests if the bucket handles contact additions/updates correctly """ """ Tests if the bucket handles contact additions/updates correctly """
# Test if contacts can be added to empty list # Test if contacts can be added to empty list
# Add k contacts to bucket # Add k contacts to bucket
for i in range(constants.k): for i in range(constants.k):
tmpContact = contact.Contact('tempContactID%d' % i, str(i), i, i) tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.kbucket.addContact(tmpContact) self.kbucket.addContact(tmpContact)
self.failUnlessEqual( self.failUnlessEqual(
self.kbucket._contacts[i], self.kbucket._contacts[i],
@ -28,8 +46,7 @@ class KBucketTest(unittest.TestCase):
"Contact in position %d not the same as the newly-added contact" % i) "Contact in position %d not the same as the newly-added contact" % i)
# Test if contact is not added to full list # Test if contact is not added to full list
i += 1 tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
tmpContact = contact.Contact('tempContactID%d' % i, str(i), i, i)
self.failUnlessRaises(kbucket.BucketFull, self.kbucket.addContact, tmpContact) self.failUnlessRaises(kbucket.BucketFull, self.kbucket.addContact, tmpContact)
# Test if an existing contact is updated correctly if added again # Test if an existing contact is updated correctly if added again
@ -48,14 +65,19 @@ class KBucketTest(unittest.TestCase):
# Add k-2 contacts # Add k-2 contacts
node_ids = []
if constants.k >= 2: if constants.k >= 2:
for i in range(constants.k-2): for i in range(constants.k-2):
tmpContact = contact.Contact(i, i, i, i) node_ids.append(generate_id())
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
None)
self.kbucket.addContact(tmpContact) self.kbucket.addContact(tmpContact)
else: else:
# add k contacts # add k contacts
for i in range(constants.k): for i in range(constants.k):
tmpContact = contact.Contact(i, i, i, i) node_ids.append(generate_id())
tmpContact = self.contact_manager.make_contact(node_ids[-1], next(self.address_generator), 4444, 0,
None)
self.kbucket.addContact(tmpContact) self.kbucket.addContact(tmpContact)
# try to get too many contacts # try to get too many contacts
@ -65,8 +87,8 @@ class KBucketTest(unittest.TestCase):
'Returned list should not have more than k entries!') 'Returned list should not have more than k entries!')
# verify returned contacts in list # verify returned contacts in list
for i in range(constants.k-2): for node_id, i in zip(node_ids, range(constants.k-2)):
self.failIf(self.kbucket._contacts[i].id != i, self.failIf(self.kbucket._contacts[i].id != node_id,
"Contact in position %s not same as added contact" % (str(i))) "Contact in position %s not same as added contact" % (str(i)))
# try to get too many contacts # try to get too many contacts
@ -89,25 +111,15 @@ class KBucketTest(unittest.TestCase):
def testRemoveContact(self): def testRemoveContact(self):
# try remove contact from empty list # try remove contact from empty list
rmContact = contact.Contact('TestContactID1', '127.0.0.1', 1, 1) rmContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.failUnlessRaises(ValueError, self.kbucket.removeContact, rmContact) self.failUnlessRaises(ValueError, self.kbucket.removeContact, rmContact)
# Add couple contacts # Add couple contacts
for i in range(constants.k-2): for i in range(constants.k-2):
tmpContact = contact.Contact('tmpTestContactID%d' % i, str(i), i, i) tmpContact = self.contact_manager.make_contact(generate_id(), next(self.address_generator), 4444, 0, None)
self.kbucket.addContact(tmpContact) self.kbucket.addContact(tmpContact)
# try remove contact from empty list # try remove contact from empty list
self.kbucket.addContact(rmContact) self.kbucket.addContact(rmContact)
result = self.kbucket.removeContact(rmContact) result = self.kbucket.removeContact(rmContact)
self.failIf(rmContact in self.kbucket._contacts, "Could not remove contact from bucket") self.failIf(rmContact in self.kbucket._contacts, "Could not remove contact from bucket")
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(KBucketTest))
return suite
if __name__ == '__main__':
# If this module is executed from the commandline, run all its tests
unittest.TextTestRunner().run(suite())

View file

@ -4,7 +4,7 @@
# the GNU Lesser General Public License Version 3, or any later version. # the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive # See the COPYING file included in this archive
import unittest from twisted.trial import unittest
from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage
from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat

View file

@ -5,20 +5,18 @@
# See the COPYING file included in this archive # See the COPYING file included in this archive
import hashlib import hashlib
import unittest from twisted.trial import unittest
import struct import struct
from twisted.internet import protocol, defer, selectreactor from twisted.internet import defer
from lbrynet.dht.msgtypes import ResponseMessage from lbrynet.dht.node import Node
import lbrynet.dht.node from lbrynet.dht import constants
import lbrynet.dht.constants
import lbrynet.dht.datastore
class NodeIDTest(unittest.TestCase): class NodeIDTest(unittest.TestCase):
""" Test case for the Node class's ID """ """ Test case for the Node class's ID """
def setUp(self): def setUp(self):
self.node = lbrynet.dht.node.Node() self.node = Node()
def testAutoCreatedID(self): def testAutoCreatedID(self):
""" Tests if a new node has a valid node ID """ """ Tests if a new node has a valid node ID """
@ -49,12 +47,10 @@ class NodeIDTest(unittest.TestCase):
class NodeDataTest(unittest.TestCase): class NodeDataTest(unittest.TestCase):
""" Test case for the Node class's data-related functions """ """ Test case for the Node class's data-related functions """
def setUp(self): def setUp(self):
import lbrynet.dht.contact
h = hashlib.sha384() h = hashlib.sha384()
h.update('test') h.update('test')
self.node = lbrynet.dht.node.Node() self.node = Node()
self.contact = lbrynet.dht.contact.Contact(h.digest(), '127.0.0.1', 12345, self.contact = self.node.contact_manager.make_contact(h.digest(), '127.0.0.1', 12345, self.node._protocol)
self.node._protocol)
self.token = self.node.make_token(self.contact.compact_ip()) self.token = self.node.make_token(self.contact.compact_ip())
self.cases = [] self.cases = []
for i in xrange(5): for i in xrange(5):
@ -65,13 +61,10 @@ class NodeDataTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def testStore(self): def testStore(self):
""" Tests if the node can store (and privately retrieve) some data """ """ Tests if the node can store (and privately retrieve) some data """
for key, value in self.cases: for key, port in self.cases:
request = { yield self.node.store( # pylint: disable=too-many-function-args
'port': value, self.contact, key, self.token, port, self.contact.id, 0
'lbryid': self.contact.id, )
'token': self.token
}
yield self.node.store(key, request, self.contact.id, _rpcNodeContact=self.contact)
for key, value in self.cases: for key, value in self.cases:
expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \ expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \
self.contact.id self.contact.id
@ -85,189 +78,185 @@ class NodeDataTest(unittest.TestCase):
class NodeContactTest(unittest.TestCase): class NodeContactTest(unittest.TestCase):
""" Test case for the Node class's contact management-related functions """ """ Test case for the Node class's contact management-related functions """
def setUp(self): def setUp(self):
self.node = lbrynet.dht.node.Node() self.node = Node()
@defer.inlineCallbacks
def testAddContact(self): def testAddContact(self):
""" Tests if a contact can be added and retrieved correctly """ """ Tests if a contact can be added and retrieved correctly """
import lbrynet.dht.contact
# Create the contact # Create the contact
h = hashlib.sha384() h = hashlib.sha384()
h.update('node1') h.update('node1')
contactID = h.digest() contactID = h.digest()
contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.node._protocol) contact = self.node.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.node._protocol)
# Now add it... # Now add it...
self.node.addContact(contact) yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(contactID, lbrynet.dht.constants.k) closestNodes = self.node._routingTable.findCloseNodes(contactID, constants.k)
self.failUnlessEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; ' self.failUnlessEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; '
'expected 1, got %d' % len(closestNodes)) 'expected 1, got %d' % len(closestNodes))
self.failUnless(contact in closestNodes, 'Added contact not found by issueing ' self.failUnless(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()') '_findCloseNodes()')
@defer.inlineCallbacks
def testAddSelfAsContact(self): def testAddSelfAsContact(self):
""" Tests the node's behaviour when attempting to add itself as a contact """ """ Tests the node's behaviour when attempting to add itself as a contact """
import lbrynet.dht.contact
# Create a contact with the same ID as the local node's ID # Create a contact with the same ID as the local node's ID
contact = lbrynet.dht.contact.Contact(self.node.node_id, '127.0.0.1', 91824, None) contact = self.node.contact_manager.make_contact(self.node.node_id, '127.0.0.1', 9182, None)
# Now try to add it # Now try to add it
self.node.addContact(contact) yield self.node.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id, closestNodes = self.node._routingTable.findCloseNodes(self.node.node_id,
lbrynet.dht.constants.k) constants.k)
self.failIf(contact in closestNodes, 'Node added itself as a contact') self.failIf(contact in closestNodes, 'Node added itself as a contact')
class FakeRPCProtocol(protocol.DatagramProtocol): # class FakeRPCProtocol(protocol.DatagramProtocol):
def __init__(self): # def __init__(self):
self.reactor = selectreactor.SelectReactor() # self.reactor = selectreactor.SelectReactor()
self.testResponse = None # self.testResponse = None
self.network = None # self.network = None
#
def createNetwork(self, contactNetwork): # def createNetwork(self, contactNetwork):
""" # """
set up a list of contacts together with their closest contacts # set up a list of contacts together with their closest contacts
@param contactNetwork: a sequence of tuples, each containing a contact together with its # @param contactNetwork: a sequence of tuples, each containing a contact together with its
closest contacts: C{(<contact>, <closest contact 1, ...,closest contact n>)} # closest contacts: C{(<contact>, <closest contact 1, ...,closest contact n>)}
""" # """
self.network = contactNetwork # self.network = contactNetwork
#
def sendRPC(self, contact, method, args, rawResponse=False): # def sendRPC(self, contact, method, args, rawResponse=False):
""" Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs""" # """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs"""
#
h = hashlib.sha384() # h = hashlib.sha384()
h.update('rpcId') # h.update('rpcId')
rpc_id = h.digest()[:20] # rpc_id = h.digest()[:20]
#
if method == "findNode": # if method == "findNode":
# get the specific contacts closest contacts # # get the specific contacts closest contacts
closestContacts = [] # closestContacts = []
closestContactsList = [] # closestContactsList = []
for contactTuple in self.network: # for contactTuple in self.network:
if contact == contactTuple[0]: # if contact == contactTuple[0]:
# get the list of closest contacts for this contact # # get the list of closest contacts for this contact
closestContactsList = contactTuple[1] # closestContactsList = contactTuple[1]
# Pack the closest contacts into a ResponseMessage # # Pack the closest contacts into a ResponseMessage
for closeContact in closestContactsList: # for closeContact in closestContactsList:
closestContacts.append((closeContact.id, closeContact.address, closeContact.port)) # closestContacts.append((closeContact.id, closeContact.address, closeContact.port))
#
message = ResponseMessage(rpc_id, contact.id, closestContacts) # message = ResponseMessage(rpc_id, contact.id, closestContacts)
df = defer.Deferred() # df = defer.Deferred()
df.callback((message, (contact.address, contact.port))) # df.callback((message, (contact.address, contact.port)))
return df # return df
elif method == "findValue": # elif method == "findValue":
for contactTuple in self.network: # for contactTuple in self.network:
if contact == contactTuple[0]: # if contact == contactTuple[0]:
# Get the data stored by this remote contact # # Get the data stored by this remote contact
dataDict = contactTuple[2] # dataDict = contactTuple[2]
dataKey = dataDict.keys()[0] # dataKey = dataDict.keys()[0]
data = dataDict.get(dataKey) # data = dataDict.get(dataKey)
# Check if this contact has the requested value # # Check if this contact has the requested value
if dataKey == args[0]: # if dataKey == args[0]:
# Return the data value # # Return the data value
response = dataDict # response = dataDict
print "data found at contact: " + contact.id # print "data found at contact: " + contact.id
else: # else:
# Return the closest contact to the requested data key # # Return the closest contact to the requested data key
print "data not found at contact: " + contact.id # print "data not found at contact: " + contact.id
closeContacts = contactTuple[1] # closeContacts = contactTuple[1]
closestContacts = [] # closestContacts = []
for closeContact in closeContacts: # for closeContact in closeContacts:
closestContacts.append((closeContact.id, closeContact.address, # closestContacts.append((closeContact.id, closeContact.address,
closeContact.port)) # closeContact.port))
response = closestContacts # response = closestContacts
#
# Create the response message # # Create the response message
message = ResponseMessage(rpc_id, contact.id, response) # message = ResponseMessage(rpc_id, contact.id, response)
df = defer.Deferred() # df = defer.Deferred()
df.callback((message, (contact.address, contact.port))) # df.callback((message, (contact.address, contact.port)))
return df # return df
#
def _send(self, data, rpcID, address): # def _send(self, data, rpcID, address):
""" fake sending data """ # """ fake sending data """
#
#
class NodeLookupTest(unittest.TestCase): # class NodeLookupTest(unittest.TestCase):
""" Test case for the Node class's iterativeFind node lookup algorithm """ # """ Test case for the Node class's iterativeFind node lookup algorithm """
#
def setUp(self): # def setUp(self):
# create a fake protocol to imitate communication with other nodes # # create a fake protocol to imitate communication with other nodes
self._protocol = FakeRPCProtocol() # self._protocol = FakeRPCProtocol()
# Note: The reactor is never started for this test. All deferred calls run sequentially, # # Note: The reactor is never started for this test. All deferred calls run sequentially,
# since there is no asynchronous network communication # # since there is no asynchronous network communication
# create the node to be tested in isolation # # create the node to be tested in isolation
h = hashlib.sha384() # h = hashlib.sha384()
h.update('node1') # h.update('node1')
node_id = str(h.digest()) # node_id = str(h.digest())
self.node = lbrynet.dht.node.Node(node_id=node_id, udpPort=4000, networkProtocol=self._protocol) # self.node = Node(node_id, 4000, None, None, self._protocol)
self.updPort = 81173 # self.updPort = 81173
self.contactsAmount = 80 # self.contactsAmount = 80
# Reinitialise the routing table # # Reinitialise the routing table
self.node._routingTable = lbrynet.dht.routingtable.OptimizedTreeRoutingTable( # self.node._routingTable = TreeRoutingTable(self.node.node_id)
self.node.node_id) #
# # create 160 bit node ID's for test purposes
# create 160 bit node ID's for test purposes # self.testNodeIDs = []
self.testNodeIDs = [] # idNum = int(self.node.node_id.encode('hex'), 16)
idNum = int(self.node.node_id.encode('hex'), 16) # for i in range(self.contactsAmount):
for i in range(self.contactsAmount): # # create the testNodeIDs in ascending order, away from the actual node ID,
# create the testNodeIDs in ascending order, away from the actual node ID, # # with regards to the distance metric
# with regards to the distance metric # self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex'))
self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex')) #
# # generate contacts
# generate contacts # self.contacts = []
self.contacts = [] # for i in range(self.contactsAmount):
for i in range(self.contactsAmount): # contact = self.node.contact_manager.make_contact(self.testNodeIDs[i], "127.0.0.1",
contact = lbrynet.dht.contact.Contact(self.testNodeIDs[i], "127.0.0.1", # self.updPort + i + 1, self._protocol)
self.updPort + i + 1, self._protocol) # self.contacts.append(contact)
self.contacts.append(contact) #
# # create the network of contacts in format: (contact, closest contacts)
# create the network of contacts in format: (contact, closest contacts) # contactNetwork = ((self.contacts[0], self.contacts[8:15]),
contactNetwork = ((self.contacts[0], self.contacts[8:15]), # (self.contacts[1], self.contacts[16:23]),
(self.contacts[1], self.contacts[16:23]), # (self.contacts[2], self.contacts[24:31]),
(self.contacts[2], self.contacts[24:31]), # (self.contacts[3], self.contacts[32:39]),
(self.contacts[3], self.contacts[32:39]), # (self.contacts[4], self.contacts[40:47]),
(self.contacts[4], self.contacts[40:47]), # (self.contacts[5], self.contacts[48:55]),
(self.contacts[5], self.contacts[48:55]), # (self.contacts[6], self.contacts[56:63]),
(self.contacts[6], self.contacts[56:63]), # (self.contacts[7], self.contacts[64:71]),
(self.contacts[7], self.contacts[64:71]), # (self.contacts[8], self.contacts[72:79]),
(self.contacts[8], self.contacts[72:79]), # (self.contacts[40], self.contacts[41:48]),
(self.contacts[40], self.contacts[41:48]), # (self.contacts[41], self.contacts[41:48]),
(self.contacts[41], self.contacts[41:48]), # (self.contacts[42], self.contacts[41:48]),
(self.contacts[42], self.contacts[41:48]), # (self.contacts[43], self.contacts[41:48]),
(self.contacts[43], self.contacts[41:48]), # (self.contacts[44], self.contacts[41:48]),
(self.contacts[44], self.contacts[41:48]), # (self.contacts[45], self.contacts[41:48]),
(self.contacts[45], self.contacts[41:48]), # (self.contacts[46], self.contacts[41:48]),
(self.contacts[46], self.contacts[41:48]), # (self.contacts[47], self.contacts[41:48]),
(self.contacts[47], self.contacts[41:48]), # (self.contacts[48], self.contacts[41:48]),
(self.contacts[48], self.contacts[41:48]), # (self.contacts[50], self.contacts[0:7]),
(self.contacts[50], self.contacts[0:7]), # (self.contacts[51], self.contacts[8:15]),
(self.contacts[51], self.contacts[8:15]), # (self.contacts[52], self.contacts[16:23]))
(self.contacts[52], self.contacts[16:23])) #
# contacts_with_datastores = []
contacts_with_datastores = [] #
# for contact_tuple in contactNetwork:
for contact_tuple in contactNetwork: # contacts_with_datastores.append((contact_tuple[0], contact_tuple[1],
contacts_with_datastores.append((contact_tuple[0], contact_tuple[1], # DictDataStore()))
lbrynet.dht.datastore.DictDataStore())) # self._protocol.createNetwork(contacts_with_datastores)
self._protocol.createNetwork(contacts_with_datastores) #
# # @defer.inlineCallbacks
@defer.inlineCallbacks # # def testNodeBootStrap(self):
def testNodeBootStrap(self): # # """ Test bootstrap with the closest possible contacts """
""" Test bootstrap with the closest possible contacts """ # # # Set the expected result
# # expectedResult = {item.id for item in self.contacts[0:8]}
activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8]) # #
# Set the expected result # # activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8])
expectedResult = set() # #
for item in self.contacts[0:6]: # # # Check the length of the active contacts
expectedResult.add(item.id) # # self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(),
# Get the result from the deferred # # "More active contacts should exist, there should be %d "
# # "contacts but there are %d" % (len(expectedResult),
# Check the length of the active contacts # # len(activeContacts)))
self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(), # #
"More active contacts should exist, there should be %d " # # # Check that the received active contacts are the same as the input contacts
"contacts but there are %d" % (len(expectedResult), # # self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
len(activeContacts))) # # "Active should only contain the closest possible contacts"
# # " which were used as input for the boostrap")
# Check that the received active contacts are the same as the input contacts
self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
"Active should only contain the closest possible contacts"
" which were used as input for the boostrap")

View file

@ -1,200 +0,0 @@
import time
import unittest
from twisted.internet.task import Clock
from twisted.internet import defer
import lbrynet.dht.protocol
import lbrynet.dht.contact
import lbrynet.dht.constants
import lbrynet.dht.msgtypes
from lbrynet.dht.error import TimeoutError
from lbrynet.dht.node import Node, rpcmethod
from lbrynet.tests.mocks import listenUDP, resolve
from lbrynet.core.call_later_manager import CallLaterManager
import logging
log = logging.getLogger()
class KademliaProtocolTest(unittest.TestCase):
""" Test case for the Protocol class """
udpPort = 9182
def setUp(self):
self._reactor = Clock()
CallLaterManager.setup(self._reactor.callLater)
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
def tearDown(self):
CallLaterManager.stop()
del self._reactor
@defer.inlineCallbacks
def testReactor(self):
""" Tests if the reactor can start/stop the protocol correctly """
d = defer.Deferred()
self._reactor.callLater(1, d.callback, True)
self._reactor.advance(1)
result = yield d
self.assertTrue(result)
def testRPCTimeout(self):
""" Tests if a RPC message sent to a dead remote node times out correctly """
dead_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
dead_node.start_listening()
dead_node.stop()
self._reactor.pump([1 for _ in range(10)])
dead_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(dead_contact)
@rpcmethod
def fake_ping(*args, **kwargs):
time.sleep(lbrynet.dht.constants.rpcTimeout + 1)
return 'pong'
real_ping = self.node.ping
real_timeout = lbrynet.dht.constants.rpcTimeout
real_attempts = lbrynet.dht.constants.rpcAttempts
lbrynet.dht.constants.rpcAttempts = 1
lbrynet.dht.constants.rpcTimeout = 1
self.node.ping = fake_ping
# Make sure the contact was added
self.failIf(dead_contact not in self.node.contacts,
'Contact not added to fake node (error in test code)')
self.node.start_listening()
# Run the PING RPC (which should raise a timeout error)
df = self.node._protocol.sendRPC(dead_contact, 'ping', {})
def check_timeout(err):
self.assertEqual(err.type, TimeoutError)
df.addErrback(check_timeout)
def reset_values():
self.node.ping = real_ping
lbrynet.dht.constants.rpcTimeout = real_timeout
lbrynet.dht.constants.rpcAttempts = real_attempts
# See if the contact was removed due to the timeout
def check_removed_contact():
self.failIf(dead_contact in self.node.contacts,
'Contact was not removed after RPC timeout; check exception types.')
df.addCallback(lambda _: reset_values())
# Stop the reactor if a result arrives (timeout or not)
df.addCallback(lambda _: check_removed_contact())
self._reactor.pump([1 for _ in range(20)])
def testRPCRequest(self):
""" Tests if a valid RPC request is executed and responded to correctly """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remoteContact)
self.error = None
def handleError(f):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \
% (expectedResult, result)
# Publish the "local" node on the network
self.node.start_listening()
# Simulate the RPC
df = remoteContact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
for _ in range(10):
self._reactor.advance(1)
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')
def testRPCAccess(self):
""" Tests invalid RPC requests
Verifies that a RPC request for an existing but unpublished
method is denied, and that the associated (remote) exception gets
raised locally """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remote_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remote_contact)
self.error = None
def handleError(f):
try:
f.raiseException()
except AttributeError, e:
# This is the expected outcome since the remote node did not publish the method
self.error = None
except Exception, e:
self.error = 'The remote method failed, but the wrong exception was raised; ' \
'expected AttributeError, got %s' % type(e)
def handleResult(result):
self.error = 'The remote method executed successfully, returning: "%s"; ' \
'this RPC should not have been allowed.' % result
self.node.start_listening()
self._reactor.pump([1 for _ in range(10)])
# Simulate the RPC
df = remote_contact.not_a_rpc_function()
df.addCallback(handleResult)
df.addErrback(handleError)
self._reactor.pump([1 for _ in range(10)])
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')
def testRPCRequestArgs(self):
""" Tests if an RPC requiring arguments is executed correctly """
remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
remote_node.start_listening()
remote_contact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.node.addContact(remote_contact)
self.error = None
def handleError(f):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result):
expectedResult = 'pong'
if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' % \
(expectedResult, result)
# Publish the "local" node on the network
self.node.start_listening()
# Simulate the RPC
df = remote_contact.ping()
df.addCallback(handleResult)
df.addErrback(handleError)
self._reactor.pump([1 for _ in range(10)])
self.failIf(self.error, self.error)
# The list of sent RPC messages should be empty at this stage
self.failUnlessEqual(len(self.node._protocol._sentMessages), 0,
'The protocol is still waiting for a RPC result, '
'but the transaction is already done!')

View file

@ -1,29 +1,16 @@
import hashlib import hashlib
import unittest from twisted.trial import unittest
from twisted.internet import defer
import lbrynet.dht.constants from lbrynet.dht import constants
import lbrynet.dht.routingtable from lbrynet.dht.routingtable import TreeRoutingTable
import lbrynet.dht.contact from lbrynet.dht.contact import ContactManager
import lbrynet.dht.node from lbrynet.dht.distance import Distance
import lbrynet.dht.distance
class FakeRPCProtocol(object): class FakeRPCProtocol(object):
""" Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """ """ Fake RPC protocol; allows lbrynet.dht.contact.Contact objects to "send" RPCs """
def sendRPC(self, *args, **kwargs): def sendRPC(self, *args, **kwargs):
return FakeDeferred() return defer.succeed(None)
class FakeDeferred(object):
""" Fake Twisted Deferred object; allows the routing table to add callbacks that do nothing """
def addCallback(self, *args, **kwargs):
return
def addErrback(self, *args, **kwargs):
return
def addCallbacks(self, *args, **kwargs):
return
class TreeRoutingTableTest(unittest.TestCase): class TreeRoutingTableTest(unittest.TestCase):
@ -31,97 +18,94 @@ class TreeRoutingTableTest(unittest.TestCase):
def setUp(self): def setUp(self):
h = hashlib.sha384() h = hashlib.sha384()
h.update('node1') h.update('node1')
self.contact_manager = ContactManager()
self.nodeID = h.digest() self.nodeID = h.digest()
self.protocol = FakeRPCProtocol() self.protocol = FakeRPCProtocol()
self.routingTable = lbrynet.dht.routingtable.TreeRoutingTable(self.nodeID) self.routingTable = TreeRoutingTable(self.nodeID)
def testDistance(self): def testDistance(self):
""" Test to see if distance method returns correct result""" """ Test to see if distance method returns correct result"""
# testList holds a couple 3-tuple (variable1, variable2, result) # testList holds a couple 3-tuple (variable1, variable2, result)
basicTestList = [('123456789', '123456789', 0L), ('12345', '98765', 34527773184L)] basicTestList = [(chr(170) * 48, chr(85) * 48, long((chr(255) * 48).encode('hex'), 16))]
for test in basicTestList: for test in basicTestList:
result = lbrynet.dht.distance.Distance(test[0])(test[1]) result = Distance(test[0])(test[1])
self.failIf(result != test[2], 'Result of _distance() should be %s but %s returned' % self.failIf(result != test[2], 'Result of _distance() should be %s but %s returned' %
(test[2], result)) (test[2], result))
baseIp = '146.64.19.111' @defer.inlineCallbacks
ipTestList = ['146.64.29.222', '192.68.19.333']
distanceOne = lbrynet.dht.distance.Distance(baseIp)(ipTestList[0])
distanceTwo = lbrynet.dht.distance.Distance(baseIp)(ipTestList[1])
self.failIf(distanceOne > distanceTwo, '%s should be closer to the base ip %s than %s' %
(ipTestList[0], baseIp, ipTestList[1]))
def testAddContact(self): def testAddContact(self):
""" Tests if a contact can be added and retrieved correctly """ """ Tests if a contact can be added and retrieved correctly """
# Create the contact # Create the contact
h = hashlib.sha384() h = hashlib.sha384()
h.update('node2') h.update('node2')
contactID = h.digest() contactID = h.digest()
contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it (will retrieve it) # ...and request the closest nodes to it (will retrieve it)
closestNodes = self.routingTable.findCloseNodes(contactID, lbrynet.dht.constants.k) closestNodes = self.routingTable.findCloseNodes(contactID, constants.k)
self.failUnlessEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,' self.failUnlessEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,'
' got %d' % len(closestNodes)) ' got %d' % len(closestNodes))
self.failUnless(contact in closestNodes, 'Added contact not found by issueing ' self.failUnless(contact in closestNodes, 'Added contact not found by issueing '
'_findCloseNodes()') '_findCloseNodes()')
@defer.inlineCallbacks
def testGetContact(self): def testGetContact(self):
""" Tests if a specific existing contact can be retrieved correctly """ """ Tests if a specific existing contact can be retrieved correctly """
h = hashlib.sha384() h = hashlib.sha384()
h.update('node2') h.update('node2')
contactID = h.digest() contactID = h.digest()
contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and get it again # ...and get it again
sameContact = self.routingTable.getContact(contactID) sameContact = self.routingTable.getContact(contactID)
self.failUnlessEqual(contact, sameContact, 'getContact() should return the same contact') self.failUnlessEqual(contact, sameContact, 'getContact() should return the same contact')
@defer.inlineCallbacks
def testAddParentNodeAsContact(self): def testAddParentNodeAsContact(self):
""" """
Tests the routing table's behaviour when attempting to add its parent node as a contact Tests the routing table's behaviour when attempting to add its parent node as a contact
""" """
# Create a contact with the same ID as the local node's ID # Create a contact with the same ID as the local node's ID
contact = lbrynet.dht.contact.Contact(self.nodeID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(self.nodeID, '127.0.0.1', 9182, self.protocol)
# Now try to add it # Now try to add it
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it using FIND_NODE # ...and request the closest nodes to it using FIND_NODE
closestNodes = self.routingTable.findCloseNodes(self.nodeID, lbrynet.dht.constants.k) closestNodes = self.routingTable.findCloseNodes(self.nodeID, constants.k)
self.failIf(contact in closestNodes, 'Node added itself as a contact') self.failIf(contact in closestNodes, 'Node added itself as a contact')
@defer.inlineCallbacks
def testRemoveContact(self): def testRemoveContact(self):
""" Tests contact removal """ """ Tests contact removal """
# Create the contact # Create the contact
h = hashlib.sha384() h = hashlib.sha384()
h.update('node2') h.update('node2')
contactID = h.digest() contactID = h.digest()
contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(contactID, '127.0.0.1', 9182, self.protocol)
# Now add it... # Now add it...
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
# Verify addition # Verify addition
self.failUnlessEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly') self.failUnlessEqual(len(self.routingTable._buckets[0]), 1, 'Contact not added properly')
# Now remove it # Now remove it
self.routingTable.removeContact(contact.id) self.routingTable.removeContact(contact)
self.failUnlessEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly') self.failUnlessEqual(len(self.routingTable._buckets[0]), 0, 'Contact not removed properly')
@defer.inlineCallbacks
def testSplitBucket(self): def testSplitBucket(self):
""" Tests if the the routing table correctly dynamically splits k-buckets """ """ Tests if the the routing table correctly dynamically splits k-buckets """
self.failUnlessEqual(self.routingTable._buckets[0].rangeMax, 2**384, self.failUnlessEqual(self.routingTable._buckets[0].rangeMax, 2**384,
'Initial k-bucket range should be 0 <= range < 2**384') 'Initial k-bucket range should be 0 <= range < 2**384')
# Add k contacts # Add k contacts
for i in range(lbrynet.dht.constants.k): for i in range(constants.k):
h = hashlib.sha384() h = hashlib.sha384()
h.update('remote node %d' % i) h.update('remote node %d' % i)
nodeID = h.digest() nodeID = h.digest()
contact = lbrynet.dht.contact.Contact(nodeID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
self.failUnlessEqual(len(self.routingTable._buckets), 1, self.failUnlessEqual(len(self.routingTable._buckets), 1,
'Only k nodes have been added; the first k-bucket should now ' 'Only k nodes have been added; the first k-bucket should now '
'be full, but should not yet be split') 'be full, but should not yet be split')
@ -129,8 +113,8 @@ class TreeRoutingTableTest(unittest.TestCase):
h = hashlib.sha384() h = hashlib.sha384()
h.update('yet another remote node') h.update('yet another remote node')
nodeID = h.digest() nodeID = h.digest()
contact = lbrynet.dht.contact.Contact(nodeID, '127.0.0.1', 91824, self.protocol) contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
self.routingTable.addContact(contact) yield self.routingTable.addContact(contact)
self.failUnlessEqual(len(self.routingTable._buckets), 2, self.failUnlessEqual(len(self.routingTable._buckets), 2,
'k+1 nodes have been added; the first k-bucket should have been ' 'k+1 nodes have been added; the first k-bucket should have been '
'split into two new buckets') 'split into two new buckets')
@ -144,99 +128,102 @@ class TreeRoutingTableTest(unittest.TestCase):
'K-bucket was split, but the min/max ranges were ' 'K-bucket was split, but the min/max ranges were '
'not divided properly') 'not divided properly')
def testFullBucketNoSplit(self): @defer.inlineCallbacks
def testFullSplit(self):
""" """
Test that a bucket is not split if it full, but does not cover the range Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
containing the parent node's ID
""" """
self.routingTable._parentNodeID = 49 * 'a'
# more than 384 bits; this will not be in the range of _any_ k-bucket self.routingTable._parentNodeID = 48 * chr(255)
node_ids = [
"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"ff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
]
# Add k contacts # Add k contacts
for i in range(lbrynet.dht.constants.k): for nodeID in node_ids:
h = hashlib.sha384() # self.assertEquals(nodeID, node_ids[i].decode('hex'))
h.update('remote node %d' % i) contact = self.contact_manager.make_contact(nodeID.decode('hex'), '127.0.0.1', 9182, self.protocol)
nodeID = h.digest() yield self.routingTable.addContact(contact)
contact = lbrynet.dht.contact.Contact(nodeID, '127.0.0.1', 91824, self.protocol) self.failUnlessEqual(len(self.routingTable._buckets), 2)
self.routingTable.addContact(contact) self.failUnlessEqual(len(self.routingTable._buckets[0]._contacts), 8)
self.failUnlessEqual(len(self.routingTable._buckets), 1, 'Only k nodes have been added; ' self.failUnlessEqual(len(self.routingTable._buckets[1]._contacts), 2)
'the first k-bucket should now be '
'full, and there should not be ' # try adding a contact who is further from us than the k'th known contact
'more than 1 bucket') nodeID = '020000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'
self.failUnlessEqual(len(self.routingTable._buckets[0]._contacts), lbrynet.dht.constants.k, nodeID = nodeID.decode('hex')
'Bucket should have k contacts; expected %d got %d' % contact = self.contact_manager.make_contact(nodeID, '127.0.0.1', 9182, self.protocol)
(lbrynet.dht.constants.k, self.assertFalse(self.routingTable._shouldSplit(self.routingTable._kbucketIndex(contact.id), contact.id))
len(self.routingTable._buckets[0]._contacts))) yield self.routingTable.addContact(contact)
# Now add 1 more contact self.failUnlessEqual(len(self.routingTable._buckets), 2)
h = hashlib.sha384() self.failUnlessEqual(len(self.routingTable._buckets[0]._contacts), 8)
h.update('yet another remote node') self.failUnlessEqual(len(self.routingTable._buckets[1]._contacts), 2)
nodeID = h.digest() self.failIf(contact in self.routingTable._buckets[0]._contacts)
contact = lbrynet.dht.contact.Contact(nodeID, '127.0.0.1', 91824, self.protocol) self.failIf(contact in self.routingTable._buckets[1]._contacts)
self.routingTable.addContact(contact)
self.failUnlessEqual(len(self.routingTable._buckets), 1,
'There should not be more than 1 bucket, since the bucket '
'should not have been split (parent node ID not in range)')
self.failUnlessEqual(len(self.routingTable._buckets[0]._contacts),
lbrynet.dht.constants.k, 'Bucket should have k contacts; '
'expected %d got %d' %
(lbrynet.dht.constants.k,
len(self.routingTable._buckets[0]._contacts)))
self.failIf(contact in self.routingTable._buckets[0]._contacts,
'New contact should have been discarded (since RPC is faked in this test)')
class KeyErrorFixedTest(unittest.TestCase): # class KeyErrorFixedTest(unittest.TestCase):
""" Basic tests case for boolean operators on the Contact class """ # """ Basic tests case for boolean operators on the Contact class """
#
def setUp(self): # def setUp(self):
own_id = (2 ** lbrynet.dht.constants.key_bits) - 1 # own_id = (2 ** constants.key_bits) - 1
# carefully chosen own_id. here's the logic # # carefully chosen own_id. here's the logic
# we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id # # we want a bunch of buckets (k+1, to be exact), and we want to make sure own_id
# is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the # # is not in bucket 0. so we put own_id at the end so we can keep splitting by adding to the
# end # # end
#
self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id) # self.table = lbrynet.dht.routingtable.OptimizedTreeRoutingTable(own_id)
#
def fill_bucket(self, bucket_min): # def fill_bucket(self, bucket_min):
bucket_size = lbrynet.dht.constants.k # bucket_size = lbrynet.dht.constants.k
for i in range(bucket_min, bucket_min + bucket_size): # for i in range(bucket_min, bucket_min + bucket_size):
self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None)) # self.table.addContact(lbrynet.dht.contact.Contact(long(i), '127.0.0.1', 9999, None))
#
def overflow_bucket(self, bucket_min): # def overflow_bucket(self, bucket_min):
bucket_size = lbrynet.dht.constants.k # bucket_size = lbrynet.dht.constants.k
self.fill_bucket(bucket_min) # self.fill_bucket(bucket_min)
self.table.addContact( # self.table.addContact(
lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1), # lbrynet.dht.contact.Contact(long(bucket_min + bucket_size + 1),
'127.0.0.1', 9999, None)) # '127.0.0.1', 9999, None))
#
def testKeyError(self): # def testKeyError(self):
#
# find middle, so we know where bucket will split # # find middle, so we know where bucket will split
bucket_middle = self.table._buckets[0].rangeMax / 2 # bucket_middle = self.table._buckets[0].rangeMax / 2
#
# fill last bucket # # fill last bucket
self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1) # self.fill_bucket(self.table._buckets[0].rangeMax - lbrynet.dht.constants.k - 1)
# -1 in previous line because own_id is in last bucket # # -1 in previous line because own_id is in last bucket
#
# fill/overflow 7 more buckets # # fill/overflow 7 more buckets
bucket_start = 0 # bucket_start = 0
for i in range(0, lbrynet.dht.constants.k): # for i in range(0, lbrynet.dht.constants.k):
self.overflow_bucket(bucket_start) # self.overflow_bucket(bucket_start)
bucket_start += bucket_middle / (2 ** i) # bucket_start += bucket_middle / (2 ** i)
#
# replacement cache now has k-1 entries. # # replacement cache now has k-1 entries.
# adding one more contact to bucket 0 used to cause a KeyError, but it should work # # adding one more contact to bucket 0 used to cause a KeyError, but it should work
self.table.addContact( # self.table.addContact(
lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None)) # lbrynet.dht.contact.Contact(long(lbrynet.dht.constants.k + 2), '127.0.0.1', 9999, None))
#
# import math # # import math
# print "" # # print ""
# for i, bucket in enumerate(self.table._buckets): # # for i, bucket in enumerate(self.table._buckets):
# print "Bucket " + str(i) + " (2 ** " + str( # # print "Bucket " + str(i) + " (2 ** " + str(
# math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str( # # math.log(bucket.rangeMin, 2) if bucket.rangeMin > 0 else 0) + " <= x < 2 ** "+str(
# math.log(bucket.rangeMax, 2)) + ")" # # math.log(bucket.rangeMax, 2)) + ")"
# for c in bucket.getContacts(): # # for c in bucket.getContacts():
# print " contact " + str(c.id) # # print " contact " + str(c.id)
# for key, bucket in self.table._replacementCache.iteritems(): # # for key, bucket in self.table._replacementCache.iteritems():
# print "Replacement Cache for Bucket " + str(key) # # print "Replacement Cache for Bucket " + str(key)
# for c in bucket: # # for c in bucket:
# print " contact " + str(c.id) # # print " contact " + str(c.id)

View file

@ -1,16 +1,22 @@
import mock import mock
import json import json
import unittest import unittest
import random
from os import path
from twisted.internet import defer from twisted.internet import defer
from twisted import trial from twisted import trial
from faker import Faker
from lbryschema.decode import smart_decode from lbryschema.decode import smart_decode
from lbryum.wallet import NewWallet from lbryum.wallet import NewWallet
from lbrynet import conf from lbrynet import conf
from lbrynet.core import Session, PaymentRateManager, Wallet from lbrynet.core import Session, PaymentRateManager, Wallet
from lbrynet.database.storage import SQLiteStorage from lbrynet.database.storage import SQLiteStorage
from lbrynet.daemon.Daemon import Daemon as LBRYDaemon from lbrynet.daemon.Daemon import Daemon as LBRYDaemon
from lbrynet.file_manager.EncryptedFileManager import EncryptedFileManager
from lbrynet.file_manager.EncryptedFileDownloader import ManagedEncryptedFileDownloader
from lbrynet.tests import util from lbrynet.tests import util
from lbrynet.tests.mocks import mock_conf_settings, FakeNetwork from lbrynet.tests.mocks import mock_conf_settings, FakeNetwork
@ -126,3 +132,172 @@ class TestJsonRpc(trial.unittest.TestCase):
d = defer.maybeDeferred(self.test_daemon.jsonrpc_help, command='status') d = defer.maybeDeferred(self.test_daemon.jsonrpc_help, command='status')
d.addCallback(lambda result: self.assertSubstring('daemon status', result['help'])) d.addCallback(lambda result: self.assertSubstring('daemon status', result['help']))
# self.assertSubstring('daemon status', d.result) # self.assertSubstring('daemon status', d.result)
class TestFileListSorting(trial.unittest.TestCase):
def setUp(self):
mock_conf_settings(self)
util.resetTime(self)
self.faker = Faker('en_US')
self.faker.seed(66410)
self.test_daemon = get_test_daemon()
self.test_daemon.lbry_file_manager = mock.Mock(spec=EncryptedFileManager)
self.test_daemon.lbry_file_manager.lbry_files = self._get_fake_lbry_files()
# Pre-sorted lists of prices and file names in ascending order produced by
# faker with seed 66410. This seed was chosen becacuse it produces 3 results
# 'points_paid' at 6.0 and 2 results at 4.5 to test multiple sort criteria.
self.test_points_paid = [0.2, 2.9, 4.5, 4.5, 6.0, 6.0, 6.0, 6.8, 7.1, 9.2]
self.test_file_names = ['also.mp3', 'better.css', 'call.mp3', 'pay.jpg',
'record.pages', 'sell.css', 'strategy.pages',
'thousand.pages', 'town.mov', 'vote.ppt']
self.test_authors = ['angela41', 'edward70', 'fhart', 'johnrosales',
'lucasfowler', 'peggytorres', 'qmitchell',
'trevoranderson', 'xmitchell', 'zhangsusan']
def test_sort_by_points_paid_no_direction_specified(self):
sort_options = ['points_paid']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(self.test_points_paid, [f['points_paid'] for f in file_list])
def test_sort_by_points_paid_ascending(self):
sort_options = ['points_paid,asc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(self.test_points_paid, [f['points_paid'] for f in file_list])
def test_sort_by_points_paid_descending(self):
sort_options = ['points_paid, desc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(list(reversed(self.test_points_paid)), [f['points_paid'] for f in file_list])
def test_sort_by_file_name_no_direction_specified(self):
sort_options = ['file_name']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(self.test_file_names, [f['file_name'] for f in file_list])
def test_sort_by_file_name_ascending(self):
sort_options = ['file_name,\nasc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(self.test_file_names, [f['file_name'] for f in file_list])
def test_sort_by_file_name_descending(self):
sort_options = ['\tfile_name,\n\tdesc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(list(reversed(self.test_file_names)), [f['file_name'] for f in file_list])
def test_sort_by_multiple_criteria(self):
expected = ['file_name=record.pages, points_paid=9.2',
'file_name=vote.ppt, points_paid=7.1',
'file_name=strategy.pages, points_paid=6.8',
'file_name=also.mp3, points_paid=6.0',
'file_name=better.css, points_paid=6.0',
'file_name=town.mov, points_paid=6.0',
'file_name=sell.css, points_paid=4.5',
'file_name=thousand.pages, points_paid=4.5',
'file_name=call.mp3, points_paid=2.9',
'file_name=pay.jpg, points_paid=0.2']
format_result = lambda f: 'file_name={}, points_paid={}'.format(f['file_name'], f['points_paid'])
sort_options = ['file_name,asc', 'points_paid,desc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(expected, map(format_result, file_list))
# Check that the list is not sorted as expected when sorted only by file_name.
sort_options = ['file_name,asc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertNotEqual(expected, map(format_result, file_list))
# Check that the list is not sorted as expected when sorted only by points_paid.
sort_options = ['points_paid,desc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertNotEqual(expected, map(format_result, file_list))
# Check that the list is not sorted as expected when not sorted at all.
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list)
file_list = self.successResultOf(deferred)
self.assertNotEqual(expected, map(format_result, file_list))
def test_sort_by_nested_field(self):
extract_authors = lambda file_list: [f['metadata']['author'] for f in file_list]
sort_options = ['metadata.author']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(self.test_authors, extract_authors(file_list))
# Check that the list matches the expected in reverse when sorting in descending order.
sort_options = ['metadata.author,desc']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
file_list = self.successResultOf(deferred)
self.assertEquals(list(reversed(self.test_authors)), extract_authors(file_list))
# Check that the list is not sorted as expected when not sorted at all.
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list)
file_list = self.successResultOf(deferred)
self.assertNotEqual(self.test_authors, extract_authors(file_list))
def test_invalid_sort_produces_meaningful_errors(self):
sort_options = ['meta.author']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
failure_assertion = self.assertFailure(deferred, Exception)
exception = self.successResultOf(failure_assertion)
expected_message = 'Failed to get "meta.author", key "meta" was not found.'
self.assertEquals(expected_message, exception.message)
sort_options = ['metadata.foo.bar']
deferred = defer.maybeDeferred(self.test_daemon.jsonrpc_file_list, sort=sort_options)
failure_assertion = self.assertFailure(deferred, Exception)
exception = self.successResultOf(failure_assertion)
expected_message = 'Failed to get "metadata.foo.bar", key "foo" was not found.'
self.assertEquals(expected_message, exception.message)
def _get_fake_lbry_files(self):
return [self._get_fake_lbry_file() for _ in range(10)]
def _get_fake_lbry_file(self):
lbry_file = mock.Mock(spec=ManagedEncryptedFileDownloader)
file_path = self.faker.file_path()
stream_name = self.faker.file_name()
channel_claim_id = self.faker.sha1()
channel_name = self.faker.simple_profile()['username']
faked_attributes = {
'channel_claim_id': channel_claim_id,
'channel_name': '@' + channel_name,
'claim_id': self.faker.sha1(),
'claim_name': '-'.join(self.faker.words(4)),
'completed': self.faker.boolean(),
'download_directory': path.dirname(file_path),
'download_path': file_path,
'file_name': path.basename(file_path),
'key': self.faker.md5(),
'metadata': {
'author': channel_name,
'nsfw': random.randint(0, 1) == 1,
},
'mime_type': self.faker.mime_type(),
'nout': abs(self.faker.pyint()),
'outpoint': self.faker.md5() + self.faker.md5(),
'points_paid': self.faker.pyfloat(left_digits=1, right_digits=1, positive=True),
'sd_hash': self.faker.md5() + self.faker.md5() + self.faker.md5(),
'stopped': self.faker.boolean(),
'stream_hash': self.faker.md5() + self.faker.md5() + self.faker.md5(),
'stream_name': stream_name,
'suggested_file_name': stream_name,
'txid': self.faker.md5() + self.faker.md5(),
'written_bytes': self.faker.pyint(),
}
for key in faked_attributes:
setattr(lbry_file, key, faked_attributes[key])
return lbry_file

View file

@ -0,0 +1,43 @@
import unittest
from lbrynet.daemon.Daemon import sort_claim_results
class ClaimsComparatorTest(unittest.TestCase):
def test_sort_claim_results_when_sorted_by_claim_id(self):
results = [{"height": 1, "name": "res", "claim_id": "ccc", "nout": 0, "txid": "fdsafa"},
{"height": 1, "name": "res", "claim_id": "aaa", "nout": 0, "txid": "w5tv8uorgt"},
{"height": 1, "name": "res", "claim_id": "bbb", "nout": 0, "txid": "aecfaewcfa"}]
self.run_test(results, 'claim_id', ['aaa', 'bbb', 'ccc'])
def test_sort_claim_results_when_sorted_by_height(self):
results = [{"height": 1, "name": "res", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"},
{"height": 3, "name": "res", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"},
{"height": 2, "name": "res", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"}]
self.run_test(results, 'height', [1, 2, 3])
def test_sort_claim_results_when_sorted_by_name(self):
results = [{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"},
{"height": 1, "name": "res3", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"},
{"height": 1, "name": "res2", "claim_id": "ccc", "nout": 0, "txid": "aecfaewcfa"}]
self.run_test(results, 'name', ['res1', 'res2', 'res3'])
def test_sort_claim_results_when_sorted_by_txid(self):
results = [{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 2, "txid": "111"},
{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 1, "txid": "222"},
{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 3, "txid": "333"}]
self.run_test(results, 'txid', ['111', '222', '333'])
def test_sort_claim_results_when_sorted_by_nout(self):
results = [{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 2, "txid": "aecfaewcfa"},
{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 1, "txid": "aecfaewcfa"},
{"height": 1, "name": "res1", "claim_id": "ccc", "nout": 3, "txid": "aecfaewcfa"}]
self.run_test(results, 'nout', [1, 2, 3])
def run_test(self, results, field, expected):
actual = sort_claim_results(results)
self.assertEqual(expected, [r[field] for r in actual])
if __name__ == '__main__':
unittest.main()

View file

@ -82,8 +82,7 @@ class SettingsTest(unittest.TestCase):
def test_load_save_config_file(self): def test_load_save_config_file(self):
# setup settings # setup settings
adjustable_settings = {'data_dir': (str, conf.default_data_dir), adjustable_settings = {'data_dir': (str, conf.default_data_dir),
'lbryum_servers': (list, [('localhost', 5001)], 'lbryum_servers': (list, [])}
conf.server_list, conf.server_list_reverse)}
env = conf.Env(**adjustable_settings) env = conf.Env(**adjustable_settings)
settings = conf.Config({}, adjustable_settings, environment=env) settings = conf.Config({}, adjustable_settings, environment=env)
conf.settings = settings conf.settings = settings

View file

@ -5,19 +5,16 @@ import os
import tempfile import tempfile
import shutil import shutil
import mock import mock
import logging
from lbrynet.dht.encoding import Bencode
from lbrynet.dht.error import DecodeError
from lbrynet.dht.msgformat import DefaultFormat
from lbrynet.dht.msgtypes import ResponseMessage, RequestMessage, ErrorMessage
_encode = Bencode()
_datagram_formatter = DefaultFormat()
DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1) DEFAULT_TIMESTAMP = datetime.datetime(2016, 1, 1)
DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple()) DEFAULT_ISO_TIME = time.mktime(DEFAULT_TIMESTAMP.timetuple())
<<<<<<< HEAD
log = logging.getLogger("lbrynet.tests.util") log = logging.getLogger("lbrynet.tests.util")
=======
>>>>>>> 2619c396d980944802a3f7e0800e92794cc3de5a
def mk_db_and_blob_dir(): def mk_db_and_blob_dir():
db_dir = tempfile.mkdtemp() db_dir = tempfile.mkdtemp()
@ -48,28 +45,5 @@ def resetTime(test_case, timestamp=DEFAULT_TIMESTAMP):
patcher.start().return_value = timestamp patcher.start().return_value = timestamp
test_case.addCleanup(patcher.stop) test_case.addCleanup(patcher.stop)
def is_android(): def is_android():
return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way return 'ANDROID_ARGUMENT' in os.environ # detect Android using the Kivy way
def debug_kademlia_packet(data, source, destination, node):
if log.level != logging.DEBUG:
return
try:
packet = _datagram_formatter.fromPrimitive(_encode.decode(data))
if isinstance(packet, RequestMessage):
log.debug("request %s --> %s %s (node time %s)", source[0], destination[0], packet.request,
node.clock.seconds())
elif isinstance(packet, ResponseMessage):
if isinstance(packet.response, (str, unicode)):
log.debug("response %s <-- %s %s (node time %s)", destination[0], source[0], packet.response,
node.clock.seconds())
else:
log.debug("response %s <-- %s %i contacts (node time %s)", destination[0], source[0],
len(packet.response), node.clock.seconds())
elif isinstance(packet, ErrorMessage):
log.error("error %s <-- %s %s (node time %s)", destination[0], source[0], packet.exceptionType,
node.clock.seconds())
except DecodeError:
log.exception("decode error %s --> %s (node time %s)", source[0], destination[0], node.clock.seconds())

View file

@ -1,3 +1,4 @@
certifi==2018.4.16
Twisted==16.6.0 Twisted==16.6.0
cryptography==2.2.2 cryptography==2.2.2
appdirs==1.4.3 appdirs==1.4.3
@ -12,8 +13,8 @@ GitPython==2.1.3
jsonrpc==1.2 jsonrpc==1.2
jsonrpclib==0.1.7 jsonrpclib==0.1.7
keyring==10.4.0 keyring==10.4.0
git+https://github.com/lbryio/lbryschema.git@v0.0.15#egg=lbryschema git+https://github.com/lbryio/lbryschema.git@v0.0.16rc2#egg=lbryschema
git+https://github.com/lbryio/lbryum.git@v3.2.1#egg=lbryum git+https://github.com/lbryio/lbryum.git@v3.2.2rc1#egg=lbryum
miniupnpc==1.9 miniupnpc==1.9
pbkdf2==1.3 pbkdf2==1.3
pyyaml==3.12 pyyaml==3.12

View file

@ -0,0 +1,2 @@
mock>=2.0,<3.0
Faker>=0.8,<1.0

View file

@ -1,92 +0,0 @@
#!/usr/bin/env python
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# Thanks to Paul Cannon for IP-address resolution functions (taken from aspn.activestate.com)
import argparse
import os
import sys
import time
import signal
amount = 0
def destroyNetwork(nodes):
print 'Destroying Kademlia network'
i = 0
for node in nodes:
i += 1
hashAmount = i * 50 / amount
hashbar = '#' * hashAmount
output = '\r[%-50s] %d/%d' % (hashbar, i, amount)
sys.stdout.write(output)
time.sleep(0.15)
os.kill(node, signal.SIGTERM)
print
def main():
parser = argparse.ArgumentParser(description="Launch a network of dht nodes")
parser.add_argument("amount_of_nodes",
help="The number of nodes to create",
type=int)
parser.add_argument(
"--nic_ip_address",
help=("The network interface on which these nodes will listen for connections "
"from each other and from other nodes. If omitted, an attempt will be "
"made to automatically determine the system's IP address, but this may "
"result in the nodes being reachable only from this system"))
args = parser.parse_args()
global amount
amount = args.amount_of_nodes
if args.nic_ip_address:
ipAddress = args.nic_ip_address
else:
import socket
ipAddress = socket.gethostbyname(socket.gethostname())
print 'Network interface IP address omitted; using %s' % ipAddress
startPort = 4000
port = startPort + 1
nodes = []
print 'Creating Kademlia network'
try:
node = os.spawnlp(
os.P_NOWAIT, 'lbrynet-launch-node', 'lbrynet-launch-node', str(startPort))
nodes.append(node)
for i in range(amount - 1):
time.sleep(0.15)
hashAmount = i * 50 / amount
hashbar = '#' * hashAmount
output = '\r[%-50s] %d/%d' % (hashbar, i, amount)
sys.stdout.write(output)
node = os.spawnlp(
os.P_NOWAIT, 'lbrynet-launch-node', 'lbrynet-launch-node', str(port),
ipAddress, str(startPort))
nodes.append(node)
port += 1
except KeyboardInterrupt:
'\nNetwork creation cancelled.'
destroyNetwork(nodes)
sys.exit(1)
print '\n\n---------------\nNetwork running\n---------------\n'
try:
while 1:
time.sleep(1)
except KeyboardInterrupt:
pass
finally:
destroyNetwork(nodes)
if __name__ == '__main__':
main()

View file

@ -1,7 +1,7 @@
import curses import curses
import time import time
from jsonrpc.proxy import JSONRPCProxy
import logging import logging
from lbrynet.daemon import get_client
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.addHandler(logging.FileHandler("dht contacts.log")) log.addHandler(logging.FileHandler("dht contacts.log"))
@ -9,7 +9,7 @@ log.addHandler(logging.FileHandler("dht contacts.log"))
log.setLevel(logging.INFO) log.setLevel(logging.INFO)
stdscr = curses.initscr() stdscr = curses.initscr()
api = JSONRPCProxy.from_url("http://localhost:5279") api = get_client()
def init_curses(): def init_curses():
@ -53,7 +53,7 @@ def refresh(last_contacts, last_blobs):
stdscr.addstr(y, 0, "bucket %s" % i) stdscr.addstr(y, 0, "bucket %s" % i)
y += 1 y += 1
for h in sorted(buckets[i], key=lambda x: x['node_id'].decode('hex')): for h in sorted(buckets[i], key=lambda x: x['node_id'].decode('hex')):
stdscr.addstr(y, 0, '%s (%s) - %i blobs' % (h['node_id'], h['address'], stdscr.addstr(y, 0, '%s (%s:%i) - %i blobs' % (h['node_id'], h['address'], h['port'],
len(h['blobs']))) len(h['blobs'])))
y += 1 y += 1
y += 1 y += 1

View file

@ -1,129 +0,0 @@
from lbrynet.core import log_support
import logging.handlers
import sys
import traceback
from lbrynet.dht.node import Node
from twisted.internet import reactor, defer
from lbrynet.core.utils import generate_id
log = logging.getLogger(__name__)
def print_usage():
print "Usage:\n%s UDP_PORT KNOWN_NODE_IP KNOWN_NODE_PORT HASH"
@defer.inlineCallbacks
def join_network(udp_port, known_nodes):
lbryid = generate_id()
log.info('Creating node')
node = Node(udpPort=udp_port, node_id=lbryid)
log.info('Joining network')
yield node.joinNetwork(known_nodes)
defer.returnValue(node)
@defer.inlineCallbacks
def get_hosts(node, h):
log.info("Looking up %s", h)
hosts = yield node.getPeersForBlob(h.decode("hex"))
log.info("Hosts returned from the DHT: %s", hosts)
@defer.inlineCallbacks
def announce_hash(node, h):
results = yield node.announceHaveBlob(h, 34567)
for success, result in results:
if success:
log.info("Succeeded: %s", str(result))
else:
log.info("Failed: %s", str(result.getErrorMessage()))
# def get_args():
# if len(sys.argv) < 5:
# print_usage()
# sys.exit(1)
# udp_port = int(sys.argv[1])
# known_nodes = [(sys.argv[2], int(sys.argv[3]))]
# h = binascii.unhexlify(sys.argv[4])
# return udp_port, known_nodes, h
@defer.inlineCallbacks
def connect(port=None):
try:
if port is None:
raise Exception("need a port")
known_nodes = [('54.236.227.82', 4444)] # lbrynet1
node = yield join_network(port, known_nodes)
log.info("joined")
reactor.callLater(3, find, node)
except Exception:
log.error("CAUGHT EXCEPTION")
traceback.print_exc()
log.info("Stopping reactor")
yield reactor.stop()
def getApproximateTotalDHTNodes(node):
from lbrynet.dht import constants
# get the deepest bucket and the number of contacts in that bucket and multiply it
# by the number of equivalently deep buckets in the whole DHT to get a really bad
# estimate!
bucket = node._routingTable._buckets[node._routingTable._kbucketIndex(node.node_id)]
num_in_bucket = len(bucket._contacts)
factor = (2 ** constants.key_bits) / (bucket.rangeMax - bucket.rangeMin)
return num_in_bucket * factor
def getApproximateTotalHashes(node):
# Divide the number of hashes we know about by k to get a really, really, really
# bad estimate of the average number of hashes per node, then multiply by the
# approximate number of nodes to get a horrendous estimate of the total number
# of hashes in the DHT
num_in_data_store = len(node._dataStore._dict)
if num_in_data_store == 0:
return 0
return num_in_data_store * getApproximateTotalDHTNodes(node) / 8
@defer.inlineCallbacks
def find(node):
try:
log.info("Approximate number of nodes in DHT: %s", str(getApproximateTotalDHTNodes(node)))
log.info("Approximate number of blobs in DHT: %s", str(getApproximateTotalHashes(node)))
h = "578f5e82da7db97bfe0677826d452cc0c65406a8e986c9caa126af4ecdbf4913daad2f7f5d1fb0ffec17d0bf8f187f5a"
peersFake = yield node.getPeersForBlob(h.decode("hex"))
print peersFake
peers = yield node.getPeersForBlob(h.decode("hex"))
print peers
# yield get_hosts(node, h)
except Exception:
log.error("CAUGHT EXCEPTION")
traceback.print_exc()
log.info("Stopping reactor")
yield reactor.stop()
def main():
log_support.configure_console(level='DEBUG')
log_support.configure_twisted()
reactor.callLater(0, connect, port=10001)
log.info("Running reactor")
reactor.run()
if __name__ == '__main__':
sys.exit(main())

View file

@ -0,0 +1,85 @@
import curses
import time
import datetime
from jsonrpc.proxy import JSONRPCProxy
stdscr = curses.initscr()
api = JSONRPCProxy.from_url("http://localhost:5280")
def init_curses():
curses.noecho()
curses.cbreak()
stdscr.nodelay(1)
stdscr.keypad(1)
def teardown_curses():
curses.nocbreak()
stdscr.keypad(0)
curses.echo()
curses.endwin()
def refresh(node_index):
height, width = stdscr.getmaxyx()
node_ids = api.get_node_ids()
node_id = node_ids[node_index]
node_statuses = api.node_status()
running = node_statuses[node_id]
buckets = api.node_routing_table(node_id=node_id)
for y in range(height):
stdscr.addstr(y, 0, " " * (width - 1))
stdscr.addstr(0, 0, "node id: %s, running: %s (%i/%i running)" % (node_id, running, sum(node_statuses.values()), len(node_ids)))
stdscr.addstr(1, 0, "%i buckets, %i contacts" %
(len(buckets), sum([len(buckets[b]['contacts']) for b in buckets])))
y = 3
for i in sorted(buckets.keys()):
stdscr.addstr(y, 0, "bucket %s" % i)
y += 1
for h in sorted(buckets[i]['contacts'], key=lambda x: x['node_id'].decode('hex')):
stdscr.addstr(y, 0, '%s (%s:%i) failures: %i, last replied to us: %s, last requested from us: %s' %
(h['node_id'], h['address'], h['port'], h['failedRPCs'],
datetime.datetime.fromtimestamp(float(h['lastReplied'] or 0)),
datetime.datetime.fromtimestamp(float(h['lastRequested'] or 0))))
y += 1
y += 1
stdscr.addstr(y + 1, 0, str(time.time()))
stdscr.refresh()
return len(node_ids)
def do_main():
c = None
nodes = 1
node_index = 0
while c not in [ord('q'), ord('Q')]:
try:
nodes = refresh(node_index)
except:
pass
c = stdscr.getch()
if c == curses.KEY_LEFT:
node_index -= 1
node_index = max(node_index, 0)
elif c == curses.KEY_RIGHT:
node_index += 1
node_index = min(node_index, nodes - 1)
time.sleep(0.1)
def main():
try:
init_curses()
do_main()
finally:
teardown_curses()
if __name__ == "__main__":
main()

View file

@ -1,163 +0,0 @@
#!/usr/bin/env python
#
# This is a basic single-node example of how to use the Entangled
# DHT. It creates a Node and (optionally) joins an existing DHT. It
# then does a Kademlia store and find, and then it deletes the stored
# value (non-Kademlia method).
#
# No tuple space functionality is demonstrated by this script.
#
# To test it properly, start a multi-node Kademlia DHT with the "create_network.py"
# script and point this node to that, e.g.:
# $python create_network.py 10 127.0.0.1
#
# $python basic_example.py 5000 127.0.0.1 4000
#
# This library is free software, distributed under the terms of
# the GNU Lesser General Public License Version 3, or any later version.
# See the COPYING file included in this archive
#
# Thanks to Paul Cannon for IP-address resolution functions (taken from aspn.activestate.com)
import binascii
import random
import twisted.internet.reactor
from lbrynet.dht.node import Node
from lbrynet.core.cryptoutils import get_lbry_hash_obj
# The Entangled DHT node; instantiated in the main() method
node = None
# The key to use for this example when storing/retrieving data
h = get_lbry_hash_obj()
h.update("key")
KEY = h.digest()
# The value to store
VALUE = random.randint(10000, 20000)
lbryid = KEY
def storeValue(key, value):
""" Stores the specified value in the DHT using the specified key """
global node
print '\nStoring value; Key: %s, Value: %s' % (key, value)
# Store the value in the DHT. This method returns a Twisted
# Deferred result, which we then add callbacks to
deferredResult = node.announceHaveHash(key, value)
# Add our callback; this method is called when the operation completes...
deferredResult.addCallback(storeValueCallback)
# ...and for error handling, add an "error callback" as well.
#
# For this example script, I use a generic error handler; usually
# you would need something more specific
deferredResult.addErrback(genericErrorCallback)
def storeValueCallback(*args, **kwargs):
""" Callback function that is invoked when the storeValue() operation succeeds """
print 'Value has been stored in the DHT'
# Now that the value has been stored, schedule that the value is read again after 2.5 seconds
print 'Scheduling retrieval in 2.5 seconds'
twisted.internet.reactor.callLater(2.5, getValue)
def genericErrorCallback(failure):
""" Callback function that is invoked if an error occurs during any of the DHT operations """
print 'An error has occurred:', failure.getErrorMessage()
twisted.internet.reactor.callLater(0, stop)
def getValue():
""" Retrieves the value of the specified key (KEY) from the DHT """
global node, KEY
# Get the value for the specified key (immediately returns a Twisted deferred result)
print ('\nRetrieving value from DHT for key "%s"' %
binascii.unhexlify("f7d9dc4de674eaa2c5a022eb95bc0d33ec2e75c6"))
deferredResult = node.iterativeFindValue(
binascii.unhexlify("f7d9dc4de674eaa2c5a022eb95bc0d33ec2e75c6"))
# Add a callback to this result; this will be called as soon as the operation has completed
deferredResult.addCallback(getValueCallback)
# As before, add the generic error callback
deferredResult.addErrback(genericErrorCallback)
def getValueCallback(result):
""" Callback function that is invoked when the getValue() operation succeeds """
# Check if the key was found (result is a dict of format {key:
# value}) or not (in which case a list of "closest" Kademlia
# contacts would be returned instead")
print "Got the value"
print result
# Either way, schedule a "delete" operation for the key
print 'Scheduling shutdown in 2.5 seconds'
twisted.internet.reactor.callLater(2.5, stop)
def stop():
""" Stops the Twisted reactor, and thus the script """
print '\nStopping Kademlia node and terminating script'
twisted.internet.reactor.stop()
if __name__ == '__main__':
import sys
if len(sys.argv) < 2:
print 'Usage:\n%s UDP_PORT [KNOWN_NODE_IP KNOWN_NODE_PORT]' % sys.argv[0]
print 'or:\n%s UDP_PORT [FILE_WITH_KNOWN_NODES]' % sys.argv[0]
print
print 'If a file is specified, it should containg one IP address and UDP port'
print 'per line, seperated by a space.'
sys.exit(1)
try:
int(sys.argv[1])
except ValueError:
print '\nUDP_PORT must be an integer value.\n'
print 'Usage:\n%s UDP_PORT [KNOWN_NODE_IP KNOWN_NODE_PORT]' % sys.argv[0]
print 'or:\n%s UDP_PORT [FILE_WITH_KNOWN_NODES]' % sys.argv[0]
print
print 'If a file is specified, it should contain one IP address and UDP port'
print 'per line, seperated by a space.'
sys.exit(1)
if len(sys.argv) == 4:
knownNodes = [(sys.argv[2], int(sys.argv[3]))]
elif len(sys.argv) == 3:
knownNodes = []
f = open(sys.argv[2], 'r')
lines = f.readlines()
f.close()
for line in lines:
ipAddress, udpPort = line.split()
knownNodes.append((ipAddress, int(udpPort)))
else:
knownNodes = None
print '\nNOTE: You have not specified any remote DHT node(s) to connect to'
print 'It will thus not be aware of any existing DHT, but will still function as'
print ' a self-contained DHT (until another node contacts it).'
print 'Run this script without any arguments for info.\n'
# Set up SQLite-based data store (you could use an in-memory store instead, for example)
#
# Create the Entangled node. It extends the functionality of a
# basic Kademlia node (but is fully backwards-compatible with a
# Kademlia-only network)
#
# If you wish to have a pure Kademlia network, use the
# entangled.kademlia.node.Node class instead
print 'Creating Node'
node = Node(udpPort=int(sys.argv[1]), node_id=lbryid)
# Schedule the node to join the Kademlia/Entangled DHT
node.joinNetwork(knownNodes)
# Schedule the "storeValue() call to be invoked after 2.5 seconds,
# using KEY and VALUE as arguments
twisted.internet.reactor.callLater(2.5, getValue)
# Start the Twisted reactor - this fires up all networking, and
# allows the scheduled join operation to take place
print 'Twisted reactor started (script will commence in 2.5 seconds)'
twisted.internet.reactor.run()

View file

@ -1,41 +0,0 @@
"""
CLI for sending rpc commands to a DHT node
"""
import argparse
from twisted.internet import reactor
from txjsonrpc.web.jsonrpc import Proxy
def print_value(value):
print value
def print_error(err):
print err.getErrorMessage()
def shut_down():
reactor.stop()
def main():
parser = argparse.ArgumentParser(description="Send an rpc command to a dht node")
parser.add_argument("rpc_command",
help="The rpc command to send to the dht node")
parser.add_argument("--node_host",
help="The host of the node to connect to",
default="127.0.0.1")
parser.add_argument("--node_port",
help="The port of the node to connect to",
default="8888")
args = parser.parse_args()
connect_string = 'http://%s:%s' % (args.node_host, args.node_port)
proxy = Proxy(connect_string)
d = proxy.callRemote(args.rpc_command)
d.addCallbacks(print_value, print_error)
d.addBoth(lambda _: shut_down())
reactor.run()

View file

@ -1,214 +0,0 @@
import logging
import requests
import miniupnpc
import argparse
from copy import deepcopy
from twisted.internet import reactor, defer
from twisted.web import resource
from twisted.web.server import Site
from lbrynet import conf
from lbrynet.core.log_support import configure_console
from lbrynet.dht.error import TimeoutError
conf.initialize_settings()
log = logging.getLogger("dht tool")
configure_console()
log.setLevel(logging.INFO)
from lbrynet.dht.node import Node
from lbrynet.dht.contact import Contact
from lbrynet.daemon.auth.server import AuthJSONRPCServer
from lbrynet.core.utils import generate_id
def get_external_ip_and_setup_upnp():
try:
u = miniupnpc.UPnP()
u.discoverdelay = 200
u.discover()
u.selectigd()
if u.getspecificportmapping(4444, "UDP"):
u.deleteportmapping(4444, "UDP")
log.info("Removed UPnP redirect for UDP 4444.")
u.addportmapping(4444, 'UDP', u.lanaddr, 4444, 'LBRY DHT port', '')
log.info("got external ip from upnp")
return u.externalipaddress()
except Exception:
log.exception("derp")
r = requests.get('https://api.ipify.org', {'format': 'json'})
log.info("got external ip from ipify.org")
return r.json()['ip']
class NodeRPC(AuthJSONRPCServer):
def __init__(self, lbryid, seeds, node_port, rpc_port):
AuthJSONRPCServer.__init__(self, False)
self.root = None
self.port = None
self.seeds = seeds
self.node_port = node_port
self.rpc_port = rpc_port
if lbryid:
lbryid = lbryid.decode('hex')
else:
lbryid = generate_id()
self.node_id = lbryid
self.external_ip = get_external_ip_and_setup_upnp()
self.node_port = node_port
@defer.inlineCallbacks
def setup(self):
self.node = Node(node_id=self.node_id, udpPort=self.node_port,
externalIP=self.external_ip)
hosts = []
for hostname, hostport in self.seeds:
host_ip = yield reactor.resolve(hostname)
hosts.append((host_ip, hostport))
log.info("connecting to dht")
yield self.node.joinNetwork(tuple(hosts))
log.info("connected to dht")
if not self.announced_startup:
self.announced_startup = True
self.start_api()
log.info("lbry id: %s (%i bytes)", self.node.node_id.encode('hex'), len(self.node.node_id))
def start_api(self):
root = resource.Resource()
root.putChild('', self)
self.port = reactor.listenTCP(self.rpc_port, Site(root), interface='localhost')
log.info("started jsonrpc server")
@defer.inlineCallbacks
def jsonrpc_node_id_set(self, node_id):
old_id = self.node.node_id
self.node.stop()
del self.node
self.node_id = node_id.decode('hex')
yield self.setup()
msg = "changed dht id from %s to %s" % (old_id.encode('hex'),
self.node.node_id.encode('hex'))
defer.returnValue(msg)
def jsonrpc_node_id_get(self):
return self._render_response(self.node.node_id.encode('hex'))
@defer.inlineCallbacks
def jsonrpc_peer_find(self, node_id):
node_id = node_id.decode('hex')
contact = yield self.node.findContact(node_id)
result = None
if contact:
result = (contact.address, contact.port)
defer.returnValue(result)
@defer.inlineCallbacks
def jsonrpc_peer_list_for_blob(self, blob_hash):
peers = yield self.node.getPeersForBlob(blob_hash.decode('hex'))
defer.returnValue(peers)
@defer.inlineCallbacks
def jsonrpc_ping(self, node_id):
contact_host = yield self.jsonrpc_peer_find(node_id=node_id)
if not contact_host:
defer.returnValue("failed to find node")
contact_ip, contact_port = contact_host
contact = Contact(node_id.decode('hex'), contact_ip, contact_port, self.node._protocol)
try:
result = yield contact.ping()
except TimeoutError:
self.node.removeContact(contact.id)
self.node._dataStore.removePeer(contact.id)
result = {'error': 'timeout'}
defer.returnValue(result)
def get_routing_table(self):
result = {}
data_store = deepcopy(self.node._dataStore._dict)
datastore_len = len(data_store)
hosts = {}
missing_contacts = []
if datastore_len:
for k, v in data_store.iteritems():
for value, lastPublished, originallyPublished, originalPublisherID in v:
try:
contact = self.node._routingTable.getContact(originalPublisherID)
except ValueError:
if originalPublisherID.encode('hex') not in missing_contacts:
missing_contacts.append(originalPublisherID.encode('hex'))
continue
if contact in hosts:
blobs = hosts[contact]
else:
blobs = []
blobs.append(k.encode('hex'))
hosts[contact] = blobs
contact_set = []
blob_hashes = []
result['buckets'] = {}
for i in range(len(self.node._routingTable._buckets)):
for contact in self.node._routingTable._buckets[i]._contacts:
contacts = result['buckets'].get(i, [])
if contact in hosts:
blobs = hosts[contact]
del hosts[contact]
else:
blobs = []
host = {
"address": contact.address,
"id": contact.id.encode("hex"),
"blobs": blobs,
}
for blob_hash in blobs:
if blob_hash not in blob_hashes:
blob_hashes.append(blob_hash)
contacts.append(host)
result['buckets'][i] = contacts
contact_set.append(contact.id.encode("hex"))
if hosts:
result['datastore extra'] = [
{
"id": host.id.encode('hex'),
"blobs": hosts[host],
}
for host in hosts]
result['missing contacts'] = missing_contacts
result['contacts'] = contact_set
result['blob hashes'] = blob_hashes
result['node id'] = self.node_id.encode('hex')
return result
def jsonrpc_routing_table_get(self):
return self._render_response(self.get_routing_table())
def main():
parser = argparse.ArgumentParser(description="Launch a dht node which responds to rpc commands")
parser.add_argument("--node_port",
help=("The UDP port on which the node will listen for connections "
"from other dht nodes"),
type=int, default=4444)
parser.add_argument("--rpc_port",
help="The TCP port on which the node will listen for rpc commands",
type=int, default=5280)
parser.add_argument("--bootstrap_host",
help="The IP of a DHT node to be used to bootstrap into the network",
default='lbrynet1.lbry.io')
parser.add_argument("--node_id",
help="The IP of a DHT node to be used to bootstrap into the network",
default=None)
parser.add_argument("--bootstrap_port",
help="The port of a DHT node to be used to bootstrap into the network",
default=4444, type=int)
args = parser.parse_args()
seeds = [(args.bootstrap_host, args.bootstrap_port)]
server = NodeRPC(args.node_id, seeds, args.node_port, args.rpc_port)
reactor.addSystemEventTrigger('after', 'startup', server.setup)
reactor.run()
if __name__ == "__main__":
main()

228
scripts/seed_node.py Normal file
View file

@ -0,0 +1,228 @@
import struct
import json
import logging
import argparse
import hashlib
from copy import deepcopy
from urllib import urlopen
from twisted.internet.epollreactor import install as install_epoll
install_epoll()
from twisted.internet import reactor, defer
from twisted.web import resource
from twisted.web.server import Site
from lbrynet import conf
from lbrynet.dht import constants
from lbrynet.dht.node import Node
from lbrynet.dht.error import TransportNotConnected
from lbrynet.core.log_support import configure_console, configure_twisted
from lbrynet.daemon.auth.server import AuthJSONRPCServer
# configure_twisted()
conf.initialize_settings()
configure_console()
lbrynet_handler = logging.getLogger("lbrynet").handlers[0]
log = logging.getLogger("dht router")
log.addHandler(lbrynet_handler)
log.setLevel(logging.INFO)
def node_id_supplier(seed="jack.lbry.tech"): # simple deterministic node id generator
h = hashlib.sha384()
h.update(seed)
while True:
next_id = h.digest()
yield next_id
h = hashlib.sha384()
h.update(seed)
h.update(next_id)
def get_external_ip():
response = json.loads(urlopen("https://api.lbry.io/ip").read())
if not response['success']:
raise ValueError("failed to get external ip")
return response['data']['ip']
def format_contact(contact):
return {
"node_id": contact.id.encode('hex'),
"address": contact.address,
"nodePort": contact.port,
"lastReplied": contact.lastReplied,
"lastRequested": contact.lastRequested,
"failedRPCs": contact.failedRPCs,
"lastFailed": None if not contact.failures else contact.failures[-1]
}
def format_datastore(node):
datastore = deepcopy(node._dataStore._dict)
result = {}
for key, values in datastore.iteritems():
contacts = []
for (contact, value, last_published, originally_published, original_publisher_id) in values:
contact_dict = format_contact(contact)
contact_dict['peerPort'] = struct.unpack('>H', value[4:6])[0]
contact_dict['lastPublished'] = last_published
contact_dict['originallyPublished'] = originally_published
contact_dict['originalPublisherID'] = original_publisher_id.encode('hex')
contacts.append(contact_dict)
result[key.encode('hex')] = contacts
return result
class MultiSeedRPCServer(AuthJSONRPCServer):
def __init__(self, starting_node_port, nodes, rpc_port):
AuthJSONRPCServer.__init__(self, False)
self.port = None
self.rpc_port = rpc_port
self.external_ip = get_external_ip()
node_id_gen = node_id_supplier()
self._nodes = [Node(node_id=next(node_id_gen), udpPort=starting_node_port+i, externalIP=self.external_ip)
for i in range(nodes)]
self._own_addresses = [(self.external_ip, starting_node_port+i) for i in range(nodes)]
reactor.addSystemEventTrigger('after', 'startup', self.start)
@defer.inlineCallbacks
def start(self):
self.announced_startup = True
root = resource.Resource()
root.putChild('', self)
self.port = reactor.listenTCP(self.rpc_port, Site(root), interface='localhost')
log.info("starting %i nodes on %s, rpc available on localhost:%i", len(self._nodes), self.external_ip, self.rpc_port)
for node in self._nodes:
node.start_listening()
yield node._protocol._listening
for node1 in self._nodes:
for node2 in self._nodes:
if node1 is node2:
continue
try:
yield node1.addContact(node1.contact_manager.make_contact(node2.node_id, node2.externalIP,
node2.port, node1._protocol))
except TransportNotConnected:
pass
node1.safe_start_looping_call(node1._change_token_lc, constants.tokenSecretChangeInterval)
node1.safe_start_looping_call(node1._refresh_node_lc, constants.checkRefreshInterval)
node1._join_deferred = defer.succeed(True)
reactor.addSystemEventTrigger('before', 'shutdown', self.stop)
log.info("finished bootstrapping the network, running %i nodes", len(self._nodes))
@defer.inlineCallbacks
def stop(self):
yield self.port.stopListening()
yield defer.DeferredList([node.stop() for node in self._nodes])
def jsonrpc_get_node_ids(self):
return defer.succeed([node.node_id.encode('hex') for node in self._nodes])
def jsonrpc_node_datastore(self, node_id):
for node in self._nodes:
if node.node_id == node_id.decode('hex'):
return defer.succeed(format_datastore(node))
def jsonrpc_get_nodes_who_stored(self, blob_hash):
storing_nodes = {}
for node in self._nodes:
datastore = format_datastore(node)
if blob_hash in datastore:
storing_nodes[node.node_id.encode('hex')] = datastore[blob_hash]
return defer.succeed(storing_nodes)
def jsonrpc_node_routing_table(self, node_id):
def format_bucket(bucket):
return {
"contacts": [format_contact(contact) for contact in bucket._contacts],
"lastAccessed": bucket.lastAccessed
}
def format_routing(node):
return {
i: format_bucket(bucket) for i, bucket in enumerate(node._routingTable._buckets)
}
for node in self._nodes:
if node.node_id == node_id.decode('hex'):
return defer.succeed(format_routing(node))
def jsonrpc_restart_node(self, node_id):
for node in self._nodes:
if node.node_id == node_id.decode('hex'):
d = node.stop()
d.addCallback(lambda _: node.start(self._own_addresses))
return d
@defer.inlineCallbacks
def jsonrpc_local_node_rpc(self, from_node, query, args=()):
def format_result(response):
if isinstance(response, list):
return [[node_id.encode('hex'), address, port] for (node_id, address, port) in response]
if isinstance(response, dict):
return {'token': response['token'].encode('hex'), 'contacts': format_result(response['contacts'])}
return response
for node in self._nodes:
if node.node_id == from_node.decode('hex'):
fn = getattr(node, query)
self_contact = node.contact_manager.make_contact(node.node_id, node.externalIP, node.port, node._protocol)
if args:
args = (str(arg) if isinstance(arg, (str, unicode)) else int(arg) for arg in args)
result = yield fn(self_contact, *args)
else:
result = yield fn()
# print "result: %s" % result
defer.returnValue(format_result(result))
@defer.inlineCallbacks
def jsonrpc_node_rpc(self, from_node, to_node, query, args=()):
def format_result(response):
if isinstance(response, list):
return [[node_id.encode('hex'), address, port] for (node_id, address, port) in response]
if isinstance(response, dict):
return {'token': response['token'].encode('hex'), 'contacts': format_result(response['contacts'])}
return response
for node in self._nodes:
if node.node_id == from_node.decode('hex'):
remote = node._routingTable.getContact(to_node.decode('hex'))
fn = getattr(remote, query)
if args:
args = (str(arg).decode('hex') for arg in args)
result = yield fn(*args)
else:
result = yield fn()
defer.returnValue(format_result(result))
@defer.inlineCallbacks
def jsonrpc_get_nodes_who_know(self, ip_address):
nodes = []
for node_id in [n.node_id.encode('hex') for n in self._nodes]:
routing_info = yield self.jsonrpc_node_routing_table(node_id=node_id)
for index, bucket in routing_info.iteritems():
if ip_address in map(lambda c: c['address'], bucket['contacts']):
nodes.append(node_id)
break
defer.returnValue(nodes)
def jsonrpc_node_status(self):
return defer.succeed({
node.node_id.encode('hex'): node._join_deferred is not None and node._join_deferred.called
for node in self._nodes
})
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--rpc_port', default=5280)
parser.add_argument('--starting_port', default=4455)
parser.add_argument('--nodes', default=32)
args = parser.parse_args()
MultiSeedRPCServer(int(args.starting_port), int(args.nodes), int(args.rpc_port))
reactor.run()
if __name__ == "__main__":
main()

View file

@ -1,47 +0,0 @@
#!/usr/bin/env python
from lbrynet.core import log_support
import logging.handlers
import sys
import time
from pprint import pprint
from twisted.internet import defer, reactor
from lbrynet.dht.node import Node
import lbrynet.dht.constants
import lbrynet.dht.datastore
from lbrynet.tests.util import random_lbry_hash
log = logging.getLogger(__name__)
@defer.inlineCallbacks
def run():
nodeid = "9648996b4bef3ff41176668a0577f86aba7f1ea2996edd18f9c42430802c8085331345c5f0c44a7f352e2ba8ae59aaaa".decode("hex")
node = Node(node_id=nodeid, externalIP='127.0.0.1', udpPort=21999, peerPort=1234)
node.startNetwork()
yield node.joinNetwork([("127.0.0.1", 21001)])
print ""
print ""
print ""
print ""
print ""
print ""
yield node.announceHaveBlob("2bb150cb996b4bef3ff41176648a0577f86abb7f1ea2996edd18f9c42430802c8085331345c5f0c44a7f352e2ba8ae59".decode("hex"))
log.info("Shutting down...")
reactor.callLater(1, reactor.stop)
def main():
log_support.configure_console(level='DEBUG')
log_support.configure_twisted()
reactor.callLater(0, run)
log.info("Running reactor")
reactor.run()
if __name__ == '__main__':
sys.exit(main())

View file

@ -17,11 +17,12 @@ from setuptools import setup, find_packages
requires = [ requires = [
'Twisted', 'Twisted',
'appdirs', 'appdirs',
'distro',
'base58', 'base58',
'envparse', 'envparse',
'jsonrpc', 'jsonrpc',
'lbryschema==0.0.15', 'lbryschema==0.0.16rc2',
'lbryum==3.2.1', 'lbryum==3.2.2rc1',
'miniupnpc', 'miniupnpc',
'pyyaml', 'pyyaml',
'requests', 'requests',