@@ -74,8 +74,10 @@ def __init__(self, tools: Optional[List[MCPTool]] = None,
7474 self ._calls_lock = threading .Lock ()
7575 self ._write_lock = threading .Lock ()
7676 self ._sampling_id_counter = itertools .count (1 )
77+ self ._outbound_id_counter = itertools .count (1 )
7778 self ._pending_outbound : Dict [Any , Dict [str , Any ]] = {}
7879 self ._outbound_lock = threading .Lock ()
80+ self ._client_capabilities : Dict [str , Any ] = {}
7981
8082 def register_tool (self , tool : MCPTool ) -> None :
8183 """Add or replace a tool in the live registry.
@@ -251,12 +253,76 @@ def _handle_notification(self, method: Optional[str],
251253 if method == "notifications/initialized" :
252254 self ._initialized = True
253255 autocontrol_logger .info ("MCP client initialized" )
256+ self ._maybe_request_roots_async ()
254257 return
255258 if method == "notifications/cancelled" :
256259 self ._cancel_active_call (params )
257260 return
261+ if method == "notifications/roots/list_changed" :
262+ self ._maybe_request_roots_async ()
263+ return
258264 autocontrol_logger .debug ("MCP notification ignored: %s" , method )
259265
266+ def _maybe_request_roots_async (self ) -> None :
267+ """Fire a roots/list request when the client supports it."""
268+ if "roots" not in self ._client_capabilities :
269+ return
270+ if self ._writer is None :
271+ return
272+ threading .Thread (
273+ target = self ._refresh_roots_safely , daemon = True ,
274+ name = "MCPRootsRefresh" ,
275+ ).start ()
276+
277+ def _refresh_roots_safely (self ) -> None :
278+ try :
279+ self .refresh_roots (timeout = 5.0 )
280+ except (RuntimeError , TimeoutError ) as error :
281+ autocontrol_logger .info ("MCP roots refresh skipped: %r" , error )
282+
283+ def refresh_roots (self , timeout : float = 10.0 ) -> List [Dict [str , Any ]]:
284+ """Send ``roots/list`` to the client and apply the first root."""
285+ result = self ._send_outbound_request (
286+ "roots/list" , params = {}, timeout = timeout ,
287+ )
288+ roots_list = (result or {}).get ("roots" ) or []
289+ if not isinstance (roots_list , list ) or not roots_list :
290+ return []
291+ first_uri = roots_list [0 ].get ("uri" ) if isinstance (roots_list [0 ],
292+ dict ) else None
293+ if isinstance (first_uri , str ):
294+ local_path = _file_uri_to_path (first_uri )
295+ if local_path :
296+ self ._resources .set_workspace_root (local_path )
297+ autocontrol_logger .info ("MCP workspace root → %s" , local_path )
298+ return roots_list
299+
300+ def _send_outbound_request (self , method : str ,
301+ params : Dict [str , Any ],
302+ timeout : float = 10.0 ) -> Dict [str , Any ]:
303+ """Send a server-initiated request and wait for the response."""
304+ writer = self ._writer
305+ if writer is None :
306+ raise RuntimeError (f"{ method } requires an outbound writer" )
307+ request_id = f"srv-{ next (self ._outbound_id_counter )} "
308+ slot = {"event" : threading .Event ()}
309+ with self ._outbound_lock :
310+ self ._pending_outbound [request_id ] = slot
311+ envelope = json .dumps ({
312+ "jsonrpc" : "2.0" , "id" : request_id ,
313+ "method" : method , "params" : params ,
314+ }, ensure_ascii = False , default = str )
315+ try :
316+ writer (envelope )
317+ if not slot ["event" ].wait (timeout = timeout ):
318+ raise TimeoutError (f"{ method } timed out after { timeout } s" )
319+ finally :
320+ with self ._outbound_lock :
321+ self ._pending_outbound .pop (request_id , None )
322+ if "error" in slot :
323+ raise RuntimeError (f"{ method } failed: { slot ['error' ]} " )
324+ return slot .get ("result" ) or {}
325+
260326 def _cancel_active_call (self , params : Dict [str , Any ]) -> None :
261327 """Mark the matching active tool call as cancelled, if any."""
262328 request_id = params .get ("requestId" )
@@ -293,17 +359,22 @@ def _dispatch(self, msg_id: Any, method: Optional[str],
293359 return self ._handle_prompts_get (params )
294360 raise _MCPError (- 32601 , f"Method not found: { method } " )
295361
296- @staticmethod
297- def _handle_initialize (params : Dict [str , Any ]) -> Dict [str , Any ]:
362+ def _handle_initialize (self , params : Dict [str , Any ]) -> Dict [str , Any ]:
298363 client_version = params .get ("protocolVersion" , PROTOCOL_VERSION )
364+ client_caps = params .get ("capabilities" ) or {}
365+ if isinstance (client_caps , dict ):
366+ self ._client_capabilities = client_caps
367+ capabilities : Dict [str , Any ] = {
368+ "tools" : {"listChanged" : True },
369+ "resources" : {"listChanged" : False , "subscribe" : False },
370+ "prompts" : {"listChanged" : False },
371+ "sampling" : {},
372+ }
373+ if "roots" in self ._client_capabilities :
374+ capabilities ["roots" ] = {"listChanged" : True }
299375 return {
300376 "protocolVersion" : client_version or PROTOCOL_VERSION ,
301- "capabilities" : {
302- "tools" : {"listChanged" : True },
303- "resources" : {"listChanged" : False , "subscribe" : False },
304- "prompts" : {"listChanged" : False },
305- "sampling" : {},
306- },
377+ "capabilities" : capabilities ,
307378 "serverInfo" : {"name" : SERVER_NAME , "version" : SERVER_VERSION },
308379 }
309380
@@ -465,6 +536,20 @@ def _stringify_result(value: Any) -> str:
465536 return repr (value )
466537
467538
539+ def _file_uri_to_path (uri : str ) -> Optional [str ]:
540+ """Convert a ``file://`` URI to a local filesystem path; ``None`` otherwise."""
541+ if not isinstance (uri , str ) or not uri .startswith ("file://" ):
542+ return None
543+ from urllib .parse import unquote , urlparse
544+ parsed = urlparse (uri )
545+ raw_path = unquote (parsed .path )
546+ # Windows: file:///C:/foo strips the leading slash before the drive letter.
547+ if sys .platform .startswith ("win" ) and raw_path .startswith ("/" ) and \
548+ len (raw_path ) > 2 and raw_path [2 ] == ":" :
549+ raw_path = raw_path [1 :]
550+ return raw_path or None
551+
552+
468553def _notification_message (method : str , params : Dict [str , Any ]) -> str :
469554 return json .dumps ({"jsonrpc" : "2.0" , "method" : method , "params" : params },
470555 ensure_ascii = False , default = str )
0 commit comments