add ability to re-join network on disconnect + tests
This commit is contained in:
parent
6f06026511
commit
bac7d99b8a
7 changed files with 56 additions and 22 deletions
|
@ -31,8 +31,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self.closed = asyncio.Event(loop=self.loop)
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
||||
self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
||||
#log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
||||
# self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
||||
if not self.transport or self.transport.is_closing():
|
||||
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
|
||||
binascii.hexlify(data))
|
||||
|
|
|
@ -33,7 +33,7 @@ class BlobAnnouncer:
|
|||
while batch_size:
|
||||
if not self.node.joined.is_set():
|
||||
await self.node.joined.wait()
|
||||
await asyncio.sleep(60)
|
||||
await asyncio.sleep(60, loop=self.loop)
|
||||
if not self.node.protocol.routing_table.get_peers():
|
||||
log.warning("No peers in DHT, announce round skipped")
|
||||
continue
|
||||
|
|
|
@ -144,9 +144,21 @@ class Node:
|
|||
KademliaPeer(self.loop, address, udp_port=port)
|
||||
for (address, port) in known_node_addresses
|
||||
]
|
||||
while not len(self.protocol.routing_table.get_peers()):
|
||||
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
|
||||
while True:
|
||||
if not self.protocol.routing_table.get_peers():
|
||||
if self.joined.is_set():
|
||||
self.joined.clear()
|
||||
self.protocol.peer_manager.reset()
|
||||
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
|
||||
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
|
||||
if self.protocol.routing_table.get_peers():
|
||||
self.joined.set()
|
||||
log.info(
|
||||
"Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||
self.protocol.routing_table.buckets_with_contacts())
|
||||
else:
|
||||
continue
|
||||
await asyncio.sleep(1, loop=self.loop)
|
||||
|
||||
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||
self.protocol.routing_table.buckets_with_contacts())
|
||||
|
@ -186,27 +198,25 @@ class Node:
|
|||
|
||||
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
|
||||
result_queue: asyncio.Queue):
|
||||
ongoing = {}
|
||||
tasks = []
|
||||
async def __start_producing_task():
|
||||
while True:
|
||||
blob_hash = await search_queue.get()
|
||||
ongoing[blob_hash] = asyncio.create_task(self._value_producer(blob_hash, result_queue))
|
||||
ongoing[''] = asyncio.create_task(__start_producing_task())
|
||||
tasks.append(asyncio.create_task(self._value_producer(blob_hash, result_queue)))
|
||||
tasks.append(asyncio.create_task(__start_producing_task()))
|
||||
try:
|
||||
while True:
|
||||
await asyncio.wait(ongoing.values(), return_when='FIRST_COMPLETED')
|
||||
for key in list(ongoing.keys())[:]:
|
||||
if key and ongoing[key].done():
|
||||
ongoing[key] = asyncio.create_task(self._value_producer(key, result_queue))
|
||||
await asyncio.wait(tasks)
|
||||
finally:
|
||||
for task in ongoing.values():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
|
||||
for interval in range(1000):
|
||||
log.info("Searching %s", blob_hash[:8])
|
||||
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
|
||||
result_queue.put_nowait(results)
|
||||
log.info("Search expired %s", blob_hash[:8])
|
||||
await asyncio.sleep(interval ** 2)
|
||||
|
||||
def accumulate_peers(self, search_queue: asyncio.Queue,
|
||||
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[
|
||||
|
|
|
@ -32,6 +32,10 @@ class PeerManager:
|
|||
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
||||
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
||||
|
||||
def reset(self):
|
||||
for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested):
|
||||
statistic.clear()
|
||||
|
||||
def report_failure(self, address: str, udp_port: int):
|
||||
now = self._loop.time()
|
||||
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))
|
||||
|
|
|
@ -275,6 +275,7 @@ class KademliaProtocol(DatagramProtocol):
|
|||
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
|
||||
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
@functools.lru_cache(128)
|
||||
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
||||
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ class DHTIntegrationTest(AsyncioTestCase):
|
|||
self.nodes = []
|
||||
self.known_node_addresses = []
|
||||
|
||||
async def setup_network(self, size: int, start_port=40000):
|
||||
async def setup_network(self, size: int, start_port=40000, seed_nodes=1):
|
||||
for i in range(size):
|
||||
node_port = start_port + i
|
||||
node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i),
|
||||
|
@ -29,7 +29,7 @@ class DHTIntegrationTest(AsyncioTestCase):
|
|||
for node in self.nodes:
|
||||
node.protocol.rpc_timeout = .2
|
||||
node.protocol.ping_queue._default_delay = .5
|
||||
node.start('127.0.0.1', self.known_node_addresses[:1])
|
||||
node.start('127.0.0.1', self.known_node_addresses[:seed_nodes])
|
||||
await asyncio.gather(*[node.joined.wait() for node in self.nodes])
|
||||
|
||||
async def test_replace_bad_nodes(self):
|
||||
|
@ -52,6 +52,24 @@ class DHTIntegrationTest(AsyncioTestCase):
|
|||
for peer in node.protocol.routing_table.get_peers():
|
||||
self.assertIn(peer.node_id, good_nodes)
|
||||
|
||||
async def test_re_join(self):
|
||||
await self.setup_network(20, seed_nodes=10)
|
||||
node = self.nodes[-1]
|
||||
self.assertTrue(node.joined.is_set())
|
||||
self.assertTrue(node.protocol.routing_table.get_peers())
|
||||
for network_node in self.nodes[:-1]:
|
||||
network_node.stop()
|
||||
await node.refresh_node(True)
|
||||
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||
for network_node in self.nodes[:-1]:
|
||||
await network_node.start_listening('127.0.0.1')
|
||||
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||
timeout = 20
|
||||
while not node.protocol.routing_table.get_peers():
|
||||
await asyncio.sleep(.1)
|
||||
timeout -= 1
|
||||
if not timeout:
|
||||
self.fail("node didnt join back after 2 seconds")
|
||||
|
||||
async def test_announce_no_peers(self):
|
||||
await self.setup_network(1)
|
||||
|
|
|
@ -86,7 +86,8 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
|||
to_announce = await self.storage.get_blobs_to_announce()
|
||||
self.assertEqual(2, len(to_announce))
|
||||
self.blob_announcer.start(batch_size=1) # so it covers batching logic
|
||||
await self.advance(61.0)
|
||||
# takes 60 seconds to start, but we advance 120 to ensure it processed all batches
|
||||
await self.advance(60.0 * 2)
|
||||
to_announce = await self.storage.get_blobs_to_announce()
|
||||
self.assertEqual(0, len(to_announce))
|
||||
self.blob_announcer.stop()
|
||||
|
|
Loading…
Reference in a new issue