Skip to content

Commit 23168ab

Browse files
committed
fix(sandbox): Stream artifact checksum writes
Avoid spooling full artifact contents into a temporary file before write. Hash the same stream as it is consumed by session.write and update the checksum tests to assert that recorded hashes match the bytes that were actually materialized.
1 parent 344bb47 commit 23168ab

2 files changed

Lines changed: 77 additions & 43 deletions

File tree

src/agents/sandbox/entries/artifacts.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import os
77
import re
88
import stat
9-
import tempfile
109
import uuid
1110
from collections.abc import Awaitable, Callable, Mapping
1211
from pathlib import Path
13-
from typing import TYPE_CHECKING, Literal, cast
12+
from typing import TYPE_CHECKING, Literal
1413

1514
from pydantic import Field, field_serializer, field_validator
1615

@@ -31,27 +30,52 @@
3130
_COMMIT_REF_RE = re.compile(r"[0-9a-fA-F]{7,40}")
3231
_OPEN_SUPPORTS_DIR_FD = os.open in os.supports_dir_fd
3332
_HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY")
34-
_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES = 8 * 1024 * 1024
35-
36-
37-
def _buffer_handle_with_checksum(handle: io.BufferedReader) -> tuple[io.IOBase, str]:
38-
digest = hashlib.sha256()
39-
buffer = tempfile.SpooledTemporaryFile(
40-
max_size=_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES,
41-
mode="w+b",
42-
)
43-
try:
44-
while True:
45-
chunk = handle.read(1024 * 1024)
46-
if not chunk:
47-
break
48-
digest.update(chunk)
49-
buffer.write(chunk)
50-
buffer.seek(0)
51-
return cast(io.IOBase, buffer), digest.hexdigest()
52-
except BaseException:
53-
buffer.close()
54-
raise
33+
34+
35+
class _HashingReader(io.IOBase):
36+
def __init__(self, stream: io.BufferedReader) -> None:
37+
self._stream = stream
38+
self._digest = hashlib.sha256()
39+
self._started = False
40+
self._finished = False
41+
42+
def readable(self) -> bool:
43+
return True
44+
45+
def read(self, size: int = -1) -> bytes:
46+
chunk = self._stream.read(size)
47+
if chunk is None:
48+
self._finished = True
49+
return b""
50+
if isinstance(chunk, bytearray):
51+
chunk = bytes(chunk)
52+
self._started = True
53+
if not chunk:
54+
self._finished = True
55+
return b""
56+
self._digest.update(chunk)
57+
if size < 0 or len(chunk) < size:
58+
self._finished = True
59+
return chunk
60+
61+
def readinto(self, b: bytearray) -> int:
62+
data = self.read(len(b))
63+
n = len(data)
64+
b[:n] = data
65+
return n
66+
67+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
68+
if self._started:
69+
raise io.UnsupportedOperation("cannot seek after reads begin")
70+
return int(self._stream.seek(offset, whence))
71+
72+
def tell(self) -> int:
73+
return int(self._stream.tell())
74+
75+
def hexdigest(self) -> str:
76+
if not self._finished:
77+
raise RuntimeError("checksum is not available until the stream is fully consumed")
78+
return self._digest.hexdigest()
5579

5680

5781
class Dir(BaseEntry):
@@ -120,21 +144,15 @@ async def apply(
120144
base_dir: Path,
121145
) -> list[MaterializedFile]:
122146
src = (base_dir / self.src).resolve()
123-
try:
124-
with src.open("rb") as f:
125-
buffered, checksum = _buffer_handle_with_checksum(f)
126-
except OSError as e:
127-
raise LocalFileReadError(src=src, cause=e) from e
128147
await session.mkdir(Path(dest).parent, parents=True)
129148
try:
130-
try:
131-
await session.write(dest, buffered)
132-
finally:
133-
buffered.close()
149+
with src.open("rb") as f:
150+
hashing_reader = _HashingReader(f)
151+
await session.write(dest, hashing_reader)
134152
except OSError as e:
135153
raise LocalFileReadError(src=src, cause=e) from e
136154
await self._apply_metadata(session, dest)
137-
return [MaterializedFile(path=dest, sha256=checksum)]
155+
return [MaterializedFile(path=dest, sha256=hashing_reader.hexdigest())]
138156

139157

140158
class LocalDir(BaseEntry):
@@ -362,18 +380,15 @@ async def _copy_local_dir_file(
362380
)
363381
with os.fdopen(fd, "rb") as f:
364382
fd = None
365-
buffered, checksum = _buffer_handle_with_checksum(f)
383+
hashing_reader = _HashingReader(f)
366384
await session.mkdir(child_dest.parent, parents=True, user=user)
367-
try:
368-
await session.write(child_dest, buffered, user=user)
369-
finally:
370-
buffered.close()
385+
await session.write(child_dest, hashing_reader, user=user)
371386
except OSError as e:
372387
raise LocalFileReadError(src=src, cause=e) from e
373388
finally:
374389
if fd is not None:
375390
os.close(fd)
376-
return MaterializedFile(path=child_dest, sha256=checksum)
391+
return MaterializedFile(path=child_dest, sha256=hashing_reader.hexdigest())
377392

378393
def _open_local_dir_file_for_copy(
379394
self, *, base_dir: Path, src_root: Path, rel_child: Path

tests/sandbox/test_entries.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,27 @@ async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> No
112112
await super().write(path, data, user=user)
113113

114114

115+
class _ChunkedMutatingWriteSession(_RecordingSession):
116+
def __init__(self, mutate_after_first_chunk: Callable[[], None]) -> None:
117+
super().__init__()
118+
self._mutate_after_first_chunk = mutate_after_first_chunk
119+
self._mutated = False
120+
121+
async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
122+
_ = user
123+
chunks: list[bytes] = []
124+
first = data.read(4)
125+
if isinstance(first, bytes):
126+
chunks.append(first)
127+
if not self._mutated:
128+
self._mutate_after_first_chunk()
129+
self._mutated = True
130+
rest = data.read()
131+
if isinstance(rest, bytes):
132+
chunks.append(rest)
133+
self.writes[path] = b"".join(chunks)
134+
135+
115136
@pytest.mark.asyncio
116137
async def test_base_sandbox_session_uses_current_working_directory_for_local_file_sources(
117138
monkeypatch: pytest.MonkeyPatch,
@@ -138,7 +159,7 @@ async def test_local_file_checksum_matches_written_bytes_when_source_changes(
138159
) -> None:
139160
source = tmp_path / "source.txt"
140161
source.write_bytes(b"original")
141-
session = _MutatingWriteSession(lambda: source.write_bytes(b"mutated"))
162+
session = _ChunkedMutatingWriteSession(lambda: source.write_bytes(b"mutated"))
142163

143164
result = await LocalFile(src=Path("source.txt")).apply(
144165
session,
@@ -147,7 +168,6 @@ async def test_local_file_checksum_matches_written_bytes_when_source_changes(
147168
)
148169

149170
written = session.writes[Path("/workspace/copied.txt")]
150-
assert written == b"original"
151171
assert result[0].sha256 == hashlib.sha256(written).hexdigest()
152172

153173

@@ -186,7 +206,7 @@ async def test_local_dir_checksum_matches_written_bytes_when_source_changes(
186206
src_root.mkdir()
187207
src_file = src_root / "safe.txt"
188208
src_file.write_bytes(b"original")
189-
session = _MutatingWriteSession(lambda: src_file.write_bytes(b"mutated"))
209+
session = _ChunkedMutatingWriteSession(lambda: src_file.write_bytes(b"mutated"))
190210
local_dir = LocalDir(src=Path("src"))
191211

192212
result = await local_dir._copy_local_dir_file(
@@ -198,7 +218,6 @@ async def test_local_dir_checksum_matches_written_bytes_when_source_changes(
198218
)
199219

200220
written = session.writes[Path("/workspace/copied/safe.txt")]
201-
assert written == b"original"
202221
assert result.sha256 == hashlib.sha256(written).hexdigest()
203222

204223

0 commit comments

Comments
 (0)