diff --git a/lbrynet/dht/protocol/data_store.py b/lbrynet/dht/protocol/data_store.py index f422c21d5..874a8da60 100644 --- a/lbrynet/dht/protocol/data_store.py +++ b/lbrynet/dht/protocol/data_store.py @@ -9,68 +9,62 @@ if typing.TYPE_CHECKING: class DictDataStore: def __init__(self, loop: asyncio.BaseEventLoop, peer_manager: 'PeerManager'): # Dictionary format: - # { : [, , , ] } - self._data_store: typing.Dict[bytes, - typing.List[typing.Tuple['KademliaPeer', bytes, float, float, bytes]]] = {} - self._get_time = loop.time + # { : [(, ), ...] } + self._data_store: typing.Dict[bytes, typing.List[typing.Tuple['KademliaPeer', float]]] = {} + + self.loop = loop self._peer_manager = peer_manager self.completed_blobs: typing.Set[str] = set() - def filter_bad_and_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']: + def removed_expired_peers(self): + now = self.loop.time() + keys = list(self._data_store.keys()) + for key in keys: + to_remove = [] + for (peer, ts) in self._data_store[key]: + if ts + constants.data_expiration < now or self._peer_manager.peer_is_good(peer) is False: + to_remove.append((peer, ts)) + for item in to_remove: + self._data_store[key].remove(item) + if not self._data_store[key]: + del self._data_store[key] + + def filter_bad_and_expired_peers(self, key: bytes) -> typing.Iterator['KademliaPeer']: """ Returns only non-expired and unknown/good peers """ - peers = [] - for peer in map(lambda p: p[0], - filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, - self._data_store[key])): + for peer in self.filter_expired_peers(key): if self._peer_manager.peer_is_good(peer) is not False: - peers.append(peer) - return peers + yield peer - def filter_expired_peers(self, key: bytes) -> typing.List['KademliaPeer']: + def filter_expired_peers(self, key: bytes) -> typing.Iterator['KademliaPeer']: """ Returns only non-expired peers """ - return list( - map( - lambda p: p[0], - filter(lambda peer: self._get_time() - peer[3] < constants.data_expiration, self._data_store[key]) - ) - ) - - def removed_expired_peers(self): - expired_keys = [] - for key in self._data_store.keys(): - unexpired_peers = self.filter_expired_peers(key) - if not unexpired_peers: - expired_keys.append(key) - else: - self._data_store[key] = [x for x in self._data_store[key] if x[0] in unexpired_peers] - for key in expired_keys: - del self._data_store[key] + now = self.loop.time() + for (peer, ts) in self._data_store.get(key, []): + if ts + constants.data_expiration > now: + yield peer def has_peers_for_blob(self, key: bytes) -> bool: - return key in self._data_store and len(self.filter_bad_and_expired_peers(key)) > 0 + return key in self._data_store - def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes, compact_address: bytes, last_published: int, - originally_published: int, original_publisher_id: bytes) -> None: + def add_peer_to_blob(self, contact: 'KademliaPeer', key: bytes) -> None: + now = self.loop.time() if key in self._data_store: - if compact_address not in map(lambda store_tuple: store_tuple[1], self._data_store[key]): - self._data_store[key].append( - (contact, compact_address, last_published, originally_published, original_publisher_id) - ) + current = list(filter(lambda x: x[0] == contact, self._data_store[key])) + if len(current): + self._data_store[key][self._data_store[key].index(current[0])] = contact, now + else: + self._data_store[key].append((contact, now)) else: - self._data_store[key] = [(contact, compact_address, last_published, originally_published, - original_publisher_id)] + self._data_store[key] = [(contact, now)] def get_peers_for_blob(self, key: bytes) -> typing.List['KademliaPeer']: - return [] if key not in self._data_store else [peer for peer in self.filter_bad_and_expired_peers(key)] + return list(self.filter_bad_and_expired_peers(key)) def get_storing_contacts(self) -> typing.List['KademliaPeer']: peers = set() - for key in self._data_store: - for values in self._data_store[key]: - if values[0] not in peers: - peers.add(values[0]) + for key, stored in self._data_store.items(): + peers.update(set(map(lambda tup: tup[0], stored))) return list(peers) diff --git a/lbrynet/dht/protocol/protocol.py b/lbrynet/dht/protocol/protocol.py index eff0211c2..fdd61be15 100644 --- a/lbrynet/dht/protocol/protocol.py +++ b/lbrynet/dht/protocol/protocol.py @@ -45,19 +45,19 @@ class KademliaRPC: def ping(): return b'pong' - def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int, - original_publisher_id: bytes, age: int) -> bytes: - if original_publisher_id is None: - original_publisher_id = rpc_contact.node_id + def store(self, rpc_contact: 'KademliaPeer', blob_hash: bytes, token: bytes, port: int) -> bytes: + if len(blob_hash) != constants.hash_bits // 8: + raise ValueError(f"invalid length of blob hash: {len(blob_hash)}") + if not 0 < port < 65535: + raise ValueError(f"invalid tcp port: {port}") rpc_contact.update_tcp_port(port) - if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval: - pass - elif not self.verify_token(token, rpc_contact.compact_ip()): - raise ValueError("Invalid token") - now = int(self.loop.time()) - originally_published = now - age + if not self.verify_token(token, rpc_contact.compact_ip()): + if self.loop.time() - self.protocol.started_listening_time < constants.token_secret_refresh_interval: + pass + else: + raise ValueError("Invalid token") self.protocol.data_store.add_peer_to_blob( - rpc_contact, blob_hash, rpc_contact.compact_address_tcp(), now, originally_published, original_publisher_id + rpc_contact, blob_hash ) return b'OK' @@ -416,7 +416,7 @@ class KademliaProtocol(DatagramProtocol): result = self.node_rpc.ping() elif method == b'store': blob_hash, token, port, original_publisher_id, age = a - result = self.node_rpc.store(sender_contact, blob_hash, token, port, original_publisher_id, age) + result = self.node_rpc.store(sender_contact, blob_hash, token, port) elif method == b'findNode': key, = a result = self.node_rpc.find_node(sender_contact, key) diff --git a/tests/unit/dht/protocol/test_data_store.py b/tests/unit/dht/protocol/test_data_store.py index f8d264ffd..0fdaccba3 100644 --- a/tests/unit/dht/protocol/test_data_store.py +++ b/tests/unit/dht/protocol/test_data_store.py @@ -16,10 +16,22 @@ class DataStoreTests(TestCase): peer = self.peer_manager.get_kademlia_peer(node_id, address, udp_port) peer.update_tcp_port(tcp_port) before = self.data_store.get_peers_for_blob(blob) - self.data_store.add_peer_to_blob(peer, blob, peer.compact_address_tcp(), 0, 0, peer.node_id) + self.data_store.add_peer_to_blob(peer, blob) self.assertListEqual(before + [peer], self.data_store.get_peers_for_blob(blob)) return peer + def test_refresh_peer_to_blob(self): + blob = b'f' * 48 + self.assertListEqual([], self.data_store.get_peers_for_blob(blob)) + peer = self._test_add_peer_to_blob(blob=blob, node_id=b'a' * 48, address='1.2.3.4') + self.assertTrue(self.data_store.has_peers_for_blob(blob)) + self.assertEqual(len(self.data_store.get_peers_for_blob(blob)), 1) + self.assertEqual(self.data_store._data_store[blob][0][1], 0) + self.loop.time = lambda: 100.0 + self.assertEqual(self.data_store._data_store[blob][0][1], 0) + self.data_store.add_peer_to_blob(peer, blob) + self.assertEqual(self.data_store._data_store[blob][0][1], 100) + def test_add_peer_to_blob(self, blob=b'f' * 48, peers=None): peers = peers or [ (b'a' * 48, '1.2.3.4'), @@ -67,8 +79,8 @@ class DataStoreTests(TestCase): self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) # expire the first peer from blob1 - first = self.data_store._data_store[blob1][0] - self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4]) + first = self.data_store._data_store[blob1][0][0] + self.data_store._data_store[blob1][0] = (first, -86401) self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) self.data_store.removed_expired_peers() self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1) @@ -76,18 +88,18 @@ class DataStoreTests(TestCase): self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers)) # expire the first peer from blob2 - first = self.data_store._data_store[blob2][0] - self.data_store._data_store[blob2][0] = (first[0], first[1], first[2], -86401, first[4]) + first = self.data_store._data_store[blob2][0][0] + self.data_store._data_store[blob2][0] = (first, -86401) self.data_store.removed_expired_peers() self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), len(peers) - 1) self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1) self.assertEqual(len(self.data_store.get_storing_contacts()), len(peers) - 1) # expire the second and third peers from blob1 - first = self.data_store._data_store[blob2][0] - self.data_store._data_store[blob1][0] = (first[0], first[1], first[2], -86401, first[4]) - second = self.data_store._data_store[blob2][1] - self.data_store._data_store[blob1][1] = (second[0], second[1], second[2], -86401, second[4]) + first = self.data_store._data_store[blob2][0][0] + self.data_store._data_store[blob1][0] = (first, -86401) + second = self.data_store._data_store[blob2][1][0] + self.data_store._data_store[blob1][1] = (second, -86401) self.data_store.removed_expired_peers() self.assertEqual(len(self.data_store.get_peers_for_blob(blob1)), 0) self.assertEqual(len(self.data_store.get_peers_for_blob(blob2)), len(peers) - 1) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py index 0db4af1fe..d59281a80 100644 --- a/tests/unit/stream/test_managed_stream.py +++ b/tests/unit/stream/test_managed_stream.py @@ -72,7 +72,7 @@ class TestManagedStream(BlobExchangeTestBase): self.assertTrue(self.stream._running.is_set()) await asyncio.sleep(0.5, loop=self.loop) self.assertTrue(self.stream._running.is_set()) - await asyncio.sleep(0.6, loop=self.loop) + await asyncio.sleep(2, loop=self.loop) self.assertEqual(self.stream.status, "finished") self.assertFalse(self.stream._running.is_set())