Skip to content

Commit 6407cd1

Browse files
committed
fix(sandbox): stream checksums after safe local file opens
1 parent 5d16cdd commit 6407cd1

2 files changed

Lines changed: 286 additions & 19 deletions

File tree

src/agents/sandbox/entries/artifacts.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
GitCloneError,
1818
GitCopyError,
1919
GitMissingInImageError,
20-
LocalChecksumError,
20+
LocalArtifactError,
2121
LocalDirReadError,
2222
LocalFileReadError,
23+
WorkspaceArchiveWriteError,
2324
)
2425
from ..materialization import MaterializedFile, gather_in_order
2526
from ..types import ExecResult, User
@@ -33,14 +34,91 @@
3334
_HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY")
3435

3536

36-
def _sha256_handle(handle: io.BufferedReader) -> str:
37-
digest = hashlib.sha256()
38-
while True:
39-
chunk = handle.read(1024 * 1024)
37+
class _HashingReader(io.IOBase):
38+
def __init__(
39+
self,
40+
stream: io.BufferedReader,
41+
*,
42+
read_error_factory: Callable[[OSError], BaseException] | None = None,
43+
) -> None:
44+
self._stream = stream
45+
self._digest = hashlib.sha256()
46+
self._started = False
47+
self._finished = False
48+
self._read_error_factory = read_error_factory
49+
50+
def readable(self) -> bool:
51+
return True
52+
53+
def read(self, size: int = -1) -> bytes:
54+
try:
55+
chunk = self._stream.read(size)
56+
except OSError as e:
57+
if self._read_error_factory is not None:
58+
raise self._read_error_factory(e) from e
59+
raise
60+
if chunk is None:
61+
self._finished = True
62+
return b""
63+
if isinstance(chunk, bytearray):
64+
chunk = bytes(chunk)
65+
self._started = True
4066
if not chunk:
41-
break
42-
digest.update(chunk)
43-
return digest.hexdigest()
67+
self._finished = True
68+
return b""
69+
self._digest.update(chunk)
70+
if size < 0 or len(chunk) < size:
71+
self._finished = True
72+
return chunk
73+
74+
def readinto(self, b: bytearray) -> int:
75+
data = self.read(len(b))
76+
n = len(data)
77+
b[:n] = data
78+
return n
79+
80+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
81+
if self._started:
82+
raise io.UnsupportedOperation("cannot seek after reads begin")
83+
try:
84+
return int(self._stream.seek(offset, whence))
85+
except OSError as e:
86+
if self._read_error_factory is not None:
87+
raise self._read_error_factory(e) from e
88+
raise
89+
90+
def tell(self) -> int:
91+
try:
92+
return int(self._stream.tell())
93+
except OSError as e:
94+
if self._read_error_factory is not None:
95+
raise self._read_error_factory(e) from e
96+
raise
97+
98+
def hexdigest(self) -> str:
99+
if not self._finished:
100+
raise RuntimeError("checksum is not available until the stream is fully consumed")
101+
return self._digest.hexdigest()
102+
103+
104+
def _find_nested_local_artifact_error(exc: BaseException) -> LocalArtifactError | None:
105+
seen: set[int] = set()
106+
current: BaseException | None = exc
107+
while current is not None and id(current) not in seen:
108+
if isinstance(current, LocalArtifactError):
109+
return current
110+
seen.add(id(current))
111+
next_exc = getattr(current, "cause", None)
112+
if not isinstance(next_exc, BaseException):
113+
next_exc = current.__cause__
114+
current = next_exc if isinstance(next_exc, BaseException) else None
115+
return None
116+
117+
118+
def _reraise_nested_local_artifact_error(exc: BaseException) -> None:
119+
nested_local_artifact_error = _find_nested_local_artifact_error(exc)
120+
if nested_local_artifact_error is not None:
121+
raise nested_local_artifact_error
44122

45123

