Skip to content

Commit 2d40c09

Browse files
authored
fix: #3055 surface model refusals during run resolution (#3057)
1 parent 3a3f34f commit 2d40c09

12 files changed

Lines changed: 226 additions & 16 deletions

docs/running_agents.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ Set the hook per run via `run_config` to redact sensitive data, trim long histor
410410

411411
### Error handlers
412412

413-
All `Runner` entry points accept `error_handlers`, a dict keyed by error kind. Today, the supported key is `"max_turns"`. Use it when you want to return a controlled final output instead of raising `MaxTurnsExceeded`.
413+
All `Runner` entry points accept `error_handlers`, a dict keyed by error kind. The supported keys are `"max_turns"` and `"model_refusal"`. Use them when you want to return a controlled final output instead of raising `MaxTurnsExceeded` or `ModelRefusalError`.
414414

415415
```python
416416
from agents import (
@@ -441,6 +441,38 @@ print(result.final_output)
441441

442442
Set `include_in_history=False` when you do not want the fallback output appended to conversation history.
443443

444+
Use `"model_refusal"` when a model refusal should produce an application-specific fallback instead of ending the run with `ModelRefusalError`.
445+
446+
```python
447+
from pydantic import BaseModel
448+
449+
from agents import Agent, ModelRefusalError, RunErrorHandlerInput, Runner
450+
451+
452+
class Recipe(BaseModel):
453+
ingredients: list[str]
454+
refusal_reason: str | None = None
455+
456+
457+
def on_model_refusal(data: RunErrorHandlerInput[None]) -> Recipe:
458+
assert isinstance(data.error, ModelRefusalError)
459+
return Recipe(ingredients=[], refusal_reason=data.error.refusal)
460+
461+
462+
agent = Agent(
463+
name="Recipe assistant",
464+
instructions="Return a structured recipe.",
465+
output_type=Recipe,
466+
)
467+
468+
result = Runner.run_sync(
469+
agent,
470+
"Make me something unsafe.",
471+
error_handlers={"model_refusal": on_model_refusal},
472+
)
473+
print(result.final_output)
474+
```
475+
444476
## Durable execution integrations and human-in-the-loop
445477

446478
For tool approval pause/resume patterns, start with the dedicated [Human-in-the-loop guide](human_in_the_loop.md).

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
InputGuardrailTripwireTriggered,
2323
MaxTurnsExceeded,
2424
ModelBehaviorError,
25+
ModelRefusalError,
2526
OutputGuardrailTripwireTriggered,
2627
RunErrorDetails,
2728
ToolInputGuardrailTripwireTriggered,
@@ -362,6 +363,7 @@ def enable_verbose_stdout_logging():
362363
"Prompt",
363364
"MaxTurnsExceeded",
364365
"ModelBehaviorError",
366+
"ModelRefusalError",
365367
"ToolTimeoutError",
366368
"UserError",
367369
"InputGuardrail",

src/agents/exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ def __init__(self, message: str):
6565
super().__init__(message)
6666

6767

68+
class ModelRefusalError(AgentsException):
69+
"""Exception raised when the model refuses to produce the requested output."""
70+
71+
refusal: str
72+
"""The refusal text returned by the model."""
73+
74+
def __init__(self, refusal: str):
75+
self.refusal = refusal
76+
super().__init__(f"Model refused to produce output: {refusal}")
77+
78+
6879
class UserError(AgentsException):
6980
"""Exception raised when the user makes an error using the SDK."""
7081

src/agents/items.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,19 @@ def extract_text(cls, message: TResponseOutputItem) -> str | None:
722722

723723
return text or None
724724

725+
@classmethod
726+
def extract_refusal(cls, message: TResponseOutputItem) -> str | None:
727+
"""Extracts refusal content from a message, if any."""
728+
if not isinstance(message, ResponseOutputMessage):
729+
return None
730+
731+
refusal = ""
732+
for content_item in message.content:
733+
if isinstance(content_item, ResponseOutputRefusal):
734+
refusal += content_item.refusal or ""
735+
736+
return refusal or None
737+
725738
@classmethod
726739
def input_to_new_input_list(
727740
cls, input: str | list[TResponseInputItem]

src/agents/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,7 @@ def _finalize_result(result: RunResult) -> RunResult:
11961196
),
11971197
reasoning_item_id_policy=resolved_reasoning_item_id_policy,
11981198
prompt_cache_key_resolver=prompt_cache_key_resolver,
1199+
error_handlers=error_handlers,
11991200
)
12001201
)
12011202

@@ -1251,6 +1252,7 @@ def _finalize_result(result: RunResult) -> RunResult:
12511252
),
12521253
reasoning_item_id_policy=resolved_reasoning_item_id_policy,
12531254
prompt_cache_key_resolver=prompt_cache_key_resolver,
1255+
error_handlers=error_handlers,
12541256
)
12551257
finally:
12561258
attach_usage_to_span(

src/agents/run_error_handlers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing_extensions import TypedDict
88

99
from .agent import Agent
10-
from .exceptions import MaxTurnsExceeded
10+
from .exceptions import MaxTurnsExceeded, ModelRefusalError
1111
from .items import ModelResponse, RunItem, TResponseInputItem
1212
from .run_context import RunContextWrapper, TContext
1313
from .util._types import MaybeAwaitable
@@ -27,7 +27,7 @@ class RunErrorData:
2727

2828
@dataclass
2929
class RunErrorHandlerInput(Generic[TContext]):
30-
error: MaxTurnsExceeded
30+
error: MaxTurnsExceeded | ModelRefusalError
3131
context: RunContextWrapper[TContext]
3232
run_data: RunErrorData
3333

@@ -51,6 +51,7 @@ class RunErrorHandlers(TypedDict, Generic[TContext], total=False):
5151
"""Error handlers keyed by error kind."""
5252

5353
max_turns: RunErrorHandler[TContext]
54+
model_refusal: RunErrorHandler[TContext]
5455

5556

5657
__all__ = [

src/agents/run_internal/error_handlers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..agent import Agent
1010
from ..agent_output import _WRAPPER_DICT_KEY, AgentOutputSchema
11-
from ..exceptions import MaxTurnsExceeded, ModelBehaviorError, UserError
11+
from ..exceptions import MaxTurnsExceeded, ModelBehaviorError, ModelRefusalError, UserError
1212
from ..items import (
1313
ItemHelpers,
1414
MessageOutputItem,
@@ -128,13 +128,16 @@ def create_message_output_item(agent: Agent[Any], output_text: str) -> MessageOu
128128
async def resolve_run_error_handler_result(
129129
*,
130130
error_handlers: RunErrorHandlers[TContext] | None,
131-
error: MaxTurnsExceeded,
131+
error: MaxTurnsExceeded | ModelRefusalError,
132132
context_wrapper: RunContextWrapper[TContext],
133133
run_data: RunErrorData,
134134
) -> RunErrorHandlerResult | None:
135135
if not error_handlers:
136136
return None
137-
handler = error_handlers.get("max_turns")
137+
if isinstance(error, ModelRefusalError):
138+
handler = error_handlers.get("model_refusal")
139+
else:
140+
handler = error_handlers.get("max_turns")
138141
if handler is None:
139142
return None
140143
handler_input = RunErrorHandlerInput(

src/agents/run_internal/run_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ async def _save_stream_items_without_count(
10301030
),
10311031
reasoning_item_id_policy=resolved_reasoning_item_id_policy,
10321032
prompt_cache_key_resolver=prompt_cache_key_resolver,
1033+
error_handlers=error_handlers,
10331034
)
10341035
finally:
10351036
attach_usage_to_span(
@@ -1248,6 +1249,7 @@ async def run_single_turn_streamed(
12481249
pending_server_items: list[RunItem] | None = None,
12491250
reasoning_item_id_policy: ReasoningItemIdPolicy | None = None,
12501251
prompt_cache_key_resolver: PromptCacheKeyResolver | None = None,
1252+
error_handlers: RunErrorHandlers[TContext] | None = None,
12511253
) -> SingleStepResult:
12521254
"""Run a single streamed turn and emit events as results arrive."""
12531255
public_agent = bindings.public_agent
@@ -1643,6 +1645,7 @@ async def rewind_model_request() -> None:
16431645
hooks=hooks,
16441646
context_wrapper=context_wrapper,
16451647
run_config=run_config,
1648+
error_handlers=error_handlers,
16461649
tool_use_tracker=tool_use_tracker,
16471650
server_manages_conversation=server_conversation_tracker is not None,
16481651
event_queue=streamed_result._event_queue,
@@ -1708,6 +1711,7 @@ async def run_single_turn(
17081711
session_items_to_rewind: list[TResponseInputItem] | None = None,
17091712
reasoning_item_id_policy: ReasoningItemIdPolicy | None = None,
17101713
prompt_cache_key_resolver: PromptCacheKeyResolver | None = None,
1714+
error_handlers: RunErrorHandlers[TContext] | None = None,
17111715
) -> SingleStepResult:
17121716
"""Run a single non-streaming turn of the agent loop."""
17131717
public_agent = bindings.public_agent
@@ -1775,6 +1779,7 @@ async def run_single_turn(
17751779
hooks=hooks,
17761780
context_wrapper=context_wrapper,
17771781
run_config=run_config,
1782+
error_handlers=error_handlers,
17781783
tool_use_tracker=tool_use_tracker,
17791784
server_manages_conversation=server_conversation_tracker is not None,
17801785
)

src/agents/run_internal/turn_resolution.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ..agent import Agent, ToolsToFinalOutputResult
4343
from ..agent_output import AgentOutputSchemaBase
4444
from ..agent_tool_state import get_agent_tool_state_scope, peek_agent_tool_run_result
45-
from ..exceptions import ModelBehaviorError, UserError
45+
from ..exceptions import ModelBehaviorError, ModelRefusalError, UserError
4646
from ..handoffs import Handoff, HandoffInputData, HandoffInputFilter, nest_handoff_history
4747
from ..items import (
4848
CompactionItem,
@@ -68,6 +68,7 @@
6868
from ..logger import logger
6969
from ..run_config import RunConfig
7070
from ..run_context import AgentHookContext, RunContextWrapper, TContext
71+
from ..run_error_handlers import RunErrorHandlers
7172
from ..run_state import RunState
7273
from ..stream_events import StreamEvent
7374
from ..tool import (
@@ -89,6 +90,13 @@
8990
from ..util import _coro, _error_tracing
9091
from ..util._approvals import evaluate_needs_approval_setting
9192
from .agent_bindings import AgentBindings
93+
from .error_handlers import (
94+
build_run_error_data,
95+
create_message_output_item,
96+
format_final_output_text,
97+
resolve_run_error_handler_result,
98+
validate_handler_final_output,
99+
)
92100
from .items import (
93101
REJECTION_MESSAGE,
94102
apply_patch_rejection_item,
@@ -555,6 +563,7 @@ async def execute_tools_and_side_effects(
555563
hooks: RunHooks[TContext],
556564
context_wrapper: RunContextWrapper[TContext],
557565
run_config: RunConfig,
566+
error_handlers: RunErrorHandlers[TContext] | None = None,
558567
server_manages_conversation: bool = False,
559568
) -> SingleStepResult:
560569
"""Run one turn of the loop, coordinating tools, approvals, guardrails, and handoffs."""
@@ -668,6 +677,7 @@ async def execute_tools_and_side_effects(
668677
return tool_final_output
669678

670679
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
680+
refusal = ItemHelpers.extract_refusal(message_items[-1].raw_item) if message_items else None
671681
potential_final_output_text = (
672682
ItemHelpers.extract_text(message_items[-1].raw_item) if message_items else None
673683
)
@@ -677,6 +687,41 @@ async def execute_tools_and_side_effects(
677687
processed_response.tools_used
678688
)
679689
if not has_tool_activity_without_message:
690+
if refusal:
691+
refusal_error = ModelRefusalError(refusal)
692+
run_error_data = build_run_error_data(
693+
input=original_input,
694+
new_items=pre_step_items + new_step_items,
695+
raw_responses=[new_response],
696+
last_agent=public_agent,
697+
)
698+
handler_result = await resolve_run_error_handler_result(
699+
error_handlers=error_handlers,
700+
error=refusal_error,
701+
context_wrapper=context_wrapper,
702+
run_data=run_error_data,
703+
)
704+
if handler_result is None:
705+
raise refusal_error
706+
707+
final_output = validate_handler_final_output(
708+
public_agent, handler_result.final_output
709+
)
710+
if handler_result.include_in_history:
711+
output_text = format_final_output_text(public_agent, final_output)
712+
new_step_items.append(create_message_output_item(public_agent, output_text))
713+
return await execute_final_output_call(
714+
public_agent=public_agent,
715+
original_input=original_input,
716+
new_response=new_response,
717+
pre_step_items=pre_step_items,
718+
new_step_items=new_step_items,
719+
final_output=final_output,
720+
hooks=hooks,
721+
context_wrapper=context_wrapper,
722+
tool_input_guardrail_results=tool_input_guardrail_results,
723+
tool_output_guardrail_results=tool_output_guardrail_results,
724+
)
680725
if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
681726
final_output = output_schema.validate_json(potential_final_output_text)
682727
return await execute_final_output_call(
@@ -1871,6 +1916,7 @@ async def get_single_step_result_from_response(
18711916
context_wrapper: RunContextWrapper[TContext],
18721917
run_config: RunConfig,
18731918
tool_use_tracker,
1919+
error_handlers: RunErrorHandlers[TContext] | None = None,
18741920
server_manages_conversation: bool = False,
18751921
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
18761922
before_side_effects: Callable[[], Awaitable[None]] | None = None,
@@ -1907,5 +1953,6 @@ async def get_single_step_result_from_response(
19071953
hooks=hooks,
19081954
context_wrapper=context_wrapper,
19091955
run_config=run_config,
1956+
error_handlers=error_handlers,
19101957
server_manages_conversation=server_manages_conversation,
19111958
)

0 commit comments

Comments
 (0)