Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 136 additions & 58 deletions api/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def _agent_chain(typ: str, started_at: float, step_id: str, payload: dict[str, A
"plan_id": "",
"tool": "text2sql_query",
"sql_draft": "",
"rewrite_query": "",
"planned_top_k": 10,
"preview_headlines": [],
"warnings": [],
"plan_execution_token": "",
"expires_in_sec": 120,
Expand Down Expand Up @@ -446,28 +449,51 @@ async def run(
# 不得沿用「已切到 fallback 工具」的 step1_mode(常见 rag),否则 Timeline 像已转 RAG 却无任何工具执行。
clarify_gate = os.getenv("CHATBI_V3_LOW_CONFIDENCE_CLARIFY", "").strip().lower() in ("1", "true", "yes")
from .chatbi_plan_token import (
mint_clarify_text2sql_bypass_token,
mint_clarify_plan_bypass_token,
plan_preview_confirm_enabled,
plan_token_ttl_s,
verify_clarify_text2sql_bypass_token,
verify_clarify_plan_bypass_token,
)

_plan_bypass = plan_preview_confirm_enabled() and verify_clarify_text2sql_bypass_token(
plan_execution_token, session_id=session_id, query=query
)
_clarify_eligible = (
_plan_bypass_tool: ToolName | None = None
if plan_preview_confirm_enabled():
if verify_clarify_plan_bypass_token(
plan_execution_token,
session_id=session_id,
query=query,
expected_tool="text2sql_query",
):
_plan_bypass_tool = "text2sql_query"
elif verify_clarify_plan_bypass_token(
plan_execution_token,
session_id=session_id,
query=query,
expected_tool="rag_search",
):
_plan_bypass_tool = "rag_search"

_clarify_tool: ToolName | None = None
if (
clarify_gate
and prefer == "auto"
and intent is not None
and intent.tool == "text2sql_query"
and intent.confidence < self._min_confidence
and not _plan_bypass
)
# 用户已持有效 plan_execution_token:本轮回放首步须回到意图候选 text2sql,而非低置信 fallback 的 rag。
if _plan_bypass:
step1_tool = "text2sql_query"
and _plan_bypass_tool is None
):
if intent.tool == "text2sql_query":
_clarify_tool = "text2sql_query"
elif intent.tool == "rag_search":
_clarify_tool = "rag_search"
_clarify_eligible = _clarify_tool is not None

# 用户已持有效 plan_execution_token:本轮回放首步须回到确认时的工具,而非低置信 fallback。
if _plan_bypass_tool:
step1_tool = _plan_bypass_tool
step1_mode = self._tool_to_mode(step1_tool)
step1_reasoning = "已校验 plan_execution_token,按用户确认放行执行 Text2SQL。"
if _plan_bypass_tool == "rag_search":
step1_reasoning = "已校验 plan_execution_token,按用户确认放行执行 RAG 检索。"
else:
step1_reasoning = "已校验 plan_execution_token,按用户确认放行执行 Text2SQL。"

# Step 循环:必须多步(允许成功在 2 步内结束,但失败应触发继续)
current_tool: ToolName = step1_tool
Expand Down Expand Up @@ -644,16 +670,30 @@ async def _emit_final_chains(fin: AgentFinalView, answer: str) -> None:
)
)

# P1-4 §4.3:低置信 + SQL 候选时可选「澄清短路」(默认关,避免改变现网行为)
if _clarify_eligible:
# P1-4 §4.3:低置信 + text2sql/rag 候选时可选「澄清短路」(默认关,避免改变现网行为)
if _clarify_eligible and _clarify_tool is not None:
_cl_msg = "待您澄清(低置信度)"
plan_preview_payload: dict[str, Any] | None = None
plan_ttl_s = plan_token_ttl_s()
ttl_notice = (
f"若确认按预览 SQL 继续查数:请在 {plan_ttl_s} 秒内在**下一轮同一问题**的请求 JSON 中带 "
f"`\"plan_execution_token\": \"…\"`(见 `agent.plan.preview` 中的 `plan_execution_token`)。"
"若未及时附带令牌,本预览 SQL 与该令牌均失效,须**重新发起本问题**才能再次预览。"
)
_clarify_mode = self._tool_to_mode(_clarify_tool)
if _clarify_tool == "rag_search":
ttl_notice = (
f"若确认按预览检索方案继续:请在 {plan_ttl_s} 秒内在**下一轮同一问题**的请求 JSON 中带 "
f"`\"plan_execution_token\": \"…\"`(见 `agent.plan.preview` 中的 `plan_execution_token`)。"
"若未及时附带令牌,本预览方案与该令牌均失效,须**重新发起本问题**才能再次预览。"
)
_preview_fail_hint = (
"(本轮未能生成可放行的 RAG 方案预览,无法签发 plan_execution_token;请改问或补充检索范围。)"
)
else:
ttl_notice = (
f"若确认按预览 SQL 继续查数:请在 {plan_ttl_s} 秒内在**下一轮同一问题**的请求 JSON 中带 "
f"`\"plan_execution_token\": \"…\"`(见 `agent.plan.preview` 中的 `plan_execution_token`)。"
"若未及时附带令牌,本预览 SQL 与该令牌均失效,须**重新发起本问题**才能再次预览。"
)
_preview_fail_hint = (
"(本轮未能生成可放行的 SQL 预览,无法签发 plan_execution_token;请改问或使用 prefer=text2sql。)"
)
use_reasoning = (os.getenv("CHATBI_V3_CLARIFY_PROMPT_USE_REASONING", "") or "").strip().lower() in (
"1",
"true",
Expand All @@ -673,35 +713,76 @@ async def _emit_final_chains(fin: AgentFinalView, answer: str) -> None:
_cl_prompt = _generic
if plan_preview_confirm_enabled():
_cl_prompt = (_cl_prompt.rstrip() + "\n\n" + ttl_notice).strip()
from .tools import text2sql_execute as _t2s_preview # noqa: PLC0415

_prev_hist: list[dict[str, Any]] = turn_history[-6:]
_t2s_json_ctx: dict[str, Any] | None = None
if run_id:
_t2s_json_ctx = {"request_id": run_id, "run_id": run_id, "session_id": session_id}
_pr = await _t2s_preview(
query,
history=_prev_hist,
debug_llm_prompts=debug_llm_prompts,
chain_emit=emit,
chain_started_at=ts_ref,
json_log_ctx=_t2s_json_ctx,
preview_only=True,
)
sql_pv = ""
if _pr.success and isinstance(_pr.data, dict) and isinstance(_pr.data.get("sql"), str):
sql_pv = (_pr.data.get("sql") or "").strip()
if sql_pv:
exec_tok = mint_clarify_text2sql_bypass_token(session_id=session_id, query=query)
plan_prev_id = str(uuid.uuid4()).replace("-", "")[:20]
plan_preview_payload = {
"plan_id": plan_prev_id,
"tool": "text2sql_query",
"sql_draft": sql_pv,
"warnings": [ttl_notice],
"plan_execution_token": exec_tok,
"expires_in_sec": plan_ttl_s,
}
plan_prev_id = str(uuid.uuid4()).replace("-", "")[:20]
if _clarify_tool == "text2sql_query":
from .tools import text2sql_execute as _t2s_preview # noqa: PLC0415

_t2s_json_ctx: dict[str, Any] | None = None
if run_id:
_t2s_json_ctx = {"request_id": run_id, "run_id": run_id, "session_id": session_id}
_pr = await _t2s_preview(
query,
history=_prev_hist,
debug_llm_prompts=debug_llm_prompts,
chain_emit=emit,
chain_started_at=ts_ref,
json_log_ctx=_t2s_json_ctx,
preview_only=True,
)
sql_pv = ""
if _pr.success and isinstance(_pr.data, dict) and isinstance(_pr.data.get("sql"), str):
sql_pv = (_pr.data.get("sql") or "").strip()
if sql_pv:
exec_tok = mint_clarify_plan_bypass_token(
session_id=session_id, query=query, tool="text2sql_query"
)
plan_preview_payload = {
"plan_id": plan_prev_id,
"tool": "text2sql_query",
"sql_draft": sql_pv,
"warnings": [ttl_notice],
"plan_execution_token": exec_tok,
"expires_in_sec": plan_ttl_s,
}
else:
from .tools import rag_search_execute as _rag_preview # noqa: PLC0415

_pr_rag = await _rag_preview(
query,
history=_prev_hist,
debug_llm_prompts=debug_llm_prompts,
preview_only=True,
)
rewrite_pv = ""
planned_k = 10
headlines: list[str] = []
if _pr_rag.success and isinstance(_pr_rag.data, dict):
rw = _pr_rag.data.get("rewritten")
if isinstance(rw, str):
rewrite_pv = rw.strip()
try:
planned_k = int(_pr_rag.data.get("planned_top_k") or 10)
except Exception: # noqa: BLE001
planned_k = 10
ph = _pr_rag.data.get("preview_headlines")
if isinstance(ph, list):
headlines = [str(x) for x in ph if x][:6]
if rewrite_pv:
exec_tok = mint_clarify_plan_bypass_token(
session_id=session_id, query=query, tool="rag_search"
)
plan_preview_payload = {
"plan_id": plan_prev_id,
"tool": "rag_search",
"rewrite_query": rewrite_pv,
"planned_top_k": planned_k,
"preview_headlines": headlines,
"warnings": [ttl_notice],
"plan_execution_token": exec_tok,
"expires_in_sec": plan_ttl_s,
}
if plan_preview_payload:
if emit is not None:
await emit(
_agent_chain(
Expand All @@ -718,15 +799,12 @@ async def _emit_final_chains(fin: AgentFinalView, answer: str) -> None:
run_id=run_id,
session_id=session_id,
route="agent",
mode="text2sql",
mode=_clarify_mode,
plan_id=plan_prev_id,
gate_bypass_reason="plan_preview_token_minted",
)
else:
_cl_prompt = (
_cl_prompt.rstrip()
+ "\n\n(本轮未能生成可放行的 SQL 预览,无法签发 plan_execution_token;请改问或使用 prefer=text2sql。)"
).strip()
_cl_prompt = (_cl_prompt.rstrip() + "\n\n" + _preview_fail_hint).strip()

clarify_pl: dict[str, Any] = {"step_number": 1, "message": _cl_msg, "prompt_for_user": _cl_prompt}
if chatbi_json_log_enabled() and run_id:
Expand All @@ -736,21 +814,21 @@ async def _emit_final_chains(fin: AgentFinalView, answer: str) -> None:
run_id=run_id,
session_id=session_id,
route="agent",
mode="text2sql",
intent_tool="text2sql_query",
mode=_clarify_mode,
intent_tool=_clarify_tool,
intent_confidence=float(intent.confidence),
clarify_gate=True,
)
_final_answer = (
"系统在继续查数前需要先与您对齐语义。请查看 Timeline 中「待您澄清」条目并补充说明;"
"也可改用 prefer=text2sql 强制路径或改写问题后重试。"
"也可改用 prefer=text2sql / prefer=rag 强制路径或改写问题后重试。"
)
final_cl = AgentFinalView(
answer=_final_answer,
mode="text2sql",
mode=_clarify_mode,
total_steps=0,
tools_used=[],
modes=["text2sql"],
modes=[_clarify_mode],
fallback_used=False,
)
if emit is not None:
Expand Down
42 changes: 34 additions & 8 deletions api/chatbi_plan_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""低置信 Text2SQL 澄清:一次性放行令牌(HMAC,无服务端会话表)。"""
"""低置信澄清:一次性放行令牌(HMAC,无服务端会话表)。"""

from __future__ import annotations

Expand All @@ -10,7 +10,8 @@
import time
from typing import Any

_PURPOSE = "clarify_text2sql_once"
_PURPOSE_LEGACY = "clarify_text2sql_once"
_PURPOSE = "clarify_plan_once"
_SIG_LEN = 32 # HMAC-SHA256 digest size;签名字节中可出现 ``0x0a``,不得用 ``rsplit(b"\\n")`` 定界


Expand Down Expand Up @@ -40,12 +41,13 @@ def plan_token_ttl_s() -> int:
return max(30, min(600, v))


def mint_clarify_text2sql_bypass_token(*, session_id: str | None, query: str) -> str:
"""签发「跳过一轮 clarify」令牌:绑定 session + 问句指纹 + TTL。"""
def mint_clarify_plan_bypass_token(*, session_id: str | None, query: str, tool: str) -> str:
"""签发「跳过一轮 clarify」令牌:绑定 session + 问句指纹 + 工具 + TTL。"""
ttl = plan_token_ttl_s()
payload: dict[str, Any] = {
"v": 1,
"p": _PURPOSE,
"t": (tool or "").strip()[:64],
"sid": (session_id or "")[:256],
"qh": _query_fingerprint(query),
"exp": int(time.time()) + ttl,
Expand All @@ -56,14 +58,28 @@ def mint_clarify_text2sql_bypass_token(*, session_id: str | None, query: str) ->
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")


def mint_clarify_text2sql_bypass_token(*, session_id: str | None, query: str) -> str:
"""兼容 5-2:等价于 ``mint_clarify_plan_bypass_token(..., tool=text2sql_query)``。"""
return mint_clarify_plan_bypass_token(session_id=session_id, query=query, tool="text2sql_query")


def _b64url_decode_padded(token: str) -> bytes:
"""urlsafe base64 无填充串的可靠解码(禁止固定追加 ``==``,否则部分长度会误解码/验签偶发失败)。"""
t = token.strip()
pad = (-len(t)) % 4
return base64.urlsafe_b64decode(t + ("=" * pad))


def verify_clarify_text2sql_bypass_token(token: str | None, *, session_id: str | None, query: str) -> bool:
def _payload_tool(payload: dict[str, Any]) -> str:
if payload.get("p") == _PURPOSE_LEGACY:
return "text2sql_query"
t = payload.get("t")
return t if isinstance(t, str) else ""


def verify_clarify_plan_bypass_token(
token: str | None, *, session_id: str | None, query: str, expected_tool: str
) -> bool:
if not isinstance(token, str) or not token.strip():
return False
try:
Expand All @@ -84,17 +100,27 @@ def verify_clarify_text2sql_bypass_token(token: str | None, *, session_id: str |
payload = json.loads(body_b.decode("utf-8"))
except Exception: # noqa: BLE001
return False
if payload.get("p") != _PURPOSE or int(payload.get("exp") or 0) < int(time.time()):
purpose = payload.get("p")
if purpose not in (_PURPOSE, _PURPOSE_LEGACY):
return False
if int(payload.get("exp") or 0) < int(time.time()):
return False
if (payload.get("sid") or "") != (session_id or ""):
return False
if (payload.get("qh") or "") != _query_fingerprint(query):
return False
return True
tool = _payload_tool(payload)
return tool == (expected_tool or "").strip()


def verify_clarify_text2sql_bypass_token(token: str | None, *, session_id: str | None, query: str) -> bool:
return verify_clarify_plan_bypass_token(
token, session_id=session_id, query=query, expected_tool="text2sql_query"
)


def plan_preview_confirm_enabled() -> bool:
"""低置信澄清时是否走 SQL 预览 + plan_execution_token。
"""低置信澄清时是否走方案预览 + plan_execution_token。

默认 **开启**(未设置或空字符串视为开);显式 ``0``/``false``/``no``/``off`` 关闭。
仍须 ``CHATBI_V3_LOW_CONFIDENCE_CLARIFY`` 开启且命中澄清分支才会实际预览。
Expand Down
25 changes: 25 additions & 0 deletions api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ async def rag_search_execute(
*,
history: list[dict[str, Any]] | None = None,
debug_llm_prompts: bool = False,
preview_only: bool = False,
) -> ToolResult:
started_at = time.perf_counter()
hist = history or []
Expand Down Expand Up @@ -244,6 +245,30 @@ def _sync_rw() -> str:
latency_ms=_elapsed_ms(started_at),
)

if preview_only:
planned_top_k = int(retrieved.get("top_k") or 10)
headlines: list[str] = []
for h in hits[:6]:
if not isinstance(h, dict):
continue
label = (
h.get("filename")
or h.get("title")
or h.get("path")
or h.get("url")
or h.get("id")
)
if isinstance(label, str) and label.strip():
headlines.append(label.strip()[:120])
out_preview: dict[str, Any] = {
"rewritten": rewritten,
"planned_top_k": planned_top_k,
"preview_headlines": headlines,
}
if debug_llm_prompts and llm_prompts:
out_preview["llm_prompts"] = llm_prompts
return ToolResult(success=True, data=out_preview, latency_ms=_elapsed_ms(started_at))

parts: list[str] = []
for i, h in enumerate(hits[:12]):
content = h.get("content") if isinstance(h, dict) else None
Expand Down
Loading