Skip to content

Commit 3ff12bb

Browse files
committed
Replace private pydantic-ai access with public Agent.toolsets surface
The _get_function_tools helper reached into agent._function_tools (0.2.x) and agent._function_toolset (1.x) — both private dataclass fields. The tool-dispatch and requires-approval lookup paths additionally probed speculative private attributes (_wrapped_function, callable_function, _core_tool, wrapped_tool) on Tool objects that do not exist on the public Tool dataclass. Refactor to walk the public Agent.toolsets property (documented to include the auto-built function toolset for tools registered directly on the agent), pick FunctionToolset instances, and read the public FunctionToolset.tools dict. Use the public Tool.function and Tool.requires_approval fields for callable extraction and fallback approval checks. Verified against pydantic-ai 1.87.0 source (our minimum pin). Update test mocks in test_pydantic_ai_agents.py and test_nested_approval_gates.py to set inst.toolsets via a MagicMock(spec=FunctionToolset) helper so isinstance() recognises the fake. The integration test test_check_tool_requires_approval_with_real_pydantic_ai_agent exercises the helper end-to-end against a real Agent instance.
1 parent 9d8226e commit 3ff12bb

3 files changed

Lines changed: 74 additions & 76 deletions

File tree

opencontractserver/llms/agents/pydantic_ai_agents.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ToolReturnPart,
3030
UserPromptPart,
3131
)
32+
from pydantic_ai.toolsets import FunctionToolset
3233
from pydantic_graph import End
3334

3435
from opencontractserver.constants.context_guardrails import COMPACTION_SUMMARY_PREFIX
@@ -148,18 +149,16 @@ async def similarity_search(
148149
def _get_function_tools(agent: PydanticAIAgent) -> dict:
149150
"""Return the function-tools dict from a pydantic-ai Agent.
150151
151-
Handles both pydantic-ai 0.2.x (``agent._function_tools``) and
152-
1.x (``agent._function_toolset.tools``).
152+
Uses only the public surface: ``Agent.toolsets`` (documented property
153+
that includes the auto-built function toolset for tools registered
154+
directly on the agent) and ``FunctionToolset.tools`` (public dict of
155+
tool name -> ``Tool``).
153156
"""
154-
# pydantic-ai 0.2.x
155-
ft = getattr(agent, "_function_tools", None)
156-
if ft is not None:
157-
return ft
158-
# pydantic-ai 1.x
159-
toolset = getattr(agent, "_function_toolset", None)
160-
if toolset is not None:
161-
return getattr(toolset, "tools", {})
162-
return {}
157+
merged: dict = {}
158+
for toolset in agent.toolsets:
159+
if isinstance(toolset, FunctionToolset):
160+
merged.update(toolset.tools)
161+
return merged
163162

164163

165164
@dataclasses.dataclass
@@ -1605,27 +1604,17 @@ async def _maybe_await(call_result): # noqa: D401 – small helper
16051604
# Don't retry here, fall through to registry lookup
16061605

16071606
if not tool_executed:
1608-
# Resort to pydantic-ai registry – may return Tool object.
1607+
# Resort to pydantic-ai registry – returns a public ``Tool``.
16091608
tool_obj = _get_function_tools(self.pydantic_ai_agent).get(
16101609
tool_name
16111610
)
16121611
if tool_obj is None:
16131612
raise ValueError(f"Tool '{tool_name}' not found for execution")
16141613

1615-
# Try common attributes to reach the underlying callable.
1616-
candidate = None
1617-
for attr in (
1618-
"function",
1619-
"_wrapped_function",
1620-
"callable_function",
1621-
):
1622-
candidate = getattr(tool_obj, attr, None)
1623-
if callable(candidate):
1624-
break
1625-
1626-
if candidate is None or not callable(candidate):
1614+
candidate = tool_obj.function
1615+
if not callable(candidate):
16271616
raise TypeError(
1628-
"Tool object is not callable and no inner function found"
1617+
f"Tool '{tool_name}' has a non-callable function"
16291618
)
16301619

16311620
logger.info(
@@ -1933,36 +1922,24 @@ def _check_tool_requires_approval(self, tool_name: str) -> bool:
19331922
if hasattr(tool, "requires_approval"):
19341923
return tool.requires_approval
19351924

1936-
# Check tools registered with pydantic-ai agent
1925+
# Check tools registered with pydantic-ai agent. Tools registered as
1926+
# plain async callables (our common case) carry their CoreTool on the
1927+
# underlying function, not on the Tool object — pydantic-ai 1.x's
1928+
# Tool.requires_approval defaults to False unless the caller passes it
1929+
# in, so we must consult the function attribute first.
19371930
function_tools = _get_function_tools(self.pydantic_ai_agent)
19381931
if function_tools:
19391932
tool_obj = function_tools.get(tool_name)
1940-
if tool_obj:
1941-
# Check various possible attributes where the CoreTool might be stored
1942-
for attr in ("core_tool", "_core_tool", "wrapped_tool"):
1943-
core_tool = getattr(tool_obj, attr, None)
1944-
if core_tool and hasattr(core_tool, "requires_approval"):
1945-
return core_tool.requires_approval
1946-
1947-
# Check the wrapped function (must come before the native
1948-
# tool_obj.requires_approval check because pydantic-ai 1.x
1949-
# Tool has a native requires_approval field that defaults to
1950-
# False, shadowing the custom attribute on the function).
1951-
for attr in ("function", "_wrapped_function", "callable_function"):
1952-
func = getattr(tool_obj, attr, None)
1953-
if func:
1954-
# Check if the function has a core_tool attribute
1955-
if hasattr(func, "core_tool") and hasattr(
1956-
func.core_tool, "requires_approval"
1957-
):
1958-
return func.core_tool.requires_approval
1959-
# Check if the function itself has requires_approval
1960-
if hasattr(func, "requires_approval"):
1961-
return func.requires_approval
1962-
1963-
# Fall back to the tool object's own requires_approval
1964-
if hasattr(tool_obj, "requires_approval"):
1965-
return tool_obj.requires_approval
1933+
if tool_obj is not None:
1934+
func = tool_obj.function
1935+
core_tool = getattr(func, "core_tool", None)
1936+
if core_tool is not None and getattr(
1937+
core_tool, "requires_approval", False
1938+
):
1939+
return True
1940+
if getattr(func, "requires_approval", False):
1941+
return True
1942+
return tool_obj.requires_approval
19661943

19671944
# Default to not requiring approval
19681945
return False

opencontractserver/tests/test_nested_approval_gates.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ async def __aexit__(self, exc_type, exc, tb):
8787
return False
8888

8989

90+
def _fake_function_toolset(entries: dict[str, Any]) -> Any:
91+
"""Return a MagicMock that satisfies isinstance(_, FunctionToolset) and
92+
exposes a .tools dict, mimicking the public toolset surface that
93+
pydantic-ai exposes via Agent.toolsets.
94+
"""
95+
from pydantic_ai.toolsets import FunctionToolset
96+
97+
ts = MagicMock(spec=FunctionToolset)
98+
ts.tools = entries
99+
return ts
100+
101+
90102
# ---------------------------------------------------------------------------
91103
# Mock sub-agent that yields configurable stream events
92104
# ---------------------------------------------------------------------------
@@ -180,7 +192,7 @@ async def _create_corpus_agent(self, sub_agent_events=None):
180192
) as mock_agent_cls:
181193
inst = MagicMock()
182194
inst.iter = MagicMock(return_value=_IterCtx())
183-
inst._function_tools = {}
195+
inst.toolsets = []
184196
inst.run = AsyncMock(
185197
return_value=types.SimpleNamespace(
186198
data="ok", sources=[], usage=lambda: None
@@ -413,11 +425,14 @@ async def _spy_tool(ctx, **kwargs):
413425
)
414426
)
415427
inst.iter = MagicMock(return_value=_IterCtx())
416-
# resume_with_approval uses pydantic_ai_agent._function_tools
417-
# as fallback to find the tool callable.
418-
inst._function_tools = {
419-
"ask_document": types.SimpleNamespace(function=_spy_tool),
420-
}
428+
# resume_with_approval falls back to pydantic-ai's public
429+
# function-toolset registry (Agent.toolsets) to find the tool
430+
# callable when config.tools doesn't carry it.
431+
inst.toolsets = [
432+
_fake_function_toolset(
433+
{"ask_document": types.SimpleNamespace(function=_spy_tool)}
434+
)
435+
]
421436
mock_agent_cls.return_value = inst
422437

423438
agent = await UnifiedAgentFactory.create_corpus_agent(
@@ -499,9 +514,11 @@ async def _capture_bypass_tool(ctx, **kwargs):
499514
)
500515
return {"result": "done"}
501516

502-
inst._function_tools = {
503-
"test_tool": types.SimpleNamespace(function=_capture_bypass_tool),
504-
}
517+
inst.toolsets = [
518+
_fake_function_toolset(
519+
{"test_tool": types.SimpleNamespace(function=_capture_bypass_tool)}
520+
)
521+
]
505522
mock_agent_cls.return_value = inst
506523

507524
agent = await UnifiedAgentFactory.create_corpus_agent(
@@ -569,9 +586,11 @@ async def test_bypass_flag_reset_on_tool_error(self):
569586
async def _failing_tool(ctx, **kwargs):
570587
raise RuntimeError("Tool exploded")
571588

572-
inst._function_tools = {
573-
"exploding_tool": types.SimpleNamespace(function=_failing_tool),
574-
}
589+
inst.toolsets = [
590+
_fake_function_toolset(
591+
{"exploding_tool": types.SimpleNamespace(function=_failing_tool)}
592+
)
593+
]
575594
mock_agent_cls.return_value = inst
576595

577596
agent = await UnifiedAgentFactory.create_corpus_agent(
@@ -615,7 +634,7 @@ async def _failing_tool(ctx, **kwargs):
615634
async def test_resume_executes_tool_from_config_tools(self):
616635
"""resume_with_approval should find and execute a tool present in
617636
config.tools (the wrapper_fn path) rather than falling back to
618-
pydantic_ai_agent._function_tools."""
637+
pydantic-ai's function-toolset registry."""
619638
from opencontractserver.conversations.models import ChatMessage, Conversation
620639
from opencontractserver.llms.agents.agent_factory import (
621640
UnifiedAgentFactory,
@@ -640,9 +659,9 @@ async def _config_spy(ctx, **kwargs):
640659
)
641660
)
642661
inst.iter = MagicMock(return_value=_IterCtx())
643-
# Leave _function_tools empty so the only way to find the tool
644-
# is via config.tools.
645-
inst._function_tools = {}
662+
# Leave the function toolset empty so the only way to find the
663+
# tool is via config.tools.
664+
inst.toolsets = []
646665
mock_agent_cls.return_value = inst
647666

648667
agent = await UnifiedAgentFactory.create_corpus_agent(
@@ -683,7 +702,7 @@ async def _config_spy(ctx, **kwargs):
683702

684703
async def test_resume_config_tools_typeerror_falls_through(self):
685704
"""When a config.tools wrapper raises TypeError, resume_with_approval
686-
should fall through to the _function_tools registry lookup."""
705+
should fall through to the function-toolset registry lookup."""
687706
from opencontractserver.conversations.models import ChatMessage, Conversation
688707
from opencontractserver.llms.agents.agent_factory import (
689708
UnifiedAgentFactory,
@@ -712,9 +731,11 @@ async def _fallback_tool(ctx, **kwargs):
712731
)
713732
)
714733
inst.iter = MagicMock(return_value=_IterCtx())
715-
inst._function_tools = {
716-
"dual_tool": types.SimpleNamespace(function=_fallback_tool),
717-
}
734+
inst.toolsets = [
735+
_fake_function_toolset(
736+
{"dual_tool": types.SimpleNamespace(function=_fallback_tool)}
737+
)
738+
]
718739
mock_agent_cls.return_value = inst
719740

720741
agent = await UnifiedAgentFactory.create_corpus_agent(

opencontractserver/tests/test_pydantic_ai_agents.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ def test_tool():
859859
config = AgentConfig(user_id=self.user.id, tools=[mock_tool])
860860

861861
mock_pydantic_agent = MagicMock()
862-
mock_pydantic_agent._function_tools = {}
862+
mock_pydantic_agent.toolsets = []
863863
mock_agent_cls.return_value = mock_pydantic_agent
864864

865865
agent = PydanticAICoreAgent(
@@ -885,7 +885,7 @@ def test_check_tool_requires_approval_default_false(
885885
config = AgentConfig(user_id=self.user.id)
886886

887887
mock_pydantic_agent = MagicMock()
888-
mock_pydantic_agent._function_tools = {}
888+
mock_pydantic_agent.toolsets = []
889889
mock_agent_cls.return_value = mock_pydantic_agent
890890

891891
agent = PydanticAICoreAgent(
@@ -1048,7 +1048,7 @@ async def test_resume_with_approval_approved_success(
10481048
)
10491049

10501050
mock_pydantic_agent = MagicMock()
1051-
mock_pydantic_agent._function_tools = {}
1051+
mock_pydantic_agent.toolsets = []
10521052

10531053
agent = PydanticAICoreAgent(
10541054
config=config,
@@ -1203,7 +1203,7 @@ async def test_resume_with_approval_parses_json_string_args(
12031203
)
12041204

12051205
mock_pydantic_agent = MagicMock()
1206-
mock_pydantic_agent._function_tools = {}
1206+
mock_pydantic_agent.toolsets = []
12071207

12081208
agent = PydanticAICoreAgent(
12091209
config=config,

0 commit comments

Comments
 (0)