fix refreshing peers in the datastore
This commit is contained in:
parent
079c81f298
commit
6ecc22f2c7
4 changed files with 71 additions and 65 deletions
|
@ -9,68 +9,62 @@ if typing.TYPE_CHECKING:
|
||||||
class DictDataStore:
|
class DictDataStore:
|
||||||
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'):
|
def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'):
|
||||||
# Dictionary format:
|
# Dictionary format:
|
||||||
# { <key>: [<contact>, <value>, <lastPublished>, <originallyPublished> <original_publisher_id>] }
|
# { <key>: [(<contact>, <age>), ...] }
|
||||||
self._data_store: typing.Dict[bytes,
|
self._data_store: typing.Dict[bytes, typing.List[typing.Tuple['KademliaPeer', float]]] = {}
|
||||||
typing.List[typing.Tuple['KademliaPeer', bytes, float, float, bytes]]] = {}
|
|
||||||
self._get_time = loop.time
|
self.loop = loop
|
||||||
self._peer_manager = peer_manager
|
self._peer_manager = peer_manager
|
||||||
self.completed_blobs: typing.Set[str] = set()
|
self.completed_blobs: typing.Set[str] = set()
|
||||||
|
|
||||||
def filter_bad_and_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
|
def removed_expired_peers(self):
|
||||||
|
now = self.loop.time()
|
||||||
|
keys = list(self._data_store.keys())
|
||||||
|
for key in keys:
|
||||||
|
to_remove = []
|
||||||
|
for (peer, ts) in self._data_store[key]:
|
||||||
|
if ts + constants.data_expiration < now or self._peer_manager.peer_is_good(peer) is False:
|
||||||
|
to_remove.append((peer, ts))
|
||||||
|
for item in to_remove:
|
||||||
|
self._data_store[key].remove(item)
|
||||||
|
if not self._data_store[key]:
|
||||||
|
del self._data_store[key]
|
||||||
|
|
||||||
|
def filter_bad_and_expired_peers(self, key: bytes) -> typing.Iterator['KademliaPeer']:
|
||||||
"""
|
"""
|
||||||
Returns only non-expired and unknown/good peers
|
Returns only non-expired and unknown/good peers
|
||||||
"""
|
"""
|
||||||
peers = []
|
for peer in self.filter_expired_peers(key):
|
||||||
for peer in map(lambda p: p[0],
|
|
||||||
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration,
|
|
||||||
self._data_store[key])):
|
|
||||||
if self._peer_manager.peer_is_good(peer) is not False:
|
if self._peer_manager.peer_is_good(peer) is not False:
|
||||||
peers.append(peer)
|
yield peer
|
||||||
return peers
|
|
||||||
|
|
||||||
def filter_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']:
|
def filter_expired_peers(self, key: bytes) -> typing.Iterator['KademliaPeer']:
|
||||||
"""
|
"""
|
||||||
Returns only non-expired peers
|
Returns only non-expired peers
|
||||||
"""
|
"""
|
||||||
return list(
|
now = self.loop.time()
|
||||||
map(
|
for (peer, ts) in self._data_store.get(key, []):
|
||||||
lambda p: p[0],
|
if ts + constants.data_expiration > now:
|
||||||
filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, self._data_store[key])
|
yield peer
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def removed_expired_peers(self):
|
|
||||||
expired_keys = []
|
|
||||||
for key in self._data_store.keys():
|
|
||||||
unexpired_peers = self.filter_expired_peers(key)
|
|
||||||
if not unexpired_peers:
|
|
||||||
expired_keys.append(key)
|
|
||||||
else:
|
|
||||||
self._data_store[key] = [x for x in self._data_store[key] if x[0] in unexpired_peers]
|
|
||||||
for key in expired_keys:
|
|
||||||
del self._data_store[key]
|
|
||||||
|
|
||||||
def has_peers_for_blob(self, key: bytes) -> bool:
|
def has_peers_for_blob(self, key: bytes) -> bool:
|
||||||
return key in self._data_store and len(self.filter_bad_and_expired_peers(key)) > 0
|
return key in self._data_store
|
||||||
|
|
||||||
def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes, compact_address: bytes, last_published: int,
|
def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes) -> None:
|
||||||
originally_published: int, original_publisher_id: bytes) -> None:
|
now = self.loop.time()
|
||||||
if key in self._data_store:
|
if key in self._data_store:
|
||||||
if compact_address not in map(lambda store_tuple: store_tuple[1], self._data_store[key]):
|
current = list(filter(lambda x: x[0] == contact, self._data_store[key]))
|
||||||
self._data_store[key].append(
|
if len(current):
|
||||||
(contact, compact_address, last_published, originally_published, original_publisher_id)
|
self._data_store[key][self._data_store[key].index(current[0])] = contact, now
|
||||||
)
|
else:
|
||||||
|
self._data_store[key].append((contact, now))
|
||||||
else:
|
else:
|
||||||
self._data_store[key] = [(contact, compact_address, last_published, originally_published,
|
self._data_store[key] = [(contact, now)]
|
||||||
original_publisher_id)]
|
|
||||||
|
|
||||||
def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']:
|
def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']:
|
||||||
return [] if key not in self._data_store else [peer for peer in self.filter_bad_and_expired_peers(key)]
|
return list(self.filter_bad_and_expired_peers(key))
|
||||||
|
|
||||||
def get_storing_contacts(self) -> typing.List['KademliaPeer']:
|
def get_storing_contacts(self) -> typing.List['KademliaPeer']:
|
||||||
peers = set()
|
peers = set()
|
||||||
for key in self._data_store:
|
for key, stored in self._data_store.items():
|
||||||
for values in self._data_store[key]:
|
peers.update(set(map(lambda tup: tup[0], stored)))
|
||||||
if values[0] not in peers:
|
|
||||||
peers.add(values[0])
|
|
||||||
return list(peers)
|
return list(peers)
|
||||||
|
|
|
@ -45,19 +45,19 @@ class KademliaRPC:
|
||||||
def ping():
|
def ping():
|
||||||
return b'pong'
|
return b'pong'
|
||||||
|
|
||||||
def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int,
|
def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int) -> bytes:
|
||||||
original_publisher_id: bytes, age: int) -> bytes:
|
if len(blob_hash) != constants.hash_bits // 8:
|
||||||
if original_publisher_id is None:
|
raise ValueError(f"invalid length of blob hash: {len(blob_hash)}")
|
||||||
original_publisher_id = rpc_contact.node_id
|
if not 0 < port < 65535:
|
||||||
|
raise ValueError(f"invalid tcp port: {port}")
|
||||||
rpc_contact.update_tcp_port(port)
|
rpc_contact.update_tcp_port(port)
|
||||||
if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval:
|
if not self.verify_token(token, rpc_contact.compact_ip()):
|
||||||
pass
|
if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval:
|
||||||
elif not self.verify_token(token, rpc_contact.compact_ip()):
|
pass
|
||||||
raise ValueError("Invalid token")
|
else:
|
||||||
now = int(self.loop.time())
|
raise ValueError("Invalid token")
|
||||||
originally_published = now - age
|
|
||||||
self.protocol.data_store.add_peer_to_blob(
|
self.protocol.data_store.add_peer_to_blob(
|
||||||
rpc_contact, blob_hash, rpc_contact.compact_address_tcp(), now, originally_published, original_publisher_id
|
rpc_contact, blob_hash
|
||||||
)
|
)
|
||||||
return b'OK'
|
return b'OK'
|
||||||
|
|
||||||
|
@ -416,7 +416,7 @@ class KademliaProtocol(DatagramProtocol):
|
||||||
result = self.node_rpc.ping()
|
result = self.node_rpc.ping()
|
||||||
elif method == b'store':
|
elif method == b'store':
|
||||||
blob_hash, token, port, original_publisher_id, age = a
|
blob_hash, token, port, original_publisher_id, age = a
|
||||||
result = self.node_rpc.store(sender_contact, blob_hash, token, port, original_publisher_id, age)
|
result = self.node_rpc.store(sender_contact, blob_hash, token, port)
|
||||||
elif method == b'findNode':
|
elif method == b'findNode':
|
||||||
key, = a
|
key, = a
|
||||||
result = self.node_rpc.find_node(sender_contact, key)
|
result = self.node_rpc.find_node(sender_contact, key)
|
||||||
|
|
|
@ -16,10 +16,22 @@ class DataStoreTests(TestCase):
|
||||||
peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
|
peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port)
|
||||||
peer.update_tcp_port(tcp_port)
|
peer.update_tcp_port(tcp_port)
|
||||||
before = self.data_store.get_peers_for_blob(blob)
|
before = self.data_store.get_peers_for_blob(blob)
|
||||||
self.data_store.add_peer_to_blob(peer, blob, peer.compact_address_tcp(), 0, 0, peer.node_id)
|
self.data_store.add_peer_to_blob(peer, blob)
|
||||||
self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob))
|
self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob))
|
||||||
return peer
|
return peer
|
||||||
|
|
||||||
|
def test_refresh_peer_to_blob(self):
|
||||||
|
blob = b'f' * 48
|
||||||
|
self.assertListEqual([], self.data_store.get_peers_for_blob(blob))
|
||||||
|
peer = self._test_add_peer_to_blob(blob=blob, node_id=b'a' * 48, address='1.2.3.4')
|
||||||
|
self.assertTrue(self.data_store.has_peers_for_blob(blob))
|
||||||
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob)), 1)
|
||||||
|
self.assertEqual(self.data_store._data_store[blob][0][1], 0)
|
||||||
|
self.loop.time = lambda: 100.0
|
||||||
|
self.assertEqual(self.data_store._data_store[blob][0][1], 0)
|
||||||
|
self.data_store.add_peer_to_blob(peer, blob)
|
||||||
|
self.assertEqual(self.data_store._data_store[blob][0][1], 100)
|
||||||
|
|
||||||
def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None):
|
def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None):
|
||||||
peers = peers or [
|
peers = peers or [
|
||||||
(b'a' * 48, '1.2.3.4'),
|
(b'a' * 48, '1.2.3.4'),
|
||||||
|
@ -67,8 +79,8 @@ class DataStoreTests(TestCase):
|
||||||
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
|
|
||||||
# expire the first peer from blob1
|
# expire the first peer from blob1
|
||||||
first = self.data_store._data_store[blob1][0]
|
first = self.data_store._data_store[blob1][0][0]
|
||||||
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
|
self.data_store._data_store[blob1][0] = (first, -86401)
|
||||||
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
self.data_store.removed_expired_peers()
|
self.data_store.removed_expired_peers()
|
||||||
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
||||||
|
@ -76,18 +88,18 @@ class DataStoreTests(TestCase):
|
||||||
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers))
|
||||||
|
|
||||||
# expire the first peer from blob2
|
# expire the first peer from blob2
|
||||||
first = self.data_store._data_store[blob2][0]
|
first = self.data_store._data_store[blob2][0][0]
|
||||||
self.data_store._data_store[blob2][0] = (first[0], first[1], first[2], -86401, first[4])
|
self.data_store._data_store[blob2][0] = (first, -86401)
|
||||||
self.data_store.removed_expired_peers()
|
self.data_store.removed_expired_peers()
|
||||||
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1)
|
||||||
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
||||||
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)
|
self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1)
|
||||||
|
|
||||||
# expire the second and third peers from blob1
|
# expire the second and third peers from blob1
|
||||||
first = self.data_store._data_store[blob2][0]
|
first = self.data_store._data_store[blob2][0][0]
|
||||||
self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4])
|
self.data_store._data_store[blob1][0] = (first, -86401)
|
||||||
second = self.data_store._data_store[blob2][1]
|
second = self.data_store._data_store[blob2][1][0]
|
||||||
self.data_store._data_store[blob1][1] = (second[0], second[1], second[2], -86401, second[4])
|
self.data_store._data_store[blob1][1] = (second, -86401)
|
||||||
self.data_store.removed_expired_peers()
|
self.data_store.removed_expired_peers()
|
||||||
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0)
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0)
|
||||||
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class TestManagedStream(BlobExchangeTestBase):
|
||||||
self.assertTrue(self.stream._running.is_set())
|
self.assertTrue(self.stream._running.is_set())
|
||||||
await asyncio.sleep(0.5, loop=self.loop)
|
await asyncio.sleep(0.5, loop=self.loop)
|
||||||
self.assertTrue(self.stream._running.is_set())
|
self.assertTrue(self.stream._running.is_set())
|
||||||
await asyncio.sleep(0.6, loop=self.loop)
|
await asyncio.sleep(2, loop=self.loop)
|
||||||
self.assertEqual(self.stream.status, "finished")
|
self.assertEqual(self.stream.status, "finished")
|
||||||
self.assertFalse(self.stream._running.is_set())
|
self.assertFalse(self.stream._running.is_set())
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue