manage connections, reusing them when possible
This commit is contained in:
parent
1be5dce30e
commit
c06ec6cd69
3 changed files with 47 additions and 33 deletions
|
@ -74,7 +74,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
if self._response_fut and not self._response_fut.done():
|
if self._response_fut and not self._response_fut.done():
|
||||||
self._response_fut.set_exception(err)
|
self._response_fut.set_exception(err)
|
||||||
|
|
||||||
async def _download_blob(self) -> typing.Tuple[bool, bool]:
|
async def _download_blob(self) -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
|
||||||
"""
|
"""
|
||||||
:return: download success (bool), keep connection (bool)
|
:return: download success (bool), keep connection (bool)
|
||||||
"""
|
"""
|
||||||
|
@ -92,24 +92,24 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address,
|
log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address,
|
||||||
self.peer_port)
|
self.peer_port)
|
||||||
log.warning(response.to_dict())
|
log.warning(response.to_dict())
|
||||||
return False, False
|
return False, self.close()
|
||||||
elif availability_response.available_blobs and \
|
elif availability_response.available_blobs and \
|
||||||
availability_response.available_blobs != [self.blob.blob_hash]:
|
availability_response.available_blobs != [self.blob.blob_hash]:
|
||||||
log.warning("blob availability response doesn't match our request from %s:%i",
|
log.warning("blob availability response doesn't match our request from %s:%i",
|
||||||
self.peer_address, self.peer_port)
|
self.peer_address, self.peer_port)
|
||||||
return False, False
|
return False, self.close()
|
||||||
if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED':
|
if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED':
|
||||||
log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port)
|
log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port)
|
||||||
return False, False
|
return False, self.close()
|
||||||
if not blob_response or blob_response.error:
|
if not blob_response or blob_response.error:
|
||||||
log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port)
|
log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port)
|
||||||
return False, True
|
return False, self.transport
|
||||||
if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash:
|
if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash:
|
||||||
log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port)
|
log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port)
|
||||||
return False, False
|
return False, self.close()
|
||||||
if self.blob.length is not None and self.blob.length != blob_response.length:
|
if self.blob.length is not None and self.blob.length != blob_response.length:
|
||||||
log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port)
|
log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port)
|
||||||
return False, False
|
return False, self.close()
|
||||||
msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \
|
msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \
|
||||||
f" timeout in {self.peer_timeout}"
|
f" timeout in {self.peer_timeout}"
|
||||||
log.debug(msg)
|
log.debug(msg)
|
||||||
|
@ -117,16 +117,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop)
|
await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop)
|
||||||
log.info(msg)
|
log.info(msg)
|
||||||
await self.blob.finished_writing.wait()
|
await self.blob.finished_writing.wait()
|
||||||
return True, True
|
return True, self.transport
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return False, False
|
return False, self.close()
|
||||||
except (InvalidBlobHashError, InvalidDataError):
|
except (InvalidBlobHashError, InvalidDataError):
|
||||||
log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port)
|
log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port)
|
||||||
return False, False
|
return False, self.close()
|
||||||
finally:
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
async def close(self):
|
def close(self):
|
||||||
if self._response_fut and not self._response_fut.done():
|
if self._response_fut and not self._response_fut.done():
|
||||||
self._response_fut.cancel()
|
self._response_fut.cancel()
|
||||||
if self.writer and not self.writer.closed():
|
if self.writer and not self.writer.closed():
|
||||||
|
@ -138,9 +136,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
|
||||||
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, bool]:
|
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return False, True
|
return False, self.transport
|
||||||
try:
|
try:
|
||||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
|
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
|
||||||
self._response_fut = asyncio.Future(loop=self.loop)
|
self._response_fut = asyncio.Future(loop=self.loop)
|
||||||
|
@ -148,14 +146,14 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
except OSError:
|
except OSError:
|
||||||
log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port)
|
log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port)
|
||||||
# i'm not sure how to fix this race condition - jack
|
# i'm not sure how to fix this race condition - jack
|
||||||
return False, True
|
return False, self.transport
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if self._response_fut and not self._response_fut.done():
|
if self._response_fut and not self._response_fut.done():
|
||||||
self._response_fut.cancel()
|
self._response_fut.cancel()
|
||||||
return False, False
|
self.close()
|
||||||
|
return False, None
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
if self._response_fut and not self._response_fut.done():
|
self.close()
|
||||||
self._response_fut.cancel()
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def connection_made(self, transport: asyncio.Transport):
|
def connection_made(self, transport: asyncio.Transport):
|
||||||
|
@ -166,24 +164,29 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
def connection_lost(self, reason):
|
def connection_lost(self, reason):
|
||||||
log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason),
|
log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason),
|
||||||
str(type(reason)))
|
str(type(reason)))
|
||||||
self.transport = None
|
self.close()
|
||||||
self.loop.create_task(self.close())
|
|
||||||
|
|
||||||
|
|
||||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
|
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
|
||||||
peer_connect_timeout: float, blob_download_timeout: float) -> typing.Tuple[bool, bool]:
|
peer_connect_timeout: float, blob_download_timeout: float,
|
||||||
|
connected_transport: asyncio.Transport = None)\
|
||||||
|
-> typing.Tuple[bool, typing.Optional[asyncio.Transport]]:
|
||||||
"""
|
"""
|
||||||
Returns [<downloaded blob>, <keep connection>]
|
Returns [<downloaded blob>, <keep connection>]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
|
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return False, True
|
return False, connected_transport
|
||||||
|
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
|
||||||
|
if connected_transport and not connected_transport.is_closing():
|
||||||
|
connected_transport.set_protocol(protocol)
|
||||||
|
protocol.connection_made(connected_transport)
|
||||||
|
else:
|
||||||
|
connected_transport = None
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
|
if not connected_transport:
|
||||||
peer_connect_timeout, loop=loop)
|
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
|
||||||
|
peer_connect_timeout, loop=loop)
|
||||||
return await protocol.download_blob(blob)
|
return await protocol.download_blob(blob)
|
||||||
except (asyncio.TimeoutError, ConnectionRefusedError, ConnectionAbortedError, OSError):
|
except (asyncio.TimeoutError, ConnectionRefusedError, ConnectionAbortedError, OSError):
|
||||||
return False, False
|
return False, None
|
||||||
finally:
|
|
||||||
await protocol.close()
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ class BlobDownloader:
|
||||||
self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls
|
self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls
|
||||||
self.ignored: typing.Set['KademliaPeer'] = set()
|
self.ignored: typing.Set['KademliaPeer'] = set()
|
||||||
self.scores: typing.Dict['KademliaPeer', int] = {}
|
self.scores: typing.Dict['KademliaPeer', int] = {}
|
||||||
|
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
|
||||||
|
|
||||||
def should_race_continue(self):
|
def should_race_continue(self):
|
||||||
if len(self.active_connections) >= self.config.max_connections_per_download:
|
if len(self.active_connections) >= self.config.max_connections_per_download:
|
||||||
|
@ -38,15 +39,19 @@ class BlobDownloader:
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return
|
return
|
||||||
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
||||||
success, keep_connection = await request_blob(
|
transport = self.connections.get(peer)
|
||||||
|
success, transport = await request_blob(
|
||||||
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
|
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
|
||||||
self.config.blob_download_timeout
|
self.config.blob_download_timeout, connected_transport=transport
|
||||||
)
|
)
|
||||||
if not keep_connection and peer not in self.ignored:
|
if not transport and peer not in self.ignored:
|
||||||
self.ignored.add(peer)
|
self.ignored.add(peer)
|
||||||
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
|
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
|
||||||
elif keep_connection:
|
if peer in self.connections:
|
||||||
|
del self.connections[peer]
|
||||||
|
elif transport:
|
||||||
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
|
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
|
||||||
|
self.connections[peer] = transport
|
||||||
self.scores[peer] = (self.scores.get(peer, 0) + 2) if success else 0
|
self.scores[peer] = (self.scores.get(peer, 0) + 2) if success else 0
|
||||||
|
|
||||||
async def new_peer_or_finished(self, blob: 'BlobFile'):
|
async def new_peer_or_finished(self, blob: 'BlobFile'):
|
||||||
|
@ -107,6 +112,10 @@ class BlobDownloader:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
for transport in self.connections.values():
|
||||||
|
transport.close()
|
||||||
|
|
||||||
|
|
||||||
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
|
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
|
||||||
blob_hash: str) -> 'BlobFile':
|
blob_hash: str) -> 'BlobFile':
|
||||||
|
@ -119,3 +128,4 @@ async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager',
|
||||||
finally:
|
finally:
|
||||||
if accumulate_task and not accumulate_task.done():
|
if accumulate_task and not accumulate_task.done():
|
||||||
accumulate_task.cancel()
|
accumulate_task.cancel()
|
||||||
|
downloader.close()
|
||||||
|
|
|
@ -51,6 +51,7 @@ class StreamDownloader(StreamAssembler):
|
||||||
async def after_finished(self):
|
async def after_finished(self):
|
||||||
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
||||||
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
||||||
|
self.blob_downloader.close()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.accumulate_task:
|
if self.accumulate_task:
|
||||||
|
|
Loading…
Reference in a new issue