Skip to content

Commit 344bb47

Browse files
committed
fix(sandbox): Keep artifact checksums in sync
Buffer local artifact contents while hashing so recorded checksums match the exact bytes written into the sandbox. Cover both LocalFile and LocalDir with focused mutation tests.
1 parent da82b2c commit 344bb47

2 files changed

Lines changed: 89 additions & 17 deletions

File tree

src/agents/sandbox/entries/artifacts.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,23 @@
66
import os
77
import re
88
import stat
9+
import tempfile
910
import uuid
1011
from collections.abc import Awaitable, Callable, Mapping
1112
from pathlib import Path
12-
from typing import TYPE_CHECKING, Literal
13+
from typing import TYPE_CHECKING, Literal, cast
1314

1415
from pydantic import Field, field_serializer, field_validator
1516

1617
from ..errors import (
1718
GitCloneError,
1819
GitCopyError,
1920
GitMissingInImageError,
20-
LocalChecksumError,
2121
LocalDirReadError,
2222
LocalFileReadError,
2323
)
2424
from ..materialization import MaterializedFile, gather_in_order
2525
from ..types import ExecResult, User
26-
from ..util.checksums import sha256_file
2726
from .base import BaseEntry
2827

2928
if TYPE_CHECKING:
@@ -32,16 +31,27 @@
3231
_COMMIT_REF_RE = re.compile(r"[0-9a-fA-F]{7,40}")
3332
_OPEN_SUPPORTS_DIR_FD = os.open in os.supports_dir_fd
3433
_HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY")
34+
_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES = 8 * 1024 * 1024
3535

3636

37-
def _sha256_handle(handle: io.BufferedReader) -> str:
37+
def _buffer_handle_with_checksum(handle: io.BufferedReader) -> tuple[io.IOBase, str]:
3838
digest = hashlib.sha256()
39-
while True:
40-
chunk = handle.read(1024 * 1024)
41-
if not chunk:
42-
break
43-
digest.update(chunk)
44-
return digest.hexdigest()
39+
buffer = tempfile.SpooledTemporaryFile(
40+
max_size=_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES,
41+
mode="w+b",
42+
)
43+
try:
44+
while True:
45+
chunk = handle.read(1024 * 1024)
46+
if not chunk:
47+
break
48+
digest.update(chunk)
49+
buffer.write(chunk)
50+
buffer.seek(0)
51+
return cast(io.IOBase, buffer), digest.hexdigest()
52+
except BaseException:
53+
buffer.close()
54+
raise
4555

4656

4757
class Dir(BaseEntry):
@@ -111,13 +121,16 @@ async def apply(
111121
) -> list[MaterializedFile]:
112122
src = (base_dir / self.src).resolve()
113123
try:
114-
checksum = sha256_file(src)
124+
with src.open("rb") as f:
125+
buffered, checksum = _buffer_handle_with_checksum(f)
115126
except OSError as e:
116-
raise LocalChecksumError(src=src, cause=e) from e
127+
raise LocalFileReadError(src=src, cause=e) from e
117128
await session.mkdir(Path(dest).parent, parents=True)
118129
try:
119-
with src.open("rb") as f:
120-
await session.write(dest, f)
130+
try:
131+
await session.write(dest, buffered)
132+
finally:
133+
buffered.close()
121134
except OSError as e:
122135
raise LocalFileReadError(src=src, cause=e) from e
123136
await self._apply_metadata(session, dest)
@@ -349,10 +362,12 @@ async def _copy_local_dir_file(
349362
)
350363
with os.fdopen(fd, "rb") as f:
351364
fd = None
352-
checksum = _sha256_handle(f)
353-
f.seek(0)
365+
buffered, checksum = _buffer_handle_with_checksum(f)
354366
await session.mkdir(child_dest.parent, parents=True, user=user)
355-
await session.write(child_dest, f, user=user)
367+
try:
368+
await session.write(child_dest, buffered, user=user)
369+
finally:
370+
buffered.close()
356371
except OSError as e:
357372
raise LocalFileReadError(src=src, cause=e) from e
358373
finally:

tests/sandbox/test_entries.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import io
45
import os
56
from collections.abc import Awaitable, Callable, Sequence
@@ -98,6 +99,19 @@ async def _exec_internal(
9899
return ExecResult(stdout=b"", stderr=b"", exit_code=0)
99100

100101

102+
class _MutatingWriteSession(_RecordingSession):
103+
def __init__(self, mutate_before_read: Callable[[], None]) -> None:
104+
super().__init__()
105+
self._mutate_before_read = mutate_before_read
106+
self._mutated = False
107+
108+
async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
109+
if not self._mutated:
110+
self._mutate_before_read()
111+
self._mutated = True
112+
await super().write(path, data, user=user)
113+
114+
101115
@pytest.mark.asyncio
102116
async def test_base_sandbox_session_uses_current_working_directory_for_local_file_sources(
103117
monkeypatch: pytest.MonkeyPatch,
@@ -118,6 +132,25 @@ async def test_base_sandbox_session_uses_current_working_directory_for_local_fil
118132
assert session.writes[Path("/workspace/copied.txt")] == b"hello"
119133

120134

135+
@pytest.mark.asyncio
136+
async def test_local_file_checksum_matches_written_bytes_when_source_changes(
137+
tmp_path: Path,
138+
) -> None:
139+
source = tmp_path / "source.txt"
140+
source.write_bytes(b"original")
141+
session = _MutatingWriteSession(lambda: source.write_bytes(b"mutated"))
142+
143+
result = await LocalFile(src=Path("source.txt")).apply(
144+
session,
145+
Path("/workspace/copied.txt"),
146+
tmp_path,
147+
)
148+
149+
written = session.writes[Path("/workspace/copied.txt")]
150+
assert written == b"original"
151+
assert result[0].sha256 == hashlib.sha256(written).hexdigest()
152+
153+
121154
@pytest.mark.asyncio
122155
async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable(
123156
monkeypatch: pytest.MonkeyPatch,
@@ -145,6 +178,30 @@ async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable(
145178
assert session.writes[Path("/workspace/copied/safe.txt")] == b"safe"
146179

147180

181+
@pytest.mark.asyncio
182+
async def test_local_dir_checksum_matches_written_bytes_when_source_changes(
183+
tmp_path: Path,
184+
) -> None:
185+
src_root = tmp_path / "src"
186+
src_root.mkdir()
187+
src_file = src_root / "safe.txt"
188+
src_file.write_bytes(b"original")
189+
session = _MutatingWriteSession(lambda: src_file.write_bytes(b"mutated"))
190+
local_dir = LocalDir(src=Path("src"))
191+
192+
result = await local_dir._copy_local_dir_file(
193+
base_dir=tmp_path,
194+
session=session,
195+
src_root=src_root,
196+
src=src_file,
197+
dest_root=Path("/workspace/copied"),
198+
)
199+
200+
written = session.writes[Path("/workspace/copied/safe.txt")]
201+
assert written == b"original"
202+
assert result.sha256 == hashlib.sha256(written).hexdigest()
203+
204+
148205
@pytest.mark.asyncio
149206
async def test_local_dir_copy_revalidates_swapped_paths_during_open(
150207
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)