forked from LBRYCommunity/lbry-sdk
add ability to re-join network on disconnect + tests
This commit is contained in:
parent
6f06026511
commit
bac7d99b8a
7 changed files with 56 additions and 22 deletions
|
@ -31,8 +31,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
self.closed = asyncio.Event(loop=self.loop)
|
self.closed = asyncio.Event(loop=self.loop)
|
||||||
|
|
||||||
def data_received(self, data: bytes):
|
def data_received(self, data: bytes):
|
||||||
log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
#log.debug("%s:%d -- got %s bytes -- %s bytes on buffer -- %s blob bytes received",
|
||||||
self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
# self.peer_address, self.peer_port, len(data), len(self.buf), self._blob_bytes_received)
|
||||||
if not self.transport or self.transport.is_closing():
|
if not self.transport or self.transport.is_closing():
|
||||||
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
|
log.warning("transport closing, but got more bytes from %s:%i\n%s", self.peer_address, self.peer_port,
|
||||||
binascii.hexlify(data))
|
binascii.hexlify(data))
|
||||||
|
|
|
@ -33,7 +33,7 @@ class BlobAnnouncer:
|
||||||
while batch_size:
|
while batch_size:
|
||||||
if not self.node.joined.is_set():
|
if not self.node.joined.is_set():
|
||||||
await self.node.joined.wait()
|
await self.node.joined.wait()
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60, loop=self.loop)
|
||||||
if not self.node.protocol.routing_table.get_peers():
|
if not self.node.protocol.routing_table.get_peers():
|
||||||
log.warning("No peers in DHT, announce round skipped")
|
log.warning("No peers in DHT, announce round skipped")
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -144,9 +144,21 @@ class Node:
|
||||||
KademliaPeer(self.loop, address, udp_port=port)
|
KademliaPeer(self.loop, address, udp_port=port)
|
||||||
for (address, port) in known_node_addresses
|
for (address, port) in known_node_addresses
|
||||||
]
|
]
|
||||||
while not len(self.protocol.routing_table.get_peers()):
|
while True:
|
||||||
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
|
if not self.protocol.routing_table.get_peers():
|
||||||
|
if self.joined.is_set():
|
||||||
|
self.joined.clear()
|
||||||
|
self.protocol.peer_manager.reset()
|
||||||
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
|
self.protocol.ping_queue.enqueue_maybe_ping(*peers, delay=0.0)
|
||||||
|
peers.extend(await self.peer_search(self.protocol.node_id, shortlist=peers, count=32))
|
||||||
|
if self.protocol.routing_table.get_peers():
|
||||||
|
self.joined.set()
|
||||||
|
log.info(
|
||||||
|
"Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||||
|
self.protocol.routing_table.buckets_with_contacts())
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
await asyncio.sleep(1, loop=self.loop)
|
||||||
|
|
||||||
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
log.info("Joined DHT, %i peers known in %i buckets", len(self.protocol.routing_table.get_peers()),
|
||||||
self.protocol.routing_table.buckets_with_contacts())
|
self.protocol.routing_table.buckets_with_contacts())
|
||||||
|
@ -186,27 +198,25 @@ class Node:
|
||||||
|
|
||||||
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
|
async def _accumulate_search_junction(self, search_queue: asyncio.Queue,
|
||||||
result_queue: asyncio.Queue):
|
result_queue: asyncio.Queue):
|
||||||
ongoing = {}
|
tasks = []
|
||||||
async def __start_producing_task():
|
async def __start_producing_task():
|
||||||
while True:
|
while True:
|
||||||
blob_hash = await search_queue.get()
|
blob_hash = await search_queue.get()
|
||||||
ongoing[blob_hash] = asyncio.create_task(self._value_producer(blob_hash, result_queue))
|
tasks.append(asyncio.create_task(self._value_producer(blob_hash, result_queue)))
|
||||||
ongoing[''] = asyncio.create_task(__start_producing_task())
|
tasks.append(asyncio.create_task(__start_producing_task()))
|
||||||
try:
|
try:
|
||||||
while True:
|
await asyncio.wait(tasks)
|
||||||
await asyncio.wait(ongoing.values(), return_when='FIRST_COMPLETED')
|
|
||||||
for key in list(ongoing.keys())[:]:
|
|
||||||
if key and ongoing[key].done():
|
|
||||||
ongoing[key] = asyncio.create_task(self._value_producer(key, result_queue))
|
|
||||||
finally:
|
finally:
|
||||||
for task in ongoing.values():
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
|
async def _value_producer(self, blob_hash: str, result_queue: asyncio.Queue):
|
||||||
|
for interval in range(1000):
|
||||||
log.info("Searching %s", blob_hash[:8])
|
log.info("Searching %s", blob_hash[:8])
|
||||||
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
|
async for results in self.get_iterative_value_finder(binascii.unhexlify(blob_hash.encode())):
|
||||||
result_queue.put_nowait(results)
|
result_queue.put_nowait(results)
|
||||||
log.info("Search expired %s", blob_hash[:8])
|
log.info("Search expired %s", blob_hash[:8])
|
||||||
|
await asyncio.sleep(interval ** 2)
|
||||||
|
|
||||||
def accumulate_peers(self, search_queue: asyncio.Queue,
|
def accumulate_peers(self, search_queue: asyncio.Queue,
|
||||||
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[
|
peer_queue: typing.Optional[asyncio.Queue] = None) -> typing.Tuple[
|
||||||
|
|
|
@ -32,6 +32,10 @@ class PeerManager:
|
||||||
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
|
||||||
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for statistic in (self._rpc_failures, self._last_replied, self._last_sent, self._last_requested):
|
||||||
|
statistic.clear()
|
||||||
|
|
||||||
def report_failure(self, address: str, udp_port: int):
|
def report_failure(self, address: str, udp_port: int):
|
||||||
now = self._loop.time()
|
now = self._loop.time()
|
||||||
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))
|
_, previous = self._rpc_failures.pop((address, udp_port), (None, None))
|
||||||
|
|
|
@ -275,6 +275,7 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
|
self._wakeup_routing_task = asyncio.Event(loop=self.loop)
|
||||||
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
|
self.maintaing_routing_task: typing.Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
@functools.lru_cache(128)
|
||||||
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
def get_rpc_peer(self, peer: 'KademliaPeer') -> RemoteKademliaRPC:
|
||||||
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
return RemoteKademliaRPC(self.loop, self.peer_manager, self, peer)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class DHTIntegrationTest(AsyncioTestCase):
|
||||||
self.nodes = []
|
self.nodes = []
|
||||||
self.known_node_addresses = []
|
self.known_node_addresses = []
|
||||||
|
|
||||||
async def setup_network(self, size: int, start_port=40000):
|
async def setup_network(self, size: int, start_port=40000, seed_nodes=1):
|
||||||
for i in range(size):
|
for i in range(size):
|
||||||
node_port = start_port + i
|
node_port = start_port + i
|
||||||
node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i),
|
node = Node(self.loop, PeerManager(self.loop), node_id=constants.generate_id(i),
|
||||||
|
@ -29,7 +29,7 @@ class DHTIntegrationTest(AsyncioTestCase):
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node.protocol.rpc_timeout = .2
|
node.protocol.rpc_timeout = .2
|
||||||
node.protocol.ping_queue._default_delay = .5
|
node.protocol.ping_queue._default_delay = .5
|
||||||
node.start('127.0.0.1', self.known_node_addresses[:1])
|
node.start('127.0.0.1', self.known_node_addresses[:seed_nodes])
|
||||||
await asyncio.gather(*[node.joined.wait() for node in self.nodes])
|
await asyncio.gather(*[node.joined.wait() for node in self.nodes])
|
||||||
|
|
||||||
async def test_replace_bad_nodes(self):
|
async def test_replace_bad_nodes(self):
|
||||||
|
@ -52,6 +52,24 @@ class DHTIntegrationTest(AsyncioTestCase):
|
||||||
for peer in node.protocol.routing_table.get_peers():
|
for peer in node.protocol.routing_table.get_peers():
|
||||||
self.assertIn(peer.node_id, good_nodes)
|
self.assertIn(peer.node_id, good_nodes)
|
||||||
|
|
||||||
|
async def test_re_join(self):
|
||||||
|
await self.setup_network(20, seed_nodes=10)
|
||||||
|
node = self.nodes[-1]
|
||||||
|
self.assertTrue(node.joined.is_set())
|
||||||
|
self.assertTrue(node.protocol.routing_table.get_peers())
|
||||||
|
for network_node in self.nodes[:-1]:
|
||||||
|
network_node.stop()
|
||||||
|
await node.refresh_node(True)
|
||||||
|
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||||
|
for network_node in self.nodes[:-1]:
|
||||||
|
await network_node.start_listening('127.0.0.1')
|
||||||
|
self.assertFalse(node.protocol.routing_table.get_peers())
|
||||||
|
timeout = 20
|
||||||
|
while not node.protocol.routing_table.get_peers():
|
||||||
|
await asyncio.sleep(.1)
|
||||||
|
timeout -= 1
|
||||||
|
if not timeout:
|
||||||
|
self.fail("node didnt join back after 2 seconds")
|
||||||
|
|
||||||
async def test_announce_no_peers(self):
|
async def test_announce_no_peers(self):
|
||||||
await self.setup_network(1)
|
await self.setup_network(1)
|
||||||
|
|
|
@ -86,7 +86,8 @@ class TestBlobAnnouncer(AsyncioTestCase):
|
||||||
to_announce = await self.storage.get_blobs_to_announce()
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
self.assertEqual(2, len(to_announce))
|
self.assertEqual(2, len(to_announce))
|
||||||
self.blob_announcer.start(batch_size=1) # so it covers batching logic
|
self.blob_announcer.start(batch_size=1) # so it covers batching logic
|
||||||
await self.advance(61.0)
|
# takes 60 seconds to start, but we advance 120 to ensure it processed all batches
|
||||||
|
await self.advance(60.0 * 2)
|
||||||
to_announce = await self.storage.get_blobs_to_announce()
|
to_announce = await self.storage.get_blobs_to_announce()
|
||||||
self.assertEqual(0, len(to_announce))
|
self.assertEqual(0, len(to_announce))
|
||||||
self.blob_announcer.stop()
|
self.blob_announcer.stop()
|
||||||
|
|
Loading…
Reference in a new issue