async reflector
This commit is contained in:
parent
d0e6e74e1b
commit
a7610e3d34
7 changed files with 377 additions and 6 deletions
|
@ -466,7 +466,7 @@ class StreamManagerComponent(Component):
|
|||
self.conf.peer_connect_timeout, [
|
||||
KademliaPeer(loop, address=(await resolve_host(loop, url)), tcp_port=port + 1)
|
||||
for url, port in self.conf.reflector_servers
|
||||
]
|
||||
], self.conf.reflector_servers
|
||||
)
|
||||
await self.stream_manager.start()
|
||||
log.info('Done setting up file manager')
|
||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
|||
from lbrynet.extras.daemon.mime_types import guess_media_type
|
||||
from lbrynet.stream.downloader import StreamDownloader
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
from lbrynet.stream.reflector.client import StreamReflectorClient
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.extras.daemon.storage import StoredStreamClaim
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
|
@ -30,6 +31,7 @@ class ManagedStream:
|
|||
self.stream_claim_info = claim
|
||||
self._status = status
|
||||
self._store_after_finished: asyncio.Task = None
|
||||
self.fully_reflected = asyncio.Event(loop=self.loop)
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
|
@ -139,9 +141,10 @@ class ManagedStream:
|
|||
|
||||
@classmethod
|
||||
async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager',
|
||||
file_path: str) -> 'ManagedStream':
|
||||
file_path: str, key: typing.Optional[bytes] = None,
|
||||
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream':
|
||||
descriptor = await StreamDescriptor.create_stream(
|
||||
loop, blob_manager.blob_dir, file_path
|
||||
loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator
|
||||
)
|
||||
sd_blob = blob_manager.get_blob(descriptor.sd_hash)
|
||||
await blob_manager.blob_completed(sd_blob)
|
||||
|
@ -156,3 +159,38 @@ class ManagedStream:
|
|||
await self.downloader.stop()
|
||||
if not self.finished:
|
||||
self.update_status(self.STATUS_STOPPED)
|
||||
|
||||
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
|
||||
sent = []
|
||||
protocol = StreamReflectorClient(self.blob_manager, self.descriptor)
|
||||
try:
|
||||
await self.loop.create_connection(lambda: protocol, host, port)
|
||||
except ConnectionRefusedError:
|
||||
return sent
|
||||
try:
|
||||
await protocol.send_handshake()
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
|
||||
if protocol.transport:
|
||||
protocol.transport.close()
|
||||
return sent
|
||||
try:
|
||||
sent_sd, needed = await protocol.send_descriptor()
|
||||
if sent_sd:
|
||||
sent.append(self.sd_hash)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
|
||||
if protocol.transport:
|
||||
protocol.transport.close()
|
||||
return sent
|
||||
for blob_hash in needed:
|
||||
try:
|
||||
await protocol.send_blob(blob_hash)
|
||||
sent.append(blob_hash)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
|
||||
if protocol.transport:
|
||||
protocol.transport.close()
|
||||
return sent
|
||||
if protocol.transport:
|
||||
protocol.transport.close()
|
||||
if not self.fully_reflected.is_set():
|
||||
self.fully_reflected.set()
|
||||
return sent
|
||||
|
|
0
lbrynet/stream/reflector/__init__.py
Normal file
0
lbrynet/stream/reflector/__init__.py
Normal file
103
lbrynet/stream/reflector/client.py
Normal file
103
lbrynet/stream/reflector/client.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
|
||||
REFLECTOR_V1 = 0
|
||||
REFLECTOR_V2 = 1
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamReflectorClient(asyncio.Protocol):
|
||||
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
|
||||
self.transport: asyncio.StreamWriter = None
|
||||
self.blob_manager = blob_manager
|
||||
self.descriptor = descriptor
|
||||
self.response_buff = b''
|
||||
self.reflected_blobs = []
|
||||
self.connected = asyncio.Event()
|
||||
self.response_queue = asyncio.Queue(maxsize=1)
|
||||
self.pending_request: typing.Optional[asyncio.Task] = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
log.info("Connected to reflector")
|
||||
self.connected.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]):
|
||||
self.transport = None
|
||||
self.connected.clear()
|
||||
|
||||
def data_received(self, data):
|
||||
try:
|
||||
response = json.loads(data.decode())
|
||||
self.response_queue.put_nowait(response)
|
||||
except ValueError:
|
||||
self.transport.close()
|
||||
return
|
||||
|
||||
async def send_request(self, request_dict: typing.Dict):
|
||||
msg = json.dumps(request_dict)
|
||||
self.transport.write(msg.encode())
|
||||
try:
|
||||
self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get())
|
||||
return await self.pending_request
|
||||
finally:
|
||||
self.pending_request = None
|
||||
|
||||
async def send_handshake(self) -> None:
|
||||
response_dict = await self.send_request({'version': REFLECTOR_V2})
|
||||
if 'version' not in response_dict:
|
||||
raise ValueError("Need protocol version number!")
|
||||
server_version = int(response_dict['version'])
|
||||
if server_version != REFLECTOR_V2:
|
||||
raise ValueError("I can't handle protocol version {}!".format(server_version))
|
||||
return
|
||||
|
||||
async def send_descriptor(self) -> typing.Tuple[bool, typing.List[str]]: # returns a list of needed blob hashes
|
||||
sd_blob = self.blob_manager.get_blob(self.descriptor.sd_hash)
|
||||
assert sd_blob.get_is_verified(), "need to have a sd blob to send at this point"
|
||||
response = await self.send_request({
|
||||
'sd_blob_hash': sd_blob.blob_hash,
|
||||
'sd_blob_size': sd_blob.length
|
||||
})
|
||||
if 'send_sd_blob' not in response:
|
||||
raise ValueError("I don't know whether to send the sd blob or not!")
|
||||
needed = response.get('needed_blobs', [])
|
||||
sent_sd = False
|
||||
if response['send_sd_blob']:
|
||||
await sd_blob.sendfile(self)
|
||||
received = await self.response_queue.get()
|
||||
if received.get('received_sd_blob'):
|
||||
sent_sd = True
|
||||
if not needed:
|
||||
for blob in self.descriptor.blobs[:-1]:
|
||||
if self.blob_manager.get_blob(blob.blob_hash, blob.length).get_is_verified():
|
||||
needed.append(blob.blob_hash)
|
||||
log.info("Sent reflector descriptor %s", sd_blob.blob_hash[:8])
|
||||
else:
|
||||
log.warning("Reflector failed to receive descriptor %s", sd_blob.blob_hash[:8])
|
||||
return sent_sd, needed
|
||||
|
||||
async def send_blob(self, blob_hash: str):
|
||||
blob = self.blob_manager.get_blob(blob_hash)
|
||||
assert blob.get_is_verified(), "need to have a blob to send at this point"
|
||||
response = await self.send_request({
|
||||
'blob_hash': blob.blob_hash,
|
||||
'blob_size': blob.length
|
||||
})
|
||||
if 'send_blob' not in response:
|
||||
raise ValueError("I don't know whether to send the blob or not!")
|
||||
if response['send_blob']:
|
||||
await blob.sendfile(self)
|
||||
received = await self.response_queue.get()
|
||||
if received.get('received_blob'):
|
||||
self.reflected_blobs.append(blob.blob_hash)
|
||||
log.info("Sent reflector blob %s", blob.blob_hash[:8])
|
||||
else:
|
||||
log.warning("Reflector failed to receive blob %s", blob.blob_hash[:8])
|
147
lbrynet/stream/reflector/server.py
Normal file
147
lbrynet/stream/reflector/server.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
import json
|
||||
from json.decoder import JSONDecodeError
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.blob.blob_file import BlobFile
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.blob.writer import HashBlobWriter
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReflectorServerProtocol(asyncio.Protocol):
|
||||
def __init__(self, blob_manager: 'BlobFileManager'):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: asyncio.Task = None
|
||||
self.started_listening = asyncio.Event(loop=self.loop)
|
||||
self.buf = b''
|
||||
self.transport: asyncio.StreamWriter = None
|
||||
self.writer: typing.Optional['HashBlobWriter'] = None
|
||||
self.client_version: typing.Optional[int] = None
|
||||
self.descriptor: typing.Optional['StreamDescriptor'] = None
|
||||
self.sd_blob: typing.Optional['BlobFile'] = None
|
||||
self.received = []
|
||||
self.incoming = asyncio.Event(loop=self.loop)
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
if self.incoming.is_set():
|
||||
try:
|
||||
self.writer.write(data)
|
||||
except IOError as err:
|
||||
log.error("error downloading blob: %s", err)
|
||||
return
|
||||
try:
|
||||
request = json.loads(data.decode())
|
||||
except (ValueError, JSONDecodeError):
|
||||
return
|
||||
self.loop.create_task(self.handle_request(request))
|
||||
|
||||
def send_response(self, response: typing.Dict):
|
||||
self.transport.write(json.dumps(response).encode())
|
||||
|
||||
async def handle_request(self, request: typing.Dict):
|
||||
if self.client_version is None:
|
||||
if 'version' not in request:
|
||||
self.transport.close()
|
||||
return
|
||||
self.client_version = request['version']
|
||||
self.send_response({'version': 1})
|
||||
return
|
||||
if not self.sd_blob:
|
||||
if 'sd_blob_hash' not in request:
|
||||
self.transport.close()
|
||||
return
|
||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.open_for_writing()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
|
||||
self.send_response({"received_sd_blob": True})
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.incoming.clear()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self.send_response({"received_sd_blob": False})
|
||||
self.incoming.clear()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
self.transport.close()
|
||||
return
|
||||
else:
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.send_response({"send_sd_blob": False, 'needed': [
|
||||
blob.blob_hash for blob in self.descriptor.blobs[:-1]
|
||||
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
|
||||
]})
|
||||
return
|
||||
elif self.descriptor:
|
||||
if 'blob_hash' not in request:
|
||||
self.transport.close()
|
||||
return
|
||||
if request['blob_hash'] not in map(lambda b: b.blob_hash, self.descriptor.blobs[:-1]):
|
||||
self.send_response({"send_blob": False})
|
||||
return
|
||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.open_for_writing()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop)
|
||||
self.send_response({"received_blob": True})
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self.send_response({"received_blob": False})
|
||||
self.incoming.clear()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
else:
|
||||
self.send_response({"send_blob": False})
|
||||
return
|
||||
else:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class ReflectorServer:
|
||||
def __init__(self, blob_manager: 'BlobFileManager'):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.blob_manager = blob_manager
|
||||
self.server_task: asyncio.Task = None
|
||||
self.started_listening = asyncio.Event(loop=self.loop)
|
||||
|
||||
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
|
||||
if self.server_task is not None:
|
||||
raise Exception("already running")
|
||||
|
||||
async def _start_server():
|
||||
server = await self.loop.create_server(
|
||||
lambda: ReflectorServerProtocol(self.blob_manager),
|
||||
interface, port
|
||||
)
|
||||
self.started_listening.set()
|
||||
log.info("Reflector server listening on TCP %s:%i", interface, port)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
self.server_task = self.loop.create_task(_start_server())
|
||||
|
||||
def stop_server(self):
|
||||
if self.server_task:
|
||||
self.server_task.cancel()
|
||||
self.server_task = None
|
||||
log.info("Stopped reflector server")
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import typing
|
||||
import binascii
|
||||
import logging
|
||||
import random
|
||||
from lbrynet.stream.downloader import StreamDownloader
|
||||
from lbrynet.stream.managed_stream import ManagedStream
|
||||
from lbrynet.schema.claim import ClaimDict
|
||||
|
@ -46,7 +47,8 @@ comparison_operators = {
|
|||
class StreamManager:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', wallet: 'LbryWalletManager',
|
||||
storage: 'SQLiteStorage', node: typing.Optional['Node'], peer_timeout: float,
|
||||
peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None):
|
||||
peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None,
|
||||
reflector_servers: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
|
||||
self.loop = loop
|
||||
self.blob_manager = blob_manager
|
||||
self.wallet = wallet
|
||||
|
@ -59,6 +61,7 @@ class StreamManager:
|
|||
self.resume_downloading_task: asyncio.Task = None
|
||||
self.update_stream_finished_futs: typing.List[asyncio.Future] = []
|
||||
self.fixed_peers = fixed_peers
|
||||
self.reflector_servers = reflector_servers
|
||||
|
||||
async def load_streams_from_database(self):
|
||||
infos = await self.storage.get_all_lbry_files()
|
||||
|
@ -93,6 +96,20 @@ class StreamManager:
|
|||
if resumed:
|
||||
log.info("resuming %i downloads", resumed)
|
||||
|
||||
async def reflect_streams(self):
|
||||
streams = list(self.streams)
|
||||
batch = []
|
||||
while streams:
|
||||
stream = streams.pop()
|
||||
if not stream.fully_reflected.is_set():
|
||||
host, port = random.choice(self.reflector_servers)
|
||||
batch.append(stream.upload_to_reflector(host, port))
|
||||
if len(batch) >= 10:
|
||||
await asyncio.gather(*batch)
|
||||
batch = []
|
||||
if batch:
|
||||
await asyncio.gather(*batch)
|
||||
|
||||
async def start(self):
|
||||
await self.load_streams_from_database()
|
||||
self.resume_downloading_task = self.loop.create_task(self.resume())
|
||||
|
@ -106,9 +123,13 @@ class StreamManager:
|
|||
while self.update_stream_finished_futs:
|
||||
self.update_stream_finished_futs.pop().cancel()
|
||||
|
||||
async def create_stream(self, file_path: str) -> ManagedStream:
|
||||
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path)
|
||||
async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None,
|
||||
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
|
||||
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator)
|
||||
self.streams.add(stream)
|
||||
if self.reflector_servers:
|
||||
host, port = random.choice(self.reflector_servers)
|
||||
self.loop.create_task(stream.upload_to_reflector(host, port))
|
||||
return stream
|
||||
|
||||
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):
|
||||
|
|
62
tests/unit/stream/test_reflector.py
Normal file
62
tests/unit/stream/test_reflector.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import asyncio
|
||||
import tempfile
|
||||
import shutil
|
||||
from torba.testcase import AsyncioTestCase
|
||||
from lbrynet.conf import Config
|
||||
from lbrynet.extras.daemon.storage import SQLiteStorage
|
||||
from lbrynet.blob.blob_manager import BlobFileManager
|
||||
from lbrynet.stream.stream_manager import StreamManager
|
||||
from lbrynet.stream.reflector.server import ReflectorServer
|
||||
|
||||
|
||||
class TestStreamAssembler(AsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.key = b'deadbeef' * 4
|
||||
self.cleartext = os.urandom(20000000)
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
|
||||
self.storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
|
||||
await self.storage.open()
|
||||
self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage)
|
||||
self.stream_manager = StreamManager(self.loop, self.blob_manager, None, self.storage, None, 3.0, 3.0)
|
||||
|
||||
server_tmp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(server_tmp_dir))
|
||||
self.server_storage = SQLiteStorage(Config(), os.path.join(server_tmp_dir, "lbrynet.sqlite"))
|
||||
await self.server_storage.open()
|
||||
self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage)
|
||||
|
||||
download_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(download_dir))
|
||||
|
||||
# create the stream
|
||||
file_path = os.path.join(tmp_dir, "test_file")
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(self.cleartext)
|
||||
|
||||
self.stream = await self.stream_manager.create_stream(file_path)
|
||||
|
||||
async def test_reflect_stream(self):
|
||||
reflector = ReflectorServer(self.server_blob_manager)
|
||||
reflector.start_server(5566, '127.0.0.1')
|
||||
await reflector.started_listening.wait()
|
||||
self.addCleanup(reflector.stop_server)
|
||||
sent = await self.stream.upload_to_reflector('127.0.0.1', 5566)
|
||||
self.assertSetEqual(
|
||||
set(sent),
|
||||
set(map(lambda b: b.blob_hash,
|
||||
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
|
||||
)
|
||||
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
|
||||
self.assertTrue(server_sd_blob.get_is_verified())
|
||||
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
|
||||
for blob in self.stream.descriptor.blobs[:-1]:
|
||||
server_blob = self.server_blob_manager.get_blob(blob.blob_hash)
|
||||
self.assertTrue(server_blob.get_is_verified())
|
||||
self.assertEqual(server_blob.length, blob.length)
|
||||
|
||||
sent = await self.stream.upload_to_reflector('127.0.0.1', 5566)
|
||||
self.assertListEqual(sent, [])
|
Loading…
Reference in a new issue