remove peer junction and some refactoring

This commit is contained in:
Victor Shyba 2019-05-11 04:58:50 -03:00
parent 0da2827c78
commit e675f1387c
7 changed files with 89 additions and 326 deletions

View file

@ -2,11 +2,8 @@ import logging
import asyncio import asyncio
import typing import typing
import binascii import binascii
import contextlib
from lbrynet.utils import resolve_host from lbrynet.utils import resolve_host
from lbrynet.dht import constants from lbrynet.dht import constants
from lbrynet.dht.error import RemoteException
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
from lbrynet.dht.protocol.distance import Distance from lbrynet.dht.protocol.distance import Distance
from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder from lbrynet.dht.protocol.iterative_find import IterativeNodeFinder, IterativeValueFinder
from lbrynet.dht.protocol.protocol import KademliaProtocol from lbrynet.dht.protocol.protocol import KademliaProtocol
@ -138,28 +135,13 @@ class Node:
url_to_addr[address] = host url_to_addr[address] = host
if known_node_addresses: if known_node_addresses:
while not self.protocol.routing_table.get_peers(): peers = [
success = False KademliaPeer(self.loop, address, udp_port=port)
# ping the seed nodes, this will set their node ids (since we don't know them ahead of time) for (address, port) in known_node_addresses
for address, port in known_node_addresses: ]
peer = self.protocol.get_rpc_peer(KademliaPeer(self.loop, address, udp_port=port)) while not len(self.protocol.routing_table.get_peers()):
try: peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
await peer.ping() self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
success = True
except asyncio.TimeoutError:
log.warning("seed node (%s:%i) timed out in %s", url_to_addr.get(address, address), port,
round(self.protocol.rpc_timeout, 2))
if success:
break
# now that we have the seed nodes in routing, to an iterative lookup of our own id to populate the buckets
# in the routing table with good peers who are near us
async with self.peer_search_junction(self.protocol.node_id, max_results=16) as junction:
async for peers in junction:
for peer in peers:
try:
await self.protocol.get_rpc_peer(peer).ping()
except (asyncio.TimeoutError, RemoteException):
pass
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()), log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
self.protocol.routing_table.buckets_with_contacts()) self.protocol.routing_table.buckets_with_contacts())
@ -186,61 +168,40 @@ class Node:
return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol, return IterativeValueFinder(self.loop, self.protocol.peer_manager, self.protocol.routing_table, self.protocol,
key, bottom_out_limit, max_results, None, shortlist) key, bottom_out_limit, max_results, None, shortlist)
@contextlib.asynccontextmanager
async def stream_peer_search_junction(self, hash_queue: asyncio.Queue, bottom_out_limit=20,
max_results=-1) -> AsyncGeneratorJunction:
peer_generator = AsyncGeneratorJunction(self.loop)
async def _add_hashes_from_queue():
while True:
blob_hash = await hash_queue.get()
peer_generator.add_generator(
self.get_iterative_value_finder(
binascii.unhexlify(blob_hash.encode()), bottom_out_limit=bottom_out_limit,
max_results=max_results
)
)
add_hashes_task = self.loop.create_task(_add_hashes_from_queue())
try:
async with peer_generator as junction:
yield junction
finally:
if add_hashes_task and not (add_hashes_task.done() or add_hashes_task.cancelled()):
add_hashes_task.cancel()
def peer_search_junction(self, node_id: bytes, max_results=constants.k*2,
bottom_out_limit=20) -> AsyncGeneratorJunction:
peer_generator = AsyncGeneratorJunction(self.loop)
peer_generator.add_generator(
self.get_iterative_node_finder(
node_id, bottom_out_limit=bottom_out_limit, max_results=max_results
)
)
return peer_generator
async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2, async def peer_search(self, node_id: bytes, count=constants.k, max_results=constants.k*2,
bottom_out_limit=20) -> typing.List['KademliaPeer']: bottom_out_limit=20, shortlist: typing.Optional[typing.List] = None
accumulated: typing.List['KademliaPeer'] = [] ) -> typing.List['KademliaPeer']:
async with self.peer_search_junction(node_id, max_results=max_results, peers = []
bottom_out_limit=bottom_out_limit) as junction: async for iteration_peers in self.get_iterative_node_finder(
async for peers in junction: node_id, shortlist=shortlist, bottom_out_limit=bottom_out_limit, max_results=max_results):
accumulated.extend(peers) peers.extend(iteration_peers)
distance = Distance(node_id) distance = Distance(node_id)
accumulated.sort(key=lambda peer: distance(peer.node_id)) peers.sort(key=lambda peer: distance(peer.node_id))
return accumulated[:count] return peers[:count]
async def _accumulate_search_junction(self, search_queue: asyncio.Queue, async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
result_queue: asyncio.Queue): result_queue: asyncio.Queue):
async with self.stream_peer_search_junction(search_queue) as search_junction: # pylint: disable=E1701 ongoing = {}
async for peers in search_junction: async def __start_producing_task():
if peers: while True:
result_queue.put_nowait([ blob_hash = await search_queue.get()
peer for peer in peers ongoing[blob_hash] = asyncio.create_task(self._value_producer(blob_hash, result_queue))
if not ( ongoing[''] = asyncio.create_task(__start_producing_task())
peer.address == self.protocol.external_ip try:
and peer.tcp_port == self.protocol.peer_port while True:
) await asyncio.wait(ongoing.values(), return_when='FIRST_COMPLETED')
]) for key in list(ongoing.keys())[:]:
if key and ongoing[key].done():
ongoing[key] = asyncio.create_task(self._value_producer(key, result_queue))
finally:
for task in ongoing.values():
task.cancel()
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
log.info("Searching %s", blob_hash[:8])
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
result_queue.put_nowait(results)
log.info("Search expired %s", blob_hash[:8])
def accumulate_peers(self, search_queue: asyncio.Queue, def accumulate_peers(self, search_queue: asyncio.Queue,
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[ peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[

View file

@ -1,94 +0,0 @@
import asyncio
import typing
import logging
import traceback
if typing.TYPE_CHECKING:
from types import AsyncGeneratorType
log = logging.getLogger(__name__)
def cancel_task(task: typing.Optional[asyncio.Task]):
if task and not (task.done() or task.cancelled()):
task.cancel()
def drain_tasks(tasks: typing.List[typing.Optional[asyncio.Task]]):
while tasks:
cancel_task(tasks.pop())
class AsyncGeneratorJunction:
"""
A helper to interleave the results from multiple async generators into one
async generator.
"""
def __init__(self, loop: asyncio.BaseEventLoop, queue: typing.Optional[asyncio.Queue] = None):
self.loop = loop
self.__iterator_queue = asyncio.Queue(loop=loop)
self.result_queue = queue or asyncio.Queue(loop=loop)
self.tasks: typing.List[asyncio.Task] = []
self.running_iterators: typing.Dict[typing.AsyncGenerator, bool] = {}
self.generator_queue: asyncio.Queue = asyncio.Queue(loop=self.loop)
@property
def running(self):
return any(self.running_iterators.values())
async def wait_for_generators(self):
async def iterate(iterator: typing.AsyncGenerator):
try:
async for item in iterator:
self.result_queue.put_nowait(item)
self.__iterator_queue.put_nowait(item)
finally:
self.running_iterators[iterator] = False
if not self.running:
self.__iterator_queue.put_nowait(StopAsyncIteration)
while True:
async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType'] = await self.generator_queue.get()
self.running_iterators[async_gen] = True
self.tasks.append(self.loop.create_task(iterate(async_gen)))
def add_generator(self, async_gen: typing.Union[typing.AsyncGenerator, 'AsyncGeneratorType']):
"""
Add an async generator. This can be called during an iteration of the generator junction.
"""
self.generator_queue.put_nowait(async_gen)
def __aiter__(self):
return self
async def __anext__(self):
result = await self.__iterator_queue.get()
if result is StopAsyncIteration:
raise result
return result
def aclose(self):
async def _aclose():
for iterator in list(self.running_iterators.keys()):
result = iterator.aclose()
if asyncio.iscoroutine(result):
await result
self.running_iterators[iterator] = False
drain_tasks(self.tasks)
raise StopAsyncIteration()
return self.loop.create_task(_aclose())
async def __aenter__(self):
self.tasks.append(self.loop.create_task(self.wait_for_generators()))
return self
async def __aexit__(self, exc_type, exc, tb):
try:
await self.aclose()
except StopAsyncIteration:
pass
finally:
if exc_type:
if exc_type not in (asyncio.CancelledError, asyncio.TimeoutError, StopAsyncIteration, GeneratorExit):
err = traceback.format_exception(exc_type, exc, tb)
log.error(err)

View file

@ -66,11 +66,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
""" """
if len(key) != constants.hash_length: if len(key) != constants.hash_length:
raise ValueError("invalid key length: %i" % len(key)) raise ValueError("invalid key length: %i" % len(key))
if not shortlist: return shortlist or routing_table.find_close_peers(key)
shortlist = routing_table.find_close_peers(key)
distance = Distance(key)
shortlist.sort(key=lambda peer: distance(peer.node_id), reverse=True)
return shortlist
class IterativeFinder: class IterativeFinder:
@ -92,11 +88,11 @@ class IterativeFinder:
self.exclude = exclude or [] self.exclude = exclude or []
self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist) self.shortlist: typing.List['KademliaPeer'] = get_shortlist(routing_table, key, shortlist)
self.active: typing.Set['KademliaPeer'] = set() self.active: typing.List['KademliaPeer'] = []
self.contacted: typing.Set[typing.Tuple[str, int]] = set() self.contacted: typing.Set[typing.Tuple[str, int]] = set()
self.distance = Distance(key) self.distance = Distance(key)
self.closest_peer: typing.Optional['KademliaPeer'] = None if not self.shortlist else self.shortlist[0] self.closest_peer: typing.Optional['KademliaPeer'] = None
self.prev_closest_peer: typing.Optional['KademliaPeer'] = None self.prev_closest_peer: typing.Optional['KademliaPeer'] = None
self.iteration_queue = asyncio.Queue(loop=self.loop) self.iteration_queue = asyncio.Queue(loop=self.loop)
@ -139,23 +135,21 @@ class IterativeFinder:
return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id) return not self.closest_peer or self.distance.is_closer(peer.node_id, self.closest_peer.node_id)
def _update_closest(self): def _update_closest(self):
self.shortlist.sort(key=lambda peer: self.distance(peer.node_id), reverse=True) self.active.sort(key=lambda peer: self.distance(peer.node_id))
if self.closest_peer and self.closest_peer is not self.shortlist[-1]: if self.closest_peer and self.closest_peer is not self.active[0]:
if self._is_closer(self.shortlist[-1]): if self._is_closer(self.active[0]):
self.prev_closest_peer = self.closest_peer self.prev_closest_peer = self.closest_peer
self.closest_peer = self.shortlist[-1] self.closest_peer = self.active[0]
async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse): async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindResponse):
if peer not in self.shortlist: if peer not in self.active and peer.node_id:
self.shortlist.append(peer) self.active.append(peer)
if peer not in self.active:
self.active.add(peer)
for contact_triple in response.get_close_triples(): for contact_triple in response.get_close_triples():
node_id, address, udp_port = contact_triple node_id, address, udp_port = contact_triple
if (address, udp_port) not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple) if (address, udp_port) not in self.contacted: # and not self.peer_manager.is_ignored(addr_tuple)
found_peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port) found_peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
if found_peer not in self.shortlist and self.peer_manager.peer_is_good(peer) is not False: if found_peer not in self.active and self.peer_manager.peer_is_good(found_peer) is not False:
self.shortlist.append(found_peer) self.active.append(found_peer)
self._update_closest() self._update_closest()
self.check_result_ready(response) self.check_result_ready(response)
@ -163,11 +157,13 @@ class IterativeFinder:
try: try:
response = await self.send_probe(peer) response = await self.send_probe(peer)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.active.discard(peer) if peer in self.active:
self.active.remove(peer)
return return
except ValueError as err: except ValueError as err:
log.warning(str(err)) log.warning(str(err))
self.active.discard(peer) if peer in self.active:
self.active.remove(peer)
return return
except TransportNotConnected: except TransportNotConnected:
return self.aclose() return self.aclose()
@ -181,18 +177,18 @@ class IterativeFinder:
""" """
added = 0 added = 0
self.shortlist.sort(key=lambda p: self.distance(p.node_id), reverse=True) for peer in chain(self.active, self.shortlist):
while self.running and len(self.shortlist) and added < constants.alpha: if added >= constants.alpha:
peer = self.shortlist.pop() break
origin_address = (peer.address, peer.udp_port) origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False: if origin_address in self.exclude or self.peer_manager.peer_is_good(peer) is False:
continue continue
if peer.node_id == self.protocol.node_id: if peer.node_id == self.protocol.node_id:
continue continue
if (peer.address, peer.udp_port) == (self.protocol.external_ip, self.protocol.udp_port): if origin_address == (self.protocol.external_ip, self.protocol.udp_port):
continue continue
if (peer.address, peer.udp_port) not in self.contacted: if origin_address not in self.contacted:
self.contacted.add((peer.address, peer.udp_port)) self.contacted.add(origin_address)
t = self.loop.create_task(self._send_probe(peer)) t = self.loop.create_task(self._send_probe(peer))
@ -200,7 +196,7 @@ class IterativeFinder:
self.running_probes.difference_update({ self.running_probes.difference_update({
probe for probe in self.running_probes if probe.done() or probe == t probe for probe in self.running_probes if probe.done() or probe == t
}) })
if not self.running_probes and self.shortlist: if not self.running_probes:
self.tasks.append(self.loop.create_task(self._search_task(0.0))) self.tasks.append(self.loop.create_task(self._search_task(0.0)))
t.add_done_callback(callback) t.add_done_callback(callback)
@ -266,6 +262,7 @@ class IterativeNodeFinder(IterativeFinder):
self.yielded_peers: typing.Set['KademliaPeer'] = set() self.yielded_peers: typing.Set['KademliaPeer'] = set()
async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse: async def send_probe(self, peer: 'KademliaPeer') -> FindNodeResponse:
log.debug("probing %s:%d %s", peer.address, peer.udp_port, hexlify(peer.node_id)[:8] if peer.node_id else '')
response = await self.protocol.get_rpc_peer(peer).find_node(self.key) response = await self.protocol.get_rpc_peer(peer).find_node(self.key)
return FindNodeResponse(self.key, response) return FindNodeResponse(self.key, response)
@ -273,7 +270,9 @@ class IterativeNodeFinder(IterativeFinder):
self.put_result(self.active, finish=True) self.put_result(self.active, finish=True)
def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False): def put_result(self, from_iter: typing.Iterable['KademliaPeer'], finish=False):
not_yet_yielded = [peer for peer in from_iter if peer not in self.yielded_peers] not_yet_yielded = [
peer for peer in from_iter if peer not in self.yielded_peers and peer.node_id != self.protocol.node_id
]
not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id)) not_yet_yielded.sort(key=lambda peer: self.distance(peer.node_id))
to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))] to_yield = not_yet_yielded[:min(constants.k, len(not_yet_yielded))]
if to_yield: if to_yield:
@ -288,7 +287,7 @@ class IterativeNodeFinder(IterativeFinder):
if found: if found:
log.debug("found") log.debug("found")
return self.put_result(self.shortlist, finish=True) return self.put_result(self.active, finish=True)
if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer): if self.prev_closest_peer and self.closest_peer and not self._is_closer(self.prev_closest_peer):
# log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted), # log.info("improving, %i %i %i %i %i", len(self.shortlist), len(self.active), len(self.contacted),
# self.bottom_out_count, self.iteration_count) # self.bottom_out_count, self.iteration_count)
@ -300,9 +299,6 @@ class IterativeNodeFinder(IterativeFinder):
if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit: if self.bottom_out_count >= self.bottom_out_limit or self.iteration_count >= self.bottom_out_limit:
log.info("limit hit") log.info("limit hit")
self.put_result(self.active, True) self.put_result(self.active, True)
elif self.max_results and len(self.active) - len(self.yielded_peers) >= self.max_results:
log.debug("max results")
self.put_result(self.active, True)
class IterativeValueFinder(IterativeFinder): class IterativeValueFinder(IterativeFinder):

View file

@ -270,13 +270,14 @@ class KademliaProtocol(DatagramProtocol):
self._split_lock = asyncio.Lock(loop=self.loop) self._split_lock = asyncio.Lock(loop=self.loop)
self._to_remove: typing.Set['KademliaPeer'] = set() self._to_remove: typing.Set['KademliaPeer'] = set()
self._to_add: typing.Set['KademliaPeer'] = set() self._to_add: typing.Set['KademliaPeer'] = set()
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC: def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer) return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
def start(self, force_delay=None): def start(self):
self.maintaing_routing_task = asyncio.create_task(self.routing_table_task(force_delay)) self.maintaing_routing_task = asyncio.create_task(self.routing_table_task())
def stop(self): def stop(self):
if self.maintaing_routing_task: if self.maintaing_routing_task:
@ -376,8 +377,9 @@ class KademliaProtocol(DatagramProtocol):
if peer.node_id == self.node_id: if peer.node_id == self.node_id:
return False return False
self._to_add.add(peer) self._to_add.add(peer)
self._wakeup_routing_task.set()
async def routing_table_task(self, force_delay=None): async def routing_table_task(self):
while True: while True:
while self._to_remove: while self._to_remove:
async with self._split_lock: async with self._split_lock:
@ -388,9 +390,10 @@ class KademliaProtocol(DatagramProtocol):
while self._to_add: while self._to_add:
async with self._split_lock: async with self._split_lock:
await self._add_peer(self._to_add.pop()) await self._add_peer(self._to_add.pop())
await asyncio.sleep(force_delay or constants.rpc_timeout) await asyncio.gather(self._wakeup_routing_task.wait(), asyncio.sleep(0.2))
self._wakeup_routing_task.clear()
async def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram): def _handle_rpc(self, sender_contact: 'KademliaPeer', message: RequestDatagram):
assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(), assert sender_contact.node_id != self.node_id, (binascii.hexlify(sender_contact.node_id)[:8].decode(),
binascii.hexlify(self.node_id)[:8].decode()) binascii.hexlify(self.node_id)[:8].decode())
method = message.method method = message.method
@ -417,11 +420,11 @@ class KademliaProtocol(DatagramProtocol):
key, = a key, = a
result = self.node_rpc.find_value(sender_contact, key) result = self.node_rpc.find_value(sender_contact, key)
await self.send_response( self.send_response(
sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result), sender_contact, ResponseDatagram(RESPONSE_TYPE, message.rpc_id, self.node_id, result),
) )
async def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram): def handle_request_datagram(self, address: typing.Tuple[str, int], request_datagram: RequestDatagram):
# This is an RPC method request # This is an RPC method request
self.peer_manager.report_last_requested(address[0], address[1]) self.peer_manager.report_last_requested(address[0], address[1])
try: try:
@ -429,7 +432,7 @@ class KademliaProtocol(DatagramProtocol):
except IndexError: except IndexError:
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1]) peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
try: try:
await self._handle_rpc(peer, request_datagram) self._handle_rpc(peer, request_datagram)
# if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it # if the contact is not known to be bad (yet) and we haven't yet queried it, send it a ping so that it
# will be added to our routing table if successful # will be added to our routing table if successful
is_good = self.peer_manager.peer_is_good(peer) is_good = self.peer_manager.peer_is_good(peer)
@ -442,7 +445,7 @@ class KademliaProtocol(DatagramProtocol):
log.debug("error raised handling %s request from %s:%i - %s(%s)", log.debug("error raised handling %s request from %s:%i - %s(%s)",
request_datagram.method, peer.address, peer.udp_port, str(type(err)), request_datagram.method, peer.address, peer.udp_port, str(type(err)),
str(err)) str(err))
await self.send_error( self.send_error(
peer, peer,
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(), ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
str(err).encode()) str(err).encode())
@ -451,13 +454,13 @@ class KademliaProtocol(DatagramProtocol):
log.warning("error raised handling %s request from %s:%i - %s(%s)", log.warning("error raised handling %s request from %s:%i - %s(%s)",
request_datagram.method, peer.address, peer.udp_port, str(type(err)), request_datagram.method, peer.address, peer.udp_port, str(type(err)),
str(err)) str(err))
await self.send_error( self.send_error(
peer, peer,
ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(), ErrorDatagram(ERROR_TYPE, request_datagram.rpc_id, self.node_id, str(type(err)).encode(),
str(err).encode()) str(err).encode())
) )
async def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram): def handle_response_datagram(self, address: typing.Tuple[str, int], response_datagram: ResponseDatagram):
# Find the message that triggered this response # Find the message that triggered this response
if response_datagram.rpc_id in self.sent_messages: if response_datagram.rpc_id in self.sent_messages:
peer, df, request = self.sent_messages[response_datagram.rpc_id] peer, df, request = self.sent_messages[response_datagram.rpc_id]
@ -531,15 +534,15 @@ class KademliaProtocol(DatagramProtocol):
return return
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
self.loop.create_task(self.handle_request_datagram(address, message)) self.handle_request_datagram(address, message)
elif isinstance(message, ErrorDatagram): elif isinstance(message, ErrorDatagram):
self.handle_error_datagram(address, message) self.handle_error_datagram(address, message)
else: else:
assert isinstance(message, ResponseDatagram), "sanity" assert isinstance(message, ResponseDatagram), "sanity"
self.loop.create_task(self.handle_response_datagram(address, message)) self.handle_response_datagram(address, message)
async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram: async def send_request(self, peer: 'KademliaPeer', request: RequestDatagram) -> ResponseDatagram:
await self._send(peer, request) self._send(peer, request)
response_fut = self.sent_messages[request.rpc_id][1] response_fut = self.sent_messages[request.rpc_id][1]
try: try:
response = await asyncio.wait_for(response_fut, self.rpc_timeout) response = await asyncio.wait_for(response_fut, self.rpc_timeout)
@ -553,15 +556,16 @@ class KademliaProtocol(DatagramProtocol):
self.peer_manager.report_failure(peer.address, peer.udp_port) self.peer_manager.report_failure(peer.address, peer.udp_port)
if self.peer_manager.peer_is_good(peer) is False: if self.peer_manager.peer_is_good(peer) is False:
self._to_remove.add(peer) self._to_remove.add(peer)
self._wakeup_routing_task.set()
raise raise
async def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram): def send_response(self, peer: 'KademliaPeer', response: ResponseDatagram):
await self._send(peer, response) self._send(peer, response)
async def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram): def send_error(self, peer: 'KademliaPeer', error: ErrorDatagram):
await self._send(peer, error) self._send(peer, error)
async def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram, def _send(self, peer: 'KademliaPeer', message: typing.Union[RequestDatagram, ResponseDatagram,
ErrorDatagram]): ErrorDatagram]):
if not self.transport or self.transport.is_closing(): if not self.transport or self.transport.is_closing():
raise TransportNotConnected() raise TransportNotConnected()

