From ed93c647d26c5d0f41a93d20f7f49655dd2a3d95 Mon Sep 17 00:00:00 2001 From: Miroslav Kovar Date: Mon, 23 Sep 2019 17:48:36 +0200 Subject: [PATCH] Sanitize file names before saving --- lbry/lbry/stream/managed_stream.py | 21 ++++++++++++++----- lbry/tests/unit/stream/test_managed_stream.py | 13 ++++++++++-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/lbry/lbry/stream/managed_stream.py b/lbry/lbry/stream/managed_stream.py index 187084081..1d22b9113 100644 --- a/lbry/lbry/stream/managed_stream.py +++ b/lbry/lbry/stream/managed_stream.py @@ -3,6 +3,7 @@ import asyncio import typing import logging import binascii +import re from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable from lbry.utils import generate_id from lbry.error import DownloadSDTimeout @@ -39,6 +40,11 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) +def sanitize_file_name(file_name): + file_name = str(file_name).strip().replace(' ', '_') + return re.sub(r'(?u)[^-\w.]', '', file_name) + + class ManagedStream: STATUS_RUNNING = "running" STATUS_STOPPED = "stopped" @@ -115,7 +121,12 @@ class ManagedStream: @property def file_name(self) -> typing.Optional[str]: - return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None) + return self._file_name or \ + (sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None) + + @file_name.setter + def file_name(self, value): + self._file_name = sanitize_file_name(value) @property def status(self) -> str: @@ -250,7 +261,7 @@ class ManagedStream: self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) if not await self.blob_manager.storage.file_exists(self.sd_hash): if save_now: - file_name, download_dir = self._file_name, self.download_directory + file_name, download_dir = self.file_name, self.download_directory else: file_name, download_dir = None, None self.rowid = await self.blob_manager.storage.save_downloaded_file( @@ -360,7 +371,7 @@ class ManagedStream: await self.blob_manager.storage.change_file_download_dir_and_file_name( self.stream_hash, None, None ) - self._file_name, self.download_directory = None, None + self.file_name, self.download_directory = None, None await self.blob_manager.storage.clear_saved_file(self.stream_hash) await self.update_status(self.STATUS_STOPPED) return @@ -379,12 +390,12 @@ class ManagedStream: self.download_directory = download_directory or self.download_directory or self.config.download_dir if not self.download_directory: raise ValueError("no directory to download to") - if not (file_name or self._file_name or self.descriptor.suggested_file_name): + if not (file_name or self.file_name): raise ValueError("no file name to download to") if not os.path.isdir(self.download_directory): log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) os.mkdir(self.download_directory) - self._file_name = await get_next_available_file_name( + self.file_name = await get_next_available_file_name( self.loop, self.download_directory, file_name or self.descriptor.suggested_file_name ) diff --git a/lbry/tests/unit/stream/test_managed_stream.py b/lbry/tests/unit/stream/test_managed_stream.py index fb58a2168..ac31c3b9e 100644 --- a/lbry/tests/unit/stream/test_managed_stream.py +++ b/lbry/tests/unit/stream/test_managed_stream.py @@ -39,6 +39,14 @@ class TestManagedStream(BlobExchangeTestBase): self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir ) + async def test_file_saves_with_valid_file_name(self): + await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov") + self.assertTrue(self.stream.output_file_exists) + self.assertTrue(self.stream.completed) + self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov") + self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name)) + self.assertTrue(os.path.isfile(self.stream.full_path)) + async def test_status_file_completed(self): await self._test_transfer_stream(10) self.assertTrue(self.stream.output_file_exists) @@ -48,7 +56,8 @@ class TestManagedStream(BlobExchangeTestBase): self.assertTrue(self.stream.output_file_exists) self.assertFalse(self.stream.completed) - async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True): + async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True, + file_name=None): await self.setup_stream(blob_count) mock_node = mock.Mock(spec=Node) @@ -59,7 +68,7 @@ class TestManagedStream(BlobExchangeTestBase): return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - await self.stream.save_file(node=mock_node) + await self.stream.save_file(node=mock_node, file_name=file_name) await self.stream.finished_write_attempt.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) if stop_when_done: