Skip to content

Commit 3aad7eb

Browse files
authored
refactor: share sandbox ephemeral mount lifecycle (#2986)
This pull request improves sandbox backend persistence by extracting the common ephemeral mount teardown and restore flow into a shared session helper. Cloudflare and Vercel persistence now use the shared lifecycle wrapper for persist and hydrate operations while preserving existing archive error precedence and corruption context metadata.
1 parent 4c5112c commit 3aad7eb

4 files changed

Lines changed: 276 additions & 180 deletions

File tree

src/agents/extensions/sandbox/cloudflare/sandbox.py

Lines changed: 15 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from ....sandbox.session.base_sandbox_session import BaseSandboxSession
4848
from ....sandbox.session.dependencies import Dependencies
4949
from ....sandbox.session.manager import Instrumentation
50+
from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed
5051
from ....sandbox.session.pty_types import (
5152
PTY_PROCESSES_MAX,
5253
PTY_PROCESSES_WARNING,
@@ -1204,91 +1205,23 @@ async def _hydrate_workspace_via_http(self, data: io.IOBase) -> None:
12041205

12051206
async def persist_workspace(self) -> io.IOBase:
12061207
root = self._workspace_root_path()
1207-
unmounted_mounts: list[tuple[Any, Path]] = []
1208-
unmount_error: WorkspaceArchiveReadError | None = None
1209-
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
1210-
try:
1211-
await mount_entry.mount_strategy.teardown_for_snapshot(
1212-
mount_entry, self, mount_path
1213-
)
1214-
except Exception as e:
1215-
unmount_error = WorkspaceArchiveReadError(path=root, cause=e)
1216-
break
1217-
unmounted_mounts.append((mount_entry, mount_path))
1218-
1219-
snapshot_error: WorkspaceArchiveReadError | None = None
1220-
persisted: io.IOBase | None = None
1221-
if unmount_error is None:
1222-
try:
1223-
persisted = await self._persist_workspace_via_http()
1224-
except WorkspaceArchiveReadError as e:
1225-
snapshot_error = e
1226-
1227-
remount_error: WorkspaceArchiveReadError | None = None
1228-
for mount_entry, mount_path in reversed(unmounted_mounts):
1229-
try:
1230-
await mount_entry.mount_strategy.restore_after_snapshot(
1231-
mount_entry, self, mount_path
1232-
)
1233-
except Exception as e:
1234-
if remount_error is None:
1235-
remount_error = WorkspaceArchiveReadError(path=root, cause=e)
1236-
1237-
if remount_error is not None:
1238-
if snapshot_error is not None:
1239-
remount_error.context["snapshot_error_before_remount_corruption"] = {
1240-
"message": snapshot_error.message,
1241-
}
1242-
raise remount_error
1243-
if unmount_error is not None:
1244-
raise unmount_error
1245-
if snapshot_error is not None:
1246-
raise snapshot_error
1247-
1248-
assert persisted is not None
1249-
return persisted
1208+
return await with_ephemeral_mounts_removed(
1209+
self,
1210+
self._persist_workspace_via_http,
1211+
error_path=root,
1212+
error_cls=WorkspaceArchiveReadError,
1213+
operation_error_context_key="snapshot_error_before_remount_corruption",
1214+
)
12501215

12511216
async def hydrate_workspace(self, data: io.IOBase) -> None:
12521217
root = self._workspace_root_path()
1253-
unmounted_mounts: list[tuple[Any, Path]] = []
1254-
unmount_error: WorkspaceArchiveWriteError | None = None
1255-
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
1256-
try:
1257-
await mount_entry.mount_strategy.teardown_for_snapshot(
1258-
mount_entry, self, mount_path
1259-
)
1260-
except Exception as e:
1261-
unmount_error = WorkspaceArchiveWriteError(path=root, cause=e)
1262-
break
1263-
unmounted_mounts.append((mount_entry, mount_path))
1264-
1265-
hydrate_error: WorkspaceArchiveWriteError | None = None
1266-
if unmount_error is None:
1267-
try:
1268-
await self._hydrate_workspace_via_http(data)
1269-
except WorkspaceArchiveWriteError as e:
1270-
hydrate_error = e
1271-
1272-
remount_error: WorkspaceArchiveWriteError | None = None
1273-
for mount_entry, mount_path in reversed(unmounted_mounts):
1274-
try:
1275-
await mount_entry.mount_strategy.restore_after_snapshot(
1276-
mount_entry, self, mount_path
1277-
)
1278-
except Exception as e:
1279-
if remount_error is None:
1280-
remount_error = WorkspaceArchiveWriteError(path=root, cause=e)
1281-
1282-
if remount_error is not None:
1283-
if hydrate_error is not None:
1284-
remount_error.context["hydrate_error_before_remount_corruption"] = {
1285-
"message": hydrate_error.message,
1286-
}
1287-
raise remount_error
1288-
if unmount_error is not None:
1289-
raise unmount_error
1290-
if hydrate_error is not None:
1291-
raise hydrate_error
1218+
await with_ephemeral_mounts_removed(
1219+
self,
1220+
lambda: self._hydrate_workspace_via_http(data),
1221+
error_path=root,
1222+
error_cls=WorkspaceArchiveWriteError,
1223+
operation_error_context_key="hydrate_error_before_remount_corruption",
1224+
)
12921225

12931226

12941227
class CloudflareSandboxClient(BaseSandboxClient[CloudflareSandboxClientOptions]):

src/agents/extensions/sandbox/vercel/sandbox.py

Lines changed: 14 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import posixpath
1818
import tarfile
1919
import uuid
20-
from collections.abc import Awaitable, Callable
2120
from pathlib import Path, PurePosixPath
2221
from typing import Any, Literal, cast
2322
from urllib.parse import urlsplit
@@ -50,6 +49,7 @@
5049
from ....sandbox.session.base_sandbox_session import BaseSandboxSession
5150
from ....sandbox.session.dependencies import Dependencies
5251
from ....sandbox.session.manager import Instrumentation
52+
from ....sandbox.session.mount_lifecycle import with_ephemeral_mounts_removed
5353
from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript
5454
from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions
5555
from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot
@@ -404,100 +404,6 @@ async def running(self) -> bool:
404404
async def shutdown(self) -> None:
405405
await self._stop_attached_sandbox()
406406

407-
async def _persist_with_ephemeral_mounts_removed(
408-
self,
409-
operation: Callable[[], Awaitable[io.IOBase]],
410-
) -> io.IOBase:
411-
root = self._workspace_root_path()
412-
unmounted_mounts: list[tuple[Any, Path]] = []
413-
unmount_error: WorkspaceArchiveReadError | None = None
414-
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
415-
try:
416-
await mount_entry.mount_strategy.teardown_for_snapshot(
417-
mount_entry, self, mount_path
418-
)
419-
except Exception as exc:
420-
unmount_error = WorkspaceArchiveReadError(path=root, cause=exc)
421-
break
422-
unmounted_mounts.append((mount_entry, mount_path))
423-
424-
persist_error: WorkspaceArchiveReadError | None = None
425-
persisted: io.IOBase | None = None
426-
if unmount_error is None:
427-
try:
428-
persisted = await operation()
429-
except WorkspaceArchiveReadError as exc:
430-
persist_error = exc
431-
432-
remount_error: WorkspaceArchiveReadError | None = None
433-
for mount_entry, mount_path in reversed(unmounted_mounts):
434-
try:
435-
await mount_entry.mount_strategy.restore_after_snapshot(
436-
mount_entry, self, mount_path
437-
)
438-
except Exception as exc:
439-
if remount_error is None:
440-
remount_error = WorkspaceArchiveReadError(path=root, cause=exc)
441-
442-
if remount_error is not None:
443-
if persist_error is not None:
444-
remount_error.context["snapshot_error_before_remount_corruption"] = {
445-
"message": persist_error.message
446-
}
447-
raise remount_error
448-
if unmount_error is not None:
449-
raise unmount_error
450-
if persist_error is not None:
451-
raise persist_error
452-
453-
assert persisted is not None
454-
return persisted
455-
456-
async def _hydrate_with_ephemeral_mounts_removed(
457-
self,
458-
operation: Callable[[], Awaitable[None]],
459-
) -> None:
460-
root = self._workspace_root_path()
461-
unmounted_mounts: list[tuple[Any, Path]] = []
462-
unmount_error: WorkspaceArchiveWriteError | None = None
463-
for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets():
464-
try:
465-
await mount_entry.mount_strategy.teardown_for_snapshot(
466-
mount_entry, self, mount_path
467-
)
468-
except Exception as exc:
469-
unmount_error = WorkspaceArchiveWriteError(path=root, cause=exc)
470-
break
471-
unmounted_mounts.append((mount_entry, mount_path))
472-
473-
hydrate_error: WorkspaceArchiveWriteError | None = None
474-
if unmount_error is None:
475-
try:
476-
await operation()
477-
except WorkspaceArchiveWriteError as exc:
478-
hydrate_error = exc
479-
480-
remount_error: WorkspaceArchiveWriteError | None = None
481-
for mount_entry, mount_path in reversed(unmounted_mounts):
482-
try:
483-
await mount_entry.mount_strategy.restore_after_snapshot(
484-
mount_entry, self, mount_path
485-
)
486-
except Exception as exc:
487-
if remount_error is None:
488-
remount_error = WorkspaceArchiveWriteError(path=root, cause=exc)
489-
490-
if remount_error is not None:
491-
if hydrate_error is not None:
492-
remount_error.context["hydrate_error_before_remount_corruption"] = {
493-
"message": hydrate_error.message
494-
}
495-
raise remount_error
496-
if unmount_error is not None:
497-
raise unmount_error
498-
if hydrate_error is not None:
499-
raise hydrate_error
500-
501407
async def _exec_internal(
502408
self,
503409
*command: str | Path,
@@ -601,7 +507,13 @@ async def write(
601507
raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc
602508

603509
async def persist_workspace(self) -> io.IOBase:
604-
return await self._persist_with_ephemeral_mounts_removed(self._persist_workspace_internal)
510+
return await with_ephemeral_mounts_removed(
511+
self,
512+
self._persist_workspace_internal,
513+
error_path=self._workspace_root_path(),
514+
error_cls=WorkspaceArchiveReadError,
515+
operation_error_context_key="snapshot_error_before_remount_corruption",
516+
)
605517

606518
async def _persist_workspace_internal(self) -> io.IOBase:
607519
if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT:
@@ -665,8 +577,12 @@ async def hydrate_workspace(self, data: io.IOBase) -> None:
665577
actual_type=type(raw).__name__,
666578
)
667579

668-
await self._hydrate_with_ephemeral_mounts_removed(
669-
lambda: self._hydrate_workspace_internal(bytes(raw))
580+
await with_ephemeral_mounts_removed(
581+
self,
582+
lambda: self._hydrate_workspace_internal(bytes(raw)),
583+
error_path=self._workspace_root_path(),
584+
error_cls=WorkspaceArchiveWriteError,
585+
operation_error_context_key="hydrate_error_before_remount_corruption",
670586
)
671587

672588
async def _hydrate_workspace_internal(self, raw: bytes) -> None:
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Awaitable, Callable
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast
6+
7+
from ..errors import (
8+
WorkspaceArchiveReadError,
9+
WorkspaceArchiveWriteError,
10+
WorkspaceIOError,
11+
)
12+
13+
if TYPE_CHECKING:
14+
from ..entries import Mount
15+
from .base_sandbox_session import BaseSandboxSession
16+
17+
ArchiveError: TypeAlias = WorkspaceArchiveReadError | WorkspaceArchiveWriteError
18+
ArchiveErrorClass: TypeAlias = type[WorkspaceArchiveReadError] | type[WorkspaceArchiveWriteError]
19+
20+
_ResultT = TypeVar("_ResultT")
21+
_MISSING = object()
22+
23+
24+
async def with_ephemeral_mounts_removed(
25+
session: BaseSandboxSession,
26+
operation: Callable[[], Awaitable[_ResultT]],
27+
*,
28+
error_path: Path,
29+
error_cls: ArchiveErrorClass,
30+
operation_error_context_key: str | None,
31+
) -> _ResultT:
32+
detached_mounts: list[tuple[Mount, Path]] = []
33+
detach_error: ArchiveError | None = None
34+
for mount_entry, mount_path in session.state.manifest.ephemeral_mount_targets():
35+
try:
36+
await mount_entry.mount_strategy.teardown_for_snapshot(mount_entry, session, mount_path)
37+
except Exception as exc:
38+
detach_error = error_cls(path=error_path, cause=exc)
39+
break
40+
detached_mounts.append((mount_entry, mount_path))
41+
42+
operation_error: ArchiveError | None = None
43+
operation_result: object = _MISSING
44+
if detach_error is None:
45+
try:
46+
operation_result = await operation()
47+
except WorkspaceIOError as exc:
48+
if not isinstance(exc, error_cls):
49+
raise
50+
operation_error = cast(ArchiveError, exc)
51+
52+
restore_error = await restore_detached_mounts(
53+
session,
54+
detached_mounts,
55+
error_path=error_path,
56+
error_cls=error_cls,
57+
)
58+
59+
if restore_error is not None:
60+
if operation_error is not None and operation_error_context_key is not None:
61+
restore_error.context[operation_error_context_key] = {
62+
"message": operation_error.message
63+
}
64+
raise restore_error
65+
if detach_error is not None:
66+
raise detach_error
67+
if operation_error is not None:
68+
raise operation_error
69+
70+
assert operation_result is not _MISSING
71+
return cast(_ResultT, operation_result)
72+
73+
74+
async def restore_detached_mounts(
75+
session: BaseSandboxSession,
76+
detached_mounts: list[tuple[Mount, Path]],
77+
*,
78+
error_path: Path,
79+
error_cls: ArchiveErrorClass,
80+
) -> ArchiveError | None:
81+
restore_error: ArchiveError | None = None
82+
for mount_entry, mount_path in reversed(detached_mounts):
83+
try:
84+
await mount_entry.mount_strategy.restore_after_snapshot(
85+
mount_entry, session, mount_path
86+
)
87+
except Exception as exc:
88+
current_error = error_cls(path=error_path, cause=exc)
89+
if restore_error is None:
90+
restore_error = current_error
91+
else:
92+
additional_errors = restore_error.context.setdefault(
93+
"additional_remount_errors", []
94+
)
95+
assert isinstance(additional_errors, list)
96+
additional_errors.append(workspace_archive_error_summary(current_error))
97+
return restore_error
98+
99+
100+
def workspace_archive_error_summary(error: ArchiveError) -> dict[str, str]:
101+
summary = {"message": error.message}
102+
if error.cause is not None:
103+
summary["cause_type"] = type(error.cause).__name__
104+
summary["cause"] = str(error.cause)
105+
return summary
106+
107+
108+
__all__ = [
109+
"restore_detached_mounts",
110+
"with_ephemeral_mounts_removed",
111+
"workspace_archive_error_summary",
112+
]

0 commit comments

Comments
 (0)