diff --git a/src/services/naturalLanguageComponentSearchService.test.ts b/src/services/naturalLanguageComponentSearchService.test.ts new file mode 100644 index 000000000..6da0bf7b3 --- /dev/null +++ b/src/services/naturalLanguageComponentSearchService.test.ts @@ -0,0 +1,266 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +import type { ComponentReference } from "@/utils/componentSpec"; + +import { + componentReferenceToCandidate, + NaturalLanguageSearchConfigError, + rerankComponentsByNaturalLanguage, +} from "./naturalLanguageComponentSearchService"; + +const VALID_OPTIONS = { + apiBase: "https://api.example.com/v1", + apiKey: "sk-test", + model: "gpt-4o-mini", +}; + +function mockChatResponse(content: unknown, status = 200) { + return new Response( + JSON.stringify({ + choices: [{ message: { content: JSON.stringify(content) } }], + }), + { + status, + statusText: status === 200 ? "OK" : "Internal Server Error", + }, + ); +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function parseFetchBody(call: unknown[] | undefined): Record { + const init = call?.[1]; + if ( + typeof init !== "object" || + init === null || + !("body" in init) || + typeof init.body !== "string" + ) { + throw new Error("Expected fetch body to be a string"); + } + const { body } = init; + const parsed: unknown = JSON.parse(body); + if (!isRecord(parsed)) { + throw new Error("Expected fetch body to be an object"); + } + return parsed; +} + +describe("componentReferenceToCandidate", () => { + it("returns null for references without a digest", () => { + const ref: ComponentReference = { + spec: { + name: "no_digest", + inputs: [], + outputs: [], + implementation: { container: { image: "x" } }, + }, + }; + expect(componentReferenceToCandidate(ref)).toBeNull(); + }); + + it("returns null when the reference has no useful metadata", () => { + const ref: ComponentReference = { + digest: "abc", + spec: { + inputs: [], + outputs: [], + implementation: { container: { image: "x" } }, + }, + }; + expect(componentReferenceToCandidate(ref)).toBeNull(); + }); + + it("omits empty inputs/outputs from the candidate", () => { + const ref: ComponentReference = { + digest: "abc", + spec: { + name: "train", + description: "trainer", + inputs: [], + outputs: [], + implementation: { container: { image: "x" } }, + }, + }; + const candidate = componentReferenceToCandidate(ref); + expect(candidate).toEqual({ + id: "abc", + name: "train", + description: "trainer", + }); + }); + + it("includes input/output names when present", () => { + const ref: ComponentReference = { + digest: "abc", + spec: { + name: "train", + description: "", + inputs: [{ name: "dataset" }], + outputs: [{ name: "model" }], + implementation: { container: { image: "x" } }, + }, + }; + expect(componentReferenceToCandidate(ref)).toEqual({ + id: "abc", + name: "train", + description: "", + inputs: ["dataset"], + outputs: ["model"], + }); + }); +}); + +describe("rerankComponentsByNaturalLanguage", () => { + beforeEach(() => { + global.fetch = vi.fn(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("returns an empty result for an empty query", async () => { + const result = await rerankComponentsByNaturalLanguage( + "", + [{ id: "a", name: "n", description: "d" }], + VALID_OPTIONS, + ); + expect(result.matches).toEqual([]); + expect(global.fetch).not.toHaveBeenCalled(); + }); + + it("returns an empty result when no candidates are provided", async () => { + const result = await rerankComponentsByNaturalLanguage( + "train", + [], + VALID_OPTIONS, + ); + expect(result.matches).toEqual([]); + expect(global.fetch).not.toHaveBeenCalled(); + }); + + it("throws NaturalLanguageSearchConfigError when API base or key is missing", async () => { + await expect( + rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "n", description: "d" }], + { ...VALID_OPTIONS, apiKey: "" }, + ), + ).rejects.toBeInstanceOf(NaturalLanguageSearchConfigError); + + await expect( + rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "n", description: "d" }], + { ...VALID_OPTIONS, apiBase: "" }, + ), + ).rejects.toBeInstanceOf(NaturalLanguageSearchConfigError); + }); + + it("throws NaturalLanguageSearchConfigError when model is missing", async () => { + await expect( + rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "n", description: "d" }], + { ...VALID_OPTIONS, model: "" }, + ), + ).rejects.toBeInstanceOf(NaturalLanguageSearchConfigError); + }); + + it("filters out hallucinated ids the model returned", async () => { + vi.mocked(global.fetch).mockResolvedValue( + mockChatResponse({ + matches: [ + { id: "a", score: 0.9, reason: "best fit" }, + { id: "ghost", score: 0.8, reason: "made up" }, + ], + }), + ); + + const result = await rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "trainer", description: "" }], + VALID_OPTIONS, + ); + expect(result.matches.map((m) => m.id)).toEqual(["a"]); + }); + + it("clamps out-of-range score values into [0, 1]", async () => { + // NaN scores are intentionally not tested here: JSON.stringify({score: NaN}) + // serializes to `null`, which never reaches `normalizeScore` because + // `isValidMatch` rejects it upstream. + vi.mocked(global.fetch).mockResolvedValue( + mockChatResponse({ + matches: [ + { id: "a", score: 1.5, reason: "over" }, + { id: "b", score: -0.4, reason: "under" }, + ], + }), + ); + + const result = await rerankComponentsByNaturalLanguage( + "train", + [ + { id: "a", name: "a", description: "" }, + { id: "b", name: "b", description: "" }, + ], + VALID_OPTIONS, + ); + const byId = Object.fromEntries(result.matches.map((m) => [m.id, m.score])); + expect(byId.a).toBe(1); + expect(byId.b).toBe(0); + }); + + it("returns empty matches when the response shape is wrong, but keeps raw content", async () => { + vi.mocked(global.fetch).mockResolvedValue( + mockChatResponse({ matches: "not an array" }), + ); + + const result = await rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "trainer", description: "" }], + VALID_OPTIONS, + ); + expect(result.matches).toEqual([]); + expect(result.rawContent).toContain("not an array"); + }); + + it("uses max_completion_tokens for gpt-5 / o-series models", async () => { + vi.mocked(global.fetch).mockResolvedValue( + mockChatResponse({ matches: [] }), + ); + + await rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "a", description: "" }], + { ...VALID_OPTIONS, model: "gpt-5-mini" }, + ); + + const call = vi.mocked(global.fetch).mock.calls[0]; + const body = parseFetchBody(call); + expect(body.max_completion_tokens).toBeDefined(); + expect(body.max_tokens).toBeUndefined(); + expect(body.temperature).toBeUndefined(); + }); + + it("uses max_tokens and temperature for non-reasoning models", async () => { + vi.mocked(global.fetch).mockResolvedValue( + mockChatResponse({ matches: [] }), + ); + + await rerankComponentsByNaturalLanguage( + "train", + [{ id: "a", name: "a", description: "" }], + VALID_OPTIONS, + ); + + const call = vi.mocked(global.fetch).mock.calls[0]; + const body = parseFetchBody(call); + expect(body.max_tokens).toBeDefined(); + expect(body.max_completion_tokens).toBeUndefined(); + expect(body.temperature).toBe(0); + }); +}); diff --git a/src/services/naturalLanguageComponentSearchService.ts b/src/services/naturalLanguageComponentSearchService.ts new file mode 100644 index 000000000..0aa2fe841 --- /dev/null +++ b/src/services/naturalLanguageComponentSearchService.ts @@ -0,0 +1,262 @@ +/** + * LLM reranker for component search. + * + * Takes a small candidate set already pre-filtered by the lexical index (see + * `componentSearchIndex.ts`) and asks an LLM to: + * 1. Reorder by best fit to the user's query + * 2. Write a one-sentence reason per result + * + * The LLM is intentionally NOT used for retrieval — that's the lexical index's + * job. Reranking 20 candidates is fast, cheap, and plays to the LLM's actual + * strength: judgment over a small, well-defined set. + */ + +import type { ComponentReference } from "@/utils/componentSpec"; +import { getComponentName } from "@/utils/getComponentName"; +import { isRecord } from "@/utils/typeGuards"; + +/** + * Compact candidate shape sent to the model. Only the fields that inform + * judgment: name, description, i/o names. Implementation/command text is + * already covered by the lexical layer and would just inflate the prompt. + */ +export interface RerankCandidate { + /** Component digest. Used to round-trip the model's response to references. */ + id: string; + name: string; + description: string; + inputs?: string[]; + outputs?: string[]; +} + +interface RerankedMatch { + id: string; + /** Model-provided relevance, clamped to [0, 1]. */ + score: number; + reason: string; +} + +export interface RerankResult { + matches: RerankedMatch[]; + /** Raw model response, kept for debugging. */ + rawContent?: string; +} + +export class NaturalLanguageSearchConfigError extends Error { + constructor(message: string) { + super(message); + this.name = "NaturalLanguageSearchConfigError"; + } +} + +interface RerankOptions { + signal?: AbortSignal; + /** Model id (OpenAI-compatible). Required. */ + model: string; + /** Base URL of an OpenAI-compatible API. Required. */ + apiBase: string; + /** Bearer token. Required. */ + apiKey: string; +} + +/** + * gpt-5 / o-series reasoning models reject `max_tokens` and require + * `max_completion_tokens` instead. Detect by name prefix and use the + * appropriate field. + */ +function usesCompletionTokensParam(model: string): boolean { + return /^(gpt-5|o\d|openai:gpt-5|openai:o\d)/i.test(model); +} + +/** Clamp score to [0, 1] and reject NaN so the UI/sort never sees garbage. */ +function normalizeScore(value: number): number { + if (Number.isNaN(value)) return 0; + if (value < 0) return 0; + if (value > 1) return 1; + return value; +} + +function isValidMatch(parsed: unknown): parsed is RerankedMatch { + return ( + isRecord(parsed) && + typeof parsed.id === "string" && + typeof parsed.score === "number" && + typeof parsed.reason === "string" + ); +} + +function readChatCompletionContent(payload: unknown): string { + if (!isRecord(payload) || !Array.isArray(payload.choices)) return ""; + + const [firstChoice] = payload.choices; + if (!isRecord(firstChoice) || !isRecord(firstChoice.message)) return ""; + + return typeof firstChoice.message.content === "string" + ? firstChoice.message.content + : ""; +} + +function isMatchArray(value: unknown): value is RerankedMatch[] { + return Array.isArray(value) && value.every(isValidMatch); +} + +/** + * Project a hydrated `ComponentReference` into the compact shape we send to + * the model. Returns null when the reference has no usable metadata — those + * would just waste tokens. + */ +export function componentReferenceToCandidate( + reference: ComponentReference, +): RerankCandidate | null { + if (!reference.digest) return null; + + const spec = reference.spec; + const description = spec?.description?.trim() ?? ""; + const hasUsefulMetadata = + Boolean(spec?.name) || + description.length > 0 || + (spec?.inputs?.length ?? 0) > 0 || + (spec?.outputs?.length ?? 0) > 0; + if (!hasUsefulMetadata) return null; + + const inputs = spec?.inputs + ?.map((i) => i.name) + .filter((n): n is string => typeof n === "string" && n.length > 0); + const outputs = spec?.outputs + ?.map((o) => o.name) + .filter((n): n is string => typeof n === "string" && n.length > 0); + + return { + id: reference.digest, + name: getComponentName(reference), + description, + ...(inputs && inputs.length > 0 ? { inputs } : {}), + ...(outputs && outputs.length > 0 ? { outputs } : {}), + }; +} + +function buildSystemPrompt(): string { + return [ + "You are a reranker for an ML pipeline component search.", + "The user gives you a natural-language query and a small list of candidate components that were already retrieved by lexical search.", + "Your job: reorder the candidates by how well they fit the query's intent, and write one short reason per match.", + "Respond with a single JSON object:", + '{ "matches": [ { "id": "", "score": <0..1>, "reason": "" } ] }', + "Rules:", + "- Include every candidate that plausibly matches the query intent.", + "- Drop candidates that are clearly unrelated.", + '- If none of the candidates fit, return { "matches": [] }.', + "- Order matches from highest to lowest score.", + "- Use the exact id strings provided. Do not invent ids.", + "- Keep each reason under 120 characters.", + ].join("\n"); +} + +function buildUserPrompt(query: string, candidates: RerankCandidate[]): string { + // No pretty-printing: indentation adds ~25-30% to the payload for no signal. + return [ + `Query: ${query}`, + "", + "Candidates to rerank:", + JSON.stringify(candidates), + ].join("\n"); +} + +function validateConfig(options: RerankOptions): { + base: string; + key: string; + model: string; +} { + const base = options.apiBase.trim(); + const key = options.apiKey.trim(); + const model = options.model.trim(); + if (!base || !key) { + throw new NaturalLanguageSearchConfigError( + "Configure your API base URL and key in Settings → Agent Configuration to use AI search.", + ); + } + if (!model) { + throw new NaturalLanguageSearchConfigError( + "No model configured. Set one in Settings → Agent Configuration.", + ); + } + return { base: base.replace(/\/+$/, ""), key, model }; +} + +/** + * Rerank lexical candidates against the user's query. Returns an empty result + * when called with no candidates — callers should fall back to the lexical + * ordering in that case. + */ +export async function rerankComponentsByNaturalLanguage( + query: string, + candidates: RerankCandidate[], + options: RerankOptions, +): Promise { + const trimmed = query.trim(); + if (trimmed.length === 0) return { matches: [] }; + if (candidates.length === 0) return { matches: [] }; + + const { base, key, model } = validateConfig(options); + + const response = await fetch(`${base}/chat/completions`, { + method: "POST", + signal: options.signal, + headers: { + "content-type": "application/json", + authorization: `Bearer ${key}`, + }, + body: JSON.stringify({ + model, + // gpt-5 / o-series reject temperature overrides entirely; omit for them. + ...(usesCompletionTokensParam(model) ? {} : { temperature: 0 }), + // Tiny payload now (≤20 candidates × ~150 chars), so the response is + // bounded. Reasoning models burn budget on hidden thinking tokens — + // give them more headroom. + ...(usesCompletionTokensParam(model) + ? { max_completion_tokens: 2000 } + : { max_tokens: 700 }), + response_format: { type: "json_object" }, + messages: [ + { role: "system", content: buildSystemPrompt() }, + { role: "user", content: buildUserPrompt(trimmed, candidates) }, + ], + }), + }); + + if (!response.ok) { + const detail = await response.text().catch(() => ""); + throw new Error( + `LLM proxy returned ${response.status}: ${detail.slice(0, 200) || response.statusText}`, + ); + } + + const payload: unknown = await response.json(); + const rawContent = readChatCompletionContent(payload); + if (!rawContent) { + throw new Error("LLM proxy returned an empty response"); + } + + let parsed: unknown; + try { + parsed = JSON.parse(rawContent); + } catch { + throw new Error( + `Could not parse LLM response as JSON: ${rawContent.slice(0, 200)}`, + ); + } + + const matchesValue = isRecord(parsed) ? parsed.matches : undefined; + if (!isMatchArray(matchesValue)) { + return { matches: [], rawContent }; + } + + // Drop hallucinated ids and clamp scores. + const validIds = new Set(candidates.map((c) => c.id)); + const matches = matchesValue + .filter((m) => validIds.has(m.id)) + .map((m) => ({ ...m, score: normalizeScore(m.score) })) + .sort((a, b) => b.score - a.score); + + return { matches, rawContent }; +}