diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 1a0760eb7c..06b499cb38 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -15,6 +15,8 @@ import concurrent.futures import contextlib import logging +import os +import shlex import sys import threading from pathlib import Path @@ -28,6 +30,38 @@ _T = TypeVar("_T") +def _split_initializer_paths(arg: str) -> list[str]: + """ + Split a command-line argument string into individual file paths. + + Supports quoting paths that contain spaces. On Windows, backslashes are treated + as literal path separators (not escape characters) so that unquoted paths such as + ``C:\\Users\\me\\init.py`` are preserved; surrounding quotes are stripped from each + token. On POSIX systems, standard ``shlex`` parsing is used. + + Args: + arg: The raw argument string passed to the ``add-initializer`` command. + + Returns: + The list of individual file path strings parsed from ``arg``. + + Raises: + ValueError: If the argument contains unbalanced quotes. + """ + if os.name == "nt": + lexer = shlex.shlex(arg, posix=False) + lexer.whitespace_split = True + tokens = list(lexer) + return [_strip_surrounding_quotes(token) for token in tokens] + return shlex.split(arg) + + +def _strip_surrounding_quotes(token: str) -> str: + if len(token) >= 2 and token[0] == token[-1] and token[0] in ("'", '"'): + return token[1:-1] + return token + + class PyRITShell(cmd.Cmd): """ Interactive shell for PyRIT (thin REST client). @@ -249,7 +283,13 @@ def do_add_initializer(self, arg: str) -> None: from pyrit.cli.api_client import ServerNotAvailableError - for script_path_str in arg.split(): + try: + script_path_strings = _split_initializer_paths(arg) + except ValueError as exc: + print(f"Error parsing initializer paths: {exc}") + return + + for script_path_str in script_path_strings: script_path = Path(script_path_str).resolve() if not script_path.exists(): print(f"Error: File not found: {script_path}") diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 788ed71ad5..7264f49525 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -358,6 +358,45 @@ def test_success_path(self, shell, tmp_path, capsys): assert "Registered initializer 'my_init'" in capsys.readouterr().out client.register_initializer_async.assert_awaited_once() + def test_success_with_quoted_path_containing_spaces(self, shell, tmp_path, capsys): + s, client = shell + script_dir = tmp_path / "initializer scripts" + script_dir.mkdir() + script = script_dir / "my_init.py" + script.write_text("def init(): pass") + client.register_initializer_async = AsyncMock(return_value={"status": "ok"}) + + s.do_add_initializer(f'"{script}"') + + assert "Registered initializer 'my_init'" in capsys.readouterr().out + client.register_initializer_async.assert_awaited_once_with(name="my_init", script_content="def init(): pass") + + def test_malformed_path_quote(self, shell, capsys): + s, client = shell + client.register_initializer_async = AsyncMock(return_value={"status": "ok"}) + + s.do_add_initializer('"unterminated') + + assert "Error parsing initializer paths" in capsys.readouterr().out + client.register_initializer_async.assert_not_called() + + def test_success_with_multiple_quoted_paths(self, shell, tmp_path, capsys): + s, client = shell + script_dir = tmp_path / "initializer scripts" + script_dir.mkdir() + first = script_dir / "first_init.py" + second = script_dir / "second_init.py" + first.write_text("def init(): pass") + second.write_text("def init(): pass") + client.register_initializer_async = AsyncMock(return_value={"status": "ok"}) + + s.do_add_initializer(f'"{first}" "{second}"') + + out = capsys.readouterr().out + assert "Registered initializer 'first_init'" in out + assert "Registered initializer 'second_init'" in out + assert client.register_initializer_async.await_count == 2 + def test_server_not_available_error(self, shell, tmp_path, capsys): from pyrit.cli.api_client import ServerNotAvailableError @@ -782,3 +821,32 @@ def test_shell_choices_rejected_before_request(self, shell, capsys): # do_run surfaces these as "Error: ...". assert "Error" in out client.start_scenario_run_async.assert_not_called() + + +class TestSplitInitializerPaths: + def test_posix_splits_on_whitespace(self): + with patch.object(pyrit_shell.os, "name", "posix"): + assert pyrit_shell._split_initializer_paths("/a/one.py /b/two.py") == ["/a/one.py", "/b/two.py"] + + def test_posix_respects_quotes_with_spaces(self): + with patch.object(pyrit_shell.os, "name", "posix"): + assert pyrit_shell._split_initializer_paths('"/a b/one.py"') == ["/a b/one.py"] + + def test_windows_preserves_unquoted_backslash_path(self): + with patch.object(pyrit_shell.os, "name", "nt"): + assert pyrit_shell._split_initializer_paths(r"C:\Users\me\init.py") == [r"C:\Users\me\init.py"] + + def test_windows_quoted_path_with_spaces_strips_quotes(self): + with patch.object(pyrit_shell.os, "name", "nt"): + assert pyrit_shell._split_initializer_paths(r'"C:\a b\one.py"') == [r"C:\a b\one.py"] + + def test_windows_multiple_paths(self): + with patch.object(pyrit_shell.os, "name", "nt"): + result = pyrit_shell._split_initializer_paths(r'"C:\a b\one.py" C:\c\two.py') + assert result == [r"C:\a b\one.py", r"C:\c\two.py"] + + @pytest.mark.parametrize("os_name", ["posix", "nt"]) + def test_unterminated_quote_raises(self, os_name): + with patch.object(pyrit_shell.os, "name", os_name): + with pytest.raises(ValueError): + pyrit_shell._split_initializer_paths('"unterminated')