remove lock from PeerManager

This commit is contained in:
Jack Robison 2019-01-31 20:43:19 -05:00
parent 16d0ff8376
commit 31445c7797
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 44 additions and 65 deletions

View file

@ -30,37 +30,30 @@ class PeerManager:
self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {} self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = {}
self._node_tokens: typing.Dict[bytes, (float, bytes)] = {} self._node_tokens: typing.Dict[bytes, (float, bytes)] = {}
self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer'] self._kademlia_peers: typing.Dict[typing.Tuple[bytes, str, int], 'KademliaPeer']
self._lock = asyncio.Lock(loop=loop)
async def report_failure(self, address: str, udp_port: int): def report_failure(self, address: str, udp_port: int):
now = self._loop.time() now = self._loop.time()
async with self._lock: _, previous = self._rpc_failures.pop((address, udp_port), (None, None))
_, previous = self._rpc_failures.pop((address, udp_port), (None, None)) self._rpc_failures[(address, udp_port)] = (previous, now)
self._rpc_failures[(address, udp_port)] = (previous, now)
async def report_last_sent(self, address: str, udp_port: int): def report_last_sent(self, address: str, udp_port: int):
now = self._loop.time() now = self._loop.time()
async with self._lock: self._last_sent[(address, udp_port)] = now
self._last_sent[(address, udp_port)] = now
async def report_last_replied(self, address: str, udp_port: int): def report_last_replied(self, address: str, udp_port: int):
now = self._loop.time() now = self._loop.time()
async with self._lock: self._last_replied[(address, udp_port)] = now
self._last_replied[(address, udp_port)] = now
async def report_last_requested(self, address: str, udp_port: int): def report_last_requested(self, address: str, udp_port: int):
now = self._loop.time() now = self._loop.time()
async with self._lock: self._last_requested[(address, udp_port)] = now
self._last_requested[(address, udp_port)] = now
async def clear_token(self, node_id: bytes): def clear_token(self, node_id: bytes):
async with self._lock: self._node_tokens.pop(node_id, None)
self._node_tokens.pop(node_id, None)
async def update_token(self, node_id: bytes, token: bytes): def update_token(self, node_id: bytes, token: bytes):
now = self._loop.time() now = self._loop.time()
async with self._lock: self._node_tokens[node_id] = (now, token)
self._node_tokens[node_id] = (now, token)
def get_node_token(self, node_id: bytes) -> typing.Optional[bytes]: def get_node_token(self, node_id: bytes) -> typing.Optional[bytes]:
ts, token = self._node_tokens.get(node_id, (None, None)) ts, token = self._node_tokens.get(node_id, (None, None))
@ -70,50 +63,36 @@ class PeerManager:
def get_last_replied(self, address: str, udp_port: int) -> typing.Optional[float]: def get_last_replied(self, address: str, udp_port: int) -> typing.Optional[float]:
return self._last_replied.get((address, udp_port)) return self._last_replied.get((address, udp_port))
def get_node_id(self, address: str, udp_port: int) -> typing.Optional[bytes]: def update_contact_triple(self, node_id: bytes, address: str, udp_port: int):
return self._node_id_mapping.get((address, udp_port))
def get_node_address(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
return self._node_id_reverse_mapping.get(node_id)
async def get_node_address_and_port(self, node_id: bytes) -> typing.Optional[typing.Tuple[str, int]]:
async with self._lock:
addr_tuple = self._node_id_reverse_mapping.get(node_id)
if addr_tuple and addr_tuple in self._node_id_mapping:
return addr_tuple
async def update_contact_triple(self, node_id: bytes, address: str, udp_port: int):
""" """
Update the mapping of node_id -> address tuple and that of address tuple -> node_id Update the mapping of node_id -> address tuple and that of address tuple -> node_id
This is to handle peers changing addresses and ids while assuring that the we only ever have This is to handle peers changing addresses and ids while assuring that the we only ever have
one node id / address tuple mapped to each other one node id / address tuple mapped to each other
""" """
async with self._lock: if (address, udp_port) in self._node_id_mapping:
if (address, udp_port) in self._node_id_mapping: self._node_id_reverse_mapping.pop(self._node_id_mapping.pop((address, udp_port)))
self._node_id_reverse_mapping.pop(self._node_id_mapping.pop((address, udp_port))) if node_id in self._node_id_reverse_mapping:
if node_id in self._node_id_reverse_mapping: self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id))
self._node_id_mapping.pop(self._node_id_reverse_mapping.pop(node_id)) self._node_id_mapping[(address, udp_port)] = node_id
self._node_id_mapping[(address, udp_port)] = node_id self._node_id_reverse_mapping[node_id] = (address, udp_port)
self._node_id_reverse_mapping[node_id] = (address, udp_port)
def get_kademlia_peer(self, node_id: bytes, address: str, udp_port: int) -> 'KademliaPeer': def get_kademlia_peer(self, node_id: bytes, address: str, udp_port: int) -> 'KademliaPeer':
return KademliaPeer(self._loop, address, node_id, udp_port) return KademliaPeer(self._loop, address, node_id, udp_port)
async def prune(self): def prune(self): # TODO: periodically call this
now = self._loop.time() now = self._loop.time()
async with self._lock: to_pop = []
to_pop = [] for (address, udp_port), (_, last_failure) in self._rpc_failures.items():
for (address, udp_port), (_, last_failure) in self._rpc_failures.items(): if last_failure and last_failure < now - constants.rpc_attempts_pruning_window:
if last_failure and last_failure < now - constants.rpc_attempts_pruning_window: to_pop.append((address, udp_port))
to_pop.append((address, udp_port)) while to_pop:
while to_pop: del self._rpc_failures[to_pop.pop()]
del self._rpc_failures[to_pop.pop()] to_pop = []
to_pop = [] for node_id, (age, token) in self._node_tokens.items():
for node_id, (age, token) in self._node_tokens.items(): if age < now - constants.token_secret_refresh_interval:
if age < now - constants.token_secret_refresh_interval: to_pop.append(node_id)
to_pop.append(node_id) while to_pop:
while to_pop: del self._node_tokens[to_pop.pop()]
del self._node_tokens[to_pop.pop()]
def contact_triple_is_good(self, node_id: bytes, address: str, udp_port: int): def contact_triple_is_good(self, node_id: bytes, address: str, udp_port: int):
""" """

