Merge pull request #2125 from lbryio/dht_wip
Remove search junctions, make search retry, some refactoring
This commit is contained in:
commit
26d183cbab
14 changed files with 327 additions and 388 deletions
|
@ -31,8 +31,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
self.closed = asyncio.Event(loop=self.loop)
|
self.closed = asyncio.Event(loop=self.loop)
|
||||||
|
|
||||||
def data_received(self, data: bytes):
|
def data_received(self, data: bytes):
|
||||||
log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
#log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
||||||
self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
# self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
||||||
if not self.transport or self.transport.is_closing():
|
if not self.transport or self.transport.is_closing():
|
||||||
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
|
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
|
||||||
binascii.hexlify(data))
|
binascii.hexlify(data))
|
||||||
|
|
|
@ -37,7 +37,6 @@ class BlobDownloader:
|
||||||
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
|
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return
|
return
|
||||||
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
|
||||||
transport = self.connections.get(peer)
|
transport = self.connections.get(peer)
|
||||||
start = self.loop.time()
|
start = self.loop.time()
|
||||||
bytes_received, transport = await request_blob(
|
bytes_received, transport = await request_blob(
|
||||||
|
@ -55,17 +54,18 @@ class BlobDownloader:
|
||||||
self.failures[peer] = 0
|
self.failures[peer] = 0
|
||||||
self.connections[peer] = transport
|
self.connections[peer] = transport
|
||||||
elapsed = self.loop.time() - start
|
elapsed = self.loop.time() - start
|
||||||
self.scores[peer] = bytes_received / elapsed if bytes_received and elapsed else 0
|
self.scores[peer] = bytes_received / elapsed if bytes_received and elapsed else 1
|
||||||
|
|
||||||
async def new_peer_or_finished(self):
|
async def new_peer_or_finished(self):
|
||||||
active_tasks = list(self.active_connections.values()) + [asyncio.sleep(1)]
|
active_tasks = list(self.active_connections.values()) + [asyncio.sleep(1)]
|
||||||
await asyncio.wait(active_tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
await asyncio.wait(active_tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
||||||
|
|
||||||
def cleanup_active(self):
|
def cleanup_active(self):
|
||||||
|
if not self.active_connections and not self.connections:
|
||||||
|
self.clearbanned()
|
||||||
to_remove = [peer for (peer, task) in self.active_connections.items() if task.done()]
|
to_remove = [peer for (peer, task) in self.active_connections.items() if task.done()]
|
||||||
for peer in to_remove:
|
for peer in to_remove:
|
||||||
del self.active_connections[peer]
|
del self.active_connections[peer]
|
||||||
self.clearbanned()
|
|
||||||
|
|
||||||
def clearbanned(self):
|
def clearbanned(self):
|
||||||
now = self.loop.time()
|
now = self.loop.time()
|
||||||
|
|
|
@ -33,6 +33,10 @@ class BlobAnnouncer:
|
||||||
while batch_size:
|
while batch_size:
|
||||||
if not self.node.joined.is_set():
|
if not self.node.joined.is_set():
|
||||||
await self.node.joined.wait()
|
await self.node.joined.wait()
|
||||||
|
await asyncio.sleep(60, loop=self.loop)
|
||||||
|
if not self.node.protocol.routing_table.get_peers():
|
||||||
|
log.warning("No peers in DHT, announce round skipped")
|
||||||
|
continue
|
||||||
self.announce_queue.extend(await self.storage.get_blobs_to_announce())
|
self.announce_queue.extend(await self.storage.get_blobs_to_announce())
|
||||||
log.debug("announcer task wake up, %d blobs to announce", len(self.announce_queue))
|
log.debug("announcer task wake up, %d blobs to announce", len(self.announce_queue))
|
||||||
while len(self.announce_queue):
|
while len(self.announce_queue):
|
||||||
|
@ -45,7 +49,6 @@ class BlobAnnouncer:
|
||||||
if announced:
|
if announced:
|
||||||
await self.storage.update_last_announced_blobs(announced)
|
await self.storage.update_last_announced_blobs(announced)
|
||||||
log.info("announced %i blobs", len(announced))
|
log.info("announced %i blobs", len(announced))
|
||||||
await asyncio.sleep(60)
|
|
||||||
|
|
||||||
def start(self, batch_size: typing.Optional[int] = 10):
|
def start(self, batch_size: typing.Optional[int] = 10):
|
||||||
assert not self.announce_task or self.announce_task.done(), "already running"
|
assert not self.announce_task or self.announce_task.done(), "already running"
|
||||||
|
|
|
@ -2,11 +2,8 @@ import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
import binascii
|
import binascii
|
||||||
import contextlib
|
|
||||||
from lbrynet.utils import resolve_host
|
from lbrynet.utils import resolve_host
|
||||||
from lbrynet.dht import constants
|
from lbrynet.dht import constants
|
||||||
from lbrynet.dht.error import RemoteException
|
|
||||||
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
|
|
||||||
from lbrynet.dht.protocol.distance import Distance
|
from lbrynet.dht.protocol.distance import Distance
|
||||||
from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder
|
from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder
|
||||||
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
from lbrynet.dht.protocol.protocol import KademliaProtocol
|
||||||
|
@ -31,7 +28,7 @@ class Node:
|
||||||
self._join_task: asyncio.Task = None
|
self._join_task: asyncio.Task = None
|
||||||
self._refresh_task: asyncio.Task = None
|
self._refresh_task: asyncio.Task = None
|
||||||
|
|
||||||
async def refresh_node(self):
|
async def refresh_node(self, force_once=False):
|
||||||
while True:
|
while True:
|
||||||
# remove peers with expired blob announcements from the datastore
|
# remove peers with expired blob announcements from the datastore
|
||||||
self.protocol.data_store.removed_expired_peers()
|
self.protocol.data_store.removed_expired_peers()
|
||||||
|
@ -58,6 +55,8 @@ class Node:
|
||||||
peers = await self.peer_search(node_ids.pop())
|
peers = await self.peer_search(node_ids.pop())
|
||||||
total_peers.extend(peers)
|
total_peers.extend(peers)
|
||||||
else:
|
else:
|
||||||
|
if force_once:
|
||||||
|
break
|
||||||
fut = asyncio.Future(loop=self.loop)
|
fut = asyncio.Future(loop=self.loop)
|
||||||
self.loop.call_later(constants.refresh_interval // 4, fut.set_result, None)
|
self.loop.call_later(constants.refresh_interval // 4, fut.set_result, None)
|
||||||
await fut
|
await fut
|
||||||
|
@ -67,14 +66,14 @@ class Node:
|
||||||
to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True]
|
to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True]
|
||||||
if to_ping:
|
if to_ping:
|
||||||
self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0)
|
self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0)
|
||||||
|
if force_once:
|
||||||
|
break
|
||||||
|
|
||||||
fut = asyncio.Future(loop=self.loop)
|
fut = asyncio.Future(loop=self.loop)
|
||||||
self.loop.call_later(constants.refresh_interval, fut.set_result, None)
|
self.loop.call_later(constants.refresh_interval, fut.set_result, None)
|
||||||
await fut
|
await fut
|
||||||
|
|
||||||
async def announce_blob(self, blob_hash: str) -> typing.List[bytes]:
|
async def announce_blob(self, blob_hash: str) -> typing.List[bytes]:
|
||||||
announced_to_node_ids = []
|
|
||||||
while not announced_to_node_ids:
|
|
||||||
hash_value = binascii.unhexlify(blob_hash.encode())
|
hash_value = binascii.unhexlify(blob_hash.encode())
|
||||||
assert len(hash_value) == constants.hash_length
|
assert len(hash_value) == constants.hash_length
|
||||||
peers = await self.peer_search(hash_value)
|
peers = await self.peer_search(hash_value)
|
||||||
|
@ -87,10 +86,13 @@ class Node:
|
||||||
stored_to_tup = await asyncio.gather(
|
stored_to_tup = await asyncio.gather(
|
||||||
*(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop
|
*(self.protocol.store_to_peer(hash_value, peer) for peer in peers), loop=self.loop
|
||||||
)
|
)
|
||||||
announced_to_node_ids.extend([node_id for node_id, contacted in stored_to_tup if contacted])
|
stored_to = [node_id for node_id, contacted in stored_to_tup if contacted]
|
||||||
|
if stored_to:
|
||||||
log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8],
|
log.info("Stored %s to %i of %i attempted peers", binascii.hexlify(hash_value).decode()[:8],
|
||||||
len(announced_to_node_ids), len(peers))
|
len(stored_to), len(peers))
|
||||||
return announced_to_node_ids
|
else:
|
||||||
|
log.warning("Failed announcing %s, stored to 0 peers", blob_hash[:8])
|
||||||
|
return stored_to
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if self.joined.is_set():
|
if self.joined.is_set():
|
||||||
|
@ -101,6 +103,7 @@ class Node:
|
||||||
self._refresh_task.cancel()
|
self._refresh_task.cancel()
|
||||||
if self.protocol and self.protocol.ping_queue.running:
|
if self.protocol and self.protocol.ping_queue.running:
|
||||||
self.protocol.ping_queue.stop()
|
self.protocol.ping_queue.stop()
|
||||||
|
self.protocol.stop()
|
||||||
if self.listening_port is not None:
|
if self.listening_port is not None:
|
||||||
self.listening_port.close()
|
self.listening_port.close()
|
||||||
self._join_task = None
|
self._join_task = None
|
||||||
|
@ -113,6 +116,7 @@ class Node:
|
||||||
lambda: self.protocol, (interface, self.internal_udp_port)
|
lambda: self.protocol, (interface, self.internal_udp_port)
|
||||||
)
|
)
|
||||||
log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port)
|
log.info("DHT node listening on UDP %s:%i", interface, self.internal_udp_port)
|
||||||
|
self.protocol.start()
|
||||||
else:
|
else:
|
||||||
log.warning("Already bound to port %s", self.listening_port)
|
log.warning("Already bound to port %s", self.listening_port)
|
||||||
|
|
||||||
|
@ -130,33 +134,31 @@ class Node:
|
||||||
if known_node_urls:
|
if known_node_urls:
|
||||||
for host, port in known_node_urls:
|
for host, port in known_node_urls:
|
||||||
address = await resolve_host(host, port, proto='udp')
|
address = await resolve_host(host, port, proto='udp')
|
||||||
if (address, port) not in known_node_addresses and address != self.protocol.external_ip:
|
if (address, port) not in known_node_addresses and\
|
||||||
|
(address, port) != (self.protocol.external_ip, self.protocol.udp_port):
|
||||||
known_node_addresses.append((address, port))
|
known_node_addresses.append((address, port))
|
||||||
url_to_addr[address] = host
|
url_to_addr[address] = host
|
||||||
|
|
||||||
if known_node_addresses:
|
if known_node_addresses:
|
||||||
while not self.protocol.routing_table.get_peers():
|
peers = [
|
||||||
success = False
|
KademliaPeer(self.loop, address, udp_port=port)
|
||||||
# ping the seed nodes, this will set their node ids (since we don't know them ahead of time)
|
for (address, port) in known_node_addresses
|
||||||
for address, port in known_node_addresses:
|
]
|
||||||
peer = self.protocol.get_rpc_peer(KademliaPeer(self.loop, address, udp_port=port))
|
while True:
|
||||||
try:
|
if not self.protocol.routing_table.get_peers():
|
||||||
await peer.ping()
|
if self.joined.is_set():
|
||||||
success = True
|
self.joined.clear()
|
||||||
except asyncio.TimeoutError:
|
self.protocol.peer_manager.reset()
|
||||||
log.warning("seed node (%s:%i) timed out in %s", url_to_addr.get(address, address), port,
|
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
|
||||||
round(self.protocol.rpc_timeout, 2))
|
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
|
||||||
if success:
|
if self.protocol.routing_table.get_peers():
|
||||||
break
|
self.joined.set()
|
||||||
# now that we have the seed nodes in routing, to an iterative lookup of our own id to populate the buckets
|
log.info(
|
||||||
# in the routing table with good peers who are near us
|
"Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||||
async with self.peer_search_junction(self.protocol.node_id, max_results=16) as junction:
|
self.protocol.routing_table.buckets_with_contacts())
|
||||||
async for peers in junction:
|
else:
|
||||||
for peer in peers:
|
continue
|
||||||
try:
|
await asyncio.sleep(1, loop=self.loop)
|
||||||
await self.protocol.get_rpc_peer(peer).ping()
|
|
||||||
except (asyncio.TimeoutError, RemoteException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||||
self.protocol.routing_table.buckets_with_contacts())
|
self.protocol.routing_table.buckets_with_contacts())
|
||||||
|
@ -169,78 +171,48 @@ class Node:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
|
def get_iterative_node_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None,
|
||||||
bottom_out_limit: int = constants.bottom_out_limit,
|
bottom_out_limit: int = constants.bottom_out_limit,
|
||||||
max_results: int = constants.k) -> IterativeNodeFinder:
|
max_results: int = constants.k) -> IterativeNodeFinder:
|
||||||
|
|
||||||
return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
return IterativeNodeFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
||||||
key, bottom_out_limit, max_results, None, shortlist)
|
key, bottom_out_limit, max_results, None, shortlist)
|
||||||
|
|
||||||
def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List] = None,
|
def get_iterative_value_finder(self, key: bytes, shortlist: typing.Optional[typing.List['KademliaPeer']] = None,
|
||||||
bottom_out_limit: int = 40,
|
bottom_out_limit: int = 40,
|
||||||
max_results: int = -1) -> IterativeValueFinder:
|
max_results: int = -1) -> IterativeValueFinder:
|
||||||
|
|
||||||
return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
|
||||||
key, bottom_out_limit, max_results, None, shortlist)
|
key, bottom_out_limit, max_results, None, shortlist)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
|
||||||
async def stream_peer_search_junction(self, hash_queue: asyncio.Queue, bottom_out_limit=20,
|
|
||||||
max_results=-1) -> AsyncGeneratorJunction:
|
|
||||||
peer_generator = AsyncGeneratorJunction(self.loop)
|
|
||||||
|
|
||||||
async def _add_hashes_from_queue():
|
|
||||||
while True:
|
|
||||||
blob_hash = await hash_queue.get()
|
|
||||||
peer_generator.add_generator(
|
|
||||||
self.get_iterative_value_finder(
|
|
||||||
binascii.unhexlify(blob_hash.encode()), bottom_out_limit=bottom_out_limit,
|
|
||||||
max_results=max_results
|
|
||||||
)
|
|
||||||
)
|
|
||||||
add_hashes_task = self.loop.create_task(_add_hashes_from_queue())
|
|
||||||
try:
|
|
||||||
async with peer_generator as junction:
|
|
||||||
yield junction
|
|
||||||
finally:
|
|
||||||
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
|
|
||||||
add_hashes_task.cancel()
|
|
||||||
|
|
||||||
def peer_search_junction(self, node_id: bytes, max_results=constants.k*2,
|
|
||||||
bottom_out_limit=20) -> AsyncGeneratorJunction:
|
|
||||||
peer_generator = AsyncGeneratorJunction(self.loop)
|
|
||||||
peer_generator.add_generator(
|
|
||||||
self.get_iterative_node_finder(
|
|
||||||
node_id, bottom_out_limit=bottom_out_limit, max_results=max_results
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return peer_generator
|
|
||||||
|
|
||||||
async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2,
|
async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2,
|
||||||
bottom_out_limit=20) -> typing.List['KademliaPeer']:
|
bottom_out_limit=20, shortlist: typing.Optional[typing.List['KademliaPeer']] = None
|
||||||
accumulated: typing.List['KademliaPeer'] = []
|
) -> typing.List['KademliaPeer']:
|
||||||
async with self.peer_search_junction(node_id, max_results=max_results,
|
peers = []
|
||||||
bottom_out_limit=bottom_out_limit) as junction:
|
async for iteration_peers in self.get_iterative_node_finder(
|
||||||
async for peers in junction:
|
node_id, shortlist=shortlist, bottom_out_limit=bottom_out_limit, max_results=max_results):
|
||||||
accumulated.extend(peers)
|
peers.extend(iteration_peers)
|
||||||
distance = Distance(node_id)
|
distance = Distance(node_id)
|
||||||
accumulated.sort(key=lambda peer: distance(peer.node_id))
|
peers.sort(key=lambda peer: distance(peer.node_id))
|
||||||
return accumulated[:count]
|
return peers[:count]
|
||||||
|
|
||||||
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
|
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
|
||||||
result_queue: asyncio.Queue):
|
result_queue: asyncio.Queue):
|
||||||
async with self.stream_peer_search_junction(search_queue) as search_junction: # pylint: disable=E1701
|
tasks = []
|
||||||
async for peers in search_junction:
|
try:
|
||||||
if peers:
|
while True:
|
||||||
result_queue.put_nowait([
|
blob_hash = await search_queue.get()
|
||||||
peer for peer in peers
|
tasks.append(self.loop.create_task(self._value_producer(blob_hash, result_queue)))
|
||||||
if not (
|
finally:
|
||||||
peer.address == self.protocol.external_ip
|
for task in tasks:
|
||||||
and peer.tcp_port == self.protocol.peer_port
|
task.cancel()
|
||||||
)
|
|
||||||
])
|
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
|
||||||
|
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
|
||||||
|
result_queue.put_nowait(results)
|
||||||
|
|
||||||
def accumulate_peers(self, search_queue: asyncio.Queue,
|
def accumulate_peers(self, search_queue: asyncio.Queue,
|
||||||
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[
|
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[
|
||||||
asyncio.Queue, asyncio.Task]:
|
asyncio.Queue, asyncio.Task]:
|
||||||
q = peer_queue or asyncio.Queue()
|
q = peer_queue or asyncio.Queue(loop=self.loop)
|
||||||
return q, asyncio.create_task(self._accumulate_search_junction(search_queue, q))
|
return q, self.loop.create_task(self._accumulate_search_junction(search_queue, q))
|
||||||
|
|
|
@ -31,7 +31,10 @@ class PeerManager:
|
||||||
self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {}
|
self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = {}
|
||||||
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
||||||
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
||||||
self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer']
|
|
||||||
|
def reset(self):
|
||||||
|
for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested):
|
||||||
|
statistic.clear()
|
||||||
|
|
||||||
def report_failure(self, address: str, udp_port: int):
|
def report_failure(self, address: str, udp_port: int):
|
||||||
now = self._loop.time()
|
now = self._loop.time()
|
||||||
|
@ -104,11 +107,12 @@ class PeerManager:
|
||||||
|
|
||||||
delay = self._loop.time() - constants.check_refresh_interval
|
delay = self._loop.time() - constants.check_refresh_interval
|
||||||
|
|
||||||
if node_id not in self._node_id_reverse_mapping or (address, udp_port) not in self._node_id_mapping:
|
# fixme: find a way to re-enable that without breaking other parts
|
||||||
return
|
#if node_id not in self._node_id_reverse_mapping or (address, udp_port) not in self._node_id_mapping:
|
||||||
addr_tup = (address, udp_port)
|
# return
|
||||||
if self._node_id_reverse_mapping[node_id] != addr_tup or self._node_id_mapping[addr_tup] != node_id:
|
#addr_tup = (address, udp_port)
|
||||||
return
|
#if self._node_id_reverse_mapping[node_id] != addr_tup or self._node_id_mapping[addr_tup] != node_id:
|
||||||
|
# return
|
||||||
previous_failure, most_recent_failure = self._rpc_failures.get((address, udp_port), (None, None))
|
previous_failure, most_recent_failure = self._rpc_failures.get((address, udp_port), (None, None))
|
||||||
last_requested = self._last_requested.get((address, udp_port))
|
last_requested = self._last_requested.get((address, udp_port))
|
||||||
last_replied = self._last_replied.get((address, udp_port))
|
last_replied = self._last_replied.get((address, udp_port))
|
||||||
|
|
|
@ -1,94 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import typing
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from types import AsyncGeneratorType
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def cancel_task(task: typing.Optional[asyncio.Task]):
|
|
||||||
if task and not (task.done() or task.cancelled()):
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
|
|
||||||
while tasks:
|
|
||||||
cancel_task(tasks.pop())
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncGeneratorJunction:
|
|
||||||
"""
|
|
||||||
A helper to interleave the results from multiple async generators into one
|
|
||||||
async generator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
|
|
||||||
self.loop = loop
|
|
||||||
self.__iterator_queue = asyncio.Queue(loop=loop)
|
|
||||||
self.result_queue = queue or asyncio.Queue(loop=loop)
|
|
||||||
self.tasks: typing.List[asyncio.Task] = []
|
|
||||||
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
|
|
||||||
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def running(self):
|
|
||||||
return any(self.running_iterators.values())
|
|
||||||
|
|
||||||
async def wait_for_generators(self):
|
|
||||||
async def iterate(iterator: typing.AsyncGenerator):
|
|
||||||
try:
|
|
||||||
async for item in iterator:
|
|
||||||
self.result_queue.put_nowait(item)
|
|
||||||
self.__iterator_queue.put_nowait(item)
|
|
||||||
finally:
|
|
||||||
self.running_iterators[iterator] = False
|
|
||||||
if not self.running:
|
|
||||||
self.__iterator_queue.put_nowait(StopAsyncIteration)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
|
|
||||||
self.running_iterators[async_gen] = True
|
|
||||||
self.tasks.append(self.loop.create_task(iterate(async_gen)))
|
|
||||||
|
|
||||||
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
|
|
||||||
"""
|
|
||||||
Add an async generator. This can be called during an iteration of the generator junction.
|
|
||||||
"""
|
|
||||||
self.generator_queue.put_nowait(async_gen)
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
result = await self.__iterator_queue.get()
|
|
||||||
if result is StopAsyncIteration:
|
|
||||||
raise result
|
|
||||||
return result
|
|
||||||
|
|
||||||
def aclose(self):
|
|
||||||
async def _aclose():
|
|
||||||
for iterator in list(self.running_iterators.keys()):
|
|
||||||
result = iterator.aclose()
|
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
await result
|
|
||||||
self.running_iterators[iterator] = False
|
|
||||||
drain_tasks(self.tasks)
|
|
||||||
raise StopAsyncIteration()
|
|
||||||
return self.loop.create_task(_aclose())
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
self.tasks.append(self.loop.create_task(self.wait_for_generators()))
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
|
||||||
try:
|
|
||||||
await self.aclose()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
if exc_type:
|
|
||||||
if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration, GeneratorExit):
|
|
||||||
err = traceback.format_exception(exc_type, exc, tb)
|
|
||||||
log.error(err)
|
|
|
@ -66,11 +66,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
|
||||||
"""
|
"""
|
||||||
if len(key) != constants.hash_length:
|
if len(key) != constants.hash_length:
|
||||||
raise ValueError("invalid key length: %i" % len(key))
|
raise ValueError("invalid key length: %i" % len(key))
|
||||||
if not shortlist:
|
return shortlist or routing_table.find_close_peers(key)
|
||||||
shortlist = routing_table.find_close_peers(key)
|
|
||||||
distance = Distance(key)
|
|
||||||
shortlist.sort(key=lambda peer: distance(peer.node_id), reverse=True)
|
|
||||||
return shortlist
|
|
||||||
|
|
||||||
|
|
||||||
class IterativeFinder:
|
class IterativeFinder:
|
||||||
|
@ -91,12 +87,11 @@ class IterativeFinder:
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
self.exclude = exclude or []
|
self.exclude = exclude or []
|
||||||
|
|
||||||
self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist)
|
|
||||||
self.active: typing.Set['KademliaPeer'] = set()
|
self.active: typing.Set['KademliaPeer'] = set()
|
||||||
self.contacted: typing.Set[typing.Tuple[str, int]] = set()
|
self.contacted: typing.Set['KademliaPeer'] = set()
|
||||||
self.distance = Distance(key)
|
self.distance = Distance(key)
|
||||||
|
|
||||||
self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0]
|
self.closest_peer: typing.Optional['KademliaPeer'] = None
|
||||||
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
|
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
|
||||||
|
|
||||||
self.iteration_queue = asyncio.Queue(loop=self.loop)
|
self.iteration_queue = asyncio.Queue(loop=self.loop)
|
||||||
|
@ -107,6 +102,12 @@ class IterativeFinder:
|
||||||
self.running = False
|
self.running = False
|
||||||
self.tasks: typing.List[asyncio.Task] = []
|
self.tasks: typing.List[asyncio.Task] = []
|
||||||
self.delayed_calls: typing.List[asyncio.Handle] = []
|
self.delayed_calls: typing.List[asyncio.Handle] = []
|
||||||
|
for peer in get_shortlist(routing_table, key, shortlist):
|
||||||
|
if peer.node_id:
|
||||||
|
self._add_active(peer)
|
||||||
|
else:
|
||||||
|
# seed nodes
|
||||||
|
self._schedule_probe(peer)
|
||||||
|
|
||||||
async def send_probe(self, peer: 'KademliaPeer') -> FindResponse:
|
async def send_probe(self, peer: 'KademliaPeer') -> FindResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -138,25 +139,18 @@ class IterativeFinder:
|
||||||
def _is_closer(self, peer: 'KademliaPeer') -> bool:
|
def _is_closer(self, peer: 'KademliaPeer') -> bool:
|
||||||
return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
|
return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
|
||||||
|
|
||||||
def _update_closest(self):
|
def _add_active(self, peer):
|
||||||
self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True)
|
if peer not in self.active and peer.node_id and peer.node_id != self.protocol.node_id:
|
||||||
if self.closest_peer and self.closest_peer is not self.shortlist[-1]:
|
self.active.add(peer)
|
||||||
if self._is_closer(self.shortlist[-1]):
|
if self._is_closer(peer):
|
||||||
self.prev_closest_peer = self.closest_peer
|
self.prev_closest_peer = self.closest_peer
|
||||||
self.closest_peer = self.shortlist[-1]
|
self.closest_peer = peer
|
||||||
|
|
||||||
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
|
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
|
||||||
if peer not in self.shortlist:
|
self._add_active(peer)
|
||||||
self.shortlist.append(peer)
|
|
||||||
if peer not in self.active:
|
|
||||||
self.active.add(peer)
|
|
||||||
for contact_triple in response.get_close_triples():
|
for contact_triple in response.get_close_triples():
|
||||||
node_id, address, udp_port = contact_triple
|
node_id, address, udp_port = contact_triple
|
||||||
if (address, udp_port) not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple)
|
self._add_active(self.peer_manager.get_kademlia_peer(node_id, address, udp_port))
|
||||||
found_peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
|
|
||||||
if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False:
|
|
||||||
self.shortlist.append(found_peer)
|
|
||||||
self._update_closest()
|
|
||||||
self.check_result_ready(response)
|
self.check_result_ready(response)
|
||||||
|
|
||||||
async def _send_probe(self, peer: 'KademliaPeer'):
|
async def _send_probe(self, peer: 'KademliaPeer'):
|
||||||
|
@ -177,22 +171,31 @@ class IterativeFinder:
|
||||||
|
|
||||||
async def _search_round(self):
|
async def _search_round(self):
|
||||||
"""
|
"""
|
||||||
Send up to constants.alpha (5) probes to the closest peers in the shortlist
|
Send up to constants.alpha (5) probes to closest active peers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
added = 0
|
added = 0
|
||||||
self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True)
|
to_probe = list(self.active - self.contacted)
|
||||||
while self.running and len(self.shortlist) and added < constants.alpha:
|
to_probe.sort(key=lambda peer: self.distance(self.key))
|
||||||
peer = self.shortlist.pop()
|
for peer in to_probe:
|
||||||
|
if added >= constants.alpha:
|
||||||
|
break
|
||||||
origin_address = (peer.address, peer.udp_port)
|
origin_address = (peer.address, peer.udp_port)
|
||||||
if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
|
if origin_address in self.exclude:
|
||||||
continue
|
continue
|
||||||
if peer.node_id == self.protocol.node_id:
|
if peer.node_id == self.protocol.node_id:
|
||||||
continue
|
continue
|
||||||
if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port):
|
if origin_address == (self.protocol.external_ip, self.protocol.udp_port):
|
||||||
continue
|
continue
|
||||||
if (peer.address, peer.udp_port) not in self.contacted:
|
self._schedule_probe(peer)
|
||||||
self.contacted.add((peer.address, peer.udp_port))
|
added += 1
|
||||||
|
log.debug("running %d probes", len(self.running_probes))
|
||||||
|
if not added and not self.running_probes:
|
||||||
|
log.debug("search for %s exhausted", hexlify(self.key)[:8])
|
||||||
|
self.search_exhausted()
|
||||||
|
|
||||||
|
def _schedule_probe(self, peer: 'KademliaPeer'):
|
||||||
|
self.contacted.add(peer)
|
||||||
|
|
||||||
t = self.loop.create_task(self._send_probe(peer))
|
t = self.loop.create_task(self._send_probe(peer))
|
||||||
|
|
||||||
|
@ -200,16 +203,11 @@ class IterativeFinder:
|
||||||
self.running_probes.difference_update({
|
self.running_probes.difference_update({
|
||||||
probe for probe in self.running_probes if probe.done() or probe == t
|
probe for probe in self.running_probes if probe.done() or probe == t
|
||||||
})
|
})
|
||||||
if not self.running_probes and self.shortlist:
|
if not self.running_probes:
|
||||||
self.tasks.append(self.loop.create_task(self._search_task(0.0)))
|
self.tasks.append(self.loop.create_task(self._search_task(0.0)))
|
||||||
|
|
||||||
t.add_done_callback(callback)
|
t.add_done_callback(callback)
|
||||||
self.running_probes.add(t)
|
self.running_probes.add(t)
|
||||||
added += 1
|
|
||||||
log.debug("running %d probes", len(self.running_probes))
|
|
||||||
if not added and not self.running_probes:
|
|
||||||
log.debug("search for %s exhausted", hexlify(self.key)[:8])
|
|
||||||
self.search_exhausted()
|
|
||||||
|
|
||||||
async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay):
|
async def _search_task(self, delay: typing.Optional[float] = constants.iterative_lookup_delay):
|
||||||
try:
|
try:
|
||||||
|
@ -266,6 +264,7 @@ class IterativeNodeFinder(IterativeFinder):
|
||||||
self.yielded_peers: typing.Set['KademliaPeer'] = set()
|
self.yielded_peers: typing.Set['KademliaPeer'] = set()
|
||||||
|
|
||||||
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
|
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
|
||||||
|
log.debug("probing %s:%d %s", peer.address, peer.udp_port, hexlify(peer.node_id)[:8] if peer.node_id else '')
|
||||||
response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
|
response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
|
||||||
return FindNodeResponse(self.key, response)
|
return FindNodeResponse(self.key, response)
|
||||||
|
|
||||||
|
@ -273,12 +272,16 @@ class IterativeNodeFinder(IterativeFinder):
|
||||||
self.put_result(self.active, finish=True)
|
self.put_result(self.active, finish=True)
|
||||||
|
|
||||||
def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False):
|
def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False):
|
||||||
not_yet_yielded = [peer for peer in from_iter if peer not in self.yielded_peers]
|
not_yet_yielded = [
|
||||||
|
peer for peer in from_iter
|
||||||
|
if peer not in self.yielded_peers
|
||||||
|
and peer.node_id != self.protocol.node_id
|
||||||
|
and self.peer_manager.peer_is_good(peer) is not False
|
||||||
|
]
|
||||||
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
|
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
|
||||||
to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
|
to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
|
||||||
if to_yield:
|
if to_yield:
|
||||||
for peer in to_yield:
|
self.yielded_peers.update(to_yield)
|
||||||
self.yielded_peers.add(peer)
|
|
||||||
self.iteration_queue.put_nowait(to_yield)
|
self.iteration_queue.put_nowait(to_yield)
|
||||||
if finish:
|
if finish:
|
||||||
self.iteration_queue.put_nowait(None)
|
self.iteration_queue.put_nowait(None)
|
||||||
|
@ -288,21 +291,17 @@ class IterativeNodeFinder(IterativeFinder):
|
||||||
|
|
||||||
if found:
|
if found:
|
||||||
log.debug("found")
|
log.debug("found")
|
||||||
return self.put_result(self.shortlist, finish=True)
|
return self.put_result(self.active, finish=True)
|
||||||
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
|
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
|
||||||
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
|
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
|
||||||
# self.bottom_out_count, self.iteration_count)
|
# self.bottom_out_count, self.iteration_count)
|
||||||
self.bottom_out_count = 0
|
self.bottom_out_count = 0
|
||||||
elif self.prev_closest_peer and self.closest_peer:
|
elif self.prev_closest_peer and self.closest_peer:
|
||||||
self.bottom_out_count += 1
|
self.bottom_out_count += 1
|
||||||
log.info("bottom out %i %i %i %i", len(self.active), len(self.contacted), len(self.shortlist),
|
log.info("bottom out %i %i %i", len(self.active), len(self.contacted), self.bottom_out_count)
|
||||||
self.bottom_out_count)
|
|
||||||
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
|
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
|
||||||
log.info("limit hit")
|
log.info("limit hit")
|
||||||
self.put_result(self.active, True)
|
self.put_result(self.active, True)
|
||||||
elif self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
|
|
||||||
log.debug("max results")
|
|
||||||
self.put_result(self.active, True)
|
|
||||||
|
|
||||||
|
|
||||||
class IterativeValueFinder(IterativeFinder):
|
class IterativeValueFinder(IterativeFinder):
|
||||||
|
|
|
@ -191,12 +191,14 @@ class PingQueue:
|
||||||
self._process_task: asyncio.Task = None
|
self._process_task: asyncio.Task = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._running_pings: typing.Set[asyncio.Task] = set()
|
self._running_pings: typing.Set[asyncio.Task] = set()
|
||||||
|
self._default_delay = constants.maybe_ping_delay
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self):
|
def running(self):
|
||||||
return self._running
|
return self._running
|
||||||
|
|
||||||
def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: float = constants.maybe_ping_delay):
|
def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None):
|
||||||
|
delay = delay if delay is not None else self._default_delay
|
||||||
now = self._loop.time()
|
now = self._loop.time()
|
||||||
for peer in peers:
|
for peer in peers:
|
||||||
if peer not in self._pending_contacts or now + delay < self._pending_contacts[peer]:
|
if peer not in self._pending_contacts or now + delay < self._pending_contacts[peer]:
|
||||||
|
@ -207,7 +209,7 @@ class PingQueue:
|
||||||
try:
|
try:
|
||||||
if self._protocol.peer_manager.peer_is_good(peer):
|
if self._protocol.peer_manager.peer_is_good(peer):
|
||||||
if peer not in self._protocol.routing_table.get_peers():
|
if peer not in self._protocol.routing_table.get_peers():
|
||||||
await self._protocol.add_peer(peer)
|
self._protocol.add_peer(peer)
|
||||||
return
|
return
|
||||||
await self._protocol.get_rpc_peer(peer).ping()
|
await self._protocol.get_rpc_peer(peer).ping()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
@ -268,11 +270,21 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
|
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
|
||||||
self.rpc_timeout = rpc_timeout
|
self.rpc_timeout = rpc_timeout
|
||||||
self._split_lock = asyncio.Lock(loop=self.loop)
|
self._split_lock = asyncio.Lock(loop=self.loop)
|
||||||
|
self._to_remove: typing.Set['KademliaPeer'] = set()
|
||||||
|
self._to_add: typing.Set['KademliaPeer'] = set()
|
||||||
|
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
|
||||||
|
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
@functools.lru_cache(128)
|
||||||
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
||||||
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.maintaing_routing_task = self.loop.create_task(self.routing_table_task())
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
if self.maintaing_routing_task:
|
||||||
|
self.maintaing_routing_task.cancel()
|
||||||
if self.transport:
|
if self.transport:
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
|
@ -363,13 +375,30 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
self.routing_table.buckets[bucket_index].remove_peer(to_replace)
|
self.routing_table.buckets[bucket_index].remove_peer(to_replace)
|
||||||
return await self._add_peer(peer)
|
return await self._add_peer(peer)
|
||||||
|
|
||||||
async def add_peer(self, peer: 'KademliaPeer') -> bool:
|
def add_peer(self, peer: 'KademliaPeer'):
|
||||||
if peer.node_id == self.node_id:
|
if peer.node_id == self.node_id:
|
||||||
return False
|
return False
|
||||||
async with self._split_lock:
|
self._to_add.add(peer)
|
||||||
return await self._add_peer(peer)
|
self._wakeup_routing_task.set()
|
||||||
|
|
||||||
async def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
|
def remove_peer(self, peer: 'KademliaPeer'):
|
||||||
|
self._to_remove.add(peer)
|
||||||
|
self._wakeup_routing_task.set()
|
||||||
|
|
||||||
|
async def routing_table_task(self):
|
||||||
|
while True:
|
||||||
|
while self._to_remove:
|
||||||
|
async with self._split_lock:
|
||||||
|
peer = self._to_remove.pop()
|
||||||
|
self.routing_table.remove_peer(peer)
|
||||||
|
self.routing_table.join_buckets()
|
||||||
|
while self._to_add:
|
||||||
|
async with self._split_lock:
|
||||||
|
await self._add_peer(self._to_add.pop())
|
||||||
|
await asyncio.gather(self._wakeup_routing_task.wait(), asyncio.sleep(.1, loop=self.loop), loop=self.loop)
|
||||||
|
self._wakeup_routing_task.clear()
|
||||||
|
|
||||||
|
def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
|
||||||
assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
|
assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
|
||||||
binascii.hexlify(self.node_id)[:8].decode())
|
binascii.hexlify(self.node_id)[:8].decode())
|
||||||
method = message.method
|
method = message.method
|
||||||
|
@ -396,11 +425,11 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
key, = a
|
key, = a
|
||||||
result = self.node_rpc.find_value(sender_contact, key)
|
result = self.node_rpc.find_value(sender_contact, key)
|
||||||
|
|
||||||
await self.send_response(
|
self.send_response(
|
||||||
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram):
|
def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram):
|
||||||
# This is an RPC method request
|
# This is an RPC method request
|
||||||
self.peer_manager.report_last_requested(address[0], address[1])
|
self.peer_manager.report_last_requested(address[0], address[1])
|
||||||
try:
|
try:
|
||||||
|
@ -408,7 +437,7 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
except IndexError:
|
except IndexError:
|
||||||
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
|
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
|
||||||
try:
|
try:
|
||||||
await self._handle_rpc(peer, request_datagram)
|
self._handle_rpc(peer, request_datagram)
|
||||||
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
|
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
|
||||||
# will be added to our routing table if successful
|
# will be added to our routing table if successful
|
||||||
is_good = self.peer_manager.peer_is_good(peer)
|
is_good = self.peer_manager.peer_is_good(peer)
|
||||||
|
@ -416,12 +445,12 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
self.ping_queue.enqueue_maybe_ping(peer)
|
self.ping_queue.enqueue_maybe_ping(peer)
|
||||||
# only add a requesting contact to the routing table if it has replied to one of our requests
|
# only add a requesting contact to the routing table if it has replied to one of our requests
|
||||||
elif is_good is True:
|
elif is_good is True:
|
||||||
await self.add_peer(peer)
|
self.add_peer(peer)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
log.debug("error raised handling %s request from %s:%i - %s(%s)",
|
log.debug("error raised handling %s request from %s:%i - %s(%s)",
|
||||||
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
|
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
|
||||||
str(err))
|
str(err))
|
||||||
await self.send_error(
|
self.send_error(
|
||||||
peer,
|
peer,
|
||||||
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
|
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
|
||||||
str(err).encode())
|
str(err).encode())
|
||||||
|
@ -430,13 +459,13 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
log.warning("error raised handling %s request from %s:%i - %s(%s)",
|
log.warning("error raised handling %s request from %s:%i - %s(%s)",
|
||||||
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
|
request_datagram.method, peer.address, peer.udp_port, str(type(err)),
|
||||||
str(err))
|
str(err))
|
||||||
await self.send_error(
|
self.send_error(
|
||||||
peer,
|
peer,
|
||||||
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
|
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
|
||||||
str(err).encode())
|
str(err).encode())
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram):
|
def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram):
|
||||||
# Find the message that triggered this response
|
# Find the message that triggered this response
|
||||||
if response_datagram.rpc_id in self.sent_messages:
|
if response_datagram.rpc_id in self.sent_messages:
|
||||||
peer, df, request = self.sent_messages[response_datagram.rpc_id]
|
peer, df, request = self.sent_messages[response_datagram.rpc_id]
|
||||||
|
@ -459,7 +488,7 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1])
|
self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1])
|
||||||
if not df.cancelled():
|
if not df.cancelled():
|
||||||
df.set_result(response_datagram)
|
df.set_result(response_datagram)
|
||||||
await self.add_peer(peer)
|
self.add_peer(peer)
|
||||||
else:
|
else:
|
||||||
log.warning("%s:%i replied, but after we cancelled the request attempt",
|
log.warning("%s:%i replied, but after we cancelled the request attempt",
|
||||||
peer.address, peer.udp_port)
|
peer.address, peer.udp_port)
|
||||||
|
@ -510,15 +539,15 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(message, RequestDatagram):
|
if isinstance(message, RequestDatagram):
|
||||||
self.loop.create_task(self.handle_request_datagram(address, message))
|
self.handle_request_datagram(address, message)
|
||||||
elif isinstance(message, ErrorDatagram):
|
elif isinstance(message, ErrorDatagram):
|
||||||
self.handle_error_datagram(address, message)
|
self.handle_error_datagram(address, message)
|
||||||
else:
|
else:
|
||||||
assert isinstance(message, ResponseDatagram), "sanity"
|
assert isinstance(message, ResponseDatagram), "sanity"
|
||||||
self.loop.create_task(self.handle_response_datagram(address, message))
|
self.handle_response_datagram(address, message)
|
||||||
|
|
||||||
async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram:
|
async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram:
|
||||||
await self._send(peer, request)
|
self._send(peer, request)
|
||||||
response_fut = self.sent_messages[request.rpc_id][1]
|
response_fut = self.sent_messages[request.rpc_id][1]
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(response_fut, self.rpc_timeout)
|
response = await asyncio.wait_for(response_fut, self.rpc_timeout)
|
||||||
|
@ -531,18 +560,16 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
except (asyncio.TimeoutError, RemoteException):
|
except (asyncio.TimeoutError, RemoteException):
|
||||||
self.peer_manager.report_failure(peer.address, peer.udp_port)
|
self.peer_manager.report_failure(peer.address, peer.udp_port)
|
||||||
if self.peer_manager.peer_is_good(peer) is False:
|
if self.peer_manager.peer_is_good(peer) is False:
|
||||||
async with self._split_lock:
|
self.remove_peer(peer)
|
||||||
self.routing_table.remove_peer(peer)
|
|
||||||
self.routing_table.join_buckets()
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram):
|
def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram):
|
||||||
await self._send(peer, response)
|
self._send(peer, response)
|
||||||
|
|
||||||
async def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
|
||||||
await self._send(peer, error)
|
self._send(peer, error)
|
||||||
|
|
||||||
async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
|
def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
|
||||||
ErrorDatagram]):
|
ErrorDatagram]):
|
||||||
if not self.transport or self.transport.is_closing():
|
if not self.transport or self.transport.is_closing():
|
||||||
raise TransportNotConnected()
|
raise TransportNotConnected()
|
||||||
|
@ -602,12 +629,15 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def store_to_peer(self, hash_value: bytes, peer: 'KademliaPeer') -> typing.Tuple[bytes, bool]:
|
async def store_to_peer(self, hash_value: bytes, peer: 'KademliaPeer') -> typing.Tuple[bytes, bool]:
|
||||||
try:
|
async def __store():
|
||||||
res = await self.get_rpc_peer(peer).store(hash_value)
|
res = await self.get_rpc_peer(peer).store(hash_value)
|
||||||
if res != b"OK":
|
if res != b"OK":
|
||||||
raise ValueError(res)
|
raise ValueError(res)
|
||||||
log.debug("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
log.debug("Stored %s to %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||||
return peer.node_id, True
|
return peer.node_id, True
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await __store()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
log.debug("Timeout while storing blob_hash %s at %s", binascii.hexlify(hash_value).decode()[:8], peer)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
@ -615,6 +645,10 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if 'Invalid token' in str(err):
|
if 'Invalid token' in str(err):
|
||||||
self.peer_manager.clear_token(peer.node_id)
|
self.peer_manager.clear_token(peer.node_id)
|
||||||
|
try:
|
||||||
|
return await __store()
|
||||||
|
except:
|
||||||
|
return peer.node_id, False
|
||||||
else:
|
else:
|
||||||
log.exception("Unexpected error while storing blob_hash")
|
log.exception("Unexpected error while storing blob_hash")
|
||||||
return peer.node_id, False
|
return peer.node_id, False
|
||||||
|
|
|
@ -83,6 +83,8 @@ class StreamDownloader:
|
||||||
# set up peer accumulation
|
# set up peer accumulation
|
||||||
if node:
|
if node:
|
||||||
self.node = node
|
self.node = node
|
||||||
|
if self.accumulate_task and not self.accumulate_task.done():
|
||||||
|
self.accumulate_task.cancel()
|
||||||
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
||||||
await self.add_fixed_peers()
|
await self.add_fixed_peers()
|
||||||
# start searching for peers for the sd hash
|
# start searching for peers for the sd hash
|
||||||
|
|
115
tests/integration/test_dht.py
Normal file
115
tests/integration/test_dht.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import asyncio
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
from lbrynet.dht import constants
|
||||||
|
from lbrynet.dht.node import Node
|
||||||
|
from lbrynet.dht.peer import PeerManager, KademliaPeer
|
||||||
|
from torba.testcase import AsyncioTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class DHTIntegrationTest(AsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
import logging
|
||||||
|
logging.getLogger('asyncio').setLevel(logging.ERROR)
|
||||||
|
logging.getLogger('lbrynet.dht').setLevel(logging.WARN)
|
||||||
|
self.nodes = []
|
||||||
|
self.known_node_addresses = []
|
||||||
|
|
||||||
|
async def setup_network(self, size: int, start_port=40000, seed_nodes=1):
|
||||||
|
for i in range(size):
|
||||||
|
node_port = start_port + i
|
||||||
|
node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i),
|
||||||
|
udp_port=node_port, internal_udp_port=node_port,
|
||||||
|
peer_port=3333, external_ip='127.0.0.1')
|
||||||
|
self.nodes.append(node)
|
||||||
|
self.known_node_addresses.append(('127.0.0.1', node_port))
|
||||||
|
await node.start_listening('127.0.0.1')
|
||||||
|
self.addCleanup(node.stop)
|
||||||
|
for node in self.nodes:
|
||||||
|
node.protocol.rpc_timeout = .5
|
||||||
|
node.protocol.ping_queue._default_delay = .5
|
||||||
|
node.start('127.0.0.1', self.known_node_addresses[:seed_nodes])
|
||||||
|
await asyncio.gather(*[node.joined.wait() for node in self.nodes])
|
||||||
|
|
||||||
|
async def test_replace_bad_nodes(self):
|
||||||
|
await self.setup_network(20)
|
||||||
|
self.assertEquals(len(self.nodes), 20)
|
||||||
|
node = self.nodes[0]
|
||||||
|
bad_peers = []
|
||||||
|
for candidate in self.nodes[1:10]:
|
||||||
|
address, port, node_id = candidate.protocol.external_ip, candidate.protocol.udp_port, candidate.protocol.node_id
|
||||||
|
peer = KademliaPeer(self.loop, address, node_id, port)
|
||||||
|
bad_peers.append(peer)
|
||||||
|
node.protocol.add_peer(peer)
|
||||||
|
candidate.stop()
|
||||||
|
await asyncio.sleep(.3) # let pending events settle
|
||||||
|
for bad_peer in bad_peers:
|
||||||
|
self.assertIn(bad_peer, node.protocol.routing_table.get_peers())
|
||||||
|
await node.refresh_node(True)
|
||||||
|
await asyncio.sleep(.3) # let pending events settle
|
||||||
|
good_nodes = {good_node.protocol.node_id for good_node in self.nodes[10:]}
|
||||||
|
for peer in node.protocol.routing_table.get_peers():
|
||||||
|
self.assertIn(peer.node_id, good_nodes)
|
||||||
|
|
||||||
|
async def test_re_join(self):
|
||||||
|
await self.setup_network(20, seed_nodes=10)
|
||||||
|
node = self.nodes[-1]
|
||||||
|
self.assertTrue(node.joined.is_set())
|
||||||
|
self.assertTrue(node.protocol.routing_table.get_peers())
|
||||||
|
for network_node in self.nodes[:-1]:
|
||||||
|
network_node.stop()
|
||||||
|
await node.refresh_node(True)
|
||||||
|
await asyncio.sleep(.3) # let pending events settle
|
||||||
|
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||||
|
for network_node in self.nodes[:-1]:
|
||||||
|
network_node.start('127.0.0.1', self.known_node_addresses)
|
||||||
|
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||||
|
timeout = 20
|
||||||
|
while not node.protocol.routing_table.get_peers():
|
||||||
|
await asyncio.sleep(.1)
|
||||||
|
timeout -= 1
|
||||||
|
if not timeout:
|
||||||
|
self.fail("node didnt join back after 2 seconds")
|
||||||
|
|
||||||
|
async def test_announce_no_peers(self):
|
||||||
|
await self.setup_network(1)
|
||||||
|
node = self.nodes[0]
|
||||||
|
blob_hash = hexlify(constants.generate_id(1337)).decode()
|
||||||
|
peers = await node.announce_blob(blob_hash)
|
||||||
|
self.assertEqual(len(peers), 0)
|
||||||
|
|
||||||
|
async def test_get_token_on_announce(self):
|
||||||
|
await self.setup_network(2, seed_nodes=2)
|
||||||
|
node1, node2 = self.nodes
|
||||||
|
node1.protocol.peer_manager.clear_token(node2.protocol.node_id)
|
||||||
|
blob_hash = hexlify(constants.generate_id(1337)).decode()
|
||||||
|
node_ids = await node1.announce_blob(blob_hash)
|
||||||
|
self.assertIn(node2.protocol.node_id, node_ids)
|
||||||
|
node2.protocol.node_rpc.refresh_token()
|
||||||
|
node_ids = await node1.announce_blob(blob_hash)
|
||||||
|
self.assertIn(node2.protocol.node_id, node_ids)
|
||||||
|
node2.protocol.node_rpc.refresh_token()
|
||||||
|
node_ids = await node1.announce_blob(blob_hash)
|
||||||
|
self.assertIn(node2.protocol.node_id, node_ids)
|
||||||
|
|
||||||
|
async def test_peer_search_removes_bad_peers(self):
|
||||||
|
# that's an edge case discovered by Tom, but an important one
|
||||||
|
# imagine that you only got bad peers and refresh will happen in one hour
|
||||||
|
# instead of failing for one hour we should be able to recover by scheduling pings to bad peers we find
|
||||||
|
await self.setup_network(2, seed_nodes=2)
|
||||||
|
node1, node2 = self.nodes
|
||||||
|
node2.stop()
|
||||||
|
# forcefully make it a bad peer but dont remove it from routing table
|
||||||
|
address, port, node_id = node2.protocol.external_ip, node2.protocol.udp_port, node2.protocol.node_id
|
||||||
|
peer = KademliaPeer(self.loop, address, node_id, port)
|
||||||
|
self.assertTrue(node1.protocol.peer_manager.peer_is_good(peer))
|
||||||
|
node1.protocol.peer_manager.report_failure(node2.protocol.external_ip, node2.protocol.udp_port)
|
||||||
|
node1.protocol.peer_manager.report_failure(node2.protocol.external_ip, node2.protocol.udp_port)
|
||||||
|
self.assertFalse(node1.protocol.peer_manager.peer_is_good(peer))
|
||||||
|
|
||||||
|
# now a search happens, which removes bad peers while contacting them
|
||||||
|
self.assertTrue(node1.protocol.routing_table.get_peers())
|
||||||
|
await node1.peer_search(node2.protocol.node_id)
|
||||||
|
await asyncio.sleep(.3) # let pending events settle
|
||||||
|
self.assertFalse(node1.protocol.routing_table.get_peers())
|
|
@ -1,102 +0,0 @@
|
||||||
import unittest
|
|
||||||
import asyncio
|
|
||||||
from torba.testcase import AsyncioTestCase
|
|
||||||
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
|
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncGen:
|
|
||||||
def __init__(self, loop, result, delay, stop_cnt=10):
|
|
||||||
self.loop = loop
|
|
||||||
self.result = result
|
|
||||||
self.delay = delay
|
|
||||||
self.count = 0
|
|
||||||
self.stop_cnt = stop_cnt
|
|
||||||
self.called_close = False
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
await asyncio.sleep(self.delay, loop=self.loop)
|
|
||||||
if self.count > self.stop_cnt - 1:
|
|
||||||
raise StopAsyncIteration()
|
|
||||||
self.count += 1
|
|
||||||
return self.result
|
|
||||||
|
|
||||||
async def aclose(self):
|
|
||||||
self.called_close = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncGeneratorJunction(AsyncioTestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
async def _test_junction(self, expected, *generators):
|
|
||||||
order = []
|
|
||||||
async with AsyncGeneratorJunction(self.loop) as junction:
|
|
||||||
for generator in generators:
|
|
||||||
junction.add_generator(generator)
|
|
||||||
async for item in junction:
|
|
||||||
order.append(item)
|
|
||||||
self.assertListEqual(order, expected)
|
|
||||||
|
|
||||||
async def test_yield_order(self):
|
|
||||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2]
|
|
||||||
fast_gen = MockAsyncGen(self.loop, 1, 0.2)
|
|
||||||
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
|
|
||||||
await self._test_junction(expected_order, fast_gen, slow_gen)
|
|
||||||
self.assertEqual(fast_gen.called_close, True)
|
|
||||||
self.assertEqual(slow_gen.called_close, True)
|
|
||||||
|
|
||||||
async def test_nothing_to_yield(self):
|
|
||||||
async def __nothing():
|
|
||||||
for _ in []:
|
|
||||||
yield self.fail("nada")
|
|
||||||
await self._test_junction([], __nothing())
|
|
||||||
|
|
||||||
async def test_fast_iteratiors(self):
|
|
||||||
async def __gotta_go_fast():
|
|
||||||
for _ in range(10):
|
|
||||||
yield 0
|
|
||||||
await self._test_junction([0]*40, __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast())
|
|
||||||
|
|
||||||
@unittest.SkipTest
|
|
||||||
async def test_one_stopped_first(self):
|
|
||||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
|
|
||||||
fast_gen = MockAsyncGen(self.loop, 1, 0.2, 5)
|
|
||||||
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
|
|
||||||
await self._test_junction(expected_order, fast_gen, slow_gen)
|
|
||||||
self.assertEqual(fast_gen.called_close, True)
|
|
||||||
self.assertEqual(slow_gen.called_close, True)
|
|
||||||
|
|
||||||
async def test_with_non_async_gen_class(self):
|
|
||||||
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
|
|
||||||
|
|
||||||
async def fast_gen():
|
|
||||||
for i in range(10):
|
|
||||||
if i == 5:
|
|
||||||
return
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
yield 1
|
|
||||||
|
|
||||||
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
|
|
||||||
await self._test_junction(expected_order, fast_gen(), slow_gen)
|
|
||||||
self.assertEqual(slow_gen.called_close, True)
|
|
||||||
|
|
||||||
async def test_stop_when_encapsulating_task_cancelled(self):
|
|
||||||
fast_gen = MockAsyncGen(self.loop, 1, 0.2)
|
|
||||||
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
|
|
||||||
|
|
||||||
async def _task():
|
|
||||||
async with AsyncGeneratorJunction(self.loop) as junction:
|
|
||||||
junction.add_generator(fast_gen)
|
|
||||||
junction.add_generator(slow_gen)
|
|
||||||
async for _ in junction:
|
|
||||||
pass
|
|
||||||
|
|
||||||
task = self.loop.create_task(_task())
|
|
||||||
self.loop.call_later(1.0, task.cancel)
|
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
|
||||||
await task
|
|
||||||
self.assertEqual(fast_gen.called_close, True)
|
|
||||||
self.assertEqual(slow_gen.called_close, True)
|
|
|
@ -99,6 +99,7 @@ class TestProtocol(AsyncioTestCase):
|
||||||
self.loop, PeerManager(self.loop), node_id, address, udp_port, tcp_port
|
self.loop, PeerManager(self.loop), node_id, address, udp_port, tcp_port
|
||||||
)
|
)
|
||||||
await self.loop.create_datagram_endpoint(lambda: proto, (address, 4444))
|
await self.loop.create_datagram_endpoint(lambda: proto, (address, 4444))
|
||||||
|
proto.start()
|
||||||
return proto, other_peer.peer_manager.get_kademlia_peer(node_id, address, udp_port=udp_port)
|
return proto, other_peer.peer_manager.get_kademlia_peer(node_id, address, udp_port=udp_port)
|
||||||
|
|
||||||
async def test_add_peer_after_handle_request(self):
|
async def test_add_peer_after_handle_request(self):
|
||||||
|
@ -112,6 +113,7 @@ class TestProtocol(AsyncioTestCase):
|
||||||
self.loop, PeerManager(self.loop), node_id1, '1.2.3.4', 4444, 3333
|
self.loop, PeerManager(self.loop), node_id1, '1.2.3.4', 4444, 3333
|
||||||
)
|
)
|
||||||
await self.loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
|
await self.loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
|
||||||
|
peer1.start()
|
||||||
|
|
||||||
peer2, peer_2_from_peer_1 = await self._make_protocol(peer1, node_id2, '1.2.3.5', 4444, 3333)
|
peer2, peer_2_from_peer_1 = await self._make_protocol(peer1, node_id2, '1.2.3.5', 4444, 3333)
|
||||||
peer3, peer_3_from_peer_1 = await self._make_protocol(peer1, node_id3, '1.2.3.6', 4444, 3333)
|
peer3, peer_3_from_peer_1 = await self._make_protocol(peer1, node_id3, '1.2.3.6', 4444, 3333)
|
||||||
|
@ -119,6 +121,7 @@ class TestProtocol(AsyncioTestCase):
|
||||||
|
|
||||||
# peers who reply should be added
|
# peers who reply should be added
|
||||||
await peer1.get_rpc_peer(peer_2_from_peer_1).ping()
|
await peer1.get_rpc_peer(peer_2_from_peer_1).ping()
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
self.assertListEqual([peer_2_from_peer_1], peer1.routing_table.get_peers())
|
self.assertListEqual([peer_2_from_peer_1], peer1.routing_table.get_peers())
|
||||||
peer1.routing_table.remove_peer(peer_2_from_peer_1)
|
peer1.routing_table.remove_peer(peer_2_from_peer_1)
|
||||||
|
|
||||||
|
@ -137,6 +140,7 @@ class TestProtocol(AsyncioTestCase):
|
||||||
self.assertEqual(0, len(peer1.ping_queue._pending_contacts))
|
self.assertEqual(0, len(peer1.ping_queue._pending_contacts))
|
||||||
pong = await peer1_from_peer4.ping()
|
pong = await peer1_from_peer4.ping()
|
||||||
self.assertEqual(b'pong', pong)
|
self.assertEqual(b'pong', pong)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
self.assertEqual(1, len(peer1.routing_table.get_peers()))
|
self.assertEqual(1, len(peer1.routing_table.get_peers()))
|
||||||
self.assertEqual(0, len(peer1.ping_queue._pending_contacts))
|
self.assertEqual(0, len(peer1.ping_queue._pending_contacts))
|
||||||
peer1.routing_table.buckets[0].peers.clear()
|
peer1.routing_table.buckets[0].peers.clear()
|
||||||
|
|
|
@ -57,7 +57,7 @@ class TestRouting(AsyncioTestCase):
|
||||||
node.protocol.node_id, node.protocol.external_ip,
|
node.protocol.node_id, node.protocol.external_ip,
|
||||||
udp_port=node.protocol.udp_port
|
udp_port=node.protocol.udp_port
|
||||||
)
|
)
|
||||||
added = await node_1.protocol.add_peer(peer)
|
added = await node_1.protocol._add_peer(peer)
|
||||||
self.assertEqual(True, added)
|
self.assertEqual(True, added)
|
||||||
contact_cnt += 1
|
contact_cnt += 1
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ class TestRouting(AsyncioTestCase):
|
||||||
# set all of the peers to good (as to not attempt pinging stale ones during split)
|
# set all of the peers to good (as to not attempt pinging stale ones during split)
|
||||||
node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port)
|
node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port)
|
||||||
node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port)
|
node_1.protocol.peer_manager.report_last_replied(peer.address, peer.udp_port)
|
||||||
await node_1.protocol.add_peer(peer)
|
await node_1.protocol._add_peer(peer)
|
||||||
# check that bucket 0 is always the one covering the local node id
|
# check that bucket 0 is always the one covering the local node id
|
||||||
self.assertEqual(True, node_1.protocol.routing_table.buckets[0].key_in_range(node_1.protocol.node_id))
|
self.assertEqual(True, node_1.protocol.routing_table.buckets[0].key_in_range(node_1.protocol.node_id))
|
||||||
self.assertEqual(40, len(node_1.protocol.routing_table.get_peers()))
|
self.assertEqual(40, len(node_1.protocol.routing_table.get_peers()))
|
||||||
|
|
|
@ -32,7 +32,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
await n.start_listening(address)
|
await n.start_listening(address)
|
||||||
self.nodes.update({len(self.nodes): n})
|
self.nodes.update({len(self.nodes): n})
|
||||||
if add_to_routing_table:
|
if add_to_routing_table:
|
||||||
await self.node.protocol.add_peer(
|
self.node.protocol.add_peer(
|
||||||
self.peer_manager.get_kademlia_peer(
|
self.peer_manager.get_kademlia_peer(
|
||||||
n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port
|
n.protocol.node_id, n.protocol.external_ip, n.protocol.udp_port
|
||||||
)
|
)
|
||||||
|
@ -86,7 +86,8 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
to_announce = await self.storage.get_blobs_to_announce()
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
self.assertEqual(2, len(to_announce))
|
self.assertEqual(2, len(to_announce))
|
||||||
self.blob_announcer.start(batch_size=1) # so it covers batching logic
|
self.blob_announcer.start(batch_size=1) # so it covers batching logic
|
||||||
await self.advance(61.0)
|
# takes 60 seconds to start, but we advance 120 to ensure it processed all batches
|
||||||
|
await self.advance(60.0 * 2)
|
||||||
to_announce = await self.storage.get_blobs_to_announce()
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
self.assertEqual(0, len(to_announce))
|
self.assertEqual(0, len(to_announce))
|
||||||
self.blob_announcer.stop()
|
self.blob_announcer.stop()
|
||||||
|
@ -98,6 +99,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
await self.chain_peer(constants.generate_id(12), '1.2.3.12')
|
await self.chain_peer(constants.generate_id(12), '1.2.3.12')
|
||||||
await self.chain_peer(constants.generate_id(13), '1.2.3.13')
|
await self.chain_peer(constants.generate_id(13), '1.2.3.13')
|
||||||
await self.chain_peer(constants.generate_id(14), '1.2.3.14')
|
await self.chain_peer(constants.generate_id(14), '1.2.3.14')
|
||||||
|
await self.advance(61.0)
|
||||||
|
|
||||||
last = self.nodes[len(self.nodes) - 1]
|
last = self.nodes[len(self.nodes) - 1]
|
||||||
search_q, peer_q = asyncio.Queue(loop=self.loop), asyncio.Queue(loop=self.loop)
|
search_q, peer_q = asyncio.Queue(loop=self.loop), asyncio.Queue(loop=self.loop)
|
||||||
|
@ -105,7 +107,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
|
|
||||||
_, task = last.accumulate_peers(search_q, peer_q)
|
_, task = last.accumulate_peers(search_q, peer_q)
|
||||||
found_peers = await peer_q.get()
|
found_peers = await peer_q.get()
|
||||||
await task
|
task.cancel()
|
||||||
|
|
||||||
self.assertEqual(1, len(found_peers))
|
self.assertEqual(1, len(found_peers))
|
||||||
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)
|
||||||
|
|
Loading…
Reference in a new issue