diff --git a/lbry/extras/cli.py b/lbry/extras/cli.py index c263a84d9..a33543faf 100644 --- a/lbry/extras/cli.py +++ b/lbry/extras/cli.py @@ -226,6 +226,9 @@ def get_argument_parser(): def ensure_directory_exists(path: str): if not os.path.isdir(path): pathlib.Path(path).mkdir(parents=True, exist_ok=True) + use_effective_ids = os.access in os.supports_effective_ids + if not os.access(path, os.W_OK, effective_ids=use_effective_ids): + raise PermissionError(f"The following directory is not writable: {path}") LOG_MODULES = 'lbry', 'aioupnp' diff --git a/tests/integration/other/test_cli.py b/tests/integration/other/test_cli.py index 7de635fc6..968665333 100644 --- a/tests/integration/other/test_cli.py +++ b/tests/integration/other/test_cli.py @@ -1,4 +1,6 @@ import contextlib +import os +import tempfile from io import StringIO from lbry.testcase import AsyncioTestCase @@ -37,3 +39,9 @@ class CLIIntegrationTest(AsyncioTestCase): cli.main(["--api", "localhost:5299", "status"]) actual_output = actual_output.getvalue() self.assertIn("is_running", actual_output) + + def test_when_download_dir_non_writable_on_start_then_daemon_dies_with_helpful_msg(self): + with tempfile.TemporaryDirectory() as download_dir: + os.chmod(download_dir, mode=0o555) # makes download dir non-writable, readable and executable + with self.assertRaisesRegex(PermissionError, f"The following directory is not writable: {download_dir}"): + cli.main(["start", "--download-dir", download_dir]) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 935364f05..cf6d91002 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -3,6 +3,7 @@ import tempfile import shutil import contextlib import logging +import pathlib from io import StringIO from unittest import TestCase from unittest.mock import patch @@ -12,7 +13,7 @@ from contextlib import asynccontextmanager import docopt from lbry.testcase import AsyncioTestCase -from lbry.extras.cli import normalize_value, main, setup_logging +from lbry.extras.cli import normalize_value, main, setup_logging, ensure_directory_exists from lbry.extras.system_info import get_platform from lbry.extras.daemon.daemon import Daemon from lbry.conf import Config @@ -202,3 +203,37 @@ class DaemonDocsTests(TestCase): pass if failures: self.fail("\n" + "\n".join(failures)) + + +class EnsureDirectoryExistsTests(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_when_parent_dir_does_not_exist_then_dir_is_created_with_parent(self): + dir_path = os.path.join(self.temp_dir, "parent_dir", "dir") + ensure_directory_exists(dir_path) + self.assertTrue(os.path.exists(dir_path)) + + def test_when_non_writable_dir_exists_then_raise(self): + dir_path = os.path.join(self.temp_dir, "dir") + pathlib.Path(dir_path).mkdir(mode=0o555) # creates a non-writable, readable and executable dir + with self.assertRaises(PermissionError): + ensure_directory_exists(dir_path) + + def test_when_dir_exists_and_writable_then_no_raise(self): + dir_path = os.path.join(self.temp_dir, "dir") + pathlib.Path(dir_path).mkdir(mode=0o777) # creates a writable, readable and executable dir + try: + ensure_directory_exists(dir_path) + except (FileExistsError, PermissionError) as err: + self.fail(f"{type(err).__name__} was raised") + + def test_when_non_dir_file_exists_at_path_then_raise(self): + file_path = os.path.join(self.temp_dir, "file.extension") + pathlib.Path(file_path).touch() + with self.assertRaises(FileExistsError): + ensure_directory_exists(file_path)