Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def _init_structure_tables(self):

conn.commit()

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items( # type: ignore[override]
self, items: list[TResponseInputItem],
wrapper: Any = None,
) -> None:
"""Add items to the session.

Args:
Expand Down Expand Up @@ -156,10 +159,11 @@ def _add_items_sync():

await asyncio.to_thread(_add_items_sync)

async def get_items(
async def get_items( # type: ignore[override]
self,
limit: int | None = None,
branch_id: str | None = None,
wrapper: Any = None,
Comment on lines 164 to +165
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Session positional order in AdvancedSQLite get_items

The new Session.get_items contract adds wrapper as the second optional argument, but this override keeps branch_id in that slot and suppresses the incompatibility with # type: ignore[override]. Any caller using the protocol positionally (e.g. session.get_items(limit, wrapper)) will pass the wrapper object as branch_id, leading to incorrect branch lookup and broken context propagation for this built-in session implementation.

Useful? React with 👍 / 👎.

) -> list[TResponseInputItem]:
"""Get items from current or specified branch.

Expand Down
27 changes: 22 additions & 5 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, Any, cast

import aiosqlite

from ...items import TResponseInputItem
from ...memory import SessionABC
from ...memory.session_settings import SessionSettings

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class AsyncSQLiteSession(SessionABC):
"""Async SQLite-based implementation of session storage.
Expand Down Expand Up @@ -102,7 +105,11 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
conn = await self._get_connection()
yield conn

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -150,7 +157,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -186,7 +197,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await conn.commit()

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down Expand Up @@ -220,7 +234,10 @@ async def pop_item(self) -> TResponseInputItem | None:

return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
async with self._locked_connection() as conn:
await conn.execute(
Expand Down
25 changes: 20 additions & 5 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import json
import random
import time
from typing import Any, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Literal

try:
from dapr.aio.clients import DaprClient
Expand All @@ -42,6 +42,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper

# Type alias for consistency levels
ConsistencyLevel = Literal["eventual", "strong"]

Expand Down Expand Up @@ -232,7 +235,10 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) ->
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -271,7 +277,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -324,7 +333,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
options=self._get_state_options(),
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down Expand Up @@ -368,7 +380,10 @@ async def pop_item(self) -> TResponseInputItem | None:
except (json.JSONDecodeError, TypeError):
return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
async with self._lock:
# Delete messages and metadata keys
Expand Down
37 changes: 28 additions & 9 deletions src/agents/extensions/memory/encrypt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import base64
import json
from typing import Any, Literal, TypeGuard, cast
from typing import TYPE_CHECKING, Any, Literal, TypeGuard, cast

from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
Expand All @@ -40,6 +40,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class EncryptedEnvelope(TypedDict):
"""TypedDict for encrypted message envelopes stored in the underlying session."""
Expand Down Expand Up @@ -170,27 +173,43 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu
except (InvalidToken, KeyError):
return None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit)
async def get_items(
self,
limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
await self.underlying_session.add_items(
cast(list[TResponseInputItem], wrapped), wrapper=wrapper
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
while True:
enc = await self.underlying_session.pop_item()
enc = await self.underlying_session.pop_item(wrapper=wrapper)
if not enc:
return None
item = self._unwrap(enc)
if item is not None:
return item

async def clear_session(self) -> None:
await self.underlying_session.clear_session()
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
await self.underlying_session.clear_session(wrapper=wrapper)
25 changes: 20 additions & 5 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import json
import threading
import weakref
from typing import Any
from typing import TYPE_CHECKING, Any

try:
from importlib.metadata import version as _get_version
Expand All @@ -57,6 +57,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper

# Identifies this library in the MongoDB handshake for server-side telemetry.
_DRIVER_INFO = DriverInfo(name="openai-agents", version=_VERSION)

Expand Down Expand Up @@ -241,7 +244,10 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -283,7 +289,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -319,7 +328,10 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await self._messages.insert_many(payload, ordered=True)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand All @@ -340,7 +352,10 @@ async def pop_item(self) -> TResponseInputItem | None:
except (json.JSONDecodeError, KeyError, TypeError):
return None

async def clear_session(self) -> None:
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session."""
await self._ensure_indexes()
await self._messages.delete_many({"session_id": self.session_id})
Expand Down
36 changes: 30 additions & 6 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import asyncio
import json
import time
from typing import Any
from typing import TYPE_CHECKING, Any

try:
import redis.asyncio as redis
Expand All @@ -38,6 +38,9 @@
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit

if TYPE_CHECKING:
from ...run_context import RunContextWrapper


class RedisSession(SessionABC):
"""Redis implementation of :pyclass:`agents.memory.session.Session`."""
Expand Down Expand Up @@ -140,12 +143,16 @@ async def _set_ttl_if_configured(self, *keys: str) -> None:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self, limit: int | None = None,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
limit: Maximum number of items to retrieve. If None, uses session_settings.limit.
When specified, returns the latest N items in chronological order.
wrapper: Optional run context wrapper providing context and usage info.

Returns:
List of input items representing the conversation history
Expand Down Expand Up @@ -179,11 +186,15 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self, items: list[TResponseInputItem],
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
items: List of input items to add to the history
wrapper: Optional run context wrapper providing context and usage info.
"""
if not items:
return
Expand Down Expand Up @@ -221,9 +232,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
self._session_key, self._messages_key, self._counter_key
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Args:
wrapper: Optional run context wrapper providing context and usage info.

Returns:
The most recent item if it exists, None if the session is empty
"""
Expand All @@ -245,8 +262,15 @@ async def pop_item(self) -> TResponseInputItem | None:
# Return None for corrupted messages (already removed)
return None

async def clear_session(self) -> None:
"""Clear all items for this session."""
async def clear_session(
self,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Clear all items for this session.

Args:
wrapper: Optional run context wrapper providing context and usage info.
"""
async with self._lock:
# Delete all keys associated with this session
await self._redis.delete(
Expand Down
Loading