Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/application/use_cases/send_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,39 @@ def __init__(self, registry: AgentRegistry, threads: ThreadRepository):
self._registry = registry
self._threads = threads

@staticmethod
def _is_duplicate_human_message(messages: list, message: str) -> bool:
"""Detect duplicate HUMAN message submissions (crash/retry scenario).

When a request crashes before the AI response is persisted, the last DB message
is HUMAN with status=None. On client retry, this check prevents storing a
duplicate HUMAN message in the DB.

NOTE: The graph invocation still proceeds (LangGraph will add the human message
to its internal checkpoint state). This is intentional — the graph needs to be
invoked to produce a response. The trade-off is that the LangGraph checkpoint
may accumulate duplicate human messages, but the DB projection remains clean.
"""
if not messages:
return False
last = messages[-1]
return (
last.role == MessageRole.HUMAN
and last.content == message
and last.status is None
)

async def execute(self, thread_id: str, request: ChatRequest) -> Message:
thread = await self._threads.get(thread_id)
runner = await self._registry.get_runner(thread.agent_name)

if request.message is not None:
logger.info("[thread=%s][agent=%s] Sending human message", thread_id, thread.agent_name)
human_msg = Message(role=MessageRole.HUMAN, content=request.message)
await self._threads.add_message(thread_id, human_msg)
if not self._is_duplicate_human_message(thread.messages, request.message):
human_msg = Message(role=MessageRole.HUMAN, content=request.message)
await self._threads.add_message(thread_id, human_msg)
else:
logger.info("[thread=%s] Skipping duplicate HUMAN message", thread_id)
start = time.monotonic()
response = await runner.invoke(thread_id, request.message)
elapsed = time.monotonic() - start
Expand Down
31 changes: 28 additions & 3 deletions src/application/use_cases/stream_message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import sys
import time
import json
from collections.abc import AsyncGenerator

from src.domain.entities.message import Message, MessageRole
Expand All @@ -25,10 +25,35 @@ def __init__(self, registry: AgentRegistry, threads: ThreadRepository):
self._registry = registry
self._threads = threads

@staticmethod
def _is_duplicate_human_message(messages: list, message: str) -> bool:
"""Detect duplicate HUMAN message submissions (crash/retry scenario).

When a stream crashes before the AI response is persisted, the last DB message
is HUMAN with status=None. On client retry, this check prevents storing a
duplicate HUMAN message in the DB.

NOTE: The graph invocation still proceeds (LangGraph will add the human message
to its internal checkpoint state). This is intentional — the graph needs to be
invoked to produce a response. The trade-off is that the LangGraph checkpoint
may accumulate duplicate human messages, but the DB projection remains clean.
"""
if not messages:
return False
last = messages[-1]
return (
last.role == MessageRole.HUMAN
and last.content == message
and last.status is None
)