View file

@ -1,102 +0,0 @@
import unittest
import asyncio
from torba.testcase import AsyncioTestCase
from lbrynet.dht.protocol.async_generator_junction import AsyncGeneratorJunction
class MockAsyncGen:
def __init__(self, loop, result, delay, stop_cnt=10):
self.loop = loop
self.result = result
self.delay = delay
self.count = 0
self.stop_cnt = stop_cnt
self.called_close = False
def __aiter__(self):
return self
async def __anext__(self):
await asyncio.sleep(self.delay, loop=self.loop)
if self.count > self.stop_cnt - 1:
raise StopAsyncIteration()
self.count += 1
return self.result
async def aclose(self):
self.called_close = True
class TestAsyncGeneratorJunction(AsyncioTestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
async def _test_junction(self, expected, *generators):
order = []
async with AsyncGeneratorJunction(self.loop) as junction:
for generator in generators:
junction.add_generator(generator)
async for item in junction:
order.append(item)
self.assertListEqual(order, expected)
async def test_yield_order(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2]
fast_gen = MockAsyncGen(self.loop, 1, 0.2)
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
await self._test_junction(expected_order, fast_gen, slow_gen)
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)
async def test_nothing_to_yield(self):
async def __nothing():
for _ in []:
yield self.fail("nada")
await self._test_junction([], __nothing())
async def test_fast_iteratiors(self):
async def __gotta_go_fast():
for _ in range(10):
yield 0
await self._test_junction([0]*40, __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast(), __gotta_go_fast())
@unittest.SkipTest
async def test_one_stopped_first(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
fast_gen = MockAsyncGen(self.loop, 1, 0.2, 5)
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
await self._test_junction(expected_order, fast_gen, slow_gen)
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)
async def test_with_non_async_gen_class(self):
expected_order = [1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
async def fast_gen():
for i in range(10):
if i == 5:
return
await asyncio.sleep(0.2)
yield 1
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
await self._test_junction(expected_order, fast_gen(), slow_gen)
self.assertEqual(slow_gen.called_close, True)
async def test_stop_when_encapsulating_task_cancelled(self):
fast_gen = MockAsyncGen(self.loop, 1, 0.2)
slow_gen = MockAsyncGen(self.loop, 2, 0.4)
async def _task():
async with AsyncGeneratorJunction(self.loop) as junction:
junction.add_generator(fast_gen)
junction.add_generator(slow_gen)
async for _ in junction:
pass
task = self.loop.create_task(_task())
self.loop.call_later(1.0, task.cancel)
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(fast_gen.called_close, True)
self.assertEqual(slow_gen.called_close, True)

View file

@ -99,7 +99,7 @@ class TestProtocol(AsyncioTestCase):
self.loop, PeerManager(self.loop), node_id, address, udp_port, tcp_port self.loop, PeerManager(self.loop), node_id, address, udp_port, tcp_port
) )
await self.loop.create_datagram_endpoint(lambda: proto, (address, 4444)) await self.loop.create_datagram_endpoint(lambda: proto, (address, 4444))
proto.start(0.1) proto.start()
return proto, other_peer.peer_manager.get_kademlia_peer(node_id, address, udp_port=udp_port) return proto, other_peer.peer_manager.get_kademlia_peer(node_id, address, udp_port=udp_port)
async def test_add_peer_after_handle_request(self): async def test_add_peer_after_handle_request(self):
@ -113,7 +113,7 @@ class TestProtocol(AsyncioTestCase):
self.loop, PeerManager(self.loop), node_id1, '1.2.3.4', 4444, 3333 self.loop, PeerManager(self.loop), node_id1, '1.2.3.4', 4444, 3333
) )
await self.loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444)) await self.loop.create_datagram_endpoint(lambda: peer1, ('1.2.3.4', 4444))
peer1.start(0.1) peer1.start()
peer2, peer_2_from_peer_1 = await self._make_protocol(peer1, node_id2, '1.2.3.5', 4444, 3333) peer2, peer_2_from_peer_1 = await self._make_protocol(peer1, node_id2, '1.2.3.5', 4444, 3333)
peer3, peer_3_from_peer_1 = await self._make_protocol(peer1, node_id3, '1.2.3.6', 4444, 3333) peer3, peer_3_from_peer_1 = await self._make_protocol(peer1, node_id3, '1.2.3.6', 4444, 3333)

