forked from LBRYCommunity/lbry-sdk
refactor downloader
split peer accumulation out, use Queues instead of locks
This commit is contained in:
parent
16efe9ba95
commit
6aef6a80b7
7 changed files with 181 additions and 347 deletions
|
@ -80,7 +80,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
if (not blob_response or blob_response.error) and\
|
||||
(not availability_response or not availability_response.available_blobs):
|
||||
log.warning("blob not in availability response from %s:%i", self.peer_address, self.peer_port)
|
||||
return False, True
|
||||
log.warning(response.to_dict())
|
||||
return False, False
|
||||
elif availability_response.available_blobs and \
|
||||
availability_response.available_blobs != [self.blob.blob_hash]:
|
||||
log.warning("blob availability response doesn't match our request from %s:%i",
|
||||
|
@ -160,11 +161,13 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self.loop.create_task(self.close())
|
||||
|
||||
|
||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', protocol: 'BlobExchangeClientProtocol',
|
||||
address: str, tcp_port: int, peer_connect_timeout: float) -> typing.Tuple[bool, bool]:
|
||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
|
||||
peer_connect_timeout: float, blob_download_timeout: float) -> typing.Tuple[bool, bool]:
|
||||
"""
|
||||
Returns [<downloaded blob>, <keep connection>]
|
||||
"""
|
||||
|
||||
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
|
||||
if blob.get_is_verified():
|
||||
return False, True
|
||||
try:
|
||||
|
@ -173,3 +176,5 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', protocol:
|
|||
return await protocol.download_blob(blob)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionRefusedError, ConnectionAbortedError, OSError):
|
||||
return False, False
|
||||
finally:
|
||||
await protocol.close()
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
from lbrynet import conf
|
||||
from lbrynet.utils import drain_tasks
|
||||
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
|
||||
from lbrynet.blob_exchange.client import request_blob
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
|
@ -18,115 +18,91 @@ def drain_into(a: list, b: list):
|
|||
b.append(a.pop())
|
||||
|
||||
|
||||
class BlobDownloader: # TODO: refactor to be the base class used by StreamDownloader
|
||||
"""A single blob downloader"""
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', config: conf.Config):
|
||||
class BlobDownloader:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager',
|
||||
peer_queue: asyncio.Queue):
|
||||
self.loop = loop
|
||||
self.config = config
|
||||
self.blob_manager = blob_manager
|
||||
self.new_peer_event = asyncio.Event(loop=self.loop)
|
||||
self.active_connections: typing.Dict['KademliaPeer', BlobExchangeClientProtocol] = {}
|
||||
self.running_download_requests: typing.List[asyncio.Task] = []
|
||||
self.requested_from: typing.Dict[str, typing.Dict['KademliaPeer', asyncio.Task]] = {}
|
||||
self.lock = asyncio.Lock(loop=self.loop)
|
||||
self.blob: 'BlobFile' = None
|
||||
self.blob_queue = asyncio.Queue(loop=self.loop)
|
||||
self.peer_queue = peer_queue
|
||||
self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls
|
||||
self.ignored: typing.Set['KademliaPeer'] = set()
|
||||
|
||||
self.blob_download_timeout = config.blob_download_timeout
|
||||
self.peer_connect_timeout = config.peer_connect_timeout
|
||||
self.max_connections = config.max_connections_per_download
|
||||
@property
|
||||
def blob_download_timeout(self):
|
||||
return self.config.blob_download_timeout
|
||||
|
||||
async def _request_blob(self, peer: 'KademliaPeer'):
|
||||
if self.blob.get_is_verified():
|
||||
log.info("already verified")
|
||||
return
|
||||
if peer not in self.active_connections:
|
||||
log.warning("not active, adding: %s", str(peer))
|
||||
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.blob_download_timeout)
|
||||
protocol = self.active_connections[peer]
|
||||
success, keep_connection = await request_blob(self.loop, self.blob, protocol, peer.address, peer.tcp_port,
|
||||
self.peer_connect_timeout)
|
||||
await protocol.close()
|
||||
if not keep_connection:
|
||||
log.info("drop peer %s:%i", peer.address, peer.tcp_port)
|
||||
if peer in self.active_connections:
|
||||
async with self.lock:
|
||||
del self.active_connections[peer]
|
||||
return
|
||||
log.info("keep peer %s:%i", peer.address, peer.tcp_port)
|
||||
@property
|
||||
def peer_connect_timeout(self):
|
||||
return self.config.peer_connect_timeout
|
||||
|
||||
def _update_requests(self):
|
||||
self.new_peer_event.clear()
|
||||
if self.blob.blob_hash not in self.requested_from:
|
||||
self.requested_from[self.blob.blob_hash] = {}
|
||||
to_add = []
|
||||
for peer in self.active_connections.keys():
|
||||
if peer not in self.requested_from[self.blob.blob_hash] and peer not in to_add:
|
||||
to_add.append(peer)
|
||||
if to_add or self.running_download_requests:
|
||||
log.info("adding download probes for %i peers to %i already active",
|
||||
min(len(to_add), 8 - len(self.running_download_requests)),
|
||||
len(self.running_download_requests))
|
||||
else:
|
||||
log.info("downloader idle...")
|
||||
for peer in to_add:
|
||||
if len(self.running_download_requests) >= 8:
|
||||
break
|
||||
task = self.loop.create_task(self._request_blob(peer))
|
||||
self.requested_from[self.blob.blob_hash][peer] = task
|
||||
self.running_download_requests.append(task)
|
||||
@property
|
||||
def max_connections(self):
|
||||
return self.config.max_connections_per_download
|
||||
|
||||
def _add_peer_protocols(self, peers: typing.List['KademliaPeer']):
|
||||
added = 0
|
||||
for peer in peers:
|
||||
if peer not in self.active_connections:
|
||||
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.blob_download_timeout)
|
||||
added += 1
|
||||
if added:
|
||||
if not self.new_peer_event.is_set():
|
||||
log.info("added %i new peers", len(peers))
|
||||
self.new_peer_event.set()
|
||||
def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'):
|
||||
async def _request_blob():
|
||||
if blob.get_is_verified():
|
||||
return
|
||||
success, keep_connection = await request_blob(self.loop, blob, peer.address, peer.tcp_port,
|
||||
self.peer_connect_timeout, self.blob_download_timeout)
|
||||
if not keep_connection and peer not in self.ignored:
|
||||
self.ignored.add(peer)
|
||||
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
|
||||
elif keep_connection:
|
||||
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
|
||||
return self.loop.create_task(_request_blob())
|
||||
|
||||
async def _accumulate_connections(self, node: 'Node'):
|
||||
async def new_peer_or_finished(self, blob: 'BlobFile'):
|
||||
async def get_and_re_add_peers():
|
||||
new_peers = await self.peer_queue.get()
|
||||
self.peer_queue.put_nowait(new_peers)
|
||||
tasks = [self.loop.create_task(get_and_re_add_peers()), self.loop.create_task(blob.verified.wait())]
|
||||
try:
|
||||
async with node.stream_peer_search_junction(self.blob_queue) as search_junction:
|
||||
async for peers in search_junction:
|
||||
if not isinstance(peers, list): # TODO: what's up with this?
|
||||
log.error("not a list: %s %s", peers, str(type(peers)))
|
||||
else:
|
||||
self._add_peer_protocols(peers)
|
||||
return
|
||||
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED')
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
drain_tasks(tasks)
|
||||
|
||||
async def get_blob(self, blob_hash: str, node: 'Node') -> 'BlobFile':
|
||||
self.blob = self.blob_manager.get_blob(blob_hash)
|
||||
if self.blob.get_is_verified():
|
||||
return self.blob
|
||||
accumulator = self.loop.create_task(self._accumulate_connections(node))
|
||||
self.blob_queue.put_nowait(blob_hash)
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||
if blob.get_is_verified():
|
||||
return blob
|
||||
try:
|
||||
while not self.blob.get_is_verified():
|
||||
if len(self.running_download_requests) < self.max_connections:
|
||||
self._update_requests()
|
||||
|
||||
# drain the tasks into a temporary list
|
||||
download_tasks = []
|
||||
drain_into(self.running_download_requests, download_tasks)
|
||||
got_new_peer = self.loop.create_task(self.new_peer_event.wait())
|
||||
|
||||
# wait for a new peer to be added or for a download attempt to finish
|
||||
await asyncio.wait([got_new_peer] + download_tasks, return_when='FIRST_COMPLETED',
|
||||
loop=self.loop)
|
||||
if got_new_peer and not got_new_peer.done():
|
||||
got_new_peer.cancel()
|
||||
if self.blob.get_is_verified():
|
||||
if got_new_peer and not got_new_peer.done():
|
||||
got_new_peer.cancel()
|
||||
drain_tasks(download_tasks)
|
||||
return self.blob
|
||||
while not blob.get_is_verified():
|
||||
batch: typing.List['KademliaPeer'] = []
|
||||
while not self.peer_queue.empty():
|
||||
batch.extend(await self.peer_queue.get())
|
||||
for peer in batch:
|
||||
if peer not in self.active_connections and peer not in self.ignored:
|
||||
log.info("add request %s", blob_hash[:8])
|
||||
self.active_connections[peer] = self.request_blob_from_peer(blob, peer)
|
||||
await self.new_peer_or_finished(blob)
|
||||
log.info("new peer or finished %s", blob_hash[:8])
|
||||
to_re_add = list(set(filter(lambda peer: peer not in self.ignored, batch)))
|
||||
if to_re_add:
|
||||
self.peer_queue.put_nowait(to_re_add)
|
||||
log.info("finished %s", blob_hash[:8])
|
||||
while self.active_connections:
|
||||
peer, task = self.active_connections.popitem()
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
return blob
|
||||
except asyncio.CancelledError:
|
||||
drain_tasks(self.running_download_requests)
|
||||
while self.active_connections:
|
||||
peer, task = self.active_connections.popitem()
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
raise
|
||||
finally:
|
||||
if accumulator and not accumulator.done():
|
||||
accumulator.cancel()
|
||||
|
||||
|
||||
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
|
||||
blob_hash: str) -> 'BlobFile':
|
||||
search_queue = asyncio.Queue(loop=loop)
|
||||
search_queue.put_nowait(blob_hash)
|
||||
peer_queue, accumulate_task = node.accumulate_peers(search_queue)
|
||||
downloader = BlobDownloader(loop, config, blob_manager, peer_queue)
|
||||
try:
|
||||
return await downloader.download_blob(blob_hash)
|
||||
finally:
|
||||
if accumulate_task and not accumulate_task.done():
|
||||
accumulate_task.cancel()
|
||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
|||
import math
|
||||
import binascii
|
||||
import typing
|
||||
import socket
|
||||
from hashlib import sha256
|
||||
from types import SimpleNamespace
|
||||
import base58
|
||||
|
@ -18,7 +17,6 @@ import lbrynet.schema
|
|||
from lbrynet import utils
|
||||
from lbrynet.conf import HEADERS_FILE_SHA256_CHECKSUM
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.dht.blob_announcer import BlobAnnouncer
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.blob_exchange.server import BlobServer
|
||||
|
@ -65,14 +63,6 @@ async def get_external_ip(): # used if upnp is disabled or non-functioning
|
|||
pass
|
||||
|
||||
|
||||
async def resolve_host(loop: asyncio.BaseEventLoop, url: str):
|
||||
info = await loop.getaddrinfo(
|
||||
url, 'https',
|
||||
proto=socket.IPPROTO_TCP,
|
||||
)
|
||||
return info[0][4][0]
|
||||
|
||||
|
||||
class DatabaseComponent(Component):
|
||||
component_name = DATABASE_COMPONENT
|
||||
|
||||
|
@ -463,11 +453,7 @@ class StreamManagerComponent(Component):
|
|||
log.info('Starting the file manager')
|
||||
loop = asyncio.get_event_loop()
|
||||
self.stream_manager = StreamManager(
|
||||
loop, blob_manager, wallet, storage, node, self.conf.blob_download_timeout,
|
||||
self.conf.peer_connect_timeout, [
|
||||
KademliaPeer(loop, address=(await resolve_host(loop, url)), tcp_port=port + 1)
|
||||
for url, port in self.conf.reflector_servers
|
||||
], self.conf.reflector_servers
|
||||
loop, self.conf, blob_manager, wallet, storage, node,
|
||||
)
|
||||
await self.stream_manager.start()
|
||||
log.info('Done setting up file manager')
|
||||
|
|
|
@ -18,7 +18,7 @@ from torba.client.baseaccount import SingleKey, HierarchicalDeterministic
|
|||
from lbrynet import __version__, utils
|
||||
from lbrynet.conf import Config, Setting, SLACK_WEBHOOK
|
||||
from lbrynet.blob.blob_file import is_valid_blobhash
|
||||
from lbrynet.blob_exchange.downloader import BlobDownloader
|
||||
from lbrynet.blob_exchange.downloader import download_blob
|
||||
from lbrynet.error import InsufficientFundsError, DownloadSDTimeout, ComponentsNotStarted
|
||||
from lbrynet.error import NullFundsError, NegativeFundsError, ResolveError, ComponentStartConditionNotMet
|
||||
from lbrynet.extras import system_info
|
||||
|
@ -1582,7 +1582,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
stream = existing[0]
|
||||
else:
|
||||
stream = await self.stream_manager.download_stream_from_claim(
|
||||
self.dht_node, self.conf, resolved, file_name, timeout, fee_amount, fee_address
|
||||
self.dht_node, resolved, file_name, timeout, fee_amount, fee_address
|
||||
)
|
||||
if stream:
|
||||
return stream.as_dict()
|
||||
|
@ -2567,8 +2567,7 @@ class Daemon(metaclass=JSONRPCServerType):
|
|||
(str) Success/Fail message or (dict) decoded data
|
||||
"""
|
||||
|
||||
downloader = BlobDownloader(asyncio.get_event_loop(), self.blob_manager, self.conf)
|
||||
blob = await downloader.get_blob(blob_hash, self.dht_node)
|
||||
blob = await download_blob(asyncio.get_event_loop(), self.conf, self.blob_manager, self.dht_node, blob_hash)
|
||||
if read:
|
||||
with open(blob.file_path, 'rb') as handle:
|
||||
return handle.read().decode()
|
||||
|
|
|
@ -48,6 +48,8 @@ class StreamAssembler:
|
|||
def _decrypt_and_write():
|
||||
if self.stream_handle.closed:
|
||||
return False
|
||||
if not blob:
|
||||
return False
|
||||
self.stream_handle.seek(offset)
|
||||
_decrypted = blob.decrypt(
|
||||
binascii.unhexlify(key), binascii.unhexlify(blob_info.iv.encode())
|
||||
|
@ -62,15 +64,26 @@ class StreamAssembler:
|
|||
log.debug("decrypted %s", blob.blob_hash[:8])
|
||||
return
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
async def after_got_descriptor(self):
|
||||
pass
|
||||
|
||||
async def after_finished(self):
|
||||
pass
|
||||
|
||||
async def assemble_decrypted_stream(self, output_dir: str, output_file_name: typing.Optional[str] = None):
|
||||
if not os.path.isdir(output_dir):
|
||||
raise OSError(f"output directory does not exist: '{output_dir}' '{output_file_name}'")
|
||||
await self.setup()
|
||||
self.sd_blob = await self.get_blob(self.sd_hash)
|
||||
await self.blob_manager.blob_completed(self.sd_blob)
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_manager.blob_dir,
|
||||
self.sd_blob)
|
||||
if not self.got_descriptor.is_set():
|
||||
self.got_descriptor.set()
|
||||
await self.after_got_descriptor()
|
||||
self.output_path = await get_next_available_file_name(self.loop, output_dir,
|
||||
output_file_name or self.descriptor.suggested_file_name)
|
||||
|
||||
|
@ -85,17 +98,16 @@ class StreamAssembler:
|
|||
blob = await self.get_blob(blob_info.blob_hash, blob_info.length)
|
||||
await self._decrypt_blob(blob, blob_info, self.descriptor.key)
|
||||
break
|
||||
except ValueError as err:
|
||||
except (ValueError, IOError, OSError) as err:
|
||||
log.error("failed to decrypt blob %s for stream %s - %s", blob_info.blob_hash,
|
||||
self.descriptor.sd_hash, str(err))
|
||||
continue
|
||||
if not self.wrote_bytes_event.is_set():
|
||||
self.wrote_bytes_event.set()
|
||||
self.stream_finished_event.set()
|
||||
await self.after_finished()
|
||||
finally:
|
||||
self.stream_handle.close()
|
||||
|
||||
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
f = asyncio.Future(loop=self.loop)
|
||||
f.set_result(self.blob_manager.get_blob(blob_hash, length))
|
||||
return await f
|
||||
return self.blob_manager.get_blob(blob_hash, length)
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import os
|
||||
import asyncio
|
||||
import typing
|
||||
import socket
|
||||
import logging
|
||||
from lbrynet.utils import drain_tasks, cancel_task
|
||||
from lbrynet.stream.assembler import StreamAssembler
|
||||
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.blob_exchange.downloader import BlobDownloader
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.blob.blob_file import BlobFile
|
||||
|
||||
|
@ -19,212 +20,72 @@ def drain_into(a: list, b: list):
|
|||
b.append(a.pop())
|
||||
|
||||
|
||||
class StreamDownloader(StreamAssembler): # TODO: reduce duplication, refactor to inherit BlobDownloader
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', sd_hash: str,
|
||||
peer_timeout: float, peer_connect_timeout: float, output_dir: typing.Optional[str] = None,
|
||||
output_file_name: typing.Optional[str] = None,
|
||||
fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None,
|
||||
max_connections_per_stream: typing.Optional[int] = 8):
|
||||
async def resolve_host(loop: asyncio.BaseEventLoop, url: str):
|
||||
info = await loop.getaddrinfo(
|
||||
url, 'https',
|
||||
proto=socket.IPPROTO_TCP,
|
||||
)
|
||||
return info[0][4][0]
|
||||
|
||||
|
||||
class StreamDownloader(StreamAssembler):
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', sd_hash: str,
|
||||
output_dir: typing.Optional[str] = None, output_file_name: typing.Optional[str] = None):
|
||||
super().__init__(loop, blob_manager, sd_hash)
|
||||
self.peer_timeout = peer_timeout
|
||||
self.peer_connect_timeout = peer_connect_timeout
|
||||
self.current_blob: 'BlobFile' = None
|
||||
self.download_task: asyncio.Task = None
|
||||
self.accumulate_connections_task: asyncio.Task = None
|
||||
self.new_peer_event = asyncio.Event(loop=self.loop)
|
||||
self.active_connections: typing.Dict['KademliaPeer', BlobExchangeClientProtocol] = {}
|
||||
self.running_download_requests: typing.List[asyncio.Task] = []
|
||||
self.requested_from: typing.Dict[str, typing.Dict['KademliaPeer', asyncio.Task]] = {}
|
||||
self.output_dir = output_dir or os.getcwd()
|
||||
self.config = config
|
||||
self.output_dir = output_dir or self.config.download_dir
|
||||
self.output_file_name = output_file_name
|
||||
self._lock = asyncio.Lock(loop=self.loop)
|
||||
self.max_connections_per_stream = max_connections_per_stream
|
||||
self.fixed_peers = fixed_peers or []
|
||||
self.blob_downloader: typing.Optional[BlobDownloader] = None
|
||||
self.search_queue = asyncio.Queue(loop=loop)
|
||||
self.peer_queue = asyncio.Queue(loop=loop)
|
||||
self.accumulate_task: typing.Optional[asyncio.Task] = None
|
||||
self.descriptor: typing.Optional[StreamDescriptor]
|
||||
self.node: typing.Optional['Node'] = None
|
||||
self.assemble_task: typing.Optional[asyncio.Task] = None
|
||||
self.fixed_peers_handle: typing.Optional[asyncio.Handle] = None
|
||||
|
||||
async def _update_current_blob(self, blob: 'BlobFile'):
|
||||
async with self._lock:
|
||||
drain_tasks(self.running_download_requests)
|
||||
self.current_blob = blob
|
||||
if not blob.get_is_verified():
|
||||
self._update_requests()
|
||||
async def setup(self): # start the peer accumulator and initialize the downloader
|
||||
if self.blob_downloader:
|
||||
raise Exception("downloader is already set up")
|
||||
if self.node:
|
||||
_, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue)
|
||||
self.blob_downloader = BlobDownloader(self.loop, self.config, self.blob_manager, self.peer_queue)
|
||||
self.search_queue.put_nowait(self.sd_hash)
|
||||
|
||||
async def _request_blob(self, peer: 'KademliaPeer'):
|
||||
if self.current_blob.get_is_verified():
|
||||
log.debug("already verified")
|
||||
return
|
||||
if peer not in self.active_connections:
|
||||
log.warning("not active, adding: %s", str(peer))
|
||||
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.peer_timeout)
|
||||
protocol = self.active_connections[peer]
|
||||
success, keep_connection = await request_blob(self.loop, self.current_blob, protocol,
|
||||
peer.address, peer.tcp_port, self.peer_connect_timeout)
|
||||
await protocol.close()
|
||||
if not keep_connection:
|
||||
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
|
||||
if peer in self.active_connections:
|
||||
async with self._lock:
|
||||
del self.active_connections[peer]
|
||||
return
|
||||
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
|
||||
async def after_got_descriptor(self):
|
||||
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
|
||||
log.info("added head blob to search")
|
||||
|
||||
def _update_requests(self):
|
||||
self.new_peer_event.clear()
|
||||
if self.current_blob.blob_hash not in self.requested_from:
|
||||
self.requested_from[self.current_blob.blob_hash] = {}
|
||||
to_add = []
|
||||
for peer in self.active_connections.keys():
|
||||
if peer not in self.requested_from[self.current_blob.blob_hash] and peer not in to_add:
|
||||
to_add.append(peer)
|
||||
if to_add or self.running_download_requests:
|
||||
log.debug("adding download probes for %i peers to %i already active",
|
||||
min(len(to_add), 8 - len(self.running_download_requests)),
|
||||
len(self.running_download_requests))
|
||||
else:
|
||||
log.info("downloader idle...")
|
||||
for peer in to_add:
|
||||
if len(self.running_download_requests) >= self.max_connections_per_stream:
|
||||
break
|
||||
task = self.loop.create_task(self._request_blob(peer))
|
||||
self.requested_from[self.current_blob.blob_hash][peer] = task
|
||||
self.running_download_requests.append(task)
|
||||
|
||||
async def wait_for_download_or_new_peer(self) -> typing.Optional['BlobFile']:
|
||||
async with self._lock:
|
||||
if len(self.running_download_requests) < self.max_connections_per_stream:
|
||||
# update the running download requests
|
||||
self._update_requests()
|
||||
|
||||
# drain the tasks into a temporary list
|
||||
download_tasks = []
|
||||
drain_into(self.running_download_requests, download_tasks)
|
||||
|
||||
got_new_peer = self.loop.create_task(self.new_peer_event.wait())
|
||||
|
||||
# wait for a new peer to be added or for a download attempt to finish
|
||||
await asyncio.wait([got_new_peer] + download_tasks, return_when='FIRST_COMPLETED',
|
||||
loop=self.loop)
|
||||
if got_new_peer and not got_new_peer.done():
|
||||
got_new_peer.cancel()
|
||||
|
||||
async with self._lock:
|
||||
if self.current_blob.get_is_verified():
|
||||
# a download attempt finished
|
||||
if got_new_peer and not got_new_peer.done():
|
||||
got_new_peer.cancel()
|
||||
drain_tasks(download_tasks)
|
||||
return self.current_blob
|
||||
else:
|
||||
# we got a new peer, re add the other pending download attempts
|
||||
for task in download_tasks:
|
||||
if task and not task.done():
|
||||
self.running_download_requests.append(task)
|
||||
return
|
||||
|
||||
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||
await self._update_current_blob(blob)
|
||||
if blob.get_is_verified():
|
||||
return blob
|
||||
|
||||
# the blob must be downloaded
|
||||
try:
|
||||
while not self.current_blob.get_is_verified():
|
||||
if not self.active_connections: # wait for a new connection
|
||||
await self.new_peer_event.wait()
|
||||
continue
|
||||
blob = await self.wait_for_download_or_new_peer()
|
||||
if blob:
|
||||
drain_tasks(self.running_download_requests)
|
||||
return blob
|
||||
return blob
|
||||
except asyncio.CancelledError:
|
||||
drain_tasks(self.running_download_requests)
|
||||
raise
|
||||
|
||||
def _add_peer_protocols(self, peers: typing.List['KademliaPeer']):
|
||||
added = 0
|
||||
for peer in peers:
|
||||
if peer not in self.active_connections:
|
||||
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.peer_timeout)
|
||||
added += 1
|
||||
if added:
|
||||
if not self.new_peer_event.is_set():
|
||||
log.debug("added %i new peers", len(peers))
|
||||
self.new_peer_event.set()
|
||||
|
||||
async def _accumulate_connections(self, node: 'Node'):
|
||||
blob_queue = asyncio.Queue(loop=self.loop)
|
||||
blob_queue.put_nowait(self.sd_hash)
|
||||
task = asyncio.create_task(self.got_descriptor.wait())
|
||||
add_fixed_peers_timer: typing.Optional[asyncio.Handle] = None
|
||||
|
||||
if self.fixed_peers:
|
||||
def check_added_peers():
|
||||
self._add_peer_protocols(self.fixed_peers)
|
||||
log.info("adding fixed peer %s:%i", self.fixed_peers[0].address, self.fixed_peers[0].tcp_port)
|
||||
|
||||
add_fixed_peers_timer = self.loop.call_later(2, check_added_peers)
|
||||
|
||||
def got_descriptor(f):
|
||||
try:
|
||||
f.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
log.info("add head blob hash to peer search")
|
||||
blob_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
|
||||
|
||||
task.add_done_callback(got_descriptor)
|
||||
try:
|
||||
async with node.stream_peer_search_junction(blob_queue) as search_junction:
|
||||
async for peers in search_junction:
|
||||
if peers:
|
||||
self._add_peer_protocols(peers)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
log.info("cancelled head blob task")
|
||||
if add_fixed_peers_timer and not add_fixed_peers_timer.cancelled():
|
||||
add_fixed_peers_timer.cancel()
|
||||
async def after_finished(self):
|
||||
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
|
||||
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
|
||||
|
||||
async def stop(self):
|
||||
cancel_task(self.accumulate_connections_task)
|
||||
self.accumulate_connections_task = None
|
||||
drain_tasks(self.running_download_requests)
|
||||
if self.accumulate_task and not self.accumulate_task.done():
|
||||
self.accumulate_task.cancel()
|
||||
self.accumulate_task = None
|
||||
if self.assemble_task and not self.assemble_task.done():
|
||||
self.assemble_task.cancel()
|
||||
self.assemble_task = None
|
||||
if self.fixed_peers_handle:
|
||||
self.fixed_peers_handle.cancel()
|
||||
self.fixed_peers_handle = None
|
||||
self.blob_downloader = None
|
||||
|
||||
while self.requested_from:
|
||||
_, peer_task_dict = self.requested_from.popitem()
|
||||
while peer_task_dict:
|
||||
peer, task = peer_task_dict.popitem()
|
||||
try:
|
||||
cancel_task(task)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
return await self.blob_downloader.download_blob(blob_hash, length)
|
||||
|
||||
while self.active_connections:
|
||||
_, client = self.active_connections.popitem()
|
||||
if client:
|
||||
await client.close()
|
||||
log.info("stopped downloader")
|
||||
def add_fixed_peers(self):
|
||||
async def _add_fixed_peers():
|
||||
self.peer_queue.put_nowait([
|
||||
KademliaPeer(self.loop, address=(await resolve_host(self.loop, url)), tcp_port=port + 1)
|
||||
for url, port in self.config.reflector_servers
|
||||
])
|
||||
|
||||
async def _download(self):
|
||||
try:
|
||||
self.fixed_peers_handle = self.loop.call_later(self.config.fixed_peer_delay, self.loop.create_task,
|
||||
_add_fixed_peers())
|
||||
|
||||
log.info("download and decrypt stream")
|
||||
await self.assemble_decrypted_stream(self.output_dir, self.output_file_name)
|
||||
log.info(
|
||||
"downloaded stream %s -> %s", self.sd_hash, self.output_path
|
||||
)
|
||||
await self.blob_manager.storage.change_file_status(
|
||||
self.descriptor.stream_hash, 'finished'
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
def download(self, node: 'Node'):
|
||||
self.accumulate_connections_task = self.loop.create_task(self._accumulate_connections(node))
|
||||
self.download_task = self.loop.create_task(self._download())
|
||||
def download(self, node: typing.Optional['Node'] = None):
|
||||
self.node = node
|
||||
self.assemble_task = self.loop.create_task(self.assemble_decrypted_stream(self.config.download_dir))
|
||||
self.add_fixed_peers()
|
||||
|
|
|
@ -46,23 +46,18 @@ comparison_operators = {
|
|||
|
||||
|
||||
class StreamManager:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', wallet: 'LbryWalletManager',
|
||||
storage: 'SQLiteStorage', node: typing.Optional['Node'], peer_timeout: float,
|
||||
peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None,
|
||||
reflector_servers: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager',
|
||||
wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node']):
|
||||
self.loop = loop
|
||||
self.config = config
|
||||
self.blob_manager = blob_manager
|
||||
self.wallet = wallet
|
||||
self.storage = storage
|
||||
self.node = node
|
||||
self.peer_timeout = peer_timeout
|
||||
self.peer_connect_timeout = peer_connect_timeout
|
||||
self.streams: typing.Set[ManagedStream] = set()
|
||||
self.starting_streams: typing.Dict[str, asyncio.Future] = {}
|
||||
self.resume_downloading_task: asyncio.Task = None
|
||||
self.update_stream_finished_futs: typing.List[asyncio.Future] = []
|
||||
self.fixed_peers = fixed_peers
|
||||
self.reflector_servers = reflector_servers
|
||||
|
||||
async def load_streams_from_database(self):
|
||||
infos = await self.storage.get_all_lbry_files()
|
||||
|
@ -71,9 +66,9 @@ class StreamManager:
|
|||
if sd_blob.get_is_verified():
|
||||
descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
|
||||
downloader = StreamDownloader(
|
||||
self.loop, self.blob_manager, descriptor.sd_hash, self.peer_timeout,
|
||||
self.peer_connect_timeout, binascii.unhexlify(file_info['download_directory']).decode(),
|
||||
binascii.unhexlify(file_info['file_name']).decode(), self.fixed_peers
|
||||
self.loop, self.config, self.blob_manager, descriptor.sd_hash,
|
||||
binascii.unhexlify(file_info['download_directory']).decode(),
|
||||
binascii.unhexlify(file_info['file_name']).decode()
|
||||
)
|
||||
stream = ManagedStream(
|
||||
self.loop, self.blob_manager, descriptor,
|
||||
|
@ -128,8 +123,8 @@ class StreamManager:
|
|||
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
|
||||
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator)
|
||||
self.streams.add(stream)
|
||||
if self.reflector_servers:
|
||||
host, port = random.choice(self.reflector_servers)
|
||||
if self.config.reflector_servers:
|
||||
host, port = random.choice(self.config.reflector_servers)
|
||||
self.loop.create_task(stream.upload_to_reflector(host, port))
|
||||
return stream
|
||||
|
||||
|
@ -166,8 +161,8 @@ class StreamManager:
|
|||
file_name: typing.Optional[str] = None) -> typing.Optional[ManagedStream]:
|
||||
|
||||
claim = ClaimDict.load_dict(claim_info['value'])
|
||||
downloader = StreamDownloader(self.loop, self.blob_manager, claim.source_hash.decode(), self.peer_timeout,
|
||||
self.peer_connect_timeout, download_directory, file_name, self.fixed_peers)
|
||||
downloader = StreamDownloader(self.loop, self.config, self.blob_manager, claim.source_hash.decode(),
|
||||
download_directory, file_name)
|
||||
try:
|
||||
downloader.download(node)
|
||||
await downloader.got_descriptor.wait()
|
||||
|
@ -205,7 +200,7 @@ class StreamManager:
|
|||
except asyncio.CancelledError:
|
||||
await downloader.stop()
|
||||
|
||||
async def download_stream_from_claim(self, node: 'Node', config: 'Config', claim_info: typing.Dict,
|
||||
async def download_stream_from_claim(self, node: 'Node', claim_info: typing.Dict,
|
||||
file_name: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = 60,
|
||||
fee_amount: typing.Optional[float] = 0.0,
|
||||
|
@ -224,10 +219,10 @@ class StreamManager:
|
|||
|
||||
self.starting_streams[sd_hash] = asyncio.Future(loop=self.loop)
|
||||
stream_task = self.loop.create_task(
|
||||
self._download_stream_from_claim(node, config.download_dir, claim_info, file_name)
|
||||
self._download_stream_from_claim(node, self.config.download_dir, claim_info, file_name)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(stream_task, timeout or config.download_timeout)
|
||||
await asyncio.wait_for(stream_task, timeout or self.config.download_timeout)
|
||||
stream = await stream_task
|
||||
self.starting_streams[sd_hash].set_result(stream)
|
||||
if fee_address and fee_amount:
|
||||
|
|
Loading…
Add table
Reference in a new issue