From 1da2502ce1aac62a3f5916fa32479107e379cd77 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:34:44 +0200 Subject: [PATCH 1/8] feat(ai): provider-neutral tool-calling wire + bounded agent loop (#714) Add the internal layer for LLM tool/function calling, shared across QuickAdd's three provider wire shapes (OpenAI-compatible, Anthropic Messages, Gemini generateContent): - NormalizedTools: provider-neutral request/message/tool-call types. - providerToolMapping: buildChatBody/parseChatResponse per provider, incl. OpenAI JSON-string arg parsing, Anthropic tool_result turns + native output_config.format (additionalProperties injected, as Anthropic requires), Gemini functionCall/Response + thoughtSignature echo, and max_completion_tokens for GPT-5.x/o-series reasoning models. - runToolLoop: pure bounded multi-step loop (forced-final text turn, one result per call id, step/result-byte/cumulative-transcript-byte caps, graceful context-overflow + online-disabled stops). - jsonSchemaValidator: registration-time keyword allowlist + value validation. - sanitizeVaultPath: every-segment dot-floor + traversal/absolute rejection. - OpenAIRequest: chatRequest() dispatch + Anthropic request-path fix; legacy single-shot routing unified on provider kind. - Provider: ProviderKind discriminator + getProviderKind() + migration. --- src/ai/OpenAIRequest.test.ts | 74 +++- src/ai/OpenAIRequest.ts | 245 ++++++++++- src/ai/Provider.test.ts | 27 ++ src/ai/Provider.ts | 38 ++ src/ai/tools/NormalizedTools.ts | 117 ++++++ src/ai/tools/aiToolTypes.ts | 126 ++++++ src/ai/tools/assignableVariable.ts | 26 ++ src/ai/tools/jsonSchemaValidator.test.ts | 101 +++++ src/ai/tools/jsonSchemaValidator.ts | 218 ++++++++++ src/ai/tools/providerToolMapping.test.ts | 265 ++++++++++++ src/ai/tools/providerToolMapping.ts | 502 +++++++++++++++++++++++ src/ai/tools/runToolLoop.test.ts | 238 +++++++++++ src/ai/tools/runToolLoop.ts | 326 +++++++++++++++ src/ai/tools/sanitizeVaultPath.test.ts | 116 ++++++ src/ai/tools/sanitizeVaultPath.ts | 144 +++++++ src/migrations/migrate.ts | 2 + src/migrations/setProviderKind.ts | 37 ++ 17 files changed, 2579 insertions(+), 23 deletions(-) create mode 100644 src/ai/Provider.test.ts create mode 100644 src/ai/tools/NormalizedTools.ts create mode 100644 src/ai/tools/aiToolTypes.ts create mode 100644 src/ai/tools/assignableVariable.ts create mode 100644 src/ai/tools/jsonSchemaValidator.test.ts create mode 100644 src/ai/tools/jsonSchemaValidator.ts create mode 100644 src/ai/tools/providerToolMapping.test.ts create mode 100644 src/ai/tools/providerToolMapping.ts create mode 100644 src/ai/tools/runToolLoop.test.ts create mode 100644 src/ai/tools/runToolLoop.ts create mode 100644 src/ai/tools/sanitizeVaultPath.test.ts create mode 100644 src/ai/tools/sanitizeVaultPath.ts create mode 100644 src/migrations/setProviderKind.ts diff --git a/src/ai/OpenAIRequest.test.ts b/src/ai/OpenAIRequest.test.ts index cf880113..0bf28c8d 100644 --- a/src/ai/OpenAIRequest.test.ts +++ b/src/ai/OpenAIRequest.test.ts @@ -367,7 +367,7 @@ describe("OpenAIRequest", () => { getModelProviderMock.mockReturnValue(anthropicProvider); }); - it("posts to /v1/messages with anthropic headers and a user-only message", async () => { + it("posts to /v1/messages with a top-level system prompt, model-aware max_tokens, and no stale beta header", async () => { requestUrlMock.mockResolvedValue({ json: { id: "msg-1", @@ -391,21 +391,89 @@ describe("OpenAIRequest", () => { const arg = requestUrlMock.mock.calls[0][0]; expect(arg.url).toBe("https://api.anthropic.com/v1/messages"); + // Stale anthropic-beta header dropped. expect(arg.headers).toEqual({ "Content-Type": "application/json", "x-api-key": "anthropic-key", "anthropic-version": "2023-06-01", - "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15", }); const body = JSON.parse(arg.body); expect(body).toEqual({ model: "claude-3-5-sonnet", - max_tokens: 4096, + // min(8192, 200000) — derived from the model's context window, not hardcoded 4096. + max_tokens: 8192, messages: [{ role: "user", content: "hello claude" }], + system: "system prompt", }); }); + it("omits the system key when no system prompt is given", async () => { + requestUrlMock.mockResolvedValue({ + json: { + id: "msg-1b", + model: "claude-3-5-sonnet", + role: "assistant", + stop_reason: "end_turn", + stop_sequence: null, + type: "message", + content: [{ text: "ok", type: "text" }], + usage: { input_tokens: 1, output_tokens: 1 }, + }, + }); + + const makeRequest = OpenAIRequest(makeApp(), "anthropic-key", anthropicModel, ""); + await makeRequest("hi"); + + const body = JSON.parse(requestUrlMock.mock.calls[0][0].body); + expect("system" in body).toBe(false); + expect(body.max_tokens).toBe(8192); + }); + + it("omits the system key for a whitespace-only system prompt", async () => { + requestUrlMock.mockResolvedValue({ + json: { + id: "msg-1c", + model: "claude-3-5-sonnet", + role: "assistant", + stop_reason: "end_turn", + stop_sequence: null, + type: "message", + content: [{ text: "ok", type: "text" }], + usage: { input_tokens: 1, output_tokens: 1 }, + }, + }); + + const makeRequest = OpenAIRequest(makeApp(), "anthropic-key", anthropicModel, " "); + await makeRequest("hi"); + + const body = JSON.parse(requestUrlMock.mock.calls[0][0].body); + expect("system" in body).toBe(false); + }); + + it("extracts text by scanning all content blocks (a leading tool_use block does not break it)", async () => { + requestUrlMock.mockResolvedValue({ + json: { + id: "msg-3", + model: "claude-3-5-sonnet", + role: "assistant", + stop_reason: "tool_use", + stop_sequence: null, + type: "message", + content: [ + { type: "tool_use", id: "toolu_1", name: "x", input: {} }, + { type: "text", text: "after the tool block" }, + ], + usage: { input_tokens: 3, output_tokens: 2 }, + }, + }); + + const makeRequest = OpenAIRequest(makeApp(), "anthropic-key", anthropicModel, "sys"); + const result = await makeRequest("q"); + + expect(result.content).toBe("after the tool block"); + }); + it("maps the Anthropic response, summing tokens and preserving stop_sequence", async () => { requestUrlMock.mockResolvedValue({ json: { diff --git a/src/ai/OpenAIRequest.ts b/src/ai/OpenAIRequest.ts index ba5946c8..b33aeb91 100644 --- a/src/ai/OpenAIRequest.ts +++ b/src/ai/OpenAIRequest.ts @@ -8,10 +8,21 @@ import { } from "./AIAssistant"; import { preventCursorChange } from "./preventCursorChange"; import type { AIProvider, Model } from "./Provider"; +import { getProviderKind } from "./Provider"; import { getModelProvider } from "./aiHelpers"; +import type { NormalizedChatRequest } from "./tools/NormalizedTools"; +import { + buildChatBody, + parseChatResponse, + type ProviderKind, +} from "./tools/providerToolMapping"; import { log } from "src/logger/logManager"; import { estimateTokenCount } from "./tokenEstimator"; import { buildProviderError, classifyProviderError } from "./providerErrors"; +import type { + NormalizedStopReason, + NormalizedToolCall, +} from "./tools/NormalizedTools"; export interface CommonResponse { id: string; @@ -22,9 +33,21 @@ export interface CommonResponse { completionTokens: number; totalTokens: number; }; + /** Raw provider stop/finish reason (kept for back-compat + debugging). */ stopReason: string; stopSequence: string | null; created: number; + // --- Tool-calling additions (#714); all optional so existing consumers are unaffected. --- + /** Tool calls the model requested this turn, normalized across providers. */ + toolCalls?: NormalizedToolCall[]; + /** Provider stop reason mapped to a neutral enum, for the execute loop. */ + normalizedStopReason?: NormalizedStopReason; + /** + * Opaque provider-specific blocks that must be echoed back unchanged on the + * next turn (e.g. Gemini `thoughtSignature` parts). Carried on the assistant + * turn the loop reconstructs. + */ + providerRaw?: unknown; } // Shared request execution for all providers: call `requestUrl` with @@ -71,7 +94,12 @@ function mapAnthropicResponseToCommon( return { id: response.id, model: response.model, - content: response.content[0].text, + // Scan all blocks and join the text ones — reading content[0] breaks the + // moment a non-text block (e.g. a tool_use block) is first. + content: response.content + .filter((block) => block.type === "text") + .map((block) => block.text ?? "") + .join(""), usage: { promptTokens: response.usage.input_tokens, completionTokens: response.usage.output_tokens, @@ -101,8 +129,19 @@ type OpenAIReqResponse = { created: number; }; +// A content block can be text OR a non-text block (e.g. tool_use); only `type` is +// guaranteed. Modeled honestly so the tool-calling layer (PR2) can read tool_use +// blocks without the type lying about every block having a string `text`. +export interface AnthropicContentBlock { + type: string; + text?: string; + id?: string; + name?: string; + input?: Record; +} + export interface AnthropicResponse { - content: { text: string; type: string }[]; + content: AnthropicContentBlock[]; id: string; model: string; role: string; @@ -160,14 +199,37 @@ async function makeOpenAIRequest( ); } +// Conservative Anthropic output budget. The Messages API REQUIRES max_tokens, but +// model.maxTokens is the *context window*, not an output cap — so derive a sensible +// floor and clamp it to the window (a small/mis-registered model can't request more +// than it has). A dedicated per-model output field is a future refinement. +export function anthropicMaxTokens(model: Model): number { + const ceiling = 8192; + return Number.isFinite(model.maxTokens) && model.maxTokens > 0 + ? Math.min(ceiling, model.maxTokens) + : ceiling; +} + async function makeAnthropicRequest( apiKey: string, model: Model, modelProvider: AIProvider, + systemPrompt: string, modelParams: Partial, prompt: string, afterRequestCallback?: () => void ): Promise { + const body: Record = { + model: model.name, + max_tokens: anthropicMaxTokens(model), + messages: [{ role: "user", content: prompt }], + }; + // Send the system prompt at the top level (the Messages API has no system role + // inside `messages`). Previously dropped entirely — this is a real behaviour fix. + if (systemPrompt && systemPrompt.trim().length > 0) { + body.system = systemPrompt; + } + return dispatchProviderRequest( { url: `${modelProvider.endpoint}/v1/messages`, @@ -176,13 +238,8 @@ async function makeAnthropicRequest( "Content-Type": "application/json", "x-api-key": apiKey, "anthropic-version": "2023-06-01", - "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" }, - body: JSON.stringify({ - model: model.name, - max_tokens: 4096, - messages: [{ role: "user", content: prompt }], - }), + body: JSON.stringify(body), }, modelProvider.name, afterRequestCallback @@ -318,27 +375,34 @@ export function OpenAIRequest( const restoreCursor = preventCursorChange(app); let response: CommonResponse; - if (modelProvider.name === "Anthropic") { + // Route on the provider KIND (not the display name) so a custom-named + // Anthropic/Gemini provider — e.g. { name: "Claude Proxy", kind: "anthropic" } — + // gets the right wire shape here too, matching the ai.agent chat path. + // getProviderKind falls back to name inference, so the built-in + // "Anthropic"/"Gemini" providers are unchanged. + const providerKind = getProviderKind(modelProvider); + if (providerKind === "anthropic") { const anthropicResponse = await makeAnthropicRequest( apiKey, model, modelProvider, + systemPrompt, modelParams, prompt, restoreCursor ); response = mapAnthropicResponseToCommon(anthropicResponse); - } else if (modelProvider.name === "Gemini") { - const geminiResponse = await makeGeminiRequest( - apiKey, - model, - modelProvider, - systemPrompt, - modelParams, - prompt, - restoreCursor - ); - response = mapGeminiResponseToCommon(geminiResponse); + } else if (providerKind === "gemini") { + const geminiResponse = await makeGeminiRequest( + apiKey, + model, + modelProvider, + systemPrompt, + modelParams, + prompt, + restoreCursor + ); + response = mapGeminiResponseToCommon(geminiResponse); } else { const openaiResponse = await makeOpenAIRequest( apiKey, @@ -395,3 +459,144 @@ export function OpenAIRequest( } }; } + +// --------------------------------------------------------------------------- +// Multi-turn chat entrypoint (#714) — used by the tool-calling Agent loop. +// Sibling to OpenAIRequest: the single-prompt path above stays byte-identical. +// Builds a provider body from a NormalizedChatRequest, dispatches, and parses +// tool calls. It does NOT arm preventCursorChange — the Agent captures the cursor +// ONCE per generate() and passes its restore fn here so intermediate turns don't +// re-arm it. +// --------------------------------------------------------------------------- +async function dispatchChat( + kind: ProviderKind, + apiKey: string, + modelProvider: AIProvider, + model: Model, + body: Record, + afterRequestCallback?: () => void, +): Promise> { + if (kind === "anthropic") { + return dispatchProviderRequest>( + { + url: `${modelProvider.endpoint}/v1/messages`, + method: "POST", + headers: { + "Content-Type": "application/json", + "x-api-key": apiKey, + "anthropic-version": "2023-06-01", + }, + body: JSON.stringify(body), + }, + modelProvider.name, + afterRequestCallback, + ); + } + if (kind === "gemini") { + const url = `${modelProvider.endpoint}/v1beta/models/${encodeURIComponent( + model.name, + )}:generateContent?key=${encodeURIComponent(apiKey)}`; + return dispatchProviderRequest>( + { url, method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify(body) }, + modelProvider.name, + afterRequestCallback, + ); + } + return dispatchProviderRequest>( + { + url: `${modelProvider.endpoint}/chat/completions`, + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(body), + }, + modelProvider.name, + afterRequestCallback, + ); +} + +export async function chatRequest( + app: App, + apiKey: string, + model: Model, + request: NormalizedChatRequest, + afterRequestCallback?: () => void, +): Promise { + void app; // cursor handling is owned by the caller (Agent) for the whole loop + if (settingsStore.getState().disableOnlineFeatures) { + throw new Error( + "Blocking request: Online features are disabled in settings.", + ); + } + + const modelProvider = getModelProvider(model.name); + if (!modelProvider) { + throw new Error(`Model ${model.name} not found with any provider.`); + } + const kind = getProviderKind(modelProvider); + const body = buildChatBody(kind, model.name, request, anthropicMaxTokens(model)); + + // Compact log summary — never dump the whole transcript / tool data into the log. + const systemMsg = request.messages.find((m) => m.role === "system"); + const lastUser = [...request.messages] + .reverse() + .find((m) => m.role === "user"); + const requestStart = Date.now(); + const requestLogId = beginAIRequestLogEntry({ + provider: modelProvider.name, + endpoint: modelProvider.endpoint, + model: model.name, + systemPrompt: systemMsg && systemMsg.role === "system" ? systemMsg.content : "", + prompt: + lastUser && lastUser.role === "user" ? lastUser.content : "[tool-calling turn]", + modelOptions: request.modelParams ?? {}, + }); + + try { + const json = await dispatchChat( + kind, + apiKey, + modelProvider, + model, + body, + afterRequestCallback, + ); + const parsed = parseChatResponse(kind, json); + const durationMs = Date.now() - requestStart; + finishAIRequestLogEntry(requestLogId, { + status: "success", + durationMs, + usage: parsed.usage, + }); + log.logMessage(`[AI Chat ${requestLogId}] Success in ${durationMs}ms`); + + return { + id: (json.id as string) ?? `${Date.now()}`, + model: model.name, + content: parsed.content, + usage: parsed.usage, + stopReason: parsed.rawStopReason, + stopSequence: null, + created: Date.now(), + toolCalls: parsed.toolCalls, + normalizedStopReason: parsed.normalizedStopReason, + providerRaw: parsed.providerRaw, + }; + } catch (error) { + const errorMessage = + (error as { message?: string }).message ?? String(error); + const durationMs = Date.now() - requestStart; + finishAIRequestLogEntry(requestLogId, { + status: "error", + durationMs, + errorMessage, + }); + log.logError(error as Error); + throw new Error( + `Error while making request to ${modelProvider.name}: ${errorMessage}`, + { cause: error }, + ); + } +} diff --git a/src/ai/Provider.test.ts b/src/ai/Provider.test.ts new file mode 100644 index 00000000..89857f41 --- /dev/null +++ b/src/ai/Provider.test.ts @@ -0,0 +1,27 @@ +import { describe, it, expect } from "vitest"; +import { getProviderKind } from "./Provider"; + +describe("getProviderKind", () => { + it("prefers an explicit kind", () => { + expect(getProviderKind({ kind: "anthropic", name: "Whatever" })).toBe("anthropic"); + expect(getProviderKind({ kind: "openai", name: "Anthropic" })).toBe("openai"); + }); + + it("infers anthropic from name or endpoint", () => { + expect(getProviderKind({ name: "Anthropic" })).toBe("anthropic"); + expect(getProviderKind({ name: "My Claude", endpoint: "https://api.anthropic.com" })).toBe("anthropic"); + }); + + it("infers gemini from name or endpoint", () => { + expect(getProviderKind({ name: "Gemini" })).toBe("gemini"); + expect( + getProviderKind({ name: "Google", endpoint: "https://generativelanguage.googleapis.com" }), + ).toBe("gemini"); + }); + + it("defaults unknown/OpenAI-compatible providers to openai", () => { + expect(getProviderKind({ name: "Groq", endpoint: "https://api.groq.com/openai/v1" })).toBe("openai"); + expect(getProviderKind({ name: "OpenRouter" })).toBe("openai"); + expect(getProviderKind({})).toBe("openai"); + }); +}); diff --git a/src/ai/Provider.ts b/src/ai/Provider.ts index e9e2b1ba..dabd8f6d 100644 --- a/src/ai/Provider.ts +++ b/src/ai/Provider.ts @@ -1,5 +1,13 @@ export type ModelDiscoveryMode = "modelsDev" | "providerApi" | "auto"; +/** + * Which wire protocol a provider speaks. Used to select the request/response + * adapter for tool calling + structured output (#714) instead of matching the + * provider's display NAME — a custom Anthropic-compatible provider named anything + * other than "Anthropic" must still get the Anthropic wire shape. + */ +export type ProviderKind = "openai" | "anthropic" | "gemini"; + export interface AIProvider { name: string; endpoint: string; @@ -12,6 +20,34 @@ export interface AIProvider { autoSyncModels?: boolean; /** Controls how QuickAdd discovers browseable models for this provider. */ modelSource: ModelDiscoveryMode; + /** Wire protocol. Optional for back-compat; inferred when absent (see getProviderKind). */ + kind?: ProviderKind; +} + +/** + * Resolve a provider's wire kind. Prefers the explicit `kind` field; otherwise + * infers from the (legacy) name/endpoint so providers saved before the field + * existed still route correctly. Unknown → "openai" (the OpenAI-compatible default, + * matching today's fallback branch). + */ +export function getProviderKind(provider: { + kind?: ProviderKind; + name?: string; + endpoint?: string; +}): ProviderKind { + if (provider.kind) return provider.kind; + const name = (provider.name ?? "").toLowerCase(); + const endpoint = (provider.endpoint ?? "").toLowerCase(); + if (name === "anthropic" || endpoint.includes("api.anthropic.com")) { + return "anthropic"; + } + if ( + name === "gemini" || + endpoint.includes("generativelanguage.googleapis.com") + ) { + return "gemini"; + } + return "openai"; } export interface Model { @@ -22,6 +58,7 @@ export interface Model { const OpenAIProvider: AIProvider = { name: "OpenAI", endpoint: "https://api.openai.com/v1", + kind: "openai", apiKey: "", models: [ { @@ -68,6 +105,7 @@ const OpenAIProvider: AIProvider = { const GeminiProvider: AIProvider = { name: "Gemini", endpoint: "https://generativelanguage.googleapis.com", + kind: "gemini", apiKey: "", models: [ { diff --git a/src/ai/tools/NormalizedTools.ts b/src/ai/tools/NormalizedTools.ts new file mode 100644 index 00000000..0c0f5f22 --- /dev/null +++ b/src/ai/tools/NormalizedTools.ts @@ -0,0 +1,117 @@ +/** + * Provider-neutral tool-calling types for QuickAdd's AI layer (#714). + * + * INTERFACE-ONLY — this module imports only the pure OpenAIModelParameters type (no + * Obsidian, no settings), so the pure provider mappers and the execute loop that + * consume it stay unit-testable under vitest/jsdom. The public, AI-SDK-shaped surface + * (Agent / tool() / result) lives in `aiToolTypes.ts`; this file is the internal wire + * representation that the per-provider build/parse functions map to and from. + */ +import type { OpenAIModelParameters } from "../OpenAIModelParameters"; + +/** + * A minimal JSON Schema shape. QuickAdd has no Zod dependency (bundle-size + * sensitive), so tool input/output schemas are plain JSON Schema objects. The + * in-house validator (see `jsonSchemaValidator.ts`) accepts only a documented + * subset and rejects unsupported keywords at registration — this type stays + * permissive so authors can still express that subset naturally. + */ +export interface JSONSchema { + type?: JSONSchemaType | JSONSchemaType[]; + properties?: Record; + items?: JSONSchema | JSONSchema[]; + required?: string[]; + enum?: unknown[]; + const?: unknown; + description?: string; + [keyword: string]: unknown; +} + +export type JSONSchemaType = + | "object" + | "array" + | "string" + | "number" + | "integer" + | "boolean" + | "null"; + +/** A tool as sent to the provider (no handler — that lives in the registry). */ +export interface NormalizedToolDefinition { + /** Must match ^[a-zA-Z0-9_-]{1,64}$ (Anthropic's strictest rule, enforced for all). */ + name: string; + description: string; + parameters: JSONSchema; + /** OpenAI/Anthropic strict mode; ignored by Gemini. */ + strict?: boolean; +} + +/** + * A tool call the model asked for, normalized across providers. + * `args` is ALWAYS a parsed object internally — except when the provider sent + * unparseable JSON (OpenAI streams arguments as a JSON string), in which case + * `args` is null, `rawArgs` holds the original string, and `parseError` is true so + * the loop emits an `isError` result for this call rather than running a handler. + */ +export interface NormalizedToolCall { + id: string; + name: string; + args: Record | null; + rawArgs?: string; + parseError?: boolean; +} + +/** The result of executing one tool call, normalized for the result turn. */ +export interface NormalizedToolResult { + toolCallId: string; + /** The tool name — REQUIRED because Gemini's functionResponse needs it. */ + name: string; + content: string; + isError?: boolean; +} + +/** + * One turn of the conversation. The assistant turn keeps `content` and + * `toolCalls` separate; the per-provider builder is responsible for emitting them + * in the order each provider requires (e.g. Anthropic needs text blocks before + * tool_use blocks). `providerRaw` carries opaque provider blocks that must be + * echoed back unchanged across turns (Gemini 3 `thoughtSignature` parts). + */ +export type NormalizedMessage = + | { role: "system"; content: string } + | { role: "user"; content: string } + | { + role: "assistant"; + content: string; + toolCalls?: NormalizedToolCall[]; + providerRaw?: unknown; + } + | { role: "tool"; results: NormalizedToolResult[] }; + +export type NormalizedToolChoice = + | "auto" + | "none" + | "required" + | { name: string }; + +/** A JSON-schema-constrained output request (structured output). */ +export interface NormalizedResponseFormat { + schema?: JSONSchema; + name?: string; + strict?: boolean; +} + +/** A provider-neutral chat request. The mappers turn this into each provider's body. */ +export interface NormalizedChatRequest { + messages: NormalizedMessage[]; + tools?: NormalizedToolDefinition[]; + toolChoice?: NormalizedToolChoice; + /** Ask the provider to emit at most one tool call per turn (no-op on Gemini). */ + disableParallel?: boolean; + maxOutputTokens?: number; + responseFormat?: NormalizedResponseFormat; + modelParams?: Partial; +} + +/** Normalized stop/finish reason. `rawStopReason` on CommonResponse keeps the original. */ +export type NormalizedStopReason = "stop" | "tool_calls" | "length" | "other"; diff --git a/src/ai/tools/aiToolTypes.ts b/src/ai/tools/aiToolTypes.ts new file mode 100644 index 00000000..d55ce485 --- /dev/null +++ b/src/ai/tools/aiToolTypes.ts @@ -0,0 +1,126 @@ +/** + * Public, AI-SDK-shaped types for the QuickAdd Agent surface (#714). + * + * Pure (no Obsidian). These are what user scripts touch: `ai.tool(def)` builds a + * QATool, agents take a `ToolSet` (object map keyed by tool name), and + * `agent.generate()` resolves to a GenerateResult. Public field names mirror the + * Vercel AI SDK (`input`/`output`, `inputSchema`, `needsApproval`); the internal + * wire layer (NormalizedTools) translates at the boundary. + */ +import type { JSONSchema } from "./NormalizedTools"; + +export interface ToolExecuteContext { + toolCallId: string; + toolName: string; +} + +/** What a script passes to `ai.tool()`. */ +export interface ToolDefinitionInput { + description: string; + inputSchema: JSONSchema; + execute: ( + input: Record, + ctx: ToolExecuteContext, + ) => unknown | Promise; + /** + * Per-tool human-in-the-loop gate (AI-SDK name). A tool resolving to true is + * ALWAYS confirmed regardless of the global confirmToolCalls setting. + */ + needsApproval?: + | boolean + | ((opts: { args: Record }) => boolean | Promise); + /** Read-only tools skip confirmation under the 'destructive' global setting. */ + readOnly?: boolean; + /** OpenAI/Anthropic strict tool-input validation. */ + strict?: boolean; +} + +/** A tool registered via `ai.tool()`. Branded so a ToolSet can't be a plain object. */ +export interface QATool extends ToolDefinitionInput { + readonly __qaTool: true; +} + +export type ToolSet = Record; + +export type ToolChoice = + | "auto" + | "none" + | "required" + | { type: "tool"; toolName: string }; + +/** + * A stop condition (AI-SDK parity). Built via `ai.stepCountIs(n)` / `ai.hasToolCall(name)`. + * Evaluated after each tool-execution step; returning true ends the loop on a final + * text turn. The hard `maxSteps` clamp always applies on top of these. + */ +export type StopCondition = (ctx: { + stepNumber: number; + toolCallNames: string[]; +}) => boolean; + +export interface PublicToolCall { + toolCallId: string; + toolName: string; + /** AI-SDK uses `input` (not `args`). Null when the model's JSON args failed to parse. */ + input: Record | null; +} + +export interface PublicToolResult { + toolCallId: string; + toolName: string; + output: string; + isError: boolean; +} + +export interface GenerateStep { + stepNumber: number; + text: string; + toolCalls: PublicToolCall[]; + toolResults: PublicToolResult[]; + finishReason: string; +} + +export interface GenerateUsage { + inputTokens: number; + outputTokens: number; + totalTokens: number; +} + +export interface GenerateResult { + /** Final assistant text. */ + text: string; + /** Present only when a `schema` was passed (structured output): the parsed + validated object. */ + object?: unknown; + steps: GenerateStep[]; + /** Tool calls / results from the last step (AI-SDK semantics). */ + toolCalls: PublicToolCall[]; + toolResults: PublicToolResult[]; + usage: GenerateUsage; + finishReason: string; +} + +/** Config for `ai.agent(config)`. */ +export interface AgentConfig { + model: string | { name: string }; + system?: string; + tools?: ToolSet; + toolChoice?: ToolChoice; + stopWhen?: StopCondition | StopCondition[]; + /** Sugar for `stopWhen: ai.stepCountIs(N)`. If both given, `stopWhen` also applies. */ + maxSteps?: number; + maxOutputTokens?: number; + modelOptions?: Record; +} + +/** Per-call options for `agent.generate(options)`. */ +export interface GenerateOptions { + prompt?: string; + /** Present ⇒ structured output; result.object is populated. */ + schema?: JSONSchema; + /** Per-call overrides of agent config. */ + system?: string; + toolChoice?: ToolChoice; + maxOutputTokens?: number; + /** Write result.text into choiceExecutor.variables as + -quoted for {{VALUE:name}}. */ + assignToVariable?: string; +} diff --git a/src/ai/tools/assignableVariable.ts b/src/ai/tools/assignableVariable.ts new file mode 100644 index 00000000..2514f260 --- /dev/null +++ b/src/ai/tools/assignableVariable.ts @@ -0,0 +1,26 @@ +/** + * Guard for the `assignToVariable` option shared by ai.agent and ai.prompt/chunkedPrompt (#714). + * Rejects names that the formatter treats specially (so they cannot be hijacked) or + * that are unreachable by the {{VALUE:name}} grammar. + */ +export function assertAssignableVariableName(name: string): void { + const trimmed = name.trim(); + if (!trimmed) throw new Error("assignToVariable cannot be empty."); + // `value` (the programmatic {{VALUE}} slot) and `title` (file-name title) are + // formatter-reserved; assigning them would silently hijack those tokens. + if (new Set(["value", "title", "text", "meta"]).has(trimmed)) { + throw new Error( + `assignToVariable "${trimmed}" is reserved (it would hijack a built-in formatter token).`, + ); + } + if (trimmed.endsWith("-quoted")) { + throw new Error( + `assignToVariable "${trimmed}" cannot end with "-quoted" (collides with the quoted output variable).`, + ); + } + if (/[|,]/.test(trimmed) || trimmed.startsWith("__qa.")) { + throw new Error( + `assignToVariable "${trimmed}" contains characters unreachable by {{VALUE:name}}.`, + ); + } +} diff --git a/src/ai/tools/jsonSchemaValidator.test.ts b/src/ai/tools/jsonSchemaValidator.test.ts new file mode 100644 index 00000000..0b1dbcb8 --- /dev/null +++ b/src/ai/tools/jsonSchemaValidator.test.ts @@ -0,0 +1,101 @@ +import { describe, it, expect } from "vitest"; +import { + assertRegisterableSchema, + validateValue, + ToolSchemaError, +} from "./jsonSchemaValidator"; +import type { JSONSchema } from "./NormalizedTools"; + +describe("assertRegisterableSchema", () => { + it("accepts the supported subset", () => { + expect(() => + assertRegisterableSchema({ + type: "object", + properties: { + name: { type: "string", description: "a name" }, + count: { type: "integer" }, + tags: { type: "array", items: { type: "string" } }, + mode: { type: "string", enum: ["a", "b"] }, + }, + required: ["name"], + }), + ).not.toThrow(); + }); + + it("rejects unsupported keywords that a provider would silently drop", () => { + const bad: JSONSchema[] = [ + { type: "string", pattern: "^x" }, + { type: "string", minLength: 2 }, + { type: "object", additionalProperties: false }, + { $ref: "#/defs/x" }, + { allOf: [{ type: "object" }] }, + { anyOf: [{ type: "string" }] }, + { type: "string", format: "email" }, + { type: "number", minimum: 0 }, + ]; + for (const schema of bad) { + expect(() => assertRegisterableSchema(schema)).toThrow(ToolSchemaError); + } + }); + + it("recurses into properties and items", () => { + expect(() => + assertRegisterableSchema({ + type: "object", + properties: { nested: { type: "object", properties: { x: { type: "string", pattern: "y" } } } }, + }), + ).toThrow(ToolSchemaError); + expect(() => + assertRegisterableSchema({ type: "array", items: { type: "string", format: "uri" } }), + ).toThrow(ToolSchemaError); + }); + + it("rejects invalid type names and malformed required", () => { + expect(() => assertRegisterableSchema({ type: "stringy" as never })).toThrow( + ToolSchemaError, + ); + expect(() => + assertRegisterableSchema({ type: "object", required: "name" as never }), + ).toThrow(ToolSchemaError); + }); +}); + +describe("validateValue", () => { + const schema = { + type: "object" as const, + properties: { + path: { type: "string" as const }, + count: { type: "integer" as const }, + mode: { type: "string" as const, enum: ["top", "bottom"] }, + tags: { type: "array" as const, items: { type: "string" as const } }, + }, + required: ["path"], + }; + + it("passes a valid value", () => { + expect( + validateValue({ path: "a.md", count: 3, mode: "top", tags: ["x", "y"] }, schema), + ).toBeNull(); + }); + + it("flags a missing required property", () => { + expect(validateValue({ count: 1 }, schema)).toMatch(/required/); + }); + + it("flags a wrong type", () => { + expect(validateValue({ path: 42 }, schema)).toMatch(/expected string/); + expect(validateValue({ path: "a", count: 1.5 }, schema)).toMatch(/integer/); + }); + + it("flags an out-of-enum value", () => { + expect(validateValue({ path: "a", mode: "middle" }, schema)).toMatch(/enum/); + }); + + it("validates array item types", () => { + expect(validateValue({ path: "a", tags: ["ok", 3] }, schema)).toMatch(/expected string/); + }); + + it("accepts an integer where number is expected", () => { + expect(validateValue(3, { type: "number" })).toBeNull(); + }); +}); diff --git a/src/ai/tools/jsonSchemaValidator.ts b/src/ai/tools/jsonSchemaValidator.ts new file mode 100644 index 00000000..0e82aca9 --- /dev/null +++ b/src/ai/tools/jsonSchemaValidator.ts @@ -0,0 +1,218 @@ +/** + * Tiny JSON Schema subset validator for QuickAdd tool inputs and structured output (#714). + * + * QuickAdd has no Zod and is bundle-size sensitive, so this is a purpose-built + * validator for the lowest-common-denominator subset the three providers agree on. + * It does TWO jobs, both pure (no Obsidian): + * + * 1. `assertRegisterableSchema(schema)` — called when a tool (or an output schema) + * is registered. It REJECTS unsupported keywords so an author can never write a + * constraint that one provider silently drops (e.g. `pattern`, `minLength`, + * `additionalProperties`, `$ref`, `allOf`, `format`). It is a security/shape gate, + * NOT a guarantee of per-provider acceptance — the runtime repair re-ask handles + * provider divergence. + * + * 2. `validateValue(value, schema)` — validates model-produced args / structured + * output against the subset. Returns the first error message (or null) so the + * loop can feed an `isError` result back to the model instead of running a + * handler with malformed input. + */ +import type { JSONSchema, JSONSchemaType } from "./NormalizedTools"; + +// The only keywords this validator understands. Anything else is rejected at +// registration so it can't be a silently-unenforced constraint. +const SUPPORTED_KEYWORDS = new Set([ + "type", + "properties", + "required", + "items", + "enum", + "const", + "description", + "title", +]); + +const VALID_TYPES: ReadonlySet = new Set([ + "object", + "array", + "string", + "number", + "integer", + "boolean", + "null", +]); + +export class ToolSchemaError extends Error { + constructor(message: string) { + super(message); + this.name = "ToolSchemaError"; + } +} + +/** + * Throws ToolSchemaError if `schema` uses a keyword outside the supported subset. + * `where` is a human path for the error message (e.g. "inputSchema"). + */ +export function assertRegisterableSchema( + schema: JSONSchema, + where = "schema", +): void { + if (schema === null || typeof schema !== "object" || Array.isArray(schema)) { + throw new ToolSchemaError(`${where} must be a JSON Schema object.`); + } + + for (const keyword of Object.keys(schema)) { + if (!SUPPORTED_KEYWORDS.has(keyword)) { + throw new ToolSchemaError( + `${where} uses unsupported JSON Schema keyword "${keyword}". QuickAdd tool schemas support only: ${[ + ...SUPPORTED_KEYWORDS, + ].join(", ")}. (Constraints like pattern/minLength/additionalProperties/$ref/allOf/anyOf/format are rejected because at least one provider would silently drop them.)`, + ); + } + } + + if (schema.type !== undefined) { + const types = Array.isArray(schema.type) ? schema.type : [schema.type]; + for (const t of types) { + if (!VALID_TYPES.has(t as string)) { + throw new ToolSchemaError(`${where} has invalid type "${String(t)}".`); + } + } + } + + if (schema.properties !== undefined) { + if ( + schema.properties === null || + typeof schema.properties !== "object" || + Array.isArray(schema.properties) + ) { + throw new ToolSchemaError(`${where}.properties must be an object.`); + } + for (const [key, sub] of Object.entries(schema.properties)) { + assertRegisterableSchema(sub as JSONSchema, `${where}.properties.${key}`); + } + } + + if (schema.required !== undefined) { + if ( + !Array.isArray(schema.required) || + schema.required.some((r) => typeof r !== "string") + ) { + throw new ToolSchemaError(`${where}.required must be an array of strings.`); + } + } + + if (schema.items !== undefined) { + if (Array.isArray(schema.items)) { + schema.items.forEach((sub, i) => + assertRegisterableSchema(sub, `${where}.items[${i}]`), + ); + } else { + assertRegisterableSchema(schema.items, `${where}.items`); + } + } + + if (schema.enum !== undefined && !Array.isArray(schema.enum)) { + throw new ToolSchemaError(`${where}.enum must be an array.`); + } +} + +function typeOfValue(value: unknown): JSONSchemaType { + if (value === null) return "null"; + if (Array.isArray(value)) return "array"; + if (typeof value === "number") { + return Number.isInteger(value) ? "integer" : "number"; + } + if (typeof value === "boolean") return "boolean"; + if (typeof value === "string") return "string"; + return "object"; +} + +function matchesType(value: unknown, expected: JSONSchemaType): boolean { + const actual = typeOfValue(value); + if (actual === expected) return true; + // JSON has one number type; an integer value satisfies "number". + if (expected === "number" && actual === "integer") return true; + // An integer schema accepts only whole numbers (handled by typeOfValue: a + // non-integer number reports "number", which won't match "integer"). + return false; +} + +/** + * Validate `value` against the subset schema. Returns the first error message, or + * null when valid. `path` is used for readable error messages. + */ +export function validateValue( + value: unknown, + schema: JSONSchema, + path = "$", +): string | null { + // type + if (schema.type !== undefined) { + const types = Array.isArray(schema.type) ? schema.type : [schema.type]; + if (!types.some((t) => matchesType(value, t as JSONSchemaType))) { + return `${path}: expected ${types.join(" | ")}, got ${typeOfValue(value)}`; + } + } + + // enum + if (Array.isArray(schema.enum)) { + const ok = schema.enum.some((e) => deepEqual(e, value)); + if (!ok) return `${path}: value is not one of the allowed enum values`; + } + + // const + if ("const" in schema && !deepEqual(schema.const, value)) { + return `${path}: value does not equal the required const`; + } + + // object: required + nested properties + if (typeOfValue(value) === "object") { + const obj = value as Record; + if (Array.isArray(schema.required)) { + for (const key of schema.required) { + if (!(key in obj)) return `${path}.${key}: required property is missing`; + } + } + if (schema.properties) { + for (const [key, sub] of Object.entries(schema.properties)) { + if (key in obj) { + const err = validateValue(obj[key], sub as JSONSchema, `${path}.${key}`); + if (err) return err; + } + } + } + } + + // array: items + if (typeOfValue(value) === "array" && schema.items && !Array.isArray(schema.items)) { + const arr = value as unknown[]; + for (let i = 0; i < arr.length; i++) { + const err = validateValue(arr[i], schema.items, `${path}[${i}]`); + if (err) return err; + } + } + + return null; +} + +function deepEqual(a: unknown, b: unknown): boolean { + if (a === b) return true; + if (typeof a !== typeof b) return false; + if (a === null || b === null) return a === b; + if (Array.isArray(a) && Array.isArray(b)) { + return a.length === b.length && a.every((x, i) => deepEqual(x, b[i])); + } + if (typeof a === "object" && typeof b === "object") { + const ka = Object.keys(a as object); + const kb = Object.keys(b as object); + if (ka.length !== kb.length) return false; + return ka.every((k) => + deepEqual( + (a as Record)[k], + (b as Record)[k], + ), + ); + } + return false; +} diff --git a/src/ai/tools/providerToolMapping.test.ts b/src/ai/tools/providerToolMapping.test.ts new file mode 100644 index 00000000..fe2edabc --- /dev/null +++ b/src/ai/tools/providerToolMapping.test.ts @@ -0,0 +1,265 @@ +import { describe, it, expect } from "vitest"; +import { + buildChatBody, + parseChatResponse, + injectStrictObjectSchema, +} from "./providerToolMapping"; +import type { NormalizedChatRequest } from "./NormalizedTools"; + +const tool = { + name: "create_note", + description: "Create a note", + parameters: { type: "object" as const, properties: { path: { type: "string" as const } }, required: ["path"] }, +}; + +describe("buildChatBody — byte-minimal when no tools/schema (byte-identity lock)", () => { + const req: NormalizedChatRequest = { messages: [{ role: "user", content: "hi" }] }; + it("OpenAI: omits tools/tool_choice/response_format entirely", () => { + const body = buildChatBody("openai", "gpt-4o", req); + expect(Object.keys(body).sort()).toEqual(["messages", "model"].sort()); + }); + it("Anthropic: only model/max_tokens/messages (no tools/output_config/system)", () => { + const body = buildChatBody("anthropic", "claude-3-5-sonnet", req, 8192); + expect(Object.keys(body).sort()).toEqual(["max_tokens", "messages", "model"].sort()); + expect(body.max_tokens).toBe(8192); + }); + it("Gemini: only contents (no tools/toolConfig/generationConfig/systemInstruction)", () => { + const body = buildChatBody("gemini", "gemini-1.5-pro", req); + expect(Object.keys(body)).toEqual(["contents"]); + }); +}); + +describe("OpenAI mapping", () => { + it("builds function tools + tool_choice + parallel flag", () => { + const req: NormalizedChatRequest = { + messages: [{ role: "user", content: "make a note" }], + tools: [tool], + toolChoice: { name: "create_note" }, + disableParallel: true, + }; + const body = buildChatBody("openai", "gpt-4o", req) as Record; + expect(body.tools).toEqual([ + { type: "function", function: { name: "create_note", description: "Create a note", parameters: tool.parameters } }, + ]); + expect(body.tool_choice).toEqual({ type: "function", function: { name: "create_note" } }); + expect(body.parallel_tool_calls).toBe(false); + }); + + it("threads an assistant tool-call turn (args as JSON string, rawArgs preserved) and flat tool results", () => { + const req: NormalizedChatRequest = { + messages: [ + { role: "user", content: "q" }, + { role: "assistant", content: "", toolCalls: [{ id: "call_1", name: "create_note", args: { path: "a.md" }, rawArgs: '{"path":"a.md"}' }] }, + { role: "tool", results: [{ toolCallId: "call_1", name: "create_note", content: "ok" }] }, + ], + tools: [tool], + }; + const body = buildChatBody("openai", "gpt-4o", req) as Record; + const msgs = body.messages as Array>; + const assistant = msgs[1]; + expect((assistant.tool_calls as Array>)[0]).toMatchObject({ + id: "call_1", + type: "function", + function: { name: "create_note", arguments: '{"path":"a.md"}' }, + }); + expect(msgs[2]).toEqual({ role: "tool", tool_call_id: "call_1", content: "ok" }); + }); + + it("parses a JSON-STRING arguments tool call", () => { + const res = parseChatResponse("openai", { + choices: [{ finish_reason: "tool_calls", message: { content: null, tool_calls: [{ id: "c1", function: { name: "create_note", arguments: '{"path":"x.md"}' } }] } }], + usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 }, + }); + expect(res.normalizedStopReason).toBe("tool_calls"); + expect(res.toolCalls[0]).toMatchObject({ id: "c1", name: "create_note", args: { path: "x.md" }, rawArgs: '{"path":"x.md"}' }); + }); + + it("flags malformed JSON arguments with parseError (no throw)", () => { + const res = parseChatResponse("openai", { + choices: [{ finish_reason: "tool_calls", message: { tool_calls: [{ id: "c1", function: { name: "create_note", arguments: "{not json" } }] } }], + }); + expect(res.toolCalls[0]).toMatchObject({ id: "c1", args: null, parseError: true }); + }); + + it("defensively accepts object arguments (Ollama-native)", () => { + const res = parseChatResponse("openai", { + choices: [{ finish_reason: "tool_calls", message: { tool_calls: [{ id: "c1", function: { name: "create_note", arguments: { path: "x.md" } } }] } }], + }); + expect(res.toolCalls[0].args).toEqual({ path: "x.md" }); + expect(res.toolCalls[0].parseError).toBeUndefined(); + }); + + it("uses max_completion_tokens for reasoning models (gpt-5+/o-series), max_tokens otherwise", () => { + const req: NormalizedChatRequest = { + messages: [{ role: "user", content: "q" }], + maxOutputTokens: 1234, + }; + const pick = (model: string) => { + const b = buildChatBody("openai", model, req) as Record; + return { mc: b.max_completion_tokens, mt: b.max_tokens }; + }; + // Reasoning line → max_completion_tokens + expect(pick("gpt-5.5")).toEqual({ mc: 1234, mt: undefined }); + expect(pick("gpt-5-mini")).toEqual({ mc: 1234, mt: undefined }); + expect(pick("o3")).toEqual({ mc: 1234, mt: undefined }); + // Classic OpenAI + OpenAI-compatible (Ollama/Groq names) → max_tokens + expect(pick("gpt-4o")).toEqual({ mc: undefined, mt: 1234 }); + expect(pick("llama3.1")).toEqual({ mc: undefined, mt: 1234 }); + }); + + it("builds a strict json_schema response_format with injected additionalProperties + required", () => { + const req: NormalizedChatRequest = { + messages: [{ role: "user", content: "extract" }], + responseFormat: { schema: { type: "object", properties: { title: { type: "string" } } }, name: "doc" }, + }; + const body = buildChatBody("openai", "gpt-4o", req) as Record; + const rf = body.response_format as Record; + expect(rf.type).toBe("json_schema"); + const js = rf.json_schema as Record; + expect(js).toMatchObject({ name: "doc", strict: true }); + expect(js.schema).toMatchObject({ additionalProperties: false, required: ["title"] }); + }); +}); + +describe("Anthropic mapping", () => { + it("hoists system to top level and orders assistant text before tool_use; tool_result leads the next user turn", () => { + const req: NormalizedChatRequest = { + messages: [ + { role: "system", content: "be helpful" }, + { role: "user", content: "q" }, + { role: "assistant", content: "let me check", toolCalls: [{ id: "tu1", name: "create_note", args: { path: "a.md" } }] }, + { role: "tool", results: [{ toolCallId: "tu1", name: "create_note", content: "done", isError: false }] }, + ], + tools: [tool], + toolChoice: "required", + }; + const body = buildChatBody("anthropic", "claude-3-5-sonnet", req) as Record; + expect(body.system).toBe("be helpful"); + const msgs = body.messages as Array>; + // system removed from messages + expect(msgs.map((m) => m.role)).toEqual(["user", "assistant", "user"]); + const assistantBlocks = msgs[1].content as Array>; + expect(assistantBlocks[0]).toEqual({ type: "text", text: "let me check" }); + expect(assistantBlocks[1]).toMatchObject({ type: "tool_use", id: "tu1", name: "create_note" }); + const resultBlocks = msgs[2].content as Array>; + expect(resultBlocks[0]).toMatchObject({ type: "tool_result", tool_use_id: "tu1", content: "done" }); + expect(body.tool_choice).toMatchObject({ type: "any" }); + }); + + it("maps native output_config for structured output, injecting additionalProperties:false + required (Anthropic requires it)", () => { + const req: NormalizedChatRequest = { + messages: [{ role: "user", content: "x" }], + responseFormat: { schema: { type: "object", properties: { a: { type: "string" } } } }, + }; + const body = buildChatBody("anthropic", "claude-sonnet-4-6", req) as Record; + expect(body.output_config).toEqual({ + format: { + type: "json_schema", + schema: { type: "object", additionalProperties: false, required: ["a"], properties: { a: { type: "string" } } }, + }, + }); + }); + + it("maps tool_choice variants and disable_parallel_tool_use", () => { + const base: NormalizedChatRequest = { messages: [{ role: "user", content: "q" }], tools: [tool] }; + const pick = (r: NormalizedChatRequest) => (buildChatBody("anthropic", "m", r) as Record).tool_choice; + expect(pick({ ...base, toolChoice: { name: "create_note" } })).toEqual({ type: "tool", name: "create_note" }); + expect(pick({ ...base, disableParallel: true })).toEqual({ type: "auto", disable_parallel_tool_use: true }); + expect(pick({ ...base, toolChoice: "required", disableParallel: true })).toEqual({ type: "any", disable_parallel_tool_use: true }); + }); + + it("never sends a tool-level strict field (no such field on the Anthropic API)", () => { + const body = buildChatBody("anthropic", "m", { messages: [{ role: "user", content: "q" }], tools: [{ ...tool, strict: true }] }) as Record; + expect((body.tools as Array>)[0]).toEqual({ name: "create_note", description: "Create a note", input_schema: tool.parameters }); + }); + + it("parses tool_use blocks and joins text across all blocks", () => { + const res = parseChatResponse("anthropic", { + stop_reason: "tool_use", + content: [ + { type: "text", text: "thinking " }, + { type: "tool_use", id: "tu1", name: "create_note", input: { path: "a.md" } }, + ], + usage: { input_tokens: 5, output_tokens: 7 }, + }); + expect(res.content).toBe("thinking "); + expect(res.normalizedStopReason).toBe("tool_calls"); + expect(res.toolCalls[0]).toMatchObject({ id: "tu1", name: "create_note", args: { path: "a.md" } }); + expect(res.usage).toEqual({ promptTokens: 5, completionTokens: 7, totalTokens: 12 }); + }); +}); + +describe("Gemini mapping", () => { + it("uses role 'model' + functionResponse user parts, no 'tool' role; sets functionDeclarations + responseSchema", () => { + const req: NormalizedChatRequest = { + messages: [ + { role: "system", content: "sys" }, + { role: "user", content: "q" }, + { role: "assistant", content: "", toolCalls: [{ id: "g0", name: "create_note", args: { path: "a.md" } }] }, + { role: "tool", results: [{ toolCallId: "g0", name: "create_note", content: "ok" }] }, + ], + tools: [tool], + toolChoice: "auto", + responseFormat: { schema: { type: "object", properties: { a: { type: "string" } } } }, + }; + const body = buildChatBody("gemini", "gemini-1.5-pro", req) as Record; + expect(body.systemInstruction).toMatchObject({ role: "system" }); + const contents = body.contents as Array>; + expect(contents.map((c) => c.role)).toEqual(["user", "model", "user"]); + const modelParts = contents[1].parts as Array>; + expect(modelParts.some((p) => p.functionCall)).toBe(true); + const fnRespParts = contents[2].parts as Array>; + expect(fnRespParts[0]).toEqual({ functionResponse: { name: "create_note", response: { result: "ok" } } }); + expect(body.tools).toEqual([{ functionDeclarations: [{ name: "create_note", description: "Create a note", parameters: tool.parameters }] }]); + const gc = body.generationConfig as Record; + expect(gc.responseMimeType).toBe("application/json"); + expect(gc.responseSchema).toBeDefined(); + }); + + it("echoes the original parts verbatim when providerRaw is present (preserves thoughtSignature)", () => { + const raw = [{ functionCall: { name: "create_note", args: { path: "a.md" } }, thoughtSignature: "sig123" }]; + const req: NormalizedChatRequest = { + messages: [ + { role: "user", content: "q" }, + { role: "assistant", content: "", toolCalls: [{ id: "g0", name: "create_note", args: { path: "a.md" } }], providerRaw: raw }, + ], + tools: [tool], + }; + const body = buildChatBody("gemini", "gemini-1.5-pro", req) as Record; + const contents = body.contents as Array>; + expect(contents[1].parts).toBe(raw); + }); + + it("maps toolConfig mode for required and a forced {name}", () => { + const base: NormalizedChatRequest = { messages: [{ role: "user", content: "q" }], tools: [tool] }; + const cfg = (r: NormalizedChatRequest) => (buildChatBody("gemini", "m", r) as Record).toolConfig; + expect(cfg({ ...base, toolChoice: "required" })).toEqual({ functionCallingConfig: { mode: "ANY" } }); + expect(cfg({ ...base, toolChoice: { name: "create_note" } })).toEqual({ functionCallingConfig: { mode: "ANY", allowedFunctionNames: ["create_note"] } }); + expect(cfg({ ...base, toolChoice: "none" })).toEqual({ functionCallingConfig: { mode: "NONE" } }); + }); + + it("parses functionCall parts (synthesizing ids) and keeps raw parts for echo", () => { + const res = parseChatResponse("gemini", { + candidates: [{ content: { role: "model", parts: [{ text: "ok " }, { functionCall: { name: "create_note", args: { path: "a.md" } } }] }, finishReason: "STOP" }], + usageMetadata: { promptTokenCount: 3, candidatesTokenCount: 4, totalTokenCount: 7 }, + }); + expect(res.content).toBe("ok "); + expect(res.normalizedStopReason).toBe("tool_calls"); + expect(res.toolCalls[0]).toMatchObject({ id: "gemini-call-0", name: "create_note", args: { path: "a.md" } }); + expect(Array.isArray(res.providerRaw)).toBe(true); + }); +}); + +describe("injectStrictObjectSchema", () => { + it("adds additionalProperties:false + required at every object level", () => { + const out = injectStrictObjectSchema({ + type: "object", + properties: { outer: { type: "object", properties: { inner: { type: "string" } } } }, + }); + expect(out).toMatchObject({ additionalProperties: false, required: ["outer"] }); + expect((out.properties?.outer as Record)).toMatchObject({ + additionalProperties: false, + required: ["inner"], + }); + }); +}); diff --git a/src/ai/tools/providerToolMapping.ts b/src/ai/tools/providerToolMapping.ts new file mode 100644 index 00000000..f5ae9a49 --- /dev/null +++ b/src/ai/tools/providerToolMapping.ts @@ -0,0 +1,502 @@ +/** + * Pure per-provider build/parse for tool calling + structured output (#714). + * + * No Obsidian imports — fully unit-testable. The execute loop (runToolLoop) and the + * chatRequest entrypoint use `buildChatBody` to turn a NormalizedChatRequest into a + * provider request body, and `parseChatResponse` to turn the raw provider JSON into + * a normalized result (content + toolCalls + stop reason + usage). + * + * The three providers diverge in three load-bearing ways (verified against the + * fc-prototype + current docs): (1) where tool calls live, (2) how arguments are + * encoded (OpenAI = JSON STRING → parse; Anthropic/Gemini = object), (3) how results + * thread back (OpenAI flat tool message; Anthropic tool_result-first user turn; + * Gemini functionResponse parts under a user turn — no "tool" role). + */ +import type { + JSONSchema, + NormalizedChatRequest, + NormalizedMessage, + NormalizedStopReason, + NormalizedToolCall, + NormalizedToolChoice, + NormalizedToolDefinition, +} from "./NormalizedTools"; + +export type ProviderKind = "openai" | "anthropic" | "gemini"; + +export interface ParsedChatResult { + content: string; + toolCalls: NormalizedToolCall[]; + normalizedStopReason: NormalizedStopReason; + rawStopReason: string; + usage: { promptTokens: number; completionTokens: number; totalTokens: number }; + /** Opaque provider blocks to echo back next turn (Gemini thoughtSignature parts). */ + providerRaw?: unknown; +} + +type Body = Record; + +// --------------------------------------------------------------------------- +// Strict-output schema injection (OpenAI structured output) +// --------------------------------------------------------------------------- +// OpenAI's `strict: true` requires additionalProperties:false AND every property +// listed in `required`, at every object level. Output schemas are a SEPARATE path +// from author tool-INPUT schemas (which the validator deliberately keeps free of +// additionalProperties), so we inject these here at send time. +export function injectStrictObjectSchema(schema: JSONSchema): JSONSchema { + const out: JSONSchema = Array.isArray(schema) + ? schema + : { ...(schema as JSONSchema) }; + + const isObject = + out.type === "object" || + (Array.isArray(out.type) && out.type.includes("object")) || + out.properties !== undefined; + + if (out.properties) { + const props: Record = {}; + for (const [k, v] of Object.entries(out.properties)) { + props[k] = injectStrictObjectSchema(v); + } + out.properties = props; + } + if (out.items && !Array.isArray(out.items)) { + out.items = injectStrictObjectSchema(out.items); + } + if (isObject) { + out.additionalProperties = false; + out.required = out.properties ? Object.keys(out.properties) : []; + } + return out; +} + +// =========================================================================== +// OpenAI-compatible +// =========================================================================== +function openaiMessages(messages: NormalizedMessage[]): Body[] { + const out: Body[] = []; + for (const m of messages) { + if (m.role === "system") out.push({ role: "system", content: m.content }); + else if (m.role === "user") out.push({ role: "user", content: m.content }); + else if (m.role === "assistant") { + const msg: Body = { role: "assistant", content: m.content || null }; + if (m.toolCalls && m.toolCalls.length > 0) { + msg.tool_calls = m.toolCalls.map((c) => ({ + id: c.id, + type: "function", + function: { + name: c.name, + // Re-echo the exact original string when we have it (loss-free); else serialize. + arguments: c.rawArgs ?? JSON.stringify(c.args ?? {}), + }, + })); + } + out.push(msg); + } else { + // tool results: one flat {role:'tool'} message per result. + for (const r of m.results) { + out.push({ + role: "tool", + tool_call_id: r.toolCallId, + content: r.isError ? `ERROR: ${r.content}` : r.content, + }); + } + } + } + return out; +} + +function openaiTools(tools: NormalizedToolDefinition[]): Body[] { + return tools.map((t) => ({ + type: "function", + function: { + name: t.name, + description: t.description, + parameters: t.strict + ? injectStrictObjectSchema(t.parameters) + : t.parameters, + ...(t.strict ? { strict: true } : {}), + }, + })); +} + +function openaiToolChoice(choice: NormalizedToolChoice): unknown { + if (typeof choice === "string") return choice; // 'auto' | 'none' | 'required' + return { type: "function", function: { name: choice.name } }; +} + +// OpenAI reasoning models (o-series, and gpt-5 / newer) reject the classic +// `max_tokens` — they require `max_completion_tokens`. Detect by NAME so the rename +// applies only to OpenAI-proper's reasoning line: OpenAI-compatible endpoints +// (Ollama / Groq / Together / …), whose model names don't match this shape, keep +// `max_tokens`. Forward-looking: matches gpt-5..gpt-9 and gpt-10+ as well as o1..o9. +const OPENAI_REASONING_MODEL_RE = /^(?:o[1-9]|gpt-(?:[5-9]|\d\d))/i; + +function buildOpenAIBody(modelName: string, req: NormalizedChatRequest): Body { + const body: Body = { + model: modelName, + ...(req.modelParams ?? {}), + messages: openaiMessages(req.messages), + }; + if (req.tools && req.tools.length > 0) { + body.tools = openaiTools(req.tools); + if (req.toolChoice) body.tool_choice = openaiToolChoice(req.toolChoice); + if (req.disableParallel) body.parallel_tool_calls = false; + } + if (req.maxOutputTokens !== undefined) { + if (OPENAI_REASONING_MODEL_RE.test(modelName)) + body.max_completion_tokens = req.maxOutputTokens; + else body.max_tokens = req.maxOutputTokens; + } + if (req.responseFormat) { + body.response_format = req.responseFormat.schema + ? { + type: "json_schema", + json_schema: { + name: req.responseFormat.name ?? "response", + // Default to a real guarantee: strict + injected output schema. + strict: req.responseFormat.strict ?? true, + schema: + (req.responseFormat.strict ?? true) + ? injectStrictObjectSchema(req.responseFormat.schema) + : req.responseFormat.schema, + }, + } + : { type: "json_object" }; + } + return body; +} + +interface OpenAIToolCallRaw { + id: string; + function: { name: string; arguments: unknown }; +} +function parseOpenAIResponse(json: Record): ParsedChatResult { + const choice = (json.choices as Array>)?.[0] ?? {}; + const msg = (choice.message as Record) ?? {}; + const rawCalls = (msg.tool_calls as OpenAIToolCallRaw[] | undefined) ?? []; + const toolCalls = rawCalls.map((tc) => parseOpenAIToolCall(tc)); + const finish = String(choice.finish_reason ?? ""); + const usage = (json.usage as Record) ?? {}; + return { + content: (msg.content as string) ?? "", + toolCalls, + normalizedStopReason: + toolCalls.length > 0 || finish === "tool_calls" + ? "tool_calls" + : finish === "length" + ? "length" + : finish === "stop" + ? "stop" + : "other", + rawStopReason: finish, + usage: { + promptTokens: usage.prompt_tokens ?? 0, + completionTokens: usage.completion_tokens ?? 0, + totalTokens: usage.total_tokens ?? 0, + }, + }; +} + +function parseOpenAIToolCall(tc: OpenAIToolCallRaw): NormalizedToolCall { + const raw = tc.function.arguments; + if (typeof raw === "string") { + try { + return { + id: tc.id, + name: tc.function.name, + args: JSON.parse(raw || "{}") as Record, + rawArgs: raw, + }; + } catch { + return { + id: tc.id, + name: tc.function.name, + args: null, + rawArgs: raw, + parseError: true, + }; + } + } + // Defensive: some OpenAI-compatible servers (Ollama native) send an object. + return { + id: tc.id, + name: tc.function.name, + args: (raw as Record) ?? {}, + }; +} + +// =========================================================================== +// Anthropic +// =========================================================================== +function buildAnthropicBody( + modelName: string, + req: NormalizedChatRequest, + defaultMaxTokens: number, +): Body { + const systemParts: string[] = []; + const messages: Body[] = []; + for (const m of req.messages) { + if (m.role === "system") { + if (m.content) systemParts.push(m.content); + } else if (m.role === "user") { + messages.push({ role: "user", content: m.content }); + } else if (m.role === "assistant") { + const blocks: Body[] = []; + if (m.content) blocks.push({ type: "text", text: m.content }); // text BEFORE tool_use + for (const c of m.toolCalls ?? []) + blocks.push({ type: "tool_use", id: c.id, name: c.name, input: c.args ?? {} }); + messages.push({ role: "assistant", content: blocks }); + } else { + // tool_result blocks FIRST in a new user turn, immediately after tool_use. + messages.push({ + role: "user", + content: m.results.map((r) => ({ + type: "tool_result", + tool_use_id: r.toolCallId, + content: r.content, + ...(r.isError ? { is_error: true } : {}), + })), + }); + } + } + + const body: Body = { + model: modelName, + max_tokens: req.maxOutputTokens ?? defaultMaxTokens, + messages, + }; + if (systemParts.length > 0) body.system = systemParts.join("\n\n"); + if (req.tools && req.tools.length > 0) { + // NOTE: the Anthropic Messages API has no tool-level `strict` field (that is + // OpenAI-only) — sending one risks an unknown-field 400, so we never do. + body.tools = req.tools.map((t) => ({ + name: t.name, + description: t.description, + input_schema: t.parameters, + })); + if (req.toolChoice) { + body.tool_choice = anthropicToolChoice(req.toolChoice, req.disableParallel); + } else if (req.disableParallel) { + body.tool_choice = { type: "auto", disable_parallel_tool_use: true }; + } + } + if (req.responseFormat?.schema) { + // Native structured output (GA on Claude 4.x). Anthropic REQUIRES every object + // in the schema to set `additionalProperties: false` (verified live: a raw + // schema 400s with "additionalProperties must be explicitly set to false") — so + // inject it (and the all-required closure), exactly as the OpenAI strict path does. + body.output_config = { + format: { + type: "json_schema", + schema: injectStrictObjectSchema(req.responseFormat.schema), + }, + }; + } + return body; +} + +function anthropicToolChoice( + choice: NormalizedToolChoice, + disableParallel?: boolean, +): Body { + const dp = disableParallel ? { disable_parallel_tool_use: true } : {}; + if (typeof choice === "string") { + if (choice === "required") return { type: "any", ...dp }; + if (choice === "none") return { type: "none" }; + return { type: "auto", ...dp }; + } + return { type: "tool", name: choice.name, ...dp }; +} + +interface AnthropicBlockRaw { + type: string; + text?: string; + id?: string; + name?: string; + input?: Record; +} +function parseAnthropicResponse(json: Record): ParsedChatResult { + const blocks = (json.content as AnthropicBlockRaw[] | undefined) ?? []; + const toolCalls: NormalizedToolCall[] = blocks + .filter((b) => b.type === "tool_use") + .map((b) => ({ id: b.id ?? "", name: b.name ?? "", args: b.input ?? {} })); + const stop = String(json.stop_reason ?? ""); + const usage = (json.usage as Record) ?? {}; + const input = usage.input_tokens ?? 0; + const output = usage.output_tokens ?? 0; + return { + content: blocks + .filter((b) => b.type === "text") + .map((b) => b.text ?? "") + .join(""), + toolCalls, + normalizedStopReason: + stop === "tool_use" || toolCalls.length > 0 + ? "tool_calls" + : stop === "max_tokens" + ? "length" + : stop === "end_turn" || stop === "stop_sequence" + ? "stop" + : "other", + rawStopReason: stop, + usage: { promptTokens: input, completionTokens: output, totalTokens: input + output }, + }; +} + +// =========================================================================== +// Gemini +// =========================================================================== +function buildGeminiBody(modelName: string, req: NormalizedChatRequest): Body { + const systemParts: string[] = []; + const contents: Body[] = []; + for (const m of req.messages) { + if (m.role === "system") { + if (m.content) systemParts.push(m.content); + } else if (m.role === "user") { + contents.push({ role: "user", parts: [{ text: m.content }] }); + } else if (m.role === "assistant") { + // Echo the original parts verbatim when present (preserves thoughtSignature). + if (Array.isArray(m.providerRaw)) { + contents.push({ role: "model", parts: m.providerRaw as Body[] }); + } else { + const parts: Body[] = []; + if (m.content) parts.push({ text: m.content }); + for (const c of m.toolCalls ?? []) + parts.push({ functionCall: { name: c.name, args: c.args ?? {} } }); + contents.push({ role: "model", parts }); + } + } else { + contents.push({ + role: "user", + parts: m.results.map((r) => ({ + functionResponse: { + name: r.name, + response: r.isError ? { error: r.content } : { result: r.content }, + }, + })), + }); + } + } + + const body: Body = { contents }; + if (systemParts.length > 0) { + body.systemInstruction = { + role: "system", + parts: [{ text: systemParts.join("\n\n") }], + }; + } + if (req.tools && req.tools.length > 0) { + body.tools = [ + { + functionDeclarations: req.tools.map((t) => ({ + name: t.name, + description: t.description, + parameters: t.parameters, + })), + }, + ]; + if (req.toolChoice) body.toolConfig = geminiToolConfig(req.toolChoice); + } + const generationConfig: Body = {}; + const mp = req.modelParams ?? {}; + if (typeof mp.temperature === "number") generationConfig.temperature = mp.temperature; + if (typeof mp.top_p === "number") generationConfig.topP = mp.top_p; + if (req.maxOutputTokens !== undefined) + generationConfig.maxOutputTokens = req.maxOutputTokens; + if (req.responseFormat?.schema) { + generationConfig.responseMimeType = "application/json"; + generationConfig.responseSchema = req.responseFormat.schema; + } + if (Object.keys(generationConfig).length > 0) + body.generationConfig = generationConfig; + return body; +} + +function geminiToolConfig(choice: NormalizedToolChoice): Body { + if (typeof choice === "string") { + const mode = + choice === "required" ? "ANY" : choice === "none" ? "NONE" : "AUTO"; + return { functionCallingConfig: { mode } }; + } + return { + functionCallingConfig: { mode: "ANY", allowedFunctionNames: [choice.name] }, + }; +} + +interface GeminiPartRaw { + text?: string; + functionCall?: { id?: string; name: string; args?: Record }; + [key: string]: unknown; +} +function parseGeminiResponse(json: Record): ParsedChatResult { + const candidate = + (json.candidates as Array> | undefined)?.[0] ?? {}; + const content = (candidate.content as { parts?: GeminiPartRaw[] }) ?? {}; + const parts = content.parts ?? []; + const toolCalls: NormalizedToolCall[] = parts + .filter((p) => p.functionCall) + .map((p, i) => ({ + id: p.functionCall?.id ?? `gemini-call-${i}`, + name: p.functionCall?.name ?? "", + args: p.functionCall?.args ?? {}, + })); + const finish = String(candidate.finishReason ?? ""); + const meta = (json.usageMetadata as Record) ?? {}; + return { + content: parts + .filter((p) => typeof p.text === "string") + .map((p) => p.text ?? "") + .join(""), + toolCalls, + normalizedStopReason: + toolCalls.length > 0 + ? "tool_calls" + : finish === "MAX_TOKENS" + ? "length" + : finish === "STOP" + ? "stop" + : "other", + rawStopReason: finish, + usage: { + promptTokens: meta.promptTokenCount ?? 0, + completionTokens: meta.candidatesTokenCount ?? 0, + totalTokens: meta.totalTokenCount ?? 0, + }, + // Preserve raw parts so the next turn can echo thoughtSignature unchanged. + providerRaw: parts, + }; +} + +// =========================================================================== +// Dispatch +// =========================================================================== +export function buildChatBody( + kind: ProviderKind, + modelName: string, + req: NormalizedChatRequest, + anthropicDefaultMaxTokens = 8192, +): Body { + switch (kind) { + case "anthropic": + return buildAnthropicBody(modelName, req, anthropicDefaultMaxTokens); + case "gemini": + return buildGeminiBody(modelName, req); + default: + return buildOpenAIBody(modelName, req); + } +} + +export function parseChatResponse( + kind: ProviderKind, + json: Record, +): ParsedChatResult { + switch (kind) { + case "anthropic": + return parseAnthropicResponse(json); + case "gemini": + return parseGeminiResponse(json); + default: + return parseOpenAIResponse(json); + } +} diff --git a/src/ai/tools/runToolLoop.test.ts b/src/ai/tools/runToolLoop.test.ts new file mode 100644 index 00000000..9480b5b5 --- /dev/null +++ b/src/ai/tools/runToolLoop.test.ts @@ -0,0 +1,238 @@ +import { describe, it, expect, vi } from "vitest"; +import { + runToolLoop, + stringifyToolResult, + type RunToolLoopDeps, + type ToolEntry, +} from "./runToolLoop"; +import type { NormalizedToolCall } from "./NormalizedTools"; +import type { ParsedChatResult } from "./providerToolMapping"; + +class FakeAbort extends Error {} + +function turn(p: Partial): ParsedChatResult { + return { + content: p.content ?? "", + toolCalls: p.toolCalls ?? [], + normalizedStopReason: p.normalizedStopReason ?? (p.toolCalls?.length ? "tool_calls" : "stop"), + rawStopReason: p.rawStopReason ?? "", + usage: p.usage ?? { promptTokens: 1, completionTokens: 1, totalTokens: 2 }, + providerRaw: p.providerRaw, + }; +} + +function call(name: string, args: Record | null, extra: Partial = {}): NormalizedToolCall { + return { id: extra.id ?? `c-${name}`, name, args, ...extra }; +} + +function makeDeps( + turns: ParsedChatResult[], + tools: Record, + overrides: Partial = {}, +): RunToolLoopDeps { + const queue = [...turns]; + return { + request: { messages: [{ role: "user", content: "go" }], tools: Object.values(tools).map((t) => t.definition) }, + dispatch: vi.fn(async (_req, ctx) => { + // If a final step is forced, return a plain text turn unless the test queued one. + const next = queue.shift(); + if (next) return next; + return turn({ content: ctx.isFinalStep ? "final text" : "done", normalizedStopReason: "stop" }); + }), + getTool: (name) => tools[name], + confirm: async () => true, + validateArgs: () => null, + isAbortError: (e) => e instanceof FakeAbort, + maxSteps: 8, + ...overrides, + }; +} + +function toolEntry(name: string, execute: ToolEntry["execute"], readOnly = false): ToolEntry { + return { + definition: { name, description: name, parameters: { type: "object", properties: {} } }, + execute, + readOnly, + }; +} + +describe("runToolLoop", () => { + it("runs a single tool call then returns the final text", async () => { + const exec = vi.fn(async () => "tool-ran"); + const deps = makeDeps( + [turn({ toolCalls: [call("t", { x: 1 })] }), turn({ content: "all done", normalizedStopReason: "stop" })], + { t: toolEntry("t", exec) }, + ); + const res = await runToolLoop(deps); + expect(exec).toHaveBeenCalledOnce(); + expect(res.text).toBe("all done"); + expect(res.finishReason).toBe("stop"); + // transcript: user, assistant(tool_calls), tool(results), [final text has no append] + const toolTurn = res.messages.find((m) => m.role === "tool"); + expect(toolTurn).toBeDefined(); + expect(res.steps.length).toBe(2); + }); + + it("executes parallel tool calls from one turn and returns all results together", async () => { + const exec = vi.fn(async (args: Record) => `r:${args.city}`); + const deps = makeDeps( + [ + turn({ toolCalls: [call("w", { city: "Paris" }, { id: "a" }), call("w", { city: "London" }, { id: "b" })] }), + turn({ content: "compared", normalizedStopReason: "stop" }), + ], + { w: toolEntry("w", exec) }, + ); + const res = await runToolLoop(deps); + expect(exec).toHaveBeenCalledTimes(2); + const toolTurn = res.messages.find((m) => m.role === "tool"); + expect(toolTurn && "results" in toolTurn && toolTurn.results.map((r) => r.toolCallId)).toEqual(["a", "b"]); + }); + + it("propagates a MacroAbort-style error from a handler (kills the run)", async () => { + const deps = makeDeps( + [turn({ toolCalls: [call("t", {})] })], + { t: toolEntry("t", async () => { throw new FakeAbort("aborted"); }) }, + ); + await expect(runToolLoop(deps)).rejects.toBeInstanceOf(FakeAbort); + }); + + it("turns a non-abort handler throw into an isError result", async () => { + const deps = makeDeps( + [turn({ toolCalls: [call("t", {})] }), turn({ content: "recovered" })], + { t: toolEntry("t", async () => { throw new Error("boom"); }) }, + ); + const res = await runToolLoop(deps); + const tr = res.steps[0].toolResults[0]; + expect(tr.isError).toBe(true); + expect(tr.content).toMatch(/tool execution failed: boom/); + expect(res.text).toBe("recovered"); + }); + + it("returns isError for an unknown tool, a parseError call, a schema-invalid call, and a denied call", async () => { + const deps = makeDeps( + [ + turn({ toolCalls: [ + call("missing", {}), + call("t", null, { id: "pe", parseError: true }), + call("t", { bad: true }, { id: "iv" }), + call("t", {}, { id: "dn" }), + ] }), + turn({ content: "ok" }), + ], + { t: toolEntry("t", async () => "ran") }, + { + validateArgs: (_tool, args) => ("bad" in args ? "bad field" : null), + confirm: async (c) => c.id !== "dn", + }, + ); + const res = await runToolLoop(deps); + const byId = Object.fromEntries(res.steps[0].toolResults.map((r) => [r.toolCallId, r])); + expect(byId["c-missing"].content).toMatch(/unknown tool/); + expect(byId["pe"].content).toMatch(/could not parse/); + expect(byId["iv"].content).toMatch(/invalid arguments: bad field/); + expect(byId["dn"].content).toMatch(/denied by user/); + expect(res.steps[0].toolResults.every((r) => r.isError)).toBe(true); + }); + + it("forces a final no-tool turn at the step budget and reports max-steps when the model still wants tools", async () => { + // A 'never stops' model: every dispatch returns a tool call, even on the final step. + const deps = makeDeps([], { t: toolEntry("t", async () => "x") }, { + maxSteps: 2, + dispatch: vi.fn(async (_req, ctx) => turn({ + content: ctx.isFinalStep ? "forced text" : "", + toolCalls: ctx.isFinalStep ? [call("t", {})] : [call("t", {})], + })), + }); + const res = await runToolLoop(deps); + expect(res.finishReason).toBe("max-steps"); + // dispatch received isFinalStep on the last call + const calls = (deps.dispatch as ReturnType).mock.calls; + expect(calls[calls.length - 1][1]).toEqual({ isFinalStep: true }); + }); + + it("trips the transcript-bytes ceiling (cumulative re-sent transcript) and forces an early final turn", async () => { + const big = "x".repeat(2000); + // A model that ALWAYS wants tools — without the byte ceiling it would run to + // maxSteps (8). The ceiling must force a final text turn far earlier because the + // 2 KB tool result is re-sent every turn (>> the 100-byte cap). + const deps = makeDeps( + [], + { t: toolEntry("t", async () => big) }, + { + maxTranscriptBytes: 100, + maxSteps: 8, + dispatch: vi.fn(async (_req, ctx) => + turn({ + content: ctx.isFinalStep ? "wrapped up" : "", + toolCalls: ctx.isFinalStep ? [] : [call("t", {})], + normalizedStopReason: ctx.isFinalStep ? "stop" : "tool_calls", + }), + ), + }, + ); + const res = await runToolLoop(deps); + expect(res.text).toBe("wrapped up"); + expect(res.finishReason).toBe("stop"); + // Seed is tiny (<100 B), so step 0's small upload does not trip; step 1 re-sends + // the first 2 KB result → trips → step 2 is the forced final. Far below maxSteps. + expect(res.steps.length).toBe(3); + }); + + it("returns context-overflow when a dispatch errors with a context-window overflow", async () => { + let n = 0; + const deps = makeDeps([], { t: toolEntry("t", async () => "x") }, { + maxSteps: 5, + dispatch: vi.fn(async () => { + n++; + if (n === 1) return turn({ toolCalls: [call("t", {})] }); + throw new Error("input length and max_tokens exceed context limit"); + }), + isContextOverflow: (e) => e instanceof Error && e.message.includes("context limit"), + }); + const res = await runToolLoop(deps); + expect(res.finishReason).toBe("context-overflow"); + }); + + it("honors shouldStop (stopWhen) by forcing an early final turn", async () => { + const deps = makeDeps( + [turn({ toolCalls: [call("t", {})] }), turn({ content: "stopped early", normalizedStopReason: "stop" })], + { t: toolEntry("t", async () => "x") }, + { maxSteps: 8, shouldStop: ({ steps }) => steps.some((s) => s.toolResults.length > 0) }, + ); + const res = await runToolLoop(deps); + expect(res.text).toBe("stopped early"); + expect(res.steps.length).toBe(2); + }); + + it("stops gracefully if online features flip off mid-loop", async () => { + let disabled = false; + const deps = makeDeps( + [turn({ toolCalls: [call("t", {})] }), turn({ content: "never reached" })], + { t: toolEntry("t", async () => { disabled = true; return "x"; }) }, + { isOnlineDisabled: () => disabled }, + ); + const res = await runToolLoop(deps); + expect(res.finishReason).toBe("aborted"); + }); +}); + +describe("stringifyToolResult", () => { + it("passes strings through, JSON-stringifies objects, and marks null/undefined as ok", () => { + expect(stringifyToolResult("hi", 1000)).toEqual({ content: "hi", isError: false }); + expect(stringifyToolResult({ a: 1 }, 1000)).toEqual({ content: '{"a":1}', isError: false }); + expect(stringifyToolResult(null, 1000)).toEqual({ content: '{"ok":true}', isError: false }); + expect(stringifyToolResult(undefined, 1000)).toEqual({ content: '{"ok":true}', isError: false }); + }); + + it("truncates oversized results with a marker", () => { + const r = stringifyToolResult("y".repeat(5000), 100); + expect(r.content.endsWith("…[truncated]")).toBe(true); + expect(r.content.length).toBeLessThan(5000); + }); + + it("returns isError for a non-serializable value", () => { + const circular: Record = {}; + circular.self = circular; + expect(stringifyToolResult(circular, 1000).isError).toBe(true); + }); +}); diff --git a/src/ai/tools/runToolLoop.ts b/src/ai/tools/runToolLoop.ts new file mode 100644 index 00000000..daa69b6c --- /dev/null +++ b/src/ai/tools/runToolLoop.ts @@ -0,0 +1,326 @@ +/** + * Provider-agnostic agentic tool loop (#714). PURE — every Obsidian/provider touch + * point is injected, so the loop is fully unit-testable with a fake `dispatch`. + * + * Per spec §7: dispatch → if tool calls, run each (one result PER call id) → append + * assistant + tool-result turns → repeat until the model stops or the step budget is + * hit. The LAST allowed step forces toolChoice:'none' so the run always lands on a + * text answer with no orphaned tool_use. A handler throwing a hard-abort error + * (MacroAbortError/UserCancelError) propagates out and kills the macro; any other + * throw becomes an `isError` tool result the model can recover from. Caps bound cost + * and exfiltration: a step ceiling, a per-result byte cap, and a cumulative + * transcript-bytes ceiling (the real exfiltration metric). + */ +import type { + NormalizedChatRequest, + NormalizedMessage, + NormalizedToolCall, + NormalizedToolDefinition, + NormalizedToolResult, +} from "./NormalizedTools"; +import type { ParsedChatResult } from "./providerToolMapping"; + +export interface ToolEntry { + definition: NormalizedToolDefinition; + execute: ( + args: Record, + ctx: { toolCallId: string; toolName: string }, + ) => unknown | Promise; + readOnly?: boolean; +} + +export type LoopFinishReason = + | "stop" + | "length" + | "max-steps" + | "aborted" + | "context-overflow"; + +export interface LoopStep { + stepNumber: number; + text: string; + toolCalls: NormalizedToolCall[]; + toolResults: NormalizedToolResult[]; + finishReason: string; + usage: ParsedChatResult["usage"]; +} + +export interface RunToolLoopResult { + text: string; + steps: LoopStep[]; + finishReason: LoopFinishReason; + usage: { promptTokens: number; completionTokens: number; totalTokens: number }; + messages: NormalizedMessage[]; + /** The final dispatched turn (so the caller can parse structured output). */ + finalTurn: ParsedChatResult; +} + +export interface RunToolLoopDeps { + /** The initial request (messages already seeded with system + first user turn). */ + request: NormalizedChatRequest; + /** Send one turn. `isFinalStep` lets the provider omit tools (Gemini+schema) / force no tool use. */ + dispatch: ( + req: NormalizedChatRequest, + ctx: { isFinalStep: boolean }, + ) => Promise; + /** Look up a registered tool by name (returns undefined for an unknown tool). */ + getTool: (name: string) => ToolEntry | undefined; + /** + * Decide + gate one call. Return true to run it, false to decline (→ isError + * result). THROW (MacroAbortError/UserCancelError) to abort the whole run. + */ + confirm: (call: NormalizedToolCall, tool: ToolEntry) => Promise; + /** Validate model args against the tool's schema; return an error string or null. */ + validateArgs: (tool: ToolEntry, args: Record) => string | null; + /** True if the error must propagate and kill the macro (abort/cancel). */ + isAbortError: (e: unknown) => boolean; + /** True if a dispatch error is an input-context-window overflow (→ graceful stop). */ + isContextOverflow?: (e: unknown) => boolean; + /** Re-checked at the top of every step after the first; true → graceful stop. */ + isOnlineDisabled?: () => boolean; + /** + * Evaluated after each tool-execution step (backs AI-SDK `stopWhen`/`hasToolCall`). + * Returning true makes the NEXT step the forced-final text turn. The `maxSteps` + * hard clamp always applies on top of this. + */ + shouldStop?: (info: { steps: LoopStep[] }) => boolean; + onStepFinish?: (step: LoopStep) => void | Promise; + maxSteps: number; + /** Per-tool-result byte cap (truncate + marker). Default 32 KB. */ + maxResultBytes?: number; + /** Cumulative transcript-uploaded-bytes ceiling. Default 512 KB. */ + maxTranscriptBytes?: number; +} + +const DEFAULT_MAX_RESULT_BYTES = 32 * 1024; +const DEFAULT_MAX_TRANSCRIPT_BYTES = 512 * 1024; + +function byteLength(s: string): number { + return new TextEncoder().encode(s).length; +} + +/** + * Approximate the bytes a single dispatch uploads: the WHOLE transcript is re-sent + * every turn, so the exfiltration metric is the cumulative size of all messages + * (system + user + assistant text/tool-call args + every tool result), summed across + * turns — not each tool result counted once. Counts text payloads; ignores wire + * framing/JSON keys (a small, stable constant per message). + */ +function transcriptBytes(messages: NormalizedMessage[]): number { + let n = 0; + for (const m of messages) { + if (m.role === "tool") { + for (const r of m.results) n += byteLength(r.content); + continue; + } + n += byteLength(m.content); + if (m.role === "assistant" && m.toolCalls) { + for (const c of m.toolCalls) { + if (typeof c.rawArgs === "string") n += byteLength(c.rawArgs); + else if (c.args) { + try { + n += byteLength(JSON.stringify(c.args)); + } catch { + /* unserializable args — skip; can't be uploaded as JSON anyway */ + } + } + } + } + } + return n; +} + +/** Coerce a handler return value to a string for the wire, capped + crash-safe. */ +export function stringifyToolResult( + value: unknown, + maxBytes: number, +): { content: string; isError: boolean } { + let s: string; + if (typeof value === "string") s = value; + else if (value === null || value === undefined) s = '{"ok":true}'; + else { + try { + s = JSON.stringify(value); + } catch { + return { + content: "tool returned a value that could not be serialized", + isError: true, + }; + } + } + if (byteLength(s) > maxBytes) { + // Byte-accurate truncation: shrink the char cut until it fits the byte cap + // (so multibyte content can't overshoot), without splitting a code point. + let cut = Math.min(s.length, maxBytes); + while (cut > 0 && byteLength(s.slice(0, cut)) > maxBytes) { + cut = Math.floor(cut * 0.9) || cut - 1; + } + s = s.slice(0, cut) + " …[truncated]"; + } + return { content: s, isError: false }; +} + +export async function runToolLoop( + deps: RunToolLoopDeps, +): Promise { + const maxResultBytes = deps.maxResultBytes ?? DEFAULT_MAX_RESULT_BYTES; + const maxTranscriptBytes = deps.maxTranscriptBytes ?? DEFAULT_MAX_TRANSCRIPT_BYTES; + const maxSteps = Math.max(1, deps.maxSteps); + + const messages: NormalizedMessage[] = [...deps.request.messages]; + const steps: LoopStep[] = []; + const usage = { promptTokens: 0, completionTokens: 0, totalTokens: 0 }; + let uploadedBytes = 0; + let forceFinal = false; + let lastTurn: ParsedChatResult | null = null; + + for (let step = 0; step < maxSteps; step++) { + if (step > 0 && deps.isOnlineDisabled?.()) { + return finish("aborted"); + } + + const isFinalStep = forceFinal || step === maxSteps - 1; + const turnReq: NormalizedChatRequest = { + ...deps.request, + messages, + ...(isFinalStep ? { toolChoice: "none" as const } : {}), + }; + + // Account for the bytes THIS dispatch uploads (the full transcript, re-sent), + // before sending — so a model that re-runs read tools to stream MB upstream + // trips the ceiling even though each individual result is small. + uploadedBytes += transcriptBytes(messages); + + let turn: ParsedChatResult; + try { + turn = await deps.dispatch(turnReq, { isFinalStep }); + } catch (e) { + // A growing transcript can overflow the context window mid-loop. Stop + // gracefully with whatever we have rather than throwing — no tool_use is + // dangling here (results were appended last step). + if (deps.isContextOverflow?.(e)) return finish("context-overflow"); + throw e; + } + lastTurn = turn; + usage.promptTokens += turn.usage.promptTokens; + usage.completionTokens += turn.usage.completionTokens; + usage.totalTokens += turn.usage.totalTokens; + + // Stop when the model produced no tool calls, or we are on the forced-final step. + if ( + turn.normalizedStopReason !== "tool_calls" || + turn.toolCalls.length === 0 || + isFinalStep + ) { + steps.push({ + stepNumber: step, + text: turn.content, + toolCalls: turn.toolCalls, + toolResults: [], + finishReason: turn.rawStopReason, + usage: turn.usage, + }); + await deps.onStepFinish?.(steps[steps.length - 1]); + // On the forced-final step the model can't actually run tools, so we land here. + const reason: LoopFinishReason = + turn.normalizedStopReason === "length" + ? "length" + : !isFinalStep || turn.toolCalls.length === 0 + ? "stop" + : "max-steps"; + return finish(reason); + } + + // Append the assistant turn (content + tool calls + provider echo blob). + messages.push({ + role: "assistant", + content: turn.content, + toolCalls: turn.toolCalls, + providerRaw: turn.providerRaw, + }); + + // Resolve EVERY tool call to exactly one result (in id order, sequential). + const results: NormalizedToolResult[] = []; + for (const call of turn.toolCalls) { + results.push(await resolveCall(call)); + } + messages.push({ role: "tool", results }); + + steps.push({ + stepNumber: step, + text: turn.content, + toolCalls: turn.toolCalls, + toolResults: results, + finishReason: turn.rawStopReason, + usage: turn.usage, + }); + await deps.onStepFinish?.(steps[steps.length - 1]); + + // Cumulative-bytes ceiling OR a user stopWhen condition → make the next step + // the forced-final text turn. + if (uploadedBytes > maxTranscriptBytes) forceFinal = true; + if (deps.shouldStop?.({ steps })) forceFinal = true; + } + + // Safety net (the forced-final step normally returns from inside the loop). + return finish("max-steps"); + + function finish(reason: LoopFinishReason): RunToolLoopResult { + return { + text: lastTurn?.content ?? "", + steps, + finishReason: reason, + usage, + messages, + finalTurn: + lastTurn ?? { + content: "", + toolCalls: [], + normalizedStopReason: "stop", + rawStopReason: "", + usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, + }, + }; + } + + async function resolveCall( + call: NormalizedToolCall, + ): Promise { + if (call.parseError) { + return errResult(call, "could not parse the tool arguments as JSON"); + } + const tool = deps.getTool(call.name); + if (!tool) { + return errResult(call, `unknown tool "${call.name}"`); + } + const args = call.args ?? {}; + const schemaError = deps.validateArgs(tool, args); + if (schemaError) { + return errResult(call, `invalid arguments: ${schemaError}`); + } + // confirm() throws on a hard cancel/abort (propagates and kills the macro). + const allowed = await deps.confirm(call, tool); + if (!allowed) { + return errResult(call, "tool call denied by user"); + } + let raw: unknown; + try { + raw = await tool.execute(args, { toolCallId: call.id, toolName: call.name }); + } catch (e) { + if (deps.isAbortError(e)) throw e; + return errResult( + call, + `tool execution failed: ${e instanceof Error ? e.message : String(e)}`, + ); + } + const { content, isError } = stringifyToolResult(raw, maxResultBytes); + return { toolCallId: call.id, name: call.name, content, isError }; + } + + function errResult( + call: NormalizedToolCall, + message: string, + ): NormalizedToolResult { + return { toolCallId: call.id, name: call.name, content: message, isError: true }; + } +} diff --git a/src/ai/tools/sanitizeVaultPath.test.ts b/src/ai/tools/sanitizeVaultPath.test.ts new file mode 100644 index 00000000..07b580ad --- /dev/null +++ b/src/ai/tools/sanitizeVaultPath.test.ts @@ -0,0 +1,116 @@ +import { describe, it, expect } from "vitest"; +import { sanitizeVaultPath, UnsafeVaultPathError } from "./sanitizeVaultPath"; + +describe("sanitizeVaultPath", () => { + describe("accepts safe vault-relative paths", () => { + it("returns a normalized relative path", () => { + expect(sanitizeVaultPath("Notes/Foo.md")).toBe("Notes/Foo.md"); + expect(sanitizeVaultPath(" Inbox/today.md ")).toBe("Inbox/today.md"); + expect(sanitizeVaultPath("a//b///c.md")).toBe("a/b/c.md"); + }); + + it("allows dots inside a name (not a leading-dot segment)", () => { + expect(sanitizeVaultPath("My..Note.md")).toBe("My..Note.md"); + expect(sanitizeVaultPath("v1.2.3 release.md")).toBe("v1.2.3 release.md"); + }); + }); + + describe("rejects absolute / drive / UNC paths", () => { + it("rejects a POSIX absolute path", () => { + expect(() => sanitizeVaultPath("/etc/passwd")).toThrow(UnsafeVaultPathError); + }); + it("rejects a Windows drive path", () => { + expect(() => sanitizeVaultPath("C:\\Windows\\x")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath("C:foo")).toThrow(UnsafeVaultPathError); + }); + it("rejects a UNC path (backslashes normalized to // before the absolute check)", () => { + expect(() => sanitizeVaultPath("\\\\server\\share\\x.md")).toThrow( + UnsafeVaultPathError, + ); + }); + }); + + describe("rejects traversal", () => { + it("rejects a .. segment anywhere", () => { + expect(() => sanitizeVaultPath("a/../b.md")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath("../secrets.md")).toThrow(UnsafeVaultPathError); + }); + }); + + describe("config-dir / dotfile floor — every segment, any depth, case-folded", () => { + it("rejects a first-segment config dir", () => { + expect(() => sanitizeVaultPath(".obsidian/plugins/x/main.js")).toThrow( + UnsafeVaultPathError, + ); + }); + it("rejects a NESTED config dir at depth >= 2 (the v3 blocker)", () => { + expect(() => sanitizeVaultPath("Projects/.git/hooks/post-checkout")).toThrow( + UnsafeVaultPathError, + ); + expect(() => sanitizeVaultPath("notes/.obsidian/plugins/evil/main.js")).toThrow( + UnsafeVaultPathError, + ); + }); + it("is immune to casing (structural, not a name denylist)", () => { + expect(() => sanitizeVaultPath("x/.Obsidian/y.md")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath(".GIT/config")).toThrow(UnsafeVaultPathError); + }); + it("rejects a leading-dot basename too", () => { + expect(() => sanitizeVaultPath("Notes/.env")).toThrow(UnsafeVaultPathError); + }); + }); + + describe("character + device-name validation", () => { + it("rejects illegal characters", () => { + expect(() => sanitizeVaultPath("Notes/a:b.md")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath("Notes/a*b.md")).toThrow(UnsafeVaultPathError); + }); + it("rejects a trailing dot/space segment", () => { + expect(() => sanitizeVaultPath("Notes/folder./bar.md")).toThrow( + UnsafeVaultPathError, + ); + expect(() => sanitizeVaultPath("Notes/trailing space /bar.md")).toThrow( + UnsafeVaultPathError, + ); + }); + it("rejects reserved Windows device names", () => { + expect(() => sanitizeVaultPath("CON/x.md")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath("Notes/NUL.md")).toThrow(UnsafeVaultPathError); + }); + it("rejects empty input", () => { + expect(() => sanitizeVaultPath(" ")).toThrow(UnsafeVaultPathError); + expect(() => sanitizeVaultPath("///")).toThrow(UnsafeVaultPathError); + }); + }); + + describe("allowedRoots scoping", () => { + it("allows a path under an allowed root", () => { + expect(sanitizeVaultPath("AI/notes/x.md", { allowedRoots: ["AI"] })).toBe( + "AI/notes/x.md", + ); + expect(sanitizeVaultPath("AI", { allowedRoots: ["AI"] })).toBe("AI"); + }); + it("rejects a path outside the allowed roots", () => { + expect(() => + sanitizeVaultPath("Other/x.md", { allowedRoots: ["AI"] }), + ).toThrow(UnsafeVaultPathError); + }); + it("does not let a leading slash sneak a path past allowedRoots", () => { + // "/AI/x.md" is rejected as absolute before any root check. + expect(() => + sanitizeVaultPath("/AI/x.md", { allowedRoots: ["AI"] }), + ).toThrow(UnsafeVaultPathError); + }); + it("an empty-string root never becomes allow-all", () => { + // A blank entry is dropped; with no real roots left, it falls back to the + // vault-wide default (allow), NOT allow-all-by-accident-because-of-blank. + expect(sanitizeVaultPath("Anywhere/x.md", { allowedRoots: ["", " "] })).toBe( + "Anywhere/x.md", + ); + // But a blank entry alongside a real root must NOT widen it to allow-all. + expect(() => + sanitizeVaultPath("Other/x.md", { allowedRoots: ["", "AI"] }), + ).toThrow(UnsafeVaultPathError); + }); + }); +}); diff --git a/src/ai/tools/sanitizeVaultPath.ts b/src/ai/tools/sanitizeVaultPath.ts new file mode 100644 index 00000000..5db6d5a7 --- /dev/null +++ b/src/ai/tools/sanitizeVaultPath.ts @@ -0,0 +1,144 @@ +/** + * Path sanitization for AI-built-in vault writers (#714). + * + * The model chooses the path a built-in write tool targets, so this is a security + * boundary. Modeled on `validateAssetDestination` (packageImportService.ts) but + * HARDENED per the v5 review: + * + * - Normalization is done HERE (deterministic, no dependency on Obsidian's + * `normalizePath`, whose test stub is incomplete) and in a fixed order: + * trim → backslash→'/' → NFC → reject absolute → collapse '//' + strip slashes. + * - The leading-dot config-dir floor is checked on EVERY segment at ANY depth + * (the precedent checked only segments[0], so `Projects/.git/hooks/x` and + * `notes/.obsidian/plugins/x` slipped through — an RCE-on-next-commit / + * plugin-drop vector). It is a STRUCTURAL rule (segment starts with '.'), which + * is inherently immune to the `.Obsidian`/`.OBSIDIAN` casing bypass — do not + * replace it with a case-sensitive name denylist. + * + * This module is PURE (no Obsidian import) so it is fully unit-testable. The + * symlink / realpath-containment check that writers MUST also perform needs the + * FileSystemAdapter and lives with the writers (see builtins/vaultWriteGuards.ts). + */ +import { + INVALID_FOLDER_CHARS_REGEX, + INVALID_FOLDER_CONTROL_CHARS_REGEX, + INVALID_FOLDER_TRAILING_CHARS_REGEX, + isReservedWindowsDeviceName, +} from "../../utils/pathValidation"; + +export class UnsafeVaultPathError extends Error { + constructor(message: string) { + super(message); + this.name = "UnsafeVaultPathError"; + } +} + +export interface SanitizeVaultPathOptions { + /** + * If provided (and non-empty after dropping blank entries), the path must + * resolve under one of these vault-relative folders. Absent / all-blank => + * vault-wide (the default). A blank entry NEVER becomes allow-all. + */ + allowedRoots?: string[]; +} + +/** + * Validate + normalize a model-chosen vault-relative path. Throws + * UnsafeVaultPathError on anything unsafe. Returns the normalized relative path + * (forward slashes, no leading/trailing slash). Does NOT touch the filesystem. + */ +export function sanitizeVaultPath( + rawPath: string, + options: SanitizeVaultPathOptions = {}, +): string { + const trimmed = (rawPath ?? "").trim(); + if (!trimmed) { + throw new UnsafeVaultPathError("Path is empty."); + } + + // 1. Backslash → '/' FIRST (turns UNC `\\server\share` into `//server/share`, + // so the absolute check below catches it). + const slashed = trimmed.replace(/\\/g, "/").normalize("NFC"); + + // 2. Reject absolute paths (POSIX root and Windows drive letters) BEFORE + // collapsing — `normalizePath` strips leading slashes, which would hide them. + if (slashed.startsWith("/") || /^[a-zA-Z]:/.test(slashed)) { + throw new UnsafeVaultPathError( + `Refusing to write to an absolute path outside the vault: "${slashed}".`, + ); + } + + // 3. Collapse repeated slashes and strip leading/trailing slashes. + const normalized = slashed.replace(/\/+/g, "/").replace(/^\/+|\/+$/g, ""); + const segments = normalized.split("/").filter((s) => s.length > 0); + if (segments.length === 0) { + throw new UnsafeVaultPathError("Path is empty after normalization."); + } + + for (const segment of segments) { + // Traversal. + if (segment === "." || segment === "..") { + throw new UnsafeVaultPathError( + `Refusing path with a traversal segment ("${segment}"): "${normalized}".`, + ); + } + // Config-dir / dotfile FLOOR — every segment, any depth, structural (not a + // name denylist), so .obsidian/.git/.trash and casing variants are all caught. + if (segment.startsWith(".")) { + throw new UnsafeVaultPathError( + `Refusing to write into a dot/config path segment ("${segment}"): "${normalized}".`, + ); + } + // Per-segment character / device-name validation (incl. the basename). + validateSegment(segment, normalized); + } + + const allowed = (options.allowedRoots ?? []) + .map((root) => normalizeRoot(root)) + .filter((root) => root.length > 0); + + if (allowed.length > 0 && !isUnderAllowedRoot(normalized, allowed)) { + throw new UnsafeVaultPathError( + `Path "${normalized}" is not under an allowed root (${allowed.join(", ")}).`, + ); + } + + return normalized; +} + +function validateSegment(segment: string, fullPath: string): void { + if (INVALID_FOLDER_CONTROL_CHARS_REGEX.test(segment)) { + throw new UnsafeVaultPathError( + `Path segment "${segment}" contains control characters: "${fullPath}".`, + ); + } + // '/' is already consumed by the split; the rest of the class (\\ : * ? " < > |) still applies. + if (INVALID_FOLDER_CHARS_REGEX.test(segment)) { + throw new UnsafeVaultPathError( + `Path segment "${segment}" contains an illegal character (\\ : * ? " < > |): "${fullPath}".`, + ); + } + if (INVALID_FOLDER_TRAILING_CHARS_REGEX.test(segment)) { + throw new UnsafeVaultPathError( + `Path segment "${segment}" cannot end with a space or a period: "${fullPath}".`, + ); + } + const base = segment.replace(/[. ]+$/u, "").split(".")[0] ?? ""; + if (base && isReservedWindowsDeviceName(base)) { + throw new UnsafeVaultPathError( + `Path segment "${segment}" is a reserved device name: "${fullPath}".`, + ); + } +} + +function normalizeRoot(root: string): string { + return (root ?? "") + .trim() + .replace(/\\/g, "/") + .replace(/\/+/g, "/") + .replace(/^\/+|\/+$/g, ""); +} + +function isUnderAllowedRoot(path: string, roots: string[]): boolean { + return roots.some((root) => path === root || path.startsWith(`${root}/`)); +} diff --git a/src/migrations/migrate.ts b/src/migrations/migrate.ts index a95e390a..f8dc32e1 100644 --- a/src/migrations/migrate.ts +++ b/src/migrations/migrate.ts @@ -15,6 +15,7 @@ import setProviderModelDiscoveryMode from "./setProviderModelDiscoveryMode"; import { deepClone } from "src/utils/deepClone"; import migrateProviderApiKeysToSecretStorage from "./migrateProviderApiKeysToSecretStorage"; import migrateToMultipleTemplateFolders from "./migrateToMultipleTemplateFolders"; +import setProviderKind from "./setProviderKind"; import { settingsStore } from "src/settingsStore"; const migrations: Migrations = { @@ -31,6 +32,7 @@ const migrations: Migrations = { setProviderModelDiscoveryMode, migrateProviderApiKeysToSecretStorage, migrateToMultipleTemplateFolders, + setProviderKind, }; async function migrate(plugin: QuickAdd): Promise { diff --git a/src/migrations/setProviderKind.ts b/src/migrations/setProviderKind.ts new file mode 100644 index 00000000..86f3b7b8 --- /dev/null +++ b/src/migrations/setProviderKind.ts @@ -0,0 +1,37 @@ +import type QuickAdd from "src/main"; +import { settingsStore } from "src/settingsStore"; +import type { Migration } from "./Migrations"; +import { deepClone } from "src/utils/deepClone"; +import { getProviderKind } from "src/ai/Provider"; + +/** + * Backfill the `kind` wire-protocol discriminator on every AI provider (#714). The + * tool-calling / structured-output adapter selects on `kind`, not the display name, + * so a custom Anthropic-compatible provider named anything other than "Anthropic" + * routes correctly. Inference (getProviderKind) covers providers without the field + * at runtime; this migration persists the inferred value once. + */ +const setProviderKind: Migration = { + description: "Backfill the wire-protocol kind on each AI provider", + migrate: async (_plugin: QuickAdd) => { + const currentSettings = settingsStore.getState(); + const providers = currentSettings.ai.providers ?? []; + let updated = false; + + for (const provider of providers) { + if (!provider.kind) { + provider.kind = getProviderKind(provider); + updated = true; + } + } + + if (!updated) return; + + settingsStore.setState((state) => ({ + ...state, + ai: { ...state.ai, providers: deepClone(providers) }, + })); + }, +}; + +export default setProviderKind; From 760187c43948b10f41cafe5bace57c5c0d9188ce Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:34:55 +0200 Subject: [PATCH 2/8] feat(ai): ai.agent() tool-calling API, confirm UX + built-in tools (#714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose the AI-SDK-shaped surface on quickAddApi.ai and ship opt-in built-in tools: - ai.agent(config).generate({ prompt, schema?, assignToVariable? }) — bounded tool loop + optional JSON-schema structured output (one bounded repair). - ai.tool(), ai.stepCountIs, ai.hasToolCall; assignToVariable alias on ai.prompt/ai.chunkedPrompt (reserved-name guarded, trimmed at write). - ai.tools.{vault,workspace,system}(options): reads auto-run; writes ask for approval, sanitise paths, fail-rather-than-overwrite, frontmatter-aware, symlink/realpath-contained. High-risk tools deferred. - Per-tool needsApproval + global ai.confirmToolCalls setting (default destructive; undefined→destructive for pre-existing settings) via AIToolConfirmModal (approve / approve-all-this-run / deny / abort). - EITHER-tools-OR-responseFormat per turn; per-run reentrancy guard. --- src/ai/tools/Agent.test.ts | 216 +++++++++ src/ai/tools/Agent.ts | 512 ++++++++++++++++++++++ src/ai/tools/builtins/builtins.test.ts | 119 +++++ src/ai/tools/builtins/shared.ts | 42 ++ src/ai/tools/builtins/systemTools.ts | 30 ++ src/ai/tools/builtins/vaultTools.ts | 301 +++++++++++++ src/ai/tools/builtins/vaultWriteGuards.ts | 90 ++++ src/ai/tools/builtins/workspaceTools.ts | 49 +++ src/gui/AIAssistantSettingsModal.ts | 19 + src/gui/AIToolConfirmModal.ts | 118 +++++ src/quickAddApi.ts | 69 ++- src/settings.ts | 10 + 12 files changed, 1571 insertions(+), 4 deletions(-) create mode 100644 src/ai/tools/Agent.test.ts create mode 100644 src/ai/tools/Agent.ts create mode 100644 src/ai/tools/builtins/builtins.test.ts create mode 100644 src/ai/tools/builtins/shared.ts create mode 100644 src/ai/tools/builtins/systemTools.ts create mode 100644 src/ai/tools/builtins/vaultTools.ts create mode 100644 src/ai/tools/builtins/vaultWriteGuards.ts create mode 100644 src/ai/tools/builtins/workspaceTools.ts create mode 100644 src/gui/AIToolConfirmModal.ts diff --git a/src/ai/tools/Agent.test.ts b/src/ai/tools/Agent.test.ts new file mode 100644 index 00000000..d327ccb3 --- /dev/null +++ b/src/ai/tools/Agent.test.ts @@ -0,0 +1,216 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { CommonResponse } from "../OpenAIRequest"; + +// --- Mock the Obsidian-coupled dependencies the Agent reaches for --- +const chatRequestMock = vi.fn<(...args: unknown[]) => Promise>(); +vi.mock("../OpenAIRequest", () => ({ + chatRequest: (...args: unknown[]) => chatRequestMock(...args), + anthropicMaxTokens: () => 8192, +})); + +const confirmMock = vi.fn<() => Promise>(async () => "allow"); +vi.mock("../../gui/AIToolConfirmModal", () => ({ + default: { Prompt: () => confirmMock() }, +})); + +vi.mock("../../formatters/completeFormatter", () => ({ + CompleteFormatter: class { + async formatFileContent(input: string) { + return input; // identity — formatting is exercised elsewhere + } + }, +})); + +vi.mock("../aiHelpers", () => ({ + getModelByName: (name: string) => ({ name, maxTokens: 128000 }), + getModelProvider: () => ({ name: "OpenAI", kind: "openai", endpoint: "https://x" }), +})); + +vi.mock("../providerSecrets", () => ({ resolveProviderApiKey: async () => "key" })); + +vi.mock("../preventCursorChange", () => ({ preventCursorChange: () => () => {} })); + +let mockSettings: Record; +vi.mock("../../settingsStore", () => ({ + settingsStore: { getState: () => mockSettings }, +})); + +import { Agent } from "./Agent"; +import type { AgentConfig } from "./aiToolTypes"; + +function makeAgent(config: Partial = {}, vars = new Map()) { + const choiceExecutor = { variables: vars } as never; + return new Agent( + {} as never, + {} as never, + choiceExecutor, + { model: "gpt-4o", ...config } as AgentConfig, + ); +} + +function turnResponse(p: Partial): CommonResponse { + return { + id: "r", + model: "gpt-4o", + content: p.content ?? "", + usage: p.usage ?? { promptTokens: 1, completionTokens: 1, totalTokens: 2 }, + stopReason: p.stopReason ?? "", + stopSequence: null, + created: 0, + toolCalls: p.toolCalls, + normalizedStopReason: p.normalizedStopReason ?? (p.toolCalls?.length ? "tool_calls" : "stop"), + }; +} + +function tool(extra: Record = {}) { + return { + __qaTool: true as const, + description: "create a note", + inputSchema: { type: "object" as const, properties: { path: { type: "string" as const } }, required: ["path"] }, + execute: vi.fn(async () => "created"), + ...extra, + }; +} + +beforeEach(() => { + chatRequestMock.mockReset(); + confirmMock.mockReset(); + confirmMock.mockResolvedValue("allow"); + mockSettings = { + disableOnlineFeatures: false, + ai: { confirmToolCalls: "never", defaultSystemPrompt: "default sys" }, + }; +}); + +describe("Agent.generate — tool loop", () => { + it("runs a tool call then returns final text, mapping to the public shape", async () => { + const t = tool(); + chatRequestMock + .mockResolvedValueOnce( + turnResponse({ toolCalls: [{ id: "c1", name: "create_note", args: { path: "a.md" } }] }), + ) + .mockResolvedValueOnce(turnResponse({ content: "done", normalizedStopReason: "stop" })); + + const agent = makeAgent({ tools: { create_note: t } }); + const res = await agent.generate({ prompt: "make a note" }); + + expect(t.execute).toHaveBeenCalledOnce(); + expect(res.text).toBe("done"); + // public 'output' naming on results + const firstStep = res.steps[0]; + expect(firstStep.toolResults[0]).toMatchObject({ toolName: "create_note", output: "created", isError: false }); + expect(res.usage).toMatchObject({ inputTokens: expect.any(Number), outputTokens: expect.any(Number) }); + }); + + it("denies a tool when the confirm modal returns deny → isError result, loop continues", async () => { + confirmMock.mockResolvedValue("deny"); + mockSettings.ai = { confirmToolCalls: "always", defaultSystemPrompt: "" }; + const t = tool(); + chatRequestMock + .mockResolvedValueOnce(turnResponse({ toolCalls: [{ id: "c1", name: "create_note", args: { path: "a.md" } }] })) + .mockResolvedValueOnce(turnResponse({ content: "ok", normalizedStopReason: "stop" })); + + const agent = makeAgent({ tools: { create_note: t } }); + const res = await agent.generate({ prompt: "x" }); + expect(t.execute).not.toHaveBeenCalled(); + expect(res.steps[0].toolResults[0]).toMatchObject({ isError: true }); + expect(res.steps[0].toolResults[0].output).toMatch(/denied/); + }); + + it("aborts the whole run when the confirm modal returns abort", async () => { + confirmMock.mockResolvedValue("abort"); + mockSettings.ai = { confirmToolCalls: "always", defaultSystemPrompt: "" }; + const t = tool(); + chatRequestMock.mockResolvedValueOnce( + turnResponse({ toolCalls: [{ id: "c1", name: "create_note", args: { path: "a.md" } }] }), + ); + const agent = makeAgent({ tools: { create_note: t } }); + await expect(agent.generate({ prompt: "x" })).rejects.toThrow(/aborted/i); + }); + + it("does not confirm a read-only tool under the 'destructive' setting", async () => { + mockSettings.ai = { confirmToolCalls: "destructive", defaultSystemPrompt: "" }; + const t = tool({ readOnly: true }); + chatRequestMock + .mockResolvedValueOnce(turnResponse({ toolCalls: [{ id: "c1", name: "read", args: { path: "a.md" } }] })) + .mockResolvedValueOnce(turnResponse({ content: "ok", normalizedStopReason: "stop" })); + const agent = makeAgent({ tools: { read: t } }); + await agent.generate({ prompt: "x" }); + expect(confirmMock).not.toHaveBeenCalled(); + expect(t.execute).toHaveBeenCalledOnce(); + }); + + it("confirms a destructive tool when confirmToolCalls is undefined (old persisted settings default to 'destructive')", async () => { + // main.ts shallow-merges settings, so an existing user's pre-confirmToolCalls + // `ai` object leaves this undefined — it must NOT silently auto-run a writer. + mockSettings.ai = { defaultSystemPrompt: "" } as typeof mockSettings.ai; + confirmMock.mockResolvedValue("allow"); + const t = tool({ readOnly: false }); + chatRequestMock + .mockResolvedValueOnce(turnResponse({ toolCalls: [{ id: "c1", name: "create_note", args: { path: "a.md" } }] })) + .mockResolvedValueOnce(turnResponse({ content: "ok", normalizedStopReason: "stop" })); + const agent = makeAgent({ tools: { create_note: t } }); + await agent.generate({ prompt: "x" }); + expect(confirmMock).toHaveBeenCalledOnce(); + }); + + it("assigns result text to a variable (+ -quoted) when assignToVariable is set", async () => { + const vars = new Map(); + chatRequestMock.mockResolvedValueOnce(turnResponse({ content: "the answer", normalizedStopReason: "stop" })); + const agent = makeAgent({}, vars); + await agent.generate({ prompt: "q", assignToVariable: "summary" }); + expect(vars.get("summary")).toBe("the answer"); + expect(vars.get("summary-quoted")).toBe("> the answer"); + }); + + it("trims assignToVariable so a padded name still resolves as {{VALUE:name}}", async () => { + const vars = new Map(); + chatRequestMock.mockResolvedValueOnce(turnResponse({ content: "trimmed", normalizedStopReason: "stop" })); + const agent = makeAgent({}, vars); + await agent.generate({ prompt: "q", assignToVariable: " summary " }); + expect(vars.get("summary")).toBe("trimmed"); + expect(vars.get("summary-quoted")).toBe("> trimmed"); + }); + + it("rejects a reserved assignToVariable name", async () => { + const agent = makeAgent(); + await expect(agent.generate({ prompt: "q", assignToVariable: "value" })).rejects.toThrow(/reserved/); + }); +}); + +describe("Agent.generate — structured output", () => { + it("parses + validates the final turn into result.object", async () => { + chatRequestMock.mockResolvedValueOnce( + turnResponse({ content: '{"title":"Hi","tags":["a"]}', normalizedStopReason: "stop" }), + ); + const agent = makeAgent(); + const res = await agent.generate({ + prompt: "extract", + schema: { type: "object", properties: { title: { type: "string" }, tags: { type: "array", items: { type: "string" } } }, required: ["title"] }, + }); + expect(res.object).toEqual({ title: "Hi", tags: ["a"] }); + }); + + it("repairs once when the first structured reply is invalid", async () => { + chatRequestMock + .mockResolvedValueOnce(turnResponse({ content: "not json at all", normalizedStopReason: "stop" })) + .mockResolvedValueOnce(turnResponse({ content: '{"title":"Fixed"}', normalizedStopReason: "stop" })); + const agent = makeAgent(); + const res = await agent.generate({ + prompt: "extract", + schema: { type: "object", properties: { title: { type: "string" } }, required: ["title"] }, + }); + expect(res.object).toEqual({ title: "Fixed" }); + expect(chatRequestMock).toHaveBeenCalledTimes(2); + }); +}); + +describe("Agent construction validation", () => { + it("rejects an invalid tool name", () => { + expect(() => makeAgent({ tools: { "bad name!": tool() } })).toThrow(/Invalid tool name/); + }); + it("rejects an unsupported schema keyword at registration", () => { + const t = tool({ inputSchema: { type: "string", pattern: "^x" } }); + expect(() => makeAgent({ tools: { t } })).toThrow(/unsupported/i); + }); +}); diff --git a/src/ai/tools/Agent.ts b/src/ai/tools/Agent.ts new file mode 100644 index 00000000..2432f66f --- /dev/null +++ b/src/ai/tools/Agent.ts @@ -0,0 +1,512 @@ +/** + * The AI Agent (#714) — the Obsidian-bound glue around the pure tool loop. + * + * Constructed by `quickAddApi.ai.agent(config)`. Holds model/system/tools/budget + * config; `generate({ prompt })` runs the bounded multi-step tool loop and + * `generate({ prompt, schema })` adds JSON-schema-constrained structured output. + * Stateless across calls (reuse = reuse the config); a per-run guard prevents + * overlapping calls from racing the shared variables map / cursor. + */ +import type { App } from "obsidian"; +import type QuickAdd from "../../main"; +import type { IChoiceExecutor } from "../../IChoiceExecutor"; +import { settingsStore } from "../../settingsStore"; +import { CompleteFormatter } from "../../formatters/completeFormatter"; +import { MacroAbortError } from "../../errors/MacroAbortError"; +import { getModelByName, getModelProvider } from "../aiHelpers"; +import type { Model } from "../Provider"; +import { resolveProviderApiKey } from "../providerSecrets"; +import { classifyProviderError } from "../providerErrors"; +import { preventCursorChange } from "../preventCursorChange"; +import { + chatRequest, + type CommonResponse, +} from "../OpenAIRequest"; +import type { + JSONSchema, + NormalizedChatRequest, + NormalizedMessage, + NormalizedToolChoice, + NormalizedToolDefinition, +} from "./NormalizedTools"; +import type { ParsedChatResult } from "./providerToolMapping"; +import { + runToolLoop, + type LoopStep, + type RunToolLoopResult, + type ToolEntry, +} from "./runToolLoop"; +import { assertRegisterableSchema, validateValue } from "./jsonSchemaValidator"; +import { assertAssignableVariableName } from "./assignableVariable"; +import AIToolConfirmModal from "../../gui/AIToolConfirmModal"; +import type { + AgentConfig, + GenerateOptions, + GenerateResult, + GenerateStep, + PublicToolCall, + PublicToolResult, + QATool, +} from "./aiToolTypes"; + +const TOOL_NAME_RE = /^[a-zA-Z0-9_-]{1,64}$/; +const DEFAULT_MAX_STEPS = 20; +const MAX_STEPS_CEILING = 100; + +// Cross-instance guard: only one assignToVariable run at a time per choiceExecutor +// (they would race the shared variables map). Read-only runs are unaffected. +const statefulRunsInFlight = new WeakSet(); + +function isAbortError(e: unknown): boolean { + return e instanceof MacroAbortError; +} + +function toParsed(cr: CommonResponse): ParsedChatResult { + return { + content: cr.content, + toolCalls: cr.toolCalls ?? [], + normalizedStopReason: cr.normalizedStopReason ?? "stop", + rawStopReason: cr.stopReason, + usage: cr.usage, + providerRaw: cr.providerRaw, + }; +} + +function toPublicToolChoice( + choice: AgentConfig["toolChoice"] | GenerateOptions["toolChoice"], +): NormalizedToolChoice | undefined { + if (!choice) return undefined; + if (typeof choice === "string") return choice; + return { name: choice.toolName }; +} + +export class Agent { + private running = false; + + constructor( + private readonly app: App, + private readonly plugin: QuickAdd, + private readonly choiceExecutor: IChoiceExecutor, + private readonly config: AgentConfig, + ) { + // Validate the tool set once, at construction, so authoring errors fail fast. + const seen = new Set(); + for (const [name, tool] of Object.entries(config.tools ?? {})) { + if (!TOOL_NAME_RE.test(name)) { + throw new Error( + `Invalid tool name "${name}". Tool names must match ${TOOL_NAME_RE}.`, + ); + } + if (seen.has(name)) throw new Error(`Duplicate tool name "${name}".`); + seen.add(name); + assertRegisterableSchema(tool.inputSchema, `tool "${name}".inputSchema`); + } + } + + async generate(options: GenerateOptions = {}): Promise { + if (this.running) { + throw new Error( + "This agent is already running a generate() call. Await it, or create a separate agent for concurrent runs.", + ); + } + const usesVariables = + typeof options.assignToVariable === "string" && + options.assignToVariable.length > 0; + if (usesVariables) { + this.assertAssignableName(options.assignToVariable as string); + if (statefulRunsInFlight.has(this.choiceExecutor)) { + throw new Error( + "Another AI run with assignToVariable is already in flight in this context — they would race the shared variables. Run them sequentially.", + ); + } + } + + this.running = true; + if (usesVariables) statefulRunsInFlight.add(this.choiceExecutor); + const restoreCursor = preventCursorChange(this.app); + try { + return await this.run(options, restoreCursor); + } finally { + this.running = false; + if (usesVariables) statefulRunsInFlight.delete(this.choiceExecutor); + try { + restoreCursor(); + } catch { + /* editor may be gone; ignore */ + } + } + } + + private async run( + options: GenerateOptions, + restoreCursor: () => void, + ): Promise { + // Per-run state — "approve all this run" must NOT leak into a later generate() + // on a reused agent. + this.approveAllThisRun = false; + const pluginSettings = settingsStore.getState(); + if (pluginSettings.disableOnlineFeatures) { + throw new Error( + "Rejecting AI request: Online features are disabled in settings.", + ); + } + + const modelName = + typeof this.config.model === "string" + ? this.config.model + : this.config.model?.name; + if (!modelName) throw new Error("ai.agent requires a model name."); + const model = getModelByName(modelName); + if (!model) { + throw new Error( + `Model '${modelName}' not found in configured providers. Add it in Settings → QuickAdd → AI → Providers.`, + ); + } + const modelProvider = getModelProvider(model.name); + if (!modelProvider) { + throw new Error(`No provider configured for model '${model.name}'.`); + } + const apiKey = await resolveProviderApiKey(this.app, modelProvider); + + // Build the seed request: [system?, user(formatted prompt)]. + const messages = await this.buildSeedMessages(options); + const tools = this.buildToolDefinitions(); + const registry = this.buildRegistry(); + const responseFormat = options.schema + ? { schema: options.schema, name: "response", strict: true } + : undefined; + + // The seed request carries tools but NOT responseFormat — schema is attached + // per-turn (only on turns that send no tools), so a combined tools+schema run + // does not collapse (OpenAI suppresses tool calls when response_format is set). + const request: NormalizedChatRequest = { + messages, + ...(tools.length > 0 ? { tools } : {}), + toolChoice: + toPublicToolChoice(options.toolChoice ?? this.config.toolChoice) ?? "auto", + maxOutputTokens: options.maxOutputTokens ?? this.config.maxOutputTokens, + modelParams: this.config.modelOptions, + }; + + const maxSteps = Math.max( + 1, + Math.min(this.config.maxSteps ?? DEFAULT_MAX_STEPS, MAX_STEPS_CEILING), + ); + + const loop = await runToolLoop({ + request, + maxSteps, + dispatch: async (req, ctx) => { + const turnReq = this.buildTurnRequest(req, ctx.isFinalStep, responseFormat); + const cr = await chatRequest( + this.app, + apiKey, + model, + turnReq, + restoreCursor, + ); + return toParsed(cr); + }, + getTool: (name) => registry.get(name), + confirm: (call, tool) => this.confirm(call, tool), + validateArgs: (tool, args) => + validateValue(args, tool.definition.parameters), + isAbortError, + isContextOverflow: (e) => classifyProviderError(e) === "input_context", + isOnlineDisabled: () => settingsStore.getState().disableOnlineFeatures, + shouldStop: this.buildShouldStop(), + }); + + const result = this.toGenerateResult(loop); + + if (options.schema) { + result.object = await this.resolveStructuredObject( + loop, + request, + model, + apiKey, + options.schema, + restoreCursor, + ); + } + + if (typeof options.assignToVariable === "string" && options.assignToVariable) { + this.assignVariable(options.assignToVariable, result.text); + } + + return result; + } + + // --- request assembly ------------------------------------------------- + + private async buildSeedMessages( + options: GenerateOptions, + ): Promise { + const messages: NormalizedMessage[] = []; + const system = options.system ?? this.config.system ?? settingsStore.getState().ai.defaultSystemPrompt; + if (system && system.trim().length > 0) { + messages.push({ role: "system", content: system }); + } + const prompt = options.prompt ?? ""; + // The prompt is run through the QuickAdd formatter (like ai.prompt). Tool + // args, by contrast, are NEVER formatted (that would re-inject {{...}}). + const formatted = prompt + ? await new CompleteFormatter( + this.app, + this.plugin, + this.choiceExecutor, + ).formatFileContent(prompt) + : ""; + messages.push({ role: "user", content: formatted }); + return messages; + } + + private buildToolDefinitions(): NormalizedToolDefinition[] { + return Object.entries(this.config.tools ?? {}).map(([name, tool]) => ({ + name, + description: tool.description, + parameters: tool.inputSchema, + ...(tool.strict ? { strict: true } : {}), + })); + } + + private buildRegistry(): Map { + const registry = new Map(); + for (const [name, tool] of Object.entries(this.config.tools ?? {})) { + registry.set(name, { + definition: { + name, + description: tool.description, + parameters: tool.inputSchema, + strict: tool.strict, + }, + execute: (args, ctx) => (tool as QATool).execute(args, ctx), + readOnly: tool.readOnly, + }); + } + return registry; + } + + private buildShouldStop(): + | ((info: { steps: LoopStep[] }) => boolean) + | undefined { + const conditions = this.config.stopWhen + ? Array.isArray(this.config.stopWhen) + ? this.config.stopWhen + : [this.config.stopWhen] + : []; + if (conditions.length === 0) return undefined; + return ({ steps }) => { + const toolCallNames = steps.flatMap((s) => + s.toolCalls.map((c) => c.name), + ); + return conditions.some((c) => + c({ stepNumber: steps.length, toolCallNames }), + ); + }; + } + + /** + * Build the per-turn request. The invariant: a turn sends EITHER tools OR a + * responseFormat, never both — this avoids OpenAI suppressing tool calls when a + * schema is set, and sidesteps Gemini-1.5's tools+responseSchema 400. So: + * - a tool-gathering turn (has tools, not final) → tools, no schema; + * - a no-tools turn (final step, or a schema-only run) → no tools, schema attached. + */ + private buildTurnRequest( + req: NormalizedChatRequest, + isFinalStep: boolean, + responseFormat: NormalizedChatRequest["responseFormat"], + ): NormalizedChatRequest { + const sendTools = !!(req.tools && req.tools.length > 0) && !isFinalStep; + if (sendTools) { + return { ...req, responseFormat: undefined }; + } + return { + ...req, + tools: undefined, + toolChoice: isFinalStep ? "none" : req.toolChoice, + responseFormat, + }; + } + + // --- confirmation ----------------------------------------------------- + + private approveAllThisRun = false; + + private async confirm( + call: { id: string; name: string; args: Record | null }, + tool: ToolEntry, + ): Promise { + const needs = await this.needsConfirmation(tool, call.args ?? {}); + if (!needs) return true; + if (this.approveAllThisRun) return true; + + const outcome = await AIToolConfirmModal.Prompt( + this.app, + call.name, + call.args ?? {}, + ); + if (outcome === "abort") { + throw new MacroAbortError("AI tool run aborted by user"); + } + if (outcome === "allow-all") { + this.approveAllThisRun = true; + return true; + } + return outcome === "allow"; + } + + private async needsConfirmation( + tool: ToolEntry, + args: Record, + ): Promise { + const declared = this.config.tools?.[tool.definition.name]; + const perTool = declared?.needsApproval; + const resolved = + typeof perTool === "function" + ? await perTool({ args }) + : perTool === true; + if (resolved) return true; // a tool's own approval is the floor — always confirm + + // Default to "destructive" when the persisted value is missing: main.ts + // shallow-merges loadedData, so an existing user's pre-`confirmToolCalls` + // `ai` object replaces DEFAULT_SETTINGS.ai wholesale and leaves this + // undefined. Without this floor a destructive author tool would auto-run. + const global = settingsStore.getState().ai.confirmToolCalls ?? "destructive"; + if (global === "always") return true; + if (global === "destructive") return tool.readOnly !== true; + return false; // 'never' + } + + // --- structured output ------------------------------------------------ + + private async resolveStructuredObject( + loop: RunToolLoopResult, + request: NormalizedChatRequest, + model: Model, + apiKey: string, + schema: JSONSchema, + restoreCursor: () => void, + ): Promise { + const first = parseStructured(loop.finalTurn.content, schema); + if (first.ok) return first.value; + + // One bounded repair re-ask: a fresh stricter request (never a format() pass). + // No tools + responseFormat is the valid no-tools shape for every provider. + // loop.messages omits the final assistant turn, so include the bad reply + // explicitly — the "your previous reply" framing needs it in context. + const repairMessages: NormalizedMessage[] = [ + ...loop.messages, + { role: "assistant", content: loop.finalTurn.content }, + { + role: "user", + content: `Your previous reply was not valid JSON matching the required schema (${first.error}). Reply with ONLY a JSON object that matches the schema — no prose, no code fences.`, + }, + ]; + const repairReq: NormalizedChatRequest = { + ...request, + messages: repairMessages, + tools: undefined, + toolChoice: "none", + responseFormat: { schema, name: "response", strict: true }, + }; + try { + const cr = await chatRequest( + this.app, + apiKey, + model, + repairReq, + restoreCursor, + ); + const second = parseStructured(cr.content, schema); + return second.ok ? second.value : undefined; + } catch { + return undefined; + } + } + + // --- result mapping --------------------------------------------------- + + private toGenerateResult(loop: RunToolLoopResult): GenerateResult { + const steps: GenerateStep[] = loop.steps.map((s) => ({ + stepNumber: s.stepNumber, + text: s.text, + toolCalls: s.toolCalls.map(toPublicCall), + toolResults: s.toolResults.map(toPublicResult), + finishReason: s.finishReason, + })); + const last = steps[steps.length - 1]; + return { + text: loop.text, + steps, + toolCalls: last?.toolCalls ?? [], + toolResults: last?.toolResults ?? [], + usage: { + inputTokens: loop.usage.promptTokens, + outputTokens: loop.usage.completionTokens, + totalTokens: loop.usage.totalTokens, + }, + finishReason: loop.finishReason, + }; + } + + // --- variable bridge -------------------------------------------------- + + private assignVariable(name: string, text: string): void { + // Write the trimmed key: assertAssignableName validates the trimmed form, + // and the formatter trims `{{VALUE:...}}` token names — an untrimmed key + // (" summary ") would never resolve. + const key = name.trim(); + const quoted = ("> " + text).replace(/\n/g, "\n> "); + this.choiceExecutor.variables.set(key, text); + this.choiceExecutor.variables.set(`${key}-quoted`, quoted); + } + + private assertAssignableName(name: string): void { + assertAssignableVariableName(name); + } +} + +function toPublicCall(c: { + id: string; + name: string; + args: Record | null; +}): PublicToolCall { + return { toolCallId: c.id, toolName: c.name, input: c.args }; +} + +function toPublicResult(r: { + toolCallId: string; + name: string; + content: string; + isError?: boolean; +}): PublicToolResult { + return { + toolCallId: r.toolCallId, + toolName: r.name, + output: r.content, + isError: r.isError ?? false, + }; +} + +function parseStructured( + text: string, + schema: JSONSchema, +): { ok: true; value: unknown } | { ok: false; error: string } { + let parsed: unknown; + try { + parsed = JSON.parse(stripCodeFences(text)); + } catch (e) { + return { ok: false, error: `not valid JSON (${e instanceof Error ? e.message : String(e)})` }; + } + const err = validateValue(parsed, schema); + if (err) return { ok: false, error: err }; + return { ok: true, value: parsed }; +} + +function stripCodeFences(text: string): string { + const trimmed = text.trim(); + const fence = trimmed.match(/^```(?:json)?\s*([\s\S]*?)\s*```$/i); + return fence ? fence[1] : trimmed; +} diff --git a/src/ai/tools/builtins/builtins.test.ts b/src/ai/tools/builtins/builtins.test.ts new file mode 100644 index 00000000..2bb29244 --- /dev/null +++ b/src/ai/tools/builtins/builtins.test.ts @@ -0,0 +1,119 @@ +import { describe, it, expect, vi } from "vitest"; +import { TFile } from "obsidian"; +import { applyGroupOptions, defineTool } from "./shared"; +import { createVaultTools } from "./vaultTools"; +import { createWorkspaceTools } from "./workspaceTools"; +import { createSystemTools } from "./systemTools"; +import { UnsafeVaultPathError } from "../sanitizeVaultPath"; +import type { App } from "obsidian"; + +function fileLike(path: string, basename = path): TFile { + const f = Object.create(TFile.prototype) as TFile; + Object.assign(f, { path, basename }); + return f; +} + +function makeApp(over: Record = {}): App { + return { + vault: { + adapter: {}, // not a FileSystemAdapter → symlink guard is a no-op in tests + getAbstractFileByPath: vi.fn(() => null), + getMarkdownFiles: vi.fn(() => []), + cachedRead: vi.fn(async () => ""), + read: vi.fn(async () => ""), + create: vi.fn(async (p: string) => fileLike(p)), + modify: vi.fn(async () => undefined), + createFolder: vi.fn(async () => undefined), + ...((over.vault as object) ?? {}), + }, + metadataCache: { getFileCache: vi.fn(() => null), ...((over.metadataCache as object) ?? {}) }, + workspace: { getActiveFile: vi.fn(() => null), getActiveViewOfType: vi.fn(() => null), ...((over.workspace as object) ?? {}) }, + } as unknown as App; +} + +describe("applyGroupOptions", () => { + const set = { + a: defineTool({ description: "a", inputSchema: { type: "object" }, execute: async () => 1 }), + b: defineTool({ description: "b", inputSchema: { type: "object" }, execute: async () => 2 }), + }; + it("only / exclude / prefix", () => { + expect(Object.keys(applyGroupOptions(set, { only: ["a"] }))).toEqual(["a"]); + expect(Object.keys(applyGroupOptions(set, { exclude: ["a"] }))).toEqual(["b"]); + expect(Object.keys(applyGroupOptions(set, { prefix: "qa_" }))).toEqual(["qa_a", "qa_b"]); + }); + it("prefix keeps each tool's needsApproval/readOnly intact", () => { + const writer = { w: defineTool({ description: "w", inputSchema: { type: "object" }, needsApproval: true, execute: async () => 1 }) }; + const renamed = applyGroupOptions(writer, { prefix: "p_" }); + expect(renamed.p_w.needsApproval).toBe(true); + }); +}); + +describe("vault tools — classification + schemas", () => { + const tools = createVaultTools(makeApp()); + it("read tools are readOnly, write tools need approval", () => { + for (const r of ["read_note", "list_notes", "search_notes", "get_property_values"]) { + expect(tools[r].readOnly).toBe(true); + } + for (const w of ["create_note", "append_to_note", "insert_under_heading"]) { + expect(tools[w].needsApproval).toBe(true); + } + }); + it("does NOT ship the deferred high-risk tools", () => { + for (const deferred of ["run_choice", "apply_template", "set_frontmatter_property", "delete_note"]) { + expect(tools[deferred]).toBeUndefined(); + } + }); +}); + +describe("vault write tools — safety", () => { + it("create_note rejects an unsafe (config-dir) path before touching the vault", async () => { + const app = makeApp(); + const tools = createVaultTools(app); + await expect( + tools.create_note.execute({ path: "Notes/.obsidian/evil/main.js" }, { toolCallId: "c", toolName: "create_note" }), + ).rejects.toBeInstanceOf(UnsafeVaultPathError); + expect((app.vault.create as ReturnType)).not.toHaveBeenCalled(); + }); + + it("create_note ensures .md and calls vault.create (fail-on-exist via the API)", async () => { + const create = vi.fn(async (p: string) => fileLike(p)); + const app = makeApp({ vault: { create, getAbstractFileByPath: () => fileLike("Notes") } }); + const tools = createVaultTools(app); + const res = (await tools.create_note.execute({ path: "Notes/New", content: "hi" }, { toolCallId: "c", toolName: "create_note" })) as { created: boolean; path: string }; + expect(create).toHaveBeenCalledWith("Notes/New.md", "hi"); + expect(res).toMatchObject({ created: true, path: "Notes/New.md" }); + }); + + it("append_to_note errors when the note does not exist", async () => { + const app = makeApp({ vault: { getAbstractFileByPath: () => null } }); + const tools = createVaultTools(app); + await expect( + tools.append_to_note.execute({ path: "Missing.md", content: "x" }, { toolCallId: "c", toolName: "append_to_note" }), + ).rejects.toThrow(/not found/i); + }); + + it("respects allowedRoots for reads", async () => { + const app = makeApp(); + const tools = createVaultTools(app, { allowedRoots: ["AI"] }); + await expect( + tools.read_note.execute({ path: "Secret/passwords.md" }, { toolCallId: "c", toolName: "read_note" }), + ).rejects.toBeInstanceOf(UnsafeVaultPathError); + }); +}); + +describe("workspace + system tools", () => { + it("get_selection returns empty string when no active editor", async () => { + const tools = createWorkspaceTools(makeApp()); + expect(await tools.get_selection.execute({}, { toolCallId: "c", toolName: "get_selection" })).toEqual({ selection: "" }); + }); + it("get_active_note returns active:null when there is no active file", async () => { + const tools = createWorkspaceTools(makeApp()); + expect(await tools.get_active_note.execute({}, { toolCallId: "c", toolName: "get_active_note" })).toEqual({ active: null }); + }); + it("get_date returns a date string", async () => { + const tools = createSystemTools(); + const res = (await tools.get_date.execute({ format: "YYYY" }, { toolCallId: "c", toolName: "get_date" })) as { date: string }; + expect(typeof res.date).toBe("string"); + expect(res.date.length).toBeGreaterThan(0); + }); +}); diff --git a/src/ai/tools/builtins/shared.ts b/src/ai/tools/builtins/shared.ts new file mode 100644 index 00000000..59f08679 --- /dev/null +++ b/src/ai/tools/builtins/shared.ts @@ -0,0 +1,42 @@ +import type { QATool, ToolSet } from "../aiToolTypes"; + +export type ToolSetMap = ToolSet; + +/** Per-group factory options for the built-in tool groups. */ +export interface BuiltinGroupOptions { + /** Keep only these tool names from the group. */ + only?: string[]; + /** Drop these tool names from the group. */ + exclude?: string[]; + /** Prefix the tool NAMES (the map keys) to avoid collisions. Never alters approval. */ + prefix?: string; + /** Confine vault paths the group may read/write to these folders. */ + allowedRoots?: string[]; +} + +/** + * Apply only/exclude/prefix to a built tool map. The prefix changes the map KEY + * (the wire tool name) only — it never strips a tool's needsApproval/readOnly floor. + */ +export function applyGroupOptions( + tools: ToolSetMap, + options: BuiltinGroupOptions, +): ToolSetMap { + let entries = Object.entries(tools); + if (options.only) { + const allow = new Set(options.only); + entries = entries.filter(([name]) => allow.has(name)); + } + if (options.exclude) { + const deny = new Set(options.exclude); + entries = entries.filter(([name]) => !deny.has(name)); + } + const prefix = options.prefix ?? ""; + return Object.fromEntries( + entries.map(([name, t]) => [prefix + name, t]), + ) as ToolSetMap; +} + +export function defineTool(def: Omit): QATool { + return { ...def, __qaTool: true }; +} diff --git a/src/ai/tools/builtins/systemTools.ts b/src/ai/tools/builtins/systemTools.ts new file mode 100644 index 00000000..08852293 --- /dev/null +++ b/src/ai/tools/builtins/systemTools.ts @@ -0,0 +1,30 @@ +/** + * Built-in `system` tools (#714): small, always-safe utilities. Read-only. + */ +import { getDate } from "../../../utilityObsidian"; +import { applyGroupOptions, defineTool, type BuiltinGroupOptions, type ToolSetMap } from "./shared"; + +export function createSystemTools(options: BuiltinGroupOptions = {}): ToolSetMap { + const tools: ToolSetMap = { + get_date: defineTool({ + description: + "Get the current date/time. Optionally pass a moment.js format and a day offset.", + inputSchema: { + type: "object", + properties: { + format: { type: "string", description: "moment.js format, e.g. YYYY-MM-DD" }, + offset: { type: "integer", description: "Day offset from today (e.g. 1 = tomorrow)." }, + }, + }, + readOnly: true, + execute: async ({ format, offset }) => ({ + date: getDate({ + format: format != null ? String(format) : undefined, + offset: offset != null ? Number(offset) : undefined, + }), + }), + }), + }; + + return applyGroupOptions(tools, options); +} diff --git a/src/ai/tools/builtins/vaultTools.ts b/src/ai/tools/builtins/vaultTools.ts new file mode 100644 index 00000000..7a9d5830 --- /dev/null +++ b/src/ai/tools/builtins/vaultTools.ts @@ -0,0 +1,301 @@ +/** + * Built-in `vault` tools for the AI Agent (#714). + * + * Read tools (read_note/list_notes/search_notes/get_property_values) are readOnly + * and auto-run (capped to bound cost/exfiltration). Write tools (create_note/ + * append_to_note/insert_under_heading) ship needsApproval:true, run every model- + * chosen path through sanitizeVaultPath (every-segment dot floor, etc.) + the + * runtime symlink/realpath guard, are existence-aware (create fails on exist; append/ + * insert require the file), and are frontmatter-aware (insertAtNoteBodyStart). The + * high-risk trio run_choice/apply_template/set_frontmatter is deliberately NOT shipped. + */ +import { type App, TFile } from "obsidian"; +import { getMarkdownFilesInFolder } from "../../../utilityObsidian"; +import { insertAtNoteBodyStart } from "../../../utils/noteContentInsertion"; +import { sanitizeVaultPath } from "../sanitizeVaultPath"; +import { assertWriteStaysInVault } from "./vaultWriteGuards"; +import type { JSONSchema } from "../NormalizedTools"; +import type { QATool } from "../aiToolTypes"; +import { applyGroupOptions, type BuiltinGroupOptions, type ToolSetMap } from "./shared"; + +const MAX_READ_CHARS = 16_000; +const DEFAULT_LIST = 100; +const MAX_LIST = 200; +const DEFAULT_SEARCH = 25; +const MAX_SEARCH = 50; +const MAX_SEARCH_FILE_SCAN = 400; + +export function createVaultTools( + app: App, + options: BuiltinGroupOptions = {}, +): ToolSetMap { + const roots = options.allowedRoots; + const withinRoots = (path: string) => { + try { + sanitizeVaultPath(path, { allowedRoots: roots }); + return true; + } catch { + return false; + } + }; + + const tools: ToolSetMap = { + read_note: tool({ + description: "Read a note's full markdown content by vault path.", + inputSchema: obj({ path: str("Vault path to the note, e.g. Notes/Foo.md") }, ["path"]), + readOnly: true, + execute: async ({ path }) => { + const norm = sanitizeVaultPath(String(path), { allowedRoots: roots }); + const file = app.vault.getAbstractFileByPath(norm); + if (!(file instanceof TFile)) return { found: false, path: norm }; + const content = await app.vault.cachedRead(file); + return { + found: true, + path: norm, + content: + content.length > MAX_READ_CHARS + ? content.slice(0, MAX_READ_CHARS) + "\n…[truncated]" + : content, + }; + }, + }), + + list_notes: tool({ + description: "List markdown notes, optionally within a folder.", + inputSchema: obj( + { + folder: str("Folder to list (omit for the whole vault)."), + limit: int("Max results (default 100, capped at 200)."), + }, + [], + ), + readOnly: true, + execute: async ({ folder, limit }) => { + const base = folder ? sanitizeVaultPath(String(folder), { allowedRoots: roots }) : ""; + const cap = clamp(toInt(limit, DEFAULT_LIST), 1, MAX_LIST); + const files = getMarkdownFilesInFolder(app, base).filter((f) => + roots ? withinRoots(f.path) : true, + ); + return { + total: files.length, + notes: files.slice(0, cap).map((f) => ({ path: f.path, basename: f.basename })), + }; + }, + }), + + search_notes: tool({ + description: + "Search notes by name and/or content. Returns matching paths with a short snippet.", + inputSchema: obj( + { + query: str("Text to search for."), + in: enumStr(["name", "content", "both"], "Where to search (default both)."), + limit: int("Max results (default 25, capped at 50)."), + }, + ["query"], + ), + readOnly: true, + execute: async ({ query, in: where, limit }) => { + const q = String(query).toLowerCase(); + const scope = (where as string) ?? "both"; + const cap = clamp(toInt(limit, DEFAULT_SEARCH), 1, MAX_SEARCH); + const files = app.vault + .getMarkdownFiles() + .filter((f) => (roots ? withinRoots(f.path) : true)); + const results: Array<{ path: string; snippet?: string }> = []; + let scanned = 0; + for (const f of files) { + if (results.length >= cap) break; + const nameHit = scope !== "content" && f.basename.toLowerCase().includes(q); + if (nameHit) { + results.push({ path: f.path }); + continue; + } + if (scope !== "name" && scanned < MAX_SEARCH_FILE_SCAN) { + scanned++; + const text = await app.vault.cachedRead(f); + const idx = text.toLowerCase().indexOf(q); + if (idx >= 0) { + results.push({ path: f.path, snippet: snippetAround(text, idx) }); + } + } + } + return { results, scannedFiles: scanned, truncated: results.length >= cap }; + }, + }), + + get_property_values: tool({ + description: + "List the distinct values a frontmatter property takes across the vault (e.g. existing tags or statuses).", + inputSchema: obj( + { field: str("Frontmatter property name."), folder: str("Optional folder scope.") }, + ["field"], + ), + readOnly: true, + execute: async ({ field, folder }) => { + const base = folder ? sanitizeVaultPath(String(folder), { allowedRoots: roots }) : ""; + const files = ( + folder ? getMarkdownFilesInFolder(app, base) : app.vault.getMarkdownFiles() + ).filter((f) => (roots ? withinRoots(f.path) : true)); + const values = new Set(); + for (const f of files) { + const fm = app.metadataCache.getFileCache(f)?.frontmatter; + const v = fm?.[String(field)]; + if (v == null) continue; + for (const item of Array.isArray(v) ? v : [v]) { + if (typeof item !== "object") { + const s = String(item).trim(); + if (s) values.add(s); + } + } + } + return { field: String(field), values: [...values].sort().slice(0, 200) }; + }, + }), + + create_note: tool({ + description: "Create a new markdown note. Fails if the path already exists.", + inputSchema: obj( + { path: str("Vault path for the new note."), content: str("Initial markdown content.") }, + ["path"], + ), + needsApproval: true, + execute: async ({ path, content }) => { + const norm = sanitizeVaultPath(ensureMd(String(path)), { allowedRoots: roots }); + await assertWriteStaysInVault(app, norm); + await ensureParentFolder(app, norm); + const file = await app.vault.create(norm, String(content ?? "")); + return { created: true, path: file.path }; + }, + }), + + append_to_note: tool({ + description: "Append text to an existing note (frontmatter-aware for top inserts).", + inputSchema: obj( + { + path: str("Vault path to the existing note."), + content: str("Markdown to append."), + position: enumStr(["top", "bottom"], "Where to add it (default bottom)."), + }, + ["path", "content"], + ), + needsApproval: true, + execute: async ({ path, content, position }) => { + const norm = sanitizeVaultPath(String(path), { allowedRoots: roots }); + const file = requireFile(app, norm); + await assertWriteStaysInVault(app, norm); + const body = await app.vault.read(file); + const text = String(content); + const next = + position === "top" + ? insertAtNoteBodyStart(body, text.endsWith("\n") ? text : text + "\n") + : `${body}${body.endsWith("\n") || body.length === 0 ? "" : "\n"}${text}`; + await app.vault.modify(file, next); + return { appended: true, path: norm }; + }, + }), + + insert_under_heading: tool({ + description: + "Insert markdown under an existing heading (matched exactly). Errors if the heading is not found.", + inputSchema: obj( + { + path: str("Vault path to the existing note."), + heading: str("Exact heading text (without the # markers)."), + content: str("Markdown to insert under the heading."), + }, + ["path", "heading", "content"], + ), + needsApproval: true, + execute: async ({ path, heading, content }) => { + const norm = sanitizeVaultPath(String(path), { allowedRoots: roots }); + const file = requireFile(app, norm); + await assertWriteStaysInVault(app, norm); + const headings = app.metadataCache.getFileCache(file)?.headings ?? []; + const target = headings.find((h) => h.heading === String(heading)); + if (!target) throw new Error(`Heading "${heading}" not found in ${norm}`); + const body = await app.vault.read(file); + const next = insertUnderHeading(body, target.position.start.line, target.level, String(content), headings); + await app.vault.modify(file, next); + return { inserted: true, path: norm, heading: String(heading) }; + }, + }), + }; + + return applyGroupOptions(tools, options); +} + +// --- helpers ---------------------------------------------------------------- + +function tool(def: Omit): QATool { + return { ...def, __qaTool: true }; +} +function obj(properties: Record, required: string[]): JSONSchema { + return { type: "object", properties, required }; +} +function str(description: string): JSONSchema { + return { type: "string", description }; +} +function int(description: string): JSONSchema { + return { type: "integer", description }; +} +function enumStr(values: string[], description: string): JSONSchema { + return { type: "string", enum: values, description }; +} +function toInt(v: unknown, fallback: number): number { + const n = Number(v); + return Number.isFinite(n) ? Math.floor(n) : fallback; +} +function clamp(n: number, lo: number, hi: number): number { + return Math.max(lo, Math.min(hi, n)); +} +function ensureMd(path: string): string { + return /\.[a-z0-9]+$/i.test(path) ? path : `${path}.md`; +} +function requireFile(app: App, normalizedPath: string): TFile { + const file = app.vault.getAbstractFileByPath(normalizedPath); + if (!(file instanceof TFile)) { + throw new Error(`Note not found: ${normalizedPath}`); + } + return file; +} +async function ensureParentFolder(app: App, normalizedPath: string): Promise { + const slash = normalizedPath.lastIndexOf("/"); + if (slash <= 0) return; + const parent = normalizedPath.slice(0, slash); + if (app.vault.getAbstractFileByPath(parent)) return; + try { + await app.vault.createFolder(parent); + } catch { + /* already exists / race — ignore */ + } +} +function snippetAround(text: string, idx: number): string { + const start = Math.max(0, idx - 40); + return text.slice(start, idx + 80).replace(/\s+/g, " ").trim(); +} +function insertUnderHeading( + body: string, + headingLine: number, + headingLevel: number, + content: string, + headings: Array<{ position: { start: { line: number } }; level: number }>, +): string { + const lines = body.split("\n"); + // Find the next heading at the same or shallower level after this one. + let endLine = lines.length; + for (const h of headings) { + if (h.position.start.line > headingLine && h.level <= headingLevel) { + endLine = h.position.start.line; + break; + } + } + const insertText = content.endsWith("\n") ? content : content + "\n"; + const before = lines.slice(0, endLine); + const after = lines.slice(endLine); + // Trim a trailing blank line in `before` so we do not pile up blank lines. + while (before.length > headingLine + 1 && before[before.length - 1].trim() === "") { + before.pop(); + } + return [...before, insertText.replace(/\n$/, ""), ...after].join("\n"); +} diff --git a/src/ai/tools/builtins/vaultWriteGuards.ts b/src/ai/tools/builtins/vaultWriteGuards.ts new file mode 100644 index 00000000..a9d57b01 --- /dev/null +++ b/src/ai/tools/builtins/vaultWriteGuards.ts @@ -0,0 +1,90 @@ +/** + * Runtime write guard for AI built-in vault writers (#714). + * + * String sanitization (sanitizeVaultPath) is not enough: Obsidian's desktop adapter + * follows symlinks, so a write to an in-vault symlink can land OUTSIDE the vault + * while a confirm shows an in-vault path (runtime-proven in review). This resolves + * the realpath of the target (or its nearest existing ancestor, since the file may + * not exist yet) and the vault root, and throws if the target is not contained. + * + * Desktop only — FileSystemAdapter + Node fs/path are accessed lazily via + * `window.require` so the mobile bundle (no symlinks, no require) is never affected. + */ +import { FileSystemAdapter, type App } from "obsidian"; + +export class VaultWriteEscapeError extends Error { + constructor(message: string) { + super(message); + this.name = "VaultWriteEscapeError"; + } +} + +// Minimal shapes (avoid a top-level `import ... from "fs"` so the mobile bundle, +// which has no Node builtins, is never affected — fs/path are required lazily). +interface NodeFs { + promises: { realpath(p: string): Promise }; + lstatSync(p: string): unknown; +} +interface NodePath { + dirname(p: string): string; + relative(from: string, to: string): string; + isAbsolute(p: string): boolean; + sep: string; +} + +function nodeRequire(mod: string): T | null { + try { + const req = (window as unknown as { require?: (m: string) => unknown }) + .require; + return req ? (req(mod) as T) : null; + } catch { + return null; + } +} + +export async function assertWriteStaysInVault( + app: App, + vaultRelativePath: string, +): Promise { + const adapter = app.vault.adapter; + // `typeof` guard: FileSystemAdapter is absent on mobile and in the test stub — + // `instanceof undefined` would throw, so check the symbol is a constructor first. + if (typeof FileSystemAdapter !== "function" || !(adapter instanceof FileSystemAdapter)) { + return; // mobile / no real FS — string sanitization is the only guard there + } + + const fs = nodeRequire("fs"); + const path = nodeRequire("path"); + if (!fs?.promises || !path) return; // can't verify — fail open only when fs is unavailable + + const realBase = await fs.promises.realpath(adapter.getBasePath()); + const targetFull = adapter.getFullPath(vaultRelativePath); + + // Walk up to the nearest path that actually exists; resolve ITS realpath. Any + // symlink in the existing chain (pointing outside the vault) surfaces here. The + // not-yet-created tail is plain, already-sanitized names, so it cannot escape. + let probe = targetFull; + while (!exists(fs, probe)) { + const parent = path.dirname(probe); + if (parent === probe) break; + probe = parent; + } + + const realTarget = await fs.promises.realpath(probe).catch(() => probe); + const rel = path.relative(realBase, realTarget); + const escapes = rel === ".." || rel.startsWith(`..${path.sep}`) || path.isAbsolute(rel); + if (escapes) { + throw new VaultWriteEscapeError( + `Refusing to write to "${vaultRelativePath}": it resolves (via a symlink) outside the vault.`, + ); + } +} + +function exists(fs: NodeFs, p: string): boolean { + try { + fs.lstatSync(p); + return true; + } catch { + return false; + } +} diff --git a/src/ai/tools/builtins/workspaceTools.ts b/src/ai/tools/builtins/workspaceTools.ts new file mode 100644 index 00000000..8e383e9f --- /dev/null +++ b/src/ai/tools/builtins/workspaceTools.ts @@ -0,0 +1,49 @@ +/** + * Built-in `workspace` tools (#714): read-only views of the active note + selection. + */ +import { type App, MarkdownView, TFile } from "obsidian"; +import { applyGroupOptions, defineTool, type BuiltinGroupOptions, type ToolSetMap } from "./shared"; + +const MAX_ACTIVE_CHARS = 16_000; + +export function createWorkspaceTools( + app: App, + options: BuiltinGroupOptions = {}, +): ToolSetMap { + const tools: ToolSetMap = { + get_active_note: defineTool({ + description: + "Get the currently active note (path, basename, and content). Returns active:null if there is none.", + inputSchema: { type: "object", properties: {} }, + readOnly: true, + execute: async () => { + const file = app.workspace.getActiveFile(); + if (!(file instanceof TFile)) return { active: null }; + const content = await app.vault.cachedRead(file); + return { + active: { + path: file.path, + basename: file.basename, + content: + content.length > MAX_ACTIVE_CHARS + ? content.slice(0, MAX_ACTIVE_CHARS) + "\n…[truncated]" + : content, + }, + }; + }, + }), + + get_selection: defineTool({ + description: "Get the text currently selected in the active editor (empty string if none).", + inputSchema: { type: "object", properties: {} }, + readOnly: true, + execute: async () => { + const view = app.workspace.getActiveViewOfType(MarkdownView); + const selection = view?.editor?.getSelection() ?? ""; + return { selection }; + }, + }), + }; + + return applyGroupOptions(tools, options); +} diff --git a/src/gui/AIAssistantSettingsModal.ts b/src/gui/AIAssistantSettingsModal.ts index f51b7574..613b5619 100644 --- a/src/gui/AIAssistantSettingsModal.ts +++ b/src/gui/AIAssistantSettingsModal.ts @@ -46,6 +46,7 @@ export class AIAssistantSettingsModal extends Modal { this.addDefaultModelSetting(this.contentEl); this.addPromptTemplateFolderPathSetting(this.contentEl); this.addShowAssistantSetting(this.contentEl); + this.addConfirmToolCallsSetting(this.contentEl); this.addDefaultSystemPromptSetting(this.contentEl); } @@ -116,6 +117,24 @@ export class AIAssistantSettingsModal extends Modal { }); } + addConfirmToolCallsSetting(container: HTMLElement) { + new Setting(container) + .setName("Confirm AI tool calls") + .setDesc( + "When an AI agent runs script-defined or built-in tools, ask before executing. 'Destructive only' confirms tools not marked read-only; 'Always' confirms every tool; 'Never' defers to each tool's own setting. A tool that requires approval is always confirmed regardless.", + ) + .addDropdown((dropdown) => { + dropdown.addOption("destructive", "Destructive tools only (recommended)"); + dropdown.addOption("always", "Always confirm every tool"); + dropdown.addOption("never", "Never (use each tool's own setting)"); + dropdown.setValue(this.settings.confirmToolCalls ?? "destructive"); + dropdown.onChange((value) => { + this.settings.confirmToolCalls = + value as QuickAddSettings["ai"]["confirmToolCalls"]; + }); + }); + } + addDefaultSystemPromptSetting(contentEl: HTMLElement) { new Setting(contentEl) .setName("Default System Prompt") diff --git a/src/gui/AIToolConfirmModal.ts b/src/gui/AIToolConfirmModal.ts new file mode 100644 index 00000000..6391d0d8 --- /dev/null +++ b/src/gui/AIToolConfirmModal.ts @@ -0,0 +1,118 @@ +import type { App } from "obsidian"; +import { ButtonComponent, Modal } from "obsidian"; + +export type ToolConfirmOutcome = "allow" | "allow-all" | "deny" | "abort"; + +/** + * Confirmation modal for an AI-requested tool call (#714). The model chose the tool + * and its arguments, so this is the human-in-the-loop checkpoint. Outcomes: + * - allow run this one call + * - allow-all run this and every later call this run (no more prompts) + * - deny skip this call (the model gets an isError result and can adapt) + * - abort cancel the whole AI run + * + * The safe option (Deny) is focused by default; Approve is destructive-styled. + * Dismissing (Esc / click-out) resolves to "deny" — declining THIS call, NOT + * aborting the run (use the explicit Abort button for that). + */ +export default class AIToolConfirmModal extends Modal { + private resolvePromise!: (outcome: ToolConfirmOutcome) => void; + public waitForClose: Promise; + private outcome: ToolConfirmOutcome | null = null; + + public static Prompt( + app: App, + toolName: string, + args: unknown, + ): Promise { + return new AIToolConfirmModal(app, toolName, args).waitForClose; + } + + private constructor( + app: App, + private toolName: string, + private args: unknown, + ) { + super(app); + this.waitForClose = new Promise((resolve) => { + this.resolvePromise = resolve; + }); + this.open(); + this.display(); + } + + private display() { + this.containerEl.addClass("quickAddModal", "qaAIToolConfirm"); + this.contentEl.empty(); + this.titleEl.textContent = `Run AI tool: ${this.toolName}?`; + + this.contentEl.createEl("p", { + text: "The AI requested this tool call with the arguments below.", + }); + const pre = this.contentEl.createEl("pre", { cls: "qaAIToolArgs" }); + pre.setText(safeStringify(this.args)); + + const buttons = this.contentEl.createDiv({ + cls: "yesNoPromptButtonContainer", + }); + + const abortBtn = new ButtonComponent(buttons) + .setButtonText("Abort run") + .onClick(() => this.submit("abort")); + const denyBtn = new ButtonComponent(buttons) + .setButtonText("Deny") + .onClick(() => this.submit("deny")); + const allowAllBtn = new ButtonComponent(buttons) + .setButtonText("Approve all this run") + .onClick(() => this.submit("allow-all")); + const allowBtn = new ButtonComponent(buttons) + .setButtonText("Approve") + .setCta() + .onClick(() => this.submit("allow")); + + // Focus the safe option. + denyBtn.buttonEl.focus(); + addArrowKeyNavigation([ + abortBtn.buttonEl, + denyBtn.buttonEl, + allowAllBtn.buttonEl, + allowBtn.buttonEl, + ]); + } + + private submit(outcome: ToolConfirmOutcome) { + this.outcome = outcome; + this.close(); + } + + onClose() { + super.onClose(); + // Dismiss without an explicit choice = decline THIS call (not abort). + this.resolvePromise(this.outcome ?? "deny"); + } +} + +function safeStringify(value: unknown): string { + try { + return JSON.stringify(value, null, 2) ?? String(value); + } catch { + return String(value); + } +} + +function addArrowKeyNavigation(buttons: HTMLButtonElement[]): void { + buttons.forEach((button) => { + button.addEventListener("keydown", (event) => { + if (event.key === "ArrowRight" || event.key === "ArrowLeft") { + const currentIndex = buttons.indexOf(button); + const nextIndex = + (currentIndex + + (event.key === "ArrowRight" ? 1 : -1) + + buttons.length) % + buttons.length; + buttons[nextIndex].focus(); + event.preventDefault(); + } + }); + }); +} diff --git a/src/quickAddApi.ts b/src/quickAddApi.ts index 64fd6a9b..a1dc8a37 100644 --- a/src/quickAddApi.ts +++ b/src/quickAddApi.ts @@ -17,6 +17,18 @@ import { import type { OpenAIModelParameters } from "./ai/OpenAIModelParameters"; import type { Model } from "./ai/Provider"; import { resolveProviderApiKey } from "./ai/providerSecrets"; +import { Agent } from "./ai/tools/Agent"; +import type { + AgentConfig, + QATool, + StopCondition, + ToolDefinitionInput, +} from "./ai/tools/aiToolTypes"; +import { createVaultTools } from "./ai/tools/builtins/vaultTools"; +import { createWorkspaceTools } from "./ai/tools/builtins/workspaceTools"; +import { createSystemTools } from "./ai/tools/builtins/systemTools"; +import type { BuiltinGroupOptions } from "./ai/tools/builtins/shared"; +import { assertAssignableVariableName } from "./ai/tools/assignableVariable"; import { CompleteFormatter } from "./formatters/completeFormatter"; import GenericCheckboxPrompt from "./gui/GenericCheckboxPrompt/genericCheckboxPrompt"; import GenericInfoDialog from "./gui/GenericInfoDialog/GenericInfoDialog"; @@ -349,6 +361,8 @@ export class QuickAddApi { settings?: Partial<{ variableName: string; shouldAssignVariables: boolean; + /** Alias: set the output variable name AND assign it (mirrors ai.agent). */ + assignToVariable: string; modelOptions: Partial; showAssistantMessages: boolean; systemPrompt: string; @@ -398,6 +412,10 @@ export class QuickAddApi { const apiKey = await resolveProviderApiKey(app, modelProvider); + if (settings?.assignToVariable) { + assertAssignableVariableName(settings.assignToVariable); + } + const assistantRes = await Prompt( app, { @@ -405,7 +423,8 @@ export class QuickAddApi { prompt, apiKey, modelOptions: settings?.modelOptions ?? {}, - outputVariableName: settings?.variableName ?? "output", + outputVariableName: + settings?.assignToVariable ?? settings?.variableName ?? "output", showAssistantMessages: settings?.showAssistantMessages ?? true, systemPrompt: settings?.systemPrompt ?? AISettings.defaultSystemPrompt, @@ -423,7 +442,7 @@ export class QuickAddApi { return {}; } - if (settings?.shouldAssignVariables) { + if (settings?.shouldAssignVariables || settings?.assignToVariable) { // Copy over `output` and `output-quoted` to the variables (if 'output' is variable name) Object.entries(assistantRes).forEach(([key, value]) => { choiceExecutor.variables.set(key, value); @@ -439,6 +458,8 @@ export class QuickAddApi { settings?: Partial<{ variableName: string; shouldAssignVariables: boolean; + /** Alias: set the output variable name AND assign it (mirrors ai.agent). */ + assignToVariable: string; modelOptions: Partial; showAssistantMessages: boolean; systemPrompt: string; @@ -493,6 +514,10 @@ export class QuickAddApi { const apiKey = await resolveProviderApiKey(app, modelProvider); + if (settings?.assignToVariable) { + assertAssignableVariableName(settings.assignToVariable); + } + const assistantRes = await ChunkedPrompt( app, { @@ -502,7 +527,8 @@ export class QuickAddApi { chunkSeparator: settings?.chunkSeparator ?? /\n/, apiKey, modelOptions: settings?.modelOptions ?? {}, - outputVariableName: settings?.variableName ?? "output", + outputVariableName: + settings?.assignToVariable ?? settings?.variableName ?? "output", showAssistantMessages: settings?.showAssistantMessages ?? true, systemPrompt: settings?.systemPrompt ?? AISettings.defaultSystemPrompt, @@ -528,7 +554,7 @@ export class QuickAddApi { return {}; } - if (settings?.shouldAssignVariables) { + if (settings?.shouldAssignVariables || settings?.assignToVariable) { // Copy over `output` and `output-quoted` to the variables (if 'output' is variable name) Object.entries(assistantRes).forEach(([key, value]) => { choiceExecutor.variables.set(key, value); @@ -576,6 +602,41 @@ export class QuickAddApi { clearRequestLogs() { clearAIRequestLogEntries(); }, + /** + * Create a tool-calling Agent (#714). Construct once with model/system/ + * tools/budget, then run `agent.generate({ prompt })` (text + tools) or + * `agent.generate({ prompt, schema })` (structured output). + */ + agent: (config: AgentConfig): Agent => + new Agent(app, plugin, choiceExecutor, config), + /** Declare a tool for an Agent's `tools` map. Pairs a JSON-Schema with a JS handler. */ + tool: (def: ToolDefinitionInput): QATool => ({ + ...def, + __qaTool: true, + }), + /** Stop condition: end the loop once it has taken `n` steps. */ + stepCountIs: + (n: number): StopCondition => + ({ stepNumber }) => + stepNumber >= n, + /** Stop condition: end the loop once the named tool has been called. */ + hasToolCall: + (name: string): StopCondition => + ({ toolCallNames }) => + toolCallNames.includes(name), + /** + * Standard built-in tools (#714), opt-in. Spread a group into an Agent's + * `tools` map, e.g. `tools: { ...quickAddApi.ai.tools.vault() }`. Each group + * factory accepts { only, exclude, prefix, allowedRoots }. Read tools auto-run; + * write tools require confirmation and are path-sanitized + symlink-guarded. + */ + tools: { + vault: (options?: BuiltinGroupOptions) => + createVaultTools(app, options), + workspace: (options?: BuiltinGroupOptions) => + createWorkspaceTools(app, options), + system: (options?: BuiltinGroupOptions) => createSystemTools(options), + }, }, utility: { getClipboard: async () => { diff --git a/src/settings.ts b/src/settings.ts index 14556107..d4471f74 100644 --- a/src/settings.ts +++ b/src/settings.ts @@ -70,6 +70,13 @@ export interface QuickAddSettings { promptTemplatesFolderPath: string; showAssistant: boolean; providers: AIProvider[]; + /** + * When AI tool calling (#714) runs a script-defined or built-in tool, ask + * before executing. 'destructive' (default) confirms any tool not marked + * read-only; 'always' confirms every tool; 'never' defers to each tool's own + * needsApproval. A tool that requires approval is always confirmed regardless. + */ + confirmToolCalls: "never" | "destructive" | "always"; }; migrations: { migrateToMacroIDFromEmbeddedMacro: boolean; @@ -86,6 +93,7 @@ export interface QuickAddSettings { setProviderModelDiscoveryMode: boolean; migrateProviderApiKeysToSecretStorage: boolean; migrateToMultipleTemplateFolders: boolean; + setProviderKind: boolean; }; } @@ -115,6 +123,7 @@ export const DEFAULT_SETTINGS: QuickAddSettings = { promptTemplatesFolderPath: "", showAssistant: true, providers: DefaultProviders, + confirmToolCalls: "destructive", }, migrations: { /** @@ -138,5 +147,6 @@ export const DEFAULT_SETTINGS: QuickAddSettings = { setProviderModelDiscoveryMode: false, migrateProviderApiKeysToSecretStorage: false, migrateToMultipleTemplateFolders: false, + setProviderKind: false, }, }; From f87e8627f7ea6828e4afa7cd80b418c40c5a4a0b Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:35:05 +0200 Subject: [PATCH 3/8] feat(import): disclose AI-tools capability in package preview (#714) A bundled script that wires up AI tool calling can let a model read/write the vault with model-chosen arguments. Scan the same decoded code the loader runs (including a js fence hidden in a .md note) and surface a distinct critical "AI TOOLS" capability row. --- src/services/packagePreview.test.ts | 45 +++++++++++++++++++++++ src/services/packagePreview.ts | 55 +++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/services/packagePreview.test.ts b/src/services/packagePreview.test.ts index 83448d75..0539fda1 100644 --- a/src/services/packagePreview.test.ts +++ b/src/services/packagePreview.test.ts @@ -404,6 +404,51 @@ describe("buildPackagePreview - safety must-fixes", () => { expect(preview.summary.overwritesFiles).toBe(0); expect(requiresAcknowledgement(preview)).toBe(true); }); + + it("discloses a bundled script that gives an AI model vault tools (#714)", () => { + const m = macro("m1", "Librarian", [ + userScript("c1", "run", "scripts/agent.js"), + ]); + const pkg = makePackage( + [pkgChoice(m, ["Librarian"])], + [ + asset( + "user-script", + "scripts/agent.js", + "module.exports = async ({ quickAddApi }) => {\n" + + " const a = quickAddApi.ai.agent({ model: 'gpt-4o', tools: { ...quickAddApi.ai.tools.vault() } });\n" + + " return (await a.generate({ prompt: 'hi' })).text;\n" + + "};", + ), + ], + ); + const preview = buildPackagePreview(NO_EXISTING, pkg, NONE); + const row = preview.capabilityRows.find((r) => r.flag === "ai-tools"); + expect(row).toBeDefined(); + expect(row?.severity).toBe("critical"); + expect(row?.scriptPath).toBe("scripts/agent.js"); + }); + + it("discloses AI tools hidden in an orphan note's js fence, not a plain script", () => { + // Orphan .md (referenced by no choice) with a js fence using destructured ai.tools. + const fenced = + "Some notes about the agent.\n\n```js\n" + + "const { ai } = quickAddApi;\n" + + "const tools = ai.tools.workspace();\n" + + "```\n"; + const pkg = makePackage( + [pkgChoice(macro("m1", "Plain", [userScript("c1", "run", "scripts/plain.js")]), ["Plain"])], + [ + asset("template", "Notes/hidden-agent.md", fenced), + asset("user-script", "scripts/plain.js", "module.exports = () => 1;"), + ], + ); + const preview = buildPackagePreview(NO_EXISTING, pkg, NONE); + const aiToolRows = preview.capabilityRows.filter((r) => r.flag === "ai-tools"); + expect(aiToolRows.map((r) => r.scriptPath)).toEqual(["Notes/hidden-agent.md"]); + // A plain script that never touches AI tools gets no ai-tools row. + expect(aiToolRows.some((r) => r.scriptPath === "scripts/plain.js")).toBe(false); + }); }); describe("buildPackagePreview - files manifest, overwrites, orphans, captures", () => { diff --git a/src/services/packagePreview.ts b/src/services/packagePreview.ts index 9e345d2e..76c1df5a 100644 --- a/src/services/packagePreview.ts +++ b/src/services/packagePreview.ts @@ -44,6 +44,7 @@ export type PreviewFlag = | "registers-command" | "obsidian-command" | "ai" + | "ai-tools" | "capture-writes" | "template-write" | "overwrites-existing-choice" @@ -104,6 +105,12 @@ const FLAG_META: Record = { label: "AI", description: "Sends note content to your AI provider over the network.", }, + "ai-tools": { + severity: "critical", + label: "AI TOOLS", + description: + "Lets an AI model read and write your vault: the script gives the model tools (functions) it calls with model-chosen arguments.", + }, "capture-writes": { severity: "warning", label: "WRITES", @@ -196,6 +203,36 @@ function markdownAssetIsExecutable( return code !== null && code.length > 0; } +// The runnable code of a bundled executable asset, for static disclosure scans: +// a `.md` note yields its first js fence (what the loader runs); any other +// executable asset (a `.js`, or a script-kind asset) yields its decoded body. +// Pure/App-free — operates on the bundled payload only. +function bundledScriptCode(originalPath: string, content: string): string | null { + let decoded: string; + try { + decoded = decodeFromBase64(content); + } catch { + return null; + } + if (MARKDOWN_FILE_EXTENSION_REGEX.test(originalPath)) { + const { code } = extractScriptFromMarkdown(decoded); + return code !== null && code.length > 0 ? code : null; + } + return decoded; +} + +// #714: a bundled script that wires up QuickAdd's AI tool-calling lets an AI MODEL +// read and write the vault with model-chosen arguments — a distinct risk class from +// "this script runs". Detect the common surface forms: `quickAddApi.ai.tools/.agent/ +// .tool(...)`, the destructured `ai.tools/agent/tool(...)`, and the built-in groups +// (`tools.vault/workspace/system(...)`). A heuristic for DISCLOSURE only — the +// security floor (any script ⇒ "full vault + network access, gated") already holds. +const AI_TOOLS_USE_REGEX = + /\bai\s*\.\s*(?:tools|agent|tool)\b|\btools\s*\.\s*(?:vault|workspace|system)\s*\(/; +function scriptUsesAiTools(code: string): boolean { + return AI_TOOLS_USE_REGEX.test(code); +} + const SEVERITY_ORDER: Record = { critical: 0, warning: 1, @@ -796,6 +833,24 @@ export function buildPackagePreview( }); } + // AI-tools disclosure (#714): scan each reviewable bundled script for AI + // tool-calling and surface a distinct critical row. Scans the SAME decoded code + // the loader runs, so a note that hides a js fence is covered too. + for (const file of files) { + if (!file.requiresReview) continue; + const asset = pkg.assets.find((a) => a.originalPath === file.originalPath); + if (!asset) continue; + const code = bundledScriptCode(asset.originalPath, asset.content); + if (!code || !scriptUsesAiTools(code)) continue; + capabilityRows.push({ + flag: "ai-tools", + severity: "critical", + title: "Lets an AI model read and write your vault", + detail: file.originalPath, + scriptPath: file.originalPath, + }); + } + const registersCommandCount = choices.filter((c) => c.registersCommand).length; if (registersCommandCount > 0) { capabilityRows.push({ From feedbc46029ec595b19bdced5da1e001c1fe8116 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:35:05 +0200 Subject: [PATCH 4/8] test(ai): live OpenAI/Anthropic/Gemini provider e2e for tool calling (#714) Env-gated (skipped without each provider's API key, so CI never hits the network). Each exercises the real pure modules end-to-end against a current model: tool loop + schema-constrained structured output. Validated live on gpt-5-mini/gpt-5.5, claude-sonnet-4-6/claude-opus-4-8, gemini-3-flash-preview. --- src/ai/tools/anthropic.e2e.test.ts | 101 ++++++++++++++++++++++++++++ src/ai/tools/gemini.e2e.test.ts | 102 ++++++++++++++++++++++++++++ src/ai/tools/openai.e2e.test.ts | 104 +++++++++++++++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 src/ai/tools/anthropic.e2e.test.ts create mode 100644 src/ai/tools/gemini.e2e.test.ts create mode 100644 src/ai/tools/openai.e2e.test.ts diff --git a/src/ai/tools/anthropic.e2e.test.ts b/src/ai/tools/anthropic.e2e.test.ts new file mode 100644 index 00000000..a177fd56 --- /dev/null +++ b/src/ai/tools/anthropic.e2e.test.ts @@ -0,0 +1,101 @@ +/** + * REAL-CALL e2e for the #714 wire against the live Anthropic Messages API. Exercises + * the actual pure modules (providerToolMapping.buildChatBody/parseChatResponse + + * runToolLoop + jsonSchemaValidator) end-to-end against a current Claude 4.x model. + * + * Skipped unless ANTHROPIC_API_KEY is set, so the normal suite + CI never hit the network: + * ANTHROPIC_API_KEY=$(op read "op://Agent Secrets/Anthropic Claude API Key/credential") \ + * npx vitest run src/ai/tools/anthropic.e2e.test.ts --config vitest.config.mts + */ +import { describe, it, expect } from "vitest"; +import { buildChatBody, parseChatResponse } from "./providerToolMapping"; +import { runToolLoop, type ToolEntry } from "./runToolLoop"; +import { validateValue } from "./jsonSchemaValidator"; +import type { NormalizedChatRequest } from "./NormalizedTools"; + +const KEY = process.env.ANTHROPIC_API_KEY; +const MODEL = process.env.ANTHROPIC_E2E_MODEL ?? "claude-sonnet-4-6"; +const URL = "https://api.anthropic.com/v1/messages"; + +async function anthropicDispatch(req: NormalizedChatRequest) { + const body = buildChatBody("anthropic", MODEL, req, 1024); + const res = await fetch(URL, { + method: "POST", + headers: { + "content-type": "application/json", + "x-api-key": KEY as string, + "anthropic-version": "2023-06-01", + }, + body: JSON.stringify(body), + }); + const json = (await res.json()) as Record; + if (!res.ok) throw new Error(`Anthropic ${res.status}: ${JSON.stringify(json)}`); + return parseChatResponse("anthropic", json); +} + +describe.skipIf(!KEY)("Anthropic live wire (e2e)", () => { + it("runs a real tool-calling loop end-to-end", async () => { + const calls: Array<{ a: number; b: number }> = []; + const add: ToolEntry = { + definition: { + name: "add", + description: "Add two integers and return their sum.", + parameters: { + type: "object", + properties: { a: { type: "integer" }, b: { type: "integer" } }, + required: ["a", "b"], + }, + }, + readOnly: true, + execute: (args) => { + const a = Number(args.a); + const b = Number(args.b); + calls.push({ a, b }); + return { sum: a + b }; + }, + }; + + const res = await runToolLoop({ + request: { + messages: [ + { role: "system", content: "You must use the add tool to compute sums. Do not compute them yourself." }, + { role: "user", content: "Use the add tool to add 17 and 25, then state the result as a number." }, + ], + tools: [add.definition], + toolChoice: "auto", + }, + maxSteps: 4, + dispatch: (req) => anthropicDispatch(req), + getTool: (name) => (name === "add" ? add : undefined), + confirm: async () => true, + validateArgs: (tool, args) => validateValue(args, tool.definition.parameters), + isAbortError: () => false, + }); + + expect(calls.length).toBeGreaterThanOrEqual(1); + expect(calls[0]).toEqual({ a: 17, b: 25 }); + expect(res.text).toMatch(/42/); + expect(res.finishReason).toBe("stop"); + }, 60000); + + it("returns schema-constrained structured output (native output_config.format)", async () => { + const schema = { + type: "object" as const, + properties: { + title: { type: "string" as const }, + tags: { type: "array" as const, items: { type: "string" as const } }, + }, + required: ["title", "tags"], + }; + const parsed = await anthropicDispatch({ + messages: [ + { role: "user", content: "Extract the title and the tags (without '#') from this note: 'Hello World #alpha #beta'." }, + ], + responseFormat: { schema, name: "note_meta", strict: true }, + }); + const obj = JSON.parse(parsed.content) as { title: string; tags: string[] }; + expect(validateValue(obj, schema)).toBeNull(); + expect(typeof obj.title).toBe("string"); + expect(obj.tags).toEqual(expect.arrayContaining(["alpha", "beta"])); + }, 60000); +}); diff --git a/src/ai/tools/gemini.e2e.test.ts b/src/ai/tools/gemini.e2e.test.ts new file mode 100644 index 00000000..7d7fca12 --- /dev/null +++ b/src/ai/tools/gemini.e2e.test.ts @@ -0,0 +1,102 @@ +/** + * REAL-CALL e2e for the #714 wire against the live Gemini generateContent API. + * Exercises the actual pure modules (providerToolMapping.buildChatBody/parseChatResponse + * + runToolLoop + jsonSchemaValidator) end-to-end against a current Gemini 3.x model, + * including the `model` role + functionResponse round-trip and the thoughtSignature echo + * (parseChatResponse preserves the raw parts as providerRaw; buildChatBody replays them). + * + * Skipped unless GEMINI_API_KEY is set, so the normal suite + CI never hit the network: + * GEMINI_API_KEY=$(op read "op://Agent Secrets/Gemini API Key/credential") \ + * npx vitest run src/ai/tools/gemini.e2e.test.ts --config vitest.config.mts + */ +import { describe, it, expect } from "vitest"; +import { buildChatBody, parseChatResponse } from "./providerToolMapping"; +import { runToolLoop, type ToolEntry } from "./runToolLoop"; +import { validateValue } from "./jsonSchemaValidator"; +import type { NormalizedChatRequest } from "./NormalizedTools"; + +const KEY = process.env.GEMINI_API_KEY; +// Gemini 3.x. Default to flash: the *pro* tier has 0 free-tier quota (HTTP 429) on +// many keys, while flash is reachable (though it can be slow / 503 under load — hence +// the generous per-test timeouts below). Override with GEMINI_E2E_MODEL. +const MODEL = process.env.GEMINI_E2E_MODEL ?? "gemini-3-flash-preview"; + +async function geminiDispatch(req: NormalizedChatRequest) { + const body = buildChatBody("gemini", MODEL, req); + const url = `https://generativelanguage.googleapis.com/v1beta/models/${MODEL}:generateContent?key=${KEY}`; + const res = await fetch(url, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify(body), + }); + const json = (await res.json()) as Record; + if (!res.ok) throw new Error(`Gemini ${res.status}: ${JSON.stringify(json)}`); + return parseChatResponse("gemini", json); +} + +describe.skipIf(!KEY)("Gemini live wire (e2e)", () => { + it("runs a real tool-calling loop end-to-end (functionCall/functionResponse + thoughtSignature echo)", async () => { + const calls: Array<{ a: number; b: number }> = []; + const add: ToolEntry = { + definition: { + name: "add", + description: "Add two integers and return their sum.", + parameters: { + type: "object", + properties: { a: { type: "integer" }, b: { type: "integer" } }, + required: ["a", "b"], + }, + }, + readOnly: true, + execute: (args) => { + const a = Number(args.a); + const b = Number(args.b); + calls.push({ a, b }); + return { sum: a + b }; + }, + }; + + const res = await runToolLoop({ + request: { + messages: [ + { role: "system", content: "You must use the add tool to compute sums. Do not compute them yourself." }, + { role: "user", content: "Use the add tool to add 17 and 25, then state the result as a number." }, + ], + tools: [add.definition], + toolChoice: "auto", + }, + maxSteps: 4, + dispatch: (req) => geminiDispatch(req), + getTool: (name) => (name === "add" ? add : undefined), + confirm: async () => true, + validateArgs: (tool, args) => validateValue(args, tool.definition.parameters), + isAbortError: () => false, + }); + + expect(calls.length).toBeGreaterThanOrEqual(1); + expect(calls[0]).toEqual({ a: 17, b: 25 }); + expect(res.text).toMatch(/42/); + expect(res.finishReason).toBe("stop"); + }, 120000); + + it("returns schema-constrained structured output (responseSchema)", async () => { + const schema = { + type: "object" as const, + properties: { + title: { type: "string" as const }, + tags: { type: "array" as const, items: { type: "string" as const } }, + }, + required: ["title", "tags"], + }; + const parsed = await geminiDispatch({ + messages: [ + { role: "user", content: "Extract the title and the tags (without '#') from this note: 'Hello World #alpha #beta'." }, + ], + responseFormat: { schema, name: "note_meta", strict: true }, + }); + const obj = JSON.parse(parsed.content) as { title: string; tags: string[] }; + expect(validateValue(obj, schema)).toBeNull(); + expect(typeof obj.title).toBe("string"); + expect(obj.tags).toEqual(expect.arrayContaining(["alpha", "beta"])); + }, 120000); +}); diff --git a/src/ai/tools/openai.e2e.test.ts b/src/ai/tools/openai.e2e.test.ts new file mode 100644 index 00000000..4157ec73 --- /dev/null +++ b/src/ai/tools/openai.e2e.test.ts @@ -0,0 +1,104 @@ +/** + * REAL-CALL e2e for the #714 wire against the live OpenAI API. Exercises the actual + * pure modules (providerToolMapping.buildChatBody/parseChatResponse + runToolLoop + + * jsonSchemaValidator) end-to-end — the layer the prototype could only mock. + * + * Skipped unless OPENAI_API_KEY is set, so the normal suite + CI never hit the network: + * OPENAI_API_KEY=$(op read "op://Agent Secrets/OpenAI API Key/credential") \ + * npx vitest run src/ai/tools/openai.e2e.test.ts --config vitest.config.mts + */ +import { describe, it, expect } from "vitest"; +import { buildChatBody, parseChatResponse } from "./providerToolMapping"; +import { runToolLoop, type ToolEntry } from "./runToolLoop"; +import { validateValue } from "./jsonSchemaValidator"; +import type { NormalizedChatRequest } from "./NormalizedTools"; + +const KEY = process.env.OPENAI_API_KEY; +// Current-generation default (GPT-5.x). The bare request path sends no max_tokens and +// no temperature, so it is wire-compatible with GPT-5.x reasoning models (which reject +// `max_tokens` in favour of `max_completion_tokens` and only accept the default +// temperature). Override with OPENAI_E2E_MODEL to target a specific model. +const MODEL = process.env.OPENAI_E2E_MODEL ?? "gpt-5-mini"; +const URL = "https://api.openai.com/v1/chat/completions"; + +async function openaiDispatch(req: NormalizedChatRequest) { + const body = buildChatBody("openai", MODEL, req); + const res = await fetch(URL, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${KEY}`, + }, + body: JSON.stringify(body), + }); + const json = (await res.json()) as Record; + if (!res.ok) throw new Error(`OpenAI ${res.status}: ${JSON.stringify(json)}`); + return parseChatResponse("openai", json); +} + +describe.skipIf(!KEY)("OpenAI live wire (e2e)", () => { + it("runs a real tool-calling loop end-to-end", async () => { + const calls: Array<{ a: number; b: number }> = []; + const add: ToolEntry = { + definition: { + name: "add", + description: "Add two integers and return their sum.", + parameters: { + type: "object", + properties: { a: { type: "integer" }, b: { type: "integer" } }, + required: ["a", "b"], + }, + }, + readOnly: true, + execute: (args) => { + const a = Number(args.a); + const b = Number(args.b); + calls.push({ a, b }); + return { sum: a + b }; + }, + }; + + const res = await runToolLoop({ + request: { + messages: [ + { role: "system", content: "You must use the add tool to compute sums. Do not compute them yourself." }, + { role: "user", content: "Use the add tool to add 17 and 25, then state the result as a number." }, + ], + tools: [add.definition], + toolChoice: "auto", + }, + maxSteps: 4, + dispatch: (req) => openaiDispatch(req), + getTool: (name) => (name === "add" ? add : undefined), + confirm: async () => true, + validateArgs: (tool, args) => validateValue(args, tool.definition.parameters), + isAbortError: () => false, + }); + + expect(calls.length).toBeGreaterThanOrEqual(1); + expect(calls[0]).toEqual({ a: 17, b: 25 }); + expect(res.text).toMatch(/42/); + expect(res.finishReason).toBe("stop"); + }, 60000); + + it("returns schema-constrained structured output", async () => { + const schema = { + type: "object" as const, + properties: { + title: { type: "string" as const }, + tags: { type: "array" as const, items: { type: "string" as const } }, + }, + required: ["title", "tags"], + }; + const parsed = await openaiDispatch({ + messages: [ + { role: "user", content: "Extract the title and the tags (without '#') from this note: 'Hello World #alpha #beta'." }, + ], + responseFormat: { schema, name: "note_meta", strict: true }, + }); + const obj = JSON.parse(parsed.content) as { title: string; tags: string[] }; + expect(validateValue(obj, schema)).toBeNull(); + expect(typeof obj.title).toBe("string"); + expect(obj.tags).toEqual(expect.arrayContaining(["alpha", "beta"])); + }, 60000); +}); From 4dd9b8a5afdf1533192acfd0cc7c427951751e47 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:35:05 +0200 Subject: [PATCH 5/8] docs(ai): document ai.agent tool calling + structured output (#714) --- docs/docs/AIAssistant.md | 40 ++++++++++-- docs/docs/QuickAddAPI.md | 129 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 4 deletions(-) diff --git a/docs/docs/AIAssistant.md b/docs/docs/AIAssistant.md index 8e512e18..5220b73c 100644 --- a/docs/docs/AIAssistant.md +++ b/docs/docs/AIAssistant.md @@ -93,10 +93,9 @@ Gemini is supported out of the box. Name: Gemini URL: https://generativelanguage.googleapis.com API Key secret: (AI Studio API key) -Models (add one or more): - - gemini-1.5-pro (Max Tokens: 1000000) - - gemini-1.5-flash (Max Tokens: 1000000) - - gemini-1.5-flash-8b (Max Tokens: 1000000) +Models (add one or more — use Browse models for the exact current IDs): + - gemini-3-pro (Max Tokens: 1000000) + - gemini-3-flash (Max Tokens: 1000000) ``` Notes: @@ -120,6 +119,7 @@ In the **AI Assistant Settings** modal (opened via the **Configure AI Assistant* - **Prompt Template Folder Path**: Path to your folder with prompt templates. - **Show Assistant**: Show status messages from the AI Assistant. - **Default System Prompt**: The default system prompt for the AI Assistant. Sets the behavior of the model. +- **Confirm AI tool calls**: When an AI agent runs a tool (see *Tool / function calling* below), whether to ask first. *Destructive tools only* (default) confirms any tool not marked read-only; *Always* confirms every tool; *Never* defers to each tool's own setting. A tool that requires approval is always confirmed regardless. For each individual AI Assistant command in your macros, you can set these options: @@ -135,6 +135,38 @@ You can also tweak model parameters in advanced settings: - **frequency_penalty:** A parameter ranging between -2.0 and 2.0. Positive values penalize new tokens based on their frequency in the existing text, reducing the model's tendency to repeat the same lines. (Not applicable to Gemini.) - **presence_penalty:** Also ranging between -2.0 and 2.0, positive values penalize new tokens based on their presence in the existing text, encouraging the model to introduce new topics. (Not applicable to Gemini.) +## Tool / function calling (scripts) + +Beyond one-shot prompts, the AI Assistant can act as a small **agent**: you give the model a +prompt plus a set of *tools* (JavaScript functions), and it decides which to call, in a bounded +multi-step loop, until it has an answer. This is available from the [script API](./QuickAddAPI.md) +only — tools are JS functions, so they live in a User Script (a macro), not in a stored choice. + +```js +module.exports = async ({ quickAddApi, app }) => { + const agent = quickAddApi.ai.agent({ + model: "gpt-5", + system: "You are a vault librarian. Ground every claim in the user's notes.", + tools: { ...quickAddApi.ai.tools.vault({ only: ["read_note", "search_notes"] }) }, + }); + const { text } = await agent.generate({ prompt: "What do my notes say about gardening?" }); + return text; +}; +``` + +QuickAdd ships **built-in tools** you can opt into (`quickAddApi.ai.tools.vault/workspace/system`), +and you can declare your own with `quickAddApi.ai.tool({ description, inputSchema, execute })`. See the +[API reference](./QuickAddAPI.md) for the full surface (agents, tools, structured output via a `schema`). + +:::warning Tool calls run your code with model-chosen arguments +Tool handlers run with full vault and network access. The **model** decides which tool to call and +with what arguments — possibly influenced by note content it reads. Treat tool results and note +content as untrusted data, validate the arguments your handlers receive, never pass them to +`format()`/`eval`/a shell, and never put secrets in a tool's description or arguments. Destructive +tools ask for confirmation by default (the **Confirm AI tool calls** setting); read-only tools run +automatically. +::: + ## AI-Powered Workflows You can create powerful workflows utilizing the AI Assistant. Some examples are: diff --git a/docs/docs/QuickAddAPI.md b/docs/docs/QuickAddAPI.md index 49347aaf..bc076807 100644 --- a/docs/docs/QuickAddAPI.md +++ b/docs/docs/QuickAddAPI.md @@ -610,6 +610,135 @@ const result = await quickAddApi.ai.chunkedPrompt( ); ``` +### Tool / function calling — `ai.agent(config)` + +Build an **agent**: give the model a prompt and a set of *tools* (JS functions), and it +will call them in a bounded multi-step loop until it has an answer. Works across your +configured OpenAI-compatible, Anthropic, and Gemini providers. + +```js +const agent = quickAddApi.ai.agent({ + model: "gpt-5", + system: "You manage an Obsidian vault. Use the tools to ground your answers.", + tools: { + // built-in tools, opt-in (see ai.tools.* below) + ...quickAddApi.ai.tools.vault({ only: ["read_note", "search_notes"] }), + // your own tool + save_link: quickAddApi.ai.tool({ + description: "Append a URL to the reading-list note.", + inputSchema: { + type: "object", + properties: { url: { type: "string" } }, + required: ["url"], + }, + needsApproval: true, // ask before running (the model chose the args) + execute: async ({ url }) => { + const file = app.vault.getAbstractFileByPath("Reading list.md"); + await app.vault.append(file, `\n- ${url}`); + return { saved: true }; + }, + }), + }, + maxSteps: 12, // optional; default 20, hard-capped at 100 +}); + +const { text, steps, toolCalls } = await agent.generate({ + prompt: "Summarise my notes about {{VALUE:topic}} and save any links you find.", + assignToVariable: "summary", // optional: writes {{VALUE:summary}} for a later step +}); +``` + +**`ai.agent(config)`** returns an Agent. Config: +- `model` — a configured model name (string) or `{ name }`. +- `system` — system prompt (defaults to your AI Assistant default system prompt). +- `tools` — an object map of tool name → tool (from `ai.tool()` and/or `ai.tools.*`). +- `toolChoice` — `"auto"` (default) | `"none"` | `"required"` | `{ type: "tool", toolName }`. +- `stopWhen` — one or more stop conditions from `ai.stepCountIs(n)` / `ai.hasToolCall(name)`. +- `maxSteps` — step budget (default 20, hard cap 100). Sugar for `stopWhen: ai.stepCountIs(n)`. +- `maxOutputTokens`, `modelOptions` — passed to the provider. + +**`agent.generate(options)`** runs the loop and resolves to a result: +- `text` — the final assistant text. +- `object` — present **only** when you pass a `schema` (structured output, below). +- `steps` — the full transcript: `{ text, toolCalls, toolResults, finishReason }[]`. +- `toolCalls` / `toolResults` — from the last step (`input` / `output` fields, AI-SDK style). +- `usage` — `{ inputTokens, outputTokens, totalTokens }`. +- `finishReason` — `"stop" | "max-steps" | "length" | "aborted" | "context-overflow"`. + +Options: `prompt` (formatted, like `ai.prompt`), `schema`, `system`/`toolChoice`/`maxOutputTokens` +(per-call overrides), and `assignToVariable` (write `text` into `{{VALUE:name}}`). + +The agent is a **stateless config holder** — each `generate()` is independent (no retained +conversation). Reuse means reusing the config; run one `generate()` at a time per agent. + +### `ai.tool(def)` + +Declares a tool. `def`: `{ description, inputSchema (JSON Schema), execute, needsApproval?, readOnly?, strict? }`. + +- `inputSchema` is a **JSON-Schema subset** (`type`/`properties`/`required`/`enum`/`items`). + Unsupported keywords (`pattern`, `additionalProperties`, `$ref`, `format`, …) are rejected + at registration so a provider can't silently drop a constraint. +- `execute(input, ctx)` runs your code with the model-chosen `input` (validated against the + schema first). Return a string (used verbatim) or any JSON-serialisable value. +- `needsApproval` (boolean or `(args) => boolean`) asks before running. `readOnly: true` marks a + tool that only reads, so it auto-runs under the default confirmation setting. + +> **Confirmation needs an interactive Obsidian session.** A tool that asks for approval opens a +> modal and waits for it. For unattended automation (e.g. driving QuickAdd from the CLI), give the +> agent only `readOnly` tools, or set **Confirm AI tool calls** to *Never* and gate each tool with +> its own `needsApproval` — otherwise the run blocks on a dialog no one can answer. + +> ⚠️ **Security.** Tool handlers run with the same full privilege as your script (Node `require`, +> `app`, the vault). The **model decides which tool to call and with what arguments**, possibly +> influenced by note content it reads (indirect prompt injection). QuickAdd never runs model-chosen +> arguments through the formatter — and **neither should you**: never pass a tool's `input` to +> `quickAddApi.format()`, `eval`, a shell, or `fetch` without validating it. Never put secrets in a +> tool's description or arguments (they are sent to the provider). Confirmation is governed by each +> tool's `needsApproval` plus the global **Confirm AI tool calls** setting (default *destructive only*). + +### Built-in tools — `ai.tools.{vault, workspace, system}(options)` + +Opt-in groups of ready-made tools. Each returns a tool map you spread into an agent's `tools`. +Options: `{ only, exclude, prefix, allowedRoots }` (`allowedRoots` confines vault paths to folders). + +| Group | Read-only (auto-run) | Write (asks for approval) | +|---|---|---| +| `vault` | `read_note`, `list_notes`, `search_notes`, `get_property_values` | `create_note`, `append_to_note`, `insert_under_heading` | +| `workspace` | `get_active_note`, `get_selection` | — | +| `system` | `get_date` | — | + +Write tools sanitise every model-chosen path (rejecting traversal and config dirs like `.obsidian`/ +`.git`, and symlinks that escape the vault), fail rather than overwrite an existing note, and are +frontmatter-aware. There are **no ambient tools** — nothing runs unless you spread it into `tools`. + +### Structured output — `agent.generate({ prompt, schema })` + +Pass a JSON schema to get a validated object back: + +```js +const { object } = await quickAddApi.ai.agent({ model: "gpt-5" }).generate({ + prompt: "Extract the title and tags from the selection.", + schema: { + type: "object", + properties: { title: { type: "string" }, tags: { type: "array", items: { type: "string" } } }, + required: ["title", "tags"], + }, +}); +// object => { title: "...", tags: ["...", "..."] } +``` + +`object` is the parsed, schema-validated result (or `undefined` if the model could not produce a +match after one repair attempt). Structured output works on current models — OpenAI GPT-5.x (and +GPT-4o-class), Anthropic Claude 4.x, and Gemini 3.x; it can be combined with tools. Older models +that do not support schema-constrained output (e.g. legacy OpenAI chat models) reject the request +outright with a provider error — use a current model rather than expecting a best-effort fallback. + +:::note OpenAI reasoning models (GPT-5.x, o-series) +These accept only the default `temperature` (omit it from `modelOptions`), and QuickAdd +automatically sends `maxOutputTokens` as `max_completion_tokens` for them. The agent's default +path sets neither, so `quickAddApi.ai.agent({ model: "gpt-5" })` works as-is. +::: + ### `getModels(): string[]` Returns available AI models. From 461874e9aa67974d51c0539576123d4e3b6b253a Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:39:47 +0200 Subject: [PATCH 6/8] fix(ai): match provider host precisely in getProviderKind (#714) CodeQL js/incomplete-url-substring-sanitization: endpoint.includes(host) could mis-route a URL with the known host in its path/query or as a fake subdomain prefix. Parse the hostname and match exactly or on a real subdomain instead. --- src/ai/Provider.test.ts | 11 +++++++++++ src/ai/Provider.ts | 27 +++++++++++++++++++++------ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/ai/Provider.test.ts b/src/ai/Provider.test.ts index 89857f41..ad2c84db 100644 --- a/src/ai/Provider.test.ts +++ b/src/ai/Provider.test.ts @@ -10,6 +10,9 @@ describe("getProviderKind", () => { it("infers anthropic from name or endpoint", () => { expect(getProviderKind({ name: "Anthropic" })).toBe("anthropic"); expect(getProviderKind({ name: "My Claude", endpoint: "https://api.anthropic.com" })).toBe("anthropic"); + expect(getProviderKind({ name: "Claude Proxy", endpoint: "https://api.anthropic.com/v1/messages" })).toBe("anthropic"); + // scheme-less endpoints still parse + expect(getProviderKind({ name: "X", endpoint: "api.anthropic.com" })).toBe("anthropic"); }); it("infers gemini from name or endpoint", () => { @@ -19,9 +22,17 @@ describe("getProviderKind", () => { ).toBe("gemini"); }); + it("matches the hostname precisely, not a substring of the URL (CodeQL js/incomplete-url-substring-sanitization)", () => { + // The known host appearing in the path/query or as a fake subdomain prefix must NOT match. + expect(getProviderKind({ name: "Evil", endpoint: "https://evil.com/?x=api.anthropic.com" })).toBe("openai"); + expect(getProviderKind({ name: "Evil", endpoint: "https://api.anthropic.com.attacker.example/v1" })).toBe("openai"); + expect(getProviderKind({ name: "Evil", endpoint: "https://generativelanguage.googleapis.com.evil.test" })).toBe("openai"); + }); + it("defaults unknown/OpenAI-compatible providers to openai", () => { expect(getProviderKind({ name: "Groq", endpoint: "https://api.groq.com/openai/v1" })).toBe("openai"); expect(getProviderKind({ name: "OpenRouter" })).toBe("openai"); + expect(getProviderKind({ endpoint: "not a url" })).toBe("openai"); expect(getProviderKind({})).toBe("openai"); }); }); diff --git a/src/ai/Provider.ts b/src/ai/Provider.ts index dabd8f6d..39af32a2 100644 --- a/src/ai/Provider.ts +++ b/src/ai/Provider.ts @@ -37,19 +37,34 @@ export function getProviderKind(provider: { }): ProviderKind { if (provider.kind) return provider.kind; const name = (provider.name ?? "").toLowerCase(); - const endpoint = (provider.endpoint ?? "").toLowerCase(); - if (name === "anthropic" || endpoint.includes("api.anthropic.com")) { + const host = endpointHost(provider.endpoint); + // Match the parsed HOSTNAME precisely (exact or a real subdomain), not a raw + // substring of the whole URL — so e.g. `https://evil.com/?api.anthropic.com` + // can't be mistaken for Anthropic. + const isHost = (h: string) => host === h || host.endsWith(`.${h}`); + if (name === "anthropic" || isHost("api.anthropic.com")) { return "anthropic"; } - if ( - name === "gemini" || - endpoint.includes("generativelanguage.googleapis.com") - ) { + if (name === "gemini" || isHost("generativelanguage.googleapis.com")) { return "gemini"; } return "openai"; } +/** Lowercased hostname of an endpoint, or "" if it can't be parsed (scheme optional). */ +function endpointHost(endpoint?: string): string { + const raw = (endpoint ?? "").trim(); + if (!raw) return ""; + for (const candidate of [raw, `https://${raw}`]) { + try { + return new URL(candidate).hostname.toLowerCase(); + } catch { + /* try next form */ + } + } + return ""; +} + export interface Model { name: string; maxTokens: number; From 236528a5efe2c06bd31e49fa27f08775cbb69187 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 08:58:38 +0200 Subject: [PATCH 7/8] fix(ai): harden tool schema, args, paths + structured-output repair (#714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address CodeRabbit review on #1372: - jsonSchemaValidator: reject tuple-style `items` arrays at registration (they were accepted but skipped at runtime); use own-property check for `required` so an inherited name (e.g. "toString") can't satisfy it. - providerToolMapping: normalize OpenAI tool args — a parsed array/primitive now becomes parseError instead of a non-object `args`. - runToolLoop: reserve the truncation marker's bytes so a capped tool result (content + marker) still fits maxResultBytes. - vaultTools: note writers enforce .md — refuse a non-markdown extension so a model can't create/modify a .js (runnable) or other arbitrary file type. - Agent: don't run the structured-output network repair when the loop ended aborted/context-overflow or online features are disabled (keep the free first parse); the first parse never makes a call. - quickAddApi: an empty assignToVariable ("") now falls back instead of leaking as the output variable name, matching ai.agent's length>0 behavior. - docs: convert the confirmation note to a :::note admonition (markdownlint MD028). --- docs/docs/QuickAddAPI.md | 10 +++++--- src/ai/tools/Agent.ts | 12 +++++++++ src/ai/tools/builtins/builtins.test.ts | 14 ++++++++++- src/ai/tools/builtins/vaultTools.ts | 19 ++++++++++---- src/ai/tools/jsonSchemaValidator.test.ts | 14 +++++++++++ src/ai/tools/jsonSchemaValidator.ts | 15 +++++++---- src/ai/tools/providerToolMapping.test.ts | 11 ++++++++ src/ai/tools/providerToolMapping.ts | 32 ++++++++++++++++-------- src/ai/tools/runToolLoop.ts | 13 ++++++---- src/quickAddApi.ts | 8 ++++-- 10 files changed, 115 insertions(+), 33 deletions(-) diff --git a/docs/docs/QuickAddAPI.md b/docs/docs/QuickAddAPI.md index bc076807..1bf112ef 100644 --- a/docs/docs/QuickAddAPI.md +++ b/docs/docs/QuickAddAPI.md @@ -683,10 +683,12 @@ Declares a tool. `def`: `{ description, inputSchema (JSON Schema), execute, need - `needsApproval` (boolean or `(args) => boolean`) asks before running. `readOnly: true` marks a tool that only reads, so it auto-runs under the default confirmation setting. -> **Confirmation needs an interactive Obsidian session.** A tool that asks for approval opens a -> modal and waits for it. For unattended automation (e.g. driving QuickAdd from the CLI), give the -> agent only `readOnly` tools, or set **Confirm AI tool calls** to *Never* and gate each tool with -> its own `needsApproval` — otherwise the run blocks on a dialog no one can answer. +:::note Confirmation needs an interactive Obsidian session +A tool that asks for approval opens a modal and waits for it. For unattended automation (e.g. +driving QuickAdd from the CLI), give the agent only `readOnly` tools, or set **Confirm AI tool +calls** to *Never* and gate each tool with its own `needsApproval` — otherwise the run blocks on a +dialog no one can answer. +::: > ⚠️ **Security.** Tool handlers run with the same full privilege as your script (Node `require`, > `app`, the vault). The **model decides which tool to call and with what arguments**, possibly diff --git a/src/ai/tools/Agent.ts b/src/ai/tools/Agent.ts index 2432f66f..cb83331c 100644 --- a/src/ai/tools/Agent.ts +++ b/src/ai/tools/Agent.ts @@ -392,6 +392,18 @@ export class Agent { const first = parseStructured(loop.finalTurn.content, schema); if (first.ok) return first.value; + // The first parse above is free (no network). The repair re-ask below makes a + // new outbound call, so DON'T do it when the loop ended in a terminal + // non-success state (online disabled mid-run → "aborted", or "context-overflow") + // or online features are now off — that would bypass the mid-run stop. + if ( + loop.finishReason === "aborted" || + loop.finishReason === "context-overflow" || + settingsStore.getState().disableOnlineFeatures + ) { + return undefined; + } + // One bounded repair re-ask: a fresh stricter request (never a format() pass). // No tools + responseFormat is the valid no-tools shape for every provider. // loop.messages omits the final assistant turn, so include the bad reply diff --git a/src/ai/tools/builtins/builtins.test.ts b/src/ai/tools/builtins/builtins.test.ts index 2bb29244..8ef17d7f 100644 --- a/src/ai/tools/builtins/builtins.test.ts +++ b/src/ai/tools/builtins/builtins.test.ts @@ -70,11 +70,23 @@ describe("vault write tools — safety", () => { const app = makeApp(); const tools = createVaultTools(app); await expect( - tools.create_note.execute({ path: "Notes/.obsidian/evil/main.js" }, { toolCallId: "c", toolName: "create_note" }), + tools.create_note.execute({ path: "Notes/.obsidian/evil/main.md" }, { toolCallId: "c", toolName: "create_note" }), ).rejects.toBeInstanceOf(UnsafeVaultPathError); expect((app.vault.create as ReturnType)).not.toHaveBeenCalled(); }); + it("note writers refuse a non-markdown extension (no .js/.css etc.)", async () => { + const app = makeApp(); + const tools = createVaultTools(app); + await expect( + tools.create_note.execute({ path: "Notes/evil.js" }, { toolCallId: "c", toolName: "create_note" }), + ).rejects.toThrow(/markdown/i); + await expect( + tools.append_to_note.execute({ path: "Notes/style.css", content: "x" }, { toolCallId: "c", toolName: "append_to_note" }), + ).rejects.toThrow(/markdown/i); + expect((app.vault.create as ReturnType)).not.toHaveBeenCalled(); + }); + it("create_note ensures .md and calls vault.create (fail-on-exist via the API)", async () => { const create = vi.fn(async (p: string) => fileLike(p)); const app = makeApp({ vault: { create, getAbstractFileByPath: () => fileLike("Notes") } }); diff --git a/src/ai/tools/builtins/vaultTools.ts b/src/ai/tools/builtins/vaultTools.ts index 7a9d5830..bed5a979 100644 --- a/src/ai/tools/builtins/vaultTools.ts +++ b/src/ai/tools/builtins/vaultTools.ts @@ -161,7 +161,7 @@ export function createVaultTools( ), needsApproval: true, execute: async ({ path, content }) => { - const norm = sanitizeVaultPath(ensureMd(String(path)), { allowedRoots: roots }); + const norm = sanitizeVaultPath(ensureMarkdownPath(String(path)), { allowedRoots: roots }); await assertWriteStaysInVault(app, norm); await ensureParentFolder(app, norm); const file = await app.vault.create(norm, String(content ?? "")); @@ -181,7 +181,7 @@ export function createVaultTools( ), needsApproval: true, execute: async ({ path, content, position }) => { - const norm = sanitizeVaultPath(String(path), { allowedRoots: roots }); + const norm = sanitizeVaultPath(ensureMarkdownPath(String(path)), { allowedRoots: roots }); const file = requireFile(app, norm); await assertWriteStaysInVault(app, norm); const body = await app.vault.read(file); @@ -208,7 +208,7 @@ export function createVaultTools( ), needsApproval: true, execute: async ({ path, heading, content }) => { - const norm = sanitizeVaultPath(String(path), { allowedRoots: roots }); + const norm = sanitizeVaultPath(ensureMarkdownPath(String(path)), { allowedRoots: roots }); const file = requireFile(app, norm); await assertWriteStaysInVault(app, norm); const headings = app.metadataCache.getFileCache(file)?.headings ?? []; @@ -249,8 +249,17 @@ function toInt(v: unknown, fallback: number): number { function clamp(n: number, lo: number, hi: number): number { return Math.max(lo, Math.min(hi, n)); } -function ensureMd(path: string): string { - return /\.[a-z0-9]+$/i.test(path) ? path : `${path}.md`; +// These are NOTE tools: confine them to markdown. A bare name gets `.md`; an +// explicit non-`.md` extension is refused so a model can't create/modify a `.js` +// (which a macro could then run) or any other arbitrary file type via a "note" tool. +function ensureMarkdownPath(path: string): string { + if (/\.md$/i.test(path)) return path; + if (/\.[a-z0-9]+$/i.test(path)) { + throw new Error( + `This tool only operates on markdown notes (.md); refused path: "${path}".`, + ); + } + return `${path}.md`; } function requireFile(app: App, normalizedPath: string): TFile { const file = app.vault.getAbstractFileByPath(normalizedPath); diff --git a/src/ai/tools/jsonSchemaValidator.test.ts b/src/ai/tools/jsonSchemaValidator.test.ts index 0b1dbcb8..039130dd 100644 --- a/src/ai/tools/jsonSchemaValidator.test.ts +++ b/src/ai/tools/jsonSchemaValidator.test.ts @@ -58,6 +58,15 @@ describe("assertRegisterableSchema", () => { assertRegisterableSchema({ type: "object", required: "name" as never }), ).toThrow(ToolSchemaError); }); + + it("rejects tuple-style items (an array of schemas) — not validated at runtime", () => { + expect(() => + assertRegisterableSchema({ + type: "array", + items: [{ type: "string" }, { type: "integer" }] as never, + }), + ).toThrow(ToolSchemaError); + }); }); describe("validateValue", () => { @@ -82,6 +91,11 @@ describe("validateValue", () => { expect(validateValue({ count: 1 }, schema)).toMatch(/required/); }); + it("treats an inherited (prototype) property as missing for required (own-property check)", () => { + // `toString` exists on the prototype but not as an own property → must be flagged. + expect(validateValue({}, { type: "object", required: ["toString"] })).toMatch(/required/); + }); + it("flags a wrong type", () => { expect(validateValue({ path: 42 }, schema)).toMatch(/expected string/); expect(validateValue({ path: "a", count: 1.5 }, schema)).toMatch(/integer/); diff --git a/src/ai/tools/jsonSchemaValidator.ts b/src/ai/tools/jsonSchemaValidator.ts index 0e82aca9..d3a7ea52 100644 --- a/src/ai/tools/jsonSchemaValidator.ts +++ b/src/ai/tools/jsonSchemaValidator.ts @@ -103,13 +103,15 @@ export function assertRegisterableSchema( } if (schema.items !== undefined) { + // Tuple-style `items` (an array of schemas) is NOT validated at runtime + // (validateValue only applies a single items schema), so reject it at + // registration rather than let the constraint be silently skipped. if (Array.isArray(schema.items)) { - schema.items.forEach((sub, i) => - assertRegisterableSchema(sub, `${where}.items[${i}]`), + throw new ToolSchemaError( + `${where}.items as an array (tuple validation) is not supported in QuickAdd's schema subset.`, ); - } else { - assertRegisterableSchema(schema.items, `${where}.items`); } + assertRegisterableSchema(schema.items, `${where}.items`); } if (schema.enum !== undefined && !Array.isArray(schema.enum)) { @@ -171,7 +173,10 @@ export function validateValue( const obj = value as Record; if (Array.isArray(schema.required)) { for (const key of schema.required) { - if (!(key in obj)) return `${path}.${key}: required property is missing`; + // own-property check: `key in obj` would match inherited props + // (e.g. a required "toString" would pass even when absent). + if (!Object.prototype.hasOwnProperty.call(obj, key)) + return `${path}.${key}: required property is missing`; } } if (schema.properties) { diff --git a/src/ai/tools/providerToolMapping.test.ts b/src/ai/tools/providerToolMapping.test.ts index fe2edabc..569546c3 100644 --- a/src/ai/tools/providerToolMapping.test.ts +++ b/src/ai/tools/providerToolMapping.test.ts @@ -89,6 +89,17 @@ describe("OpenAI mapping", () => { expect(res.toolCalls[0].parseError).toBeUndefined(); }); + it("flags non-object args (array/primitive) as parseError — args must be an object", () => { + const arr = parseChatResponse("openai", { + choices: [{ finish_reason: "tool_calls", message: { tool_calls: [{ id: "c1", function: { name: "t", arguments: "[1,2]" } }] } }], + }); + expect(arr.toolCalls[0]).toMatchObject({ args: null, parseError: true }); + const prim = parseChatResponse("openai", { + choices: [{ finish_reason: "tool_calls", message: { tool_calls: [{ id: "c2", function: { name: "t", arguments: 5 as unknown as string } }] } }], + }); + expect(prim.toolCalls[0]).toMatchObject({ args: null, parseError: true }); + }); + it("uses max_completion_tokens for reasoning models (gpt-5+/o-series), max_tokens otherwise", () => { const req: NormalizedChatRequest = { messages: [{ role: "user", content: "q" }], diff --git a/src/ai/tools/providerToolMapping.ts b/src/ai/tools/providerToolMapping.ts index f5ae9a49..2eea1d35 100644 --- a/src/ai/tools/providerToolMapping.ts +++ b/src/ai/tools/providerToolMapping.ts @@ -198,16 +198,20 @@ function parseOpenAIResponse(json: Record): ParsedChatResult { }; } +/** A tool call's args MUST be a JSON object — arrays/primitives violate the contract. */ +function asArgsRecord(value: unknown): Record | null { + return value !== null && typeof value === "object" && !Array.isArray(value) + ? (value as Record) + : null; +} + function parseOpenAIToolCall(tc: OpenAIToolCallRaw): NormalizedToolCall { const raw = tc.function.arguments; if (typeof raw === "string") { try { - return { - id: tc.id, - name: tc.function.name, - args: JSON.parse(raw || "{}") as Record, - rawArgs: raw, - }; + const rec = asArgsRecord(JSON.parse(raw || "{}")); + if (!rec) throw new Error("tool arguments are not a JSON object"); + return { id: tc.id, name: tc.function.name, args: rec, rawArgs: raw }; } catch { return { id: tc.id, @@ -219,11 +223,17 @@ function parseOpenAIToolCall(tc: OpenAIToolCallRaw): NormalizedToolCall { } } // Defensive: some OpenAI-compatible servers (Ollama native) send an object. - return { - id: tc.id, - name: tc.function.name, - args: (raw as Record) ?? {}, - }; + const rec = asArgsRecord(raw); + if (!rec) { + return { + id: tc.id, + name: tc.function.name, + args: null, + rawArgs: typeof raw === "string" ? raw : JSON.stringify(raw ?? {}), + parseError: true, + }; + } + return { id: tc.id, name: tc.function.name, args: rec }; } // =========================================================================== diff --git a/src/ai/tools/runToolLoop.ts b/src/ai/tools/runToolLoop.ts index daa69b6c..8122b4ed 100644 --- a/src/ai/tools/runToolLoop.ts +++ b/src/ai/tools/runToolLoop.ts @@ -149,13 +149,16 @@ export function stringifyToolResult( } } if (byteLength(s) > maxBytes) { - // Byte-accurate truncation: shrink the char cut until it fits the byte cap - // (so multibyte content can't overshoot), without splitting a code point. - let cut = Math.min(s.length, maxBytes); - while (cut > 0 && byteLength(s.slice(0, cut)) > maxBytes) { + // Byte-accurate truncation: reserve room for the marker so the FINAL string + // (content + marker) still fits maxBytes, and shrink the char cut until the + // prefix fits that budget without splitting a code point. + const marker = " …[truncated]"; + const budget = Math.max(0, maxBytes - byteLength(marker)); + let cut = Math.min(s.length, budget); + while (cut > 0 && byteLength(s.slice(0, cut)) > budget) { cut = Math.floor(cut * 0.9) || cut - 1; } - s = s.slice(0, cut) + " …[truncated]"; + s = s.slice(0, cut) + marker; } return { content: s, isError: false }; } diff --git a/src/quickAddApi.ts b/src/quickAddApi.ts index a1dc8a37..3635c0e5 100644 --- a/src/quickAddApi.ts +++ b/src/quickAddApi.ts @@ -424,7 +424,9 @@ export class QuickAddApi { apiKey, modelOptions: settings?.modelOptions ?? {}, outputVariableName: - settings?.assignToVariable ?? settings?.variableName ?? "output", + // `||` not `??`: an empty assignToVariable ("") means "no explicit + // variable" (matching ai.agent's length>0 check), so fall through. + settings?.assignToVariable || settings?.variableName || "output", showAssistantMessages: settings?.showAssistantMessages ?? true, systemPrompt: settings?.systemPrompt ?? AISettings.defaultSystemPrompt, @@ -528,7 +530,9 @@ export class QuickAddApi { apiKey, modelOptions: settings?.modelOptions ?? {}, outputVariableName: - settings?.assignToVariable ?? settings?.variableName ?? "output", + // `||` not `??`: an empty assignToVariable ("") means "no explicit + // variable" (matching ai.agent's length>0 check), so fall through. + settings?.assignToVariable || settings?.variableName || "output", showAssistantMessages: settings?.showAssistantMessages ?? true, systemPrompt: settings?.systemPrompt ?? AISettings.defaultSystemPrompt, From 2612f25dc4c0a5c7ecfe21fc3e0ac850de2b2c02 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Wed, 17 Jun 2026 09:09:07 +0200 Subject: [PATCH 8/8] fix(ai): conservative Anthropic max_tokens + drop kind-backfill migration (#714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Codex review on #1372: - anthropicMaxTokens: default back to 4096 (the long-standing pre-#714 value). Anthropic 400s when max_tokens exceeds a model's OUTPUT cap (verified live), and 8192 exceeds Claude 3 Haiku's 4096 — so the bump could reject small-input requests. Explicit maxOutputTokens is still honored uncapped for 4.x models. - Remove the setProviderKind migration: it persisted an inferred wire kind that the provider editor can't change, so repointing a provider's endpoint left a stale kind and the wrong wire shape. getProviderKind already infers at request time when kind is absent; the explicit field stays for known/preset providers (e.g. a custom-named Anthropic proxy). --- src/ai/OpenAIRequest.test.ts | 6 ++--- src/ai/OpenAIRequest.ts | 12 +++++----- src/migrations/migrate.ts | 2 -- src/migrations/setProviderKind.ts | 37 ------------------------------- src/settings.ts | 2 -- 5 files changed, 10 insertions(+), 49 deletions(-) delete mode 100644 src/migrations/setProviderKind.ts diff --git a/src/ai/OpenAIRequest.test.ts b/src/ai/OpenAIRequest.test.ts index 0bf28c8d..72864959 100644 --- a/src/ai/OpenAIRequest.test.ts +++ b/src/ai/OpenAIRequest.test.ts @@ -401,8 +401,8 @@ describe("OpenAIRequest", () => { const body = JSON.parse(arg.body); expect(body).toEqual({ model: "claude-3-5-sonnet", - // min(8192, 200000) — derived from the model's context window, not hardcoded 4096. - max_tokens: 8192, + // Conservative 4096 default (at/below every current Claude model's output cap). + max_tokens: 4096, messages: [{ role: "user", content: "hello claude" }], system: "system prompt", }); @@ -427,7 +427,7 @@ describe("OpenAIRequest", () => { const body = JSON.parse(requestUrlMock.mock.calls[0][0].body); expect("system" in body).toBe(false); - expect(body.max_tokens).toBe(8192); + expect(body.max_tokens).toBe(4096); }); it("omits the system key for a whitespace-only system prompt", async () => { diff --git a/src/ai/OpenAIRequest.ts b/src/ai/OpenAIRequest.ts index b33aeb91..cc1d0cd3 100644 --- a/src/ai/OpenAIRequest.ts +++ b/src/ai/OpenAIRequest.ts @@ -199,12 +199,14 @@ async function makeOpenAIRequest( ); } -// Conservative Anthropic output budget. The Messages API REQUIRES max_tokens, but -// model.maxTokens is the *context window*, not an output cap — so derive a sensible -// floor and clamp it to the window (a small/mis-registered model can't request more -// than it has). A dedicated per-model output field is a future refinement. +// Conservative Anthropic output budget. The Messages API REQUIRES max_tokens and +// REJECTS (400) any value above a model's OUTPUT cap — and model.maxTokens is the +// *context window*, not the output cap. So default to 4096, which is at or below the +// output limit of every current Claude model (incl. Claude 3 Haiku's 4096), matching +// the long-standing pre-#714 default. Callers needing more pass maxOutputTokens +// explicitly (honored uncapped). A dedicated per-model output field is a future refinement. export function anthropicMaxTokens(model: Model): number { - const ceiling = 8192; + const ceiling = 4096; return Number.isFinite(model.maxTokens) && model.maxTokens > 0 ? Math.min(ceiling, model.maxTokens) : ceiling; diff --git a/src/migrations/migrate.ts b/src/migrations/migrate.ts index f8dc32e1..a95e390a 100644 --- a/src/migrations/migrate.ts +++ b/src/migrations/migrate.ts @@ -15,7 +15,6 @@ import setProviderModelDiscoveryMode from "./setProviderModelDiscoveryMode"; import { deepClone } from "src/utils/deepClone"; import migrateProviderApiKeysToSecretStorage from "./migrateProviderApiKeysToSecretStorage"; import migrateToMultipleTemplateFolders from "./migrateToMultipleTemplateFolders"; -import setProviderKind from "./setProviderKind"; import { settingsStore } from "src/settingsStore"; const migrations: Migrations = { @@ -32,7 +31,6 @@ const migrations: Migrations = { setProviderModelDiscoveryMode, migrateProviderApiKeysToSecretStorage, migrateToMultipleTemplateFolders, - setProviderKind, }; async function migrate(plugin: QuickAdd): Promise { diff --git a/src/migrations/setProviderKind.ts b/src/migrations/setProviderKind.ts deleted file mode 100644 index 86f3b7b8..00000000 --- a/src/migrations/setProviderKind.ts +++ /dev/null @@ -1,37 +0,0 @@ -import type QuickAdd from "src/main"; -import { settingsStore } from "src/settingsStore"; -import type { Migration } from "./Migrations"; -import { deepClone } from "src/utils/deepClone"; -import { getProviderKind } from "src/ai/Provider"; - -/** - * Backfill the `kind` wire-protocol discriminator on every AI provider (#714). The - * tool-calling / structured-output adapter selects on `kind`, not the display name, - * so a custom Anthropic-compatible provider named anything other than "Anthropic" - * routes correctly. Inference (getProviderKind) covers providers without the field - * at runtime; this migration persists the inferred value once. - */ -const setProviderKind: Migration = { - description: "Backfill the wire-protocol kind on each AI provider", - migrate: async (_plugin: QuickAdd) => { - const currentSettings = settingsStore.getState(); - const providers = currentSettings.ai.providers ?? []; - let updated = false; - - for (const provider of providers) { - if (!provider.kind) { - provider.kind = getProviderKind(provider); - updated = true; - } - } - - if (!updated) return; - - settingsStore.setState((state) => ({ - ...state, - ai: { ...state.ai, providers: deepClone(providers) }, - })); - }, -}; - -export default setProviderKind; diff --git a/src/settings.ts b/src/settings.ts index d4471f74..02f2ec81 100644 --- a/src/settings.ts +++ b/src/settings.ts @@ -93,7 +93,6 @@ export interface QuickAddSettings { setProviderModelDiscoveryMode: boolean; migrateProviderApiKeysToSecretStorage: boolean; migrateToMultipleTemplateFolders: boolean; - setProviderKind: boolean; }; } @@ -147,6 +146,5 @@ export const DEFAULT_SETTINGS: QuickAddSettings = { setProviderModelDiscoveryMode: false, migrateProviderApiKeysToSecretStorage: false, migrateToMultipleTemplateFolders: false, - setProviderKind: false, }, };