|
16 | 16 | InvalidManifestPathError, |
17 | 17 | LocalDirReadError, |
18 | 18 | LocalFileReadError, |
| 19 | + WorkspaceArchiveWriteError, |
19 | 20 | ) |
20 | 21 | from agents.sandbox.manifest import Manifest |
21 | 22 | from agents.sandbox.materialization import MaterializedFile |
22 | 23 | from agents.sandbox.session.base_sandbox_session import BaseSandboxSession |
| 24 | +from agents.sandbox.session.workspace_payloads import coerce_write_payload |
23 | 25 | from agents.sandbox.snapshot import NoopSnapshot |
24 | 26 | from agents.sandbox.types import ExecResult, User |
25 | 27 | from tests.utils.factories import TestSessionState |
@@ -154,6 +156,77 @@ def test_resolve_workspace_path_rejects_absolute_symlink_escape_for_host_root( |
154 | 156 | assert exc_info.value.context == {"rel": escaped.as_posix(), "reason": "absolute"} |
155 | 157 |
|
156 | 158 |
|
| 159 | +class _MutatingWriteSession(_RecordingSession): |
| 160 | + def __init__(self, mutate_before_read: Callable[[], None]) -> None: |
| 161 | + super().__init__() |
| 162 | + self._mutate_before_read = mutate_before_read |
| 163 | + self._mutated = False |
| 164 | + |
| 165 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 166 | + if not self._mutated: |
| 167 | + self._mutate_before_read() |
| 168 | + self._mutated = True |
| 169 | + await super().write(path, data, user=user) |
| 170 | + |
| 171 | + |
| 172 | +class _ChunkedMutatingWriteSession(_RecordingSession): |
| 173 | + def __init__(self, mutate_after_first_chunk: Callable[[], None]) -> None: |
| 174 | + super().__init__() |
| 175 | + self._mutate_after_first_chunk = mutate_after_first_chunk |
| 176 | + self._mutated = False |
| 177 | + |
| 178 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 179 | + _ = user |
| 180 | + chunks: list[bytes] = [] |
| 181 | + first = data.read(4) |
| 182 | + if isinstance(first, bytes): |
| 183 | + chunks.append(first) |
| 184 | + if not self._mutated: |
| 185 | + self._mutate_after_first_chunk() |
| 186 | + self._mutated = True |
| 187 | + rest = data.read() |
| 188 | + if isinstance(rest, bytes): |
| 189 | + chunks.append(rest) |
| 190 | + self.writes[path] = b"".join(chunks) |
| 191 | + |
| 192 | + |
| 193 | +class _PayloadWrappingWriteSession(_RecordingSession): |
| 194 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 195 | + _ = user |
| 196 | + payload = coerce_write_payload(path=path, data=data) |
| 197 | + chunks: list[bytes] = [] |
| 198 | + try: |
| 199 | + while True: |
| 200 | + chunk = payload.stream.read(4) |
| 201 | + if not chunk: |
| 202 | + break |
| 203 | + chunks.append(chunk) |
| 204 | + except Exception as e: |
| 205 | + raise WorkspaceArchiveWriteError(path=path, cause=e) from e |
| 206 | + self.writes[path] = b"".join(chunks) |
| 207 | + |
| 208 | + |
| 209 | +class _FailAfterChunkStream(io.BytesIO): |
| 210 | + def __init__(self, data: bytes, *, owned_fd: int | None = None) -> None: |
| 211 | + super().__init__(data) |
| 212 | + self._owned_fd = owned_fd |
| 213 | + self._read_count = 0 |
| 214 | + |
| 215 | + def read(self, size: int | None = -1) -> bytes: |
| 216 | + if self._read_count > 0: |
| 217 | + raise OSError("source read failed") |
| 218 | + self._read_count += 1 |
| 219 | + return super().read(-1 if size is None else size) |
| 220 | + |
| 221 | + def close(self) -> None: |
| 222 | + try: |
| 223 | + super().close() |
| 224 | + finally: |
| 225 | + if self._owned_fd is not None: |
| 226 | + os.close(self._owned_fd) |
| 227 | + self._owned_fd = None |
| 228 | + |
| 229 | + |
157 | 230 | def _symlink_or_skip(path: Path, target: Path, *, target_is_directory: bool = False) -> None: |
158 | 231 | try: |
159 | 232 | path.symlink_to(target, target_is_directory=target_is_directory) |
@@ -184,6 +257,28 @@ async def test_base_sandbox_session_uses_current_working_directory_for_local_fil |
184 | 257 | assert session.writes[Path("/workspace/copied.txt")] == b"hello" |
185 | 258 |
|
186 | 259 |
|
| 260 | +@pytest.mark.asyncio |
| 261 | +async def test_local_file_checksum_matches_written_bytes_when_source_changes( |
| 262 | + tmp_path: Path, |
| 263 | +) -> None: |
| 264 | + source = tmp_path / "source.txt" |
| 265 | + source.write_bytes(b"original") |
| 266 | + |
| 267 | + def mutate_source() -> None: |
| 268 | + source.write_bytes(b"mutated") |
| 269 | + |
| 270 | + session = _ChunkedMutatingWriteSession(mutate_source) |
| 271 | + |
| 272 | + result = await LocalFile(src=Path("source.txt")).apply( |
| 273 | + session, |
| 274 | + Path("/workspace/copied.txt"), |
| 275 | + tmp_path, |
| 276 | + ) |
| 277 | + |
| 278 | + written = session.writes[Path("/workspace/copied.txt")] |
| 279 | + assert result[0].sha256 == hashlib.sha256(written).hexdigest() |
| 280 | + |
| 281 | + |
187 | 282 | @pytest.mark.asyncio |
188 | 283 | async def test_local_file_rejects_symlinked_source_ancestors(tmp_path: Path) -> None: |
189 | 284 | target_dir = tmp_path / "secret-dir" |
@@ -271,6 +366,93 @@ async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable( |
271 | 366 | assert session.writes[Path("/workspace/copied/safe.txt")] == b"safe" |
272 | 367 |
|
273 | 368 |
|
| 369 | +@pytest.mark.asyncio |
| 370 | +async def test_local_dir_checksum_matches_written_bytes_when_source_changes( |
| 371 | + tmp_path: Path, |
| 372 | +) -> None: |
| 373 | + src_root = tmp_path / "src" |
| 374 | + src_root.mkdir() |
| 375 | + src_file = src_root / "safe.txt" |
| 376 | + src_file.write_bytes(b"original") |
| 377 | + |
| 378 | + def mutate_source() -> None: |
| 379 | + src_file.write_bytes(b"mutated") |
| 380 | + |
| 381 | + session = _ChunkedMutatingWriteSession(mutate_source) |
| 382 | + local_dir = LocalDir(src=Path("src")) |
| 383 | + |
| 384 | + result = await local_dir._copy_local_dir_file( |
| 385 | + base_dir=tmp_path, |
| 386 | + session=session, |
| 387 | + src_root=src_root, |
| 388 | + src=src_file, |
| 389 | + dest_root=Path("/workspace/copied"), |
| 390 | + ) |
| 391 | + |
| 392 | + written = session.writes[Path("/workspace/copied/safe.txt")] |
| 393 | + assert result.sha256 == hashlib.sha256(written).hexdigest() |
| 394 | + |
| 395 | + |
| 396 | +@pytest.mark.asyncio |
| 397 | +async def test_local_file_preserves_local_read_error_when_write_wraps_stream_failures( |
| 398 | + monkeypatch: pytest.MonkeyPatch, |
| 399 | + tmp_path: Path, |
| 400 | +) -> None: |
| 401 | + source = (tmp_path / "source.txt").resolve() |
| 402 | + source.write_bytes(b"original") |
| 403 | + session = _PayloadWrappingWriteSession() |
| 404 | + |
| 405 | + def failing_fdopen( |
| 406 | + fd: int, |
| 407 | + *args: object, |
| 408 | + **kwargs: object, |
| 409 | + ) -> io.IOBase: |
| 410 | + _ = args, kwargs |
| 411 | + return _FailAfterChunkStream(b"original", owned_fd=fd) |
| 412 | + |
| 413 | + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) |
| 414 | + |
| 415 | + with pytest.raises(LocalFileReadError) as excinfo: |
| 416 | + await LocalFile(src=Path("source.txt")).apply( |
| 417 | + session, |
| 418 | + Path("/workspace/copied.txt"), |
| 419 | + tmp_path, |
| 420 | + ) |
| 421 | + |
| 422 | + assert excinfo.value.context["src"] == str(source) |
| 423 | + assert isinstance(excinfo.value.cause, OSError) |
| 424 | + |
| 425 | + |
| 426 | +@pytest.mark.asyncio |
| 427 | +async def test_local_dir_copy_preserves_local_read_error_when_write_wraps_stream_failures( |
| 428 | + monkeypatch: pytest.MonkeyPatch, |
| 429 | + tmp_path: Path, |
| 430 | +) -> None: |
| 431 | + src_root = tmp_path / "src" |
| 432 | + src_root.mkdir() |
| 433 | + src_file = (src_root / "safe.txt").resolve() |
| 434 | + src_file.write_bytes(b"original") |
| 435 | + session = _PayloadWrappingWriteSession() |
| 436 | + local_dir = LocalDir(src=Path("src")) |
| 437 | + |
| 438 | + def failing_fdopen(fd: int, *args: object, **kwargs: object) -> io.IOBase: |
| 439 | + return _FailAfterChunkStream(b"original", owned_fd=fd) |
| 440 | + |
| 441 | + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) |
| 442 | + |
| 443 | + with pytest.raises(LocalFileReadError) as excinfo: |
| 444 | + await local_dir._copy_local_dir_file( |
| 445 | + base_dir=tmp_path, |
| 446 | + session=session, |
| 447 | + src_root=src_root, |
| 448 | + src=src_file, |
| 449 | + dest_root=Path("/workspace/copied"), |
| 450 | + ) |
| 451 | + |
| 452 | + assert excinfo.value.context["src"] == str(src_file) |
| 453 | + assert isinstance(excinfo.value.cause, OSError) |
| 454 | + |
| 455 | + |
274 | 456 | @pytest.mark.asyncio |
275 | 457 | async def test_local_dir_copy_revalidates_swapped_paths_during_open( |
276 | 458 | monkeypatch: pytest.MonkeyPatch, |
|
0 commit comments