Sanitize file names before saving

This commit is contained in:
Miroslav Kovar 2019-09-23 17:48:36 +02:00 committed by Lex Berezhny
parent f60db358ee
commit ed93c647d2
2 changed files with 27 additions and 7 deletions

View file

@ -3,6 +3,7 @@ import asyncio
import typing import typing
import logging import logging
import binascii import binascii
import re
from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable
from lbry.utils import generate_id from lbry.utils import generate_id
from lbry.error import DownloadSDTimeout from lbry.error import DownloadSDTimeout
@ -39,6 +40,11 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
def sanitize_file_name(file_name):
file_name = str(file_name).strip().replace(' ', '_')
return re.sub(r'(?u)[^-\w.]', '', file_name)
class ManagedStream: class ManagedStream:
STATUS_RUNNING = "running" STATUS_RUNNING = "running"
STATUS_STOPPED = "stopped" STATUS_STOPPED = "stopped"
@ -115,7 +121,12 @@ class ManagedStream:
@property @property
def file_name(self) -> typing.Optional[str]: def file_name(self) -> typing.Optional[str]:
return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None) return self._file_name or \
(sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None)
@file_name.setter
def file_name(self, value):
self._file_name = sanitize_file_name(value)
@property @property
def status(self) -> str: def status(self) -> str:
@ -250,7 +261,7 @@ class ManagedStream:
self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
if not await self.blob_manager.storage.file_exists(self.sd_hash): if not await self.blob_manager.storage.file_exists(self.sd_hash):
if save_now: if save_now:
file_name, download_dir = self._file_name, self.download_directory file_name, download_dir = self.file_name, self.download_directory
else: else:
file_name, download_dir = None, None file_name, download_dir = None, None
self.rowid = await self.blob_manager.storage.save_downloaded_file( self.rowid = await self.blob_manager.storage.save_downloaded_file(
@ -360,7 +371,7 @@ class ManagedStream:
await self.blob_manager.storage.change_file_download_dir_and_file_name( await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, None, None self.stream_hash, None, None
) )
self._file_name, self.download_directory = None, None self.file_name, self.download_directory = None, None
await self.blob_manager.storage.clear_saved_file(self.stream_hash) await self.blob_manager.storage.clear_saved_file(self.stream_hash)
await self.update_status(self.STATUS_STOPPED) await self.update_status(self.STATUS_STOPPED)
return return
@ -379,12 +390,12 @@ class ManagedStream:
self.download_directory = download_directory or self.download_directory or self.config.download_dir self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory: if not self.download_directory:
raise ValueError("no directory to download to") raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name): if not (file_name or self.file_name):
raise ValueError("no file name to download to") raise ValueError("no file name to download to")
if not os.path.isdir(self.download_directory): if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory) log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory) os.mkdir(self.download_directory)
self._file_name = await get_next_available_file_name( self.file_name = await get_next_available_file_name(
self.loop, self.download_directory, self.loop, self.download_directory,
file_name or self.descriptor.suggested_file_name file_name or self.descriptor.suggested_file_name
) )

View file

@ -39,6 +39,14 @@ class TestManagedStream(BlobExchangeTestBase):
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
) )
async def test_file_saves_with_valid_file_name(self):
await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov")
self.assertTrue(self.stream.output_file_exists)
self.assertTrue(self.stream.completed)
self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov")
self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name))
self.assertTrue(os.path.isfile(self.stream.full_path))
async def test_status_file_completed(self): async def test_status_file_completed(self):
await self._test_transfer_stream(10) await self._test_transfer_stream(10)
self.assertTrue(self.stream.output_file_exists) self.assertTrue(self.stream.output_file_exists)
@ -48,7 +56,8 @@ class TestManagedStream(BlobExchangeTestBase):
self.assertTrue(self.stream.output_file_exists) self.assertTrue(self.stream.output_file_exists)
self.assertFalse(self.stream.completed) self.assertFalse(self.stream.completed)
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True): async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True,
file_name=None):
await self.setup_stream(blob_count) await self.setup_stream(blob_count)
mock_node = mock.Mock(spec=Node) mock_node = mock.Mock(spec=Node)
@ -59,7 +68,7 @@ class TestManagedStream(BlobExchangeTestBase):
return q2, self.loop.create_task(_task()) return q2, self.loop.create_task(_task())
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
await self.stream.save_file(node=mock_node) await self.stream.save_file(node=mock_node, file_name=file_name)
await self.stream.finished_write_attempt.wait() await self.stream.finished_write_attempt.wait()
self.assertTrue(os.path.isfile(self.stream.full_path)) self.assertTrue(os.path.isfile(self.stream.full_path))
if stop_when_done: if stop_when_done: