refactor BlobFile into AbstractBlob, BlobFile, and BlobBuffer classes

This commit is contained in:
Jack Robison 2019-03-30 20:17:42 -04:00
parent d44a79ada2
commit 676f0015aa
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
7 changed files with 290 additions and 174 deletions

View file

@ -4,6 +4,8 @@ import asyncio
import binascii import binascii
import logging import logging
import typing import typing
import contextlib
from io import BytesIO
from cryptography.hazmat.primitives.ciphers import Cipher, modes from cryptography.hazmat.primitives.ciphers import Cipher, modes
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.padding import PKCS7 from cryptography.hazmat.primitives.padding import PKCS7
@ -21,10 +23,6 @@ log = logging.getLogger(__name__)
_hexmatch = re.compile("^[a-f,0-9]+$") _hexmatch = re.compile("^[a-f,0-9]+$")
def is_valid_hashcharacter(char: str) -> bool:
return len(char) == 1 and _hexmatch.match(char)
def is_valid_blobhash(blobhash: str) -> bool: def is_valid_blobhash(blobhash: str) -> bool:
"""Checks whether the blobhash is the correct length and contains only """Checks whether the blobhash is the correct length and contains only
valid characters (0-9, a-f) valid characters (0-9, a-f)
@ -46,143 +44,51 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl
return encrypted, digest.hexdigest() return encrypted, digest.hexdigest()
class BlobFile: def decrypt_blob_bytes(read_handle: typing.BinaryIO, length: int, key: bytes, iv: bytes) -> bytes:
""" buff = read_handle.read()
A chunk of data available on the network which is specified by a hashsum if len(buff) != length:
This class is used to create blobs on the local filesystem
when we already know the blob hash before hand (i.e., when downloading blobs)
Also can be used for reading from blobs on the local filesystem
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str,
length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None):
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
self.loop = loop
self.blob_hash = blob_hash
self.length = length
self.blob_dir = blob_dir
self.file_path = os.path.join(blob_dir, self.blob_hash)
self.writers: typing.List[HashBlobWriter] = []
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=loop)
self.blob_write_lock = asyncio.Lock(loop=loop)
if self.file_exists:
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
self.length = length
self.verified.set()
self.finished_writing.set()
self.saved_verified_blob = False
self.blob_completed_callback = blob_completed_callback
@property
def file_exists(self):
return os.path.isfile(self.file_path)
def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future):
try:
error = finished.exception()
except Exception as err:
error = err
if writer in self.writers: # remove this download attempt
self.writers.remove(writer)
if not error: # the blob downloaded, cancel all the other download attempts and set the result
while self.writers:
other = self.writers.pop()
other.finished.cancel()
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
t.add_done_callback(lambda *_: self.finished_writing.set())
return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
log.exception("something else")
raise error
return callback
async def save_verified_blob(self, writer, verified_bytes: bytes):
def _save_verified():
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
if self.get_length() == len(verified_bytes):
with open(self.file_path, 'wb') as write_handle:
write_handle.write(verified_bytes)
self.saved_verified_blob = True
else:
raise Exception("length mismatch")
async with self.blob_write_lock:
if self.verified.is_set():
return
await self.loop.run_in_executor(None, _save_verified)
if self.blob_completed_callback:
await self.blob_completed_callback(self)
self.verified.set()
def open_for_writing(self) -> HashBlobWriter:
if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'")
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers.append(writer)
fut.add_done_callback(self.writer_finished(writer))
return writer
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
"""
Read and send the file to the writer and return the number of bytes sent
"""
with open(self.file_path, 'rb') as handle:
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
def close(self):
while self.writers:
self.writers.pop().finished.cancel()
def delete(self):
self.close()
self.saved_verified_blob = False
if os.path.isfile(self.file_path):
os.remove(self.file_path)
self.verified.clear()
self.finished_writing.clear()
self.length = None
def decrypt(self, key: bytes, iv: bytes) -> bytes:
"""
Decrypt a BlobFile to plaintext bytes
"""
with open(self.file_path, "rb") as f:
buff = f.read()
if len(buff) != self.length:
raise ValueError("unexpected length") raise ValueError("unexpected length")
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend) cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
unpadder = PKCS7(AES.block_size).unpadder() unpadder = PKCS7(AES.block_size).unpadder()
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize() return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
@classmethod
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
"""
Create an encrypted BlobFile from plaintext bytes
"""
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted) class AbstractBlob:
length = len(blob_bytes) """
blob = cls(loop, blob_dir, blob_hash, length) A chunk of data (up to 2MB) available on the network which is specified by a sha384 hash
writer = blob.open_for_writing()
writer.write(blob_bytes)
await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
def set_length(self, length): This class is non-io specific
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
blob_directory: typing.Optional[str] = None):
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
self.loop = loop
self.blob_hash = blob_hash
self.length = length
self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory
self.writers: typing.List[HashBlobWriter] = []
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
def __del__(self):
if self.writers:
log.warning("%s not closed before being garbage collected", self.blob_hash)
self.close()
@contextlib.contextmanager
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
raise NotImplementedError()
def _write_blob(self, blob_bytes: bytes):
raise NotImplementedError()
def set_length(self, length) -> None:
if self.length is not None and length == self.length: if self.length is not None and length == self.length:
return return
if self.length is None and 0 <= length <= MAX_BLOB_SIZE: if self.length is None and 0 <= length <= MAX_BLOB_SIZE:
@ -190,8 +96,192 @@ class BlobFile:
return return
log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length) log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length)
def get_length(self): def get_length(self) -> typing.Optional[int]:
return self.length return self.length
def get_is_verified(self): def get_is_verified(self) -> bool:
return self.verified.is_set() return self.verified.is_set()
def is_readable(self) -> bool:
return self.verified.is_set()
def is_writeable(self) -> bool:
return not self.writing.is_set()
def write_blob(self, blob_bytes: bytes):
if not self.is_writeable():
raise OSError("cannot open blob for writing")
try:
self.writing.set()
self._write_blob(blob_bytes)
finally:
self.writing.clear()
def close(self):
while self.writers:
self.writers.pop().finished.cancel()
def delete(self):
self.close()
self.verified.clear()
self.length = None
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
"""
Read and send the file to the writer and return the number of bytes sent
"""
if not self.is_readable():
raise OSError('blob files cannot be read')
with self.reader_context() as handle:
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
def decrypt(self, key: bytes, iv: bytes) -> bytes:
"""
Decrypt a BlobFile to plaintext bytes
"""
with self.reader_context() as reader:
return decrypt_blob_bytes(reader, self.length, key, iv)
@classmethod
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes,
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
"""
Create an encrypted BlobFile from plaintext bytes
"""
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes)
blob = cls(loop, blob_hash, length, blob_directory=blob_dir)
writer = blob.get_blob_writer()
writer.write(blob_bytes)
await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
async def save_verified_blob(self, verified_bytes: bytes):
if self.verified.is_set():
return
if self.is_writeable():
if self.get_length() == len(verified_bytes):
self._write_blob(verified_bytes)
self.verified.set()
if self.blob_completed_callback:
await self.blob_completed_callback(self)
else:
raise Exception("length mismatch")
def get_blob_writer(self) -> HashBlobWriter:
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers.append(writer)
def writer_finished_callback(finished: asyncio.Future):
try:
err = finished.exception()
if err:
raise err
verified_bytes = finished.result()
while self.writers:
other = self.writers.pop()
if other is not writer:
other.finished.cancel()
self.loop.create_task(self.save_verified_blob(verified_bytes))
return
except (InvalidBlobHashError, InvalidDataError) as error:
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError) as error:
# log.exception("something else")
pass
finally:
if writer in self.writers:
self.writers.remove(writer)
fut.add_done_callback(writer_finished_callback)
return writer
class BlobBuffer(AbstractBlob):
"""
An in-memory only blob
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
blob_directory: typing.Optional[str] = None):
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
self._verified_bytes: typing.Optional[BytesIO] = None
@contextlib.contextmanager
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
if not self.is_readable():
raise OSError("cannot open blob for reading")
try:
yield self._verified_bytes
finally:
self._verified_bytes.close()
self._verified_bytes = None
self.verified.clear()
def _write_blob(self, blob_bytes: bytes):
if self._verified_bytes:
raise OSError("already have bytes for blob")
self._verified_bytes = BytesIO(blob_bytes)
class BlobFile(AbstractBlob):
"""
A blob existing on the local file system
"""
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
blob_directory: typing.Optional[str] = None):
if not blob_directory or not os.path.isdir(blob_directory):
raise OSError(f"invalid blob directory '{blob_directory}'")
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
self.file_path = os.path.join(self.blob_directory, self.blob_hash)
if self.file_exists:
file_size = int(os.stat(self.file_path).st_size)
if length and length != file_size:
log.warning("expected %s to be %s bytes, file has %s", self.blob_hash, length, file_size)
self.delete()
else:
self.length = file_size
self.verified.set()
@property
def file_exists(self):
return os.path.isfile(self.file_path)
def is_writeable(self) -> bool:
return super().is_writeable() and not os.path.isfile(self.file_path)
def get_blob_writer(self) -> HashBlobWriter:
if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'")
return super().get_blob_writer()
@contextlib.contextmanager
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
handle = open(self.file_path, 'rb')
try:
yield handle
finally:
handle.close()
def _write_blob(self, blob_bytes: bytes):
with open(self.file_path, 'wb') as f:
f.write(blob_bytes)
def delete(self):
if os.path.isfile(self.file_path):
os.remove(self.file_path)
return super().delete()
@classmethod
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
if not blob_dir or not os.path.isdir(blob_dir):
raise OSError(f"cannot create blob in directory: '{blob_dir}'")
return await super().create_from_unencrypted(loop, blob_dir, key, iv, unencrypted, blob_num)

View file

@ -2,7 +2,7 @@ import os
import typing import typing
import asyncio import asyncio
import logging import logging
from lbrynet.blob.blob_file import BlobFile, is_valid_blobhash from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob
from lbrynet.stream.descriptor import StreamDescriptor from lbrynet.stream.descriptor import StreamDescriptor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class BlobManager: class BlobManager:
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage', def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage',
node_data_store: typing.Optional['DictDataStore'] = None): node_data_store: typing.Optional['DictDataStore'] = None, save_blobs: bool = True):
""" """
This class stores blobs on the hard disk This class stores blobs on the hard disk
@ -27,16 +27,25 @@ class BlobManager:
self._node_data_store = node_data_store self._node_data_store = node_data_store
self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\ self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\
else self._node_data_store.completed_blobs else self._node_data_store.completed_blobs
self.blobs: typing.Dict[str, BlobFile] = {} self.blobs: typing.Dict[str, AbstractBlob] = {}
self._save_blobs = save_blobs
def get_blob_class(self):
if not self._save_blobs:
return BlobBuffer
return BlobFile
async def setup(self) -> bool: async def setup(self) -> bool:
def get_files_in_blob_dir() -> typing.Set[str]: def get_files_in_blob_dir() -> typing.Set[str]:
if not self.blob_dir:
return set()
return { return {
item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name) item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name)
} }
in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir) in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir)
self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir)) to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir)
if to_add:
self.completed_blob_hashes.update(to_add)
return True return True
def stop(self): def stop(self):
@ -50,17 +59,20 @@ class BlobManager:
if length and self.blobs[blob_hash].length is None: if length and self.blobs[blob_hash].length is None:
self.blobs[blob_hash].set_length(length) self.blobs[blob_hash].set_length(length)
else: else:
self.blobs[blob_hash] = BlobFile(self.loop, self.blob_dir, blob_hash, length, self.blob_completed) self.blobs[blob_hash] = self.get_blob_class()(self.loop, blob_hash, length, self.blob_completed,
self.blob_dir)
return self.blobs[blob_hash] return self.blobs[blob_hash]
def get_stream_descriptor(self, sd_hash): def get_stream_descriptor(self, sd_hash):
return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash)) return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash))
async def blob_completed(self, blob: BlobFile): async def blob_completed(self, blob: AbstractBlob):
if blob.blob_hash is None: if blob.blob_hash is None:
raise Exception("Blob hash is None") raise Exception("Blob hash is None")
if not blob.length: if not blob.length:
raise Exception("Blob has a length of 0") raise Exception("Blob has a length of 0")
if isinstance(blob, BlobBuffer): # don't save blob buffers to the db / dont announce them
return
if blob.blob_hash not in self.completed_blob_hashes: if blob.blob_hash not in self.completed_blob_hashes:
self.completed_blob_hashes.add(blob.blob_hash) self.completed_blob_hashes.add(blob.blob_hash)
await self.storage.add_completed_blob(blob.blob_hash, blob.length) await self.storage.add_completed_blob(blob.blob_hash, blob.length)
@ -75,7 +87,7 @@ class BlobManager:
raise Exception("invalid blob hash to delete") raise Exception("invalid blob hash to delete")
if blob_hash not in self.blobs: if blob_hash not in self.blobs:
if os.path.isfile(os.path.join(self.blob_dir, blob_hash)): if self.blob_dir and os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
os.remove(os.path.join(self.blob_dir, blob_hash)) os.remove(os.path.join(self.blob_dir, blob_hash))
else: else:
self.blobs.pop(blob_hash).delete() self.blobs.pop(blob_hash).delete()

View file

@ -5,7 +5,7 @@ import binascii
from lbrynet.error import InvalidBlobHashError, InvalidDataError from lbrynet.error import InvalidBlobHashError, InvalidDataError
from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbrynet.blob.blob_file import BlobFile from lbrynet.blob.blob_file import AbstractBlob
from lbrynet.blob.writer import HashBlobWriter from lbrynet.blob.writer import HashBlobWriter
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -17,10 +17,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.peer_port: typing.Optional[int] = None self.peer_port: typing.Optional[int] = None
self.peer_address: typing.Optional[str] = None self.peer_address: typing.Optional[str] = None
self.peer_timeout = peer_timeout self.peer_timeout = peer_timeout
self.transport: asyncio.Transport = None self.transport: typing.Optional[asyncio.Transport] = None
self.writer: 'HashBlobWriter' = None self.writer: typing.Optional['HashBlobWriter'] = None
self.blob: 'BlobFile' = None self.blob: typing.Optional['AbstractBlob'] = None
self._blob_bytes_received = 0 self._blob_bytes_received = 0
self._response_fut: asyncio.Future = None self._response_fut: asyncio.Future = None
@ -63,8 +63,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
# write blob bytes if we're writing a blob and have blob bytes to write # write blob bytes if we're writing a blob and have blob bytes to write
self._write(response.blob_data) self._write(response.blob_data)
def _write(self, data: bytes):
def _write(self, data):
if len(data) > (self.blob.get_length() - self._blob_bytes_received): if len(data) > (self.blob.get_length() - self._blob_bytes_received):
data = data[:(self.blob.get_length() - self._blob_bytes_received)] data = data[:(self.blob.get_length() - self._blob_bytes_received)]
log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port) log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port)
@ -145,11 +144,12 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.transport = None self.transport = None
self.buf = b'' self.buf = b''
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]: async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked(): if blob.get_is_verified() or not blob.is_writeable():
return 0, self.transport return 0, self.transport
try: try:
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0 blob.get_blob_writer()
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0
self._response_fut = asyncio.Future(loop=self.loop) self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError as e: except OSError as e:
@ -177,7 +177,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.close() self.close()
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int, async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float, peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\ connected_transport: asyncio.Transport = None)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]: -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
@ -196,7 +196,7 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: s
if not connected_transport: if not connected_transport:
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
peer_connect_timeout, loop=loop) peer_connect_timeout, loop=loop)
if blob.get_is_verified() or blob.file_exists: if blob.get_is_verified() or not blob.is_writeable():
# file exists but not verified means someone is writing right now, give it time, come back later # file exists but not verified means someone is writing right now, give it time, come back later
return 0, connected_transport return 0, connected_transport
return await protocol.download_blob(blob) return await protocol.download_blob(blob)

View file

@ -8,7 +8,7 @@ if typing.TYPE_CHECKING:
from lbrynet.dht.node import Node from lbrynet.dht.node import Node
from lbrynet.dht.peer import KademliaPeer from lbrynet.dht.peer import KademliaPeer
from lbrynet.blob.blob_manager import BlobManager from lbrynet.blob.blob_manager import BlobManager
from lbrynet.blob.blob_file import BlobFile from lbrynet.blob.blob_file import AbstractBlob
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -28,7 +28,7 @@ class BlobDownloader:
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
self.time_since_last_blob = loop.time() self.time_since_last_blob = loop.time()
def should_race_continue(self, blob: 'BlobFile'): def should_race_continue(self, blob: 'AbstractBlob'):
if len(self.active_connections) >= self.config.max_connections_per_download: if len(self.active_connections) >= self.config.max_connections_per_download:
return False return False
# if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves # if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves
@ -37,9 +37,9 @@ class BlobDownloader:
# for peer, task in self.active_connections.items(): # for peer, task in self.active_connections.items():
# if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done(): # if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done():
# return False # return False
return not (blob.get_is_verified() or blob.file_exists) return not (blob.get_is_verified() or not blob.is_writeable())
async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'): async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
if blob.get_is_verified(): if blob.get_is_verified():
return return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
@ -62,7 +62,7 @@ class BlobDownloader:
rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0
self.scores[peer] = rough_speed self.scores[peer] = rough_speed
async def new_peer_or_finished(self, blob: 'BlobFile'): async def new_peer_or_finished(self, blob: 'AbstractBlob'):
async def get_and_re_add_peers(): async def get_and_re_add_peers():
try: try:
new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0) new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0)
@ -90,7 +90,7 @@ class BlobDownloader:
for banned_peer in forgiven: for banned_peer in forgiven:
self.ignored.pop(banned_peer) self.ignored.pop(banned_peer)
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length) blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified(): if blob.get_is_verified():
return blob return blob
@ -99,7 +99,7 @@ class BlobDownloader:
batch: typing.List['KademliaPeer'] = [] batch: typing.List['KademliaPeer'] = []
while not self.peer_queue.empty(): while not self.peer_queue.empty():
batch.extend(self.peer_queue.get_nowait()) batch.extend(self.peer_queue.get_nowait())
batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True) batch.sort(key=lambda p: self.scores.get(p, 0), reverse=True)
log.debug( log.debug(
"running, %d peers, %d ignored, %d active", "running, %d peers, %d ignored, %d active",
len(batch), len(self.ignored), len(self.active_connections) len(batch), len(self.ignored), len(self.active_connections)
@ -114,15 +114,29 @@ class BlobDownloader:
await self.new_peer_or_finished(blob) await self.new_peer_or_finished(blob)
self.cleanup_active() self.cleanup_active()
if batch: if batch:
self.peer_queue.put_nowait(set(batch).difference(self.ignored)) to_re_add = list(set(batch).difference(self.ignored))
if to_re_add:
self.peer_queue.put_nowait(to_re_add)
else:
self.clearbanned()
else: else:
self.clearbanned() self.clearbanned()
blob.close() blob.close()
log.debug("downloaded %s", blob_hash[:8]) log.debug("downloaded %s", blob_hash[:8])
return blob return blob
except asyncio.CancelledError as err:
error = err
finally: finally:
re_add = set()
while self.active_connections: while self.active_connections:
self.active_connections.popitem()[1].cancel() peer, t = self.active_connections.popitem()
t.cancel()
re_add.add(peer)
re_add = re_add.difference(self.ignored)
if re_add:
self.peer_queue.put_nowait(list(re_add))
blob.close()
raise error
def close(self): def close(self):
self.scores.clear() self.scores.clear()
@ -132,7 +146,7 @@ class BlobDownloader:
async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node', async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node',
blob_hash: str) -> 'BlobFile': blob_hash: str) -> 'AbstractBlob':
search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download) search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download)
search_queue.put_nowait(blob_hash) search_queue.put_nowait(blob_hash)
peer_queue, accumulate_task = node.accumulate_peers(search_queue) peer_queue, accumulate_task = node.accumulate_peers(search_queue)

View file

@ -294,7 +294,7 @@ class BlobComponent(Component):
blob_dir = os.path.join(self.conf.data_dir, 'blobfiles') blob_dir = os.path.join(self.conf.data_dir, 'blobfiles')
if not os.path.isdir(blob_dir): if not os.path.isdir(blob_dir):
os.mkdir(blob_dir) os.mkdir(blob_dir)
self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store) self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store, self.conf.save_blobs)
return await self.blob_manager.setup() return await self.blob_manager.setup()
async def stop(self): async def stop(self):

View file

@ -8,7 +8,7 @@ from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbrynet.blob import MAX_BLOB_SIZE from lbrynet.blob import MAX_BLOB_SIZE
from lbrynet.blob.blob_info import BlobInfo from lbrynet.blob.blob_info import BlobInfo
from lbrynet.blob.blob_file import BlobFile from lbrynet.blob.blob_file import AbstractBlob, BlobFile
from lbrynet.cryptoutils import get_lbry_hash_obj from lbrynet.cryptoutils import get_lbry_hash_obj
from lbrynet.error import InvalidStreamDescriptorError from lbrynet.error import InvalidStreamDescriptorError
@ -108,29 +108,29 @@ class StreamDescriptor:
h.update(self.old_sort_json()) h.update(self.old_sort_json())
return h.hexdigest() return h.hexdigest()
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None, async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None,
old_sort: typing.Optional[bool] = False): old_sort: typing.Optional[bool] = False):
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash() sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
if not old_sort: if not old_sort:
sd_data = self.as_json() sd_data = self.as_json()
else: else:
sd_data = self.old_sort_json() sd_data = self.old_sort_json()
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data)) sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_directory=self.blob_dir)
if blob_file_obj: if blob_file_obj:
blob_file_obj.set_length(len(sd_data)) blob_file_obj.set_length(len(sd_data))
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
writer = sd_blob.open_for_writing() writer = sd_blob.get_blob_writer()
writer.write(sd_data) writer.write(sd_data)
await sd_blob.verified.wait() await sd_blob.verified.wait()
sd_blob.close() sd_blob.close()
return sd_blob return sd_blob
@classmethod @classmethod
def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
blob: BlobFile) -> 'StreamDescriptor': blob: AbstractBlob) -> 'StreamDescriptor':
assert os.path.isfile(blob.file_path) with blob.reader_context() as blob_reader:
with open(blob.file_path, 'rb') as f: json_bytes = blob_reader.read()
json_bytes = f.read()
try: try:
decoded = json.loads(json_bytes.decode()) decoded = json.loads(json_bytes.decode())
except json.JSONDecodeError: except json.JSONDecodeError:
@ -160,8 +160,8 @@ class StreamDescriptor:
@classmethod @classmethod
async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str, async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
blob: BlobFile) -> 'StreamDescriptor': blob: AbstractBlob) -> 'StreamDescriptor':
return await loop.run_in_executor(None, lambda: cls._from_stream_descriptor_blob(loop, blob_dir, blob)) return await loop.run_in_executor(None, cls._from_stream_descriptor_blob, loop, blob_dir, blob)
@staticmethod @staticmethod
def get_blob_hashsum(b: typing.Dict): def get_blob_hashsum(b: typing.Dict):
@ -228,7 +228,7 @@ class StreamDescriptor:
return self.lower_bound_decrypted_length() + (AES.block_size // 8) return self.lower_bound_decrypted_length() + (AES.block_size // 8)
@classmethod @classmethod
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str, async def recover(cls, blob_dir: str, sd_blob: 'AbstractBlob', stream_hash: str, stream_name: str,
suggested_file_name: str, key: str, suggested_file_name: str, key: str,
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']: blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name, descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,

View file

@ -63,11 +63,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.open_for_writing() self.writer = self.sd_blob.get_blob_writer()
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop) await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop)
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
@ -102,11 +102,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.open_for_writing() self.writer = blob.get_blob_writer()
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:
await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop) await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop)
self.send_response({"received_blob": True}) self.send_response({"received_blob": True})
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.send_response({"received_blob": False}) self.send_response({"received_blob": False})