manage connections, reusing them when possible

This commit is contained in:
Victor Shyba 2019-02-08 02:27:58 -03:00
parent 1be5dce30e
commit c06ec6cd69
3 changed files with 47 additions and 33 deletions

View file

@ -74,7 +74,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
if self._response_fut and not self._response_fut.done():
self._response_fut.set_exception(err)
async def _download_blob(self) -> typing.Tuple[bool, bool]:
async def _download_blob(self) -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
"""
:return: download success (bool), keep connection (bool)
"""
@ -92,24 +92,24 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address,
self.peer_port)
log.warning(response.to_dict())
return False, False
return False, self.close()
elif availability_response.available_blobs and \
availability_response.available_blobs != [self.blob.blob_hash]:
log.warning("blob availability response doesn't match our request from %s:%i",
self.peer_address, self.peer_port)
return False, False
return False, self.close()
if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED':
log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port)
return False, False
return False, self.close()
if not blob_response or blob_response.error:
log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port)
return False, True
return False, self.transport
if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash:
log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port)
return False, False
return False, self.close()
if self.blob.length is not None and self.blob.length != blob_response.length:
log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port)
return False, False
return False, self.close()
msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \
f" timeout in {self.peer_timeout}"
log.debug(msg)
@ -117,16 +117,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop)
log.info(msg)
await self.blob.finished_writing.wait()
return True, True
return True, self.transport
except asyncio.TimeoutError:
return False, False
return False, self.close()
except (InvalidBlobHashError, InvalidDataError):
log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port)
return False, False
finally:
await self.close()
return False, self.close()
async def close(self):
def close(self):
if self._response_fut and not self._response_fut.done():
self._response_fut.cancel()
if self.writer and not self.writer.closed():
@ -138,9 +136,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.transport.close()
self.transport = None
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, bool]:
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
if blob.get_is_verified():
return False, True
return False, self.transport
try:
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
self._response_fut = asyncio.Future(loop=self.loop)
@ -148,14 +146,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
except OSError:
log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port)
# i'm not sure how to fix this race condition - jack
return False, True
return False, self.transport
except asyncio.TimeoutError:
if self._response_fut and not self._response_fut.done():
self._response_fut.cancel()
return False, False
self.close()
return False, None
except asyncio.CancelledError:
if self._response_fut and not self._response_fut.done():
self._response_fut.cancel()
self.close()
raise
def connection_made(self, transport: asyncio.Transport):
@ -166,24 +164,29 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
def connection_lost(self, reason):
log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason),
str(type(reason)))
self.transport = None
self.loop.create_task(self.close())
self.close()
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float) -> typing.Tuple[bool, bool]:
peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\
-> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
"""
Returns [<downloaded blob>, <keep connection>]
"""
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
if blob.get_is_verified():
return False, True
return False, connected_transport
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
if connected_transport and not connected_transport.is_closing():
connected_transport.set_protocol(protocol)
protocol.connection_made(connected_transport)
else:
connected_transport = None
try:
if not connected_transport:
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
peer_connect_timeout, loop=loop)
return await protocol.download_blob(blob)
except (asyncio.TimeoutError, ConnectionRefusedError, ConnectionAbortedError, OSError):
return False, False
finally:
await protocol.close()
return False, None

View file

@ -23,6 +23,7 @@ class BlobDownloader:
self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls
self.ignored: typing.Set['KademliaPeer'] = set()
self.scores: typing.Dict['KademliaPeer', int] = {}
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
def should_race_continue(self):
if len(self.active_connections) >= self.config.max_connections_per_download:
@ -38,15 +39,19 @@ class BlobDownloader:
if blob.get_is_verified():
return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
success, keep_connection = await request_blob(
transport = self.connections.get(peer)
success, transport = await request_blob(
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
self.config.blob_download_timeout
self.config.blob_download_timeout, connected_transport=transport
)
if not keep_connection and peer not in self.ignored:
if not transport and peer not in self.ignored:
self.ignored.add(peer)
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
elif keep_connection:
if peer in self.connections:
del self.connections[peer]
elif transport:
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
self.connections[peer] = transport
self.scores[peer] = (self.scores.get(peer, 0) + 2) if success else 0
async def new_peer_or_finished(self, blob: 'BlobFile'):
@ -107,6 +112,10 @@ class BlobDownloader:
log.exception(e)
raise e
def close(self):
for transport in self.connections.values():
transport.close()
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
blob_hash: str) -> 'BlobFile':
@ -119,3 +128,4 @@ async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager',
finally:
if accumulate_task and not accumulate_task.done():
accumulate_task.cancel()
downloader.close()

View file

@ -51,6 +51,7 @@ class StreamDownloader(StreamAssembler):
async def after_finished(self):
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
self.blob_downloader.close()
def stop(self):
if self.accumulate_task: