diff --git a/common/config.py b/common/config.py index dd59ad1..31f109f 100644 --- a/common/config.py +++ b/common/config.py @@ -434,6 +434,8 @@ def get_graphrag_config(graphname=None): graphrag_config["chunker"] = "semantic" if "extractor" not in graphrag_config: graphrag_config["extractor"] = "llm" +if "tg_memory_schema_on_startup" not in graphrag_config: + graphrag_config["tg_memory_schema_on_startup"] = True # ``retrieval_include_entity`` is resolved at install time # (see ``common.db.retriever_render.resolve_include_entity``). diff --git a/common/gsql/memory/GetLastNMemoryExchanges.gsql b/common/gsql/memory/GetLastNMemoryExchanges.gsql new file mode 100644 index 0000000..9cfa6d5 --- /dev/null +++ b/common/gsql/memory/GetLastNMemoryExchanges.gsql @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024-2026 TigerGraph, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * + * Returns up to n "message" vertices for a conversation_id, highest epoch_added first. + */ + +CREATE OR REPLACE DISTRIBUTED QUERY get_last_n_memory_exchanges(STRING conv_id, UINT n=4) SYNTAX V2 { + TYPEDEF TUPLE MsgRow; + HeapAccum(n, ep DESC) @@heap; + + seeds = {message.*}; + res = SELECT m FROM seeds:m + WHERE m.conversation_id == conv_id + ACCUM @@heap += MsgRow(m, m.epoch_added); + + PRINT @@heap AS rows; +} diff --git a/common/gsql/memory/ListConversationsForUser.gsql b/common/gsql/memory/ListConversationsForUser.gsql new file mode 100644 index 0000000..6c13ce7 --- /dev/null +++ b/common/gsql/memory/ListConversationsForUser.gsql @@ -0,0 +1,11 @@ +/* + * Copyright (c) 2024-2026 TigerGraph, Inc. + * + * List conversation vertices for a given user_id (chat memory sidebar). + */ + +CREATE OR REPLACE DISTRIBUTED QUERY list_conversations_for_user(STRING uid) SYNTAX V2 { + seeds = {conversation.*}; + res = SELECT c FROM seeds:c WHERE c.user_id == uid; + PRINT res AS rows; +} diff --git a/common/gsql/memory/ListMessagesForConversation.gsql b/common/gsql/memory/ListMessagesForConversation.gsql new file mode 100644 index 0000000..333fde5 --- /dev/null +++ b/common/gsql/memory/ListMessagesForConversation.gsql @@ -0,0 +1,11 @@ +/* + * Copyright (c) 2024-2026 TigerGraph, Inc. + * + * List message vertices for a conversation_id attribute (full thread load). + */ + +CREATE OR REPLACE DISTRIBUTED QUERY list_messages_for_conversation(STRING conv_id) SYNTAX V2 { + seeds = {message.*}; + res = SELECT m FROM seeds:m WHERE m.conversation_id == conv_id; + PRINT res AS rows; +} diff --git a/common/gsql/memory/ListMessagesForMemory.gsql b/common/gsql/memory/ListMessagesForMemory.gsql new file mode 100644 index 0000000..6366c79 --- /dev/null +++ b/common/gsql/memory/ListMessagesForMemory.gsql @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024-2026 TigerGraph, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * + * List message vertices for a conversation_id (application id on message rows). + * Python callers MUST strip tracelog — memory pipeline must not use tracelog. + */ + +CREATE OR REPLACE DISTRIBUTED QUERY list_messages_for_memory(STRING conv_id) SYNTAX V2 { + seeds = {message.*}; + res = SELECT m FROM seeds:m WHERE m.conversation_id == conv_id; + PRINT res AS rows; +} diff --git a/common/gsql/memory/Memory_Schema.gsql b/common/gsql/memory/Memory_Schema.gsql new file mode 100644 index 0000000..aac1336 --- /dev/null +++ b/common/gsql/memory/Memory_Schema.gsql @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024-2026 TigerGraph, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * + * Chat memory on the same GraphRAG graph (Explore Graph: vertex types "conversation", "message"). + * - conversation: one vertex per thread. + * - message: one vertex per Q&A (user_content + system_content on the same row). + * - epoch_*: UTC epoch SECONDS (UINT), same style as SupportAI DocumentChunk. + */ + +CREATE SCHEMA_CHANGE JOB add_graphrag_chat_memory { + ADD VERTEX conversation( + PRIMARY_ID conversation_id STRING, + user_id STRING, + epoch_added UINT, + epoch_processed UINT + ) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; + + ADD VERTEX message( + PRIMARY_ID message_id STRING, + conversation_id STRING, + user_content STRING, + system_content STRING, + epoch_added UINT, + tracelog STRING + ) WITH STATS="OUTDEGREE_BY_EDGETYPE", PRIMARY_ID_AS_ATTRIBUTE="true"; + + ADD DIRECTED EDGE has_message(FROM conversation, TO message) WITH REVERSE_EDGE="reverse_has_message"; +} diff --git a/common/llm_services/base_llm.py b/common/llm_services/base_llm.py index bf24588..2a7e5cf 100644 --- a/common/llm_services/base_llm.py +++ b/common/llm_services/base_llm.py @@ -407,6 +407,15 @@ def route_response_prompt(self): Route the user question to one of: `functions`, `vectorstore`, or `history`. +## CRITICAL — Route to `history` FIRST when: +The conversation history is non-empty AND the question is about this conversation itself: +- What questions or messages were exchanged ("what did I ask", "previous questions", "earlier questions") +- What was said, discussed, or answered previously in this session +- Recalling, summarising, or listing prior exchanges in this chat +- Anything referencing "previous", "earlier", "before", "last time", "you said", "I asked", "we discussed", "prior" in the context of THIS conversation + +If the conversation history is EMPTY, do NOT route to `history` — fall through to vectorstore or functions instead. + ## Routing - **`history`**: questions similar to previous ones, or that reference earlier answers / responses, or that refer to the same entities mentioned in a previous answer. - **`vectorstore`**: questions best answered by text documents. diff --git a/common/memory/__init__.py b/common/memory/__init__.py new file mode 100644 index 0000000..278aa81 --- /dev/null +++ b/common/memory/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024-2026 TigerGraph, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. diff --git a/common/memory/tg_memory.py b/common/memory/tg_memory.py new file mode 100644 index 0000000..e3c7f47 --- /dev/null +++ b/common/memory/tg_memory.py @@ -0,0 +1,920 @@ +# Copyright (c) 2024-2026 TigerGraph, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TigerGraph chat memory (same graph as GraphRAG). + +Vertex types in GraphStudio: **conversation**, **message** (see Memory_Schema.gsql). + +Epoch fields use UTC epoch SECONDS as UINT (SupportAI-style). + +message.system_content holds the model/RAG reply; map to LLM role \"assistant\" when building prompts. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import time +import uuid +from datetime import datetime, timezone +from typing import Any + +from pyTigerGraph import TigerGraphConnection + +from common.config import get_graphrag_config, graphrag_config + +logger = logging.getLogger(__name__) + +# `ls` output line for this vertex type (TigerGraph lists lowercase as defined) +_SCHEMA_MARKER = "- VERTEX message" +_QUERY_NAME = "get_last_n_memory_exchanges" +_QUERY_LIST_CONVOS = "list_conversations_for_user" +_QUERY_LIST_MSGS = "list_messages_for_conversation" + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + +_STARTUP_MAX_WAIT_S = 300 +_STARTUP_RETRY_INTERVAL_S = 5 +_UUID_HEX32 = re.compile(r"^[0-9a-fA-F]{32}$") + + +def _gsql_ls_contains_vertex_type(ls_output: str, vertex_type: str) -> bool: + """True if ``gsql 'USE GRAPH g; ls'`` output lists the given vertex type.""" + return bool( + re.search( + rf"-\s*VERTEX\s+{re.escape(vertex_type)}\b", + ls_output or "", + flags=re.IGNORECASE, + ) + ) + + +def _gsql_find_query_ls_line(ls_output: str, query_name: str) -> str | None: + """Return the ``ls`` line for a query by name (e.g. ``- foo(...) (installed v2)``).""" + for line in (ls_output or "").splitlines(): + s = line.strip() + if not s.startswith("- "): + continue + rest = s[2:].strip() + head = rest.split("(", 1)[0].strip() if "(" in rest else (rest.split()[0] if rest else "") + if head == query_name: + return line + return None + + +def _gsql_ls_query_installation_state(ls_output: str, query_name: str) -> str: + """ + Classify query catalog state from ``gsql ls`` (see GSQL ref: ``(draft)``, ``(installed v2)``, …). + + Returns one of: ``missing``, ``installed``, ``needs_reinstall``, ``pending_install``, + ``legacy_no_status`` (assume already OK on older servers that omit status). + """ + line = _gsql_find_query_ls_line(ls_output, query_name) + if line is None: + return "missing" + m = re.search(r"\(([^)]+)\)\s*$", line.strip()) + if not m: + return "missing" + last = m.group(1).strip().lower() + if last.startswith("installed"): + return "installed" + if "pending" in last and "install" in last.replace(" ", ""): + return "pending_install" + if last in ("draft", "deprecated", "disabled") or "failed" in last or "compilation" in last: + return "needs_reinstall" + # Trailing ``(...)`` is the signature, not status — older / minimal ls output + if re.match( + r"^(string|int|uint|float|double|bool|datetime|vertex|edge|set|list|bag)\b", + last, + ): + return "legacy_no_status" + if "," in last or " " in last: + return "legacy_no_status" + return "needs_reinstall" + + +def _gsql_drop_query_best_effort(conn: TigerGraphConnection, graphname: str, query_name: str) -> None: + try: + conn.gsql(f"USE GRAPH {graphname}\nDROP QUERY {query_name}\n") + except Exception: + logger.debug( + "[TG_MEMORY] DROP QUERY %s graph=%s (ignored if absent)", + query_name, + graphname, + exc_info=True, + ) + + +def _install_named_query_from_file( + conn: TigerGraphConnection, + graphname: str, + query_name: str, + relative_path: tuple[str, ...], + out: list[str], +) -> str: + """ + CREATE + INSTALL one query from ``common/gsql/...``; refreshes ``ls`` and returns it. + Caller must have USE GRAPH context via conn.gsql prefixes. + """ + qpath = _gsql_path(*relative_path) + with open(qpath, "r", encoding="utf-8") as f: + q_body = f.read() + q_res = conn.gsql(f"USE GRAPH {graphname}\nBEGIN\n{q_body}\nEND\n") + out.append(f"memory query create {query_name}: {q_res}") + inst = conn.gsql(f"USE GRAPH {graphname}\nINSTALL QUERY {query_name}\n") + out.append(f"memory query install {query_name}: {inst}") + return conn.gsql(f"USE GRAPH {graphname}\n ls") + + +def app_conversation_id_from_vertex_pk(vertex_pk: str) -> str: + """ + Reverse ``_conversation_vertex_primary_id``: TG stores hyphenless UUID32 as PK; + the UI/API use standard UUID strings. + """ + pk = (vertex_pk or "").strip() + if len(pk) == 32 and _UUID_HEX32.match(pk): + return f"{pk[:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}" + return pk + + +def _epoch_u_to_iso(epoch_u: Any) -> str: + try: + ts = int(epoch_u) + except (TypeError, ValueError): + return "" + if ts <= 0: + return "" + return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat().replace("+00:00", "Z") + + +def _conversation_vertex_primary_id(app_conversation_id: str) -> str: + """ + Maps application conversation_id (often a UUID with hyphens) to a valid + TigerGraph PRIMARY_ID for the **conversation** vertex type. + + TigerGraph rejects some primary keys for this type (e.g. strings containing + '-'), which leaves **message** vertices without a linked conversation and + breaks Explore Graph. **message** rows still store the original + ``conversation_id`` attribute for GSQL lookup. + """ + cid = (app_conversation_id or "").strip() + if not cid: + return "_empty_conversation" + no_hyphen = cid.replace("-", "") + if len(no_hyphen) == 32 and _UUID_HEX32.match(no_hyphen): + return no_hyphen.lower() + safe = re.sub(r"[^A-Za-z0-9_]", "_", cid).strip("_") + return (safe or "conversation")[:250] + + +def _gsql_path(*parts: str) -> str: + return os.path.join(_REPO_ROOT, "common", "gsql", *parts) + + +def tg_memory_enabled(graphname: str | None = None) -> bool: + cfg = get_graphrag_config(graphname) if graphname else graphrag_config + return bool(cfg.get("tg_memory_enabled", False)) + + +def install_memory_schema_for_all_graphs_at_startup() -> None: + """ + Wait until TigerGraph accepts ``listGraphs``, then run ``init_memory_schema`` on + each graph using ``db_config`` credentials. + + Skipped when ``graphrag_config["tg_memory_schema_on_startup"]`` is false. + Intended for Docker / process startup so Explore Graph shows **conversation** and + **message** without calling initialize_graph first. + """ + from common.config import db_config + + if db_config.get("username") is None or db_config.get("password") is None: + logger.warning( + "[TG_MEMORY] db_config username/password missing; skipping startup memory schema." + ) + return + + if not graphrag_config.get("tg_memory_schema_on_startup", True): + logger.info("[TG_MEMORY] tg_memory_schema_on_startup is false; skipping startup schema install.") + return + + elapsed = 0 + conn: TigerGraphConnection | None = None + graphs: list[str] = [] + + while elapsed < _STARTUP_MAX_WAIT_S: + try: + conn = TigerGraphConnection( + host=db_config["hostname"], + graphname="", + username=db_config["username"], + password=db_config["password"], + restppPort=db_config.get("restppPort", "9000"), + gsPort=db_config.get("gsPort", "14240"), + ) + if db_config.get("getToken"): + token = conn.getToken()[0] + conn = TigerGraphConnection( + host=db_config["hostname"], + graphname="", + username=db_config["username"], + password=db_config["password"], + apiToken=token, + restppPort=db_config.get("restppPort", "9000"), + gsPort=db_config.get("gsPort", "14240"), + ) + conn.customizeHeader( + timeout=int(db_config.get("default_timeout", 300)) * 1000, + responseSize=5000000, + ) + graph_list = conn.listGraphs() + graphs = [g["graphName"] for g in graph_list if "graphName" in g] + break + except Exception as e: + logger.warning( + "[TG_MEMORY] TigerGraph not ready (%s); retrying in %ss (%ss/%ss)", + e, + _STARTUP_RETRY_INTERVAL_S, + elapsed, + _STARTUP_MAX_WAIT_S, + ) + time.sleep(_STARTUP_RETRY_INTERVAL_S) + elapsed += _STARTUP_RETRY_INTERVAL_S + + if conn is None: + logger.error( + "[TG_MEMORY] No connection after %ss; memory schema not installed at startup.", + _STARTUP_MAX_WAIT_S, + ) + return + + _apply_long_gsql_timeout(conn) + + if not graphs: + logger.info("[TG_MEMORY] No graphs found yet; memory schema install skipped (empty cluster).") + return + + for graphname in graphs: + try: + init_memory_schema(conn, graphname) + except Exception: + logger.warning( + "[TG_MEMORY] init_memory_schema failed for graph=%s", + graphname, + exc_info=True, + ) + + +def _apply_long_gsql_timeout(conn: Any) -> None: + """INSTALL QUERY / schema jobs can exceed default HTTP read timeouts.""" + try: + from common.config import db_config + + base_ms = int(db_config.get("default_timeout", 300)) * 1000 + to_ms = max(base_ms, 600_000) + conn.customizeHeader(timeout=to_ms, responseSize=10_000_000) + except Exception: + pass + + +def init_memory_schema(conn: TigerGraphConnection, graphname: str) -> str: + """ + Idempotent: add conversation + message + has_message; install read queries. + + Called on every graph init (initialize_graph) so types appear in GraphStudio + Explore → Pick vertices, even before any chat is stored. Writes still require + tg_memory_enabled=true. + """ + _apply_long_gsql_timeout(conn) + try: + current = conn.gsql(f"USE GRAPH {graphname}\n ls") + except Exception: + logger.warning("[TG_MEMORY] gsql ls failed graph=%s", graphname, exc_info=True) + return f"init_memory_schema: USE GRAPH {graphname}; ls failed (see logs)" + out: list[str] = [] + + if _SCHEMA_MARKER in current: + out.append("memory schema: already present") + else: + path = _gsql_path("memory", "Memory_Schema.gsql") + try: + with open(path, "r", encoding="utf-8") as f: + schema = f.read() + res = conn.gsql( + f"USE GRAPH {graphname}\n{schema}\nRUN SCHEMA_CHANGE JOB add_graphrag_chat_memory" + ) + out.append(f"memory schema job: {res}") + current = conn.gsql(f"USE GRAPH {graphname}\n ls") + except Exception: + logger.warning( + "[TG_MEMORY] base memory schema job failed graph=%s", + graphname, + exc_info=True, + ) + out.append("memory schema job: FAILED (see logs)") + try: + current = conn.gsql(f"USE GRAPH {graphname}\n ls") + except Exception: + pass + + _nq_state = _gsql_ls_query_installation_state(current, _QUERY_NAME) + if _nq_state in ("installed", "legacy_no_status"): + out.append("memory query: already present") + elif _nq_state == "pending_install": + out.append("memory query: pending install; skipping until catalog settles") + else: + if _nq_state == "needs_reinstall": + _gsql_drop_query_best_effort(conn, graphname, _QUERY_NAME) + out.append("memory query: dropped draft/broken get_last_n_memory_exchanges for reinstall") + try: + current = _install_named_query_from_file( + conn, graphname, _QUERY_NAME, ("memory", "GetLastNMemoryExchanges.gsql"), out + ) + except Exception: + logger.warning( + "[TG_MEMORY] install get_last_n_memory_exchanges failed graph=%s", + graphname, + exc_info=True, + ) + out.append("memory query get_last_n_memory_exchanges: FAILED (see logs)") + try: + current = conn.gsql(f"USE GRAPH {graphname}\n ls") + except Exception: + pass + + _EXTRA_QUERIES: tuple[tuple[str, str], ...] = ( + (_QUERY_LIST_CONVOS, "ListConversationsForUser.gsql"), + (_QUERY_LIST_MSGS, "ListMessagesForConversation.gsql"), + ) + for qinst, qfile in _EXTRA_QUERIES: + _ex_state = _gsql_ls_query_installation_state(current, qinst) + if _ex_state in ("installed", "legacy_no_status"): + out.append(f"memory query {qinst}: already present") + continue + if _ex_state == "pending_install": + out.append(f"memory query {qinst}: pending install; skipping until catalog settles") + continue + if _ex_state == "needs_reinstall": + _gsql_drop_query_best_effort(conn, graphname, qinst) + out.append(f"memory query {qinst}: dropped draft/broken for reinstall") + try: + current = _install_named_query_from_file( + conn, graphname, qinst, ("memory", qfile), out + ) + except Exception: + logger.warning( + "[TG_MEMORY] install query %s (%s) failed for graph=%s", + qinst, + qfile, + graphname, + exc_info=True, + ) + out.append(f"memory query {qinst}: FAILED (see logs)") + try: + current = conn.gsql(f"USE GRAPH {graphname}\n ls") + except Exception: + pass + + summary = "\n".join(out) + logger.info("init_memory_schema for %s: %s", graphname, summary) + return summary + + +def write_exchange_to_tg_memory( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, + user_id: str, + user_content: str, + system_content: str, + *, + tracelog: str = "", + exchange_message_id: str | None = None, + is_new_conversation: bool = False, +) -> None: + """ + Persist one Q&A as a single **message** vertex linked to **conversation**. + """ + if not tg_memory_enabled(graphname): + return + + try: + init_memory_schema(conn, graphname) + + import uuid + + mid = exchange_message_id or str(uuid.uuid4()) + ts = int(time.time()) + conv_vertex_id = _conversation_vertex_primary_id(conversation_id) + + try: + df = conn.getVertexDataFrameById("conversation", conv_vertex_id) + exists = df is not None and len(df) > 0 + except Exception: + exists = False + + if exists and not is_new_conversation: + conn.upsertVertex( + "conversation", + conv_vertex_id, + attributes={ + "user_id": user_id, + "epoch_processed": ts, + }, + ) + else: + conn.upsertVertex( + "conversation", + conv_vertex_id, + attributes={ + "user_id": user_id, + "epoch_added": ts, + "epoch_processed": ts, + }, + ) + + conn.upsertVertex( + "message", + mid, + attributes={ + "conversation_id": conversation_id, + "user_content": user_content or "", + "system_content": system_content or "", + "epoch_added": ts, + "tracelog": tracelog or "", + }, + ) + conn.upsertEdge( + "conversation", + conv_vertex_id, + "has_message", + "message", + mid, + ) + logger.debug( + "tg_memory: wrote exchange message_id=%s conversation_id=%s conv_vertex_id=%s graph=%s", + mid, + conversation_id, + conv_vertex_id, + graphname, + ) + except Exception: + logger.warning( + "tg_memory: failed to write exchange conv=%s graph=%s", + conversation_id, + graphname, + exc_info=True, + ) + + +def ensure_conversation_shell_for_ui( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, + user_id: str, +) -> None: + """ + Upsert an empty **conversation** vertex as soon as the UI thread id exists. + + Lets ``GET /ui/user`` and ``GET /ui/conversation/{id}`` work **before** the first + assistant reply (which would otherwise be the first TG write). + """ + if not tg_memory_enabled(graphname): + return + try: + init_memory_schema(conn, graphname) + ts = int(time.time()) + conv_vertex_id = _conversation_vertex_primary_id(conversation_id) + conn.upsertVertex( + "conversation", + conv_vertex_id, + attributes={ + "user_id": user_id, + "epoch_added": ts, + "epoch_processed": ts, + }, + ) + logger.debug( + "tg_memory: ensured conversation shell conv_vertex_id=%s graph=%s", + conv_vertex_id, + graphname, + ) + except Exception: + logger.warning( + "tg_memory: ensure_conversation_shell_for_ui failed conv=%s graph=%s", + conversation_id, + graphname, + exc_info=True, + ) + + +def get_last_n_memory_exchanges( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, + n: int = 4, +) -> list[dict[str, Any]]: + """ + Return up to n **message** vertices, newest first by epoch_added. + """ + if not tg_memory_enabled(graphname): + return [] + init_memory_schema(conn, graphname) + try: + raw = conn.runInstalledQuery( + _QUERY_NAME, + params={"conv_id": conversation_id, "n": int(n)}, + ) + except Exception: + logger.warning("get_last_n_memory_exchanges failed", exc_info=True) + return [] + + rows: list[dict[str, Any]] = [] + for block in raw or []: + if "rows" not in block: + continue + for item in block["rows"]: + vrow = _vertex_row_from_print_item(item) if isinstance(item, dict) else None + if not vrow: + rows.append({"raw": item}) + continue + attrs = _attrs_from_vertex_row(vrow) + rows.append( + { + "message_id": _primary_id_from_vertex_row(vrow), + "user_content": attrs.get("user_content"), + "system_content": attrs.get("system_content"), + "epoch_added": attrs.get("epoch_added"), + } + ) + + if rows and all("message_id" in r for r in rows): + return rows + + if raw: + logger.debug( + "get_last_n_memory_exchanges parse fallback raw=%s", + json.dumps(raw, default=str)[:2000], + ) + return rows + + +def _vertex_row_from_print_item(item: dict[str, Any]) -> dict[str, Any] | None: + """ + GSQL ``PRINT seedSet AS rows`` yields rows like + ``{"v_id": ..., "v_type": ..., "attributes": {...}}`` directly. Older + SELECT-into-tuple patterns wrap them as ``{"c": {...}}`` / ``{"v": {...}}``; + handle both. + """ + if not isinstance(item, dict): + return None + if "v_id" in item or "primary_id" in item or "v_type" in item: + return item + for key in ("c", "m", "res", "v"): + v = item.get(key) + if isinstance(v, dict): + return v + return None + + +def _attrs_from_vertex_row(v: dict[str, Any]) -> dict[str, Any]: + attr = v.get("attributes") + if isinstance(attr, dict): + return attr + out: dict[str, Any] = {} + for k in ( + "user_id", + "epoch_added", + "epoch_processed", + "conversation_id", + "user_content", + "system_content", + "tracelog", + ): + if k in v and k != "attributes": + out[k] = v[k] + return out + + +def _primary_id_from_vertex_row(v: dict[str, Any]) -> str | None: + pid = v.get("id") or v.get("primary_id") or v.get("v_id") + if pid is None: + return None + return str(pid) + + +def list_conversations_for_user( + conn: TigerGraphConnection, + graphname: str, + user_id: str, +) -> list[dict[str, Any]]: + """ + Return conversation summaries for the sidebar (JSON-serializable dicts). + """ + if not tg_memory_enabled(graphname): + return [] + init_memory_schema(conn, graphname) + try: + raw = conn.runInstalledQuery( + _QUERY_LIST_CONVOS, + params={"uid": user_id}, + ) + except Exception: + logger.warning("list_conversations_for_user failed", exc_info=True) + return [] + + rows: list[dict[str, Any]] = [] + for block in raw or []: + if "rows" not in block: + continue + for item in block["rows"]: + vrow = _vertex_row_from_print_item(item) if isinstance(item, dict) else None + if not vrow: + continue + pk = _primary_id_from_vertex_row(vrow) + if not pk: + continue + attrs = _attrs_from_vertex_row(vrow) + ea = attrs.get("epoch_added") + ep = attrs.get("epoch_processed") + cid = app_conversation_id_from_vertex_pk(pk) + ts_create = _epoch_u_to_iso(ea) + ts_update = _epoch_u_to_iso(ep) or ts_create + rows.append( + { + "conversation_id": cid, + "user_id": attrs.get("user_id") or user_id, + "create_ts": ts_create, + "update_ts": ts_update, + "name": "", + } + ) + + rows.sort( + key=lambda r: r.get("update_ts") or r.get("create_ts") or "", + reverse=True, + ) + return rows + + +def list_messages_sorted_by_epoch( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, +) -> list[dict[str, Any]]: + """ + Raw TG **message** rows for one conversation (application conversation_id attribute). + Sorted ascending by epoch_added. + """ + if not tg_memory_enabled(graphname): + return [] + init_memory_schema(conn, graphname) + try: + raw = conn.runInstalledQuery( + _QUERY_LIST_MSGS, + params={"conv_id": conversation_id}, + ) + except Exception: + logger.warning("list_messages_for_conversation failed", exc_info=True) + return [] + + parsed: list[tuple[int, dict[str, Any]]] = [] + for block in raw or []: + if "rows" not in block: + continue + for item in block["rows"]: + vrow = _vertex_row_from_print_item(item) if isinstance(item, dict) else None + if not vrow: + continue + pk = _primary_id_from_vertex_row(vrow) + if not pk: + continue + attrs = _attrs_from_vertex_row(vrow) + try: + ep = int(attrs.get("epoch_added") or 0) + except (TypeError, ValueError): + ep = 0 + parsed.append( + ( + ep, + { + "message_id": str(pk), + "user_content": attrs.get("user_content") or "", + "system_content": attrs.get("system_content") or "", + "epoch_added": attrs.get("epoch_added"), + "tracelog": attrs.get("tracelog") or "", + }, + ) + ) + + parsed.sort(key=lambda x: x[0]) + return [p[1] for p in parsed] + + +def conversation_rows_to_ui_messages( + conversation_id: str, + rows: list[dict[str, Any]], + *, + model_name: str = "unknown", +) -> list[dict[str, Any]]: + """ + Expand each TG exchange vertex into user + system messages (legacy SQLite / UI shape). + """ + out: list[dict[str, Any]] = [] + prev_assistant_id: str | None = None + for row in rows: + mid = row["message_id"] + user_mid = str(uuid.uuid5(uuid.NAMESPACE_URL, f"graphrag:{mid}:user")) + ts = _epoch_u_to_iso(row.get("epoch_added")) + tl_raw = row.get("tracelog") or "" + qs: dict[str, Any] | None = None + answered = False + resp_time = 0.0 + if isinstance(tl_raw, str) and tl_raw.strip(): + try: + tl = json.loads(tl_raw) + if isinstance(tl, dict): + qs = tl.get("query_sources") + answered = bool(tl.get("answered_question", False)) + resp_time = float(tl.get("response_time") or 0.0) + except json.JSONDecodeError: + qs = None + + out.append( + { + "conversation_id": conversation_id, + "message_id": user_mid, + "parent_id": prev_assistant_id, + "model": model_name, + "content": row.get("user_content") or "", + "role": "user", + "response_time": 0.0, + "answered_question": False, + "response_type": "history", + "query_sources": None, + "create_ts": ts, + "update_ts": ts, + } + ) + out.append( + { + "conversation_id": conversation_id, + "message_id": mid, + "parent_id": user_mid, + "model": model_name, + "content": row.get("system_content") or "", + "role": "system", + "response_time": resp_time, + "answered_question": answered, + "response_type": "inquiryai", + "query_sources": qs if isinstance(qs, dict) else {}, + "create_ts": ts, + "update_ts": ts, + # Original user message that produced this assistant reply, + # so the Trace page can show "Original Query" for history items + # without an extra round-trip. + "user_query": row.get("user_content") or "", + } + ) + prev_assistant_id = mid + return out + + +def agent_history_from_messages(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Build ``question_for_agent`` history from TG exchange rows (ordered ascending).""" + hist: list[dict[str, Any]] = [] + for row in rows: + ts = _epoch_u_to_iso(row.get("epoch_added")) + hist.append( + { + "query": row.get("user_content") or "", + "response": row.get("system_content") or "", + "create_ts": ts, + "update_ts": ts, + } + ) + return hist + + +def verify_conversation_owner( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, + user_id: str, +) -> bool: + """True if a conversation vertex exists and ``user_id`` matches.""" + if not tg_memory_enabled(graphname): + return False + init_memory_schema(conn, graphname) + conv_vid = _conversation_vertex_primary_id(conversation_id) + try: + df = conn.getVertexDataFrameById("conversation", conv_vid) + if df is None or len(df) == 0: + return False + row = df.iloc[0] + owner = row.get("user_id", "") + return str(owner) == str(user_id) + except Exception: + logger.warning("verify_conversation_owner failed", exc_info=True) + return False + + +def delete_message_vertices( + conn: TigerGraphConnection, + graphname: str, + message_ids: list[str], +) -> int: + """Permanently delete message vertices by primary id. Returns count deleted.""" + if not message_ids or not tg_memory_enabled(graphname): + return 0 + init_memory_schema(conn, graphname) + ids = [str(m) for m in message_ids if m] + if not ids: + return 0 + try: + conn.delVerticesById("message", ids, permanent=True) + return len(ids) + except Exception: + logger.warning( + "delete_message_vertices failed graph=%s count=%s", + graphname, + len(ids), + exc_info=True, + ) + raise + + +def delete_conversation_thread( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, +) -> int: + """ + Delete all **message** vertices for ``conversation_id`` and the **conversation** vertex. + Returns number of message vertices deleted. + """ + if not tg_memory_enabled(graphname): + return 0 + init_memory_schema(conn, graphname) + msgs = list_messages_sorted_by_epoch(conn, graphname, conversation_id) + ids = [m["message_id"] for m in msgs] + if ids: + try: + conn.delVerticesById("message", ids, permanent=True) + except Exception: + logger.warning("delete messages failed", exc_info=True) + raise + conv_vid = _conversation_vertex_primary_id(conversation_id) + try: + conn.delVerticesById("conversation", conv_vid, permanent=True) + except Exception: + logger.warning("delete conversation vertex failed", exc_info=True) + raise + return len(ids) + + +def get_message_tracelog( + conn: TigerGraphConnection, + graphname: str, + message_id: str, +) -> dict[str, Any] | None: + """ + Fetch and parse ``message.tracelog`` JSON from TigerGraph memory. + Returns None when the message does not exist or tracelog is empty/invalid. + """ + try: + init_memory_schema(conn, graphname) + df = conn.getVertexDataFrameById("message", message_id) + if df is None or len(df) == 0: + return None + + row = df.iloc[0] + raw = row.get("tracelog", "") + if not isinstance(raw, str) or not raw.strip(): + return None + + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + return {"tracelog": parsed} + except Exception: + logger.warning( + "tg_memory: failed to fetch tracelog message_id=%s graph=%s", + message_id, + graphname, + exc_info=True, + ) + return None diff --git a/docker-compose.yml b/docker-compose.yml index 1034ead..d0f38d3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,14 +10,15 @@ services: - 8000:8000 depends_on: - graphrag-ecc - - chat-history -# - tigergraph + - tigergraph environment: SERVER_CONFIG: "/code/configs/server_config.json" LOGLEVEL: "INFO" USE_CYPHER: "true" volumes: - ./configs/:/code/configs + - ./trace_logs/:/code/trace_logs + - ./common:/code/common graphrag-ecc: image: tigergraph/graphrag-ecc:latest @@ -35,21 +36,6 @@ services: volumes: - ./configs/:/code/configs - chat-history: - image: tigergraph/chat-history:latest - platform: linux/amd64 - container_name: chat-history - build: - context: chat-history/ - dockerfile: Dockerfile - ports: - - 8002:8002 - environment: - CONFIG_FILES: "/configs/server_config.json" - LOGLEVEL: "INFO" - volumes: - - ./configs/:/configs - graphrag-ui: image: tigergraph/graphrag-ui:latest platform: linux/amd64 @@ -75,14 +61,14 @@ services: - graphrag-ui - graphrag -# tigergraph: -# image: tigergraph/community:4.2.2 -# container_name: tigergraph -# platform: linux/amd64 -# ports: -# - "14240:14240" -# volumes: -# - tigergraph_data:/home/tigergraph/tigergraph/data -# -# volumes: -# tigergraph_data: + tigergraph: + image: tigergraph/community:4.2.2 + platform: linux/amd64 + container_name: tigergraph + ports: + - "14240:14240" + volumes: + - tigergraph_data:/home/tigergraph/tigergraph/data + +volumes: + tigergraph_data: diff --git a/graphrag-ui/src/actions/ActionProvider.tsx b/graphrag-ui/src/actions/ActionProvider.tsx index c73c182..9d8fcb1 100644 --- a/graphrag-ui/src/actions/ActionProvider.tsx +++ b/graphrag-ui/src/actions/ActionProvider.tsx @@ -190,6 +190,33 @@ const ActionProvider: React.FC = ({ } }, [createChatBotMessage, createClientMessage, setState]); + // Listen for message-delete events fired by CustomChatMessage and remove + // both the bot answer and the preceding user question from chatbot state. + useEffect(() => { + const handler = (e: Event) => { + const { msgId } = (e as CustomEvent).detail; + setState((prev: any) => { + const msgs: any[] = prev.messages || []; + const botIdx = msgs.findIndex( + (m: any) => + m.message?.message_id === msgId || + m.message?.messageId === msgId + ); + if (botIdx === -1) return prev; + const toRemove = new Set([botIdx]); + if (botIdx > 0 && msgs[botIdx - 1]?.type === "user") { + toRemove.add(botIdx - 1); + } + return { + ...prev, + messages: msgs.filter((_: any, i: number) => !toRemove.has(i)), + }; + }); + }; + window.addEventListener("graphrag:messageDeleted", handler); + return () => window.removeEventListener("graphrag:messageDeleted", handler); + }, [setState]); + // eslint-disable-next-line // @ts-ignore const queryGraphragWs2 = useCallback((msg: string) => { diff --git a/graphrag-ui/src/components/CustomChatMessage.tsx b/graphrag-ui/src/components/CustomChatMessage.tsx index 43937ef..9fc1acc 100755 --- a/graphrag-ui/src/components/CustomChatMessage.tsx +++ b/graphrag-ui/src/components/CustomChatMessage.tsx @@ -176,6 +176,63 @@ export const CustomChatMessage: FC = ({ const [showTableVis, setShowTableVis] = useState(false); const [traceMessageId, setTraceMessageId] = useState(null); const [alert, alertDialog] = useAlert(); + const [confirmDelete, setConfirmDelete] = useState(false); + const [isDeleted, setIsDeleted] = useState(false); + + const handleDeleteMessage = async (msgId: string) => { + const creds = sessionStorage.getItem("creds"); + const graphname = sessionStorage.getItem("selectedGraph"); + if (!creds || !graphname || !msgId) return; + try { + const res = await fetch( + `/ui/message/${encodeURIComponent(msgId)}?graphname=${encodeURIComponent(graphname)}`, + { + method: "DELETE", + headers: { Authorization: `Basic ${creds}` }, + } + ); + if (res.ok) { + setIsDeleted(true); + window.dispatchEvent( + new CustomEvent("graphrag:messageDeleted", { detail: { msgId } }) + ); + const convDataRaw = sessionStorage.getItem("selectedConversationData"); + if (convDataRaw) { + try { + const convData = JSON.parse(convDataRaw); + let messages: any[] = []; + let wrap: (f: any[]) => any = (f) => f; + if (Array.isArray(convData)) { + messages = convData; + } else if (Array.isArray(convData.messages)) { + messages = convData.messages; + wrap = (f) => ({ ...convData, messages: f }); + } else if (Array.isArray(convData.content)) { + messages = convData.content; + wrap = (f) => ({ ...convData, content: f }); + } + const sysIdx = messages.findIndex((m: any) => m.message_id === msgId); + if (sysIdx !== -1) { + const toRemove = new Set([sysIdx]); + if (sysIdx > 0 && messages[sysIdx - 1]?.role === "user") { + toRemove.add(sysIdx - 1); + } + const filtered = messages.filter((_: any, i: number) => !toRemove.has(i)); + sessionStorage.setItem("selectedConversationData", JSON.stringify(wrap(filtered))); + } + } catch { /* ignore */ } + } + } else { + const errText = await res.text().catch(() => ""); + console.error(`Delete message failed (${res.status}):`, errText); + await alert(`Failed to delete message: ${res.status} ${errText}`); + } + } catch (err) { + console.error("Delete message error:", err); + } + }; + + if (isDeleted) return null; // Error handling functions const handleShowExplain = () => { @@ -222,6 +279,30 @@ export const CustomChatMessage: FC = ({ return ( <> {alertDialog} + {/* Confirm delete message dialog */} + + +

