Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion pyrit/cli/pyrit_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import concurrent.futures
import contextlib
import logging
import os
import shlex
import sys
import threading
from pathlib import Path
Expand All @@ -28,6 +30,38 @@
_T = TypeVar("_T")


def _split_initializer_paths(arg: str) -> list[str]:
"""
Split a command-line argument string into individual file paths.

Supports quoting paths that contain spaces. On Windows, backslashes are treated
as literal path separators (not escape characters) so that unquoted paths such as
``C:\\Users\\me\\init.py`` are preserved; surrounding quotes are stripped from each
token. On POSIX systems, standard ``shlex`` parsing is used.

Args:
arg: The raw argument string passed to the ``add-initializer`` command.

Returns:
The list of individual file path strings parsed from ``arg``.

Raises:
ValueError: If the argument contains unbalanced quotes.
"""
if os.name == "nt":
lexer = shlex.shlex(arg, posix=False)
lexer.whitespace_split = True
tokens = list(lexer)
return [_strip_surrounding_quotes(token) for token in tokens]
return shlex.split(arg)


def _strip_surrounding_quotes(token: str) -> str:
if len(token) >= 2 and token[0] == token[-1] and token[0] in ("'", '"'):
return token[1:-1]
return token


class PyRITShell(cmd.Cmd):
"""
Interactive shell for PyRIT (thin REST client).
Expand Down Expand Up @@ -249,7 +283,13 @@ def do_add_initializer(self, arg: str) -> None:

from pyrit.cli.api_client import ServerNotAvailableError

for script_path_str in arg.split():
try:
script_path_strings = _split_initializer_paths(arg)
except ValueError as exc:
print(f"Error parsing initializer paths: {exc}")
return

for script_path_str in script_path_strings:
script_path = Path(script_path_str).resolve()
if not script_path.exists():
print(f"Error: File not found: {script_path}")
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/cli/test_pyrit_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,45 @@ def test_success_path(self, shell, tmp_path, capsys):
assert "Registered initializer 'my_init'" in capsys.readouterr().out
client.register_initializer_async.assert_awaited_once()

def test_success_with_quoted_path_containing_spaces(self, shell, tmp_path, capsys):
s, client = shell
script_dir = tmp_path / "initializer scripts"
script_dir.mkdir()
script = script_dir / "my_init.py"
script.write_text("def init(): pass")
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})

s.do_add_initializer(f'"{script}"')

assert "Registered initializer 'my_init'" in capsys.readouterr().out
client.register_initializer_async.assert_awaited_once_with(name="my_init", script_content="def init(): pass")

def test_malformed_path_quote(self, shell, capsys):
s, client = shell
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})

s.do_add_initializer('"unterminated')

assert "Error parsing initializer paths" in capsys.readouterr().out
client.register_initializer_async.assert_not_called()

def test_success_with_multiple_quoted_paths(self, shell, tmp_path, capsys):
s, client = shell
script_dir = tmp_path / "initializer scripts"
script_dir.mkdir()
first = script_dir / "first_init.py"
second = script_dir / "second_init.py"
first.write_text("def init(): pass")
second.write_text("def init(): pass")
client.register_initializer_async = AsyncMock(return_value={"status": "ok"})

s.do_add_initializer(f'"{first}" "{second}"')

out = capsys.readouterr().out
assert "Registered initializer 'first_init'" in out
assert "Registered initializer 'second_init'" in out
assert client.register_initializer_async.await_count == 2

def test_server_not_available_error(self, shell, tmp_path, capsys):
from pyrit.cli.api_client import ServerNotAvailableError

Expand Down Expand Up @@ -782,3 +821,32 @@ def test_shell_choices_rejected_before_request(self, shell, capsys):
# do_run surfaces these as "Error: ...".
assert "Error" in out
client.start_scenario_run_async.assert_not_called()


class TestSplitInitializerPaths:
def test_posix_splits_on_whitespace(self):
with patch.object(pyrit_shell.os, "name", "posix"):
assert pyrit_shell._split_initializer_paths("/a/one.py /b/two.py") == ["/a/one.py", "/b/two.py"]

def test_posix_respects_quotes_with_spaces(self):
with patch.object(pyrit_shell.os, "name", "posix"):
assert pyrit_shell._split_initializer_paths('"/a b/one.py"') == ["/a b/one.py"]

def test_windows_preserves_unquoted_backslash_path(self):
with patch.object(pyrit_shell.os, "name", "nt"):
assert pyrit_shell._split_initializer_paths(r"C:\Users\me\init.py") == [r"C:\Users\me\init.py"]

def test_windows_quoted_path_with_spaces_strips_quotes(self):
with patch.object(pyrit_shell.os, "name", "nt"):
assert pyrit_shell._split_initializer_paths(r'"C:\a b\one.py"') == [r"C:\a b\one.py"]

def test_windows_multiple_paths(self):
with patch.object(pyrit_shell.os, "name", "nt"):
result = pyrit_shell._split_initializer_paths(r'"C:\a b\one.py" C:\c\two.py')
assert result == [r"C:\a b\one.py", r"C:\c\two.py"]

@pytest.mark.parametrize("os_name", ["posix", "nt"])
def test_unterminated_quote_raises(self, os_name):
with patch.object(pyrit_shell.os, "name", os_name):
with pytest.raises(ValueError):
pyrit_shell._split_initializer_paths('"unterminated')