async reflector

This commit is contained in:
Jack Robison 2019-01-25 15:05:22 -05:00
parent d0e6e74e1b
commit a7610e3d34
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 377 additions and 6 deletions

View file

@ -466,7 +466,7 @@ class StreamManagerComponent(Component):
self.conf.peer_connect_timeout, [ self.conf.peer_connect_timeout, [
KademliaPeer(loop, address=(await resolve_host(loop, url)), tcp_port=port + 1) KademliaPeer(loop, address=(await resolve_host(loop, url)), tcp_port=port + 1)
for url, port in self.conf.reflector_servers for url, port in self.conf.reflector_servers
] ], self.conf.reflector_servers
) )
await self.stream_manager.start() await self.stream_manager.start()
log.info('Done setting up file manager') log.info('Done setting up file manager')

View file

@ -5,6 +5,7 @@ import logging
from lbrynet.extras.daemon.mime_types import guess_media_type from lbrynet.extras.daemon.mime_types import guess_media_type
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
from lbrynet.stream.reflector.client import StreamReflectorClient
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbrynet.extras.daemon.storage import StoredStreamClaim from lbrynet.extras.daemon.storage import StoredStreamClaim
from lbrynet.blob.blob_manager import BlobFileManager from lbrynet.blob.blob_manager import BlobFileManager
@ -30,6 +31,7 @@ class ManagedStream:
self.stream_claim_info = claim self.stream_claim_info = claim
self._status = status self._status = status
self._store_after_finished: asyncio.Task = None self._store_after_finished: asyncio.Task = None
self.fully_reflected = asyncio.Event(loop=self.loop)
@property @property
def status(self) -> str: def status(self) -> str:
@ -139,9 +141,10 @@ class ManagedStream:
@classmethod @classmethod
async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', async def create(cls, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager',
file_path: str) -> 'ManagedStream': file_path: str, key: typing.Optional[bytes] = None,
iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream':
descriptor = await StreamDescriptor.create_stream( descriptor = await StreamDescriptor.create_stream(
loop, blob_manager.blob_dir, file_path loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator
) )
sd_blob = blob_manager.get_blob(descriptor.sd_hash) sd_blob = blob_manager.get_blob(descriptor.sd_hash)
await blob_manager.blob_completed(sd_blob) await blob_manager.blob_completed(sd_blob)
@ -156,3 +159,38 @@ class ManagedStream:
await self.downloader.stop() await self.downloader.stop()
if not self.finished: if not self.finished:
self.update_status(self.STATUS_STOPPED) self.update_status(self.STATUS_STOPPED)
async def upload_to_reflector(self, host: str, port: int) -> typing.List[str]:
sent = []
protocol = StreamReflectorClient(self.blob_manager, self.descriptor)
try:
await self.loop.create_connection(lambda: protocol, host, port)
except ConnectionRefusedError:
return sent
try:
await protocol.send_handshake()
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
if protocol.transport:
protocol.transport.close()
return sent
try:
sent_sd, needed = await protocol.send_descriptor()
if sent_sd:
sent.append(self.sd_hash)
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
if protocol.transport:
protocol.transport.close()
return sent
for blob_hash in needed:
try:
await protocol.send_blob(blob_hash)
sent.append(blob_hash)
except (asyncio.CancelledError, asyncio.TimeoutError, ValueError):
if protocol.transport:
protocol.transport.close()
return sent
if protocol.transport:
protocol.transport.close()
if not self.fully_reflected.is_set():
self.fully_reflected.set()
return sent

View file

View file

