diff --git a/web/src/components/assistant-ui/mermaid-diagram.test.tsx b/web/src/components/assistant-ui/mermaid-diagram.test.tsx index 1e00b2f50..463d44d85 100644 --- a/web/src/components/assistant-ui/mermaid-diagram.test.tsx +++ b/web/src/components/assistant-ui/mermaid-diagram.test.tsx @@ -1,35 +1,57 @@ -import { describe, expect, it, vi } from 'vitest' -import { render, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { cleanup, render, waitFor } from '@testing-library/react' const mermaidMocks = vi.hoisted(() => ({ initializeMock: vi.fn(), - renderMock: vi.fn().mockResolvedValue({ - svg: '' - }) + parseMock: vi.fn(), + renderMock: vi.fn(), + setParseErrorHandlerMock: vi.fn(), })) vi.mock('mermaid', () => ({ default: { initialize: mermaidMocks.initializeMock, + parse: mermaidMocks.parseMock, render: mermaidMocks.renderMock, + setParseErrorHandler: mermaidMocks.setParseErrorHandlerMock, } })) import { MermaidDiagram } from '@/components/assistant-ui/mermaid-diagram' import { MARKDOWN_COMPONENTS_BY_LANGUAGE } from '@/components/assistant-ui/markdown-text' +function renderMermaid(code: string) { + return render( +
,
+                Code: (props) => ,
+            }}
+        />
+    )
+}
+
 describe('MermaidDiagram', () => {
+    beforeEach(() => {
+        mermaidMocks.initializeMock.mockClear()
+        mermaidMocks.setParseErrorHandlerMock.mockClear()
+        mermaidMocks.parseMock.mockReset()
+        mermaidMocks.parseMock.mockResolvedValue({ diagramType: 'flowchart-v2' })
+        mermaidMocks.renderMock.mockReset()
+        mermaidMocks.renderMock.mockResolvedValue({
+            svg: ''
+        })
+    })
+
+    afterEach(() => {
+        cleanup()
+        document.documentElement.removeAttribute('data-theme')
+    })
+
     it('is wired into the shared markdown language overrides and renders svg output', async () => {
-        render(
-             B'}
-                language="mermaid"
-                components={{
-                    Pre: (props) => 
,
-                    Code: (props) => ,
-                }}
-            />
-        )
+        renderMermaid('graph TD\nA --> B')
 
         await waitFor(() => {
             const diagram = document.querySelector('[data-mermaid-diagram][data-rendered="true"]')
@@ -39,9 +61,47 @@ describe('MermaidDiagram', () => {
 
         expect(mermaidMocks.initializeMock).toHaveBeenCalled()
         expect(mermaidMocks.initializeMock).toHaveBeenCalledWith(expect.objectContaining({
-            securityLevel: 'strict'
+            securityLevel: 'strict',
+            suppressErrorRendering: true,
         }))
+        expect(mermaidMocks.parseMock).toHaveBeenCalledWith('graph TD\nA --> B', { suppressErrors: true })
         expect(mermaidMocks.renderMock).toHaveBeenCalledWith(expect.stringContaining('mermaid-'), 'graph TD\nA --> B')
         expect(MARKDOWN_COMPONENTS_BY_LANGUAGE.mermaid.SyntaxHighlighter).toBe(MermaidDiagram)
     })
+
+    it('falls back to source and suppresses Mermaid parse-error side effects for invalid syntax', async () => {
+        document.documentElement.dataset.theme = 'dark'
+        mermaidMocks.parseMock.mockResolvedValueOnce(false)
+
+        renderMermaid('graph TD\nA --')
+
+        await waitFor(() => {
+            const fallback = document.querySelector('.aui-mermaid-fallback')
+            expect(fallback).toBeTruthy()
+            expect(fallback?.textContent).toBe('graph TD\nA --')
+        })
+
+        expect(mermaidMocks.parseMock).toHaveBeenCalledWith('graph TD\nA --', { suppressErrors: true })
+        expect(mermaidMocks.renderMock).not.toHaveBeenCalled()
+        expect(mermaidMocks.setParseErrorHandlerMock).toHaveBeenCalled()
+    })
+
+    it('falls back to source and asks Mermaid not to inject its own error SVG when render throws', async () => {
+        mermaidMocks.renderMock.mockRejectedValueOnce(new Error('render failed'))
+        const code = 'gantt\ndateFormat YYYY-MM-DD\nsection A\nTask :a, 2024-01-01'
+
+        renderMermaid(code)
+
+        await waitFor(() => {
+            const fallback = document.querySelector('.aui-mermaid-fallback')
+            expect(fallback).toBeTruthy()
+            expect(fallback?.textContent).toBe(code)
+        })
+
+        expect(mermaidMocks.renderMock).toHaveBeenCalled()
+        expect(mermaidMocks.initializeMock).toHaveBeenCalledWith(expect.objectContaining({
+            suppressErrorRendering: true,
+        }))
+    })
+
 })
diff --git a/web/src/components/assistant-ui/mermaid-diagram.tsx b/web/src/components/assistant-ui/mermaid-diagram.tsx
index b942b62c2..e2b47583c 100644
--- a/web/src/components/assistant-ui/mermaid-diagram.tsx
+++ b/web/src/components/assistant-ui/mermaid-diagram.tsx
@@ -21,9 +21,12 @@ async function ensureMermaid(theme: 'light' | 'dark') {
     const mermaid = await getMermaid()
     if (initializedTheme === theme) return mermaid
 
+    mermaid.setParseErrorHandler(() => undefined)
+
     mermaid.initialize({
         startOnLoad: false,
         securityLevel: 'strict',
+        suppressErrorRendering: true,
         theme: theme === 'dark' ? 'dark' : 'default',
         themeVariables: theme === 'dark'
             ? {
@@ -101,6 +104,14 @@ export function MermaidDiagram(props: SyntaxHighlighterProps) {
         const render = async () => {
             try {
                 const mermaid = await ensureMermaid(theme)
+                const isValid = await mermaid.parse(props.code, { suppressErrors: true })
+                if (cancelled) return
+                if (!isValid) {
+                    setSvg(null)
+                    setRenderError(true)
+                    return
+                }
+
                 const result = await mermaid.render(`mermaid-${id}`, props.code)
                 if (cancelled) return
                 setSvg(result.svg)