Skip to content

Commit 1381ccb

Browse files
authored
feat: default init agent optimze (#163)
1 parent 8957277 commit 1381ccb

13 files changed

Lines changed: 483 additions & 353 deletions

File tree

packages/derisk-app/src/derisk_app/app.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
from fastapi.openapi.docs import get_swagger_ui_html
88
from fastapi.staticfiles import StaticFiles
99

10+
11+
class WebSocketAwareStaticFiles(StaticFiles):
12+
"""StaticFiles that gracefully handles WebSocket connections."""
13+
14+
async def __call__(self, scope, receive, send):
15+
if scope["type"] == "websocket":
16+
# WebSocket connections should not reach static files
17+
# Let them fall through to 404 or other handlers
18+
await send({"type": "websocket.close", "code": 1000})
19+
return
20+
await super().__call__(scope, receive, send)
21+
22+
1023
from derisk._version import version
1124
from derisk.component import SystemApp
1225
from derisk.configs.model_config import (
@@ -131,14 +144,18 @@ def mount_static_files(app: FastAPI, param: ApplicationConfig):
131144
os.makedirs(STATIC_MESSAGE_IMG_PATH, exist_ok=True)
132145
app.mount(
133146
"/images",
134-
StaticFiles(directory=STATIC_MESSAGE_IMG_PATH, html=True),
147+
WebSocketAwareStaticFiles(directory=STATIC_MESSAGE_IMG_PATH, html=True),
135148
name="static2",
136149
)
137-
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
150+
app.mount(
151+
"/",
152+
WebSocketAwareStaticFiles(directory=static_file_path, html=True),
153+
name="static",
154+
)
138155

139156
app.mount(
140157
"/swagger_static",
141-
StaticFiles(directory=static_file_path),
158+
WebSocketAwareStaticFiles(directory=static_file_path),
142159
name="swagger_static",
143160
)
144161

packages/derisk-core/src/derisk/agent/core/memory/gpts/gpts_memory.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,89 @@
5656
# 消息通道迭代器
5757
# --------------------------
5858
class QueueIterator:
59-
def __init__(self, queue: asyncio.Queue):
59+
"""异步队列迭代器,支持超时机制防止无限等待。
60+
61+
当队列长时间没有新消息时,会定期检查是否有错误发生,
62+
避免因后台任务卡死或异常导致的无限等待。
63+
"""
64+
65+
DEFAULT_TIMEOUT = 30.0
66+
MAX_TIMEOUT_COUNT = 3 # 最大连续超时次数,超过后抛出异常
67+
68+
def __init__(self, queue: asyncio.Queue, timeout: Optional[float] = None):
6069
self.queue = queue
70+
self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT
71+
self._error: Optional[Exception] = None
72+
self._stopped = False
73+
self._timeout_count = 0 # 连续超时计数器
74+
75+
def set_error(self, error: Exception):
76+
"""设置错误状态,将在下次迭代时抛出。"""
77+
self._error = error
78+
try:
79+
self.queue.put_nowait(None)
80+
except asyncio.QueueFull:
81+
pass
82+
83+
def stop(self):
84+
"""停止迭代器。"""
85+
self._stopped = True
86+
try:
87+
self.queue.put_nowait("[DONE]")
88+
except asyncio.QueueFull:
89+
pass
6190

6291
def __aiter__(self):
6392
return self
6493

6594
async def __anext__(self):
6695
start = time.perf_counter()
67-
item = await self.queue.get()
68-
if item == "[DONE]":
69-
self.queue.task_done()
70-
raise StopAsyncIteration
71-
logger.debug(f"Queue wait: {(time.perf_counter() - start) * 1000:.2f}ms")
72-
try:
73-
return item
74-
finally:
75-
self.queue.task_done()
96+
97+
while True:
98+
if self._error:
99+
raise self._error
100+
101+
if self._stopped:
102+
raise StopAsyncIteration
103+
104+
try:
105+
item = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
106+
self._timeout_count = 0 # 成功获取消息,重置超时计数
107+
except asyncio.TimeoutError:
108+
self._timeout_count += 1
109+
wait_time = time.perf_counter() - start
110+
111+
if self._timeout_count >= self.MAX_TIMEOUT_COUNT:
112+
logger.error(
113+
f"Queue timeout exceeded max retries ({self.MAX_TIMEOUT_COUNT}), "
114+
f"total wait: {wait_time:.2f}s. Terminating to prevent infinite wait."
115+
)
116+
raise TimeoutError(
117+
f"对话响应超时,已等待 {wait_time:.1f} 秒。"
118+
f"请检查后端服务状态或稍后重试。"
119+
)
120+
121+
logger.warning(
122+
f"Queue wait timeout ({self._timeout_count}/{self.MAX_TIMEOUT_COUNT}) "
123+
f"after {wait_time:.2f}s, continuing to wait... (queue size: {self.queue.qsize()})"
124+
)
125+
continue
126+
127+
if item == "[DONE]":
128+
self.queue.task_done()
129+
raise StopAsyncIteration
130+
131+
if item is None:
132+
if self._error:
133+
raise self._error
134+
self.queue.task_done()
135+
continue
136+
137+
logger.debug(f"Queue wait: {(time.perf_counter() - start) * 1000:.2f}ms")
138+
try:
139+
return item
140+
finally:
141+
self.queue.task_done()
76142

77143

78144
class AgentTaskType(Enum):
@@ -616,9 +682,11 @@ async def _merge_messages_async(
616682
# --------------------------
617683
# 外部核心方法区
618684
# --------------------------
619-
async def queue_iterator(self, conv_id: str) -> Optional[QueueIterator]:
685+
async def queue_iterator(
686+
self, conv_id: str, timeout: Optional[float] = None
687+
) -> Optional[QueueIterator]:
620688
cache = await self._get_cache(conv_id)
621-
return QueueIterator(cache.channel) if cache else None
689+
return QueueIterator(cache.channel, timeout=timeout) if cache else None
622690

623691
async def init(
624692
self,

packages/derisk-core/src/derisk/agent/core_v2/enhanced_agent.py

Lines changed: 4 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from .improved_compaction import ImprovedSessionCompaction, CompactionConfig
2323
from .llm_utils import call_llm, LLMCaller
24+
from .tools_v2 import ToolRegistry, ToolResult
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -154,88 +155,6 @@ def _match_pattern(self, tool_name: str, pattern: str) -> bool:
154155
return fnmatch.fnmatch(tool_name, pattern)
155156

156157

157-
class ToolRegistry:
158-
"""工具注册表"""
159-
160-
def __init__(self):
161-
self._tools: Dict[str, Any] = {}
162-
163-
def register(self, tool: Any) -> "ToolRegistry":
164-
# 优先使用 metadata.name,其次使用 name 属性
165-
if hasattr(tool, "metadata") and hasattr(tool.metadata, "name"):
166-
name = tool.metadata.name
167-
elif hasattr(tool, "name"):
168-
name = tool.name
169-
else:
170-
name = str(tool)
171-
self._tools[name] = tool
172-
logger.debug(f"[ToolRegistry] 注册工具: {name}")
173-
return self
174-
175-
def get(self, name: str) -> Optional[Any]:
176-
return self._tools.get(name)
177-
178-
async def execute(
179-
self, name: str, args: Dict[str, Any], context: Optional[Dict[str, Any]] = None
180-
) -> "ToolResult":
181-
"""执行工具"""
182-
from .tools_v2 import ToolResult
183-
184-
tool = self._tools.get(name)
185-
if not tool:
186-
return ToolResult(success=False, output="", error=f"工具不存在: {name}")
187-
188-
try:
189-
if hasattr(tool, "execute"):
190-
result = await tool.execute(args, context)
191-
return result
192-
elif callable(tool):
193-
result = tool(**args)
194-
if hasattr(result, "__await__"):
195-
result = await result
196-
if isinstance(result, ToolResult):
197-
return result
198-
return ToolResult(
199-
success=True,
200-
output=str(result) if result else "",
201-
)
202-
else:
203-
return ToolResult(
204-
success=False, output="", error=f"工具不可执行: {name}"
205-
)
206-
except Exception as e:
207-
logger.exception(f"[ToolRegistry] 工具执行异常: {name}")
208-
return ToolResult(success=False, output="", error=str(e))
209-
210-
def list_tools(self) -> List[str]:
211-
return list(self._tools.keys())
212-
213-
def list_all(self) -> List[Any]:
214-
"""列出所有工具对象"""
215-
return list(self._tools.values())
216-
217-
def list_names(self) -> List[str]:
218-
"""列出所有工具名称"""
219-
return list(self._tools.keys())
220-
221-
def get_openai_tools(self) -> List[Dict[str, Any]]:
222-
result = []
223-
for name, tool in self._tools.items():
224-
if hasattr(tool, "get_openai_spec"):
225-
result.append(tool.get_openai_spec())
226-
else:
227-
result.append(
228-
{
229-
"type": "function",
230-
"function": {
231-
"name": name,
232-
"description": getattr(tool, "description", ""),
233-
},
234-
}
235-
)
236-
return result
237-
238-
239158
@dataclass
240159
class SubagentSession:
241160
"""子代理会话"""
@@ -1119,8 +1038,9 @@ def _build_llm_messages(self) -> List:
11191038
else:
11201039
messages.append(SystemMessage(content=msg.content))
11211040

1122-
if self.tools.list_tools():
1123-
tools_desc = "Available tools: " + ", ".join(self.tools.list_tools())
1041+
if self.tools.list_all():
1042+
tool_names = [t.metadata.name for t in self.tools.list_all()]
1043+
tools_desc = "Available tools: " + ", ".join(tool_names)
11241044
messages.append(SystemMessage(content=tools_desc))
11251045

11261046
return messages

packages/derisk-core/src/derisk/agent/core_v2/integration/runtime.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -745,18 +745,28 @@ async def _push_user_message(self, conv_id: str, message: str):
745745

746746
# 保存到 GptsMemory (gpts_messages 表)
747747
if self.gpts_memory:
748-
user_msg = type(
749-
"GptsMessage",
750-
(),
751-
{
752-
"message_id": str(uuid.uuid4().hex),
753-
"conv_id": conv_id,
754-
"sender": "user",
755-
"receiver": "assistant",
756-
"content": message,
757-
"rounds": 0,
758-
},
759-
)()
748+
# Find the session to get additional info
749+
session = None
750+
for s in self._sessions.values():
751+
if s.conv_id == conv_id:
752+
session = s
753+
break
754+
755+
# Create a proper GptsMessage with all required attributes
756+
user_msg = GptsMessage(
757+
conv_id=conv_id,
758+
conv_session_id=session.session_id if session else conv_id,
759+
message_id=str(uuid.uuid4().hex),
760+
sender="user",
761+
sender_name="user",
762+
receiver="assistant",
763+
receiver_name="assistant",
764+
role="user",
765+
content=message,
766+
rounds=0,
767+
app_code=session.agent_name if session else None,
768+
app_name=session.agent_name if session else None,
769+
)
760770
await self.gpts_memory.append_message(conv_id, user_msg, save_db=True)
761771

762772
# 同时保存到 StorageConversation (ChatHistoryMessageEntity 表)
@@ -865,19 +875,21 @@ async def _push_stream_chunk(self, conv_id: str, chunk: V2StreamChunk):
865875

866876
# 保存到 GptsMemory (gpts_messages 表)
867877
if self.gpts_memory and session.accumulated_content:
868-
assistant_msg = type(
869-
"GptsMessage",
870-
(),
871-
{
872-
"message_id": session.current_message_id
873-
or str(uuid.uuid4().hex),
874-
"conv_id": conv_id,
875-
"sender": session.agent_name,
876-
"receiver": "user",
877-
"content": vis_final_content,
878-
"rounds": 0,
879-
},
880-
)()
878+
# Create a proper GptsMessage with all required attributes
879+
assistant_msg = GptsMessage(
880+
conv_id=conv_id,
881+
conv_session_id=session.session_id,
882+
message_id=session.current_message_id or str(uuid.uuid4().hex),
883+
sender=session.agent_name,
884+
sender_name=session.agent_name,
885+
receiver="user",
886+
receiver_name="user",
887+
role="assistant",
888+
content=vis_final_content,
889+
rounds=0,
890+
app_code=session.agent_name,
891+
app_name=session.agent_name,
892+
)
881893
await self.gpts_memory.append_message(
882894
conv_id, assistant_msg, save_db=True
883895
)

packages/derisk-core/src/derisk/agent/core_v2/tools_v2/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,21 @@
3535
from derisk.agent.tools.registry import (
3636
ToolRegistry,
3737
tool_registry,
38-
register_builtin_tools,
3938
)
4039

40+
41+
def register_builtin_tools(registry: ToolRegistry) -> None:
42+
"""
43+
注册内置工具到注册表(兼容层)
44+
45+
此函数接受 registry 参数,以兼容旧版 API。
46+
内部调用 derisk.agent.tools.builtin.register_all
47+
"""
48+
from derisk.agent.tools.builtin import register_all
49+
50+
register_all(registry)
51+
52+
4153
from derisk.agent.tools.decorators import tool
4254

4355
# 从内置工具模块导入具体工具
@@ -167,7 +179,7 @@ def register_all_tools(
167179
import logging
168180

169181
logger = logging.getLogger(__name__)
170-
logger.info(f"[Tools] 已注册所有工具,共 {len(registry.list_names())} 个")
182+
logger.info(f"[Tools] 已注册所有工具,共 {len(registry.list_all())} 个")
171183

172184
return registry
173185

0 commit comments

Comments
 (0)