diff --git a/lbrynet/blob_exchange/client.py b/lbrynet/blob_exchange/client.py index 4162ef593..df0805e57 100644 --- a/lbrynet/blob_exchange/client.py +++ b/lbrynet/blob_exchange/client.py @@ -111,7 +111,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol): return self._blob_bytes_received, self.close() if not blob_response or blob_response.error: log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port) - return self._blob_bytes_received, self.transport + return self._blob_bytes_received, self.close() if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash: log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port) return self._blob_bytes_received, self.close() diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index b602225c5..76bc6218f 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -14,6 +14,7 @@ log = logging.getLogger(__name__) class BlobDownloader: + BAN_TIME = 10.0 # fixme: when connection manager gets implemented, move it out from here def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', peer_queue: asyncio.Queue): self.loop = loop @@ -21,10 +22,10 @@ class BlobDownloader: self.blob_manager = blob_manager self.peer_queue = peer_queue self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls - self.ignored: typing.Set['KademliaPeer'] = set() + self.ignored: typing.Dict['KademliaPeer', int] = {} self.scores: typing.Dict['KademliaPeer', int] = {} self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} - self.rounds_won: typing.Dict['KademliaPeer', int] = {} + self.time_since_last_blob = loop.time() def should_race_continue(self, blob: 'BlobFile'): if len(self.active_connections) >= self.config.max_connections_per_download: @@ -48,22 +49,25 @@ class BlobDownloader: self.config.blob_download_timeout, connected_transport=transport ) if bytes_received == blob.get_length(): - self.rounds_won[peer] = self.rounds_won.get(peer, 0) + 1 + self.time_since_last_blob = self.loop.time() if not transport and peer not in self.ignored: - self.ignored.add(peer) + self.ignored[peer] = self.loop.time() log.debug("drop peer %s:%i", peer.address, peer.tcp_port) if peer in self.connections: del self.connections[peer] elif transport: log.debug("keep peer %s:%i", peer.address, peer.tcp_port) self.connections[peer] = transport - rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 - self.scores[peer] = rough_speed + rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 + self.scores[peer] = rough_speed async def new_peer_or_finished(self, blob: 'BlobFile'): async def get_and_re_add_peers(): - new_peers = await self.peer_queue.get() - self.peer_queue.put_nowait(new_peers) + try: + new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0) + self.peer_queue.put_nowait(new_peers) + except asyncio.TimeoutError: + pass tasks = [self.loop.create_task(get_and_re_add_peers()), self.loop.create_task(blob.verified.wait())] active_tasks = list(self.active_connections.values()) try: @@ -76,6 +80,15 @@ class BlobDownloader: for peer in to_remove: del self.active_connections[peer] + def clearbanned(self): + now = self.loop.time() + if now - self.time_since_last_blob > 60.0: + return + forgiven = [banned_peer for banned_peer, when in self.ignored.items() if now - when > self.BAN_TIME] + self.peer_queue.put_nowait(forgiven) + for banned_peer in forgiven: + self.ignored.pop(banned_peer) + async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': blob = self.blob_manager.get_blob(blob_hash, length) if blob.get_is_verified(): @@ -101,6 +114,8 @@ class BlobDownloader: self.cleanup_active() if batch: self.peer_queue.put_nowait(set(batch).difference(self.ignored)) + else: + self.clearbanned() while self.active_connections: peer, task = self.active_connections.popitem() if task and not task.done(): @@ -119,6 +134,8 @@ class BlobDownloader: raise e def close(self): + self.scores.clear() + self.ignored.clear() for transport in self.connections.values(): transport.close()