Sanitize suggested, add tests
This commit is contained in:
parent
ed93c647d2
commit
268f6447cb
6 changed files with 92 additions and 31 deletions
|
@ -4,6 +4,7 @@ import binascii
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||||
from lbry.blob import MAX_BLOB_SIZE
|
from lbry.blob import MAX_BLOB_SIZE
|
||||||
|
@ -46,6 +47,29 @@ def file_reader(file_path: str):
|
||||||
offset += bytes_to_read
|
offset += bytes_to_read
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_file_name(dirty_name: str):
|
||||||
|
RE_IL_CHARS = re.compile(r'[<>:"/\\\|\?\*]')
|
||||||
|
RE_DOTS_AT_END = re.compile(r'[ \t]*(\.)+[ \t]$')
|
||||||
|
RE_LEAD_TRAIL_SPACE = re.compile(r'(^[ \t]+|[ \t]+$)')
|
||||||
|
res_il_names = \
|
||||||
|
[re.compile(regex) for regex in (r'^CON$', r'^PRN$', r'^AUX$', r'^NUL$', r'^COM[1-9]$', r'^LPT[1-9]$')]
|
||||||
|
|
||||||
|
file_name = re.sub(RE_IL_CHARS, '', dirty_name)
|
||||||
|
file_name, ext = os.path.splitext(file_name)
|
||||||
|
file_name = re.sub(RE_LEAD_TRAIL_SPACE, '', file_name)
|
||||||
|
file_name = re.sub(RE_DOTS_AT_END, '', file_name)
|
||||||
|
ext = re.sub(RE_LEAD_TRAIL_SPACE, '', ext)
|
||||||
|
if any((REGEX.match(file_name) for REGEX in res_il_names)):
|
||||||
|
file_name = ''
|
||||||
|
elif file_name and len(ext) > 1:
|
||||||
|
file_name += ext
|
||||||
|
|
||||||
|
if len(file_name) == 0:
|
||||||
|
log.warning('Unable to suggest a file name for %s', dirty_name)
|
||||||
|
|
||||||
|
return file_name
|
||||||
|
|
||||||
|
|
||||||
class StreamDescriptor:
|
class StreamDescriptor:
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
'loop',
|
'loop',
|
||||||
|
@ -65,7 +89,7 @@ class StreamDescriptor:
|
||||||
self.blob_dir = blob_dir
|
self.blob_dir = blob_dir
|
||||||
self.stream_name = stream_name
|
self.stream_name = stream_name
|
||||||
self.key = key
|
self.key = key
|
||||||
self.suggested_file_name = suggested_file_name
|
self.suggested_file_name = sanitize_file_name(suggested_file_name)
|
||||||
self.blobs = blobs
|
self.blobs = blobs
|
||||||
self.stream_hash = stream_hash or self.get_stream_hash()
|
self.stream_hash = stream_hash or self.get_stream_hash()
|
||||||
self.sd_hash = sd_hash
|
self.sd_hash = sd_hash
|
||||||
|
|
|
@ -3,7 +3,6 @@ import asyncio
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
import binascii
|
import binascii
|
||||||
import re
|
|
||||||
from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable
|
from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable
|
||||||
from lbry.utils import generate_id
|
from lbry.utils import generate_id
|
||||||
from lbry.error import DownloadSDTimeout
|
from lbry.error import DownloadSDTimeout
|
||||||
|
@ -40,11 +39,6 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download
|
||||||
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
|
return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_file_name(file_name):
|
|
||||||
file_name = str(file_name).strip().replace(' ', '_')
|
|
||||||
return re.sub(r'(?u)[^-\w.]', '', file_name)
|
|
||||||
|
|
||||||
|
|
||||||
class ManagedStream:
|
class ManagedStream:
|
||||||
STATUS_RUNNING = "running"
|
STATUS_RUNNING = "running"
|
||||||
STATUS_STOPPED = "stopped"
|
STATUS_STOPPED = "stopped"
|
||||||
|
@ -121,12 +115,7 @@ class ManagedStream:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def file_name(self) -> typing.Optional[str]:
|
def file_name(self) -> typing.Optional[str]:
|
||||||
return self._file_name or \
|
return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
|
||||||
(sanitize_file_name(self.descriptor.suggested_file_name) if self.descriptor else None)
|
|
||||||
|
|
||||||
@file_name.setter
|
|
||||||
def file_name(self, value):
|
|
||||||
self._file_name = sanitize_file_name(value)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def status(self) -> str:
|
def status(self) -> str:
|
||||||
|
@ -261,7 +250,7 @@ class ManagedStream:
|
||||||
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
|
self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
|
||||||
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
if not await self.blob_manager.storage.file_exists(self.sd_hash):
|
||||||
if save_now:
|
if save_now:
|
||||||
file_name, download_dir = self.file_name, self.download_directory
|
file_name, download_dir = self._file_name, self.download_directory
|
||||||
else:
|
else:
|
||||||
file_name, download_dir = None, None
|
file_name, download_dir = None, None
|
||||||
self.rowid = await self.blob_manager.storage.save_downloaded_file(
|
self.rowid = await self.blob_manager.storage.save_downloaded_file(
|
||||||
|
@ -371,7 +360,7 @@ class ManagedStream:
|
||||||
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
||||||
self.stream_hash, None, None
|
self.stream_hash, None, None
|
||||||
)
|
)
|
||||||
self.file_name, self.download_directory = None, None
|
self._file_name, self.download_directory = None, None
|
||||||
await self.blob_manager.storage.clear_saved_file(self.stream_hash)
|
await self.blob_manager.storage.clear_saved_file(self.stream_hash)
|
||||||
await self.update_status(self.STATUS_STOPPED)
|
await self.update_status(self.STATUS_STOPPED)
|
||||||
return
|
return
|
||||||
|
@ -390,14 +379,14 @@ class ManagedStream:
|
||||||
self.download_directory = download_directory or self.download_directory or self.config.download_dir
|
self.download_directory = download_directory or self.download_directory or self.config.download_dir
|
||||||
if not self.download_directory:
|
if not self.download_directory:
|
||||||
raise ValueError("no directory to download to")
|
raise ValueError("no directory to download to")
|
||||||
if not (file_name or self.file_name):
|
if not (file_name or self._file_name or self.descriptor.suggested_file_name):
|
||||||
raise ValueError("no file name to download to")
|
raise ValueError("no file name to download to")
|
||||||
if not os.path.isdir(self.download_directory):
|
if not os.path.isdir(self.download_directory):
|
||||||
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
|
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
|
||||||
os.mkdir(self.download_directory)
|
os.mkdir(self.download_directory)
|
||||||
self.file_name = await get_next_available_file_name(
|
self._file_name = await get_next_available_file_name(
|
||||||
self.loop, self.download_directory,
|
self.loop, self.download_directory,
|
||||||
file_name or self.descriptor.suggested_file_name
|
file_name or self.file_name
|
||||||
)
|
)
|
||||||
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
await self.blob_manager.storage.change_file_download_dir_and_file_name(
|
||||||
self.stream_hash, self.download_directory, self.file_name
|
self.stream_hash, self.download_directory, self.file_name
|
||||||
|
|
|
@ -197,8 +197,9 @@ class CommandTestCase(IntegrationTestCase):
|
||||||
""" Synchronous version of `out` method. """
|
""" Synchronous version of `out` method. """
|
||||||
return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result']
|
return json.loads(jsonrpc_dumps_pretty(value, ledger=self.ledger))['result']
|
||||||
|
|
||||||
async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True, **kwargs):
|
async def stream_create(self, name='hovercraft', bid='1.0', data=b'hi!', confirm=True,
|
||||||
file = tempfile.NamedTemporaryFile()
|
prefix=None, suffix=None, **kwargs):
|
||||||
|
file = tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix)
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -45,6 +45,40 @@ class FileCommands(CommandTestCase):
|
||||||
{stream.sd_hash, stream.descriptor.blobs[0].blob_hash}
|
{stream.sd_hash, stream.descriptor.blobs[0].blob_hash}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def test_publish_with_illegal_chars(self):
|
||||||
|
# Stream a file with file name containing invalid chars
|
||||||
|
prefix = '?an|t:z< m<'
|
||||||
|
suffix = '.ext.'
|
||||||
|
san_prefix = 'antz m'
|
||||||
|
san_suffix = '.ext'
|
||||||
|
claim = await self.stream_create('foo', '0.01', prefix=prefix, suffix=suffix)
|
||||||
|
stream = self.daemon.jsonrpc_file_list()[0]
|
||||||
|
|
||||||
|
# Check that file list and source contains the local unsanitized name, but suggested name is sanitized
|
||||||
|
source_file_name = claim['outputs'][0]['value']['source']['name']
|
||||||
|
file_list_name = stream.file_name
|
||||||
|
suggested_file_name = stream.descriptor.suggested_file_name
|
||||||
|
self.assertTrue(source_file_name.startswith(prefix))
|
||||||
|
self.assertTrue(source_file_name.endswith(suffix))
|
||||||
|
self.assertEqual(file_list_name, source_file_name)
|
||||||
|
self.assertTrue(suggested_file_name.startswith(san_prefix))
|
||||||
|
self.assertTrue(suggested_file_name.endswith(san_suffix))
|
||||||
|
|
||||||
|
# Delete the file, re-download and check that the file name is sanitized
|
||||||
|
self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo'))
|
||||||
|
full_path = (await self.daemon.jsonrpc_get('lbry://foo', save_file=True)).full_path
|
||||||
|
file_name = os.path.basename(full_path)
|
||||||
|
self.assertTrue(os.path.isfile(full_path))
|
||||||
|
self.assertTrue(file_name.startswith(san_prefix))
|
||||||
|
self.assertTrue(file_name.endswith(san_suffix))
|
||||||
|
|
||||||
|
# Check that the downloaded file name is not sanitized when user provides custom file name
|
||||||
|
self.assertTrue(await self.daemon.jsonrpc_file_delete(claim_name='foo'))
|
||||||
|
file_name = 'my <u?|*m.name'
|
||||||
|
full_path = (await self.daemon.jsonrpc_get('lbry://foo', file_name=file_name, save_file=True)).full_path
|
||||||
|
self.assertTrue(os.path.isfile(full_path))
|
||||||
|
self.assertEqual(file_name, os.path.basename(full_path))
|
||||||
|
|
||||||
async def test_file_list_fields(self):
|
async def test_file_list_fields(self):
|
||||||
await self.stream_create('foo', '0.01')
|
await self.stream_create('foo', '0.01')
|
||||||
file_list = self.sout(self.daemon.jsonrpc_file_list())
|
file_list = self.sout(self.daemon.jsonrpc_file_list())
|
||||||
|
|
|
@ -21,12 +21,12 @@ def get_mock_node(loop):
|
||||||
|
|
||||||
|
|
||||||
class TestManagedStream(BlobExchangeTestBase):
|
class TestManagedStream(BlobExchangeTestBase):
|
||||||
async def create_stream(self, blob_count: int = 10):
|
async def create_stream(self, blob_count: int = 10, file_name='test_file'):
|
||||||
self.stream_bytes = b''
|
self.stream_bytes = b''
|
||||||
for _ in range(blob_count):
|
for _ in range(blob_count):
|
||||||
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
|
self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
|
||||||
# create the stream
|
# create the stream
|
||||||
file_path = os.path.join(self.server_dir, "test_file")
|
file_path = os.path.join(self.server_dir, file_name)
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, 'wb') as f:
|
||||||
f.write(self.stream_bytes)
|
f.write(self.stream_bytes)
|
||||||
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
|
||||||
|
@ -39,13 +39,20 @@ class TestManagedStream(BlobExchangeTestBase):
|
||||||
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
|
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_file_saves_with_valid_file_name(self):
|
async def test_client_sanitizes_file_name(self):
|
||||||
await self._test_transfer_stream(10, file_name="Bitcoin can't be shut down or regulated?.mov")
|
illegal_name = 't<?t_f:|<'
|
||||||
self.assertTrue(self.stream.output_file_exists)
|
descriptor = await self.create_stream(file_name=illegal_name)
|
||||||
|
descriptor.suggested_file_name = illegal_name
|
||||||
|
self.stream = ManagedStream(
|
||||||
|
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir
|
||||||
|
)
|
||||||
|
await self._test_transfer_stream(1, skip_setup=True)
|
||||||
self.assertTrue(self.stream.completed)
|
self.assertTrue(self.stream.completed)
|
||||||
self.assertEqual(self.stream.file_name, "Bitcoin_cant_be_shut_down_or_regulated.mov")
|
self.assertEqual(self.stream.file_name, 'tt_f')
|
||||||
self.assertEqual(self.stream.full_path, os.path.join(self.stream.download_directory, self.stream.file_name))
|
self.assertTrue(self.stream.output_file_exists)
|
||||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||||
|
self.assertEqual(self.stream.full_path, os.path.join(self.client_dir, 'tt_f'))
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, 'tt_f')))
|
||||||
|
|
||||||
async def test_status_file_completed(self):
|
async def test_status_file_completed(self):
|
||||||
await self._test_transfer_stream(10)
|
await self._test_transfer_stream(10)
|
||||||
|
@ -57,8 +64,9 @@ class TestManagedStream(BlobExchangeTestBase):
|
||||||
self.assertFalse(self.stream.completed)
|
self.assertFalse(self.stream.completed)
|
||||||
|
|
||||||
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True,
|
async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None, stop_when_done=True,
|
||||||
file_name=None):
|
skip_setup=False):
|
||||||
await self.setup_stream(blob_count)
|
if not skip_setup:
|
||||||
|
await self.setup_stream(blob_count)
|
||||||
mock_node = mock.Mock(spec=Node)
|
mock_node = mock.Mock(spec=Node)
|
||||||
|
|
||||||
def _mock_accumulate_peers(q1, q2):
|
def _mock_accumulate_peers(q1, q2):
|
||||||
|
@ -68,7 +76,7 @@ class TestManagedStream(BlobExchangeTestBase):
|
||||||
return q2, self.loop.create_task(_task())
|
return q2, self.loop.create_task(_task())
|
||||||
|
|
||||||
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
|
||||||
await self.stream.save_file(node=mock_node, file_name=file_name)
|
await self.stream.save_file(node=mock_node)
|
||||||
await self.stream.finished_write_attempt.wait()
|
await self.stream.finished_write_attempt.wait()
|
||||||
self.assertTrue(os.path.isfile(self.stream.full_path))
|
self.assertTrue(os.path.isfile(self.stream.full_path))
|
||||||
if stop_when_done:
|
if stop_when_done:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from lbry.conf import Config
|
||||||
from lbry.error import InvalidStreamDescriptorError
|
from lbry.error import InvalidStreamDescriptorError
|
||||||
from lbry.extras.daemon.storage import SQLiteStorage
|
from lbry.extras.daemon.storage import SQLiteStorage
|
||||||
from lbry.blob.blob_manager import BlobManager
|
from lbry.blob.blob_manager import BlobManager
|
||||||
from lbry.stream.descriptor import StreamDescriptor
|
from lbry.stream.descriptor import StreamDescriptor, sanitize_file_name
|
||||||
|
|
||||||
|
|
||||||
class TestStreamDescriptor(AsyncioTestCase):
|
class TestStreamDescriptor(AsyncioTestCase):
|
||||||
|
@ -78,6 +78,11 @@ class TestStreamDescriptor(AsyncioTestCase):
|
||||||
self.sd_dict['blobs'][-2]['length'] = 0
|
self.sd_dict['blobs'][-2]['length'] = 0
|
||||||
await self._test_invalid_sd()
|
await self._test_invalid_sd()
|
||||||
|
|
||||||
|
async def test_sanitize_file_name(self):
|
||||||
|
test_cases = [' t/-?t|.g.ext ', 'end_me .', '', '.file', 'test name.ext', 'COM8', 'LPT2']
|
||||||
|
expected = ['t-t.g.ext', 'end_me', '', '.file', 'test name.ext', '', '']
|
||||||
|
actual = [sanitize_file_name(tc) for tc in test_cases]
|
||||||
|
self.assertListEqual(actual, expected)
|
||||||
|
|
||||||
class TestRecoverOldStreamDescriptors(AsyncioTestCase):
|
class TestRecoverOldStreamDescriptors(AsyncioTestCase):
|
||||||
async def test_old_key_sort_sd_blob(self):
|
async def test_old_key_sort_sd_blob(self):
|
||||||
|
|
Loading…
Reference in a new issue