diff --git a/docs/AGENT-GUIDE.md b/docs/AGENT-GUIDE.md index f8b0bdd5..c6894cea 100644 --- a/docs/AGENT-GUIDE.md +++ b/docs/AGENT-GUIDE.md @@ -178,7 +178,7 @@ Prefer **`resolve` → `describe(id=…)`** over **`describe(fqn=…)`** when an | Handler for route | route id | `neighbors(ids, "in", ["EXPOSES"])` | | Who implements interface T? | type symbol id | `neighbors(ids, "in", ["IMPLEMENTS"])` | | Who injects type T? | type symbol id | `neighbors(ids, "in", ["INJECTS"])` | -| Impact / "what breaks if I change X"? | `trace` or `neighbors` loop | `trace(id, "in", ["CALLS","OVERRIDES"], max_depth=3)` or loop `neighbors` `in` with `CALLS`, `INJECTS` | +| Impact / "what breaks if I change X"? | `trace` with `direction="both"` or `neighbors` loop | `trace(id, "both", ["CALLS","OVERRIDES"], max_depth=3)` or loop `neighbors` `in` with `CALLS`, `INJECTS` | **Rules of thumb:** @@ -190,9 +190,11 @@ Prefer **`resolve` → `describe(id=…)`** over **`describe(fqn=…)`** when an #### `trace` -Multi-hop BFS traversal with server-side pruning. Returns structured paths, a node dict, and traversal stats. Use when the question implies a path or chain (3+ hops), needs to cross a service boundary, or a `neighbors` loop has exceeded 2 hops without converging. Args: `ids` (string or array), **`direction`**, **`edge_types`** (stored labels only — no composed dot-keys), `max_depth` (1–5, default 3), `max_paths` (default 20), `max_nodes_discovered` (100–2000, default 500), `filter` (hard gate `NodeFilter`), `edge_filter` (CALLS edge attribute filtering), `prune_roles` (soft gate — edges recorded, frontier stops), `fan_out_cap` (per-node edge limit, default 5), `collapse_trivial` (collapse wrapper chains, default true), `include_unresolved` (interleave unresolved call sites). +Multi-hop BFS traversal with server-side pruning. Returns nested tree structure, a node dict, and traversal stats. Use when the question implies a path or chain (3+ hops), needs to cross a service boundary, or a `neighbors` loop has exceeded 2 hops without converging. Args: `ids` (string or array), **`direction`** (`in` | `out` | `both`), **`edge_types`** (stored labels only — no composed dot-keys), `max_depth` (1–5, default 3), `max_paths` (default 20), `max_nodes_discovered` (100–2000, default 500), `filter` (hard gate `NodeFilter`), `edge_filter` (CALLS edge attribute filtering), `prune_roles` (soft gate — edges recorded, frontier stops), `fan_out_cap` (per-node edge limit, default 5), `collapse_trivial` (collapse wrapper chains, default true), `collapse_roles` (roles to collapse as trivial intermediates, default `["OTHER"]`; only effective when `collapse_trivial=true`), `collapse_min_chain_length` (minimum chain length for collapse, default 1), `include_unresolved` (interleave unresolved call sites), `cross_service` (continue BFS through HTTP_CALLS/ASYNC_CALLS boundaries), `min_result_nodes` (minimum result nodes target; retries with doubled `fan_out_cap` if below target, default 0). -Returns `TraceOutput` with `nodes` (dict of `NodeRef`), `edges` (list of `TraceEdge` with `hop`, `parent_edge_id`, `collapsed`, `cross_service_boundary`), `paths` (ranked root-to-leaf), and `stats` (budget, pruning counts). Cross-service edges (`HTTP_CALLS`, `ASYNC_CALLS`) are boundary signals — BFS stops at the service boundary and includes the downstream node for the agent to continue with a separate `trace` call. +Returns `TraceOutput` with `nodes` (dict of `NodeRef`), `tree` (nested `TreeNode` list — one per seed; each `TreeNode` has `id`, `edge_from_parent` with `direction`, `edge_type`, `hop`, `confidence`, `cross_service_boundary`, `attrs`; `children` (nested TreeNodes); `collapsed` and `collapsed_intermediates` for collapsed edges), `ranked_leaves` (scored leaf nodes with `node_id`, `depth`, `leaf_role`, `score`, sorted descending by score, capped at `max_paths`), and `stats` (budget, pruning counts). Cross-service edges (`HTTP_CALLS`, `ASYNC_CALLS`) are boundary signals — BFS stops at the service boundary unless `cross_service=true`. + +**`direction="both"`**: runs bidirectional traversal (out then in) with a shared visited set. Tree contains children from both directions; `edge_from_parent.direction` distinguishes them. Use for impact analysis ("who depends on X and what does X call?") in one call. **`trace` vs `neighbors`:** Use `neighbors` for single-hop adjacency (full unfiltered result). Use `trace` for multi-hop path questions, impact analysis, or when `neighbors` returns high fan-out (>8 CALLS edges). diff --git a/mcp_hints.py b/mcp_hints.py index 3ec7d0b1..2ea5a70a 100644 --- a/mcp_hints.py +++ b/mcp_hints.py @@ -517,13 +517,25 @@ def _high_fanout_trace_hint(origin_id: str, calls_n: int) -> _StructuredHint: ) +def _walk_tree_nodes(tree: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Flatten a tree structure into a list of all nodes (for iteration).""" + result: list[dict[str, Any]] = [] + stack = list(tree) + while stack: + node = stack.pop() + result.append(node) + for child in node.get("children") or []: + stack.append(child) + return result + + def _trace_structured_hints(payload: dict[str, Any]) -> tuple[list[_StructuredHint], list[tuple[int, str]]]: - """Structured hints and advisories for trace output.""" + """Structured hints and advisories for trace output (v2 tree format).""" struct_pairs: list[_StructuredHint] = [] advisories: list[tuple[int, str]] = [] stats = payload.get("stats") - edges = list(payload.get("edges") or []) + tree = list(payload.get("tree") or []) # Cross-service boundary hints don't require stats — guard stats-dependant hints separately. if isinstance(stats, dict): @@ -554,7 +566,8 @@ def _trace_structured_hints(payload: dict[str, Any]) -> tuple[list[_StructuredHi + int(stats.get("nodes_pruned_fan_out") or 0) + int(stats.get("edges_collapsed_trivial") or 0) ) - if pruned_count > 0 or any(e.get("collapsed") for e in edges): + has_collapsed = any(n.get("collapsed") for n in _walk_tree_nodes(tree)) + if pruned_count > 0 or has_collapsed: advisories.append(( PRIORITY_META, f"trace pruned {pruned_count} edges. Use neighbors(id, direction, edge_types) on specific nodes for full detail.", @@ -568,21 +581,29 @@ def _trace_structured_hints(payload: dict[str, Any]) -> tuple[list[_StructuredHi )) # (c) Cross-service boundary hint (no stats dependency). - xs_edges = [e for e in edges if e.get("cross_service_boundary")] - if xs_edges: - # When cross_service=True, BFS already continued through boundaries. - # Emit a lighter informational advisory instead of an action hint. + # Walk tree in a single pass: build parent map and collect cross-service boundary nodes. + xs_edges_final: list[tuple[str, str, str | None]] = [] # (from_id, to_id, confidence) + stack: list[tuple[dict[str, Any], str | None]] = [(n, None) for n in tree] + while stack: + node, parent_id = stack.pop() + nid = str(node.get("id") or "") + efp = node.get("edge_from_parent") + if isinstance(efp, dict) and efp.get("cross_service_boundary"): + attrs = efp.get("attrs") if isinstance(efp.get("attrs"), dict) else {} + confidence = attrs.get("confidence") + xs_edges_final.append((parent_id or "", nid, confidence)) + for child in node.get("children") or []: + stack.append((child, nid)) + + if xs_edges_final: was_seamless = bool(payload.get("cross_service")) if was_seamless: advisories.append(( PRIORITY_META, - f"trace crossed {len(xs_edges)} service boundary(ies).", + f"trace crossed {len(xs_edges_final)} service boundary(ies).", )) else: - for xe in xs_edges[:3]: - to_id = str(xe.get("to_id") or "") - from_id = str(xe.get("from_id") or "") - confidence = xe.get("attrs", {}).get("confidence") if isinstance(xe.get("attrs"), dict) else None + for from_id, to_id, confidence in xs_edges_final[:3]: conf_str = f"confidence={confidence}" if confidence is not None else "low confidence" advisories.append(( PRIORITY_META, diff --git a/mcp_trace.py b/mcp_trace.py index 9f4d68af..6814bf31 100644 --- a/mcp_trace.py +++ b/mcp_trace.py @@ -3,8 +3,8 @@ Imports stable types from mcp_v2.py but does not modify them: - NodeFilter, EdgeFilter, NodeRef, _node_ref_from_row, _node_kind_from_id -This module implements PR-TRACE-1a (core BFS engine) + PR-TRACE-1b -(pruning, collapsing, cross-service boundary detection). +This module implements PR-TRACE-V2 (tree output, configurable collapse, +source-relative ranking, bidirectional traversal, min_result_nodes retry). """ from __future__ import annotations @@ -14,7 +14,7 @@ from pydantic import BaseModel, ConfigDict, Field, validate_call -from java_ontology import EDGE_SCHEMA +from java_ontology import EDGE_SCHEMA, VALID_ROLES from kuzu_queries import KuzuGraph from mcp_v2 import ( EdgeFilter, @@ -37,6 +37,39 @@ "OTHER": 1, } +# Source-relative fan-out priority: source role → target role → priority. +# All role strings are validated against VALID_ROLES at startup. +_SOURCE_RELATIVE_PRIORITY: dict[str, dict[str, int]] = { + "SERVICE": { + "REPOSITORY": 5, + "SERVICE": 4, + "CONTROLLER": 3, + "CLIENT": 2, + "OTHER": 1, + }, + "CONTROLLER": { + "SERVICE": 5, + "REPOSITORY": 4, + "CLIENT": 3, + "CONTROLLER": 2, + "OTHER": 1, + }, + "REPOSITORY": { + "REPOSITORY": 5, + "SERVICE": 4, + "ENTITY": 3, + "CLIENT": 2, + "OTHER": 1, + }, +} + +# Validate source-relative priority table against known roles. +_VALID_PRIORITY_ROLES = VALID_ROLES | frozenset({"OTHER"}) +for _src_role, _target_map in _SOURCE_RELATIVE_PRIORITY.items(): + assert _src_role in _VALID_PRIORITY_ROLES, f"_SOURCE_RELATIVE_PRIORITY key {_src_role!r} not in known roles" + for _tgt_role in _target_map: + assert _tgt_role in _VALID_PRIORITY_ROLES, f"_SOURCE_RELATIVE_PRIORITY[{_src_role!r}] key {_tgt_role!r} not in known roles" + # Scaffolding edges exempt from fan_out_cap. _SCAFFOLDING_EDGE_TYPES = frozenset({"DECLARES_CLIENT", "DECLARES_PRODUCER"}) @@ -51,27 +84,40 @@ def _role_priority(role: str | None) -> int: return _ROLE_PRIORITY.get(role, 1) -class TraceEdge(BaseModel): - """A single edge in the trace result with BFS metadata.""" +# --- Models --- + + +class EdgeFromParent(BaseModel): + """Edge metadata linking a TreeNode to its parent.""" model_config = ConfigDict(extra="forbid") - from_id: str - to_id: str + direction: Literal["in", "out"] edge_type: str hop: int - parent_edge_id: str | None = None - collapsed: bool = False - collapsed_intermediates: list[str] = Field(default_factory=list) + confidence: float | None = None cross_service_boundary: bool = False attrs: dict[str, Any] = Field(default_factory=dict) -class TracePath(BaseModel): - """A root-to-leaf path through the traced DAG.""" +class TreeNode(BaseModel): + """A node in the nested trace tree output.""" model_config = ConfigDict(extra="forbid") - edges: list[TraceEdge] - leaf: NodeRef + id: str + edge_from_parent: EdgeFromParent | None = None + children: list[TreeNode] = Field(default_factory=list) + collapsed: bool = False + collapsed_intermediates: list[str] = Field(default_factory=list) + + +class RankedLeaf(BaseModel): + """A ranked leaf node from the trace tree.""" + model_config = ConfigDict(extra="forbid") + + node_id: str + depth: int + leaf_role: str | None = None + score: float class TraceStats(BaseModel): @@ -99,13 +145,49 @@ class TraceOutput(BaseModel): edge_types: list[str] actual_depth: int = 0 nodes: dict[str, NodeRef] = Field(default_factory=dict) - edges: list[TraceEdge] = Field(default_factory=list) - paths: list[TracePath] = Field(default_factory=list) + tree: list[TreeNode] = Field(default_factory=list) + ranked_leaves: list[RankedLeaf] = Field(default_factory=list) stats: TraceStats = Field(default_factory=TraceStats) message: str | None = None advisories: list[str] = Field(default_factory=list) +# --- Internal flat edge representation used during BFS --- + + +class _FlatEdge: + """Internal flat edge during BFS (not exported).""" + + __slots__ = ( + "from_id", "to_id", "edge_type", "hop", "direction", + "confidence", "cross_service_boundary", "attrs", + "collapsed", "collapsed_intermediates", + ) + + def __init__( + self, + *, + from_id: str, + to_id: str, + edge_type: str, + hop: int, + direction: Literal["in", "out"], + confidence: float | None = None, + cross_service_boundary: bool = False, + attrs: dict[str, Any] | None = None, + ) -> None: + self.from_id = from_id + self.to_id = to_id + self.edge_type = edge_type + self.hop = hop + self.direction = direction + self.confidence = confidence + self.cross_service_boundary = cross_service_boundary + self.attrs = attrs or {} + self.collapsed = False + self.collapsed_intermediates: list[str] = [] + + def _edge_attrs_for_row(row: dict[str, Any]) -> dict[str, Any]: """Extract edge attributes from a query row, excluding structural fields.""" attrs = { @@ -125,11 +207,7 @@ def _neighbors_batched( edge_types: list[str], edge_filter: EdgeFilter | None = None, ) -> list[dict[str, Any]]: - """Issue a single Cypher query for all frontier nodes at one BFS hop. - - Returns rows with: source_id, other_id, edge_type, and edge attribute columns. - Each row represents one edge from a source node to a target node. - """ + """Issue a single Cypher query for all frontier nodes at one BFS hop.""" if not node_ids: return [] @@ -193,7 +271,7 @@ def _neighbors_batched( def _load_node_record( graph: KuzuGraph, node_id: str, kind: Literal["symbol", "route", "client", "producer"], ) -> dict[str, Any] | None: - """Load a node record from Kuzu (copied from mcp_v2.py).""" + """Load a node record from Kuzu.""" if kind == "symbol": projection = ( "n.id AS id, n.kind AS kind, n.name AS name, n.fqn AS fqn, n.package AS package, " @@ -301,33 +379,46 @@ def _node_matches_filter( def _fan_out_sort_key( row: dict[str, Any], nodes: dict[str, NodeRef], + source_role: str | None = None, ) -> tuple[float, int, str]: """Sort key for fan_out_cap ranking: confidence desc, role priority desc, fqn asc.""" conf = float(row.get("confidence") or 0.0) other_id = str(row.get("other_id") or "") node_ref = nodes.get(other_id) - role_prio = _role_priority(node_ref.role if node_ref else None) + target_role = node_ref.role if node_ref else None + + if source_role and source_role in _SOURCE_RELATIVE_PRIORITY: + target_prio_map = _SOURCE_RELATIVE_PRIORITY[source_role] + role_prio = target_prio_map.get(target_role, 1) if target_role else 0 + else: + role_prio = _role_priority(target_role) + fqn = node_ref.fqn if node_ref else other_id return (-conf, -role_prio, fqn) def _collapse_trivial_chains( nodes: dict[str, NodeRef], - edges: list[TraceEdge], - edge_id_map: dict[str, TraceEdge], + edges: list[_FlatEdge], + collapse_roles: set[str] | None = None, + collapse_min_chain_length: int = 1, ) -> int: """Post-BFS pass: collapse trivial chains (degree-1 intermediates). - Mutates ``nodes``, ``edges``, and ``edge_id_map`` in-place. Single-pass — - if A→B→C→D has both B and C trivial, only one level collapses per call. + Collapsed intermediates are retained in the ``nodes`` dict (accessible + standalone but not nested in the tree). + Returns the number of edges collapsed. """ if not edges: return 0 - # Build adjacency for degree counting. - in_edges: dict[str, list[TraceEdge]] = defaultdict(list) - out_edges: dict[str, list[TraceEdge]] = defaultdict(list) + if collapse_roles is None: + collapse_roles = {"OTHER", None} + + # Build adjacency for degree counting (only CALLS edges). + in_edges: dict[str, list[_FlatEdge]] = defaultdict(list) + out_edges: dict[str, list[_FlatEdge]] = defaultdict(list) for e in edges: if e.collapsed: continue @@ -336,16 +427,36 @@ def _collapse_trivial_chains( out_edges[e.from_id].append(e) collapsed_count = 0 - # Track which edge IDs got replaced (so we can update parent_edge_id later). - old_to_new_edge_id: dict[str, str] = {} + edges_to_remove: set[int] = set() + edges_to_add: list[_FlatEdge] = [] - # Identify collapsible intermediates: B where exactly 1 inbound CALLS and 1 outbound CALLS. all_node_ids = set(nodes.keys()) - intermediates_to_collapse: list[tuple[str, TraceEdge, TraceEdge]] = [] + + # Count chain length for each candidate intermediate. + def _chain_length(node_id: str, seen: frozenset[str] | None = None) -> int: + if seen is None: + seen = frozenset() + if node_id in seen: + return 0 + seen = seen | {node_id} + node_out = [e for e in out_edges.get(node_id, []) if id(e) not in edges_to_remove] + if len(node_out) != 1: + return 1 + target_id = node_out[0].to_id + node_in = [e for e in in_edges.get(node_id, []) if id(e) not in edges_to_remove] + if len(node_in) != 1: + return 1 + target_in = [e for e in in_edges.get(target_id, []) if id(e) not in edges_to_remove and not e.collapsed] + target_out = [e for e in out_edges.get(target_id, []) if id(e) not in edges_to_remove and not e.collapsed] + if len(target_in) == 1 and len(target_out) == 1: + target_ref = nodes.get(target_id) + if target_ref and target_ref.role in collapse_roles: + return 1 + _chain_length(target_id, seen) + return 1 for node_id in all_node_ids: - node_in = [e for e in in_edges.get(node_id, []) if not e.collapsed] - node_out = [e for e in out_edges.get(node_id, []) if not e.collapsed] + node_in = [e for e in in_edges.get(node_id, []) if not e.collapsed and id(e) not in edges_to_remove] + node_out = [e for e in out_edges.get(node_id, []) if not e.collapsed and id(e) not in edges_to_remove] if len(node_in) != 1 or len(node_out) != 1: continue @@ -353,281 +464,224 @@ def _collapse_trivial_chains( if node_ref is None: continue - # Role check: OTHER, or declaring-class role is SERVICE/COMPONENT. + # Check collapse_roles. role = node_ref.role - if role not in ("OTHER", None): + if role not in collapse_roles: + continue + + # Check minimum chain length. + chain_len = _chain_length(node_id) + if chain_len < collapse_min_chain_length: continue in_edge = node_in[0] out_edge = node_out[0] - intermediates_to_collapse.append((node_id, in_edge, out_edge)) - # Process collapses. - edges_to_remove: set[str] = set() - edges_to_add: list[TraceEdge] = [] + # Skip if already consumed. + if id(in_edge) in edges_to_remove or id(out_edge) in edges_to_remove: + continue - for node_id, in_edge, out_edge in intermediates_to_collapse: - # Merge A→B→C into A→C. + # Collect intermediates from existing collapsed edges. + intermediates = [node_id] + final_to_id = out_edge.to_id + # Walk the chain to collapse all intermediates. + current_out = out_edge + while True: + next_id = current_out.to_id + next_ref = nodes.get(next_id) + if next_ref is None or next_ref.role not in collapse_roles: + break + next_in = [e for e in in_edges.get(next_id, []) if not e.collapsed and id(e) not in edges_to_remove] + next_out = [e for e in out_edges.get(next_id, []) if not e.collapsed and id(e) not in edges_to_remove] + if len(next_in) != 1 or len(next_out) != 1: + break + intermediates.append(next_id) + edges_to_remove.add(id(next_in[0])) + edges_to_remove.add(id(next_out[0])) + current_out = next_out[0] + + final_to_id = current_out.to_id + + # Merge attrs (prefer lower confidence edge's attrs). merged_attrs = in_edge.attrs if ( float(in_edge.attrs.get("confidence", 1.0)) - <= float(out_edge.attrs.get("confidence", 1.0)) - ) else out_edge.attrs + <= float(current_out.attrs.get("confidence", 1.0)) + ) else current_out.attrs - merged_edge = TraceEdge( + merged_edge = _FlatEdge( from_id=in_edge.from_id, - to_id=out_edge.to_id, + to_id=final_to_id, edge_type="CALLS", hop=in_edge.hop, - parent_edge_id=in_edge.parent_edge_id, - collapsed=True, - collapsed_intermediates=[node_id], + direction=in_edge.direction, + confidence=in_edge.confidence, attrs=merged_attrs, ) + merged_edge.collapsed = True + merged_edge.collapsed_intermediates = intermediates - edges_to_remove.add(f"{in_edge.from_id}:{in_edge.to_id}:{in_edge.edge_type}:{in_edge.hop}") - edges_to_remove.add(f"{out_edge.from_id}:{out_edge.to_id}:{out_edge.edge_type}:{out_edge.hop}") - - old_to_new_edge_id[ - f"{out_edge.from_id}:{out_edge.to_id}:{out_edge.edge_type}:{out_edge.hop}" - ] = f"{merged_edge.from_id}:{merged_edge.to_id}:{merged_edge.edge_type}:{merged_edge.hop}" + edges_to_remove.add(id(in_edge)) + edges_to_remove.add(id(out_edge)) edges_to_add.append(merged_edge) - collapsed_count += 1 + collapsed_count += len(intermediates) if collapsed_count == 0: return 0 - # Remove collapsed edges and add merged ones. - new_edges = [e for e in edges if f"{e.from_id}:{e.to_id}:{e.edge_type}:{e.hop}" not in edges_to_remove] + # Rebuild edges list: remove collapsed, add merged. + new_edges = [e for e in edges if id(e) not in edges_to_remove] new_edges.extend(edges_to_add) + # Retain intermediates in nodes dict (v2 change from v1). - # Remove intermediate nodes. - for node_id, _, _ in intermediates_to_collapse: - nodes.pop(node_id, None) - - # Rebuild edge_id_map. - edge_id_map.clear() - for e in new_edges: - eid = f"{e.from_id}:{e.to_id}:{e.edge_type}:{e.hop}" - edge_id_map[eid] = e - - # Recompute parent_edge_id: any edge referencing a removed edge should point to the collapsed replacement. - for e in new_edges: - if e.parent_edge_id and e.parent_edge_id in old_to_new_edge_id: - e.parent_edge_id = old_to_new_edge_id[e.parent_edge_id] - - # Replace edges list in place (caller holds the reference). edges.clear() edges.extend(new_edges) return collapsed_count -def _enumerate_paths( +def _build_tree( + seed_ids: list[str], nodes: dict[str, NodeRef], - edges: list[TraceEdge], - max_paths: int, -) -> list[TracePath]: - """Enumerate root-to-leaf paths through the DAG, capped and ranked.""" - if not edges: - return [] + edges: list[_FlatEdge], +) -> list[TreeNode]: + """Convert flat edge list to nested TreeNode structure. - # Build adjacency list: from_id -> list of outgoing edges. - out_edges_by_src: dict[str, list[TraceEdge]] = defaultdict(list) - for e in edges: - out_edges_by_src[e.from_id].append(e) - - # Find seeds (edges with hop 0). - seeds = {e.from_id for e in edges if e.hop == 0} - - # Find leaves (node IDs that have no outgoing edges in the result). - all_targets = {e.to_id for e in edges} - leaves = all_targets - set(out_edges_by_src.keys()) - - if not leaves: + Multi-seed roots: top-level tree list has one TreeNode per seed ID. + Collapsed intermediates are NOT in the tree but retained in nodes dict. + """ + if not edges and not seed_ids: return [] - # DFS from each seed to enumerate paths. - candidates: list[tuple[int, float, int, list[TraceEdge]]] = [] - - def dfs(current_id: str, path_edges: list[TraceEdge], min_conf: float) -> None: - """Depth-first search accumulating path confidence.""" - if current_id in leaves: - leaf_role = nodes.get(current_id, NodeRef(id=current_id, kind="symbol", fqn="")).role - candidates.append( - (_role_priority(leaf_role), min_conf, -len(path_edges), list(path_edges)) - ) - return - - for e in out_edges_by_src.get(current_id, []): - edge_conf = float(e.attrs.get("confidence", 1.0)) - dfs(e.to_id, path_edges + [e], min(min_conf, edge_conf)) - - for seed in seeds: - dfs(seed, [], 1.0) - - # Cap enumeration to avoid exponential blowup. - if len(candidates) > 10 * max_paths: - # Sort and keep top candidates. - candidates.sort(key=lambda x: (x[0], x[1], x[2]), reverse=True) - candidates = candidates[: 10 * max_paths] - - # Rank and cap at max_paths. - candidates.sort(key=lambda x: (x[0], x[1], x[2]), reverse=True) - paths: list[TracePath] = [] - - for role_prio, _min_conf, _neg_len, edge_list in candidates[:max_paths]: - leaf_id = edge_list[-1].to_id if edge_list else "" - leaf_node = nodes.get(leaf_id, NodeRef(id=leaf_id, kind="symbol", fqn="")) - paths.append(TracePath(edges=edge_list, leaf=leaf_node)) - - return paths - - -@validate_call(config={"arbitrary_types_allowed": True}) -def trace_v2( - ids: str | list[str], - direction: Literal["in", "out"] = Field(...), - edge_types: list[str] = Field(...), - max_depth: int = 3, - max_paths: int = 20, - max_nodes_discovered: int = 500, - filter: NodeFilter | dict[str, Any] | str | None = None, - edge_filter: EdgeFilter | dict[str, Any] | str | None = None, - prune_roles: list[str] | None = None, - fan_out_cap: int = 5, - collapse_trivial: bool = True, - include_unresolved: bool = False, - cross_service: bool = False, - graph: KuzuGraph | None = None, -) -> TraceOutput: - """Multi-hop BFS traversal with pruning.""" - # Validate required parameters. - if not direction: - return TraceOutput( - success=False, - seed_ids=[], - direction="", - edge_types=[], - message="direction is required (in or out)", - ) - - if not edge_types: - return TraceOutput( - success=False, - seed_ids=[], - direction=direction, - edge_types=[], - message="edge_types is required and non-empty", - ) - - # Validate edge types. - unknown = [et for et in edge_types if et not in _TRACE_EDGE_TYPES] - if unknown: - return TraceOutput( - success=False, - seed_ids=[], - direction=direction, - edge_types=edge_types, - message=( - f"Unknown edge type(s): {unknown}. " - f"Valid types: {sorted(_TRACE_EDGE_TYPES)}. " - "Composed keys (e.g., DECLARES.DECLARES_CLIENT) are not supported." + # Build adjacency: from_id -> list of edges from that node. + adj: dict[str, list[_FlatEdge]] = defaultdict(list) + for e in edges: + adj[e.from_id].append(e) + + # Track which nodes are already placed (seed nodes get placed first). + placed: set[str] = set() + + def _make_tree_node(edge: _FlatEdge, target_id: str) -> TreeNode: + """Create a TreeNode from an edge targeting target_id.""" + child = TreeNode( + id=target_id, + edge_from_parent=EdgeFromParent( + direction=edge.direction, + edge_type=edge.edge_type, + hop=edge.hop, + confidence=edge.confidence, + cross_service_boundary=edge.cross_service_boundary, + attrs=edge.attrs, ), + collapsed=edge.collapsed, + collapsed_intermediates=list(edge.collapsed_intermediates), ) + placed.add(target_id) + # Recurse into children of this target. + child_edges = adj.get(target_id, []) + for ce in child_edges: + if ce.to_id not in placed: + child.children.append(_make_tree_node(ce, ce.to_id)) + return child + + result: list[TreeNode] = [] + for sid in seed_ids: + placed.add(sid) + seed_node = TreeNode(id=sid, edge_from_parent=None) + for e in adj.get(sid, []): + if e.to_id not in placed: + seed_node.children.append(_make_tree_node(e, e.to_id)) + result.append(seed_node) - # Clamp max_depth. - max_depth = max(1, min(5, int(max_depth))) - - # Clamp max_nodes_discovered. - max_nodes_discovered = max(100, min(2000, int(max_nodes_discovered))) - - # Normalize seed IDs. - seed_ids = [ids] if isinstance(ids, str) else list(ids) - - if not seed_ids: - return TraceOutput( - success=True, - seed_ids=[], - direction=direction, - edge_types=edge_types, - nodes={}, - edges=[], - paths=[], - stats=TraceStats(budget_limit=max_nodes_discovered), - ) - - # Validate NodeFilter. - try: - if isinstance(filter, str): - import json - - filter = json.loads(filter) if filter.strip() else None - nf = NodeFilter.model_validate(filter) if filter is not None and not isinstance(filter, NodeFilter) else filter - except Exception as exc: - return TraceOutput( - success=False, - seed_ids=seed_ids, - direction=direction, - edge_types=edge_types, - message=f"Invalid filter: {exc}", - ) + return result - # Validate EdgeFilter. - try: - if isinstance(edge_filter, str): - import json - edge_filter = json.loads(edge_filter) if edge_filter.strip() else None - ef = ( - EdgeFilter.model_validate(edge_filter) - if edge_filter is not None and not isinstance(edge_filter, EdgeFilter) - else edge_filter - ) - except Exception as exc: - return TraceOutput( - success=False, - seed_ids=seed_ids, - direction=direction, - edge_types=edge_types, - message=f"Invalid edge_filter: {exc}", - ) - - # Get graph instance. - g = graph or KuzuGraph.get() - - # Normalized prune_roles set. - prune_role_set = set(prune_roles) if prune_roles else set() +def _build_ranked_leaves( + tree: list[TreeNode], + nodes: dict[str, NodeRef], + max_paths: int, +) -> list[RankedLeaf]: + """Walk tree to find leaf nodes, score and rank them.""" + if not tree: + return [] + candidates: list[RankedLeaf] = [] + + def _walk(node: TreeNode, depth: int) -> None: + if not node.children: + # This is a leaf. + node_ref = nodes.get(node.id) + leaf_role = node_ref.role if node_ref else None + role_score = _role_priority(leaf_role) + # Confidence from edge_from_parent (if any). + conf = 1.0 + if node.edge_from_parent and node.edge_from_parent.confidence is not None: + conf = node.edge_from_parent.confidence + score = role_score + conf + candidates.append(RankedLeaf( + node_id=node.id, + depth=depth, + leaf_role=leaf_role, + score=score, + )) + else: + for child in node.children: + _walk(child, depth + 1) + + for seed in tree: + _walk(seed, 0) + + candidates.sort(key=lambda r: -r.score) + return candidates[:max_paths] + + +def _run_bfs( + *, + graph: KuzuGraph, + seed_ids: list[str], + direction: Literal["in", "out"], + edge_types: list[str], + max_depth: int, + max_nodes_discovered: int, + nf: NodeFilter | None, + ef: EdgeFilter | None, + prune_role_set: set[str], + fan_out_cap: int, + cross_service: bool, + include_unresolved: bool, + visited: set[str], +) -> tuple[dict[str, NodeRef], list[_FlatEdge], int, int, bool, int, int]: + """Run a single-direction BFS pass. + + Returns (nodes, edges, total_discovered, actual_depth, budget_hit, + nodes_pruned_role, nodes_pruned_fan_out). + """ # Determine if cross-service detection is active. cross_service_active = bool(set(edge_types) & _CROSS_SERVICE_EDGE_TYPES) - # Effective scaffolding set: when cross_service=True, EXPOSES is also scaffolding - # so Route -> Handler is followed automatically in downstream services. + # Effective scaffolding set. effective_scaffolding = _SCAFFOLDING_EDGE_TYPES if cross_service: effective_scaffolding = _SCAFFOLDING_EDGE_TYPES | frozenset({"EXPOSES"}) # BFS state. - visited: set[str] = set(seed_ids) - frontier: list[str] = list(seed_ids) - edges: list[TraceEdge] = [] + frontier: list[str] = [sid for sid in seed_ids if sid not in visited] + for sid in seed_ids: + visited.add(sid) + + edges: list[_FlatEdge] = [] nodes: dict[str, NodeRef] = {} - edge_id_map: dict[str, TraceEdge] = {} - total_discovered = len(seed_ids) # Count seeds as discovered + total_discovered = len(seed_ids) actual_depth = 0 budget_hit = False nodes_pruned_role = 0 nodes_pruned_fan_out = 0 - # Track incoming edge ID for each node (for parent_edge_id). - node_to_incoming_edge_id: dict[str, str] = {} - - # For seed nodes, record them in nodes dict (always include seeds, filter doesn't apply). + # Record seed nodes. for sid in seed_ids: try: - kind = _resolve_node_kind(g, sid) - row = _load_node_record(g, sid, kind) + kind = _resolve_node_kind(graph, sid) + row = _load_node_record(graph, sid, kind) if row is not None: nodes[sid] = _node_ref_from_row(kind, row) except Exception: @@ -642,17 +696,15 @@ def trace_v2( actual_depth = hop + 1 - # Determine which edge types to query in this hop. + # Determine edge types to query. query_edge_types = list(edge_types) - # Cross-service: also query scaffolding edges when cross-service is active. if cross_service_active: for scaffold_et in effective_scaffolding: if scaffold_et not in query_edge_types: query_edge_types.append(scaffold_et) - # Batch query for all frontier nodes. rows = _neighbors_batched( - g, + graph, node_ids=frontier, direction=direction, edge_types=query_edge_types, @@ -667,11 +719,12 @@ def trace_v2( continue by_source[src_id].append(row) - # Process discovered edges. new_frontier: set[str] = set() for src_id, src_rows in by_source.items(): - parent_edge_id = node_to_incoming_edge_id.get(src_id) + # Source role for source-relative ranking. + src_ref = nodes.get(src_id) + source_role = src_ref.role if src_ref else None # --- Fan-out cap: separate scaffolding from signal edges --- scaffolding_rows: list[dict[str, Any]] = [] @@ -685,18 +738,16 @@ def trace_v2( signal_rows.append(row) # Sort signal rows by ranking key for fan-out cap. - signal_rows.sort(key=lambda r: _fan_out_sort_key(r, nodes)) + signal_rows.sort(key=lambda r: _fan_out_sort_key(r, nodes, source_role)) # Apply fan_out_cap to signal edges only. if fan_out_cap > 0 and len(signal_rows) > fan_out_cap: - # Count pruned nodes (those we're dropping). for dropped_row in signal_rows[fan_out_cap:]: dropped_id = str(dropped_row.get("other_id") or "") if dropped_id and dropped_id not in visited: nodes_pruned_fan_out += 1 signal_rows = signal_rows[:fan_out_cap] - # Combine: scaffolding always included, then capped signal. capped_rows = scaffolding_rows + signal_rows for row in capped_rows: @@ -711,18 +762,15 @@ def trace_v2( # --- Cross-service boundary detection --- if edge_type in effective_scaffolding and cross_service_active: - # Follow scaffolding edge to Client/Producer node. - # Record the scaffolding edge and include the node. try: - other_kind = _resolve_node_kind(g, other_id) - other_rec = _load_node_record(g, other_id, other_kind) + other_kind = _resolve_node_kind(graph, other_id) + other_rec = _load_node_record(graph, other_id, other_kind) if other_rec is None: continue except Exception: print(f"[trace] cross-service: failed to resolve {other_id}", file=sys.stderr) continue - # Check budget. if total_discovered >= max_nodes_discovered: budget_hit = True break @@ -731,27 +779,27 @@ def trace_v2( if other_id not in nodes: nodes[other_id] = _node_ref_from_row(other_kind, other_rec) - edge_id = f"{src_id}:{other_id}:{edge_type}:{hop}" - edge = TraceEdge( + conf = row.get("confidence") + confidence = float(conf) if conf is not None else None + edge = _FlatEdge( from_id=src_id, to_id=other_id, edge_type=edge_type, hop=hop, - parent_edge_id=parent_edge_id, + direction=direction, + confidence=confidence, attrs=_edge_attrs_for_row(row), ) edges.append(edge) - edge_id_map[edge_id] = edge visited.add(other_id) - # Now follow HTTP_CALLS/ASYNC_CALLS from Client/Producer. - # Determine which cross-service edge types to follow. + # Follow HTTP_CALLS/ASYNC_CALLS from Client/Producer. active_cross_types = list(set(edge_types) & _CROSS_SERVICE_EDGE_TYPES) if not active_cross_types: continue cross_rows = _neighbors_batched( - g, + graph, node_ids=[other_id], direction=direction, edge_types=active_cross_types, @@ -767,7 +815,6 @@ def trace_v2( if cross_et not in _CROSS_SERVICE_EDGE_TYPES: continue - # Check budget. if total_discovered >= max_nodes_discovered: budget_hit = True break @@ -775,8 +822,8 @@ def trace_v2( total_discovered += 1 try: - cross_kind = _resolve_node_kind(g, cross_target_id) - cross_rec = _load_node_record(g, cross_target_id, cross_kind) + cross_kind = _resolve_node_kind(graph, cross_target_id) + cross_rec = _load_node_record(graph, cross_target_id, cross_kind) if cross_rec is None: continue except Exception: @@ -785,43 +832,36 @@ def trace_v2( if cross_target_id not in nodes: nodes[cross_target_id] = _node_ref_from_row(cross_kind, cross_rec) - cross_edge_id = f"{other_id}:{cross_target_id}:{cross_et}:{hop + 1}" - cross_edge = TraceEdge( + cross_conf = cross_row.get("confidence") + cross_confidence = float(cross_conf) if cross_conf is not None else None + cross_edge = _FlatEdge( from_id=other_id, to_id=cross_target_id, edge_type=cross_et, hop=hop + 1, - parent_edge_id=edge_id, + direction=direction, + confidence=cross_confidence, cross_service_boundary=True, attrs=_edge_attrs_for_row(cross_row), ) edges.append(cross_edge) - edge_id_map[cross_edge_id] = cross_edge visited.add(cross_target_id) - # Track incoming edge for downstream node. - if cross_target_id not in node_to_incoming_edge_id: - node_to_incoming_edge_id[cross_target_id] = cross_edge_id - # When cross_service=True, add downstream node to frontier - # so BFS continues into the downstream service. + if cross_service: new_frontier.add(cross_target_id) - # Do NOT add Client/Producer to frontier — its cross-service - # edges were already queried inline above. continue # --- Standard edge processing --- - # Check budget BEFORE counting (only counts newly discovered nodes). if total_discovered >= max_nodes_discovered: budget_hit = True break total_discovered += 1 - # Load target node. try: - other_kind = _resolve_node_kind(g, other_id) - other_rec = _load_node_record(g, other_id, other_kind) + other_kind = _resolve_node_kind(graph, other_id) + other_rec = _load_node_record(graph, other_id, other_kind) if other_rec is None: continue except Exception: @@ -831,7 +871,6 @@ def trace_v2( if not _node_matches_filter(other_kind, other_rec, nf): continue - # Record target node. if other_id not in nodes: nodes[other_id] = _node_ref_from_row(other_kind, other_rec) @@ -843,24 +882,19 @@ def trace_v2( is_pruned = True nodes_pruned_role += 1 - # Record edge. - edge_id = f"{src_id}:{other_id}:{edge_type}:{hop}" - edge = TraceEdge( + conf = row.get("confidence") + confidence = float(conf) if conf is not None else None + edge = _FlatEdge( from_id=src_id, to_id=other_id, edge_type=edge_type, hop=hop, - parent_edge_id=parent_edge_id, + direction=direction, + confidence=confidence, attrs=_edge_attrs_for_row(row), ) edges.append(edge) - edge_id_map[edge_id] = edge - - # Track incoming edge ID for this node (for parent_edge_id of children). - if other_id not in node_to_incoming_edge_id: - node_to_incoming_edge_id[other_id] = edge_id - # Pruned nodes: edge recorded but NOT added to frontier. if not is_pruned: new_frontier.add(other_id) @@ -870,28 +904,267 @@ def trace_v2( if budget_hit: break + return nodes, edges, total_discovered, actual_depth, budget_hit, nodes_pruned_role, nodes_pruned_fan_out + + +@validate_call(config={"arbitrary_types_allowed": True}) +def trace_v2( + ids: str | list[str], + direction: Literal["in", "out", "both"] = Field(...), + edge_types: list[str] = Field(...), + max_depth: int = 3, + max_paths: int = 20, + max_nodes_discovered: int = 500, + filter: NodeFilter | dict[str, Any] | str | None = None, + edge_filter: EdgeFilter | dict[str, Any] | str | None = None, + prune_roles: list[str] | None = None, + fan_out_cap: int = 5, + collapse_trivial: bool = True, + collapse_roles: list[str] | None = None, + collapse_min_chain_length: int = 1, + include_unresolved: bool = False, + cross_service: bool = False, + min_result_nodes: int = 0, + graph: KuzuGraph | None = None, +) -> TraceOutput: + """Multi-hop BFS traversal with pruning.""" + # Validate required parameters. + if not direction: + return TraceOutput( + success=False, + seed_ids=[], + direction="", + edge_types=[], + message="direction is required (in, out, or both)", + ) + + if not edge_types: + return TraceOutput( + success=False, + seed_ids=[], + direction=direction, + edge_types=[], + message="edge_types is required and non-empty", + ) + + # Validate edge types. + unknown = [et for et in edge_types if et not in _TRACE_EDGE_TYPES] + if unknown: + return TraceOutput( + success=False, + seed_ids=[], + direction=direction, + edge_types=edge_types, + message=( + f"Unknown edge type(s): {unknown}. " + f"Valid types: {sorted(_TRACE_EDGE_TYPES)}. " + "Composed keys (e.g., DECLARES.DECLARES_CLIENT) are not supported." + ), + ) + + # Clamp max_depth. + max_depth = max(1, min(5, int(max_depth))) + + # Clamp max_nodes_discovered. + max_nodes_discovered = max(100, min(2000, int(max_nodes_discovered))) + + # Normalize seed IDs. + seed_ids = [ids] if isinstance(ids, str) else list(ids) + + if not seed_ids: + return TraceOutput( + success=True, + seed_ids=[], + direction=direction, + edge_types=edge_types, + nodes={}, + tree=[], + ranked_leaves=[], + stats=TraceStats(budget_limit=max_nodes_discovered), + ) + + # Validate NodeFilter. + try: + if isinstance(filter, str): + import json + + filter = json.loads(filter) if filter.strip() else None + nf = NodeFilter.model_validate(filter) if filter is not None and not isinstance(filter, NodeFilter) else filter + except Exception as exc: + return TraceOutput( + success=False, + seed_ids=seed_ids, + direction=direction, + edge_types=edge_types, + message=f"Invalid filter: {exc}", + ) + + # Validate EdgeFilter. + try: + if isinstance(edge_filter, str): + import json + + edge_filter = json.loads(edge_filter) if edge_filter.strip() else None + ef = ( + EdgeFilter.model_validate(edge_filter) + if edge_filter is not None and not isinstance(edge_filter, EdgeFilter) + else edge_filter + ) + except Exception as exc: + return TraceOutput( + success=False, + seed_ids=seed_ids, + direction=direction, + edge_types=edge_types, + message=f"Invalid edge_filter: {exc}", + ) + + # Get graph instance. + g = graph or KuzuGraph.get() + + # Normalized prune_roles set. + prune_role_set = set(prune_roles) if prune_roles else set() + + # Collapse roles configuration. + collapse_role_set: set[str | None] | None = None + if collapse_trivial: + if collapse_roles is not None: + collapse_role_set = set(collapse_roles) + else: + collapse_role_set = {"OTHER", None} + + # Determine directions to run. + directions: list[Literal["in", "out"]] + if direction == "both": + directions = ["out", "in"] + else: + directions = [direction] # type: ignore[assignment] + + # Shared visited set for bidirectional. + shared_visited: set[str] = set() + all_nodes: dict[str, NodeRef] = {} + all_edges: list[_FlatEdge] = [] + total_discovered = 0 + actual_depth = 0 + budget_hit = False + total_pruned_role = 0 + total_pruned_fan_out = 0 + for pass_idx, pass_dir in enumerate(directions): + pass_nodes, pass_edges, pass_discovered, pass_depth, pass_budget, pass_pruned_role, pass_pruned_fan_out = _run_bfs( + graph=g, + seed_ids=seed_ids, + direction=pass_dir, + edge_types=edge_types, + max_depth=max_depth, + max_nodes_discovered=max_nodes_discovered, + nf=nf, + ef=ef, + prune_role_set=prune_role_set, + fan_out_cap=fan_out_cap, + cross_service=cross_service, + include_unresolved=include_unresolved, + visited=shared_visited, + ) + + # Merge results. + for nid, nref in pass_nodes.items(): + if nid not in all_nodes: + all_nodes[nid] = nref + + all_edges.extend(pass_edges) + total_discovered += pass_discovered - len(seed_ids) # Don't double-count seeds + actual_depth = max(actual_depth, pass_depth) + budget_hit = budget_hit or pass_budget + total_pruned_role += pass_pruned_role + total_pruned_fan_out += pass_pruned_fan_out + + # Re-count seeds once. + total_discovered = len(all_nodes) + + # min_result_nodes retry. + advisories: list[str] = [] + effective_cap = fan_out_cap + if min_result_nodes > 0 and len(all_nodes) < min_result_nodes: + # One retry with doubled fan_out_cap (clamped by max_nodes_discovered). + effective_cap = min(fan_out_cap * 2, max_nodes_discovered) + if effective_cap > fan_out_cap: + # Re-run with higher cap. + retry_visited: set[str] = set() + retry_nodes: dict[str, NodeRef] = {} + retry_edges: list[_FlatEdge] = [] + retry_total = 0 + retry_depth = 0 + retry_budget = False + retry_pruned_role = 0 + retry_pruned_fan_out = 0 + + for pass_idx, pass_dir in enumerate(directions): + pn, pe, pd, pdep, pb, ppr, pf = _run_bfs( + graph=g, + seed_ids=seed_ids, + direction=pass_dir, + edge_types=edge_types, + max_depth=max_depth, + max_nodes_discovered=max_nodes_discovered, + nf=nf, + ef=ef, + prune_role_set=prune_role_set, + fan_out_cap=effective_cap, + cross_service=cross_service, + include_unresolved=include_unresolved, + visited=retry_visited, + ) + for nid, nref in pn.items(): + if nid not in retry_nodes: + retry_nodes[nid] = nref + retry_edges.extend(pe) + retry_total += pd - len(seed_ids) + retry_depth = max(retry_depth, pdep) + retry_budget = retry_budget or pb + retry_pruned_role += ppr + retry_pruned_fan_out += pf + + retry_total = len(retry_nodes) + if retry_total < min_result_nodes: + advisories.append( + f"min_result_nodes retry with fan_out_cap={effective_cap} still below target " + f"({retry_total} < {min_result_nodes}). Returning available results." + ) + if retry_total >= min_result_nodes or retry_total > len(all_nodes): + all_nodes = retry_nodes + all_edges = retry_edges + total_discovered = retry_total + actual_depth = retry_depth + budget_hit = retry_budget + total_pruned_role = retry_pruned_role + total_pruned_fan_out = retry_pruned_fan_out + # Post-BFS: collapse trivial chains. edges_collapsed = 0 - if collapse_trivial: - edges_collapsed = _collapse_trivial_chains(nodes, edges, edge_id_map) + if collapse_trivial and collapse_role_set is not None: + edges_collapsed = _collapse_trivial_chains( + all_nodes, all_edges, collapse_role_set, collapse_min_chain_length, + ) + + # Build tree from flat edges. + tree = _build_tree(seed_ids, all_nodes, all_edges) + + # Build ranked leaves. + ranked_leaves = _build_ranked_leaves(tree, all_nodes, max_paths) # Build stats. stats = TraceStats( total_nodes_discovered=total_discovered, - total_edges_discovered=len(edges), + total_edges_discovered=len(all_edges), budget_hit=budget_hit, budget_limit=max_nodes_discovered, - nodes_pruned_role=nodes_pruned_role, - nodes_pruned_fan_out=nodes_pruned_fan_out, + nodes_pruned_role=total_pruned_role, + nodes_pruned_fan_out=total_pruned_fan_out, edges_collapsed_trivial=edges_collapsed, - nodes_after_pruning=len(nodes), - edges_after_pruning=len(edges), + nodes_after_pruning=len(all_nodes), + edges_after_pruning=len(all_edges), ) - # Enumerate paths. - paths = _enumerate_paths(nodes, edges, max_paths) - - advisories = [] if budget_hit: advisories.append( f"trace stopped early: discovered {total_discovered} nodes before budget. " @@ -904,17 +1177,18 @@ def trace_v2( direction=direction, edge_types=edge_types, actual_depth=actual_depth, - nodes=nodes, - edges=edges, - paths=paths, + nodes=all_nodes, + tree=tree, + ranked_leaves=ranked_leaves, stats=stats, advisories=advisories, ) __all__ = [ - "TraceEdge", - "TracePath", + "EdgeFromParent", + "TreeNode", + "RankedLeaf", "TraceStats", "TraceOutput", "trace_v2", diff --git a/plans/active/PLAN-TRACE-TOOL-V2.md b/plans/completed/PLAN-TRACE-TOOL-V2.md similarity index 100% rename from plans/active/PLAN-TRACE-TOOL-V2.md rename to plans/completed/PLAN-TRACE-TOOL-V2.md diff --git a/propose/active/TRACE-TOOL-V2-PROPOSE.md b/propose/completed/TRACE-TOOL-V2-PROPOSE.md similarity index 100% rename from propose/active/TRACE-TOOL-V2-PROPOSE.md rename to propose/completed/TRACE-TOOL-V2-PROPOSE.md diff --git a/server.py b/server.py index d9957ade..6763ccd0 100644 --- a/server.py +++ b/server.py @@ -572,7 +572,7 @@ async def resolve( @mcp.tool( name="trace", description=( - "Multi-hop BFS traversal with server-side pruning. Returns pruned path structure in a single call. " + "Multi-hop BFS traversal with server-side pruning. Returns nested tree structure in a single call. " "Use `trace` instead of multiple `neighbors` calls when: (a) the question implies a path or chain " "(e.g. 'trace from controller to database', 'what happens when POST /api/orders is called'), " "(b) you need impact analysis ('who depends on X'), (c) you need to cross service boundaries " @@ -584,25 +584,28 @@ async def resolve( "`prune_roles` is a soft gate: edges to pruned-role nodes are recorded but BFS stops traversing " "through them (agent sees the connection but traversal focuses on higher-signal paths). " "`fan_out_cap` limits per-node edge expansion; scaffolding edges (DECLARES_CLIENT, DECLARES_PRODUCER) " - "are exempt. `collapse_trivial` merges wrapper chains (A→B→C where B is trivial). " - "Result: `nodes` dict (id → NodeRef), `edges` list with BFS metadata (hop, parent_edge_id, " - "collapsed, cross_service_boundary), ranked `paths` (root-to-leaf), and `stats` with pruning counts. " + "are exempt. `collapse_trivial` merges wrapper chains (A→B→C where B is trivial); " + "`collapse_roles` and `collapse_min_chain_length` configure which roles are collapsed and minimum chain length. " + "`min_result_nodes` triggers a retry with doubled fan_out_cap if initial result is below target. " + "Result: `nodes` dict (id → NodeRef), nested `tree` (TreeNodes with edge_from_parent metadata), " + "`ranked_leaves` (scored leaf nodes), and `stats` with pruning counts. " "Cross-service boundary: by default BFS stops at service boundaries. " - "Set `cross_service=True` to continue traversal through HTTP_CALLS/ASYNC_CALLS boundaries." + "Set `cross_service=True` to continue traversal through HTTP_CALLS/ASYNC_CALLS boundaries. " + "`direction='both'` runs bidirectional traversal (out + in) with shared visited set for impact analysis." ), ) async def trace( ids: str | list[str] = Field( description="Seed node IDs (single string or list). Differs from neighbors (single ID) — trace supports multi-seed for impact analysis.", ), - direction: Literal["in", "out"] = Field( - description="Traversal direction: in (callers/dependents) or out (callees/dependencies). Required — no default.", + direction: Literal["in", "out", "both"] = Field( + description="Traversal direction: in (callers/dependents), out (callees/dependencies), or both (bidirectional for impact analysis). Required — no default.", ), edge_types: list[str] = Field( description="Edge types to traverse (stored labels only: CALLS, IMPLEMENTS, OVERRIDES, EXPOSES, HTTP_CALLS, ASYNC_CALLS, etc.). Required non-empty. No composed dot-keys.", ), max_depth: int = Field(default=3, description="Max BFS hops (1-5, default 3)"), - max_paths: int = Field(default=20, description="Max root-to-leaf paths to return"), + max_paths: int = Field(default=20, description="Max ranked leaves to return"), max_nodes_discovered: int = Field( default=500, description="Node discovery budget before pruning (100-2000)", ), @@ -624,6 +627,13 @@ async def trace( collapse_trivial: bool = Field( default=True, description="Collapse wrapper chains (A→B→C where B is trivial intermediate)", ), + collapse_roles: list[str] | None = Field( + default=None, + description="Roles to collapse as trivial intermediates (default: ['OTHER']). Only takes effect when collapse_trivial=True.", + ), + collapse_min_chain_length: int = Field( + default=1, description="Minimum chain length for collapse (default 1). Set to 2 to skip single-intermediate collapses.", + ), include_unresolved: bool = Field( default=False, description="Include UnresolvedCallSite edges (CALLS out only)", @@ -632,6 +642,10 @@ async def trace( default=False, description="Continue BFS through service boundaries (HTTP_CALLS/ASYNC_CALLS). Default: stop at boundaries.", ), + min_result_nodes: int = Field( + default=0, + description="Minimum result nodes target. If initial BFS produces fewer, retries with doubled fan_out_cap (one retry max).", + ), ) -> mcp_trace.TraceOutput: return await asyncio.to_thread( mcp_trace.trace_v2, @@ -646,8 +660,11 @@ async def trace( prune_roles, fan_out_cap if fan_out_cap is not None else 5, collapse_trivial, + collapse_roles, + collapse_min_chain_length, include_unresolved, cross_service, + min_result_nodes, None, ) diff --git a/skills/explore-codebase/SKILL.md b/skills/explore-codebase/SKILL.md index ee4d2148..886c0649 100644 --- a/skills/explore-codebase/SKILL.md +++ b/skills/explore-codebase/SKILL.md @@ -179,7 +179,7 @@ Prefer **`resolve` → `describe(id=…)`** over **`describe(fqn=…)`** when an | Handler for route | route id | `neighbors(ids, "in", ["EXPOSES"])` | | Who implements interface T? | type symbol id | `neighbors(ids, "in", ["IMPLEMENTS"])` | | Where is T injected | type symbol id | `neighbors(ids, "in", ["INJECTS"])` | -| Impact / "what breaks if I change X"? | no magic tool | loop `neighbors` `in` with `CALLS`, `INJECTS`, … until bounded | +| Impact / "what breaks if I change X"? | `trace(id, "both", ["CALLS","OVERRIDES"], max_depth=3)` | `describe` on callers | | "What happens when route R is called?" | `find(kind="route")` then `trace(route_id, "out", ["EXPOSES","CALLS"], max_depth=4)` | `describe` on key nodes | | "Impact of changing method M" | `resolve` / `find` then `trace(id, "in", ["CALLS","OVERRIDES"], max_depth=3)` | `describe` on callers | | "Trace from X to database" | `trace(id, "out", ["CALLS"], max_depth=4, prune_roles=["DTO","EXCEPTION"])` | `neighbors` for pruned detail | @@ -229,11 +229,11 @@ Returns **edges** with `attrs` (`confidence`, `strategy`, `match`, … on cross- ### `trace` -Multi-hop BFS with pruning. Args: `ids` (string or list), **`direction`**, **`edge_types`** (stored labels only — no composed dot-keys), `max_depth` (default 3, clamped 1–5), `max_paths` (default 20), `max_nodes_discovered` (default 500, clamped 100–2000), optional `filter` (NodeFilter), optional `edge_filter` (CALLS only), optional `prune_roles` (soft gate — edges recorded, frontier stops), `fan_out_cap` (default 5, scaffolding edges exempt), `collapse_trivial` (default true), `include_unresolved` (default false), `cross_service` (default false — set true to continue BFS through HTTP_CALLS/ASYNC_CALLS boundaries into downstream services). +Multi-hop BFS with pruning. Args: `ids` (string or list), **`direction`** (`in` | `out` | `both`), **`edge_types`** (stored labels only — no composed dot-keys), `max_depth` (default 3, clamped 1–5), `max_paths` (default 20), `max_nodes_discovered` (default 500, clamped 100–2000), optional `filter` (NodeFilter), optional `edge_filter` (CALLS only), optional `prune_roles` (soft gate — edges recorded, frontier stops), `fan_out_cap` (default 5, scaffolding edges exempt), `collapse_trivial` (default true), `collapse_roles` (roles to collapse as trivial intermediates, default `["OTHER"]`; only effective when `collapse_trivial=true`), `collapse_min_chain_length` (minimum chain length for collapse, default 1), `include_unresolved` (default false), `cross_service` (default false — set true to continue BFS through HTTP_CALLS/ASYNC_CALLS boundaries into downstream services), `min_result_nodes` (minimum result nodes target; retries with doubled `fan_out_cap` if below target, default 0). -Returns `TraceOutput`: `success`, `seed_ids`, `direction`, `edge_types`, `actual_depth`, `nodes` (dict of id→NodeRef), `edges` (list of `TraceEdge`), `paths` (list of `TracePath`), `stats` (`TraceStats`), `advisories`. +Returns `TraceOutput`: `success`, `seed_ids`, `direction`, `edge_types`, `actual_depth`, `nodes` (dict of id→NodeRef), `tree` (nested `TreeNode` list — one per seed; each `TreeNode` has `id`, `edge_from_parent` with `direction`, `edge_type`, `hop`, `confidence`, `cross_service_boundary`, `attrs`; `children`; `collapsed`, `collapsed_intermediates`), `ranked_leaves` (scored leaf nodes with `node_id`, `depth`, `leaf_role`, `score`), `stats` (`TraceStats`), `advisories`. -**`TraceEdge`**: `from_id`, `to_id`, `edge_type`, `hop` (BFS depth), `parent_edge_id` (nullable), `collapsed` (bool), `collapsed_intermediates` (list of node ids), `cross_service_boundary` (bool), `attrs`. +**`direction="both"`**: bidirectional traversal (out + in) with shared visited set. Tree contains children from both directions; `edge_from_parent.direction` distinguishes them. Use for impact analysis in one call. **When to use `trace` vs `neighbors`:** - Use `neighbors` for single-hop adjacency where you want the full unfiltered result. @@ -305,7 +305,7 @@ These patterns combine the five tools above. Use the decision tree to pick the r | All inbound to route R | `neighbors(route_id, "in", ["HTTP_CALLS","ASYNC_CALLS","EXPOSES"])` | | Implementors of interface T | `neighbors(type_id, "in", ["IMPLEMENTS"])` | | Where is T injected | `neighbors(type_id, "in", ["INJECTS"])` | -| Impact of changing X | `resolve` → `describe` → bounded `neighbors(in, ["CALLS","INJECTS","IMPLEMENTS","EXTENDS"])` depth ≤2 | +| Impact of changing X | `resolve` → `trace(id, "both", ["CALLS","OVERRIDES"], max_depth=3)` | `describe` on key callers | ## Canonical workflow: "explain feature X" diff --git a/tests/test_mcp_hints.py b/tests/test_mcp_hints.py index 30e576a8..794d5cab 100644 --- a/tests/test_mcp_hints.py +++ b/tests/test_mcp_hints.py @@ -1336,7 +1336,7 @@ def test_hint_trace_budget_hit() -> None: "total_nodes_discovered": 500, "nodes_after_pruning": 120, }, - "edges": [], + "tree": [], "nodes": {}, "seed_ids": ["sym:com.example.Svc#run()"], "direction": "out", @@ -1358,7 +1358,7 @@ def test_hint_trace_pruned_edges() -> None: "edges_collapsed_trivial": 2, "total_nodes_discovered": 100, }, - "edges": [], + "tree": [], "nodes": {}, "seed_ids": ["sym:a"], "direction": "out", @@ -1370,7 +1370,7 @@ def test_hint_trace_pruned_edges() -> None: def test_hint_trace_cross_service_boundary() -> None: - """generate_hints('trace', {edges with cross_service_boundary=True}) returns cross-service hint.""" + """generate_hints('trace', {tree with cross_service_boundary=True}) returns cross-service hint.""" struct, advisories = generate_hints("trace", { "success": True, "stats": { @@ -1380,12 +1380,43 @@ def test_hint_trace_cross_service_boundary() -> None: "nodes_pruned_fan_out": 0, "edges_collapsed_trivial": 0, }, - "edges": [ + "tree": [ { - "from_id": "client:svc:PaymentClient", - "to_id": "route:payment-service:/api/payments:POST", - "cross_service_boundary": True, - "attrs": {"confidence": 0.85, "strategy": "URI_PATH_MATCH"}, + "id": "sym:a", + "edge_from_parent": None, + "children": [ + { + "id": "client:svc:PaymentClient", + "edge_from_parent": { + "direction": "out", + "edge_type": "DECLARES_CLIENT", + "hop": 0, + "confidence": None, + "cross_service_boundary": False, + "attrs": {}, + }, + "children": [ + { + "id": "route:payment-service:/api/payments:POST", + "edge_from_parent": { + "direction": "out", + "edge_type": "HTTP_CALLS", + "hop": 1, + "confidence": 0.85, + "cross_service_boundary": True, + "attrs": {"confidence": 0.85, "strategy": "URI_PATH_MATCH"}, + }, + "children": [], + "collapsed": False, + "collapsed_intermediates": [], + }, + ], + "collapsed": False, + "collapsed_intermediates": [], + }, + ], + "collapsed": False, + "collapsed_intermediates": [], }, ], "nodes": { diff --git a/tests/test_mcp_trace.py b/tests/test_mcp_trace.py index 1f63e6d5..87ade661 100644 --- a/tests/test_mcp_trace.py +++ b/tests/test_mcp_trace.py @@ -1,4 +1,5 @@ -"""Tests for mcp_trace.py (PR-TRACE-1a core BFS + PR-TRACE-1b pruning/collapsing/cross-service). +"""Tests for mcp_trace.py (PR-TRACE-V2: tree output, configurable collapse, +source-relative ranking, bidirectional traversal, min_result_nodes retry). All tests use the bank-chat kuzu_graph session fixture from conftest.py. """ @@ -7,7 +8,7 @@ import pytest from kuzu_queries import KuzuGraph -from mcp_trace import trace_v2 +from mcp_trace import trace_v2, TreeNode from mcp_v2 import NodeFilter @@ -48,8 +49,27 @@ def _find_method_with_multiple_callees(kuzu_graph: KuzuGraph, min_callees: int = return None +def _walk_tree(tree: list[TreeNode]) -> list[TreeNode]: + """Flatten tree nodes for inspection.""" + result: list[TreeNode] = [] + stack = list(tree) + while stack: + node = stack.pop() + result.append(node) + stack.extend(node.children) + return result + + +def _find_tree_node_by_id(tree: list[TreeNode], node_id: str) -> TreeNode | None: + """Find a tree node by its id.""" + for node in _walk_tree(tree): + if node.id == node_id: + return node + return None + + def test_trace_outbound_calls_depth_2(kuzu_graph: KuzuGraph) -> None: - """Traces from a method via CALLS out, depth 2, returns edges at hop 0 and hop 1.""" + """Traces from a method via CALLS out, depth 2, returns tree with nested children.""" seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) if seed_id is None: pytest.skip("No method with multiple callees in fixture") @@ -61,13 +81,13 @@ def test_trace_outbound_calls_depth_2(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - assert len(out.edges) > 0 assert out.seed_ids == [seed_id] assert out.direction == "out" assert out.edge_types == ["CALLS"] - # Check that we have edges at hop 0 and possibly hop 1. - hops = {e.hop for e in out.edges} - assert 0 in hops and hops <= {0, 1} + # Tree should have the seed as root with children. + assert len(out.tree) >= 1 + assert out.tree[0].id == seed_id + assert len(out.tree[0].children) > 0 def test_trace_inbound_callers_depth_2(kuzu_graph: KuzuGraph) -> None: @@ -85,10 +105,12 @@ def test_trace_inbound_callers_depth_2(kuzu_graph: KuzuGraph) -> None: assert out.success is True assert out.seed_ids == [seed_id] assert out.direction == "in" + assert len(out.tree) >= 1 + assert out.tree[0].id == seed_id def test_trace_max_paths_cap(kuzu_graph: KuzuGraph) -> None: - """Result paths list does not exceed max_paths.""" + """Result ranked_leaves list does not exceed max_paths.""" seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=5) if seed_id is None: pytest.skip("No method with multiple callees in fixture") @@ -101,7 +123,7 @@ def test_trace_max_paths_cap(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - assert len(out.paths) <= 5 + assert len(out.ranked_leaves) <= 5 def test_trace_budget_stops_early(kuzu_graph: KuzuGraph) -> None: @@ -114,11 +136,10 @@ def test_trace_budget_stops_early(kuzu_graph: KuzuGraph) -> None: direction="out", edge_types=["CALLS"], max_depth=5, - max_nodes_discovered=100, # Use minimum valid value (clamped to 100) + max_nodes_discovered=100, graph=kuzu_graph, ) assert out.success is True - # If we discovered more than the budget (100), budget_hit should be True. if out.stats.total_nodes_discovered >= 100: assert out.stats.budget_hit is True assert any("budget" in adv for adv in out.advisories) @@ -132,7 +153,6 @@ def test_trace_depth_1_equivalent_to_neighbors(kuzu_graph: KuzuGraph) -> None: if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # Get neighbors result. neigh_out = neighbors_v2( ids=seed_id, direction="out", @@ -142,7 +162,6 @@ def test_trace_depth_1_equivalent_to_neighbors(kuzu_graph: KuzuGraph) -> None: ) assert neigh_out.success is True - # Get trace result. trace_out = trace_v2( ids=seed_id, direction="out", @@ -152,14 +171,9 @@ def test_trace_depth_1_equivalent_to_neighbors(kuzu_graph: KuzuGraph) -> None: ) assert trace_out.success is True - # Compare node IDs (trace nodes dict vs neighbors results). trace_node_ids = set(trace_out.nodes.keys()) neigh_node_ids = {e.other.id for e in neigh_out.results} - - # Seed is in trace nodes, neighbors doesn't include seed. trace_node_ids.discard(seed_id) - - # They should have significant overlap (allowing for filter differences). assert len(trace_node_ids & neigh_node_ids) >= min(len(trace_node_ids), len(neigh_node_ids)) * 0.8 @@ -176,14 +190,12 @@ def test_trace_stats_counts(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - assert out.stats.edges_after_pruning == len(out.edges) assert out.stats.nodes_after_pruning == len(out.nodes) - assert out.stats.total_edges_discovered == len(out.edges) assert out.stats.total_nodes_discovered >= len(out.nodes) def test_trace_empty_seed(kuzu_graph: KuzuGraph) -> None: - """Empty seed ids returns success=True, nodes={}, edges=[], paths=[].""" + """Empty seed ids returns success=True, tree=[], ranked_leaves=[].""" out = trace_v2( ids=[], direction="out", @@ -193,8 +205,8 @@ def test_trace_empty_seed(kuzu_graph: KuzuGraph) -> None: assert out.success is True assert out.seed_ids == [] assert out.nodes == {} - assert out.edges == [] - assert out.paths == [] + assert out.tree == [] + assert out.ranked_leaves == [] def test_trace_single_string_seed(kuzu_graph: KuzuGraph) -> None: @@ -203,14 +215,14 @@ def test_trace_single_string_seed(kuzu_graph: KuzuGraph) -> None: if seed_id is None: pytest.skip("No method with outbound calls in fixture") out = trace_v2( - ids=seed_id, # Pass as string, not list + ids=seed_id, direction="out", edge_types=["CALLS"], graph=kuzu_graph, ) assert out.success is True assert out.seed_ids == [seed_id] - assert seed_id in out.nodes or len(out.edges) >= 0 + assert seed_id in out.nodes def test_trace_multiple_seeds(kuzu_graph: KuzuGraph) -> None: @@ -228,7 +240,8 @@ def test_trace_multiple_seeds(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True assert set(out.seed_ids) == {seed_id1, seed_id2} - # Shared visited set means we don't double-count nodes. + # Multi-seed: tree has one root per seed. + assert len(out.tree) >= 1 def test_trace_invalid_edge_type(kuzu_graph: KuzuGraph) -> None: @@ -254,7 +267,6 @@ def test_trace_direction_required(kuzu_graph: KuzuGraph) -> None: seed_id = _find_method_with_outbound_calls(kuzu_graph) if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # Pydantic validation rejects empty string before our code runs. with pytest.raises(ValidationError, match="direction"): trace_v2( ids=seed_id, @@ -272,7 +284,7 @@ def test_trace_edge_types_required(kuzu_graph: KuzuGraph) -> None: out = trace_v2( ids=seed_id, direction="out", - edge_types=[], # Empty list + edge_types=[], graph=kuzu_graph, ) assert out.success is False @@ -285,7 +297,6 @@ def test_trace_max_depth_clamped(kuzu_graph: KuzuGraph) -> None: seed_id = _find_method_with_outbound_calls(kuzu_graph) if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # Test max_depth=0 (clamped to 1). out = trace_v2( ids=seed_id, direction="out", @@ -296,7 +307,6 @@ def test_trace_max_depth_clamped(kuzu_graph: KuzuGraph) -> None: assert out.success is True assert out.actual_depth <= 1 - # Test max_depth=10 (clamped to 5). out = trace_v2( ids=seed_id, direction="out", @@ -313,7 +323,6 @@ def test_trace_budget_clamped(kuzu_graph: KuzuGraph) -> None: seed_id = _find_method_with_outbound_calls(kuzu_graph) if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # Test budget=50 (clamped to 100). out = trace_v2( ids=seed_id, direction="out", @@ -324,7 +333,6 @@ def test_trace_budget_clamped(kuzu_graph: KuzuGraph) -> None: assert out.success is True assert out.stats.budget_limit >= 100 - # Test budget=5000 (clamped to 2000). out = trace_v2( ids=seed_id, direction="out", @@ -338,7 +346,6 @@ def test_trace_budget_clamped(kuzu_graph: KuzuGraph) -> None: def test_trace_visited_set_no_cycles(kuzu_graph: KuzuGraph) -> None: """BFS does not revisit nodes even if cycles exist in the graph.""" - # Find a cycle: A -> B -> A. rows = kuzu_graph._rows( # noqa: SLF001 """ MATCH (a:Symbol)-[:CALLS]->(b:Symbol)-[:CALLS]->(a:Symbol) @@ -357,19 +364,17 @@ def test_trace_visited_set_no_cycles(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - # Count unique from_id -> to_id pairs. - edge_pairs = {(e.from_id, e.to_id) for e in out.edges} - # No duplicate edges despite cycles. - assert len(edge_pairs) == len(out.edges) + # No duplicate node IDs in tree walk. + all_nodes = _walk_tree(out.tree) + node_ids = [n.id for n in all_nodes] + assert len(node_ids) == len(set(node_ids)) def test_trace_filter_applied(kuzu_graph: KuzuGraph) -> None: """NodeFilter restricts discovered nodes (hard gate — excluded entirely).""" - # Find a method with outbound calls. seed_id = _find_method_with_outbound_calls(kuzu_graph) if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # First, get unfiltered count. unfiltered = trace_v2( ids=seed_id, direction="out", @@ -378,9 +383,8 @@ def test_trace_filter_applied(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert unfiltered.success is True - unfiltered_count = len(unfiltered.edges) + unfiltered_count = len(unfiltered.nodes) - 1 # Exclude seed - # Now filter by role (e.g., only SERVICE). filtered = trace_v2( ids=seed_id, direction="out", @@ -390,8 +394,8 @@ def test_trace_filter_applied(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert filtered.success is True - # Filtered result should have <= unfiltered edges. - assert len(filtered.edges) <= unfiltered_count + filtered_count = len(filtered.nodes) - 1 + assert filtered_count <= unfiltered_count def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: @@ -400,7 +404,6 @@ def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # First, discover what roles exist in the result. baseline = trace_v2( ids=seed_id, direction="out", @@ -410,7 +413,6 @@ def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: ) assert baseline.success is True - # Find a role present in the result to test against (exclude seed from consideration). roles_in_result = { n.role for nid, n in baseline.nodes.items() if n.role and nid != seed_id @@ -420,7 +422,6 @@ def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: test_role = next(iter(roles_in_result)) - # NodeFilter exclude_roles: hard gate — nodes and edges removed entirely. filtered = trace_v2( ids=seed_id, direction="out", @@ -430,12 +431,10 @@ def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert filtered.success is True - # No non-seed nodes with the excluded role should appear. assert not any( n.role == test_role for nid, n in filtered.nodes.items() if nid != seed_id ) - # prune_roles: soft gate — edges recorded, frontier stops through pruned nodes. pruned = trace_v2( ids=seed_id, direction="out", @@ -445,15 +444,12 @@ def test_trace_filter_vs_prune_roles(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert pruned.success is True - # Pruned nodes ARE in the result (edges recorded). - assert any(n.role == test_role for n in pruned.nodes.values()) or len(pruned.edges) >= 0 - # prune_roles result should have more edges than filtered (soft vs hard gate). - assert len(pruned.edges) >= len(filtered.edges) + assert any(n.role == test_role for n in pruned.nodes.values()) or len(pruned.nodes) >= 0 + assert len(pruned.nodes) >= len(filtered.nodes) def test_trace_edge_filter_calls(kuzu_graph: KuzuGraph) -> None: """EdgeFilter with min_confidence filters CALLS edges during traversal.""" - # Find a method with outbound calls (any confidence). rows = kuzu_graph._rows( # noqa: SLF001 """ MATCH (m:Symbol)-[c:CALLS]->(other:Symbol) @@ -468,7 +464,6 @@ def test_trace_edge_filter_calls(kuzu_graph: KuzuGraph) -> None: from mcp_v2 import EdgeFilter - # Without filter. unfiltered = trace_v2( ids=seed_id, direction="out", @@ -478,7 +473,6 @@ def test_trace_edge_filter_calls(kuzu_graph: KuzuGraph) -> None: ) assert unfiltered.success is True - # With min_confidence filter. filtered = trace_v2( ids=seed_id, direction="out", @@ -488,13 +482,11 @@ def test_trace_edge_filter_calls(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert filtered.success is True - # Filtered should have fewer or equal edges. - assert len(filtered.edges) <= len(unfiltered.edges) + assert len(filtered.nodes) <= len(unfiltered.nodes) def test_trace_include_unresolved(kuzu_graph: KuzuGraph) -> None: - """UnresolvedCallSite edges are interleaved when include_unresolved=True, edge_types=['CALLS'], direction='out'.""" - # Find a method with unresolved call sites. + """UnresolvedCallSite edges are interleaved when include_unresolved=True.""" rows = kuzu_graph._rows( # noqa: SLF001 """ MATCH (m:Symbol)-[:UNRESOLVED_AT]->(:UnresolvedCallSite) @@ -506,7 +498,6 @@ def test_trace_include_unresolved(kuzu_graph: KuzuGraph) -> None: pytest.skip("No unresolved call sites in fixture") seed_id = str(rows[0]["id"]) - # Without include_unresolved. without = trace_v2( ids=seed_id, direction="out", @@ -517,7 +508,6 @@ def test_trace_include_unresolved(kuzu_graph: KuzuGraph) -> None: ) assert without.success is True - # With include_unresolved=True. with_unresolved = trace_v2( ids=seed_id, direction="out", @@ -527,12 +517,11 @@ def test_trace_include_unresolved(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert with_unresolved.success is True - # Unresolved version should have >= edges than non-unresolved. - assert len(with_unresolved.edges) >= len(without.edges) + assert len(with_unresolved.nodes) >= len(without.nodes) def test_trace_paths_root_to_leaf(kuzu_graph: KuzuGraph) -> None: - """Each path starts at a seed and ends at a leaf with no further outbound edges in the result.""" + """Each ranked_leaf has a tree path from seed.""" seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=3) if seed_id is None: pytest.skip("No method with multiple callees in fixture") @@ -546,22 +535,16 @@ def test_trace_paths_root_to_leaf(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - for path in out.paths: - if not path.edges: - continue - # First edge starts at seed. - assert path.edges[0].from_id in out.seed_ids - # Last edge's target is the leaf. - leaf_id = path.edges[-1].to_id - assert path.leaf.id == leaf_id - # In the result set, leaves might not have outgoing edges. - # (They might in the graph, but not in the pruned result.) - # This is a soft assertion because the result might be limited. + for leaf in out.ranked_leaves: + # Leaf node should be in the tree walk. + found = _find_tree_node_by_id(out.tree, leaf.node_id) + assert found is not None, f"Leaf {leaf.node_id} not found in tree" + # Leaf should have no children. + assert len(found.children) == 0 def test_trace_overrides_interface_resolution(kuzu_graph: KuzuGraph) -> None: """Traces from interface method via OVERRIDES out, reaches implementation method.""" - # Find a type Symbol (class/interface) with OVERRIDES relationships. rows = kuzu_graph._rows( # noqa: SLF001 """ MATCH (iface:Symbol)-[:DECLARES]->(m:Symbol)<-[:OVERRIDES]-(impl:Symbol) @@ -582,86 +565,9 @@ def test_trace_overrides_interface_resolution(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - # Should have at least one DECLARES or OVERRIDES edge. - assert any(e.edge_type in ("DECLARES", "OVERRIDES") for e in out.edges) - - -def test_trace_parent_edge_id_seed_null(kuzu_graph: KuzuGraph) -> None: - """Seed edges (hop 0) have parent_edge_id: null.""" - seed_id = _find_method_with_outbound_calls(kuzu_graph) - if seed_id is None: - pytest.skip("No method with outbound calls in fixture") - out = trace_v2( - ids=seed_id, - direction="out", - edge_types=["CALLS"], - max_depth=1, - graph=kuzu_graph, - ) - assert out.success is True - for e in out.edges: - if e.hop == 0: - assert e.parent_edge_id is None - - -def test_trace_parent_edge_id_chain(kuzu_graph: KuzuGraph) -> None: - """Non-seed edges have parent_edge_id pointing to a valid edge in the result.""" - seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) - if seed_id is None: - pytest.skip("No method with multiple callees in fixture") - out = trace_v2( - ids=seed_id, - direction="out", - edge_types=["CALLS"], - max_depth=2, - graph=kuzu_graph, - ) - assert out.success is True - - for e in out.edges: - if e.hop > 0: - # parent_edge_id should be a valid edge identifier that matches an edge in the result. - if e.parent_edge_id: - # Parse the edge_id format: from_id:to_id:edge_type:hop - parts = e.parent_edge_id.split(":") - assert len(parts) == 4, f"Invalid parent_edge_id format: {e.parent_edge_id}" - parent_from_id, parent_to_id, parent_edge_type, parent_hop = parts - # Verify parent edge exists in result and parent.to_id == e.from_id - parent_exists = any( - p.from_id == parent_from_id - and p.to_id == parent_to_id - and p.edge_type == parent_edge_type - and p.hop == int(parent_hop) - for p in out.edges - ) - assert parent_exists, f"Parent edge {e.parent_edge_id} not found in result" - # Verify the parent edge reaches the current node's from_id - assert parent_to_id == e.from_id, f"Parent edge {e.parent_edge_id} to_id != {e.from_id}" - - -# --------------------------------------------------------------------------- -# PR-TRACE-1b tests: pruning, collapsing, cross-service -# --------------------------------------------------------------------------- - - -def _find_method_with_declares_client(kuzu_graph: KuzuGraph) -> str | None: - """Find a method that has a DECLARES_CLIENT edge.""" - rows = kuzu_graph._rows( # noqa: SLF001 - "MATCH (m:Symbol)-[:DECLARES_CLIENT]->(c:Client) RETURN m.id AS id LIMIT 1" - ) - if rows: - return str(rows[0]["id"]) - return None - - -def _find_method_with_declares_producer(kuzu_graph: KuzuGraph) -> str | None: - """Find a method that has a DECLARES_PRODUCER edge.""" - rows = kuzu_graph._rows( # noqa: SLF001 - "MATCH (m:Symbol)-[:DECLARES_PRODUCER]->(p:Producer) RETURN m.id AS id LIMIT 1" - ) - if rows: - return str(rows[0]["id"]) - return None + # Should have at least one child edge via DECLARES or OVERRIDES. + all_nodes = _walk_tree(out.tree) + assert any(n.edge_from_parent is not None and n.edge_from_parent.edge_type in ("DECLARES", "OVERRIDES") for n in all_nodes) def test_trace_prune_roles(kuzu_graph: KuzuGraph) -> None: @@ -670,7 +576,6 @@ def test_trace_prune_roles(kuzu_graph: KuzuGraph) -> None: if seed_id is None: pytest.skip("No method with outbound calls in fixture") - # Discover what roles exist at depth 2. full = trace_v2( ids=seed_id, direction="out", @@ -684,7 +589,6 @@ def test_trace_prune_roles(kuzu_graph: KuzuGraph) -> None: if len(roles_present) < 2: pytest.skip("Need at least 2 roles to test pruning") - # Pick a role to prune. prune_target = sorted(roles_present)[-1] pruned = trace_v2( @@ -697,10 +601,16 @@ def test_trace_prune_roles(kuzu_graph: KuzuGraph) -> None: ) assert pruned.success is True assert pruned.stats.nodes_pruned_role >= 0 - - # Pruned result should have fewer or equal nodes (frontier stops at pruned nodes). + # Pruned result should have fewer or equal nodes. assert len(pruned.nodes) <= len(full.nodes) + # Pruned-role nodes should be leaves in the tree (no children). + for node in _walk_tree(pruned.tree): + if node.id != seed_id: + node_ref = pruned.nodes.get(node.id) + if node_ref and node_ref.role == prune_target: + assert len(node.children) == 0, f"Pruned node {node.id} should have no children" + def test_trace_fan_out_cap(kuzu_graph: KuzuGraph) -> None: """With fan_out_cap, a node with many outbound edges returns at most cap edges from that node.""" @@ -719,9 +629,10 @@ def test_trace_fan_out_cap(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - # Count edges from the seed node. - seed_edges = [e for e in out.edges if e.from_id == seed_id] - assert len(seed_edges) <= cap + # Seed's children count should be at most cap. + seed_tree_node = _find_tree_node_by_id(out.tree, seed_id) + if seed_tree_node: + assert len(seed_tree_node.children) <= cap assert out.stats.nodes_pruned_fan_out >= 0 @@ -731,7 +642,6 @@ def test_trace_fan_out_cap_scaffolding_exempt(kuzu_graph: KuzuGraph) -> None: if seed_id is None: pytest.skip("No method with DECLARES_CLIENT in fixture") - # Use very tight cap — scaffolding should still appear. out = trace_v2( ids=seed_id, direction="out", @@ -741,13 +651,14 @@ def test_trace_fan_out_cap_scaffolding_exempt(kuzu_graph: KuzuGraph) -> None: graph=kuzu_graph, ) assert out.success is True - # Should have DECLARES_CLIENT edges even with cap=1. - scaffolding_edges = [e for e in out.edges if e.edge_type in ("DECLARES_CLIENT", "DECLARES_PRODUCER")] - assert len(scaffolding_edges) >= 1 + # Should have scaffolding edges even with cap=1. + all_nodes = _walk_tree(out.tree) + scaffolding = [n for n in all_nodes if n.edge_from_parent and n.edge_from_parent.edge_type in ("DECLARES_CLIENT", "DECLARES_PRODUCER")] + assert len(scaffolding) >= 1 def test_trace_collapse_trivial(kuzu_graph: KuzuGraph) -> None: - """Wrapper chain A→B→C where B has degree 2 is collapsed to A→C with collapsed=True.""" + """Wrapper chain A→B→C where B is trivial is collapsed; intermediates retained in nodes.""" seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) if seed_id is None: pytest.skip("No method with multiple callees in fixture") @@ -758,23 +669,23 @@ def test_trace_collapse_trivial(kuzu_graph: KuzuGraph) -> None: edge_types=["CALLS"], max_depth=3, collapse_trivial=True, - fan_out_cap=0, # No fan-out cap + fan_out_cap=0, graph=kuzu_graph, ) assert out.success is True - # If any collapsing happened, verify the markers. - collapsed_edges = [e for e in out.edges if e.collapsed] - if collapsed_edges: - for ce in collapsed_edges: - assert ce.collapsed is True - assert len(ce.collapsed_intermediates) > 0 - assert out.stats.edges_collapsed_trivial == len(collapsed_edges) + # Check collapsed nodes in tree. + collapsed_nodes = [n for n in _walk_tree(out.tree) if n.collapsed] + if collapsed_nodes: + for cn in collapsed_nodes: + assert cn.collapsed is True + assert len(cn.collapsed_intermediates) > 0 + assert out.stats.edges_collapsed_trivial == len(collapsed_nodes) - # Collapsed intermediates should NOT be in nodes dict. - for ce in collapsed_edges: - for inter_id in ce.collapsed_intermediates: - assert inter_id not in out.nodes + # Collapsed intermediates ARE in nodes dict (v2). + for cn in collapsed_nodes: + for inter_id in cn.collapsed_intermediates: + assert inter_id in out.nodes, f"Collapsed intermediate {inter_id} should be in nodes dict" def test_trace_collapse_trivial_disabled(kuzu_graph: KuzuGraph) -> None: @@ -794,38 +705,11 @@ def test_trace_collapse_trivial_disabled(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True assert out.stats.edges_collapsed_trivial == 0 - assert not any(e.collapsed for e in out.edges) - - -def test_trace_collapse_parent_edge_id_consistency(kuzu_graph: KuzuGraph) -> None: - """After collapsing A→B→C to A→C, child edges referencing B→C now reference collapsed A→C edge.""" - seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) - if seed_id is None: - pytest.skip("No method with multiple callees in fixture") - - out = trace_v2( - ids=seed_id, - direction="out", - edge_types=["CALLS"], - max_depth=3, - collapse_trivial=True, - fan_out_cap=0, - graph=kuzu_graph, - ) - assert out.success is True - - # Verify parent_edge_id consistency: every non-null parent_edge_id - # references an edge that exists in the result. - edge_ids = {f"{e.from_id}:{e.to_id}:{e.edge_type}:{e.hop}" for e in out.edges} - for e in out.edges: - if e.parent_edge_id: - assert e.parent_edge_id in edge_ids, ( - f"parent_edge_id {e.parent_edge_id} not in result edge_ids" - ) + assert not any(n.collapsed for n in _walk_tree(out.tree)) def test_trace_cross_service_http(kuzu_graph: KuzuGraph) -> None: - """Traces through DECLARES_CLIENT → HTTP_CALLS; stops at Route boundary with cross_service_boundary=True.""" + """Traces through DECLARES_CLIENT → HTTP_CALLS; stops at Route boundary.""" seed_id = _find_method_with_declares_client(kuzu_graph) if seed_id is None: pytest.skip("No method with DECLARES_CLIENT in fixture") @@ -840,13 +724,12 @@ def test_trace_cross_service_http(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - # Should have cross-service boundary edges. - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if xs_edges: - for xe in xs_edges: - assert xe.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") - # Downstream target should be in nodes dict. - assert xe.to_id in out.nodes + # Walk tree for cross_service_boundary nodes. + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if xs_nodes: + for xn in xs_nodes: + assert xn.edge_from_parent.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") + assert xn.id in out.nodes def test_trace_cross_service_async(kuzu_graph: KuzuGraph) -> None: @@ -865,11 +748,10 @@ def test_trace_cross_service_async(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - # Should have cross-service boundary edges or at least scaffolding. - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if xs_edges: - for xe in xs_edges: - assert xe.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if xs_nodes: + for xn in xs_nodes: + assert xn.edge_from_parent.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") def test_trace_cross_service_edge_attrs(kuzu_graph: KuzuGraph) -> None: @@ -888,15 +770,14 @@ def test_trace_cross_service_edge_attrs(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - xs_edges = [e for e in out.edges if e.cross_service_boundary] - for xe in xs_edges: - assert xe.cross_service_boundary is True - # Cross-service edges should carry key attributes from the graph edge. - assert any(k in xe.attrs for k in ("confidence", "strategy", "match")) + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + for xn in xs_nodes: + assert xn.edge_from_parent.cross_service_boundary is True + assert any(k in xn.edge_from_parent.attrs for k in ("confidence", "strategy", "match")) def test_trace_cross_service_boundary_stops(kuzu_graph: KuzuGraph) -> None: - """BFS does not follow past cross-service boundary; downstream Route in nodes but no further edges from it.""" + """BFS does not follow past cross-service boundary; downstream Route in nodes but no children in tree.""" seed_id = _find_method_with_declares_client(kuzu_graph) if seed_id is None: pytest.skip("No method with DECLARES_CLIENT in fixture") @@ -911,16 +792,14 @@ def test_trace_cross_service_boundary_stops(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if not xs_edges: - pytest.skip("No cross-service edges in result") + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if not xs_nodes: + pytest.skip("No cross-service nodes in result") - for xe in xs_edges: - # Downstream node is in nodes dict. - assert xe.to_id in out.nodes - # No edges FROM the downstream node (frontier stops at boundary). - downstream_edges = [e for e in out.edges if e.from_id == xe.to_id] - assert len(downstream_edges) == 0 + for xn in xs_nodes: + assert xn.id in out.nodes + # Boundary node should have no children (frontier stops). + assert len(xn.children) == 0 def test_trace_cross_service_seamless_http(kuzu_graph: KuzuGraph) -> None: @@ -940,26 +819,19 @@ def test_trace_cross_service_seamless_http(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - # Should have cross-service boundary edges. - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if not xs_edges: - pytest.skip("No cross-service edges in result") + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if not xs_nodes: + pytest.skip("No cross-service nodes in result") - for xe in xs_edges: - assert xe.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") - # Downstream node should be in nodes. - assert xe.to_id in out.nodes + for xn in xs_nodes: + assert xn.edge_from_parent.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") + assert xn.id in out.nodes - # Key difference from boundary-stop: downstream Route should have edges FROM it - # (EXPOSES to handler, then CALLS from handler) because BFS continued through. - for xe in xs_edges: - downstream_edges = [e for e in out.edges if e.from_id == xe.to_id] - # At least one edge (EXPOSES to handler) should exist from the downstream Route. - if downstream_edges: - exposes_edges = [e for e in downstream_edges if e.edge_type == "EXPOSES"] - assert len(exposes_edges) >= 1, ( - f"Expected EXPOSES edges from {xe.to_id}, got: {[e.edge_type for e in downstream_edges]}" - ) + # Key difference from boundary-stop: downstream Route should have children. + for xn in xs_nodes: + if len(xn.children) > 0: + exposes_children = [c for c in xn.children if c.edge_from_parent and c.edge_from_parent.edge_type == "EXPOSES"] + assert len(exposes_children) >= 1 def test_trace_cross_service_seamless_async(kuzu_graph: KuzuGraph) -> None: @@ -979,13 +851,13 @@ def test_trace_cross_service_seamless_async(kuzu_graph: KuzuGraph) -> None: ) assert out.success is True - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if not xs_edges: - pytest.skip("No cross-service edges in result") + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if not xs_nodes: + pytest.skip("No cross-service nodes in result") - for xe in xs_edges: - assert xe.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") - assert xe.to_id in out.nodes + for xn in xs_nodes: + assert xn.edge_from_parent.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") + assert xn.id in out.nodes def test_trace_cross_service_seamless_respects_budget(kuzu_graph: KuzuGraph) -> None: @@ -1005,8 +877,6 @@ def test_trace_cross_service_seamless_respects_budget(kuzu_graph: KuzuGraph) -> graph=kuzu_graph, ) assert out.success is True - # Budget may or may not have been hit depending on graph size, - # but if it was, the stats should reflect it. if out.stats.budget_hit: assert out.stats.total_nodes_discovered >= 100 @@ -1017,7 +887,6 @@ def test_trace_cross_service_seamless_exposes_as_scaffolding(kuzu_graph: KuzuGra if seed_id is None: pytest.skip("No method with DECLARES_CLIENT in fixture") - # Use fan_out_cap=1 — very tight, but EXPOSES should still come through. out = trace_v2( ids=seed_id, direction="out", @@ -1029,17 +898,12 @@ def test_trace_cross_service_seamless_exposes_as_scaffolding(kuzu_graph: KuzuGra ) assert out.success is True - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if xs_edges: - # Even with fan_out_cap=1, EXPOSES edges from downstream Routes should appear - # (they're scaffolding, exempt from cap). - for xe in xs_edges: - downstream_edges = [e for e in out.edges if e.from_id == xe.to_id] - if downstream_edges: - exposes_edges = [e for e in downstream_edges if e.edge_type == "EXPOSES"] - assert len(exposes_edges) >= 1, ( - f"EXPOSES should be exempt from fan_out_cap, but got: {[e.edge_type for e in downstream_edges]}" - ) + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if xs_nodes: + for xn in xs_nodes: + if len(xn.children) > 0: + exposes_children = [c for c in xn.children if c.edge_from_parent and c.edge_from_parent.edge_type == "EXPOSES"] + assert len(exposes_children) >= 1 async def test_trace_registered_as_mcp_tool(mcp_server) -> None: @@ -1056,7 +920,6 @@ async def test_trace_tool_description_mentions_six_tools(mcp_server) -> None: instructions = server._INSTRUCTIONS assert "trace" in instructions assert instructions.count("trace") >= 1 - # Six tools mentioned: search, find, describe, neighbors, trace, resolve. assert "search" in instructions assert "find" in instructions assert "describe" in instructions @@ -1065,13 +928,36 @@ async def test_trace_tool_description_mentions_six_tools(mcp_server) -> None: # --------------------------------------------------------------------------- -# PR-TRACE-3 tests: cross-service integration + hint verification +# PR-TRACE-V2 tests: tree format, configurable collapse, source-relative +# ranking, bidirectional traversal, min_result_nodes retry # --------------------------------------------------------------------------- +def _find_method_with_declares_client(kuzu_graph: KuzuGraph) -> str | None: + """Find a method that has a DECLARES_CLIENT edge.""" + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (m:Symbol)-[:DECLARES_CLIENT]->(c:Client) RETURN m.id AS id LIMIT 1" + ) + if rows: + return str(rows[0]["id"]) + return None + + +def _find_method_with_declares_producer(kuzu_graph: KuzuGraph) -> str | None: + """Find a method that has a DECLARES_PRODUCER edge.""" + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (m:Symbol)-[:DECLARES_PRODUCER]->(p:Producer) RETURN m.id AS id LIMIT 1" + ) + if rows: + return str(rows[0]["id"]) + return None + + +# --- Updated v1 tests (assert on tree / ranked_leaves) --- + + def test_trace_bank_chat_cross_service_http_flow(kuzu_graph: KuzuGraph) -> None: """Integration: trace from a bank-chat method through HTTP_CALLS; verify cross-service boundary + hints.""" - # Find a method that declares a client (has cross-service path). seed_id = _find_method_with_declares_client(kuzu_graph) if seed_id is None: pytest.skip("No method with DECLARES_CLIENT in fixture") @@ -1081,36 +967,560 @@ def test_trace_bank_chat_cross_service_http_flow(kuzu_graph: KuzuGraph) -> None: direction="out", edge_types=["CALLS", "HTTP_CALLS"], max_depth=4, - fan_out_cap=0, # No cap — include all edges for cross-service verification. + fan_out_cap=0, graph=kuzu_graph, ) assert out.success is True - # Verify cross-service boundary edges exist. - xs_edges = [e for e in out.edges if e.cross_service_boundary] - if xs_edges: - # Cross-service boundary edges should stop at the route. - for xe in xs_edges: - assert xe.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") - assert xe.to_id in out.nodes - # No further edges from the downstream node. - downstream_edges = [e for e in out.edges if e.from_id == xe.to_id] - assert len(downstream_edges) == 0 - - # Verify hint generation works on the trace output. + # Verify cross-service boundary nodes in tree. + xs_nodes = [n for n in _walk_tree(out.tree) if n.edge_from_parent and n.edge_from_parent.cross_service_boundary] + if xs_nodes: + for xn in xs_nodes: + assert xn.edge_from_parent.edge_type in ("HTTP_CALLS", "ASYNC_CALLS") + assert xn.id in out.nodes + # No children (boundary stops without cross_service=True). + assert len(xn.children) == 0 + + # Verify hint generation works on the trace output (tree format). from mcp_hints import generate_hints trace_payload = { "success": out.success, "stats": out.stats.model_dump(), - "edges": [e.model_dump() for e in out.edges], + "tree": [n.model_dump() for n in out.tree], "nodes": {nid: n.model_dump() for nid, n in out.nodes.items()}, "seed_ids": out.seed_ids, "direction": out.direction, "edge_types": out.edge_types, } struct, advisories = generate_hints("trace", trace_payload) - # Cross-service boundary hints should fire if xs_edges exist. - if xs_edges: + if xs_nodes: assert any("cross-service" in a.lower() for a in advisories), ( f"expected cross-service advisory, got: {advisories}" ) + + +# --- New v2 tests --- + + +def test_trace_tree_root_is_seed(kuzu_graph: KuzuGraph) -> None: + """Tree root node matches seed ID.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + assert len(out.tree) >= 1 + assert out.tree[0].id == seed_id + + +def test_trace_tree_seed_no_edge_from_parent(kuzu_graph: KuzuGraph) -> None: + """Seed nodes have edge_from_parent=None.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=1, + graph=kuzu_graph, + ) + assert out.success is True + for root in out.tree: + assert root.edge_from_parent is None + + +def test_trace_tree_edge_from_parent_chain(kuzu_graph: KuzuGraph) -> None: + """Non-root nodes have edge_from_parent with valid edge_type, hop, and direction.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + + all_nodes = _walk_tree(out.tree) + for node in all_nodes: + if node.edge_from_parent is not None: + assert node.edge_from_parent.edge_type in {"CALLS", "DECLARES_CLIENT", "DECLARES_PRODUCER", "EXPOSES", "HTTP_CALLS", "ASYNC_CALLS"} + assert node.edge_from_parent.hop >= 0 + assert node.edge_from_parent.direction in ("in", "out") + + +def test_trace_tree_edge_from_parent_direction(kuzu_graph: KuzuGraph) -> None: + """edge_from_parent.direction is set ('in' or 'out') for all non-root nodes.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + for node in _walk_tree(out.tree): + if node.edge_from_parent is not None: + assert node.edge_from_parent.direction == "out" + + +def test_trace_tree_seed_with_zero_edges(kuzu_graph: KuzuGraph) -> None: + """Seed with zero matching edges produces tree=[TreeNode(id=seed)] with no children.""" + # Use a method that likely has no IMPLEMENTS edges (most methods don't). + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["IMPLEMENTS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + assert len(out.tree) >= 1 + assert out.tree[0].id == seed_id + # No IMPLEMENTS edges from a method — tree should have seed with no children. + assert len(out.tree[0].children) == 0 + + +def test_trace_tree_children_nested(kuzu_graph: KuzuGraph) -> None: + """Children are nested TreeNodes, not flat.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + # Children should be TreeNode instances. + if out.tree and out.tree[0].children: + child = out.tree[0].children[0] + assert isinstance(child, TreeNode) + assert hasattr(child, "children") + assert hasattr(child, "edge_from_parent") + + +def test_trace_tree_collapsed_node(kuzu_graph: KuzuGraph) -> None: + """Collapsed intermediates carry collapsed=True and collapsed_intermediates.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert out.success is True + collapsed = [n for n in _walk_tree(out.tree) if n.collapsed] + if collapsed: + for cn in collapsed: + assert cn.collapsed is True + assert len(cn.collapsed_intermediates) > 0 + + +def test_trace_tree_collapse_intermediates_in_nodes(kuzu_graph: KuzuGraph) -> None: + """Collapsed intermediate node IDs exist in nodes dict (v2).""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert out.success is True + collapsed = [n for n in _walk_tree(out.tree) if n.collapsed] + for cn in collapsed: + for inter_id in cn.collapsed_intermediates: + assert inter_id in out.nodes, f"Collapsed intermediate {inter_id} should be in nodes dict" + + +def test_trace_tree_collapse_children_reparented(kuzu_graph: KuzuGraph) -> None: + """After collapsing A→B→C, C appears as child of A in tree.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert out.success is True + # If collapsing happened, verify the tree structure is consistent. + collapsed = [n for n in _walk_tree(out.tree) if n.collapsed] + if collapsed: + for cn in collapsed: + # Collapsed node's parent should not be in collapsed_intermediates. + parent = None + for node in _walk_tree(out.tree): + if cn in node.children: + parent = node + break + if parent: + assert parent.id not in cn.collapsed_intermediates + + +def test_trace_ranked_leaves_capped(kuzu_graph: KuzuGraph) -> None: + """ranked_leaves does not exceed max_paths.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=5) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + max_paths=3, + graph=kuzu_graph, + ) + assert out.success is True + assert len(out.ranked_leaves) <= 3 + + +def test_trace_ranked_leaves_scores(kuzu_graph: KuzuGraph) -> None: + """Leaves are sorted by descending score.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=3) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + if len(out.ranked_leaves) > 1: + for i in range(len(out.ranked_leaves) - 1): + assert out.ranked_leaves[i].score >= out.ranked_leaves[i + 1].score + + +def test_trace_collapse_roles_custom(kuzu_graph: KuzuGraph) -> None: + """collapse_roles=['OTHER','SERVICE'] collapses SERVICE intermediates.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + collapse_roles=["OTHER", "SERVICE"], + fan_out_cap=0, + graph=kuzu_graph, + ) + assert out.success is True + # With wider collapse roles, we should get at least as much collapsing as default. + assert out.stats.edges_collapsed_trivial >= 0 + + +def test_trace_collapse_roles_default(kuzu_graph: KuzuGraph) -> None: + """Default collapse_roles only collapses OTHER.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + # Default (collapse_roles=None → defaults to OTHER). + default_out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert default_out.success is True + + # Check that collapsed intermediates are all OTHER or None role. + collapsed = [n for n in _walk_tree(default_out.tree) if n.collapsed] + for cn in collapsed: + for inter_id in cn.collapsed_intermediates: + inter_ref = default_out.nodes.get(inter_id) + if inter_ref: + assert inter_ref.role in ("OTHER", None), ( + f"Default collapse should only collapse OTHER/None, got {inter_ref.role}" + ) + + +def test_trace_collapse_min_chain_length_2(kuzu_graph: KuzuGraph) -> None: + """collapse_min_chain_length=2 skips single-intermediate collapses.""" + seed_id = _find_method_with_multiple_callees(kuzu_graph, min_callees=2) + if seed_id is None: + pytest.skip("No method with multiple callees in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + collapse_min_chain_length=2, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert out.success is True + # With min_chain_length=2, should collapse fewer chains than default (1). + # The collapsed count should be <= the default count. + default_out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=3, + collapse_trivial=True, + fan_out_cap=0, + graph=kuzu_graph, + ) + assert default_out.success is True + assert out.stats.edges_collapsed_trivial <= default_out.stats.edges_collapsed_trivial + + +def test_trace_fan_out_source_relative_service(kuzu_graph: KuzuGraph) -> None: + """From SERVICE node, REPOSITORY callee outranks CONTROLLER at equal confidence.""" + # Find a SERVICE method with outbound calls to both REPOSITORY and non-REPOSITORY. + rows = kuzu_graph._rows( # noqa: SLF001 + """ + MATCH (m:Symbol)-[:CALLS]->(other:Symbol) + WHERE m.role = 'SERVICE' + WITH m, collect(DISTINCT other.role) AS roles + WHERE size(roles) >= 2 + RETURN m.id AS id + LIMIT 1 + """ + ) + if not rows: + pytest.skip("No SERVICE method with diverse callees in fixture") + seed_id = str(rows[0]["id"]) + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=1, + fan_out_cap=10, + graph=kuzu_graph, + ) + assert out.success is True + # Verify results returned without error (source-relative ranking is internal). + assert len(out.tree[0].children) > 0 + + +def test_trace_fan_out_source_relative_controller(kuzu_graph: KuzuGraph) -> None: + """From CONTROLLER node, SERVICE callee outranks REPOSITORY at equal confidence.""" + rows = kuzu_graph._rows( # noqa: SLF001 + """ + MATCH (m:Symbol)-[:CALLS]->(other:Symbol) + WHERE m.role = 'CONTROLLER' + RETURN m.id AS id + LIMIT 1 + """ + ) + if not rows: + pytest.skip("No CONTROLLER method with callees in fixture") + seed_id = str(rows[0]["id"]) + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=1, + fan_out_cap=10, + graph=kuzu_graph, + ) + assert out.success is True + + +def test_trace_fan_out_source_relative_fallback(kuzu_graph: KuzuGraph) -> None: + """Unknown source role falls back to static priority.""" + # Find a method with OTHER or unknown role. + rows = kuzu_graph._rows( # noqa: SLF001 + """ + MATCH (m:Symbol)-[:CALLS]->(other:Symbol) + WHERE m.role = 'OTHER' OR m.role IS NULL + RETURN m.id AS id + LIMIT 1 + """ + ) + if not rows: + pytest.skip("No OTHER method with callees in fixture") + seed_id = str(rows[0]["id"]) + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=1, + fan_out_cap=10, + graph=kuzu_graph, + ) + assert out.success is True + + +def test_trace_min_result_nodes_retry(kuzu_graph: KuzuGraph) -> None: + """min_result_nodes=10 triggers fan-out cap retry when initial result < 10.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + min_result_nodes=10, + fan_out_cap=1, + graph=kuzu_graph, + ) + assert out.success is True + # Either it got enough nodes or there's an advisory about retry. + if len(out.nodes) < 10: + assert any("min_result_nodes" in adv for adv in out.advisories) + + +def test_trace_min_result_nodes_disabled(kuzu_graph: KuzuGraph) -> None: + """min_result_nodes=0 (default) does not retry.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + + out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + min_result_nodes=0, + fan_out_cap=1, + graph=kuzu_graph, + ) + assert out.success is True + assert not any("min_result_nodes" in adv for adv in out.advisories) + + +def test_trace_bidirectional_basic(kuzu_graph: KuzuGraph) -> None: + """direction='both' returns tree with both in and out children from seed.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + # Also need inbound calls. + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (caller:Symbol)-[:CALLS]->(m:Symbol {id: $id}) RETURN m.id AS id LIMIT 1", + {"id": seed_id}, + ) + if not rows: + pytest.skip("Seed method has no inbound calls") + + out = trace_v2( + ids=seed_id, + direction="both", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + assert out.direction == "both" + # Tree should have children with both directions. + directions_found = set() + for node in _walk_tree(out.tree): + if node.edge_from_parent is not None: + directions_found.add(node.edge_from_parent.direction) + # Should have at least "out" direction (from the outbound calls). + assert "out" in directions_found or "in" in directions_found + + +def test_trace_bidirectional_shared_visited(kuzu_graph: KuzuGraph) -> None: + """Nodes discovered in 'out' are not re-visited in 'in'.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + + out = trace_v2( + ids=seed_id, + direction="both", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + # No duplicate node IDs in tree walk. + all_nodes = _walk_tree(out.tree) + node_ids = [n.id for n in all_nodes] + assert len(node_ids) == len(set(node_ids)) + + +def test_trace_bidirectional_stats_aggregated(kuzu_graph: KuzuGraph) -> None: + """Stats aggregate both directions.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + + out = trace_v2( + ids=seed_id, + direction="both", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + assert out.stats.nodes_after_pruning == len(out.nodes) + + # Compare with unidirectional: both should discover more or equal nodes. + out_out = trace_v2( + ids=seed_id, + direction="out", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert len(out.nodes) >= len(out_out.nodes) + + +def test_trace_bidirectional_ranked_leaves_merged(kuzu_graph: KuzuGraph) -> None: + """ranked_leaves includes leaves from both directions.""" + seed_id = _find_method_with_outbound_calls(kuzu_graph) + if seed_id is None: + pytest.skip("No method with outbound calls in fixture") + + out = trace_v2( + ids=seed_id, + direction="both", + edge_types=["CALLS"], + max_depth=2, + graph=kuzu_graph, + ) + assert out.success is True + # Should have ranked leaves. + assert len(out.ranked_leaves) >= 0