From 268f6447cb0affca824f932a879f1f2ece4a16a3 Mon Sep 17 00:00:00 2001 From: Miroslav Kovar <miroslavkovar@protonmail.com> Date: Tue, 8 Oct 2019 20:03:27 +0200 Subject: [PATCH] Sanitize suggested, add tests --- lbry/lbry/stream/descriptor.py | 26 +++++++++++++- lbry/lbry/stream/managed_stream.py | 23 ++++--------- lbry/lbry/testcase.py | 5 +-- lbry/tests/integration/test_file_commands.py | 34 +++++++++++++++++++ lbry/tests/unit/stream/test_managed_stream.py | 28 +++++++++------ .../unit/stream/test_stream_descriptor.py | 7 +++- 6 files changed, 92 insertions(+), 31 deletions(-) diff --git a/lbry/lbry/stream/descriptor.py b/lbry/lbry/stream/descriptor.py index b23df15b2..d071392f5 100644 --- a/lbry/lbry/stream/descriptor.py +++ b/lbry/lbry/stream/descriptor.py @@ -4,6 +4,7 @@ import binascii import logging import typing import asyncio +import re from collections import OrderedDict from cryptography.hazmat.primitives.ciphers.algorithms import AES from lbry.blob import MAX_BLOB_SIZE @@ -46,6 +47,29 @@ def file_reader(file_path: str): offset += bytes_to_read +def sanitize_file_name(dirty_name: str): + RE_IL_CHARS = re.compile(r'[<>:"/\\\|\?\*]') + RE_DOTS_AT_END = re.compile(r'[ \t]*(\.)+[ \t]$') + RE_LEAD_TRAIL_SPACE = re.compile(r'(^[ \t]+|[ \t]+$)') + res_il_names = \ + [re.compile(regex) for regex in (r'^CON$', r'^PRN$', r'^AUX$', r'^NUL$', r'^COM[1-9]$', r'^LPT[1-9]$')] + + file_name = re.sub(RE_IL_CHARS, '', dirty_name) + file_name, ext = os.path.splitext(file_name) + file_name = re.sub(RE_LEAD_TRAIL_SPACE, '', file_name) + file_name = re.sub(RE_DOTS_AT_END, '', file_name) + ext = re.sub(RE_LEAD_TRAIL_SPACE, '', ext) + if any((REGEX.match(file_name) for REGEX in res_il_names)): + file_name = '' + elif file_name and len(ext) > 1: + file_name += ext + + if len(file_name) == 0: + log.warning('Unable to suggest a file name for %s', dirty_name) + + return file_name + + class StreamDescriptor: __slots__ = [ 'loop', @@ -65,7 +89,7 @@ class StreamDescriptor: self.blob_dir = blob_dir self.stream_name = stream_name self.key = key - self.suggested_file_name = suggested_file_name + self.suggested_file_name = sanitize_file_name(suggested_file_name) self.blobs = blobs self.stream_hash = stream_hash or self.get_stream_hash() self.sd_hash = sd_hash diff --git a/lbry/lbry/stream/managed_stream.py b/lbry/lbry/stream/managed_stream.py index 1d22b9113..16a56ff56 100644 --- a/lbry/lbry/stream/managed_stream.py +++ b/lbry/lbry/stream/managed_stream.py @@ -3,7 +3,6 @@ import asyncio import typing import logging import binascii -import re from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable from lbry.utils import generate_id from lbry.error import DownloadSDTimeout @@ -40,11 +39,6 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) -def sanitize_file_name(file_name): - file_name = str(file_name).strip().replace(' ', '_') - return re.sub(r'(?u)[^-\w.]', '', file_name) - - class ManagedStream: STATUS_RUNNING = "running" STATUS_STOPPED = "stopped" @@ -121,12 +115,7 @@ class ManagedStream: @property def file_name(self) -> typing.Optional[str]: - return self._file_name or \ - (sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None) - - @file_name.setter - def file_name(self, value): - self._file_name = sanitize_file_name(value) + return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None) @property def status(self) -> str: @@ -261,7 +250,7 @@ class ManagedStream: self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) if not await self.blob_manager.storage.file_exists(self.sd_hash): if save_now: - file_name, download_dir = self.file_name, self.download_directory + file_name, download_dir = self._file_name, self.download_directory else: file_name, download_dir = None, None self.rowid = await self.blob_manager.storage.save_downloaded_file( @@ -371,7 +360,7 @@ class ManagedStream: await self.blob_manager.storage.change_file_download_dir_and_file_name( self.stream_hash, None, None ) - self.file_name, self.download_directory = None, None + self._file_name, self.download_directory = None, None await self.blob_manager.storage.clear_saved_file(self.stream_hash) await self.update_status(self.STATUS_STOPPED) return @@ -390,14 +379,14 @@ class ManagedStream: self.download_directory = download_directory or self.download_directory or self.config.download_dir if not self.download_directory: raise ValueError("no directory to download to") - if not (file_name or self.file_name): + if not (file_name or self._file_name or self.descriptor.suggested_file_name): raise ValueError("no file name to download to") if not os.path.isdir(self.download_directory): log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) os.mkdir(self.download_directory) - self.file_name = await get_next_available_file_name( + self._file_name = await get_next_available_file_name( self.loop, self.download_directory, - file_name or self.descriptor.suggested_file_name + file_name or self.file_name ) await self.blob_manager.storage.change_file_download_dir_and_file_name( self.stream_hash, self.download_directory, self.file_name diff --git a/lbry/lbry/testcase.py b/lbry/lbry/testcase.py index 0e13971a9..04fed3114 100644 --- a/lbry/lbry/testcase.py +++ b/lbry/lbry/testcase.py @@ -197,8 +197,9 @@ class CommandTestCase(IntegrationTestCase): """ Synchronous version of `out` method. """ return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result'] - async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs): - file = tempfile.NamedTemporaryFile() + async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, + prefix=None, suffix=None, **kwargs): + file = tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix) def cleanup(): try: diff --git a/lbry/tests/integration/test_file_commands.py b/lbry/tests/integration/test_file_commands.py index 59b31d1de..be1a0c074 100644 --- a/lbry/tests/integration/test_file_commands.py +++ b/lbry/tests/integration/test_file_commands.py @@ -45,6 +45,40 @@ class FileCommands(CommandTestCase): {stream.sd_hash, stream.descriptor.blobs[0].blob_hash} ) + async def test_publish_with_illegal_chars(self): + # Stream a file with file name containing invalid chars + prefix = '?an|t:z< m<' + suffix = '.ext.' + san_prefix = 'antz m' + san_suffix = '.ext' + claim = await self.stream_create('foo', '0.01', prefix=prefix, suffix=suffix) + stream = self.daemon.jsonrpc_file_list()[0] + + # Check that file list and source contains the local unsanitized name, but suggested name is sanitized + source_file_name = claim['outputs'][0]['value']['source']['name'] + file_list_name = stream.file_name + suggested_file_name = stream.descriptor.suggested_file_name + self.assertTrue(source_file_name.startswith(prefix)) + self.assertTrue(source_file_name.endswith(suffix)) + self.assertEqual(file_list_name, source_file_name) + self.assertTrue(suggested_file_name.startswith(san_prefix)) + self.assertTrue(suggested_file_name.endswith(san_suffix)) + + # Delete the file, re-download and check that the file name is sanitized + self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo')) + full_path = (await self.daemon.jsonrpc_get('lbry://foo', save_file=True)).full_path + file_name = os.path.basename(full_path) + self.assertTrue(os.path.isfile(full_path)) + self.assertTrue(file_name.startswith(san_prefix)) + self.assertTrue(file_name.endswith(san_suffix)) + + # Check that the downloaded file name is not sanitized when user provides custom file name + self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo')) + file_name = 'my <u?|*m.name' + full_path = (await self.daemon.jsonrpc_get('lbry://foo', file_name=file_name, save_file=True)).full_path + self.assertTrue(os.path.isfile(full_path)) + self.assertEqual(file_name, os.path.basename(full_path)) + async def test_file_list_fields(self): await self.stream_create('foo', '0.01') file_list = self.sout(self.daemon.jsonrpc_file_list()) diff --git a/lbry/tests/unit/stream/test_managed_stream.py b/lbry/tests/unit/stream/test_managed_stream.py index ac31c3b9e..a65c61141 100644 --- a/lbry/tests/unit/stream/test_managed_stream.py +++ b/lbry/tests/unit/stream/test_managed_stream.py @@ -21,12 +21,12 @@ def get_mock_node(loop): class TestManagedStream(BlobExchangeTestBase): - async def create_stream(self, blob_count: int = 10): + async def create_stream(self, blob_count: int = 10, file_name='test_file'): self.stream_bytes = b'' for _ in range(blob_count): self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1)) # create the stream - file_path = os.path.join(self.server_dir, "test_file") + file_path = os.path.join(self.server_dir, file_name) with open(file_path, 'wb') as f: f.write(self.stream_bytes) descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path) @@ -39,13 +39,20 @@ class TestManagedStream(BlobExchangeTestBase): self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir ) - async def test_file_saves_with_valid_file_name(self): - await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov") - self.assertTrue(self.stream.output_file_exists) + async def test_client_sanitizes_file_name(self): + illegal_name = 't<?t_f:|<' + descriptor = await self.create_stream(file_name=illegal_name) + descriptor.suggested_file_name = illegal_name + self.stream = ManagedStream( + self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir + ) + await self._test_transfer_stream(1, skip_setup=True) self.assertTrue(self.stream.completed) - self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov") - self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name)) + self.assertEqual(self.stream.file_name, 'tt_f') + self.assertTrue(self.stream.output_file_exists) self.assertTrue(os.path.isfile(self.stream.full_path)) + self.assertEqual(self.stream.full_path, os.path.join(self.client_dir, 'tt_f')) + self.assertTrue(os.path.isfile(os.path.join(self.client_dir, 'tt_f'))) async def test_status_file_completed(self): await self._test_transfer_stream(10) @@ -57,8 +64,9 @@ class TestManagedStream(BlobExchangeTestBase): self.assertFalse(self.stream.completed) async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True, - file_name=None): - await self.setup_stream(blob_count) + skip_setup=False): + if not skip_setup: + await self.setup_stream(blob_count) mock_node = mock.Mock(spec=Node) def _mock_accumulate_peers(q1, q2): @@ -68,7 +76,7 @@ class TestManagedStream(BlobExchangeTestBase): return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - await self.stream.save_file(node=mock_node, file_name=file_name) + await self.stream.save_file(node=mock_node) await self.stream.finished_write_attempt.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) if stop_when_done: diff --git a/lbry/tests/unit/stream/test_stream_descriptor.py b/lbry/tests/unit/stream/test_stream_descriptor.py index 09d230ddc..eb7783459 100644 --- a/lbry/tests/unit/stream/test_stream_descriptor.py +++ b/lbry/tests/unit/stream/test_stream_descriptor.py @@ -10,7 +10,7 @@ from lbry.conf import Config from lbry.error import InvalidStreamDescriptorError from lbry.extras.daemon.storage import SQLiteStorage from lbry.blob.blob_manager import BlobManager -from lbry.stream.descriptor import StreamDescriptor +from lbry.stream.descriptor import StreamDescriptor, sanitize_file_name class TestStreamDescriptor(AsyncioTestCase): @@ -78,6 +78,11 @@ class TestStreamDescriptor(AsyncioTestCase): self.sd_dict['blobs'][-2]['length'] = 0 await self._test_invalid_sd() + async def test_sanitize_file_name(self): + test_cases = [' t/-?t|.g.ext ', 'end_me .', '', '.file', 'test name.ext', 'COM8', 'LPT2'] + expected = ['t-t.g.ext', 'end_me', '', '.file', 'test name.ext', '', ''] + actual = [sanitize_file_name(tc) for tc in test_cases] + self.assertListEqual(actual, expected) class TestRecoverOldStreamDescriptors(AsyncioTestCase): async def test_old_key_sort_sd_blob(self):