forked from LBRYCommunity/lbry-sdk
dbb6ba6241
-test download and recover stream with old key sorting
196 lines
7.6 KiB
Python
196 lines
7.6 KiB
Python
import os
|
|
import re
|
|
import asyncio
|
|
import binascii
|
|
import logging
|
|
import typing
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, modes
|
|
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
|
from cryptography.hazmat.primitives.padding import PKCS7
|
|
|
|
from lbrynet.cryptoutils import backend, get_lbry_hash_obj
|
|
from lbrynet.error import DownloadCancelledError, InvalidBlobHashError, InvalidDataError
|
|
|
|
from lbrynet.blob import MAX_BLOB_SIZE, blobhash_length
|
|
from lbrynet.blob.blob_info import BlobInfo
|
|
from lbrynet.blob.writer import HashBlobWriter
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
_hexmatch = re.compile("^[a-f,0-9]+$")
|
|
|
|
|
|
def is_valid_hashcharacter(char: str) -> bool:
|
|
return len(char) == 1 and _hexmatch.match(char)
|
|
|
|
|
|
def is_valid_blobhash(blobhash: str) -> bool:
|
|
"""Checks whether the blobhash is the correct length and contains only
|
|
valid characters (0-9, a-f)
|
|
|
|
@param blobhash: string, the blobhash to check
|
|
|
|
@return: True/False
|
|
"""
|
|
return len(blobhash) == blobhash_length and _hexmatch.match(blobhash)
|
|
|
|
|
|
def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tuple[bytes, str]:
|
|
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
|
|
padder = PKCS7(AES.block_size).padder()
|
|
encryptor = cipher.encryptor()
|
|
encrypted = encryptor.update(padder.update(unencrypted) + padder.finalize()) + encryptor.finalize()
|
|
digest = get_lbry_hash_obj()
|
|
digest.update(encrypted)
|
|
return encrypted, digest.hexdigest()
|
|
|
|
|
|
class BlobFile:
|
|
"""
|
|
A chunk of data available on the network which is specified by a hashsum
|
|
|
|
This class is used to create blobs on the local filesystem
|
|
when we already know the blob hash before hand (i.e., when downloading blobs)
|
|
Also can be used for reading from blobs on the local filesystem
|
|
"""
|
|
|
|
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str,
|
|
length: typing.Optional[int] = None,
|
|
blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None):
|
|
if not is_valid_blobhash(blob_hash):
|
|
raise InvalidBlobHashError(blob_hash)
|
|
self.loop = loop
|
|
self.blob_hash = blob_hash
|
|
self.length = length
|
|
self.blob_dir = blob_dir
|
|
self.file_path = os.path.join(blob_dir, self.blob_hash)
|
|
self.writers: typing.List[HashBlobWriter] = []
|
|
|
|
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
|
self.finished_writing = asyncio.Event(loop=loop)
|
|
self.blob_write_lock = asyncio.Lock(loop=loop)
|
|
if self.file_exists:
|
|
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
|
|
self.length = length
|
|
self.verified.set()
|
|
self.finished_writing.set()
|
|
self.saved_verified_blob = False
|
|
self.blob_completed_callback = blob_completed_callback
|
|
|
|
@property
|
|
def file_exists(self):
|
|
return os.path.isfile(self.file_path)
|
|
|
|
def writer_finished(self, writer: HashBlobWriter):
|
|
def callback(finished: asyncio.Future):
|
|
try:
|
|
error = finished.exception()
|
|
except Exception as err:
|
|
error = err
|
|
if writer in self.writers: # remove this download attempt
|
|
self.writers.remove(writer)
|
|
if not error: # the blob downloaded, cancel all the other download attempts and set the result
|
|
while self.writers:
|
|
other = self.writers.pop()
|
|
other.finished.cancel()
|
|
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
|
|
t.add_done_callback(lambda *_: self.finished_writing.set())
|
|
return
|
|
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
|
|
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
|
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
|
|
log.exception("something else")
|
|
raise error
|
|
return callback
|
|
|
|
async def save_verified_blob(self, writer, verified_bytes: bytes):
|
|
def _save_verified():
|
|
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
|
|
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
|
|
if self.get_length() == len(verified_bytes):
|
|
with open(self.file_path, 'wb') as write_handle:
|
|
write_handle.write(verified_bytes)
|
|
self.saved_verified_blob = True
|
|
else:
|
|
raise Exception("length mismatch")
|
|
|
|
async with self.blob_write_lock:
|
|
if self.verified.is_set():
|
|
return
|
|
await self.loop.run_in_executor(None, _save_verified)
|
|
if self.blob_completed_callback:
|
|
await self.blob_completed_callback(self)
|
|
self.verified.set()
|
|
|
|
def open_for_writing(self) -> HashBlobWriter:
|
|
if self.file_exists:
|
|
raise OSError(f"File already exists '{self.file_path}'")
|
|
fut = asyncio.Future(loop=self.loop)
|
|
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
|
self.writers.append(writer)
|
|
fut.add_done_callback(self.writer_finished(writer))
|
|
return writer
|
|
|
|
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
|
|
"""
|
|
Read and send the file to the writer and return the number of bytes sent
|
|
"""
|
|
|
|
with open(self.file_path, 'rb') as handle:
|
|
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
|
|
|
|
def close(self):
|
|
while self.writers:
|
|
self.writers.pop().finished.cancel()
|
|
|
|
def delete(self):
|
|
self.close()
|
|
self.saved_verified_blob = False
|
|
if os.path.isfile(self.file_path):
|
|
os.remove(self.file_path)
|
|
self.verified.clear()
|
|
self.finished_writing.clear()
|
|
|
|
def decrypt(self, key: bytes, iv: bytes) -> bytes:
|
|
"""
|
|
Decrypt a BlobFile to plaintext bytes
|
|
"""
|
|
|
|
with open(self.file_path, "rb") as f:
|
|
buff = f.read()
|
|
if len(buff) != self.length:
|
|
raise ValueError("unexpected length")
|
|
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
|
|
unpadder = PKCS7(AES.block_size).unpadder()
|
|
decryptor = cipher.decryptor()
|
|
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
|
|
|
|
@classmethod
|
|
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
|
|
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
|
|
"""
|
|
Create an encrypted BlobFile from plaintext bytes
|
|
"""
|
|
|
|
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
|
|
length = len(blob_bytes)
|
|
blob = cls(loop, blob_dir, blob_hash, length)
|
|
writer = blob.open_for_writing()
|
|
writer.write(blob_bytes)
|
|
await blob.verified.wait()
|
|
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
|
|
|
|
def set_length(self, length):
|
|
if self.length is not None and length == self.length:
|
|
return
|
|
if self.length is None and 0 <= length <= MAX_BLOB_SIZE:
|
|
self.length = length
|
|
return
|
|
log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length)
|
|
|
|
def get_length(self):
|
|
return self.length
|
|
|
|
def get_is_verified(self):
|
|
return self.verified.is_set()
|