Sanitize suggested, add tests

This commit is contained in:
Miroslav Kovar 2019-10-08 20:03:27 +02:00 committed by Lex Berezhny
parent ed93c647d2
commit 268f6447cb
6 changed files with 92 additions and 31 deletions

View file

@ -4,6 +4,7 @@ import binascii
import logging import logging
import typing import typing
import asyncio import asyncio
import re
from collections import OrderedDict from collections import OrderedDict
from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.algorithms import AES
from lbry.blob import MAX_BLOB_SIZE from lbry.blob import MAX_BLOB_SIZE
@ -46,6 +47,29 @@ def file_reader(file_path: str):
offset += bytes_to_read offset += bytes_to_read
def sanitize_file_name(dirty_name: str):
RE_IL_CHARS = re.compile(r'[<>:"/\\\|\?\*]')
RE_DOTS_AT_END = re.compile(r'[ \t]*(\.)+[ \t]$')
RE_LEAD_TRAIL_SPACE = re.compile(r'(^[ \t]+|[ \t]+$)')
res_il_names = \
[re.compile(regex) for regex in (r'^CON$', r'^PRN$', r'^AUX$', r'^NUL$', r'^COM[1-9]$', r'^LPT[1-9]$')]
file_name = re.sub(RE_IL_CHARS, '', dirty_name)
file_name, ext = os.path.splitext(file_name)
file_name = re.sub(RE_LEAD_TRAIL_SPACE, '', file_name)
file_name = re.sub(RE_DOTS_AT_END, '', file_name)
ext = re.sub(RE_LEAD_TRAIL_SPACE, '', ext)
if any((REGEX.match(file_name) for REGEX in res_il_names)):
file_name = ''
elif file_name and len(ext) > 1:
file_name += ext
if len(file_name) == 0:
log.warning('Unable to suggest a file name for %s', dirty_name)
return file_name
class StreamDescriptor: class StreamDescriptor:
__slots__ = [ __slots__ = [
'loop', 'loop',
@ -65,7 +89,7 @@ class StreamDescriptor:
self.blob_dir = blob_dir self.blob_dir = blob_dir
self.stream_name = stream_name self.stream_name = stream_name
self.key = key self.key = key
self.suggested_file_name = suggested_file_name self.suggested_file_name = sanitize_file_name(suggested_file_name)
self.blobs = blobs self.blobs = blobs
self.stream_hash = stream_hash or self.get_stream_hash() self.stream_hash = stream_hash or self.get_stream_hash()
self.sd_hash = sd_hash self.sd_hash = sd_hash

View file

@ -3,7 +3,6 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
import re
from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable
from lbry.utils import generate_id from lbry.utils import generate_id
from lbry.error import DownloadSDTimeout from lbry.error import DownloadSDTimeout
@ -40,11 +39,6 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
def sanitize_file_name(file_name):
file_name = str(file_name).strip().replace(' ', '_')
return re.sub(r'(?u)[^-\w.]', '', file_name)
class ManagedStream: class ManagedStream:
STATUS_RUNNING = "running" STATUS_RUNNING = "running"
STATUS_STOPPED = "stopped" STATUS_STOPPED = "stopped"
@ -121,12 +115,7 @@ class ManagedStream:
@property @property
def file_name(self) -> typing.Optional[str]: def file_name(self) -> typing.Optional[str]:
return self._file_name or \ return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
(sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None)
@file_name.setter
def file_name(self, value):
self._file_name = sanitize_file_name(value)
@property @property
def status(self) -> str: def status(self) -> str:
@ -261,7 +250,7 @@ class ManagedStream:
self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
if not await self.blob_manager.storage.file_exists(self.sd_hash): if not await self.blob_manager.storage.file_exists(self.sd_hash):
if save_now: if save_now:
file_name, download_dir = self.file_name, self.download_directory file_name, download_dir = self._file_name, self.download_directory
else: else:
file_name, download_dir = None, None file_name, download_dir = None, None
self.rowid = await self.blob_manager.storage.save_downloaded_file( self.rowid = await self.blob_manager.storage.save_downloaded_file(
@ -371,7 +360,7 @@ class ManagedStream:
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, None, None self.stream_hash, None, None
) )
self.file_name, self.download_directory = None, None self._file_name, self.download_directory = None, None
await self.blob_manager.storage.clear_saved_file(self.stream_hash) await self.blob_manager.storage.clear_saved_file(self.stream_hash)
await self.update_status(self.STATUS_STOPPED) await self.update_status(self.STATUS_STOPPED)
return return
@ -390,14 +379,14 @@ class ManagedStream:
self.download_directory = download_directory or self.download_directory or self.config.download_dir self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
if not (file_name or self.file_name): if not (file_name or self._file_name or self.descriptor.suggested_file_name):
raise ValueError("no file name to download to") raise ValueError("no file name to download to")
if not os.path.isdir(self.download_directory): if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
self.file_name = await get_next_available_file_name( self._file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
file_name or self.descriptor.suggested_file_name file_name or self.file_name
) )
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name self.stream_hash, self.download_directory, self.file_name

View file

@ -197,8 +197,9 @@ class CommandTestCase(IntegrationTestCase):
""" Synchronous version of `out` method. """ """ Synchronous version of `out` method. """
return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result'] return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result']
async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs): async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True,
file = tempfile.NamedTemporaryFile() prefix=None, suffix=None, **kwargs):
file = tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix)
def cleanup(): def cleanup():
try: try:

View file

