add ability to re-join network on disconnect + tests

This commit is contained in:
Victor Shyba 2019-05-12 03:39:11 -03:00
parent 6f06026511
commit bac7d99b8a
7 changed files with 56 additions and 22 deletions

View file

@ -31,8 +31,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.closed = asyncio.Event(loop=self.loop)
def data_received(self, data: bytes):
log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
#log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
# self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
if not self.transport or self.transport.is_closing():
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
binascii.hexlify(data))

View file

@ -33,7 +33,7 @@ class BlobAnnouncer:
while batch_size:
if not self.node.joined.is_set():
await self.node.joined.wait()
await asyncio.sleep(60)
await asyncio.sleep(60, loop=self.loop)
if not self.node.protocol.routing_table.get_peers():
log.warning("No peers in DHT, announce round skipped")
continue

View file

@ -144,9 +144,21 @@ class Node:
KademliaPeer(self.loop, address, udp_port=port)
for (address, port) in known_node_addresses
]
while not len(self.protocol.routing_table.get_peers()):
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
while True:
if not self.protocol.routing_table.get_peers():
if self.joined.is_set():
self.joined.clear()
self.protocol.peer_manager.reset()
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
if self.protocol.routing_table.get_peers():
self.joined.set()
log.info(
"Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
self.protocol.routing_table.buckets_with_contacts())
else:
continue
await asyncio.sleep(1, loop=self.loop)
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
self.protocol.routing_table.buckets_with_contacts())
@ -186,27 +198,25 @@ class Node:
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
result_queue: asyncio.Queue):
ongoing = {}
tasks = []
async def __start_producing_task():
while True:
blob_hash = await search_queue.get()
ongoing[blob_hash] = asyncio.create_task(self._value_producer(blob_hash, result_queue))
ongoing[''] = asyncio.create_task(__start_producing_task())
tasks.append(asyncio.create_task(self._value_producer(blob_hash, result_queue)))
tasks.append(asyncio.create_task(__start_producing_task()))
try:
while True:
await asyncio.wait(ongoing.values(), return_when='FIRST_COMPLETED')
for key in list(ongoing.keys())[:]:
if key and ongoing[key].done():
ongoing[key] = asyncio.create_task(self._value_producer(key, result_queue))
await asyncio.wait(tasks)
finally:
for task in ongoing.values():
for task in tasks:
task.cancel()
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
log.info("Searching %s", blob_hash[:8])
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
result_queue.put_nowait(results)
log.info("Search expired %s", blob_hash[:8])
for interval in range(1000):
log.info("Searching %s", blob_hash[:8])
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
result_queue.put_nowait(results)
log.info("Search expired %s", blob_hash[:8])
await asyncio.sleep(interval ** 2)
def accumulate_peers(self, search_queue: asyncio.Queue,
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[

View file

@ -32,6 +32,10 @@ class PeerManager:
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
def reset(self):
for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested):
statistic.clear()
def report_failure(self, address: str, udp_port: int):
now = self._loop.time()
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))

View file

@ -275,6 +275,7 @@ class KademliaProtocol(DatagramProtocol):
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
@functools.lru_cache(128)
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)

View file

@ -16,7 +16,7 @@ class DHTIntegrationTest(AsyncioTestCase):
self.nodes = []
self.known_node_addresses = []
async def setup_network(self, size: int, start_port=40000):
async def setup_network(self, size: int, start_port=40000, seed_nodes=1):
for i in range(size):
node_port = start_port + i
node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i),
@ -29,7 +29,7 @@ class DHTIntegrationTest(AsyncioTestCase):
for node in self.nodes:
node.protocol.rpc_timeout = .2
node.protocol.ping_queue._default_delay = .5
node.start('127.0.0.1', self.known_node_addresses[:1])
node.start('127.0.0.1', self.known_node_addresses[:seed_nodes])
await asyncio.gather(*[node.joined.wait() for node in self.nodes])
async def test_replace_bad_nodes(self):
@ -52,6 +52,24 @@ class DHTIntegrationTest(AsyncioTestCase):
for peer in node.protocol.routing_table.get_peers():
self.assertIn(peer.node_id, good_nodes)
async def test_re_join(self):
await self.setup_network(20, seed_nodes=10)
node = self.nodes[-1]
self.assertTrue(node.joined.is_set())
self.assertTrue(node.protocol.routing_table.get_peers())
for network_node in self.nodes[:-1]:
network_node.stop()
await node.refresh_node(True)
self.assertFalse(node.protocol.routing_table.get_peers())
for network_node in self.nodes[:-1]:
await network_node.start_listening('127.0.0.1')
self.assertFalse(node.protocol.routing_table.get_peers())
timeout = 20
while not node.protocol.routing_table.get_peers():
await asyncio.sleep(.1)
timeout -= 1
if not timeout:
self.fail("node didnt join back after 2 seconds")
async def test_announce_no_peers(self):
await self.setup_network(1)

View file

@ -86,7 +86,8 @@ class TestBlobAnnouncer(AsyncioTestCase):
to_announce = await self.storage.get_blobs_to_announce()
self.assertEqual(2, len(to_announce))
self.blob_announcer.start(batch_size=1) # so it covers batching logic
await self.advance(61.0)
# takes 60 seconds to start, but we advance 120 to ensure it processed all batches
await self.advance(60.0 * 2)
to_announce = await self.storage.get_blobs_to_announce()
self.assertEqual(0, len(to_announce))
self.blob_announcer.stop()