@ -0,0 +1,103 @@
import asyncio
import json
import logging
import typing
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.stream.descriptor import StreamDescriptor
REFLECTOR_V1 = 0
REFLECTOR_V2 = 1
log = logging.getLogger(__name__)
class StreamReflectorClient(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
self.transport: asyncio.StreamWriter = None
self.blob_manager = blob_manager
self.descriptor = descriptor
self.response_buff = b''
self.reflected_blobs = []
self.connected = asyncio.Event()
self.response_queue = asyncio.Queue(maxsize=1)
self.pending_request: typing.Optional[asyncio.Task] = None
def connection_made(self, transport):
self.transport = transport
log.info("Connected to reflector")
self.connected.set()
def connection_lost(self, exc: typing.Optional[Exception]):
self.transport = None
self.connected.clear()
def data_received(self, data):
try:
response = json.loads(data.decode())
self.response_queue.put_nowait(response)
except ValueError:
self.transport.close()
return
async def send_request(self, request_dict: typing.Dict):
msg = json.dumps(request_dict)
self.transport.write(msg.encode())
try:
self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get())
return await self.pending_request
finally:
self.pending_request = None
async def send_handshake(self) -> None:
response_dict = await self.send_request({'version': REFLECTOR_V2})
if 'version' not in response_dict:
raise ValueError("Need protocol version number!")
server_version = int(response_dict['version'])
if server_version != REFLECTOR_V2:
raise ValueError("I can't handle protocol version {}!".format(server_version))
return
async def send_descriptor(self) -> typing.Tuple[bool, typing.List[str]]: # returns a list of needed blob hashes
sd_blob = self.blob_manager.get_blob(self.descriptor.sd_hash)
assert sd_blob.get_is_verified(), "need to have a sd blob to send at this point"
response = await self.send_request({
'sd_blob_hash': sd_blob.blob_hash,
'sd_blob_size': sd_blob.length
})
if 'send_sd_blob' not in response:
raise ValueError("I don't know whether to send the sd blob or not!")
needed = response.get('needed_blobs', [])
sent_sd = False
if response['send_sd_blob']:
await sd_blob.sendfile(self)
received = await self.response_queue.get()
if received.get('received_sd_blob'):
sent_sd = True
if not needed:
for blob in self.descriptor.blobs[:-1]:
if self.blob_manager.get_blob(blob.blob_hash, blob.length).get_is_verified():
needed.append(blob.blob_hash)
log.info("Sent reflector descriptor %s", sd_blob.blob_hash[:8])
else:
log.warning("Reflector failed to receive descriptor %s", sd_blob.blob_hash[:8])
return sent_sd, needed
async def send_blob(self, blob_hash: str):
blob = self.blob_manager.get_blob(blob_hash)
assert blob.get_is_verified(), "need to have a blob to send at this point"
response = await self.send_request({
'blob_hash': blob.blob_hash,
'blob_size': blob.length
})
if 'send_blob' not in response:
raise ValueError("I don't know whether to send the blob or not!")
if response['send_blob']:
await blob.sendfile(self)
received = await self.response_queue.get()
if received.get('received_blob'):
self.reflected_blobs.append(blob.blob_hash)
log.info("Sent reflector blob %s", blob.blob_hash[:8])
else:
log.warning("Reflector failed to receive blob %s", blob.blob_hash[:8])

View file

