|
11 | 11 | import agents.sandbox.entries.artifacts as artifacts_module |
12 | 12 | from agents.sandbox import SandboxConcurrencyLimits |
13 | 13 | from agents.sandbox.entries import Dir, File, GitRepo, LocalDir, LocalFile |
14 | | -from agents.sandbox.errors import ExecNonZeroError, LocalDirReadError, LocalFileReadError |
| 14 | +from agents.sandbox.errors import ( |
| 15 | + ExecNonZeroError, |
| 16 | + LocalDirReadError, |
| 17 | + LocalFileReadError, |
| 18 | + WorkspaceArchiveWriteError, |
| 19 | +) |
15 | 20 | from agents.sandbox.manifest import Manifest |
16 | 21 | from agents.sandbox.materialization import MaterializedFile |
17 | 22 | from agents.sandbox.session.base_sandbox_session import BaseSandboxSession |
| 23 | +from agents.sandbox.session.workspace_payloads import coerce_write_payload |
18 | 24 | from agents.sandbox.snapshot import NoopSnapshot |
19 | 25 | from agents.sandbox.types import ExecResult, User |
20 | 26 | from tests.utils.factories import TestSessionState |
@@ -99,6 +105,77 @@ async def _exec_internal( |
99 | 105 | return ExecResult(stdout=b"", stderr=b"", exit_code=0) |
100 | 106 |
|
101 | 107 |
|
| 108 | +class _MutatingWriteSession(_RecordingSession): |
| 109 | + def __init__(self, mutate_before_read: Callable[[], None]) -> None: |
| 110 | + super().__init__() |
| 111 | + self._mutate_before_read = mutate_before_read |
| 112 | + self._mutated = False |
| 113 | + |
| 114 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 115 | + if not self._mutated: |
| 116 | + self._mutate_before_read() |
| 117 | + self._mutated = True |
| 118 | + await super().write(path, data, user=user) |
| 119 | + |
| 120 | + |
| 121 | +class _ChunkedMutatingWriteSession(_RecordingSession): |
| 122 | + def __init__(self, mutate_after_first_chunk: Callable[[], None]) -> None: |
| 123 | + super().__init__() |
| 124 | + self._mutate_after_first_chunk = mutate_after_first_chunk |
| 125 | + self._mutated = False |
| 126 | + |
| 127 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 128 | + _ = user |
| 129 | + chunks: list[bytes] = [] |
| 130 | + first = data.read(4) |
| 131 | + if isinstance(first, bytes): |
| 132 | + chunks.append(first) |
| 133 | + if not self._mutated: |
| 134 | + self._mutate_after_first_chunk() |
| 135 | + self._mutated = True |
| 136 | + rest = data.read() |
| 137 | + if isinstance(rest, bytes): |
| 138 | + chunks.append(rest) |
| 139 | + self.writes[path] = b"".join(chunks) |
| 140 | + |
| 141 | + |
| 142 | +class _PayloadWrappingWriteSession(_RecordingSession): |
| 143 | + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: |
| 144 | + _ = user |
| 145 | + payload = coerce_write_payload(path=path, data=data) |
| 146 | + chunks: list[bytes] = [] |
| 147 | + try: |
| 148 | + while True: |
| 149 | + chunk = payload.stream.read(4) |
| 150 | + if not chunk: |
| 151 | + break |
| 152 | + chunks.append(chunk) |
| 153 | + except Exception as e: |
| 154 | + raise WorkspaceArchiveWriteError(path=path, cause=e) from e |
| 155 | + self.writes[path] = b"".join(chunks) |
| 156 | + |
| 157 | + |
| 158 | +class _FailAfterChunkStream(io.BytesIO): |
| 159 | + def __init__(self, data: bytes, *, owned_fd: int | None = None) -> None: |
| 160 | + super().__init__(data) |
| 161 | + self._owned_fd = owned_fd |
| 162 | + self._read_count = 0 |
| 163 | + |
| 164 | + def read(self, size: int | None = -1) -> bytes: |
| 165 | + if self._read_count > 0: |
| 166 | + raise OSError("source read failed") |
| 167 | + self._read_count += 1 |
| 168 | + return super().read(-1 if size is None else size) |
| 169 | + |
| 170 | + def close(self) -> None: |
| 171 | + try: |
| 172 | + super().close() |
| 173 | + finally: |
| 174 | + if self._owned_fd is not None: |
| 175 | + os.close(self._owned_fd) |
| 176 | + self._owned_fd = None |
| 177 | + |
| 178 | + |
102 | 179 | def _symlink_or_skip(path: Path, target: Path, *, target_is_directory: bool = False) -> None: |
103 | 180 | try: |
104 | 181 | path.symlink_to(target, target_is_directory=target_is_directory) |
@@ -129,6 +206,28 @@ async def test_base_sandbox_session_uses_current_working_directory_for_local_fil |
129 | 206 | assert session.writes[Path("/workspace/copied.txt")] == b"hello" |
130 | 207 |
|
131 | 208 |
|
| 209 | +@pytest.mark.asyncio |
| 210 | +async def test_local_file_checksum_matches_written_bytes_when_source_changes( |
| 211 | + tmp_path: Path, |
| 212 | +) -> None: |
| 213 | + source = tmp_path / "source.txt" |
| 214 | + source.write_bytes(b"original") |
| 215 | + |
| 216 | + def mutate_source() -> None: |
| 217 | + source.write_bytes(b"mutated") |
| 218 | + |
| 219 | + session = _ChunkedMutatingWriteSession(mutate_source) |
| 220 | + |
| 221 | + result = await LocalFile(src=Path("source.txt")).apply( |
| 222 | + session, |
| 223 | + Path("/workspace/copied.txt"), |
| 224 | + tmp_path, |
| 225 | + ) |
| 226 | + |
| 227 | + written = session.writes[Path("/workspace/copied.txt")] |
| 228 | + assert result[0].sha256 == hashlib.sha256(written).hexdigest() |
| 229 | + |
| 230 | + |
132 | 231 | @pytest.mark.asyncio |
133 | 232 | async def test_local_file_rejects_symlinked_source_ancestors(tmp_path: Path) -> None: |
134 | 233 | target_dir = tmp_path / "secret-dir" |
@@ -216,6 +315,93 @@ async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable( |
216 | 315 | assert session.writes[Path("/workspace/copied/safe.txt")] == b"safe" |
217 | 316 |
|
218 | 317 |
|
| 318 | +@pytest.mark.asyncio |
| 319 | +async def test_local_dir_checksum_matches_written_bytes_when_source_changes( |
| 320 | + tmp_path: Path, |
| 321 | +) -> None: |
| 322 | + src_root = tmp_path / "src" |
| 323 | + src_root.mkdir() |
| 324 | + src_file = src_root / "safe.txt" |
| 325 | + src_file.write_bytes(b"original") |
| 326 | + |
| 327 | + def mutate_source() -> None: |
| 328 | + src_file.write_bytes(b"mutated") |
| 329 | + |
| 330 | + session = _ChunkedMutatingWriteSession(mutate_source) |
| 331 | + local_dir = LocalDir(src=Path("src")) |
| 332 | + |
| 333 | + result = await local_dir._copy_local_dir_file( |
| 334 | + base_dir=tmp_path, |
| 335 | + session=session, |
| 336 | + src_root=src_root, |
| 337 | + src=src_file, |
| 338 | + dest_root=Path("/workspace/copied"), |
| 339 | + ) |
| 340 | + |
| 341 | + written = session.writes[Path("/workspace/copied/safe.txt")] |
| 342 | + assert result.sha256 == hashlib.sha256(written).hexdigest() |
| 343 | + |
| 344 | + |
| 345 | +@pytest.mark.asyncio |
| 346 | +async def test_local_file_preserves_local_read_error_when_write_wraps_stream_failures( |
| 347 | + monkeypatch: pytest.MonkeyPatch, |
| 348 | + tmp_path: Path, |
| 349 | +) -> None: |
| 350 | + source = (tmp_path / "source.txt").resolve() |
| 351 | + source.write_bytes(b"original") |
| 352 | + session = _PayloadWrappingWriteSession() |
| 353 | + |
| 354 | + def failing_fdopen( |
| 355 | + fd: int, |
| 356 | + *args: object, |
| 357 | + **kwargs: object, |
| 358 | + ) -> io.IOBase: |
| 359 | + _ = args, kwargs |
| 360 | + return _FailAfterChunkStream(b"original", owned_fd=fd) |
| 361 | + |
| 362 | + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) |
| 363 | + |
| 364 | + with pytest.raises(LocalFileReadError) as excinfo: |
| 365 | + await LocalFile(src=Path("source.txt")).apply( |
| 366 | + session, |
| 367 | + Path("/workspace/copied.txt"), |
| 368 | + tmp_path, |
| 369 | + ) |
| 370 | + |
| 371 | + assert excinfo.value.context["src"] == str(source) |
| 372 | + assert isinstance(excinfo.value.cause, OSError) |
| 373 | + |
| 374 | + |
| 375 | +@pytest.mark.asyncio |
| 376 | +async def test_local_dir_copy_preserves_local_read_error_when_write_wraps_stream_failures( |
| 377 | + monkeypatch: pytest.MonkeyPatch, |
| 378 | + tmp_path: Path, |
| 379 | +) -> None: |
| 380 | + src_root = tmp_path / "src" |
| 381 | + src_root.mkdir() |
| 382 | + src_file = (src_root / "safe.txt").resolve() |
| 383 | + src_file.write_bytes(b"original") |
| 384 | + session = _PayloadWrappingWriteSession() |
| 385 | + local_dir = LocalDir(src=Path("src")) |
| 386 | + |
| 387 | + def failing_fdopen(fd: int, *args: object, **kwargs: object) -> io.IOBase: |
| 388 | + return _FailAfterChunkStream(b"original", owned_fd=fd) |
| 389 | + |
| 390 | + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) |
| 391 | + |
| 392 | + with pytest.raises(LocalFileReadError) as excinfo: |
| 393 | + await local_dir._copy_local_dir_file( |
| 394 | + base_dir=tmp_path, |
| 395 | + session=session, |
| 396 | + src_root=src_root, |
| 397 | + src=src_file, |
| 398 | + dest_root=Path("/workspace/copied"), |
| 399 | + ) |
| 400 | + |
| 401 | + assert excinfo.value.context["src"] == str(src_file) |
| 402 | + assert isinstance(excinfo.value.cause, OSError) |
| 403 | + |
| 404 | + |
219 | 405 | @pytest.mark.asyncio |
220 | 406 | async def test_local_dir_copy_revalidates_swapped_paths_during_open( |
221 | 407 | monkeypatch: pytest.MonkeyPatch, |
|
0 commit comments