async def execute(self, thread_id: str, message: str) -> AsyncGenerator[StreamEvent, None]:
thread = await self._threads.get(thread_id)
human_msg = Message(role=MessageRole.HUMAN, content=message)
await self._threads.add_message(thread_id, human_msg)
if not self._is_duplicate_human_message(thread.messages, message):
human_msg = Message(role=MessageRole.HUMAN, content=message)
await self._threads.add_message(thread_id, human_msg)
else:
logger.info("[thread=%s] Skipping duplicate HUMAN message", thread_id)
runner = await self._registry.get_runner(thread.agent_name)
start = time.monotonic()
logger.info("[thread=%s][agent=%s] Stream started", thread_id, thread.agent_name)
Expand Down
26 changes: 26 additions & 0 deletions src/infrastructure/deepagent/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,32 @@ def __init__(self, graph, tracing_provider: TracingProvider | None = None, respo
self._graph = graph
self._tracing_provider = tracing_provider
self._response_format_model = response_format_model
self._patch_tool_node_error_handling()

def _patch_tool_node_error_handling(self) -> None:
"""Patch ToolNode to catch all tool errors (not just ToolInvocationError).

By default, LangGraph's ToolNode only catches ToolInvocationError, which means
Pydantic ValidationError from hallucinated parameters crashes the graph. Setting
_handle_tool_errors=True causes any exception to be surfaced as a ToolMessage,
allowing the LLM to self-correct.

TODO: Prefer configuring handle_tool_errors=True at ToolNode construction time
in the factory, rather than monkey-patching at runtime.
"""
tools_node = self._graph.nodes.get("tools")
if tools_node is None:
logger.warning("No 'tools' node found in graph; cannot patch handle_tool_errors")
return
tool_node_impl = getattr(tools_node, "bound", None)
if tool_node_impl is None:
logger.warning("'tools' node has no 'bound' attribute; cannot patch handle_tool_errors")
return
if hasattr(tool_node_impl, "_handle_tool_errors"):
tool_node_impl._handle_tool_errors = True
logger.info("Patched ToolNode handle_tool_errors=True")
else:
logger.warning("ToolNode bound object missing _handle_tool_errors; patch not applied")

@staticmethod
def _try_parse_json(content: str) -> dict | None:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_deep_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,41 @@ async def test_validate_structured_response_from_tool_call(self):

assert result.structured_response == {"summary": "ok"}
assert "hallucinated" not in result.structured_response

# --- ToolNode error handling patch tests ---

def test_patch_tool_error_handling_sets_true(self):
tools_node = MagicMock()
tool_node_impl = MagicMock()
tool_node_impl._handle_tool_errors = MagicMock()
tools_node.bound = tool_node_impl

graph = MagicMock()
graph.nodes = {"tools": tools_node}

DeepAgentRunner(graph)
assert tool_node_impl._handle_tool_errors is True

def test_patch_tool_error_handling_no_tools_node(self):
graph = MagicMock()
graph.nodes = {"model": MagicMock()}

DeepAgentRunner(graph)

def test_patch_tool_error_handling_no_bound(self):
tools_node = MagicMock(spec=["bound"])
del tools_node.bound
graph = MagicMock()
graph.nodes = {"tools": tools_node}

DeepAgentRunner(graph)

def test_patch_tool_error_handling_no_handle_tool_errors_attr(self):
tools_node = MagicMock()
tool_node_impl = MagicMock(spec=[])
tools_node.bound = tool_node_impl

graph = MagicMock()
graph.nodes = {"tools": tools_node}

DeepAgentRunner(graph)
40 changes: 39 additions & 1 deletion tests/unit/test_send_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def test_unsupported_hitl_action_raises_value_error(self, registry, thread
thread = await thread_repo.create("test-agent")
use_case = SendMessageUseCase(registry, thread_repo)

# Bypass Pydantic validation to simulate an unexpected action value
request = MagicMock(spec=ChatRequest)
request.message = None
request.action = "unknown_action"
Expand All @@ -95,3 +94,42 @@ async def test_unsupported_hitl_action_raises_value_error(self, registry, thread

with pytest.raises(ValueError, match="Unsupported HITL action"):
await use_case.execute(thread.id, request)

# --- _is_duplicate_human_message tests ---

def test_is_duplicate_true_when_last_is_human_same_content_no_status(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello")]
assert SendMessageUseCase._is_duplicate_human_message(messages, "Hello") is True

def test_is_duplicate_false_when_last_is_human_different_content(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello")]
assert SendMessageUseCase._is_duplicate_human_message(messages, "World") is False

def test_is_duplicate_false_when_last_is_human_with_status(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello", status=MessageStatus.COMPLETED)]
assert SendMessageUseCase._is_duplicate_human_message(messages, "Hello") is False

def test_is_duplicate_false_when_last_is_ai(self):
messages = [Message(role=MessageRole.AI, content="Hello")]
assert SendMessageUseCase._is_duplicate_human_message(messages, "Hello") is False

def test_is_duplicate_false_when_empty_messages(self):
assert SendMessageUseCase._is_duplicate_human_message([], "Hello") is False

def test_is_duplicate_false_when_mixed_messages_last_is_ai(self):
messages = [
Message(role=MessageRole.HUMAN, content="Hello"),
Message(role=MessageRole.AI, content="Hi there"),
]
assert SendMessageUseCase._is_duplicate_human_message(messages, "Hello") is False

async def test_execute_skips_duplicate_human_message_when_last_is_pending_human(self, registry, thread_repo):
thread = await thread_repo.create("test-agent")
await thread_repo.add_message(thread.id, Message(role=MessageRole.HUMAN, content="Hello agent!"))

use_case = SendMessageUseCase(registry, thread_repo)
await use_case.execute(thread.id, ChatRequest(message="Hello agent!"))

updated = await thread_repo.get(thread.id)
human_msgs = [m for m in updated.messages if m.role == MessageRole.HUMAN]
assert len(human_msgs) == 1
82 changes: 82 additions & 0 deletions tests/unit/test_stream_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for StreamMessageUseCase.

Uses real InMemoryThreadRepository (internal).
Uses AsyncMock for AgentRunner (external - calls LLM).
"""

from unittest.mock import AsyncMock

import pytest

from src.application.use_cases.stream_message import StreamMessageUseCase
from src.domain.entities.message import Message, MessageRole, MessageStatus
from src.domain.entities.stream_event import StreamEvent, StreamEventType
from src.domain.ports.agent_runner import AgentRunner


class TestIsDuplicateHumanMessage:
def test_true_when_last_is_human_same_content_no_status(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello")]
assert StreamMessageUseCase._is_duplicate_human_message(messages, "Hello") is True

def test_false_when_last_is_human_different_content(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello")]
assert StreamMessageUseCase._is_duplicate_human_message(messages, "World") is False

def test_false_when_last_is_human_with_status(self):
messages = [Message(role=MessageRole.HUMAN, content="Hello", status=MessageStatus.COMPLETED)]
assert StreamMessageUseCase._is_duplicate_human_message(messages, "Hello") is False

def test_false_when_last_is_ai(self):
messages = [Message(role=MessageRole.AI, content="Hello")]
assert StreamMessageUseCase._is_duplicate_human_message(messages, "Hello") is False

def test_false_when_empty_messages(self):
assert StreamMessageUseCase._is_duplicate_human_message([], "Hello") is False

def test_false_when_mixed_messages_last_is_ai(self):
messages = [
Message(role=MessageRole.HUMAN, content="Hello"),
Message(role=MessageRole.AI, content="Hi there"),
]
assert StreamMessageUseCase._is_duplicate_human_message(messages, "Hello") is False


class TestStreamMessageUseCase:
@pytest.fixture
def runner(self):
mock = AsyncMock(spec=AgentRunner)

async def _stream_with_message(_thread_id, _message):
yield StreamEvent(type=StreamEventType.CONTENT, data="Hi")
yield StreamEvent(
type=StreamEventType.MESSAGE,
data=Message(
role=MessageRole.AI,
content="Hi",
status=MessageStatus.COMPLETED,
).model_dump_json(),
)

mock.stream_with_message = _stream_with_message
return mock

@pytest.fixture
def registry(self, runner):
mock = AsyncMock()
mock.get_runner.return_value = runner
return mock

async def test_execute_skips_duplicate_human_message(self, registry, thread_repo):
thread = await thread_repo.create("test-agent")
await thread_repo.add_message(thread.id, Message(role=MessageRole.HUMAN, content="Hello"))

use_case = StreamMessageUseCase(registry, thread_repo)

events = []
async for event in use_case.execute(thread.id, "Hello"):
events.append(event)

updated = await thread_repo.get(thread.id)
human_msgs = [m for m in updated.messages if m.role == MessageRole.HUMAN]
assert len(human_msgs) == 1
Loading