Sanitize file names before saving
This commit is contained in:
parent
f60db358ee
commit
ed93c647d2
2 changed files with 27 additions and 7 deletions
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import typing
|
||||
import logging
|
||||
import binascii
|
||||
import re
|
||||
from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable
|
||||
from lbry.utils import generate_id
|
||||
from lbry.error import DownloadSDTimeout
|
||||
|
@ -39,6 +40,11 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download
|
|||
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
|
||||
|
||||
|
||||
def sanitize_file_name(file_name):
|
||||
file_name = str(file_name).strip().replace(' ', '_')
|
||||
return re.sub(r'(?u)[^-\w.]', '', file_name)
|
||||
|
||||
|
||||
class ManagedStream:
|
||||
STATUS_RUNNING = "running"
|
||||
STATUS_STOPPED = "stopped"
|
||||
|
@ -115,7 +121,12 @@ class ManagedStream:
|
|||
|
||||
@property
|
||||
def file_name(self) -> typing.Optional[str]:
|
||||
return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
|
||||
return self._file_name or \
|
||||
(sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None)
|
||||
|
||||
@file_name.setter
|
||||
def file_name(self, value):
|
||||
self._file_name = sanitize_file_name(value)
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
|
@ -250,7 +261,7 @@ class ManagedStream:
|
|||
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
|
||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||
if save_now:
|
||||
file_name, download_dir = self._file_name, self.download_directory
|
||||
file_name, download_dir = self.file_name, self.download_directory
|
||||
else:
|
||||
file_name, download_dir = None, None
|
||||
self.rowid = await self.blob_manager.storage.save_downloaded_file(
|
||||
|
@ -360,7 +371,7 @@ class ManagedStream:
|
|||
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
||||
self.stream_hash, None, None
|
||||
)
|
||||
self._file_name, self.download_directory = None, None
|
||||
self.file_name, self.download_directory = None, None
|
||||
await self.blob_manager.storage.clear_saved_file(self.stream_hash)
|
||||
await self.update_status(self.STATUS_STOPPED)
|
||||
return
|
||||
|
@ -379,12 +390,12 @@ class ManagedStream:
|
|||
self.download_directory = download_directory or self.download_directory or self.config.download_dir
|
||||
if not self.download_directory:
|
||||
raise ValueError("no directory to download to")
|
||||
if not (file_name or self._file_name or self.descriptor.suggested_file_name):
|
||||
if not (file_name or self.file_name):
|
||||
raise ValueError("no file name to download to")
|
||||
if not os.path.isdir(self.download_directory):
|
||||
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
|
||||
os.mkdir(self.download_directory)
|
||||
self._file_name = await get_next_available_file_name(
|
||||
self.file_name = await get_next_available_file_name(
|
||||
self.loop, self.download_directory,
|
||||
file_name or self.descriptor.suggested_file_name
|
||||
)
|
||||
|
|
|
@ -39,6 +39,14 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
|
||||
)
|
||||
|
||||
async def test_file_saves_with_valid_file_name(self):
|
||||
await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov")
|
||||
self.assertTrue(self.stream.output_file_exists)
|
||||
self.assertTrue(self.stream.completed)
|
||||
self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov")
|
||||
self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name))
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
|
||||
async def test_status_file_completed(self):
|
||||
await self._test_transfer_stream(10)
|
||||
self.assertTrue(self.stream.output_file_exists)
|
||||
|
@ -48,7 +56,8 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
self.assertTrue(self.stream.output_file_exists)
|
||||
self.assertFalse(self.stream.completed)
|
||||
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True):
|
||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True,
|
||||
file_name=None):
|
||||
await self.setup_stream(blob_count)
|
||||
mock_node = mock.Mock(spec=Node)
|
||||
|
||||
|
@ -59,7 +68,7 @@ class TestManagedStream(BlobExchangeTestBase):
|
|||
return q2, self.loop.create_task(_task())
|
||||
|
||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||
await self.stream.save_file(node=mock_node)
|
||||
await self.stream.save_file(node=mock_node, file_name=file_name)
|
||||
await self.stream.finished_write_attempt.wait()
|
||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||
if stop_when_done:
|
||||
|
|
Loading…
Add table
Reference in a new issue