forked from LBRYCommunity/lbry-sdk
refactor BlobFile into AbstractBlob, BlobFile, and BlobBuffer classes
This commit is contained in:
parent
d44a79ada2
commit
676f0015aa
7 changed files with 290 additions and 174 deletions
|
@ -4,6 +4,8 @@ import asyncio
|
|||
import binascii
|
||||
import logging
|
||||
import typing
|
||||
import contextlib
|
||||
from io import BytesIO
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, modes
|
||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
@ -21,10 +23,6 @@ log = logging.getLogger(__name__)
|
|||
_hexmatch = re.compile("^[a-f,0-9]+$")
|
||||
|
||||
|
||||
def is_valid_hashcharacter(char: str) -> bool:
|
||||
return len(char) == 1 and _hexmatch.match(char)
|
||||
|
||||
|
||||
def is_valid_blobhash(blobhash: str) -> bool:
|
||||
"""Checks whether the blobhash is the correct length and contains only
|
||||
valid characters (0-9, a-f)
|
||||
|
@ -46,143 +44,51 @@ def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tupl
|
|||
return encrypted, digest.hexdigest()
|
||||
|
||||
|
||||
class BlobFile:
|
||||
"""
|
||||
A chunk of data available on the network which is specified by a hashsum
|
||||
def decrypt_blob_bytes(read_handle: typing.BinaryIO, length: int, key: bytes, iv: bytes) -> bytes:
|
||||
buff = read_handle.read()
|
||||
if len(buff) != length:
|
||||
raise ValueError("unexpected length")
|
||||
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
|
||||
unpadder = PKCS7(AES.block_size).unpadder()
|
||||
decryptor = cipher.decryptor()
|
||||
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
|
||||
|
||||
This class is used to create blobs on the local filesystem
|
||||
when we already know the blob hash before hand (i.e., when downloading blobs)
|
||||
Also can be used for reading from blobs on the local filesystem
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str,
|
||||
length: typing.Optional[int] = None,
|
||||
blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None):
|
||||
class AbstractBlob:
|
||||
"""
|
||||
A chunk of data (up to 2MB) available on the network which is specified by a sha384 hash
|
||||
|
||||
This class is non-io specific
|
||||
"""
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
|
||||
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
|
||||
blob_directory: typing.Optional[str] = None):
|
||||
if not is_valid_blobhash(blob_hash):
|
||||
raise InvalidBlobHashError(blob_hash)
|
||||
|
||||
self.loop = loop
|
||||
self.blob_hash = blob_hash
|
||||
self.length = length
|
||||
self.blob_dir = blob_dir
|
||||
self.file_path = os.path.join(blob_dir, self.blob_hash)
|
||||
self.writers: typing.List[HashBlobWriter] = []
|
||||
|
||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.finished_writing = asyncio.Event(loop=loop)
|
||||
self.blob_write_lock = asyncio.Lock(loop=loop)
|
||||
if self.file_exists:
|
||||
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
|
||||
self.length = length
|
||||
self.verified.set()
|
||||
self.finished_writing.set()
|
||||
self.saved_verified_blob = False
|
||||
self.blob_completed_callback = blob_completed_callback
|
||||
self.blob_directory = blob_directory
|
||||
|
||||
@property
|
||||
def file_exists(self):
|
||||
return os.path.isfile(self.file_path)
|
||||
self.writers: typing.List[HashBlobWriter] = []
|
||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
|
||||
def writer_finished(self, writer: HashBlobWriter):
|
||||
def callback(finished: asyncio.Future):
|
||||
try:
|
||||
error = finished.exception()
|
||||
except Exception as err:
|
||||
error = err
|
||||
if writer in self.writers: # remove this download attempt
|
||||
self.writers.remove(writer)
|
||||
if not error: # the blob downloaded, cancel all the other download attempts and set the result
|
||||
while self.writers:
|
||||
other = self.writers.pop()
|
||||
other.finished.cancel()
|
||||
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
|
||||
t.add_done_callback(lambda *_: self.finished_writing.set())
|
||||
return
|
||||
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
||||
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
|
||||
log.exception("something else")
|
||||
raise error
|
||||
return callback
|
||||
|
||||
async def save_verified_blob(self, writer, verified_bytes: bytes):
|
||||
def _save_verified():
|
||||
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
|
||||
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
|
||||
if self.get_length() == len(verified_bytes):
|
||||
with open(self.file_path, 'wb') as write_handle:
|
||||
write_handle.write(verified_bytes)
|
||||
self.saved_verified_blob = True
|
||||
else:
|
||||
raise Exception("length mismatch")
|
||||
|
||||
async with self.blob_write_lock:
|
||||
if self.verified.is_set():
|
||||
return
|
||||
await self.loop.run_in_executor(None, _save_verified)
|
||||
if self.blob_completed_callback:
|
||||
await self.blob_completed_callback(self)
|
||||
self.verified.set()
|
||||
|
||||
def open_for_writing(self) -> HashBlobWriter:
|
||||
if self.file_exists:
|
||||
raise OSError(f"File already exists '{self.file_path}'")
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||
self.writers.append(writer)
|
||||
fut.add_done_callback(self.writer_finished(writer))
|
||||
return writer
|
||||
|
||||
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
|
||||
"""
|
||||
Read and send the file to the writer and return the number of bytes sent
|
||||
"""
|
||||
|
||||
with open(self.file_path, 'rb') as handle:
|
||||
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
|
||||
|
||||
def close(self):
|
||||
while self.writers:
|
||||
self.writers.pop().finished.cancel()
|
||||
|
||||
def delete(self):
|
||||
def __del__(self):
|
||||
if self.writers:
|
||||
log.warning("%s not closed before being garbage collected", self.blob_hash)
|
||||
self.close()
|
||||
self.saved_verified_blob = False
|
||||
if os.path.isfile(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
self.verified.clear()
|
||||
self.finished_writing.clear()
|
||||
self.length = None
|
||||
|
||||
def decrypt(self, key: bytes, iv: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt a BlobFile to plaintext bytes
|
||||
"""
|
||||
@contextlib.contextmanager
|
||||
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
raise NotImplementedError()
|
||||
|
||||
with open(self.file_path, "rb") as f:
|
||||
buff = f.read()
|
||||
if len(buff) != self.length:
|
||||
raise ValueError("unexpected length")
|
||||
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
|
||||
unpadder = PKCS7(AES.block_size).unpadder()
|
||||
decryptor = cipher.decryptor()
|
||||
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
|
||||
def _write_blob(self, blob_bytes: bytes):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
|
||||
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
|
||||
"""
|
||||
Create an encrypted BlobFile from plaintext bytes
|
||||
"""
|
||||
|
||||
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
|
||||
length = len(blob_bytes)
|
||||
blob = cls(loop, blob_dir, blob_hash, length)
|
||||
writer = blob.open_for_writing()
|
||||
writer.write(blob_bytes)
|
||||
await blob.verified.wait()
|
||||
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
|
||||
|
||||
def set_length(self, length):
|
||||
def set_length(self, length) -> None:
|
||||
if self.length is not None and length == self.length:
|
||||
return
|
||||
if self.length is None and 0 <= length <= MAX_BLOB_SIZE:
|
||||
|
@ -190,8 +96,192 @@ class BlobFile:
|
|||
return
|
||||
log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length)
|
||||
|
||||
def get_length(self):
|
||||
def get_length(self) -> typing.Optional[int]:
|
||||
return self.length
|
||||
|
||||
def get_is_verified(self):
|
||||
def get_is_verified(self) -> bool:
|
||||
return self.verified.is_set()
|
||||
|
||||
def is_readable(self) -> bool:
|
||||
return self.verified.is_set()
|
||||
|
||||
def is_writeable(self) -> bool:
|
||||
return not self.writing.is_set()
|
||||
|
||||
def write_blob(self, blob_bytes: bytes):
|
||||
if not self.is_writeable():
|
||||
raise OSError("cannot open blob for writing")
|
||||
try:
|
||||
self.writing.set()
|
||||
self._write_blob(blob_bytes)
|
||||
finally:
|
||||
self.writing.clear()
|
||||
|
||||
def close(self):
|
||||
while self.writers:
|
||||
self.writers.pop().finished.cancel()
|
||||
|
||||
def delete(self):
|
||||
self.close()
|
||||
self.verified.clear()
|
||||
self.length = None
|
||||
|
||||
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
|
||||
"""
|
||||
Read and send the file to the writer and return the number of bytes sent
|
||||
"""
|
||||
|
||||
if not self.is_readable():
|
||||
raise OSError('blob files cannot be read')
|
||||
with self.reader_context() as handle:
|
||||
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
|
||||
|
||||
def decrypt(self, key: bytes, iv: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt a BlobFile to plaintext bytes
|
||||
"""
|
||||
|
||||
with self.reader_context() as reader:
|
||||
return decrypt_blob_bytes(reader, self.length, key, iv)
|
||||
|
||||
@classmethod
|
||||
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: typing.Optional[str], key: bytes,
|
||||
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
|
||||
"""
|
||||
Create an encrypted BlobFile from plaintext bytes
|
||||
"""
|
||||
|
||||
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
|
||||
length = len(blob_bytes)
|
||||
blob = cls(loop, blob_hash, length, blob_directory=blob_dir)
|
||||
writer = blob.get_blob_writer()
|
||||
writer.write(blob_bytes)
|
||||
await blob.verified.wait()
|
||||
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
|
||||
|
||||
async def save_verified_blob(self, verified_bytes: bytes):
|
||||
if self.verified.is_set():
|
||||
return
|
||||
if self.is_writeable():
|
||||
if self.get_length() == len(verified_bytes):
|
||||
self._write_blob(verified_bytes)
|
||||
self.verified.set()
|
||||
if self.blob_completed_callback:
|
||||
await self.blob_completed_callback(self)
|
||||
else:
|
||||
raise Exception("length mismatch")
|
||||
|
||||
def get_blob_writer(self) -> HashBlobWriter:
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||
self.writers.append(writer)
|
||||
|
||||
def writer_finished_callback(finished: asyncio.Future):
|
||||
try:
|
||||
err = finished.exception()
|
||||
if err:
|
||||
raise err
|
||||
verified_bytes = finished.result()
|
||||
while self.writers:
|
||||
other = self.writers.pop()
|
||||
if other is not writer:
|
||||
other.finished.cancel()
|
||||
self.loop.create_task(self.save_verified_blob(verified_bytes))
|
||||
return
|
||||
except (InvalidBlobHashError, InvalidDataError) as error:
|
||||
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError) as error:
|
||||
# log.exception("something else")
|
||||
pass
|
||||
finally:
|
||||
if writer in self.writers:
|
||||
self.writers.remove(writer)
|
||||
|
||||
fut.add_done_callback(writer_finished_callback)
|
||||
return writer
|
||||
|
||||
|
||||
class BlobBuffer(AbstractBlob):
|
||||
"""
|
||||
An in-memory only blob
|
||||
"""
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
|
||||
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
|
||||
blob_directory: typing.Optional[str] = None):
|
||||
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
|
||||
self._verified_bytes: typing.Optional[BytesIO] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
if not self.is_readable():
|
||||
raise OSError("cannot open blob for reading")
|
||||
try:
|
||||
yield self._verified_bytes
|
||||
finally:
|
||||
self._verified_bytes.close()
|
||||
self._verified_bytes = None
|
||||
self.verified.clear()
|
||||
|
||||
def _write_blob(self, blob_bytes: bytes):
|
||||
if self._verified_bytes:
|
||||
raise OSError("already have bytes for blob")
|
||||
self._verified_bytes = BytesIO(blob_bytes)
|
||||
|
||||
|
||||
class BlobFile(AbstractBlob):
|
||||
"""
|
||||
A blob existing on the local file system
|
||||
"""
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_hash: str, length: typing.Optional[int] = None,
|
||||
blob_completed_callback: typing.Optional[typing.Callable[['AbstractBlob'], typing.Awaitable]] = None,
|
||||
blob_directory: typing.Optional[str] = None):
|
||||
if not blob_directory or not os.path.isdir(blob_directory):
|
||||
raise OSError(f"invalid blob directory '{blob_directory}'")
|
||||
super().__init__(loop, blob_hash, length, blob_completed_callback, blob_directory)
|
||||
if not is_valid_blobhash(blob_hash):
|
||||
raise InvalidBlobHashError(blob_hash)
|
||||
self.file_path = os.path.join(self.blob_directory, self.blob_hash)
|
||||
if self.file_exists:
|
||||
file_size = int(os.stat(self.file_path).st_size)
|
||||
if length and length != file_size:
|
||||
log.warning("expected %s to be %s bytes, file has %s", self.blob_hash, length, file_size)
|
||||
self.delete()
|
||||
else:
|
||||
self.length = file_size
|
||||
self.verified.set()
|
||||
|
||||
@property
|
||||
def file_exists(self):
|
||||
return os.path.isfile(self.file_path)
|
||||
|
||||
def is_writeable(self) -> bool:
|
||||
return super().is_writeable() and not os.path.isfile(self.file_path)
|
||||
|
||||
def get_blob_writer(self) -> HashBlobWriter:
|
||||
if self.file_exists:
|
||||
raise OSError(f"File already exists '{self.file_path}'")
|
||||
return super().get_blob_writer()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
handle = open(self.file_path, 'rb')
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
handle.close()
|
||||
|
||||
def _write_blob(self, blob_bytes: bytes):
|
||||
with open(self.file_path, 'wb') as f:
|
||||
f.write(blob_bytes)
|
||||
|
||||
def delete(self):
|
||||
if os.path.isfile(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
return super().delete()
|
||||
|
||||
@classmethod
|
||||
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
|
||||
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
|
||||
if not blob_dir or not os.path.isdir(blob_dir):
|
||||
raise OSError(f"cannot create blob in directory: '{blob_dir}'")
|
||||
return await super().create_from_unencrypted(loop, blob_dir, key, iv, unencrypted, blob_num)
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import typing
|
||||
import asyncio
|
||||
import logging
|
||||
from lbrynet.blob.blob_file import BlobFile, is_valid_blobhash
|
||||
from lbrynet.blob.blob_file import is_valid_blobhash, BlobFile, BlobBuffer, AbstractBlob
|
||||
from lbrynet.stream.descriptor import StreamDescriptor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class BlobManager:
|
||||
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage',
|
||||
node_data_store: typing.Optional['DictDataStore'] = None):
|
||||
node_data_store: typing.Optional['DictDataStore'] = None, save_blobs: bool = True):
|
||||
"""
|
||||
This class stores blobs on the hard disk
|
||||
|
||||
|
@ -27,16 +27,25 @@ class BlobManager:
|
|||
self._node_data_store = node_data_store
|
||||
self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\
|
||||
else self._node_data_store.completed_blobs
|
||||
self.blobs: typing.Dict[str, BlobFile] = {}
|
||||
self.blobs: typing.Dict[str, AbstractBlob] = {}
|
||||
self._save_blobs = save_blobs
|
||||
|
||||
def get_blob_class(self):
|
||||
if not self._save_blobs:
|
||||
return BlobBuffer
|
||||
return BlobFile
|
||||
|
||||
async def setup(self) -> bool:
|
||||
def get_files_in_blob_dir() -> typing.Set[str]:
|
||||
if not self.blob_dir:
|
||||
return set()
|
||||
return {
|
||||
item.name for item in os.scandir(self.blob_dir) if is_valid_blobhash(item.name)
|
||||
}
|
||||
|
||||
in_blobfiles_dir = await self.loop.run_in_executor(None, get_files_in_blob_dir)
|
||||
self.completed_blob_hashes.update(await self.storage.sync_missing_blobs(in_blobfiles_dir))
|
||||
to_add = await self.storage.sync_missing_blobs(in_blobfiles_dir)
|
||||
if to_add:
|
||||
self.completed_blob_hashes.update(to_add)
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
|
@ -50,17 +59,20 @@ class BlobManager:
|
|||
if length and self.blobs[blob_hash].length is None:
|
||||
self.blobs[blob_hash].set_length(length)
|
||||
else:
|
||||
self.blobs[blob_hash] = BlobFile(self.loop, self.blob_dir, blob_hash, length, self.blob_completed)
|
||||
self.blobs[blob_hash] = self.get_blob_class()(self.loop, blob_hash, length, self.blob_completed,
|
||||
self.blob_dir)
|
||||
return self.blobs[blob_hash]
|
||||
|
||||
def get_stream_descriptor(self, sd_hash):
|
||||
return StreamDescriptor.from_stream_descriptor_blob(self.loop, self.blob_dir, self.get_blob(sd_hash))
|
||||
|
||||
async def blob_completed(self, blob: BlobFile):
|
||||
async def blob_completed(self, blob: AbstractBlob):
|
||||
if blob.blob_hash is None:
|
||||
raise Exception("Blob hash is None")
|
||||
if not blob.length:
|
||||
raise Exception("Blob has a length of 0")
|
||||
if isinstance(blob, BlobBuffer): # don't save blob buffers to the db / dont announce them
|
||||
return
|
||||
if blob.blob_hash not in self.completed_blob_hashes:
|
||||
self.completed_blob_hashes.add(blob.blob_hash)
|
||||
await self.storage.add_completed_blob(blob.blob_hash, blob.length)
|
||||
|
@ -75,7 +87,7 @@ class BlobManager:
|
|||
raise Exception("invalid blob hash to delete")
|
||||
|
||||
if blob_hash not in self.blobs:
|
||||
if os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
|
||||
if self.blob_dir and os.path.isfile(os.path.join(self.blob_dir, blob_hash)):
|
||||
os.remove(os.path.join(self.blob_dir, blob_hash))
|
||||
else:
|
||||
self.blobs.pop(blob_hash).delete()
|
||||
|
|
|
@ -5,7 +5,7 @@ import binascii
|
|||
from lbrynet.error import InvalidBlobHashError, InvalidDataError
|
||||
from lbrynet.blob_exchange.serialization import BlobResponse, BlobRequest
|
||||
if typing.TYPE_CHECKING:
|
||||
from lbrynet.blob.blob_file import BlobFile
|
||||
from lbrynet.blob.blob_file import AbstractBlob
|
||||
from lbrynet.blob.writer import HashBlobWriter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -17,10 +17,10 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self.peer_port: typing.Optional[int] = None
|
||||
self.peer_address: typing.Optional[str] = None
|
||||
self.peer_timeout = peer_timeout
|
||||
self.transport: asyncio.Transport = None
|
||||
self.transport: typing.Optional[asyncio.Transport] = None
|
||||
|
||||
self.writer: 'HashBlobWriter' = None
|
||||
self.blob: 'BlobFile' = None
|
||||
self.writer: typing.Optional['HashBlobWriter'] = None
|
||||
self.blob: typing.Optional['AbstractBlob'] = None
|
||||
|
||||
self._blob_bytes_received = 0
|
||||
self._response_fut: asyncio.Future = None
|
||||
|
@ -63,8 +63,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
# write blob bytes if we're writing a blob and have blob bytes to write
|
||||
self._write(response.blob_data)
|
||||
|
||||
|
||||
def _write(self, data):
|
||||
def _write(self, data: bytes):
|
||||
if len(data) > (self.blob.get_length() - self._blob_bytes_received):
|
||||
data = data[:(self.blob.get_length() - self._blob_bytes_received)]
|
||||
log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port)
|
||||
|
@ -145,11 +144,12 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self.transport = None
|
||||
self.buf = b''
|
||||
|
||||
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
||||
if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked():
|
||||
async def download_blob(self, blob: 'AbstractBlob') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
||||
if blob.get_is_verified() or not blob.is_writeable():
|
||||
return 0, self.transport
|
||||
try:
|
||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
|
||||
blob.get_blob_writer()
|
||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0
|
||||
self._response_fut = asyncio.Future(loop=self.loop)
|
||||
return await self._download_blob()
|
||||
except OSError as e:
|
||||
|
@ -177,7 +177,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self.close()
|
||||
|
||||
|
||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
|
||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
|
||||
peer_connect_timeout: float, blob_download_timeout: float,
|
||||
connected_transport: asyncio.Transport = None)\
|
||||
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
||||
|
@ -196,7 +196,7 @@ async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: s
|
|||
if not connected_transport:
|
||||
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
|
||||
peer_connect_timeout, loop=loop)
|
||||
if blob.get_is_verified() or blob.file_exists:
|
||||
if blob.get_is_verified() or not blob.is_writeable():
|
||||
# file exists but not verified means someone is writing right now, give it time, come back later
|
||||
return 0, connected_transport
|
||||
return await protocol.download_blob(blob)
|
||||
|
|
|
@ -8,7 +8,7 @@ if typing.TYPE_CHECKING:
|
|||
from lbrynet.dht.node import Node
|
||||
from lbrynet.dht.peer import KademliaPeer
|
||||
from lbrynet.blob.blob_manager import BlobManager
|
||||
from lbrynet.blob.blob_file import BlobFile
|
||||
from lbrynet.blob.blob_file import AbstractBlob
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,7 +28,7 @@ class BlobDownloader:
|
|||
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
|
||||
self.time_since_last_blob = loop.time()
|
||||
|
||||
def should_race_continue(self, blob: 'BlobFile'):
|
||||
def should_race_continue(self, blob: 'AbstractBlob'):
|
||||
if len(self.active_connections) >= self.config.max_connections_per_download:
|
||||
return False
|
||||
# if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves
|
||||
|
@ -37,9 +37,9 @@ class BlobDownloader:
|
|||
# for peer, task in self.active_connections.items():
|
||||
# if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done():
|
||||
# return False
|
||||
return not (blob.get_is_verified() or blob.file_exists)
|
||||
return not (blob.get_is_verified() or not blob.is_writeable())
|
||||
|
||||
async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'):
|
||||
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
|
||||
if blob.get_is_verified():
|
||||
return
|
||||
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
||||
|
@ -62,7 +62,7 @@ class BlobDownloader:
|
|||
rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0
|
||||
self.scores[peer] = rough_speed
|
||||
|
||||
async def new_peer_or_finished(self, blob: 'BlobFile'):
|
||||
async def new_peer_or_finished(self, blob: 'AbstractBlob'):
|
||||
async def get_and_re_add_peers():
|
||||
try:
|
||||
new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0)
|
||||
|
@ -90,7 +90,7 @@ class BlobDownloader:
|
|||
for banned_peer in forgiven:
|
||||
self.ignored.pop(banned_peer)
|
||||
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
|
||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||
if blob.get_is_verified():
|
||||
return blob
|
||||
|
@ -99,7 +99,7 @@ class BlobDownloader:
|
|||
batch: typing.List['KademliaPeer'] = []
|
||||
while not self.peer_queue.empty():
|
||||
batch.extend(self.peer_queue.get_nowait())
|
||||
batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True)
|
||||
batch.sort(key=lambda p: self.scores.get(p, 0), reverse=True)
|
||||
log.debug(
|
||||
"running, %d peers, %d ignored, %d active",
|
||||
len(batch), len(self.ignored), len(self.active_connections)
|
||||
|
@ -114,15 +114,29 @@ class BlobDownloader:
|
|||
await self.new_peer_or_finished(blob)
|
||||
self.cleanup_active()
|
||||
if batch:
|
||||
self.peer_queue.put_nowait(set(batch).difference(self.ignored))
|
||||
to_re_add = list(set(batch).difference(self.ignored))
|
||||
if to_re_add:
|
||||
self.peer_queue.put_nowait(to_re_add)
|
||||
else:
|
||||
self.clearbanned()
|
||||
else:
|
||||
self.clearbanned()
|
||||
blob.close()
|
||||
log.debug("downloaded %s", blob_hash[:8])
|
||||
return blob
|
||||
except asyncio.CancelledError as err:
|
||||
error = err
|
||||
finally:
|
||||
re_add = set()
|
||||
while self.active_connections:
|
||||
self.active_connections.popitem()[1].cancel()
|
||||
peer, t = self.active_connections.popitem()
|
||||
t.cancel()
|
||||
re_add.add(peer)
|
||||
re_add = re_add.difference(self.ignored)
|
||||
if re_add:
|
||||
self.peer_queue.put_nowait(list(re_add))
|
||||
blob.close()
|
||||
raise error
|
||||
|
||||
def close(self):
|
||||
self.scores.clear()
|
||||
|
@ -132,7 +146,7 @@ class BlobDownloader:
|
|||
|
||||
|
||||
async def download_blob(loop, config: 'Config', blob_manager: 'BlobManager', node: 'Node',
|
||||
blob_hash: str) -> 'BlobFile':
|
||||
blob_hash: str) -> 'AbstractBlob':
|
||||
search_queue = asyncio.Queue(loop=loop, maxsize=config.max_connections_per_download)
|
||||
search_queue.put_nowait(blob_hash)
|
||||
peer_queue, accumulate_task = node.accumulate_peers(search_queue)
|
||||
|
|
|
@ -294,7 +294,7 @@ class BlobComponent(Component):
|
|||
blob_dir = os.path.join(self.conf.data_dir, 'blobfiles')
|
||||
if not os.path.isdir(blob_dir):
|
||||
os.mkdir(blob_dir)
|
||||
self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store)
|
||||
self.blob_manager = BlobManager(asyncio.get_event_loop(), blob_dir, storage, data_store, self.conf.save_blobs)
|
||||
return await self.blob_manager.setup()
|
||||
|
||||
async def stop(self):
|
||||
|
|
|
@ -8,7 +8,7 @@ from collections import OrderedDict
|
|||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
from lbrynet.blob import MAX_BLOB_SIZE
|
||||
from lbrynet.blob.blob_info import BlobInfo
|
||||
from lbrynet.blob.blob_file import BlobFile
|
||||
from lbrynet.blob.blob_file import AbstractBlob, BlobFile
|
||||
from lbrynet.cryptoutils import get_lbry_hash_obj
|
||||
from lbrynet.error import InvalidStreamDescriptorError
|
||||
|
||||
|
@ -108,29 +108,29 @@ class StreamDescriptor:
|
|||
h.update(self.old_sort_json())
|
||||
return h.hexdigest()
|
||||
|
||||
async def make_sd_blob(self, blob_file_obj: typing.Optional[BlobFile] = None,
|
||||
async def make_sd_blob(self, blob_file_obj: typing.Optional[AbstractBlob] = None,
|
||||
old_sort: typing.Optional[bool] = False):
|
||||
sd_hash = self.calculate_sd_hash() if not old_sort else self.calculate_old_sort_sd_hash()
|
||||
if not old_sort:
|
||||
sd_data = self.as_json()
|
||||
else:
|
||||
sd_data = self.old_sort_json()
|
||||
sd_blob = blob_file_obj or BlobFile(self.loop, self.blob_dir, sd_hash, len(sd_data))
|
||||
sd_blob = blob_file_obj or BlobFile(self.loop, sd_hash, len(sd_data), blob_directory=self.blob_dir)
|
||||
if blob_file_obj:
|
||||
blob_file_obj.set_length(len(sd_data))
|
||||
if not sd_blob.get_is_verified():
|
||||
writer = sd_blob.open_for_writing()
|
||||
writer = sd_blob.get_blob_writer()
|
||||
writer.write(sd_data)
|
||||
|
||||
await sd_blob.verified.wait()
|
||||
sd_blob.close()
|
||||
return sd_blob
|
||||
|
||||
@classmethod
|
||||
def _from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
|
||||
blob: BlobFile) -> 'StreamDescriptor':
|
||||
assert os.path.isfile(blob.file_path)
|
||||
with open(blob.file_path, 'rb') as f:
|
||||
json_bytes = f.read()
|
||||
blob: AbstractBlob) -> 'StreamDescriptor':
|
||||
with blob.reader_context() as blob_reader:
|
||||
json_bytes = blob_reader.read()
|
||||
try:
|
||||
decoded = json.loads(json_bytes.decode())
|
||||
except json.JSONDecodeError:
|
||||
|
@ -160,8 +160,8 @@ class StreamDescriptor:
|
|||
|
||||
@classmethod
|
||||
async def from_stream_descriptor_blob(cls, loop: asyncio.BaseEventLoop, blob_dir: str,
|
||||
blob: BlobFile) -> 'StreamDescriptor':
|
||||
return await loop.run_in_executor(None, lambda: cls._from_stream_descriptor_blob(loop, blob_dir, blob))
|
||||
blob: AbstractBlob) -> 'StreamDescriptor':
|
||||
return await loop.run_in_executor(None, cls._from_stream_descriptor_blob, loop, blob_dir, blob)
|
||||
|
||||
@staticmethod
|
||||
def get_blob_hashsum(b: typing.Dict):
|
||||
|
@ -228,7 +228,7 @@ class StreamDescriptor:
|
|||
return self.lower_bound_decrypted_length() + (AES.block_size // 8)
|
||||
|
||||
@classmethod
|
||||
async def recover(cls, blob_dir: str, sd_blob: 'BlobFile', stream_hash: str, stream_name: str,
|
||||
async def recover(cls, blob_dir: str, sd_blob: 'AbstractBlob', stream_hash: str, stream_name: str,
|
||||
suggested_file_name: str, key: str,
|
||||
blobs: typing.List['BlobInfo']) -> typing.Optional['StreamDescriptor']:
|
||||
descriptor = cls(asyncio.get_event_loop(), blob_dir, stream_name, key, suggested_file_name,
|
||||
|
|
|
@ -63,11 +63,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.open_for_writing()
|
||||
self.writer = self.sd_blob.get_blob_writer()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
|
||||
await asyncio.wait_for(self.sd_blob.verified.wait(), 30, loop=self.loop)
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
|
@ -102,11 +102,11 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.open_for_writing()
|
||||
self.writer = blob.get_blob_writer()
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
await asyncio.wait_for(blob.finished_writing.wait(), 30, loop=self.loop)
|
||||
await asyncio.wait_for(blob.verified.wait(), 30, loop=self.loop)
|
||||
self.send_response({"received_blob": True})
|
||||
except asyncio.TimeoutError:
|
||||
self.send_response({"received_blob": False})
|
||||
|
|
Loading…
Reference in a new issue