Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions rampart/core/payload_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Validation helpers for payload identifiers."""

from __future__ import annotations

import re

_PAYLOAD_ID_PATTERN = re.compile(r"^[A-Za-z0-9._-]{1,128}$")


def validate_payload_id(payload_id: str) -> None:
"""Validate that a payload ID is safe to embed in file names.

Payload IDs are used for local artifact filenames and remote upload
names. Keep them to a small cross-platform filename-safe alphabet so
generated artifacts cannot introduce path separators, path traversal,
control characters, or Graph path-addressing delimiters.
"""
if not _PAYLOAD_ID_PATTERN.fullmatch(payload_id):
msg = (
f"Invalid payload id: {payload_id!r}. Payload IDs must be 1-128 "
"characters using only letters, numbers, '.', '_', and '-'."
)
raise ValueError(msg)

if payload_id in {".", ".."}:
msg = (
f"Invalid payload id: {payload_id!r}. Payload IDs cannot be path segments."
)
raise ValueError(msg)
3 changes: 3 additions & 0 deletions rampart/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Any

from rampart.core.payload_ids import validate_payload_id

if TYPE_CHECKING:
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -119,6 +121,7 @@ class Payload:

def __post_init__(self) -> None:
"""Validate content-format-artifact consistency."""
validate_payload_id(self.id)
if self.format.is_binary and self.artifact is None:
msg = (
f"Binary format {self.format.value} requires an "
Expand Down
52 changes: 50 additions & 2 deletions rampart/payloads/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
from pathlib import Path
from typing import Any

from rampart.core.payload_ids import validate_payload_id
from rampart.core.types import Payload, PayloadFormat

logger = logging.getLogger(__name__)
_MIN_ARTIFACT_PATH_PARTS = 2


class PayloadStore:
Expand Down Expand Up @@ -326,11 +328,54 @@ def _copy_file_artifact(
Returns:
str: Relative artifact path (e.g., 'artifacts/abc123.pdf').
"""
validate_payload_id(payload_id)
artifacts_dir.mkdir(parents=True, exist_ok=True)
filename = f"{payload_id}{extension}"
shutil.copy2(source, artifacts_dir / filename)
destination = artifacts_dir / filename
PayloadStore._ensure_within_directory(
path=destination,
directory=artifacts_dir,
description="artifact destination",
)
shutil.copy2(source, destination)
return f"artifacts/{filename}"

@staticmethod
def _ensure_within_directory(
*,
path: Path,
directory: Path,
description: str,
) -> None:
"""Raise if a resolved path escapes a required directory."""
resolved_path = path.resolve(strict=False)
resolved_directory = directory.resolve(strict=False)
if not resolved_path.is_relative_to(resolved_directory):
msg = f"Invalid {description}: {path!s} escapes {directory!s}"
raise ValueError(msg)

@staticmethod
def _resolve_artifact_path(*, collection_dir: Path, artifact: str) -> Path:
"""Resolve a serialized artifact path inside collection artifacts/."""
artifact_path = Path(artifact)
if artifact_path.is_absolute() or ".." in artifact_path.parts:
msg = f"Invalid artifact path: {artifact!r}. Must stay under artifacts/."
raise ValueError(msg)
if (
len(artifact_path.parts) < _MIN_ARTIFACT_PATH_PARTS
or artifact_path.parts[0] != "artifacts"
):
msg = f"Invalid artifact path: {artifact!r}. Must be under artifacts/."
raise ValueError(msg)

resolved = collection_dir / artifact_path
PayloadStore._ensure_within_directory(
path=resolved,
directory=collection_dir / "artifacts",
description="artifact path",
)
return resolved

