-
Notifications
You must be signed in to change notification settings - Fork 796
FEAT Add safe_extract_zip helper for remote dataset loaders #1957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """ | ||
| Defensive ZIP extraction for untrusted remote archives. | ||
|
|
||
| Remote dataset loaders in PyRIT download ZIP archives from third-party sources | ||
| and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not | ||
| validate member paths, file sizes, or entry types, which leaves the loader | ||
| vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if | ||
| any upstream source is tampered with. | ||
|
|
||
| ``safe_extract_zip`` validates every archive member before writing anything to | ||
| disk. If any member fails validation, no archive members are written from the | ||
| failing call (pre-existing contents of ``dest_dir`` are untouched). | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import io | ||
| import logging | ||
| import os | ||
| import stat | ||
| import zipfile | ||
| from pathlib import Path | ||
| from typing import IO | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # 5 GiB cumulative uncompressed size across all members | ||
| DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3 | ||
| # 1 GiB cap on any single member | ||
| DEFAULT_MAX_FILE_SIZE = 1 * 1024**3 | ||
| # 50_000 entries: above legitimate dataset sizes, defeats inode DoS | ||
| DEFAULT_MAX_FILE_COUNT = 50_000 | ||
| # Reject members whose uncompressed/compressed ratio exceeds this (zip bomb) | ||
| DEFAULT_MAX_COMPRESSION_RATIO = 100 | ||
|
|
||
| # Sanitized permissions applied to extracted entries, stripping any setuid / | ||
| # setgid / sticky / world-write bits the archive may have requested. | ||
| _EXTRACTED_FILE_MODE = 0o644 | ||
| _EXTRACTED_DIR_MODE = 0o755 | ||
|
|
||
| # Predicates for entry types we refuse to extract. | ||
| _DISALLOWED_TYPE_PREDICATES = ( | ||
| stat.S_ISLNK, | ||
| stat.S_ISBLK, | ||
| stat.S_ISCHR, | ||
| stat.S_ISFIFO, | ||
| stat.S_ISSOCK, | ||
| ) | ||
|
|
||
| ZipSource = str | os.PathLike | bytes | IO[bytes] | ||
|
|
||
|
|
||
| class UnsafeArchiveError(Exception): | ||
| """Raised when an archive member fails a safe-extraction precondition.""" | ||
|
|
||
|
|
||
| def safe_extract_zip( | ||
| *, | ||
| source: ZipSource, | ||
| dest_dir: str | os.PathLike, | ||
| max_total_size: int = DEFAULT_MAX_TOTAL_SIZE, | ||
| max_file_size: int = DEFAULT_MAX_FILE_SIZE, | ||
| max_file_count: int = DEFAULT_MAX_FILE_COUNT, | ||
| max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO, | ||
| ) -> Path: | ||
| """ | ||
| Extract a ZIP archive after validating every member. | ||
|
|
||
| Validation runs in a single pass over the archive's central directory | ||
| before any bytes are written. If any check fails, ``UnsafeArchiveError`` is | ||
| raised and no archive members are written from this call. After extraction | ||
| each member's filesystem mode is replaced with a sanitized default so a | ||
| tampered archive cannot set setuid/setgid/sticky/exec bits on the host. | ||
|
|
||
| Args: | ||
| source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``. | ||
| dest_dir: Directory to extract into. Created if it does not exist. | ||
| max_total_size: Cap on the sum of uncompressed member sizes. | ||
| max_file_size: Cap on any single member's uncompressed size. | ||
| max_file_count: Cap on the number of members in the archive. | ||
| max_compression_ratio: Reject members whose uncompressed/compressed | ||
| ratio exceeds this value (zip bomb defense). | ||
|
|
||
| Returns: | ||
| Resolved destination directory. | ||
|
|
||
| Raises: | ||
| UnsafeArchiveError: If any member fails validation. | ||
| """ | ||
| if isinstance(source, (bytes, bytearray)): | ||
| source = io.BytesIO(source) | ||
|
|
||
| dest_real = Path(dest_dir).resolve() | ||
| dest_real.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| with zipfile.ZipFile(source) as zf: | ||
| members = zf.infolist() | ||
| try: | ||
| _validate_members( | ||
| members, | ||
| dest_real=dest_real, | ||
| max_total_size=max_total_size, | ||
| max_file_size=max_file_size, | ||
| max_file_count=max_file_count, | ||
| max_compression_ratio=max_compression_ratio, | ||
| ) | ||
| except UnsafeArchiveError as exc: | ||
| logger.warning("safe_extract_zip rejected archive: %s", exc) | ||
| raise | ||
| for m in members: | ||
| extracted = Path(zf.extract(m, dest_real)) | ||
| _sanitize_extracted_permissions(extracted) | ||
|
|
||
| return dest_real | ||
|
|
||
|
|
||
| def _sanitize_extracted_permissions(path: Path) -> None: | ||
| # zipfile.ZipFile.extract applies the archive's external_attr mode bits on | ||
| # POSIX, so a tampered archive can request setuid/setgid/sticky or | ||
| # executable bits on extracted entries. Replace with a sane default. | ||
| try: | ||
| if path.is_dir(): | ||
| os.chmod(path, _EXTRACTED_DIR_MODE) | ||
| else: | ||
| os.chmod(path, _EXTRACTED_FILE_MODE) | ||
| except OSError as exc: | ||
| logger.warning("safe_extract_zip could not chmod %s: %s", path, exc) | ||
|
|
||
|
|
||
| def _validate_members( | ||
| members: list[zipfile.ZipInfo], | ||
| *, | ||
| dest_real: Path, | ||
| max_total_size: int, | ||
| max_file_size: int, | ||
| max_file_count: int, | ||
| max_compression_ratio: int, | ||
| ) -> None: | ||
| if len(members) > max_file_count: | ||
| raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})") | ||
|
|
||
| total = 0 | ||
| for m in members: | ||
| _reject_disallowed_entry_type(m) | ||
| _reject_absolute_path(m) | ||
| _reject_path_traversal(m, dest_real) | ||
| _reject_oversized_member(m, max_file_size=max_file_size) | ||
| _reject_compression_bomb(m, max_ratio=max_compression_ratio) | ||
|
|
||
| total += m.file_size | ||
| if total > max_total_size: | ||
| raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes") | ||
|
|
||
|
|
||
| def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None: | ||
| # The upper 16 bits of external_attr hold the Unix mode when the archive | ||
| # was created on a Unix system. Check unconditionally because create_system | ||
| # is attacker-controlled metadata: a zip crafted with create_system=0 (DOS) | ||
| # but Unix-style mode bits set should still be rejected. | ||
| mode = m.external_attr >> 16 | ||
| if any(predicate(mode) for predicate in _DISALLOWED_TYPE_PREDICATES): | ||
| raise UnsafeArchiveError(f"disallowed entry type: {m.filename}") | ||
|
|
||
|
|
||
| def _reject_absolute_path(m: zipfile.ZipInfo) -> None: | ||
| name = m.filename | ||
| if name.startswith(("/", "\\")): | ||
| raise UnsafeArchiveError(f"absolute path in archive: {name}") | ||
| if len(name) >= 2 and name[1] == ":": | ||
| raise UnsafeArchiveError(f"drive-letter path in archive: {name}") | ||
|
|
||
|
|
||
| def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None: | ||
| # Explicit null-byte check: Path.resolve() only raises ValueError for | ||
| # embedded null bytes on POSIX. On Windows the path round-trips with the | ||
| # null byte intact, so we need an OS-independent guard up front. | ||
| if "\x00" in m.filename: | ||
| raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}") | ||
| try: | ||
| target = (dest_real / m.filename).resolve() | ||
| except ValueError as exc: | ||
| # Fallback for any other ValueError from Path construction or resolve. | ||
| raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}") from exc | ||
|
Comment on lines
+182
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This null-byte rejection is platform-dependent and currently fails on Windows. Fix: check for null bytes explicitly instead of relying on def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None:
if "\x00" in m.filename:
raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}")
target = (dest_real / m.filename).resolve()
try:
target.relative_to(dest_real)
except ValueError as exc:
raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from excThis makes the guard (and the test) behave consistently across all three CI platforms.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad i missed that out.. fixed the null byte check as you suggested ! thanks for catching this 🙏 |
||
| try: | ||
| target.relative_to(dest_real) | ||
| except ValueError as exc: | ||
| raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc | ||
|
|
||
|
|
||
| def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None: | ||
| if m.file_size > max_file_size: | ||
| raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}") | ||
|
|
||
|
|
||
| def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None: | ||
| if m.file_size <= 0: | ||
| return | ||
| if m.compress_size <= 0: | ||
| # Declared non-zero uncompressed size with zero compressed size is | ||
| # malformed metadata, refuse rather than skip the ratio check. | ||
| raise UnsafeArchiveError( | ||
| f"member {m.filename!r} declares uncompressed size {m.file_size} but compressed size {m.compress_size}" | ||
| ) | ||
| ratio = m.file_size / m.compress_size | ||
| if ratio > max_ratio: | ||
| raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the style guide, functions with more than one parameter should use
*after the first arg (or afterself/cls) to enforce keyword-only call sites — every other multi-arg helper inpyrit/common/(get_random_indices,warn_if_set,get_kwarg_param,get_required_value,get_non_required_value) follows this.sourceanddest_dirare both passed positionally at the three call sites today, which is exactly the readability/typo risk the convention is meant to prevent (safe_extract_zip(zip_file_path, self.zip_dir)reads ambiguously). Suggest:and updating the three callers accordingly.