View file

@ -180,7 +180,7 @@ class RemoteKademliaRPC:
response = await self.protocol.send_request( response = await self.protocol.send_request(
self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key) self.peer, RequestDatagram.make_find_value(self.protocol.node_id, key)
) )
await self.peer_tracker.update_token(self.peer.node_id, response.response[b'token']) self.peer_tracker.update_token(self.peer.node_id, response.response[b'token'])
return response.response return response.response
@ -415,8 +415,8 @@ class KademliaProtocol(DatagramProtocol):
async def handle_request_datagram(self, address, request_datagram: RequestDatagram): async def handle_request_datagram(self, address, request_datagram: RequestDatagram):
# This is an RPC method request # This is an RPC method request
await self.peer_manager.report_last_requested(address[0], address[1]) self.peer_manager.report_last_requested(address[0], address[1])
await self.peer_manager.update_contact_triple(request_datagram.node_id, address[0], address[1]) self.peer_manager.update_contact_triple(request_datagram.node_id, address[0], address[1])
# only add a requesting contact to the routing table if it has replied to one of our requests # only add a requesting contact to the routing table if it has replied to one of our requests
peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1]) peer = self.peer_manager.get_kademlia_peer(request_datagram.node_id, address[0], address[1])
try: try:
@ -457,8 +457,8 @@ class KademliaProtocol(DatagramProtocol):
elif response_datagram.node_id == self.node_id: elif response_datagram.node_id == self.node_id:
df.set_exception(RemoteException("incoming message is from our node id")) df.set_exception(RemoteException("incoming message is from our node id"))
return return
await self.peer_manager.report_last_replied(address[0], address[1]) self.peer_manager.report_last_replied(address[0], address[1])
await self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1]) self.peer_manager.update_contact_triple(peer.node_id, address[0], address[1])
if not df.cancelled(): if not df.cancelled():
df.set_result(response_datagram) df.set_result(response_datagram)
await self.add_peer(peer) await self.add_peer(peer)
@ -505,7 +505,7 @@ class KademliaProtocol(DatagramProtocol):
try: try:
message = decode_datagram(datagram) message = decode_datagram(datagram)
except (ValueError, TypeError): except (ValueError, TypeError):
self.loop.create_task(self.peer_manager.report_failure(address[0], address[1])) self.peer_manager.report_failure(address[0], address[1])
log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode()) log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode())
return return
@ -522,10 +522,10 @@ class KademliaProtocol(DatagramProtocol):
response_fut = self.sent_messages[request.rpc_id][1] response_fut = self.sent_messages[request.rpc_id][1]
try: try:
response = await asyncio.wait_for(response_fut, self.rpc_timeout) response = await asyncio.wait_for(response_fut, self.rpc_timeout)
await self.peer_manager.report_last_replied(peer.address, peer.udp_port) self.peer_manager.report_last_replied(peer.address, peer.udp_port)
return response return response
except (asyncio.TimeoutError, RemoteException): except (asyncio.TimeoutError, RemoteException):
await self.peer_manager.report_failure(peer.address, peer.udp_port) self.peer_manager.report_failure(peer.address, peer.udp_port)
if self.peer_manager.peer_is_good(peer) is False: if self.peer_manager.peer_is_good(peer) is False:
self.routing_table.remove_peer(peer) self.routing_table.remove_peer(peer)
raise raise
@ -575,9 +575,9 @@ class KademliaProtocol(DatagramProtocol):
else: else:
raise err raise err
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
await self.peer_manager.report_last_sent(peer.address, peer.udp_port) self.peer_manager.report_last_sent(peer.address, peer.udp_port)
elif isinstance(message, ErrorDatagram): elif isinstance(message, ErrorDatagram):
await self.peer_manager.report_failure(peer.address, peer.udp_port) self.peer_manager.report_failure(peer.address, peer.udp_port)
def change_token(self): def change_token(self):
self.old_token_secret = self.token_secret self.old_token_secret = self.token_secret
@ -609,7 +609,7 @@ class KademliaProtocol(DatagramProtocol):
log.error("Unexpected response: %s" % err) log.error("Unexpected response: %s" % err)
except Exception as err: except Exception as err:
if 'Invalid token' in str(err): if 'Invalid token' in str(err):
await self.peer_manager.clear_token(peer.node_id) self.peer_manager.clear_token(peer.node_id)
else: else:
log.exception("Unexpected error while storing blob_hash") log.exception("Unexpected error while storing blob_hash")
return peer.node_id, False return peer.node_id, False