@ -45,6 +45,40 @@ class FileCommands(CommandTestCase):
{stream.sd_hash, stream.descriptor.blobs[0].blob_hash} {stream.sd_hash, stream.descriptor.blobs[0].blob_hash}
) )
async def test_publish_with_illegal_chars(self):
# Stream a file with file name containing invalid chars
prefix = '?an|t:z< m<'
suffix = '.ext.'
san_prefix = 'antz m'
san_suffix = '.ext'
claim = await self.stream_create('foo', '0.01', prefix=prefix, suffix=suffix)
stream = self.daemon.jsonrpc_file_list()[0]
# Check that file list and source contains the local unsanitized name, but suggested name is sanitized
source_file_name = claim['outputs'][0]['value']['source']['name']
file_list_name = stream.file_name
suggested_file_name = stream.descriptor.suggested_file_name
self.assertTrue(source_file_name.startswith(prefix))
self.assertTrue(source_file_name.endswith(suffix))
self.assertEqual(file_list_name, source_file_name)
self.assertTrue(suggested_file_name.startswith(san_prefix))
self.assertTrue(suggested_file_name.endswith(san_suffix))
# Delete the file, re-download and check that the file name is sanitized
self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo'))
full_path = (await self.daemon.jsonrpc_get('lbry://foo', save_file=True)).full_path
file_name = os.path.basename(full_path)
self.assertTrue(os.path.isfile(full_path))
self.assertTrue(file_name.startswith(san_prefix))
self.assertTrue(file_name.endswith(san_suffix))
# Check that the downloaded file name is not sanitized when user provides custom file name
self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo'))
file_name = 'my <u?|*m.name'
full_path = (await self.daemon.jsonrpc_get('lbry://foo', file_name=file_name, save_file=True)).full_path
self.assertTrue(os.path.isfile(full_path))
self.assertEqual(file_name, os.path.basename(full_path))
async def test_file_list_fields(self): async def test_file_list_fields(self):
await self.stream_create('foo', '0.01') await self.stream_create('foo', '0.01')
file_list = self.sout(self.daemon.jsonrpc_file_list()) file_list = self.sout(self.daemon.jsonrpc_file_list())

View file

@ -21,12 +21,12 @@ def get_mock_node(loop):
class TestManagedStream(BlobExchangeTestBase): class TestManagedStream(BlobExchangeTestBase):
async def create_stream(self, blob_count: int = 10): async def create_stream(self, blob_count: int = 10, file_name='test_file'):
self.stream_bytes = b'' self.stream_bytes = b''
for _ in range(blob_count): for _ in range(blob_count):
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
# create the stream # create the stream
file_path = os.path.join(self.server_dir, "test_file") file_path = os.path.join(self.server_dir, file_name)
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(self.stream_bytes) f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
@ -39,13 +39,20 @@ class TestManagedStream(BlobExchangeTestBase):
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
) )
async def test_file_saves_with_valid_file_name(self): async def test_client_sanitizes_file_name(self):
await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov") illegal_name = 't<?t_f:|<'
self.assertTrue(self.stream.output_file_exists) descriptor = await self.create_stream(file_name=illegal_name)
descriptor.suggested_file_name = illegal_name
self.stream = ManagedStream(
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
)
await self._test_transfer_stream(1, skip_setup=True)
self.assertTrue(self.stream.completed) self.assertTrue(self.stream.completed)
self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov") self.assertEqual(self.stream.file_name, 'tt_f')
self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name)) self.assertTrue(self.stream.output_file_exists)
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
self.assertEqual(self.stream.full_path, os.path.join(self.client_dir, 'tt_f'))
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, 'tt_f')))
async def test_status_file_completed(self): async def test_status_file_completed(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)
@ -57,8 +64,9 @@ class TestManagedStream(BlobExchangeTestBase):
self.assertFalse(self.stream.completed) self.assertFalse(self.stream.completed)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True, async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True,
file_name=None): skip_setup=False):
await self.setup_stream(blob_count) if not skip_setup:
await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node) mock_node = mock.Mock(spec=Node)
def _mock_accumulate_peers(q1, q2): def _mock_accumulate_peers(q1, q2):
@ -68,7 +76,7 @@ class TestManagedStream(BlobExchangeTestBase):
return q2, self.loop.create_task(_task()) return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.save_file(node=mock_node, file_name=file_name) await self.stream.save_file(node=mock_node)
await self.stream.finished_write_attempt.wait() await self.stream.finished_write_attempt.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
if stop_when_done: if stop_when_done:

View file

@ -10,7 +10,7 @@ from lbry.conf import Config
from lbry.error import InvalidStreamDescriptorError from lbry.error import InvalidStreamDescriptorError
from lbry.extras.daemon.storage import SQLiteStorage from lbry.extras.daemon.storage import SQLiteStorage
from lbry.blob.blob_manager import BlobManager from lbry.blob.blob_manager import BlobManager
from lbry.stream.descriptor import StreamDescriptor from lbry.stream.descriptor import StreamDescriptor, sanitize_file_name
class TestStreamDescriptor(AsyncioTestCase): class TestStreamDescriptor(AsyncioTestCase):
@ -78,6 +78,11 @@ class TestStreamDescriptor(AsyncioTestCase):
self.sd_dict['blobs'][-2]['length'] = 0 self.sd_dict['blobs'][-2]['length'] = 0
await self._test_invalid_sd() await self._test_invalid_sd()
async def test_sanitize_file_name(self):
test_cases = [' t/-?t|.g.ext ', 'end_me .', '', '.file', 'test name.ext', 'COM8', 'LPT2']
expected = ['t-t.g.ext', 'end_me', '', '.file', 'test name.ext', '', '']
actual = [sanitize_file_name(tc) for tc in test_cases]
self.assertListEqual(actual, expected)
class TestRecoverOldStreamDescriptors(AsyncioTestCase): class TestRecoverOldStreamDescriptors(AsyncioTestCase):
async def test_old_key_sort_sd_blob(self): async def test_old_key_sort_sd_blob(self):