View file

@ -21,7 +21,6 @@ class TestBlobAnnouncer(AsyncioTestCase):
await self.storage.open() await self.storage.open()
self.peer_manager = PeerManager(self.loop) self.peer_manager = PeerManager(self.loop)
self.node = Node(self.loop, self.peer_manager, node_id, 4444, 4444, 3333, address) self.node = Node(self.loop, self.peer_manager, node_id, 4444, 4444, 3333, address)
self.node.protocol.start(0.1)
await self.node.start_listening(address) await self.node.start_listening(address)
self.blob_announcer = BlobAnnouncer(self.loop, self.node, self.storage) self.blob_announcer = BlobAnnouncer(self.loop, self.node, self.storage)
for node_id, address in peer_addresses: for node_id, address in peer_addresses:
@ -31,7 +30,6 @@ class TestBlobAnnouncer(AsyncioTestCase):
async def add_peer(self, node_id, address, add_to_routing_table=True): async def add_peer(self, node_id, address, add_to_routing_table=True):
n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address) n = Node(self.loop, PeerManager(self.loop), node_id, 4444, 4444, 3333, address)
await n.start_listening(address) await n.start_listening(address)
n.protocol.start(0.1)
self.nodes.update({len(self.nodes): n}) self.nodes.update({len(self.nodes): n})
if add_to_routing_table: if add_to_routing_table:
self.node.protocol.add_peer( self.node.protocol.add_peer(
@ -108,7 +106,7 @@ class TestBlobAnnouncer(AsyncioTestCase):
_, task = last.accumulate_peers(search_q, peer_q) _, task = last.accumulate_peers(search_q, peer_q)
found_peers = await peer_q.get() found_peers = await peer_q.get()
await task task.cancel()
self.assertEqual(1, len(found_peers)) self.assertEqual(1, len(found_peers))
self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id) self.assertEqual(self.node.protocol.node_id, found_peers[0].node_id)