lbry-sdk/lbrynet/blob/blob_file.py

198 lines
7.6 KiB
Python
Raw Normal View History

2015-08-20 17:27:15 +02:00
import os
2019-02-09 01:54:59 +01:00
import re
2019-01-22 18:47:46 +01:00
import asyncio
import binascii
2018-11-07 21:15:05 +01:00
import logging
2019-01-22 18:47:46 +01:00
import typing
from cryptography.hazmat.primitives.ciphers import Cipher, modes
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.padding import PKCS7
2015-08-20 17:27:15 +02:00
2019-01-22 18:47:46 +01:00
from lbrynet.cryptoutils import backend, get_lbry_hash_obj
from lbrynet.error import DownloadCancelledError, InvalidBlobHashError, InvalidDataError
2019-01-22 18:47:46 +01:00
from lbrynet.blob import MAX_BLOB_SIZE, blobhash_length
from lbrynet.blob.blob_info import BlobInfo
from lbrynet.blob.writer import HashBlobWriter
2017-09-13 21:27:43 +02:00
2019-01-22 18:47:46 +01:00
log = logging.getLogger(__name__)
2019-02-09 01:57:26 +01:00
_hexmatch = re.compile("^[a-f,0-9]+$")
2019-02-09 01:57:26 +01:00
def is_valid_hashcharacter(char: str) -> bool:
return len(char) == 1 and _hexmatch.match(char)
2019-02-09 01:54:59 +01:00
2019-01-22 18:47:46 +01:00
def is_valid_blobhash(blobhash: str) -> bool:
"""Checks whether the blobhash is the correct length and contains only
valid characters (0-9, a-f)
@param blobhash: string, the blobhash to check
@return: True/False
"""
2019-02-09 01:54:59 +01:00
return len(blobhash) == blobhash_length and _hexmatch.match(blobhash)
2018-02-12 20:16:43 +01:00
2019-02-09 01:57:26 +01:00
2019-01-22 18:47:46 +01:00
def encrypt_blob_bytes(key: bytes, iv: bytes, unencrypted: bytes) -> typing.Tuple[bytes, str]:
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
padder = PKCS7(AES.block_size).padder()
encryptor = cipher.encryptor()
encrypted = encryptor.update(padder.update(unencrypted) + padder.finalize()) + encryptor.finalize()
digest = get_lbry_hash_obj()
digest.update(encrypted)
return encrypted, digest.hexdigest()
class BlobFile:
2017-09-13 21:27:43 +02:00
"""
A chunk of data available on the network which is specified by a hashsum
2015-08-20 17:27:15 +02:00
2017-09-13 21:27:43 +02:00
This class is used to create blobs on the local filesystem
when we already know the blob hash before hand (i.e., when downloading blobs)
Also can be used for reading from blobs on the local filesystem
"""
2019-01-22 18:47:46 +01:00
def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, blob_hash: str,
length: typing.Optional[int] = None,
blob_completed_callback: typing.Optional[typing.Callable[['BlobFile'], typing.Awaitable]] = None):
2017-09-13 21:27:43 +02:00
if not is_valid_blobhash(blob_hash):
raise InvalidBlobHashError(blob_hash)
2019-01-22 18:47:46 +01:00
self.loop = loop
2015-08-20 17:27:15 +02:00
self.blob_hash = blob_hash
self.length = length
2017-09-13 21:27:43 +02:00
self.blob_dir = blob_dir
self.file_path = os.path.join(blob_dir, self.blob_hash)
2019-01-22 18:47:46 +01:00
self.writers: typing.List[HashBlobWriter] = []
2017-09-13 21:27:43 +02:00
2019-01-22 18:47:46 +01:00
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=loop)
self.blob_write_lock = asyncio.Lock(loop=loop)
2019-02-08 06:38:27 +01:00
if self.file_exists:
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
2019-01-22 18:47:46 +01:00
self.length = length
self.verified.set()
self.finished_writing.set()
self.saved_verified_blob = False
self.blob_completed_callback = blob_completed_callback
2017-09-13 21:27:43 +02:00
2019-02-08 06:38:27 +01:00
@property
def file_exists(self):
return os.path.isfile(self.file_path)
2019-01-22 18:47:46 +01:00
def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future):
try:
error = finished.exception()
2019-01-22 18:47:46 +01:00
except Exception as err:
error = err
if writer in self.writers: # remove this download attempt
self.writers.remove(writer)
if not error: # the blob downloaded, cancel all the other download attempts and set the result
while self.writers:
other = self.writers.pop()
other.finished.cancel()
t = self.loop.create_task(self.save_verified_blob(writer, finished.result()))
2019-01-22 18:47:46 +01:00
t.add_done_callback(lambda *_: self.finished_writing.set())
return
if isinstance(error, (InvalidBlobHashError, InvalidDataError)):
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
2019-01-22 18:47:46 +01:00
elif not isinstance(error, (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError)):
log.exception("something else")
raise error
return callback
2019-02-01 21:09:37 +01:00
async def save_verified_blob(self, writer, verified_bytes: bytes):
2019-01-22 18:47:46 +01:00
def _save_verified():
# log.debug(f"write blob file {self.blob_hash[:8]} from {writer.peer.address}")
if not self.saved_verified_blob and not os.path.isfile(self.file_path):
if self.get_length() == len(verified_bytes):
2019-01-22 18:47:46 +01:00
with open(self.file_path, 'wb') as write_handle:
write_handle.write(verified_bytes)
2019-01-22 18:47:46 +01:00
self.saved_verified_blob = True
else:
raise Exception("length mismatch")
async with self.blob_write_lock:
if self.verified.is_set():
return
2019-01-22 18:47:46 +01:00
await self.loop.run_in_executor(None, _save_verified)
if self.blob_completed_callback:
await self.blob_completed_callback(self)
2019-02-05 17:17:00 +01:00
self.verified.set()
2019-01-22 18:47:46 +01:00
def open_for_writing(self) -> HashBlobWriter:
2019-02-08 06:38:27 +01:00
if self.file_exists:
2019-01-22 18:47:46 +01:00
raise OSError(f"File already exists '{self.file_path}'")
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers.append(writer)
fut.add_done_callback(self.writer_finished(writer))
return writer
async def sendfile(self, writer: asyncio.StreamWriter) -> int:
2017-09-13 21:27:43 +02:00
"""
2019-01-22 18:47:46 +01:00
Read and send the file to the writer and return the number of bytes sent
2017-09-13 21:27:43 +02:00
"""
2019-01-22 18:47:46 +01:00
with open(self.file_path, 'rb') as handle:
2019-02-04 22:43:11 +01:00
return await self.loop.sendfile(writer.transport, handle, count=self.get_length())
2019-01-22 18:47:46 +01:00
def close(self):
2019-01-22 18:47:46 +01:00
while self.writers:
self.writers.pop().finished.cancel()
def delete(self):
self.close()
self.saved_verified_blob = False
if os.path.isfile(self.file_path):
os.remove(self.file_path)
self.verified.clear()
self.finished_writing.clear()
2019-02-22 01:00:28 +01:00
self.length = None
2019-01-22 18:47:46 +01:00
def decrypt(self, key: bytes, iv: bytes) -> bytes:
2017-09-13 21:27:43 +02:00
"""
2019-01-22 18:47:46 +01:00
Decrypt a BlobFile to plaintext bytes
2017-09-13 21:27:43 +02:00
"""
2019-01-22 18:47:46 +01:00
with open(self.file_path, "rb") as f:
buff = f.read()
if len(buff) != self.length:
raise ValueError("unexpected length")
cipher = Cipher(AES(key), modes.CBC(iv), backend=backend)
unpadder = PKCS7(AES.block_size).unpadder()
decryptor = cipher.decryptor()
return unpadder.update(decryptor.update(buff) + decryptor.finalize()) + unpadder.finalize()
@classmethod
async def create_from_unencrypted(cls, loop: asyncio.BaseEventLoop, blob_dir: str, key: bytes,
iv: bytes, unencrypted: bytes, blob_num: int) -> BlobInfo:
2017-09-13 21:27:43 +02:00
"""
2019-01-22 18:47:46 +01:00
Create an encrypted BlobFile from plaintext bytes
"""
2019-01-22 18:47:46 +01:00
blob_bytes, blob_hash = encrypt_blob_bytes(key, iv, unencrypted)
length = len(blob_bytes)
blob = cls(loop, blob_dir, blob_hash, length)
writer = blob.open_for_writing()
writer.write(blob_bytes)
await blob.verified.wait()
return BlobInfo(blob_num, length, binascii.hexlify(iv).decode(), blob_hash)
2015-08-20 17:27:15 +02:00
def set_length(self, length):
if self.length is not None and length == self.length:
2019-01-22 18:47:46 +01:00
return
if self.length is None and 0 <= length <= MAX_BLOB_SIZE:
2015-08-20 17:27:15 +02:00
self.length = length
2019-01-22 18:47:46 +01:00
return
log.warning("Got an invalid length. Previous length: %s, Invalid length: %s", self.length, length)
2015-08-20 17:27:15 +02:00
def get_length(self):
return self.length
def get_is_verified(self):
2019-01-22 18:47:46 +01:00
return self.verified.is_set()