@staticmethod
def _deserialize(
*,
Expand All @@ -354,7 +399,10 @@ def _deserialize(

artifact: Path | None = None
if "artifact" in data:
artifact_path = collection_dir / data["artifact"]
artifact_path = PayloadStore._resolve_artifact_path(
collection_dir=collection_dir,
artifact=data["artifact"],
)
if not artifact_path.exists():
msg = f"Missing artifact: {artifact_path}"
raise FileNotFoundError(msg)
Expand Down
2 changes: 2 additions & 0 deletions rampart/surfaces/onedrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from rampart.core.errors import InfrastructureError
from rampart.core.injection import sleep_until_ready
from rampart.core.payload_ids import validate_payload_id

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -110,6 +111,7 @@ async def upload_async(self, *, payload: Payload) -> str:
if the payload exceeds the 4 MiB small-upload limit.
InfrastructureError: If Graph returns no ``DriveItem``.
"""
validate_payload_id(payload.id)
filename = f"{payload.id}{payload.format.extension}"
upload_path = f"{self.folder_path}/{filename}"

Expand Down
100 changes: 100 additions & 0 deletions tests/unit/payloads/test_payload_store_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Security tests for payload IDs and artifact path handling."""

import json
from pathlib import Path

import pytest

from rampart.core.types import Payload, PayloadFormat
from rampart.payloads._store import PayloadStore


@pytest.mark.parametrize(
"payload_id",
[
"",
".",
"..",
"../outside",
"..\\outside",
"/tmp/outside",
"nested/name",
"graph:path",
"line\nbreak",
"x" * 129,
],
)
def test_payload_id_rejects_path_unsafe_values(payload_id: str) -> None:
"""Payload IDs reject values that can become unsafe path components."""
with pytest.raises(ValueError, match="Invalid payload id"):
Payload(content="content", id=payload_id)


def test_payload_store_keeps_binary_artifacts_under_artifacts_dir(
tmp_path: Path,
) -> None:
"""Binary payload artifacts are copied beneath the artifacts directory."""
source = tmp_path / "source.pdf"
source.write_bytes(b"pdf")

store = PayloadStore(root=tmp_path / "store")
payload = Payload(
content="binary",
id="safe-id_1.2",
format=PayloadFormat.PDF,
artifact=source,
)

collection_dir = store.save("collection", payloads=[payload])

expected_artifact = collection_dir / "artifacts" / "safe-id_1.2.pdf"
assert expected_artifact.read_bytes() == b"pdf"
assert store.load("collection")[0].artifact == expected_artifact


def test_payload_store_rejects_deserialized_artifact_traversal(tmp_path: Path) -> None:
"""Serialized artifact paths cannot traverse outside the collection."""
collection_dir = tmp_path / "store" / "collection"
collection_dir.mkdir(parents=True)
outside = tmp_path / "outside.pdf"
outside.write_bytes(b"outside")
record: dict[str, object] = {
"id": "safe-id",
"content": "binary",
"format": "pdf",
"metadata": {},
"artifact": "../outside.pdf",
}
(collection_dir / "payloads.jsonl").write_text(json.dumps(record) + "\n")

store = PayloadStore(root=tmp_path / "store")
with pytest.raises(ValueError, match="Invalid artifact path"):
store.load("collection")


def test_payload_store_rejects_deserialized_artifact_symlink_escape(
tmp_path: Path,
) -> None:
"""Serialized artifact paths cannot resolve through symlinks outside artifacts."""
collection_dir = tmp_path / "store" / "collection"
artifacts_dir = collection_dir / "artifacts"
artifacts_dir.mkdir(parents=True)
outside = tmp_path / "outside.pdf"
outside.write_bytes(b"outside")
symlink = artifacts_dir / "linked.pdf"
symlink.symlink_to(outside)
record: dict[str, object] = {
"id": "safe-id",
"content": "binary",
"format": "pdf",
"metadata": {},
"artifact": "artifacts/linked.pdf",
}
(collection_dir / "payloads.jsonl").write_text(json.dumps(record) + "\n")

store = PayloadStore(root=tmp_path / "store")
with pytest.raises(ValueError, match="escapes"):
store.load("collection")
25 changes: 25 additions & 0 deletions tests/unit/surfaces/test_onedrive_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Security tests for OneDrive payload upload path construction."""

from typing import Any, cast

import pytest

from rampart.core.types import Payload
from rampart.surfaces.onedrive import OneDriveSurface


async def test_onedrive_upload_rejects_unsafe_payload_id_before_graph_call() -> None:
"""OneDrive upload refuses unsafe filenames before Graph path construction."""
payload = Payload(content="content", id="safe-id")
payload.id = "../outside"
surface = OneDriveSurface(
graph_client=cast("Any", object()),
drive_id="drive-id",
folder_path="folder",
)

with pytest.raises(ValueError, match="Invalid payload id"):
await surface.upload_async(payload=payload)