Delete this message and its question? This cannot be undone.

+
+ + +
+
+
{traceMessageId && ( = ({ showExplain={handleShowExplain} showTable={handleShowTable} showGraph={handleShowGraph} + onDelete={message.messageId || message.message_id ? () => setConfirmDelete(true) : undefined} onViewTrace={async () => { const messageId = message.messageId || message.message_id || ""; if (!messageId) { await alert("Trace log unavailable: this message has no trace ID."); return; } - // Guard against a missing/invalid creds value. If we send - // ``Basic null`` (or other unparsable base64), FastAPI's - // HTTPBasic returns 401 + ``WWW-Authenticate: Basic`` and - // the browser pops up its native auth dialog. Better to - // tell the user to sign in again than to flash that popup. const creds = sessionStorage.getItem("creds"); if (!creds) { await alert("Your session has expired. Please log in again."); return; } - // Trace JSON lives under /code/trace_logs inside the - // graphrag container and is wiped on container recreate. - // Probe first so we never open an empty dialog when the file is gone. try { const probe = await fetch(`/ui/trace/${messageId}`, { method: "GET", diff --git a/graphrag-ui/src/components/Interact.tsx b/graphrag-ui/src/components/Interact.tsx index ae93539..466121b 100644 --- a/graphrag-ui/src/components/Interact.tsx +++ b/graphrag-ui/src/components/Interact.tsx @@ -11,6 +11,7 @@ import { Feedback, Message } from "@/actions/ActionProvider"; import { PiGraph } from "react-icons/pi"; import { FaTable } from "react-icons/fa"; import { LuInfo, LuActivity } from "react-icons/lu"; +import { RiDeleteBin6Line } from "react-icons/ri"; import { useRoles } from "@/hooks/useRoles"; const GRAPHRAG_URL = ""; @@ -20,6 +21,7 @@ interface Interactions { showTable: () => boolean; showGraph: () => boolean; onViewTrace?: () => void; + onDelete?: () => void; } export const Interactions: FC = ({ @@ -28,6 +30,7 @@ export const Interactions: FC = ({ showTable, showGraph, onViewTrace, + onDelete, }: Interactions) => { // Seed from the persisted feedback when re-rendering a history // message so the up/down state matches what the user already @@ -167,6 +170,16 @@ export const Interactions: FC = ({ + {onDelete && ( +
onDelete()} + > + +
+ )} + ) : null} diff --git a/graphrag-ui/src/components/SideMenu.tsx b/graphrag-ui/src/components/SideMenu.tsx index 79fe501..30415ed 100644 --- a/graphrag-ui/src/components/SideMenu.tsx +++ b/graphrag-ui/src/components/SideMenu.tsx @@ -8,6 +8,7 @@ import { IoIosHelpCircleOutline } from "react-icons/io"; import { HiOutlineChatBubbleOvalLeft } from "react-icons/hi2"; import { MdKeyboardArrowDown, MdKeyboardArrowUp } from "react-icons/md"; import { IoIosArrowForward } from "react-icons/io"; +import { RiDeleteBin6Line } from "react-icons/ri"; import { useTheme } from "@/components/ThemeProvider"; import { safeJson } from "@/utils/safeJson"; import { GoGear } from "react-icons/go"; @@ -76,6 +77,7 @@ const SideMenu = ({ const [newSet, setNewSet] = useState([]); const [expandedConversations, setExpandedConversations] = useState>(new Set()); const [activeConversationId, setActiveConversationId] = useState(null); + const [deletingConvoId, setDeletingConvoId] = useState(null); // Fade + disable the side menu (conversation list + New Chat) while // the chat is streaming an answer, so the user can't unmount Chat by // switching conversations mid-response. @@ -257,6 +259,34 @@ const SideMenu = ({ }); } + const deleteConversation = async (conversationId: string) => { + const creds = sessionStorage.getItem("creds"); + const graphname = sessionStorage.getItem("selectedGraph"); + if (!creds || !graphname) return; + try { + const res = await fetch( + `/ui/conversation/${encodeURIComponent(conversationId)}?graphname=${encodeURIComponent(graphname)}`, + { + method: "DELETE", + headers: { Authorization: `Basic ${creds}` }, + } + ); + if (res.ok) { + setDeletingConvoId(null); + setConversationId((prev: any[]) => + prev.filter((c: any) => c?.conversation_id !== conversationId) + ); + if (conversationId === activeConversationId) { + conversationManager.startNewConversation(); + sessionStorage.removeItem("selectedConversationData"); + } + fetchHistory2(); + } + } catch { + setDeletingConvoId(null); + } + }; + const renderConvoHistory = () => { if (newSet.length === 0) { return ( @@ -308,7 +338,7 @@ const SideMenu = ({
{ e.preventDefault(); resumeConvo(item.conversation_id); @@ -332,6 +362,17 @@ const SideMenu = ({ )} )} + {isExpanded && userMessages.length > 1 && (
@@ -609,6 +650,29 @@ const SideMenu = ({ Chat history + {/* Delete conversation confirm dialog */} + { if (!open) setDeletingConvoId(null); }}> + + + Delete Conversation? + + This will permanently delete this conversation and all its messages. This cannot be undone. + + + + + + + + + + + {renderConvoHistory()} {/*
bool: + """Return True if question is clearly about THIS conversation's history AND history exists.""" + if not conversation: + return False + return bool(self._HISTORY_QUESTION_RE.search(question)) + def route_question(self, state): """ Run the agent router. @@ -180,6 +205,16 @@ def route_question(self, state): return "apologize" if self._is_greeting(state["question"]): return "greeting" + + # Fast-path: if the question is clearly about prior exchanges in this + # conversation AND there is actually history available, skip the LLM + # router and go straight to history lookup. + if self._is_history_question(state["question"], state.get("conversation", [])): + logger.debug_pii( + f"request_id={req_id_cv.get()} Pre-routing to history_lookup (keyword match)" + ) + return "history_lookup" + self.emit_progress("Thinking") step = TigerGraphAgentRouter(self.llm_provider, self.db_connection) logger.debug_pii( diff --git a/graphrag/app/main.py b/graphrag/app/main.py index a4d0ec0..58db825 100644 --- a/graphrag/app/main.py +++ b/graphrag/app/main.py @@ -14,9 +14,11 @@ import json import logging +import threading import time import uuid from base64 import b64decode +from contextlib import asynccontextmanager from datetime import datetime import routers @@ -32,12 +34,36 @@ from common.logs.logwriter import LogWriter from common.metrics.prometheus_metrics import metrics as pmetrics +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def _tg_memory_schema_lifespan(app: FastAPI): + def _run_schema_install(): + try: + from common.memory import tg_memory + + tg_memory.install_memory_schema_for_all_graphs_at_startup() + except Exception: + logger.exception( + "TG memory schema startup install failed; types may appear only after " + "initialize_graph or when TigerGraph is reachable." + ) + + threading.Thread(target=_run_schema_install, daemon=True).start() + yield + + if PRODUCTION: app = FastAPI( - title="TigerGraph GraphRAG", docs_url=None, redoc_url=None, openapi_url=None + title="TigerGraph GraphRAG", + docs_url=None, + redoc_url=None, + openapi_url=None, + lifespan=_tg_memory_schema_lifespan, ) else: - app = FastAPI(title="TigerGraph GraphRAG") + app = FastAPI(title="TigerGraph GraphRAG", lifespan=_tg_memory_schema_lifespan) app.add_middleware( CORSMiddleware, @@ -56,8 +82,6 @@ excluded_metrics_paths = ("/docs", "/openapi.json", "/metrics") -logger = logging.getLogger(__name__) - logger.info("In main.py") diff --git a/graphrag/app/routers/ui.py b/graphrag/app/routers/ui.py index ff27e57..31c228e 100644 --- a/graphrag/app/routers/ui.py +++ b/graphrag/app/routers/ui.py @@ -52,6 +52,7 @@ from tools.validation_utils import MapQuestionToSchemaException from common.config import db_config, graphrag_config, embedding_service, llm_config, service_status, get_chat_config, get_completion_config, get_embedding_config, get_multimodal_config, validate_graphname, get_llm_service, resolve_llm_services +from common.memory import tg_memory from common.db.connections import get_db_connection_pwd_manual from common.db import schema_utils as schema_utils_mod from common.db import schema_extraction as schema_extraction_mod @@ -1725,25 +1726,31 @@ async def serve_image_from_vertex( @router.get(route_prefix + "/user/{user_id}") async def get_user_conversations( user_id: str, + graphname: Annotated[str, Query(description="TigerGraph graph for chat memory")], creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], ): - creds = creds[1] - auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() + graphs, cred = creds + validate_graphname(graphname) + if graphname not in graphs: + raise HTTPException(status_code=403, detail="Insufficient permissions for this graph.") + if user_id != cred.username: + raise HTTPException( + status_code=403, + detail="Not authorized to list conversations for another user.", + ) + auth = base64.b64encode(f"{cred.username}:{cred.password}".encode()).decode() try: - async with httpx.AsyncClient() as client: - res = await client.get( - f"{graphrag_config['chat_history_api']}/user/{user_id}", - headers={"Authorization": f"Basic {auth}"}, - ) - res.raise_for_status() + _, conn = ws_basic_auth(auth, graphname) + convos = await asyncio.to_thread( + tg_memory.list_conversations_for_user, conn, graphname, user_id + ) + return convos except Exception as e: exc = traceback.format_exc() logger.debug_pii( f"/ui/user/{user_id} request_id={req_id_cv.get()} Exception Trace:\n{exc}" ) - raise e - - return res.json() + raise HTTPException(status_code=500, detail=str(e)) from e @router.get(route_prefix + "/roles") @@ -1759,78 +1766,105 @@ async def get_user_roles( @router.get(route_prefix + "/conversation/{conversation_id}") async def get_conversation_contents( conversation_id: str, + graphname: Annotated[str, Query(description="TigerGraph graph for chat memory")], creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], ): - creds = creds[1] - auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() + graphs, cred = creds + validate_graphname(graphname) + if graphname not in graphs: + raise HTTPException(status_code=403, detail="Insufficient permissions for this graph.") + auth = base64.b64encode(f"{cred.username}:{cred.password}".encode()).decode() try: - async with httpx.AsyncClient() as client: - res = await client.get( - f"{graphrag_config['chat_history_api']}/conversation/{conversation_id}", - headers={"Authorization": f"Basic {auth}"}, - ) - res.raise_for_status() + _, conn = ws_basic_auth(auth, graphname) + if not tg_memory.verify_conversation_owner(conn, graphname, conversation_id, cred.username): + raise HTTPException(status_code=403, detail="Conversation not found or access denied.") + rows = await asyncio.to_thread( + tg_memory.list_messages_sorted_by_epoch, + conn, + graphname, + conversation_id, + ) + model_name = get_chat_config(graphname).get("llm_model", "unknown") + messages = tg_memory.conversation_rows_to_ui_messages( + conversation_id, rows, model_name=model_name + ) + return messages + except HTTPException: + raise except Exception as e: exc = traceback.format_exc() logger.debug_pii( f"/conversation/{conversation_id} request_id={req_id_cv.get()} Exception Trace:\n{exc}" ) - raise e + raise HTTPException(status_code=500, detail=str(e)) from e - return res.json() -@router.get(route_prefix + "/get_feedback") -async def get_conversation_feedback( +@router.delete(route_prefix + "/conversation/{conversation_id}") +async def delete_conversation( + conversation_id: str, + graphname: Annotated[str, Query(description="TigerGraph graph for chat memory")], creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], ): - creds = creds[1] - auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() + """Delete a conversation and all its messages from TigerGraph memory.""" + graphs, cred = creds + validate_graphname(graphname) + if graphname not in graphs: + raise HTTPException(status_code=403, detail="Insufficient permissions for this graph.") + auth = base64.b64encode(f"{cred.username}:{cred.password}".encode()).decode() try: - async with httpx.AsyncClient() as client: - res = await client.get( - f"{graphrag_config['chat_history_api']}/get_feedback", - headers={"Authorization": f"Basic {auth}"}, - ) - res.raise_for_status() - except httpx.HTTPStatusError as e: - logger.error(f"HTTP error occurred: {e}") - raise HTTPException(status_code=e.response.status_code, detail="Failed to fetch feedback") + _, conn = ws_basic_auth(auth, graphname) + if not tg_memory.verify_conversation_owner(conn, graphname, conversation_id, cred.username): + raise HTTPException(status_code=403, detail="Conversation not found or access denied.") + await asyncio.to_thread( + tg_memory.delete_conversation_thread, + conn, + graphname, + conversation_id, + ) + except HTTPException: + raise except Exception as e: exc = traceback.format_exc() logger.debug_pii( - f"/get_feedback request_id={req_id_cv.get()} Exception Trace:\n{exc}" + f"/conversation/{conversation_id} DELETE request_id={req_id_cv.get()} Exception Trace:\n{exc}" ) - raise HTTPException(status_code=500, detail="Internal server error") + raise HTTPException(status_code=500, detail=str(e)) from e - return res.json() + return {"message": "Conversation deleted successfully"} -@router.delete(route_prefix + "/conversation/{conversation_id}") -async def delete_conversation( - conversation_id: str, +@router.delete(route_prefix + "/message/{message_id}") +async def delete_message( + message_id: str, + graphname: Annotated[str, Query(description="TigerGraph graph for chat memory")], creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], ): - """Delete a conversation and all its messages.""" - creds = creds[1] - auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() + """Delete a single message vertex (Q&A pair) from TigerGraph memory.""" + graphs, cred = creds + validate_graphname(graphname) + if graphname not in graphs: + raise HTTPException(status_code=403, detail="Insufficient permissions for this graph.") + auth = base64.b64encode(f"{cred.username}:{cred.password}".encode()).decode() try: - async with httpx.AsyncClient() as client: - res = await client.delete( - f"{graphrag_config['chat_history_api']}/conversation/{conversation_id}", - headers={"Authorization": f"Basic {auth}"}, - ) - res.raise_for_status() - except httpx.HTTPStatusError as e: - logger.error(f"HTTP error occurred: {e}") - raise HTTPException(status_code=e.response.status_code, detail="Failed to delete conversation") + _, conn = ws_basic_auth(auth, graphname) + deleted = await asyncio.to_thread( + tg_memory.delete_message_vertices, + conn, + graphname, + [message_id], + ) + if deleted == 0: + raise HTTPException(status_code=404, detail="Message not found.") + except HTTPException: + raise except Exception as e: exc = traceback.format_exc() logger.debug_pii( - f"/conversation/{conversation_id} DELETE request_id={req_id_cv.get()} Exception Trace:\n{exc}" + f"/message/{message_id} DELETE request_id={req_id_cv.get()} Exception Trace:\n{exc}" ) - raise HTTPException(status_code=500, detail="Internal server error") + raise HTTPException(status_code=500, detail=str(e)) from e - return {"message": "Conversation deleted successfully"} + return {"message": "Message deleted successfully"} async def emit_progress(agent: TigerGraphAgent, ws: WebSocket): @@ -1916,71 +1950,39 @@ async def run_agent( return resp -async def load_conversation_history(conversation_id: str, usr_auth: str) -> list[dict[str, str]]: +async def load_conversation_history( + conn: TigerGraphConnection, + graphname: str, + conversation_id: str, +) -> list[dict[str, str]]: """ - Load conversation history from the chat history service. + Load conversation history from TigerGraph chat memory for the agent. Returns a list of dicts with 'query', 'response', 'create_ts', and 'update_ts' keys. """ if not conversation_id or conversation_id == "new": return [] - - ch = graphrag_config.get("chat_history_api") - if ch is None: - LogWriter.info("chat-history not enabled, returning empty history") + if not tg_memory.tg_memory_enabled(graphname): return [] - - headers = {"Authorization": f"Basic {usr_auth}"} try: - async with httpx.AsyncClient() as client: - res = await client.get( - f"{ch}/conversation/{conversation_id}", - headers=headers, - ) - res.raise_for_status() - conversation_data = res.json() - # Convert conversation messages to the format expected by the agent - history = [] - for msg in conversation_data: - if msg.get("role") == "user": - # Find the corresponding system response - for response_msg in conversation_data: - if (response_msg.get("role") == "system" and - response_msg.get("parent_id") == msg.get("message_id")): - history.append({ - "query": msg.get("content", ""), - "response": response_msg.get("content", ""), - "create_ts": response_msg.get("create_ts"), - "update_ts": response_msg.get("update_ts"), - }) - break - - LogWriter.info(f"Loaded {len(history)} conversation history entries for conversation {conversation_id}") - return history - + rows = await asyncio.to_thread( + tg_memory.list_messages_sorted_by_epoch, + conn, + graphname, + conversation_id, + ) + history = tg_memory.agent_history_from_messages(rows) + LogWriter.info( + f"Loaded {len(history)} conversation history entries for conversation {conversation_id}" + ) + return history except Exception as e: exc = traceback.format_exc() - logger.debug_pii(f"Error loading conversation history for {conversation_id}\nException Trace:\n{exc}") + logger.debug_pii( + f"Error loading conversation history for {conversation_id}\nException Trace:\n{exc}" + ) LogWriter.warning(f"Failed to load conversation history for {conversation_id}: {e}") return [] - -async def write_message_to_history(message: Message, usr_auth: str): - ch = graphrag_config.get("chat_history_api") - if ch is not None: - headers = {"Authorization": f"Basic {usr_auth}"} - try: - async with httpx.AsyncClient() as client: - res = await client.post( - f"{ch}/conversation", headers=headers, json=message.model_dump() - ) - res.raise_for_status() - except Exception: # catch all exceptions to log them, but don't raise - exc = traceback.format_exc() - logger.debug_pii(f"Error writing chat history\nException Trace:\n{exc}") - - else: - LogWriter.info(f"chat-history not enabled. chat-history url: {ch}") - @router.get(route_prefix + "/{graphname}/query") async def graph_query( graphname: ValidGraphName, @@ -1993,11 +1995,32 @@ async def graph_query( auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() _, conn = ws_basic_auth(auth, graphname) try: - # Load conversation history if conversation_id is provided - conversation_history = await load_conversation_history(conversation_id, auth) if conversation_id else [] + started_new = not conversation_id or conversation_id == "new" + + # Security: verify the requesting user owns this conversation before + # loading its history. Prevents User A from injecting User B's + # conversation_id to read or pollute B's chat context. + if not started_new and tg_memory.tg_memory_enabled(graphname): + if not await asyncio.to_thread( + tg_memory.verify_conversation_owner, + conn, + graphname, + conversation_id, + creds.username, + ): + raise HTTPException( + status_code=403, + detail="Conversation not found or access denied.", + ) + + conversation_history = ( + await load_conversation_history(conn, graphname, conversation_id) + if conversation_id and not started_new + else [] + ) # Use provided conversation ID or generate new one - if not conversation_id or conversation_id == "new": + if started_new: convo_id = str(uuid.uuid4()) LogWriter.info(f"Starting new conversation with ID: {convo_id}") else: @@ -2009,22 +2032,8 @@ async def graph_query( rag_pattern = rag_pattern or "auto" agent = make_agent(graphname, conn, use_cypher, supportai_retriever=rag_pattern) - prev_id = None data = q - # make message from data - message = Message( - conversation_id=convo_id, - message_id=str(uuid.uuid4()), - parent_id=prev_id, - model=get_chat_config(graphname).get("llm_model", "unknown"), - content=data, - role=Role.USER, - ) - # save message - await write_message_to_history(message, auth) - prev_id = message.message_id - # generate response and keep track of response time start = time.monotonic() resp = await run_agent( @@ -2032,11 +2041,10 @@ async def graph_query( ) elapsed = time.monotonic() - start - # save message message = Message( conversation_id=convo_id, message_id=str(uuid.uuid4()), - parent_id=prev_id, + parent_id=None, model=get_chat_config(graphname).get("llm_model", "unknown"), content=resp.natural_language_response, role=Role.SYSTEM, @@ -2045,9 +2053,20 @@ async def graph_query( response_type=resp.response_type, query_sources=resp.query_sources, ) - await write_message_to_history(message, auth) await asyncio.to_thread(_save_trace_log, message.message_id, convo_id, data, resp, elapsed, creds.username) - prev_id = message.message_id + if tg_memory.tg_memory_enabled(graphname): + await asyncio.to_thread( + tg_memory.write_exchange_to_tg_memory, + conn, + graphname, + convo_id, + creds.username, + data or "", + resp.natural_language_response or "", + tracelog="", + exchange_message_id=message.message_id, + is_new_conversation=started_new, + ) # reply return message.model_dump_json() @@ -2129,12 +2148,36 @@ async def chat( f"WebSocket conversation_id received: {conversation_id or 'empty'} " f"(graph={graphname}, rag_pattern={rag_pattern})" ) - + + started_new_conversation = conversation_id == "new" or not conversation_id + + # Security: verify the requesting user owns this conversation before + # loading its history. Prevents User A from injecting User B's + # conversation_id to read or pollute B's chat context. + if not started_new_conversation and tg_memory.tg_memory_enabled(graphname): + owner_ok = await asyncio.to_thread( + tg_memory.verify_conversation_owner, + conn, + graphname, + conversation_id, + ws_username, + ) + if not owner_ok: + await websocket.send_text( + json.dumps({"error": "Conversation not found or access denied."}) + ) + await websocket.close(code=1008) + return + # Load conversation history if not a new conversation - conversation_history = await load_conversation_history(conversation_id, usr_auth) - + conversation_history = ( + await load_conversation_history(conn, graphname, conversation_id) + if not started_new_conversation + else [] + ) + # Use provided conversation ID or generate new one - if conversation_id == "new" or not conversation_id: + if started_new_conversation: convo_id = str(uuid.uuid4()) LogWriter.info(f"Starting new conversation with ID: {convo_id}") else: @@ -2152,19 +2195,6 @@ async def chat( while True: data = await websocket.receive_text() - # make message from data - message = Message( - conversation_id=convo_id, - message_id=str(uuid.uuid4()), - parent_id=prev_id, - model=get_chat_config(graphname).get("llm_model", "unknown"), - content=data, - role=Role.USER, - ) - # save message - await write_message_to_history(message, usr_auth) - prev_id = message.message_id - # generate response and keep track of response time start = time.monotonic() resp = await run_agent( @@ -2172,7 +2202,6 @@ async def chat( ) elapsed = time.monotonic() - start - # save message message = Message( conversation_id=convo_id, message_id=str(uuid.uuid4()), @@ -2185,8 +2214,21 @@ async def chat( response_type=resp.response_type, query_sources=resp.query_sources, ) - await write_message_to_history(message, usr_auth) - await asyncio.to_thread(_save_trace_log, message.message_id, convo_id, data, resp, elapsed, ws_username) + trace_data = await asyncio.to_thread(_save_trace_log, message.message_id, convo_id, data, resp, elapsed, ws_username) + if tg_memory.tg_memory_enabled(graphname): + await asyncio.to_thread( + tg_memory.write_exchange_to_tg_memory, + conn, + graphname, + convo_id, + ws_username, + data, + resp.natural_language_response or "", + tracelog=json.dumps(trace_data, default=str) if trace_data else "", + exchange_message_id=message.message_id, + is_new_conversation=started_new_conversation, + ) + started_new_conversation = False prev_id = message.message_id # reply diff --git a/graphrag/app/supportai/supportai.py b/graphrag/app/supportai/supportai.py index f38eba9..32188dc 100644 --- a/graphrag/app/supportai/supportai.py +++ b/graphrag/app/supportai/supportai.py @@ -18,6 +18,8 @@ # SupportAIQuestion, ) from common.utils.text_extractors import TextExtractor +from common.memory import tg_memory + logger = logging.getLogger(__name__) def init_supportai(conn: TigerGraphConnection, graphname: str) -> tuple[dict, dict]: @@ -142,6 +144,11 @@ def init_supportai(conn: TigerGraphConnection, graphname: str) -> tuple[dict, di ) logger.info(f"Done installing supportai query all with status {query_res}") + try: + tg_memory.init_memory_schema(conn, graphname) + except Exception: + logger.warning("init_memory_schema failed for %s", graphname, exc_info=True) + return schema_res, index_res, query_res