lbry-sdk/lbrynet/stream/downloader.py
2019-01-28 22:01:43 -05:00

231 lines
9.8 KiB
Python

import os
import asyncio
import typing
import logging
from lbrynet.utils import drain_tasks, cancel_task
from lbrynet.stream.assembler import StreamAssembler
from lbrynet.blob_exchange.client import BlobExchangeClientProtocol, request_blob
if typing.TYPE_CHECKING:
from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.blob_file import BlobFile
log = logging.getLogger(__name__)
def drain_into(a: list, b: list):
while a:
b.append(a.pop())
class StreamDownloader(StreamAssembler): # TODO: reduce duplication, refactor to inherit BlobDownloader
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', sd_hash: str,
peer_timeout: float, peer_connect_timeout: float, output_dir: typing.Optional[str] = None,
output_file_name: typing.Optional[str] = None,
fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None,
max_connections_per_stream: typing.Optional[int] = 8):
super().__init__(loop, blob_manager, sd_hash)
self.peer_timeout = peer_timeout
self.peer_connect_timeout = peer_connect_timeout
self.current_blob: 'BlobFile' = None
self.download_task: asyncio.Task = None
self.accumulate_connections_task: asyncio.Task = None
self.new_peer_event = asyncio.Event(loop=self.loop)
self.active_connections: typing.Dict['KademliaPeer', BlobExchangeClientProtocol] = {}
self.running_download_requests: typing.List[asyncio.Task] = []
self.requested_from: typing.Dict[str, typing.Dict['KademliaPeer', asyncio.Task]] = {}
self.output_dir = output_dir or os.getcwd()
self.output_file_name = output_file_name
self._lock = asyncio.Lock(loop=self.loop)
self.max_connections_per_stream = max_connections_per_stream
self.fixed_peers = fixed_peers or []
async def _update_current_blob(self, blob: 'BlobFile'):
async with self._lock:
drain_tasks(self.running_download_requests)
self.current_blob = blob
if not blob.get_is_verified():
self._update_requests()
async def _request_blob(self, peer: 'KademliaPeer'):
if self.current_blob.get_is_verified():
log.debug("already verified")
return
if peer not in self.active_connections:
log.warning("not active, adding: %s", str(peer))
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.peer_timeout)
protocol = self.active_connections[peer]
success, keep_connection = await request_blob(self.loop, self.current_blob, protocol,
peer.address, peer.tcp_port, self.peer_connect_timeout)
await protocol.close()
if not keep_connection:
log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
if peer in self.active_connections:
async with self._lock:
del self.active_connections[peer]
return
log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
def _update_requests(self):
self.new_peer_event.clear()
if self.current_blob.blob_hash not in self.requested_from:
self.requested_from[self.current_blob.blob_hash] = {}
to_add = []
for peer in self.active_connections.keys():
if peer not in self.requested_from[self.current_blob.blob_hash] and peer not in to_add:
to_add.append(peer)
if to_add or self.running_download_requests:
log.debug("adding download probes for %i peers to %i already active",
min(len(to_add), 8 - len(self.running_download_requests)),
len(self.running_download_requests))
else:
log.info("downloader idle...")
for peer in to_add:
if len(self.running_download_requests) >= self.max_connections_per_stream:
break
task = self.loop.create_task(self._request_blob(peer))
self.requested_from[self.current_blob.blob_hash][peer] = task
self.running_download_requests.append(task)
async def wait_for_download_or_new_peer(self) -> typing.Optional['BlobFile']:
async with self._lock:
if len(self.running_download_requests) < self.max_connections_per_stream:
# update the running download requests
self._update_requests()
# drain the tasks into a temporary list
download_tasks = []
drain_into(self.running_download_requests, download_tasks)
got_new_peer = self.loop.create_task(self.new_peer_event.wait())
# wait for a new peer to be added or for a download attempt to finish
await asyncio.wait([got_new_peer] + download_tasks, return_when='FIRST_COMPLETED',
loop=self.loop)
if got_new_peer and not got_new_peer.done():
got_new_peer.cancel()
async with self._lock:
if self.current_blob.get_is_verified():
if got_new_peer and not got_new_peer.done():
got_new_peer.cancel()
drain_tasks(download_tasks)
return self.current_blob
else:
for task in download_tasks:
if task and not task.done():
self.running_download_requests.append(task)
return
async def get_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
blob = self.blob_manager.get_blob(blob_hash, length)
await self._update_current_blob(blob)
if blob.get_is_verified():
return blob
# the blob must be downloaded
try:
while not self.current_blob.get_is_verified():
if not self.active_connections: # wait for a new connection
await self.new_peer_event.wait()
continue
blob = await self.wait_for_download_or_new_peer()
if blob:
drain_tasks(self.running_download_requests)
return blob
return blob
except asyncio.CancelledError:
drain_tasks(self.running_download_requests)
raise
def _add_peer_protocols(self, peers: typing.List['KademliaPeer']):
added = 0
for peer in peers:
if peer not in self.active_connections:
self.active_connections[peer] = BlobExchangeClientProtocol(self.loop, self.peer_timeout)
added += 1
if added:
if not self.new_peer_event.is_set():
log.info("added %i new peers", len(peers))
self.new_peer_event.set()
async def _accumulate_connections(self, node: 'Node'):
blob_queue = asyncio.Queue(loop=self.loop)
blob_queue.put_nowait(self.sd_hash)
task = asyncio.create_task(self.got_descriptor.wait())
added_peers = asyncio.Event(loop=self.loop)
add_fixed_peers_timer: typing.Optional[asyncio.Handle] = None
if self.fixed_peers:
def check_added_peers():
self._add_peer_protocols(self.fixed_peers)
log.info("adding fixed peer %s:%i", self.fixed_peers[0].address, self.fixed_peers[0].tcp_port)
add_fixed_peers_timer = self.loop.call_later(2, check_added_peers)
def got_descriptor(f):
try:
f.result()
except asyncio.CancelledError:
return
log.info("add head blob hash to peer search")
blob_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
task.add_done_callback(got_descriptor)
try:
async with node.stream_peer_search_junction(blob_queue) as search_junction:
async for peers in search_junction:
if peers:
self._add_peer_protocols(peers)
if not added_peers.is_set():
added_peers.set()
return
except asyncio.CancelledError:
pass
finally:
if task and not task.done():
task.cancel()
log.info("cancelled head blob task")
if add_fixed_peers_timer and not add_fixed_peers_timer.cancelled():
add_fixed_peers_timer.cancel()
async def stop(self):
cancel_task(self.accumulate_connections_task)
self.accumulate_connections_task = None
drain_tasks(self.running_download_requests)
while self.requested_from:
_, peer_task_dict = self.requested_from.popitem()
while peer_task_dict:
peer, task = peer_task_dict.popitem()
try:
cancel_task(task)
except asyncio.CancelledError:
pass
while self.active_connections:
_, client = self.active_connections.popitem()
if client:
await client.close()
log.info("stopped downloader")
async def _download(self):
try:
log.info("download and decrypt stream")
await self.assemble_decrypted_stream(self.output_dir, self.output_file_name)
log.info(
"downloaded stream %s -> %s", self.sd_hash, self.output_path
)
await self.blob_manager.storage.change_file_status(
self.descriptor.stream_hash, 'finished'
)
except asyncio.CancelledError:
pass
finally:
await self.stop()
def download(self, node: 'Node'):
self.accumulate_connections_task = self.loop.create_task(self._accumulate_connections(node))
self.download_task = self.loop.create_task(self._download())