@ -0,0 +1,147 @@
import asyncio
import logging
import typing
import json
from json.decoder import JSONDecodeError
from lbrynet.stream.descriptor import StreamDescriptor
if typing.TYPE_CHECKING:
from lbrynet.blob.blob_file import BlobFile
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.blob.writer import HashBlobWriter
log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobFileManager'):
self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None
self.started_listening = asyncio.Event(loop=self.loop)
self.buf = b''
self.transport: asyncio.StreamWriter = None
self.writer: typing.Optional['HashBlobWriter'] = None
self.client_version: typing.Optional[int] = None
self.descriptor: typing.Optional['StreamDescriptor'] = None
self.sd_blob: typing.Optional['BlobFile'] = None
self.received = []
self.incoming = asyncio.Event(loop=self.loop)
def connection_made(self, transport):
self.transport = transport
def data_received(self, data: bytes):
if self.incoming.is_set():
try:
self.writer.write(data)
except IOError as err:
log.error("error downloading blob: %s", err)
return
try:
request = json.loads(data.decode())
except (ValueError, JSONDecodeError):
return
self.loop.create_task(self.handle_request(request))
def send_response(self, response: typing.Dict):
self.transport.write(json.dumps(response).encode())
async def handle_request(self, request: typing.Dict):
if self.client_version is None:
if 'version' not in request:
self.transport.close()
return
self.client_version = request['version']
self.send_response({'version': 1})
return
if not self.sd_blob:
if 'sd_blob_hash' not in request:
self.transport.close()
return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.open_for_writing()
self.incoming.set()
self.send_response({"send_sd_blob": True})
try:
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
self.send_response({"received_sd_blob": True})
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob
)
self.incoming.clear()
self.writer.close_handle()
self.writer = None
except (asyncio.TimeoutError, asyncio.CancelledError):
self.send_response({"received_sd_blob": False})
self.incoming.clear()
self.writer.close_handle()
self.writer = None
self.transport.close()
return
else:
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob
)
self.send_response({"send_sd_blob": False, 'needed': [
blob.blob_hash for blob in self.descriptor.blobs[:-1]
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
]})
return
elif self.descriptor:
if 'blob_hash' not in request:
self.transport.close()
return
if request['blob_hash'] not in map(lambda b: b.blob_hash, self.descriptor.blobs[:-1]):
self.send_response({"send_blob": False})
return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified():
self.writer = blob.open_for_writing()
self.incoming.set()
self.send_response({"send_blob": True})
try:
await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop)
self.send_response({"received_blob": True})
except (asyncio.TimeoutError, asyncio.CancelledError):
self.send_response({"received_blob": False})
self.incoming.clear()
self.writer.close_handle()
self.writer = None
else:
self.send_response({"send_blob": False})
return
else:
self.transport.close()
class ReflectorServer:
def __init__(self, blob_manager: 'BlobFileManager'):
self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager
self.server_task: asyncio.Task = None
self.started_listening = asyncio.Event(loop=self.loop)
def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'):
if self.server_task is not None:
raise Exception("already running")
async def _start_server():
server = await self.loop.create_server(
lambda: ReflectorServerProtocol(self.blob_manager),
interface, port
)
self.started_listening.set()
log.info("Reflector server listening on TCP %s:%i", interface, port)
async with server:
await server.serve_forever()
self.server_task = self.loop.create_task(_start_server())
def stop_server(self):
if self.server_task:
self.server_task.cancel()
self.server_task = None
log.info("Stopped reflector server")

View file

@ -3,6 +3,7 @@ import asyncio
import typing import typing
import binascii import binascii
import logging import logging
import random
from lbrynet.stream.downloader import StreamDownloader from lbrynet.stream.downloader import StreamDownloader
from lbrynet.stream.managed_stream import ManagedStream from lbrynet.stream.managed_stream import ManagedStream
from lbrynet.schema.claim import ClaimDict from lbrynet.schema.claim import ClaimDict
@ -46,7 +47,8 @@ comparison_operators = {
class StreamManager: class StreamManager:
def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', wallet: 'LbryWalletManager', def __init__(self, loop: asyncio.BaseEventLoop, blob_manager: 'BlobFileManager', wallet: 'LbryWalletManager',
storage: 'SQLiteStorage', node: typing.Optional['Node'], peer_timeout: float, storage: 'SQLiteStorage', node: typing.Optional['Node'], peer_timeout: float,
peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None): peer_connect_timeout: float, fixed_peers: typing.Optional[typing.List['KademliaPeer']] = None,
reflector_servers: typing.Optional[typing.List[typing.Tuple[str, int]]] = None):
self.loop = loop self.loop = loop
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.wallet = wallet self.wallet = wallet
@ -59,6 +61,7 @@ class StreamManager:
self.resume_downloading_task: asyncio.Task = None self.resume_downloading_task: asyncio.Task = None
self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.update_stream_finished_futs: typing.List[asyncio.Future] = []
self.fixed_peers = fixed_peers self.fixed_peers = fixed_peers
self.reflector_servers = reflector_servers
async def load_streams_from_database(self): async def load_streams_from_database(self):
infos = await self.storage.get_all_lbry_files() infos = await self.storage.get_all_lbry_files()
@ -93,6 +96,20 @@ class StreamManager:
if resumed: if resumed:
log.info("resuming %i downloads", resumed) log.info("resuming %i downloads", resumed)
async def reflect_streams(self):
streams = list(self.streams)
batch = []
while streams:
stream = streams.pop()
if not stream.fully_reflected.is_set():
host, port = random.choice(self.reflector_servers)
batch.append(stream.upload_to_reflector(host, port))
if len(batch) >= 10:
await asyncio.gather(*batch)
batch = []
if batch:
await asyncio.gather(*batch)
async def start(self): async def start(self):
await self.load_streams_from_database() await self.load_streams_from_database()
self.resume_downloading_task = self.loop.create_task(self.resume()) self.resume_downloading_task = self.loop.create_task(self.resume())
@ -106,9 +123,13 @@ class StreamManager:
while self.update_stream_finished_futs: while self.update_stream_finished_futs:
self.update_stream_finished_futs.pop().cancel() self.update_stream_finished_futs.pop().cancel()
async def create_stream(self, file_path: str) -> ManagedStream: async def create_stream(self, file_path: str, key: typing.Optional[bytes] = None,
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path) iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream:
stream = await ManagedStream.create(self.loop, self.blob_manager, file_path, key, iv_generator)
self.streams.add(stream) self.streams.add(stream)
if self.reflector_servers:
host, port = random.choice(self.reflector_servers)
self.loop.create_task(stream.upload_to_reflector(host, port))
return stream return stream
async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False): async def delete_stream(self, stream: ManagedStream, delete_file: typing.Optional[bool] = False):