46124
class Dir(BaseEntry):
@@ -122,13 +200,15 @@ async def apply(
122200
)
123201
with os.fdopen(fd, "rb") as f:
124202
fd = None
125-
try:
126-
checksum = _sha256_handle(f)
127-
f.seek(0)
128-
except OSError as e:
129-
raise LocalChecksumError(src=src, cause=e) from e
203+
hashing_reader = _HashingReader(
204+
f,
205+
read_error_factory=lambda e: LocalFileReadError(src=src, cause=e),
206+
)
130207
await session.mkdir(Path(dest).parent, parents=True)
131-
await session.write(dest, f)
208+
await session.write(dest, hashing_reader)
209+
except WorkspaceArchiveWriteError as e:
210+
_reraise_nested_local_artifact_error(e)
211+
raise
132212
except LocalDirReadError as e:
133213
context = dict(e.context)
134214
context.pop("src", None)
@@ -139,7 +219,7 @@ async def apply(
139219
if fd is not None:
140220
os.close(fd)
141221
await self._apply_metadata(session, dest)
142-
return [MaterializedFile(path=dest, sha256=checksum)]
222+
return [MaterializedFile(path=dest, sha256=hashing_reader.hexdigest())]
143223

144224

145225
class LocalDir(BaseEntry):
@@ -367,16 +447,21 @@ async def _copy_local_dir_file(
367447
)
368448
with os.fdopen(fd, "rb") as f:
369449
fd = None
370-
checksum = _sha256_handle(f)
371-
f.seek(0)
450+
hashing_reader = _HashingReader(
451+
f,
452+
read_error_factory=lambda e: LocalFileReadError(src=src, cause=e),
453+
)
372454
await session.mkdir(child_dest.parent, parents=True, user=user)
373-
await session.write(child_dest, f, user=user)
455+
await session.write(child_dest, hashing_reader, user=user)
456+
except WorkspaceArchiveWriteError as e:
457+
_reraise_nested_local_artifact_error(e)
458+
raise
374459
except OSError as e:
375460
raise LocalFileReadError(src=src, cause=e) from e
376461
finally:
377462
if fd is not None:
378463
os.close(fd)
379-
return MaterializedFile(path=child_dest, sha256=checksum)
464+
return MaterializedFile(path=child_dest, sha256=hashing_reader.hexdigest())
380465

381466
def _open_local_dir_file_for_copy(
382467
self, *, base_dir: Path, src_root: Path, rel_child: Path

tests/sandbox/test_entries.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
InvalidManifestPathError,
1717
LocalDirReadError,
1818
LocalFileReadError,
19+
WorkspaceArchiveWriteError,
1920
)
2021
from agents.sandbox.manifest import Manifest
2122
from agents.sandbox.materialization import MaterializedFile
2223
from agents.sandbox.session.base_sandbox_session import BaseSandboxSession
24+
from agents.sandbox.session.workspace_payloads import coerce_write_payload
2325
from agents.sandbox.snapshot import NoopSnapshot
2426
from agents.sandbox.types import ExecResult, User
2527
from tests.utils.factories import TestSessionState
@@ -154,6 +156,77 @@ def test_resolve_workspace_path_rejects_absolute_symlink_escape_for_host_root(
154156
assert exc_info.value.context == {"rel": escaped.as_posix(), "reason": "absolute"}
155157

156158

159+
class _MutatingWriteSession(_RecordingSession):
160+
def __init__(self, mutate_before_read: Callable[[], None]) -> None:
161+
super().__init__()
162+
self._mutate_before_read = mutate_before_read
163+
self._mutated = False
164+
165+
async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
166+
if not self._mutated:
167+
self._mutate_before_read()
168+
self._mutated = True
169+
await super().write(path, data, user=user)
170+
171+
172+
class _ChunkedMutatingWriteSession(_RecordingSession):
173+
def __init__(self, mutate_after_first_chunk: Callable[[], None]) -> None:
174+
super().__init__()
175+
self._mutate_after_first_chunk = mutate_after_first_chunk
176+
self._mutated = False
177+
178+
async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
179+
_ = user
180+
chunks: list[bytes] = []
181+
first = data.read(4)
182+
if isinstance(first, bytes):
183+
chunks.append(first)
184+
if not self._mutated:
185+
self._mutate_after_first_chunk()
186+
self._mutated = True
187+
rest = data.read()
188+
if isinstance(rest, bytes):
189+
chunks.append(rest)
190+
self.writes[path] = b"".join(chunks)
191+
192+
193+
class _PayloadWrappingWriteSession(_RecordingSession):
194+
async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
195+
_ = user
196+
payload = coerce_write_payload(path=path, data=data)
197+
chunks: list[bytes] = []
198+
try:
199+
while True:
200+
chunk = payload.stream.read(4)
201+
if not chunk:
202+
break
203+
chunks.append(chunk)
204+
except Exception as e:
205+
raise WorkspaceArchiveWriteError(path=path, cause=e) from e
206+
self.writes[path] = b"".join(chunks)
207+
208+
209+
class _FailAfterChunkStream(io.BytesIO):
210+
def __init__(self, data: bytes, *, owned_fd: int | None = None) -> None:
211+
super().__init__(data)
212+
self._owned_fd = owned_fd
213+
self._read_count = 0
214+
215+
def read(self, size: int | None = -1) -> bytes:
216+
if self._read_count > 0:
217+
raise OSError("source read failed")
218+
self._read_count += 1
219+
return super().read(-1 if size is None else size)
220+
221+
def close(self) -> None:
222+
try:
223+
super().close()
224+
finally:
225+
if self._owned_fd is not None:
226+
os.close(self._owned_fd)
227+
self._owned_fd = None
228+
229+
157230
def _symlink_or_skip(path: Path, target: Path, *, target_is_directory: bool = False) -> None:
158231
try:
159232
path.symlink_to(target, target_is_directory=target_is_directory)
@@ -184,6 +257,28 @@ async def test_base_sandbox_session_uses_current_working_directory_for_local_fil
184257
assert session.writes[Path("/workspace/copied.txt")] == b"hello"
185258

186259

260+
@pytest.mark.asyncio
261+
async def test_local_file_checksum_matches_written_bytes_when_source_changes(
262+
tmp_path: Path,
263+
) -> None:
264+
source = tmp_path / "source.txt"
265+
source.write_bytes(b"original")
266+
267+
def mutate_source() -> None:
268+
source.write_bytes(b"mutated")
269+
270+
session = _ChunkedMutatingWriteSession(mutate_source)
271+
272+
result = await LocalFile(src=Path("source.txt")).apply(
273+
session,
274+
Path("/workspace/copied.txt"),
275+
tmp_path,
276+
)
277+
278+
written = session.writes[Path("/workspace/copied.txt")]
279+
assert result[0].sha256 == hashlib.sha256(written).hexdigest()
280+
281+
187282
@pytest.mark.asyncio
188283
async def test_local_file_rejects_symlinked_source_ancestors(tmp_path: Path) -> None:
189284
target_dir = tmp_path / "secret-dir"
@@ -271,6 +366,93 @@ async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable(
271366
assert session.writes[Path("/workspace/copied/safe.txt")] == b"safe"
272367

273368

369+
@pytest.mark.asyncio
370+
async def test_local_dir_checksum_matches_written_bytes_when_source_changes(
371+
tmp_path: Path,
372+
) -> None:
373+
src_root = tmp_path / "src"
374+
src_root.mkdir()
375+
src_file = src_root / "safe.txt"
376+
src_file.write_bytes(b"original")
377+
378+
def mutate_source() -> None:
379+
src_file.write_bytes(b"mutated")
380+
381+
session = _ChunkedMutatingWriteSession(mutate_source)
382+
local_dir = LocalDir(src=Path("src"))
383+
384+
result = await local_dir._copy_local_dir_file(
385+
base_dir=tmp_path,
386+
session=session,
387+
src_root=src_root,
388+
src=src_file,
389+
dest_root=Path("/workspace/copied"),
390+
)
391+
392+
written = session.writes[Path("/workspace/copied/safe.txt")]
393+
assert result.sha256 == hashlib.sha256(written).hexdigest()
394+
395+
396+
@pytest.mark.asyncio
397+
async def test_local_file_preserves_local_read_error_when_write_wraps_stream_failures(
398+
monkeypatch: pytest.MonkeyPatch,
399+
tmp_path: Path,
400+
) -> None:
401+
source = (tmp_path / "source.txt").resolve()
402+
source.write_bytes(b"original")
403+
session = _PayloadWrappingWriteSession()
404+
405+
def failing_fdopen(
406+
fd: int,
407+
*args: object,
408+
**kwargs: object,
409+
) -> io.IOBase:
410+
_ = args, kwargs
411+
return _FailAfterChunkStream(b"original", owned_fd=fd)
412+
413+
monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen)
414+
415+
with pytest.raises(LocalFileReadError) as excinfo:
416+
await LocalFile(src=Path("source.txt")).apply(
417+
session,
418+
Path("/workspace/copied.txt"),
419+
tmp_path,
420+
)
421+
422+
assert excinfo.value.context["src"] == str(source)
423+
assert isinstance(excinfo.value.cause, OSError)
424+
425+
426+
@pytest.mark.asyncio
427+
async def test_local_dir_copy_preserves_local_read_error_when_write_wraps_stream_failures(
428+
monkeypatch: pytest.MonkeyPatch,
429+
tmp_path: Path,
430+
) -> None:
431+
src_root = tmp_path / "src"
432+
src_root.mkdir()
433+
src_file = (src_root / "safe.txt").resolve()
434+
src_file.write_bytes(b"original")
435+
session = _PayloadWrappingWriteSession()
436+
local_dir = LocalDir(src=Path("src"))
437+
438+
def failing_fdopen(fd: int, *args: object, **kwargs: object) -> io.IOBase:
439+
return _FailAfterChunkStream(b"original", owned_fd=fd)
440+
441+
monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen)
442+
443+
with pytest.raises(LocalFileReadError) as excinfo:
444+
await local_dir._copy_local_dir_file(
445+
base_dir=tmp_path,
446+
session=session,
447+
src_root=src_root,
448+
src=src_file,
449+
dest_root=Path("/workspace/copied"),
450+
)
451+
452+
assert excinfo.value.context["src"] == str(src_file)
453+
assert isinstance(excinfo.value.cause, OSError)
454+
455+
274456
@pytest.mark.asyncio
275457
async def test_local_dir_copy_revalidates_swapped_paths_during_open(
276458
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)