diff --git a/rampart/core/payload_ids.py b/rampart/core/payload_ids.py new file mode 100644 index 0000000..30866ce --- /dev/null +++ b/rampart/core/payload_ids.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Validation helpers for payload identifiers.""" + +from __future__ import annotations + +import re + +_PAYLOAD_ID_PATTERN = re.compile(r"^[A-Za-z0-9._-]{1,128}$") + + +def validate_payload_id(payload_id: str) -> None: + """Validate that a payload ID is safe to embed in file names. + + Payload IDs are used for local artifact filenames and remote upload + names. Keep them to a small cross-platform filename-safe alphabet so + generated artifacts cannot introduce path separators, path traversal, + control characters, or Graph path-addressing delimiters. + """ + if not _PAYLOAD_ID_PATTERN.fullmatch(payload_id): + msg = ( + f"Invalid payload id: {payload_id!r}. Payload IDs must be 1-128 " + "characters using only letters, numbers, '.', '_', and '-'." + ) + raise ValueError(msg) + + if payload_id in {".", ".."}: + msg = ( + f"Invalid payload id: {payload_id!r}. Payload IDs cannot be path segments." + ) + raise ValueError(msg) diff --git a/rampart/core/types.py b/rampart/core/types.py index 3d206aa..cbc6782 100644 --- a/rampart/core/types.py +++ b/rampart/core/types.py @@ -14,6 +14,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any +from rampart.core.payload_ids import validate_payload_id + if TYPE_CHECKING: from datetime import datetime from pathlib import Path @@ -119,6 +121,7 @@ class Payload: def __post_init__(self) -> None: """Validate content-format-artifact consistency.""" + validate_payload_id(self.id) if self.format.is_binary and self.artifact is None: msg = ( f"Binary format {self.format.value} requires an " diff --git a/rampart/payloads/_store.py b/rampart/payloads/_store.py index 91cef07..5740bed 100644 --- a/rampart/payloads/_store.py +++ b/rampart/payloads/_store.py @@ -31,9 +31,11 @@ from pathlib import Path from typing import Any +from rampart.core.payload_ids import validate_payload_id from rampart.core.types import Payload, PayloadFormat logger = logging.getLogger(__name__) +_MIN_ARTIFACT_PATH_PARTS = 2 class PayloadStore: @@ -326,11 +328,54 @@ def _copy_file_artifact( Returns: str: Relative artifact path (e.g., 'artifacts/abc123.pdf'). """ + validate_payload_id(payload_id) artifacts_dir.mkdir(parents=True, exist_ok=True) filename = f"{payload_id}{extension}" - shutil.copy2(source, artifacts_dir / filename) + destination = artifacts_dir / filename + PayloadStore._ensure_within_directory( + path=destination, + directory=artifacts_dir, + description="artifact destination", + ) + shutil.copy2(source, destination) return f"artifacts/{filename}" + @staticmethod + def _ensure_within_directory( + *, + path: Path, + directory: Path, + description: str, + ) -> None: + """Raise if a resolved path escapes a required directory.""" + resolved_path = path.resolve(strict=False) + resolved_directory = directory.resolve(strict=False) + if not resolved_path.is_relative_to(resolved_directory): + msg = f"Invalid {description}: {path!s} escapes {directory!s}" + raise ValueError(msg) + + @staticmethod + def _resolve_artifact_path(*, collection_dir: Path, artifact: str) -> Path: + """Resolve a serialized artifact path inside collection artifacts/.""" + artifact_path = Path(artifact) + if artifact_path.is_absolute() or ".." in artifact_path.parts: + msg = f"Invalid artifact path: {artifact!r}. Must stay under artifacts/." + raise ValueError(msg) + if ( + len(artifact_path.parts) < _MIN_ARTIFACT_PATH_PARTS + or artifact_path.parts[0] != "artifacts" + ): + msg = f"Invalid artifact path: {artifact!r}. Must be under artifacts/." + raise ValueError(msg) + + resolved = collection_dir / artifact_path + PayloadStore._ensure_within_directory( + path=resolved, + directory=collection_dir / "artifacts", + description="artifact path", + ) + return resolved + @staticmethod def _deserialize( *, @@ -354,7 +399,10 @@ def _deserialize( artifact: Path | None = None if "artifact" in data: - artifact_path = collection_dir / data["artifact"] + artifact_path = PayloadStore._resolve_artifact_path( + collection_dir=collection_dir, + artifact=data["artifact"], + ) if not artifact_path.exists(): msg = f"Missing artifact: {artifact_path}" raise FileNotFoundError(msg) diff --git a/rampart/surfaces/onedrive.py b/rampart/surfaces/onedrive.py index d239743..956a9af 100644 --- a/rampart/surfaces/onedrive.py +++ b/rampart/surfaces/onedrive.py @@ -14,6 +14,7 @@ from rampart.core.errors import InfrastructureError from rampart.core.injection import sleep_until_ready +from rampart.core.payload_ids import validate_payload_id if TYPE_CHECKING: import types @@ -110,6 +111,7 @@ async def upload_async(self, *, payload: Payload) -> str: if the payload exceeds the 4 MiB small-upload limit. InfrastructureError: If Graph returns no ``DriveItem``. """ + validate_payload_id(payload.id) filename = f"{payload.id}{payload.format.extension}" upload_path = f"{self.folder_path}/{filename}" diff --git a/tests/unit/payloads/test_payload_store_security.py b/tests/unit/payloads/test_payload_store_security.py new file mode 100644 index 0000000..e8ec796 --- /dev/null +++ b/tests/unit/payloads/test_payload_store_security.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Security tests for payload IDs and artifact path handling.""" + +import json +from pathlib import Path + +import pytest + +from rampart.core.types import Payload, PayloadFormat +from rampart.payloads._store import PayloadStore + + +@pytest.mark.parametrize( + "payload_id", + [ + "", + ".", + "..", + "../outside", + "..\\outside", + "/tmp/outside", + "nested/name", + "graph:path", + "line\nbreak", + "x" * 129, + ], +) +def test_payload_id_rejects_path_unsafe_values(payload_id: str) -> None: + """Payload IDs reject values that can become unsafe path components.""" + with pytest.raises(ValueError, match="Invalid payload id"): + Payload(content="content", id=payload_id) + + +def test_payload_store_keeps_binary_artifacts_under_artifacts_dir( + tmp_path: Path, +) -> None: + """Binary payload artifacts are copied beneath the artifacts directory.""" + source = tmp_path / "source.pdf" + source.write_bytes(b"pdf") + + store = PayloadStore(root=tmp_path / "store") + payload = Payload( + content="binary", + id="safe-id_1.2", + format=PayloadFormat.PDF, + artifact=source, + ) + + collection_dir = store.save("collection", payloads=[payload]) + + expected_artifact = collection_dir / "artifacts" / "safe-id_1.2.pdf" + assert expected_artifact.read_bytes() == b"pdf" + assert store.load("collection")[0].artifact == expected_artifact + + +def test_payload_store_rejects_deserialized_artifact_traversal(tmp_path: Path) -> None: + """Serialized artifact paths cannot traverse outside the collection.""" + collection_dir = tmp_path / "store" / "collection" + collection_dir.mkdir(parents=True) + outside = tmp_path / "outside.pdf" + outside.write_bytes(b"outside") + record: dict[str, object] = { + "id": "safe-id", + "content": "binary", + "format": "pdf", + "metadata": {}, + "artifact": "../outside.pdf", + } + (collection_dir / "payloads.jsonl").write_text(json.dumps(record) + "\n") + + store = PayloadStore(root=tmp_path / "store") + with pytest.raises(ValueError, match="Invalid artifact path"): + store.load("collection") + + +def test_payload_store_rejects_deserialized_artifact_symlink_escape( + tmp_path: Path, +) -> None: + """Serialized artifact paths cannot resolve through symlinks outside artifacts.""" + collection_dir = tmp_path / "store" / "collection" + artifacts_dir = collection_dir / "artifacts" + artifacts_dir.mkdir(parents=True) + outside = tmp_path / "outside.pdf" + outside.write_bytes(b"outside") + symlink = artifacts_dir / "linked.pdf" + symlink.symlink_to(outside) + record: dict[str, object] = { + "id": "safe-id", + "content": "binary", + "format": "pdf", + "metadata": {}, + "artifact": "artifacts/linked.pdf", + } + (collection_dir / "payloads.jsonl").write_text(json.dumps(record) + "\n") + + store = PayloadStore(root=tmp_path / "store") + with pytest.raises(ValueError, match="escapes"): + store.load("collection") diff --git a/tests/unit/surfaces/test_onedrive_security.py b/tests/unit/surfaces/test_onedrive_security.py new file mode 100644 index 0000000..af0b5d8 --- /dev/null +++ b/tests/unit/surfaces/test_onedrive_security.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Security tests for OneDrive payload upload path construction.""" + +from typing import Any, cast + +import pytest + +from rampart.core.types import Payload +from rampart.surfaces.onedrive import OneDriveSurface + + +async def test_onedrive_upload_rejects_unsafe_payload_id_before_graph_call() -> None: + """OneDrive upload refuses unsafe filenames before Graph path construction.""" + payload = Payload(content="content", id="safe-id") + payload.id = "../outside" + surface = OneDriveSurface( + graph_client=cast("Any", object()), + drive_id="drive-id", + folder_path="folder", + ) + + with pytest.raises(ValueError, match="Invalid payload id"): + await surface.upload_async(payload=payload)