Skip to content

Commit 1d473a4

Browse files
committed
Add MCP roots support
When the client advertises the roots capability during initialize, the server reciprocates with roots.listChanged in its capabilities, fires a roots/list request once notifications/initialized arrives, and re-fetches on every notifications/roots/list_changed. The first file:// URI in the response becomes the FileSystemProvider root, so one MCP server can follow the user across projects without a restart. ResourceProvider gains a no-op set_workspace_root hook; ChainProvider fans out, FileSystemProvider re-targets.
1 parent d229032 commit 1d473a4

3 files changed

Lines changed: 164 additions & 8 deletions

File tree

je_auto_control/utils/mcp_server/resources.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def read(self, uri: str) -> Optional[Dict[str, Any]]: # pragma: no cover - abst
4141
"""Return one content block (``{uri, mimeType, text}``) or ``None``."""
4242
raise NotImplementedError
4343

44+
def set_workspace_root(self, root: str) -> None:
45+
"""Hook for MCP roots. Default: no-op. FS-backed providers override."""
46+
del root
47+
4448

4549
class FileSystemProvider(ResourceProvider):
4650
"""Expose ``*.json`` action files in ``root`` under ``<scheme>://files/<name>``."""
@@ -50,6 +54,10 @@ def __init__(self, root: str = ".",
5054
self.root = os.path.realpath(root)
5155
self.scheme = scheme
5256

57+
def set_workspace_root(self, root: str) -> None:
58+
"""Re-target the provider at a new directory (e.g. via MCP roots)."""
59+
self.root = os.path.realpath(os.fspath(root))
60+
5361
def list(self) -> List[MCPResource]:
5462
if not os.path.isdir(self.root):
5563
return []
@@ -160,6 +168,11 @@ def read(self, uri: str) -> Optional[Dict[str, Any]]:
160168
return content
161169
return None
162170

171+
def set_workspace_root(self, root: str) -> None:
172+
"""Forward the root to every child provider."""
173+
for provider in self.providers:
174+
provider.set_workspace_root(root)
175+
163176

164177
def default_resource_provider(root: str = ".") -> ResourceProvider:
165178
"""Return the resource provider exposed by the default MCP server."""

je_auto_control/utils/mcp_server/server.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def __init__(self, tools: Optional[List[MCPTool]] = None,
7474
self._calls_lock = threading.Lock()
7575
self._write_lock = threading.Lock()
7676
self._sampling_id_counter = itertools.count(1)
77+
self._outbound_id_counter = itertools.count(1)
7778
self._pending_outbound: Dict[Any, Dict[str, Any]] = {}
7879
self._outbound_lock = threading.Lock()
80+
self._client_capabilities: Dict[str, Any] = {}
7981

8082
def register_tool(self, tool: MCPTool) -> None:
8183
"""Add or replace a tool in the live registry.
@@ -251,12 +253,76 @@ def _handle_notification(self, method: Optional[str],
251253
if method == "notifications/initialized":
252254
self._initialized = True
253255
autocontrol_logger.info("MCP client initialized")
256+
self._maybe_request_roots_async()
254257
return
255258
if method == "notifications/cancelled":
256259
self._cancel_active_call(params)
257260
return
261+
if method == "notifications/roots/list_changed":
262+
self._maybe_request_roots_async()
263+
return
258264
autocontrol_logger.debug("MCP notification ignored: %s", method)
259265

266+
def _maybe_request_roots_async(self) -> None:
267+
"""Fire a roots/list request when the client supports it."""
268+
if "roots" not in self._client_capabilities:
269+
return
270+
if self._writer is None:
271+
return
272+
threading.Thread(
273+
target=self._refresh_roots_safely, daemon=True,
274+
name="MCPRootsRefresh",
275+
).start()
276+
277+
def _refresh_roots_safely(self) -> None:
278+
try:
279+
self.refresh_roots(timeout=5.0)
280+
except (RuntimeError, TimeoutError) as error:
281+
autocontrol_logger.info("MCP roots refresh skipped: %r", error)
282+
283+
def refresh_roots(self, timeout: float = 10.0) -> List[Dict[str, Any]]:
284+
"""Send ``roots/list`` to the client and apply the first root."""
285+
result = self._send_outbound_request(
286+
"roots/list", params={}, timeout=timeout,
287+
)
288+
roots_list = (result or {}).get("roots") or []
289+
if not isinstance(roots_list, list) or not roots_list:
290+
return []
291+
first_uri = roots_list[0].get("uri") if isinstance(roots_list[0],
292+
dict) else None
293+
if isinstance(first_uri, str):
294+
local_path = _file_uri_to_path(first_uri)
295+
if local_path:
296+
self._resources.set_workspace_root(local_path)
297+
autocontrol_logger.info("MCP workspace root → %s", local_path)
298+
return roots_list
299+
300+
def _send_outbound_request(self, method: str,
301+
params: Dict[str, Any],
302+
timeout: float = 10.0) -> Dict[str, Any]:
303+
"""Send a server-initiated request and wait for the response."""
304+
writer = self._writer
305+
if writer is None:
306+
raise RuntimeError(f"{method} requires an outbound writer")
307+
request_id = f"srv-{next(self._outbound_id_counter)}"
308+
slot = {"event": threading.Event()}
309+
with self._outbound_lock:
310+
self._pending_outbound[request_id] = slot
311+
envelope = json.dumps({
312+
"jsonrpc": "2.0", "id": request_id,
313+
"method": method, "params": params,
314+
}, ensure_ascii=False, default=str)
315+
try:
316+
writer(envelope)
317+
if not slot["event"].wait(timeout=timeout):
318+
raise TimeoutError(f"{method} timed out after {timeout}s")
319+
finally:
320+
with self._outbound_lock:
321+
self._pending_outbound.pop(request_id, None)
322+
if "error" in slot:
323+
raise RuntimeError(f"{method} failed: {slot['error']}")
324+
return slot.get("result") or {}
325+
260326
def _cancel_active_call(self, params: Dict[str, Any]) -> None:
261327
"""Mark the matching active tool call as cancelled, if any."""
262328
request_id = params.get("requestId")
@@ -293,17 +359,22 @@ def _dispatch(self, msg_id: Any, method: Optional[str],
293359
return self._handle_prompts_get(params)
294360
raise _MCPError(-32601, f"Method not found: {method}")
295361

296-
@staticmethod
297-
def _handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
362+
def _handle_initialize(self, params: Dict[str, Any]) -> Dict[str, Any]:
298363
client_version = params.get("protocolVersion", PROTOCOL_VERSION)
364+
client_caps = params.get("capabilities") or {}
365+
if isinstance(client_caps, dict):
366+
self._client_capabilities = client_caps
367+
capabilities: Dict[str, Any] = {
368+
"tools": {"listChanged": True},
369+
"resources": {"listChanged": False, "subscribe": False},
370+
"prompts": {"listChanged": False},
371+
"sampling": {},
372+
}
373+
if "roots" in self._client_capabilities:
374+
capabilities["roots"] = {"listChanged": True}
299375
return {
300376
"protocolVersion": client_version or PROTOCOL_VERSION,
301-
"capabilities": {
302-
"tools": {"listChanged": True},
303-
"resources": {"listChanged": False, "subscribe": False},
304-
"prompts": {"listChanged": False},
305-
"sampling": {},
306-
},
377+
"capabilities": capabilities,
307378
"serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION},
308379
}
309380

@@ -465,6 +536,20 @@ def _stringify_result(value: Any) -> str:
465536
return repr(value)
466537

467538

539+
def _file_uri_to_path(uri: str) -> Optional[str]:
540+
"""Convert a ``file://`` URI to a local filesystem path; ``None`` otherwise."""
541+
if not isinstance(uri, str) or not uri.startswith("file://"):
542+
return None
543+
from urllib.parse import unquote, urlparse
544+
parsed = urlparse(uri)
545+
raw_path = unquote(parsed.path)
546+
# Windows: file:///C:/foo strips the leading slash before the drive letter.
547+
if sys.platform.startswith("win") and raw_path.startswith("/") and \
548+
len(raw_path) > 2 and raw_path[2] == ":":
549+
raw_path = raw_path[1:]
550+
return raw_path or None
551+
552+
468553
def _notification_message(method: str, params: Dict[str, Any]) -> str:
469554
return json.dumps({"jsonrpc": "2.0", "method": method, "params": params},
470555
ensure_ascii=False, default=str)

test/unit_test/headless/test_mcp_server.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import io
88
import json
9+
import os
910
import threading
1011
from typing import Any, Dict, List
1112

@@ -1135,6 +1136,63 @@ def test_rate_limiter_zero_rate_means_unlimited():
11351136
assert limiter.try_acquire() is True
11361137

11371138

1139+
def test_initialize_advertises_roots_when_client_supports_it():
1140+
server = MCPServer(tools=[])
1141+
response = _decode(server.handle_line(_request("initialize", params={
1142+
"capabilities": {"roots": {"listChanged": True}},
1143+
})))
1144+
assert "roots" in response["result"]["capabilities"]
1145+
1146+
1147+
def test_initialize_omits_roots_when_client_lacks_capability():
1148+
server = MCPServer(tools=[])
1149+
response = _decode(server.handle_line(_request("initialize", params={
1150+
"capabilities": {},
1151+
})))
1152+
assert "roots" not in response["result"]["capabilities"]
1153+
1154+
1155+
def test_refresh_roots_updates_filesystem_provider(tmp_path):
1156+
from je_auto_control.utils.mcp_server.resources import (
1157+
ChainProvider, FileSystemProvider,
1158+
)
1159+
fs_provider = FileSystemProvider(root=str(tmp_path / "initial"))
1160+
chain = ChainProvider([fs_provider])
1161+
captured_lines = []
1162+
server = MCPServer(tools=[], resource_provider=chain,
1163+
concurrent_tools=True)
1164+
server.set_writer(captured_lines.append)
1165+
# Simulate client capability so refresh is allowed.
1166+
server._client_capabilities = {"roots": {"listChanged": True}}
1167+
1168+
target = tmp_path / "ws"
1169+
target.mkdir()
1170+
1171+
def run_refresh():
1172+
server.refresh_roots(timeout=2.0)
1173+
1174+
t = threading.Thread(target=run_refresh)
1175+
t.start()
1176+
deadline = threading.Event()
1177+
for _ in range(200):
1178+
if any('"roots/list"' in line for line in captured_lines):
1179+
break
1180+
deadline.wait(0.01)
1181+
request_lines = [line for line in captured_lines
1182+
if '"roots/list"' in line]
1183+
assert request_lines, "expected outbound roots/list"
1184+
request_id = json.loads(request_lines[-1])["id"]
1185+
1186+
file_uri = "file:///" + str(target).replace("\\", "/").lstrip("/")
1187+
server.handle_line(json.dumps({
1188+
"jsonrpc": "2.0", "id": request_id,
1189+
"result": {"roots": [{"uri": file_uri, "name": "ws"}]},
1190+
}))
1191+
t.join(timeout=2.0)
1192+
assert not t.is_alive()
1193+
assert os.path.realpath(fs_provider.root) == os.path.realpath(str(target))
1194+
1195+
11381196
def test_default_registry_lists_core_automation_tools():
11391197
names = {tool.name for tool in build_default_tool_registry()}
11401198
expected = {

0 commit comments

Comments
 (0)