diff --git a/lbrynet/extras/daemon/Components.py b/lbrynet/extras/daemon/Components.py index bf5c8fb51..c1afdc65b 100644 --- a/lbrynet/extras/daemon/Components.py +++ b/lbrynet/extras/daemon/Components.py @@ -466,7 +466,7 @@ class StreamManagerComponent(Component): self.conf.peer_connect_timeout, [ KademliaPeer(loop, address=(await resolve_host(loop, url)), tcp_port=port + 1) for url, port in self.conf.reflector_servers - ] + ], self.conf.reflector_servers ) await self.stream_manager.start() log.info('Done setting up file manager') diff --git a/lbrynet/stream/managed_stream.py b/lbrynet/stream/managed_stream.py index a7d75c956..6a637ec1e 100644 --- a/lbrynet/stream/managed_stream.py +++ b/lbrynet/stream/managed_stream.py @@ -5,6 +5,7 @@ import logging from lbrynet.extras.daemon.mime_types import guess_media_type from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.descriptor import StreamDescriptor +from lbrynet.stream.reflector.client import StreamReflectorClient if typing.TYPE_CHECKING: from lbrynet.extras.daemon.storage import StoredStreamClaim from lbrynet.blob.blob_manager import BlobFileManager @@ -30,6 +31,7 @@ class ManagedStream: self.stream_claim_info = claim self._status = status self._store_after_finished: asyncio.Task = None + self.fully_reflected = asyncio.Event(loop=self.loop) @property def status(self) -> str: @@ -139,9 +141,10 @@ class ManagedStream: @classmethod async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', - file_path: str) -> 'ManagedStream': + file_path: str, key: typing.Optional[bytes] = None, + iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream': descriptor = await StreamDescriptor.create_stream( - loop, blob_manager.blob_dir, file_path + loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator ) sd_blob = blob_manager.get_blob(descriptor.sd_hash) await blob_manager.blob_completed(sd_blob) @@ -156,3 +159,38 @@ class ManagedStream: await self.downloader.stop() if not self.finished: self.update_status(self.STATUS_STOPPED) + + async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]: + sent = [] + protocol = StreamReflectorClient(self.blob_manager, self.descriptor) + try: + await self.loop.create_connection(lambda: protocol, host, port) + except ConnectionRefusedError: + return sent + try: + await protocol.send_handshake() + except (asyncio.CancelledError, asyncio.TimeoutError, ValueError): + if protocol.transport: + protocol.transport.close() + return sent + try: + sent_sd, needed = await protocol.send_descriptor() + if sent_sd: + sent.append(self.sd_hash) + except (asyncio.CancelledError, asyncio.TimeoutError, ValueError): + if protocol.transport: + protocol.transport.close() + return sent + for blob_hash in needed: + try: + await protocol.send_blob(blob_hash) + sent.append(blob_hash) + except (asyncio.CancelledError, asyncio.TimeoutError, ValueError): + if protocol.transport: + protocol.transport.close() + return sent + if protocol.transport: + protocol.transport.close() + if not self.fully_reflected.is_set(): + self.fully_reflected.set() + return sent diff --git a/lbrynet/stream/reflector/__init__.py b/lbrynet/stream/reflector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbrynet/stream/reflector/client.py b/lbrynet/stream/reflector/client.py new file mode 100644 index 000000000..ce4d7528e --- /dev/null +++ b/lbrynet/stream/reflector/client.py @@ -0,0 +1,103 @@ +import asyncio +import json +import logging +import typing + +if typing.TYPE_CHECKING: + from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.stream.descriptor import StreamDescriptor + +REFLECTOR_V1 = 0 +REFLECTOR_V2 = 1 + +log = logging.getLogger(__name__) + + +class StreamReflectorClient(asyncio.Protocol): + def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'): + self.transport: asyncio.StreamWriter = None + self.blob_manager = blob_manager + self.descriptor = descriptor + self.response_buff = b'' + self.reflected_blobs = [] + self.connected = asyncio.Event() + self.response_queue = asyncio.Queue(maxsize=1) + self.pending_request: typing.Optional[asyncio.Task] = None + + def connection_made(self, transport): + self.transport = transport + log.info("Connected to reflector") + self.connected.set() + + def connection_lost(self, exc: typing.Optional[Exception]): + self.transport = None + self.connected.clear() + + def data_received(self, data): + try: + response = json.loads(data.decode()) + self.response_queue.put_nowait(response) + except ValueError: + self.transport.close() + return + + async def send_request(self, request_dict: typing.Dict): + msg = json.dumps(request_dict) + self.transport.write(msg.encode()) + try: + self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get()) + return await self.pending_request + finally: + self.pending_request = None + + async def send_handshake(self) -> None: + response_dict = await self.send_request({'version': REFLECTOR_V2}) + if 'version' not in response_dict: + raise ValueError("Need protocol version number!") + server_version = int(response_dict['version']) + if server_version != REFLECTOR_V2: + raise ValueError("I can't handle protocol version {}!".format(server_version)) + return + + async def send_descriptor(self) -> typing.Tuple[bool, typing.List[str]]: # returns a list of needed blob hashes + sd_blob = self.blob_manager.get_blob(self.descriptor.sd_hash) + assert sd_blob.get_is_verified(), "need to have a sd blob to send at this point" + response = await self.send_request({ + 'sd_blob_hash': sd_blob.blob_hash, + 'sd_blob_size': sd_blob.length + }) + if 'send_sd_blob' not in response: + raise ValueError("I don't know whether to send the sd blob or not!") + needed = response.get('needed_blobs', []) + sent_sd = False + if response['send_sd_blob']: + await sd_blob.sendfile(self) + received = await self.response_queue.get() + if received.get('received_sd_blob'): + sent_sd = True + if not needed: + for blob in self.descriptor.blobs[:-1]: + if self.blob_manager.get_blob(blob.blob_hash, blob.length).get_is_verified(): + needed.append(blob.blob_hash) + log.info("Sent reflector descriptor %s", sd_blob.blob_hash[:8]) + else: + log.warning("Reflector failed to receive descriptor %s", sd_blob.blob_hash[:8]) + return sent_sd, needed + + async def send_blob(self, blob_hash: str): + blob = self.blob_manager.get_blob(blob_hash) + assert blob.get_is_verified(), "need to have a blob to send at this point" + response = await self.send_request({ + 'blob_hash': blob.blob_hash, + 'blob_size': blob.length + }) + if 'send_blob' not in response: + raise ValueError("I don't know whether to send the blob or not!") + if response['send_blob']: + await blob.sendfile(self) + received = await self.response_queue.get() + if received.get('received_blob'): + self.reflected_blobs.append(blob.blob_hash) + log.info("Sent reflector blob %s", blob.blob_hash[:8]) + else: + log.warning("Reflector failed to receive blob %s", blob.blob_hash[:8]) diff --git a/lbrynet/stream/reflector/server.py b/lbrynet/stream/reflector/server.py new file mode 100644 index 000000000..3547f0c8a --- /dev/null +++ b/lbrynet/stream/reflector/server.py @@ -0,0 +1,147 @@ +import asyncio +import logging +import typing +import json +from json.decoder import JSONDecodeError +from lbrynet.stream.descriptor import StreamDescriptor + +if typing.TYPE_CHECKING: + from lbrynet.blob.blob_file import BlobFile + from lbrynet.blob.blob_manager import BlobFileManager + from lbrynet.blob.writer import HashBlobWriter + + +log = logging.getLogger(__name__) + + +class ReflectorServerProtocol(asyncio.Protocol): + def __init__(self, blob_manager: 'BlobFileManager'): + self.loop = asyncio.get_event_loop() + self.blob_manager = blob_manager + self.server_task: asyncio.Task = None + self.started_listening = asyncio.Event(loop=self.loop) + self.buf = b'' + self.transport: asyncio.StreamWriter = None + self.writer: typing.Optional['HashBlobWriter'] = None + self.client_version: typing.Optional[int] = None + self.descriptor: typing.Optional['StreamDescriptor'] = None + self.sd_blob: typing.Optional['BlobFile'] = None + self.received = [] + self.incoming = asyncio.Event(loop=self.loop) + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data: bytes): + if self.incoming.is_set(): + try: + self.writer.write(data) + except IOError as err: + log.error("error downloading blob: %s", err) + return + try: + request = json.loads(data.decode()) + except (ValueError, JSONDecodeError): + return + self.loop.create_task(self.handle_request(request)) + + def send_response(self, response: typing.Dict): + self.transport.write(json.dumps(response).encode()) + + async def handle_request(self, request: typing.Dict): + if self.client_version is None: + if 'version' not in request: + self.transport.close() + return + self.client_version = request['version'] + self.send_response({'version': 1}) + return + if not self.sd_blob: + if 'sd_blob_hash' not in request: + self.transport.close() + return + self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) + if not self.sd_blob.get_is_verified(): + self.writer = self.sd_blob.open_for_writing() + self.incoming.set() + self.send_response({"send_sd_blob": True}) + try: + await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop) + self.send_response({"received_sd_blob": True}) + self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( + self.loop, self.blob_manager.blob_dir, self.sd_blob + ) + self.incoming.clear() + self.writer.close_handle() + self.writer = None + except (asyncio.TimeoutError, asyncio.CancelledError): + self.send_response({"received_sd_blob": False}) + self.incoming.clear() + self.writer.close_handle() + self.writer = None + self.transport.close() + return + else: + self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( + self.loop, self.blob_manager.blob_dir, self.sd_blob + ) + self.send_response({"send_sd_blob": False, 'needed': [ + blob.blob_hash for blob in self.descriptor.blobs[:-1] + if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified() + ]}) + return + elif self.descriptor: + if 'blob_hash' not in request: + self.transport.close() + return + if request['blob_hash'] not in map(lambda b: b.blob_hash, self.descriptor.blobs[:-1]): + self.send_response({"send_blob": False}) + return + blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) + if not blob.get_is_verified(): + self.writer = blob.open_for_writing() + self.incoming.set() + self.send_response({"send_blob": True}) + try: + await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop) + self.send_response({"received_blob": True}) + except (asyncio.TimeoutError, asyncio.CancelledError): + self.send_response({"received_blob": False}) + self.incoming.clear() + self.writer.close_handle() + self.writer = None + else: + self.send_response({"send_blob": False}) + return + else: + self.transport.close() + + +class ReflectorServer: + def __init__(self, blob_manager: 'BlobFileManager'): + self.loop = asyncio.get_event_loop() + self.blob_manager = blob_manager + self.server_task: asyncio.Task = None + self.started_listening = asyncio.Event(loop=self.loop) + + def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): + if self.server_task is not None: + raise Exception("already running") + + async def _start_server(): + server = await self.loop.create_server( + lambda: ReflectorServerProtocol(self.blob_manager), + interface, port + ) + self.started_listening.set() + log.info("Reflector server listening on TCP %s:%i", interface, port) + async with server: + await server.serve_forever() + + self.server_task = self.loop.create_task(_start_server()) + + def stop_server(self): + if self.server_task: + self.server_task.cancel() + self.server_task = None + log.info("Stopped reflector server") diff --git a/lbrynet/stream/stream_manager.py b/lbrynet/stream/stream_manager.py index c5e3df141..7f584f7ab 100644 --- a/lbrynet/stream/stream_manager.py +++ b/lbrynet/stream/stream_manager.py @@ -3,6 +3,7 @@ import asyncio import typing import binascii import logging +import random from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.managed_stream import ManagedStream from lbrynet.schema.claim import ClaimDict @@ -46,7 +47,8 @@ comparison_operators = { class StreamManager: def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', wallet: 'LbryWalletManager', storage: 'SQLiteStorage', node: typing.Optional['Node'], peer_timeout: float, - peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None): + peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None, + reflector_servers: typing.Optional[typing.List[typing.Tuple[str, int]]] = None): self.loop = loop self.blob_manager = blob_manager self.wallet = wallet @@ -59,6 +61,7 @@ class StreamManager: self.resume_downloading_task: asyncio.Task = None self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.fixed_peers = fixed_peers + self.reflector_servers = reflector_servers async def load_streams_from_database(self): infos = await self.storage.get_all_lbry_files() @@ -93,6 +96,20 @@ class StreamManager: if resumed: log.info("resuming %i downloads", resumed) + async def reflect_streams(self): + streams = list(self.streams) + batch = [] + while streams: + stream = streams.pop() + if not stream.fully_reflected.is_set(): + host, port = random.choice(self.reflector_servers) + batch.append(stream.upload_to_reflector(host, port)) + if len(batch) >= 10: + await asyncio.gather(*batch) + batch = [] + if batch: + await asyncio.gather(*batch) + async def start(self): await self.load_streams_from_database() self.resume_downloading_task = self.loop.create_task(self.resume()) @@ -106,9 +123,13 @@ class StreamManager: while self.update_stream_finished_futs: self.update_stream_finished_futs.pop().cancel() - async def create_stream(self, file_path: str) -> ManagedStream: - stream = await ManagedStream.create(self.loop, self.blob_manager, file_path) + async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None, + iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: + stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator) self.streams.add(stream) + if self.reflector_servers: + host, port = random.choice(self.reflector_servers) + self.loop.create_task(stream.upload_to_reflector(host, port)) return stream async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py new file mode 100644 index 000000000..d3e2142af --- /dev/null +++ b/tests/unit/stream/test_reflector.py @@ -0,0 +1,62 @@ +import os +import asyncio +import tempfile +import shutil +from torba.testcase import AsyncioTestCase +from lbrynet.conf import Config +from lbrynet.extras.daemon.storage import SQLiteStorage +from lbrynet.blob.blob_manager import BlobFileManager +from lbrynet.stream.stream_manager import StreamManager +from lbrynet.stream.reflector.server import ReflectorServer + + +class TestStreamAssembler(AsyncioTestCase): + async def asyncSetUp(self): + self.loop = asyncio.get_event_loop() + self.key = b'deadbeef' * 4 + self.cleartext = os.urandom(20000000) + + tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmp_dir)) + self.storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite")) + await self.storage.open() + self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage) + self.stream_manager = StreamManager(self.loop, self.blob_manager, None, self.storage, None, 3.0, 3.0) + + server_tmp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(server_tmp_dir)) + self.server_storage = SQLiteStorage(Config(), os.path.join(server_tmp_dir, "lbrynet.sqlite")) + await self.server_storage.open() + self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage) + + download_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(download_dir)) + + # create the stream + file_path = os.path.join(tmp_dir, "test_file") + with open(file_path, 'wb') as f: + f.write(self.cleartext) + + self.stream = await self.stream_manager.create_stream(file_path) + + async def test_reflect_stream(self): + reflector = ReflectorServer(self.server_blob_manager) + reflector.start_server(5566, '127.0.0.1') + await reflector.started_listening.wait() + self.addCleanup(reflector.stop_server) + sent = await self.stream.upload_to_reflector('127.0.0.1', 5566) + self.assertSetEqual( + set(sent), + set(map(lambda b: b.blob_hash, + self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)])) + ) + server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash) + self.assertTrue(server_sd_blob.get_is_verified()) + self.assertEqual(server_sd_blob.length, server_sd_blob.length) + for blob in self.stream.descriptor.blobs[:-1]: + server_blob = self.server_blob_manager.get_blob(blob.blob_hash) + self.assertTrue(server_blob.get_is_verified()) + self.assertEqual(server_blob.length, blob.length) + + sent = await self.stream.upload_to_reflector('127.0.0.1', 5566) + self.assertListEqual(sent, [])