View file

@ -0,0 +1,62 @@
import os
import asyncio
import tempfile
import shutil
from torba.testcase import AsyncioTestCase
from lbrynet.conf import Config
from lbrynet.extras.daemon.storage import SQLiteStorage
from lbrynet.blob.blob_manager import BlobFileManager
from lbrynet.stream.stream_manager import StreamManager
from lbrynet.stream.reflector.server import ReflectorServer
class TestStreamAssembler(AsyncioTestCase):
async def asyncSetUp(self):
self.loop = asyncio.get_event_loop()
self.key = b'deadbeef' * 4
self.cleartext = os.urandom(20000000)
tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmp_dir))
self.storage = SQLiteStorage(Config(), os.path.join(tmp_dir, "lbrynet.sqlite"))
await self.storage.open()
self.blob_manager = BlobFileManager(self.loop, tmp_dir, self.storage)
self.stream_manager = StreamManager(self.loop, self.blob_manager, None, self.storage, None, 3.0, 3.0)
server_tmp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(server_tmp_dir))
self.server_storage = SQLiteStorage(Config(), os.path.join(server_tmp_dir, "lbrynet.sqlite"))
await self.server_storage.open()
self.server_blob_manager = BlobFileManager(self.loop, server_tmp_dir, self.server_storage)
download_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(download_dir))
# create the stream
file_path = os.path.join(tmp_dir, "test_file")
with open(file_path, 'wb') as f:
f.write(self.cleartext)
self.stream = await self.stream_manager.create_stream(file_path)
async def test_reflect_stream(self):
reflector = ReflectorServer(self.server_blob_manager)
reflector.start_server(5566, '127.0.0.1')
await reflector.started_listening.wait()
self.addCleanup(reflector.stop_server)
sent = await self.stream.upload_to_reflector('127.0.0.1', 5566)
self.assertSetEqual(
set(sent),
set(map(lambda b: b.blob_hash,
self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)]))
)
server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash)
self.assertTrue(server_sd_blob.get_is_verified())
self.assertEqual(server_sd_blob.length, server_sd_blob.length)
for blob in self.stream.descriptor.blobs[:-1]:
server_blob = self.server_blob_manager.get_blob(blob.blob_hash)
self.assertTrue(server_blob.get_is_verified())
self.assertEqual(server_blob.length, blob.length)
sent = await self.stream.upload_to_reflector('127.0.0.1', 5566)
self.assertListEqual(sent, [])