diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 5b384eaf5f..d0173aa33b 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -121,7 +121,10 @@ def _init_structure_tables(self): conn.commit() - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( # type: ignore[override] + self, items: list[TResponseInputItem], + wrapper: Any = None, + ) -> None: """Add items to the session. Args: @@ -156,9 +159,11 @@ def _add_items_sync(): await asyncio.to_thread(_add_items_sync) - async def get_items( + async def get_items( # type: ignore[override] self, limit: int | None = None, + wrapper: Any = None, + *, branch_id: str | None = None, ) -> list[TResponseInputItem]: """Get items from current or specified branch. diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..9c2fc48a1e 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast +from typing import TYPE_CHECKING, Any, cast import aiosqlite @@ -13,6 +13,9 @@ from ...memory import SessionABC from ...memory.session_settings import SessionSettings +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + class AsyncSQLiteSession(SessionABC): """Async SQLite-based implementation of session storage. @@ -102,7 +105,11 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]: conn = await self._get_connection() yield conn - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -150,7 +157,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -186,7 +197,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: await conn.commit() - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -220,7 +234,10 @@ async def pop_item(self) -> TResponseInputItem | None: return None - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Clear all items for this session.""" async with self._locked_connection() as conn: await conn.execute( diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index ce6bf754a3..14bc292fb8 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -27,7 +27,7 @@ import json import random import time -from typing import Any, Final, Literal +from typing import TYPE_CHECKING, Any, Final, Literal try: from dapr.aio.clients import DaprClient @@ -42,6 +42,9 @@ from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + # Type alias for consistency levels ConsistencyLevel = Literal["eventual", "strong"] @@ -232,7 +235,10 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -271,7 +277,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -324,7 +333,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: options=self._get_state_options(), ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -368,7 +380,10 @@ async def pop_item(self) -> TResponseInputItem | None: except (json.JSONDecodeError, TypeError): return None - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Clear all items for this session.""" async with self._lock: # Delete messages and metadata keys diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index a72aee0a62..d388e52be1 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -29,7 +29,7 @@ import base64 import json -from typing import Any, Literal, TypeGuard, cast +from typing import TYPE_CHECKING, Any, Literal, TypeGuard, cast from cryptography.fernet import Fernet, InvalidToken from cryptography.hazmat.primitives import hashes @@ -40,6 +40,9 @@ from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + class EncryptedEnvelope(TypedDict): """TypedDict for encrypted message envelopes stored in the underlying session.""" @@ -170,8 +173,12 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu except (InvalidToken, KeyError): return None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - encrypted_items = await self.underlying_session.get_items(limit) + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper) valid_items: list[TResponseInputItem] = [] for enc in encrypted_items: item = self._unwrap(enc) @@ -179,18 +186,30 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: valid_items.append(item) return valid_items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] - await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) + await self.underlying_session.add_items( + cast(list[TResponseInputItem], wrapped), wrapper=wrapper + ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: while True: - enc = await self.underlying_session.pop_item() + enc = await self.underlying_session.pop_item(wrapper=wrapper) if not enc: return None item = self._unwrap(enc) if item is not None: return item - async def clear_session(self) -> None: - await self.underlying_session.clear_session() + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + await self.underlying_session.clear_session(wrapper=wrapper) diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 20c7c5f030..ef889f2140 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -34,7 +34,7 @@ import json import threading import weakref -from typing import Any +from typing import TYPE_CHECKING, Any try: from importlib.metadata import version as _get_version @@ -57,6 +57,9 @@ from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + # Identifies this library in the MongoDB handshake for server-side telemetry. _DRIVER_INFO = DriverInfo(name="openai-agents", version=_VERSION) @@ -241,7 +244,10 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -283,7 +289,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -319,7 +328,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: await self._messages.insert_many(payload, ordered=True) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -340,7 +352,10 @@ async def pop_item(self) -> TResponseInputItem | None: except (json.JSONDecodeError, KeyError, TypeError): return None - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Clear all items for this session.""" await self._ensure_indexes() await self._messages.delete_many({"session_id": self.session_id}) diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 1eee549e11..9da0cd14ba 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -24,7 +24,7 @@ import asyncio import json import time -from typing import Any +from typing import TYPE_CHECKING, Any try: import redis.asyncio as redis @@ -38,6 +38,9 @@ from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + class RedisSession(SessionABC): """Redis implementation of :pyclass:`agents.memory.session.Session`.""" @@ -140,12 +143,16 @@ async def _set_ttl_if_configured(self, *keys: str) -> None: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper providing context and usage info. Returns: List of input items representing the conversation history @@ -179,11 +186,15 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper providing context and usage info. """ if not items: return @@ -221,9 +232,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: self._session_key, self._messages_key, self._counter_key ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. + Args: + wrapper: Optional run context wrapper providing context and usage info. + Returns: The most recent item if it exists, None if the session is empty """ @@ -245,8 +262,15 @@ async def pop_item(self) -> TResponseInputItem | None: # Return None for corrupted messages (already removed) return None - async def clear_session(self) -> None: - """Clear all items for this session.""" + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + """Clear all items for this session. + + Args: + wrapper: Optional run context wrapper providing context and usage info. + """ async with self._lock: # Delete all keys associated with this session await self._redis.delete( diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index d84f2c78fb..e3eb62c108 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -26,7 +26,7 @@ import asyncio import json import threading -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from sqlalchemy import ( TIMESTAMP, @@ -52,6 +52,9 @@ from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ...run_context import RunContextWrapper + class SQLAlchemySession(SessionABC): """SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`.""" @@ -274,7 +277,10 @@ async def _ensure_tables(self) -> None: finally: self._init_lock.release() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -326,7 +332,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -376,7 +385,10 @@ async def _write_items() -> None: await self._run_sqlite_write_with_retry(_write_items) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -413,7 +425,10 @@ async def pop_item(self) -> TResponseInputItem | None: except json.JSONDecodeError: return None - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Clear all items for this session.""" await self._ensure_tables() async with self._session_factory() as sess: diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..55dbe2f28f 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client @@ -8,6 +10,9 @@ from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ..run_context import RunContextWrapper + async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str: _maybe_openai_client = openai_client @@ -70,7 +75,11 @@ async def _get_session_id(self) -> str: async def _clear_session_id(self) -> None: self._session_id = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: session_id = await self._get_session_id() session_limit = resolve_session_limit(limit, self.session_settings) @@ -97,7 +106,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return all_items # type: ignore - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: session_id = await self._get_session_id() if not items: return @@ -107,7 +120,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: items=items, ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: session_id = await self._get_session_id() items = await self.get_items(limit=1) if not items: @@ -118,7 +134,10 @@ async def pop_item(self) -> TResponseInputItem | None: ) return items[0] - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: session_id = await self._get_session_id() await self._openai_client.conversations.delete( conversation_id=session_id, diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index f024a33820..90c62d7aa0 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -17,6 +17,7 @@ ) if TYPE_CHECKING: + from ..run_context import RunContextWrapper from .session import Session logger = logging.getLogger("openai-agents.openai.compaction") @@ -156,7 +157,7 @@ def _resolve_compaction_mode_for_response( return "input" return _resolve_compaction_mode(mode, response_id=response_id, store=store) - async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: + async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None, **kwargs: Any) -> None: """Run compaction using responses.compact API.""" if args and args.get("response_id"): self._response_id = args["response_id"] @@ -181,7 +182,9 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None "when using previous_response_id compaction." ) - compaction_candidate_items, session_items = await self._ensure_compaction_candidates() + compaction_candidate_items, session_items = await self._ensure_compaction_candidates( + wrapper=kwargs.get("wrapper") + ) force = args.get("force", False) if args else False should_compact = force or self.should_trigger_compaction( @@ -229,13 +232,22 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - return await self.underlying_session.get_items(limit) - - async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + return await self.underlying_session.get_items(limit, wrapper=wrapper) + + async def _defer_compaction( + self, response_id: str, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: if self._deferred_response_id is not None: return - compaction_candidate_items, session_items = await self._ensure_compaction_candidates() + compaction_candidate_items, session_items = await self._ensure_compaction_candidates( + wrapper=wrapper + ) resolved_mode = self._resolve_compaction_mode_for_response( response_id=response_id, store=store, @@ -258,8 +270,12 @@ def _get_deferred_compaction_response_id(self) -> str | None: def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None - async def add_items(self, items: list[TResponseInputItem]) -> None: - await self.underlying_session.add_items(items) + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + await self.underlying_session.add_items(items, wrapper=wrapper) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) @@ -268,27 +284,36 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: if self._session_items is not None: self._session_items.extend(_normalize_compaction_session_items(items)) - async def pop_item(self) -> TResponseInputItem | None: - popped = await self.underlying_session.pop_item() + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: + popped = await self.underlying_session.pop_item(wrapper=wrapper) if popped: self._compaction_candidate_items = None self._session_items = None return popped - async def clear_session(self) -> None: - await self.underlying_session.clear_session() + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + await self.underlying_session.clear_session(wrapper=wrapper) self._compaction_candidate_items = [] self._session_items = [] self._deferred_response_id = None async def _ensure_compaction_candidates( self, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[list[TResponseInputItem], list[TResponseInputItem]]: """Lazy-load and cache compaction candidates.""" if self._compaction_candidate_items is not None and self._session_items is not None: return (self._compaction_candidate_items[:], self._session_items[:]) - history = _normalize_compaction_session_items(await self.underlying_session.get_items()) + history = _normalize_compaction_session_items( + await self.underlying_session.get_items(wrapper=wrapper) + ) candidates = select_compaction_candidate_items(history) self._compaction_candidate_items = candidates self._session_items = history diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 1781b7ac9f..e30deb0140 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,12 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Protocol, TypeGuard, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeGuard, runtime_checkable from typing_extensions import TypedDict if TYPE_CHECKING: from ..items import TResponseInputItem + from ..run_context import RunContextWrapper from .session_settings import SessionSettings @@ -21,36 +22,59 @@ class Session(Protocol): session_id: str session_settings: SessionSettings | None = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper providing context and usage info. Returns: List of input items representing the conversation history """ ... - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper providing context and usage info. """ ... - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. + Args: + wrapper: Optional run context wrapper providing context and usage info. + Returns: The most recent item if it exists, None if the session is empty """ ... - async def clear_session(self) -> None: - """Clear all items for this session.""" + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + """Clear all items for this session. + + Args: + wrapper: Optional run context wrapper providing context and usage info. + """ ... @@ -68,12 +92,17 @@ class SessionABC(ABC): session_settings: SessionSettings | None = None @abstractmethod - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper providing context and usage info. Returns: List of input items representing the conversation history @@ -81,26 +110,44 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: ... @abstractmethod - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper providing context and usage info. """ ... @abstractmethod - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. + Args: + wrapper: Optional run context wrapper providing context and usage info. + Returns: The most recent item if it exists, None if the session is empty """ ... @abstractmethod - async def clear_session(self) -> None: - """Clear all items for this session.""" + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + """Clear all items for this session. + + Args: + wrapper: Optional run context wrapper providing context and usage info. + """ ... diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index a31347cdcd..050df5b51c 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -7,12 +7,15 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from ..items import TResponseInputItem from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit +if TYPE_CHECKING: + from ..run_context import RunContextWrapper + class SQLiteSession(SessionABC): """SQLite-based implementation of session storage. @@ -199,12 +202,17 @@ def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem (self.session_id,), ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper providing context and usage info. Returns: List of input items representing the conversation history @@ -254,11 +262,16 @@ def _get_items_sync(): return await asyncio.to_thread(_get_items_sync) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper providing context and usage info. """ if not items: return @@ -270,9 +283,15 @@ def _add_items_sync(): await asyncio.to_thread(_add_items_sync) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. + Args: + wrapper: Optional run context wrapper providing context and usage info. + Returns: The most recent item if it exists, None if the session is empty """ @@ -310,8 +329,15 @@ def _pop_item_sync(): return await asyncio.to_thread(_pop_item_sync) - async def clear_session(self) -> None: - """Clear all items for this session.""" + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + """Clear all items for this session. + + Args: + wrapper: Optional run context wrapper providing context and usage info. + """ def _clear_session_sync(): with self._locked_connection() as conn: diff --git a/src/agents/run.py b/src/agents/run.py index 68fa27b3bb..d0c42bffcf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -494,6 +494,8 @@ async def run( raw_input = cast(str | list[TResponseInputItem], input) original_user_input = raw_input + context_wrapper = ensure_context_wrapper(context) + validate_session_conversation_settings( session, conversation_id=conversation_id, @@ -515,6 +517,7 @@ async def run( run_config.session_settings, include_history_in_prepared_input=False, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) original_input_for_state = raw_input session_input_items_for_persistence = [] @@ -527,6 +530,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, ) original_input_for_state = prepared_input diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 039088ecb6..720b3a0d48 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -598,6 +598,7 @@ def _sync_conversation_tracking_from_tracker() -> None: run_config.session_settings, include_history_in_prepared_input=not server_manages_conversation, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) streamed_result.input = prepared_input streamed_result._original_input = copy_input_items(prepared_input) diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 25874ad345..3ed2e3eccb 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -10,7 +10,7 @@ import inspect import json from collections.abc import Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from ..exceptions import UserError from ..items import HandoffOutputItem, ItemHelpers, RunItem, ToolCallOutputItem, TResponseInputItem @@ -24,6 +24,10 @@ ) from ..memory.openai_conversations_session import OpenAIConversationsSession from ..run_state import RunState + +if TYPE_CHECKING: + from ..run_context import RunContextWrapper + from .items import ( ReasoningItemIdPolicy, copy_input_items, @@ -59,6 +63,7 @@ async def prepare_input_with_session( *, include_history_in_prepared_input: bool = True, preserve_dropped_new_items: bool = False, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[str | list[TResponseInputItem], list[TResponseInputItem]]: """Prepare model input from session history plus the new turn input. @@ -83,9 +88,9 @@ async def prepare_input_with_session( resolved_settings = resolved_settings.resolve(session_settings) if resolved_settings.limit is not None: - history = await session.get_items(limit=resolved_settings.limit) + history = await session.get_items(limit=resolved_settings.limit, wrapper=wrapper) else: - history = await session.get_items() + history = await session.get_items(wrapper=wrapper) converted_history = [ strip_internal_input_item_metadata(ensure_input_item_format(item)) for item in history ] @@ -237,6 +242,7 @@ async def save_result_to_session( response_id: str | None = None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """ Persist a turn to the session store, keeping track of what was already saved so retries @@ -322,7 +328,7 @@ async def save_result_to_session( run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count - await session.add_items(items_to_save) + await session.add_items(items_to_save, wrapper=wrapper) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count @@ -334,7 +340,7 @@ async def save_result_to_session( if has_local_tool_outputs: defer_compaction = getattr(session, "_defer_compaction", None) if callable(defer_compaction): - result = defer_compaction(response_id, store=store) + result = defer_compaction(response_id, store=store, wrapper=wrapper) if inspect.isawaitable(result): await result logger.debug( @@ -360,7 +366,7 @@ async def save_result_to_session( } if store is not None: compaction_args["store"] = store - await session.run_compaction(compaction_args) + await session.run_compaction(compaction_args, wrapper=wrapper) return saved_run_items_count diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 56d05f12a4..4c4c66e43a 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -125,7 +125,7 @@ async def test_add_items_delegates(self) -> None: ] await session.add_items(items) - mock_session.add_items.assert_called_once_with(items) + mock_session.add_items.assert_called_once_with(items, wrapper=None) @pytest.mark.asyncio async def test_get_items_delegates(self) -> None: diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index c5cc123034..8329b5dcba 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -328,16 +328,30 @@ class DummySession(Session): session_id = "sess_123" session_settings = SessionSettings() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: return None - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: return None - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: return None dummy_session = DummySession() diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 45cdab7711..60d002627a 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -2433,16 +2433,16 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items(self, items, **kwargs): self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items(self, limit=None, **kwargs): return [] - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, **kwargs): return None - async def clear_session(self) -> None: + async def clear_session(self, **kwargs): return None session = DummyOpenAIConversationsSession() @@ -3970,15 +3970,18 @@ async def echo_tool(text: str) -> str: expected_calls = [ # First call is the initial input - (([expected_items[0]],),), + [expected_items[0]], # Second call is the first tool call and its result - (([expected_items[1], expected_items[2]],),), + [expected_items[1], expected_items[2]], # Third call is the second tool call and its result - (([expected_items[3], expected_items[4]],),), + [expected_items[3], expected_items[4]], # Fourth call is the final output - (([expected_items[5]],),), + [expected_items[5]], ] - assert mock_add_items.call_args_list == expected_calls + assert mock_add_items.call_count == len(expected_calls) + for i, expected_items_call in enumerate(expected_calls): + actual_call = mock_add_items.call_args_list[i] + assert list(actual_call[0][0]) == expected_items_call assert result.final_output == "Summary: Echoed foo and bar" assert (await session.get_items()) == expected_items diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 1c28fafbc2..44982beca7 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1166,7 +1166,7 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items(self, items, **kwargs): for item in items: if isinstance(item, dict): assert "id" not in item, "IDs should be stripped before saving" @@ -1175,13 +1175,13 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ) self.saved.append(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items(self, limit=None, **kwargs): return [] - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, **kwargs): return None - async def clear_session(self) -> None: + async def clear_session(self, **kwargs): return None session = DummyOpenAIConversationsSession() diff --git a/tests/test_session.py b/tests/test_session.py index aa8211500a..e464df0ea2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -598,7 +598,7 @@ async def test_session_add_items_exception_propagates_in_streamed(): """ session = SQLiteSession("test_exception_session") - async def _failing_add_items(_items): + async def _failing_add_items(_items, **kwargs): raise RuntimeError("Simulated session.add_items failure") session.add_items = _failing_add_items # type: ignore[method-assign] diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..6a4f8502dd 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import cast +from typing import TYPE_CHECKING, Any, cast from agents.items import TResponseInputItem from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +if TYPE_CHECKING: + from agents.run_context import RunContextWrapper + class SimpleListSession(Session): """A minimal in-memory session implementation for tests.""" @@ -24,22 +27,36 @@ def __init__( # Mirror saved_items used by some tests for inspection. self.saved_items: list[TResponseInputItem] = self._items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self._items) if limit <= 0: return [] return self._items[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.extend(items) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: if not self._items: return None return self._items.pop() - async def clear_session(self) -> None: + async def clear_session( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.clear() @@ -54,9 +71,12 @@ def __init__( super().__init__(session_id=session_id, history=history) self.pop_calls = 0 - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: self.pop_calls += 1 - return await super().pop_item() + return await super().pop_item(wrapper=wrapper) class IdStrippingSession(CountingSession): @@ -70,7 +90,11 @@ def __init__( super().__init__(session_id=session_id, history=history) self._ignore_ids_for_matching = True - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: sanitized: list[TResponseInputItem] = [] for item in items: if isinstance(item, dict): @@ -79,4 +103,4 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: sanitized.append(cast(TResponseInputItem, clean)) else: sanitized.append(item) - await super().add_items(sanitized) + await super().add_items(sanitized, wrapper=wrapper)