diff --git a/backend/routers/mcp.py b/backend/routers/mcp.py index 5edb1a58838..eb5b69b68ec 100644 --- a/backend/routers/mcp.py +++ b/backend/routers/mcp.py @@ -29,6 +29,7 @@ from utils.other.endpoints import with_rate_limit from utils.log_sanitizer import sanitize_pii from utils.mcp_data import clean_action_item, clean_chat_message, clean_person, clean_screen_activity_row +import utils.mcp_action_items as mcp_action_items from utils.mcp_memories import ( collect_filtered_memories, parse_mcp_bool, @@ -447,6 +448,81 @@ def get_action_items( return [clean_action_item(i) for i in items if not i.get("deleted", False)] +class McpCreateActionItem(BaseModel): + description: str + due_at: Optional[datetime] = None + completed: bool = False + + +class McpUpdateActionItem(BaseModel): + description: Optional[str] = None + due_at: Optional[datetime] = None + + +def _action_item_write_error(exc: Exception) -> HTTPException: + """Map a shared action-item write error to the REST status the memory writes use.""" + if isinstance(exc, mcp_action_items.ActionItemNotFound): + return HTTPException(status_code=404, detail="Action item not found") + if isinstance(exc, mcp_action_items.ActionItemLocked): + return HTTPException(status_code=402, detail="A paid plan is required to modify this action item.") + return HTTPException(status_code=500, detail="Action item write failed") + + +@router.get("/v1/mcp/action-items/search", response_model=List[SimpleActionItem], tags=["mcp"]) +def search_action_items(query: str, limit: int = 10, uid: str = Depends(get_uid_from_mcp_api_key)): + logger.info(f"search_action_items {uid} limit={limit}") + try: + return mcp_action_items.search_action_items(uid, query, limit=limit) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + + +@router.post("/v1/mcp/action-items", response_model=SimpleActionItem, tags=["mcp"]) +def create_action_item( + body: McpCreateActionItem, + uid: str = Depends(with_rate_limit(get_uid_from_mcp_api_key, "action_items:write")), +): + logger.info(f"create_action_item {uid} completed={body.completed} has_due={body.due_at is not None}") + try: + return mcp_action_items.create_action_item(uid, body.description, due_at=body.due_at, completed=body.completed) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except mcp_action_items.ActionItemError as e: + raise _action_item_write_error(e) + + +@router.post("/v1/mcp/action-items/{action_item_id}/complete", response_model=SimpleActionItem, tags=["mcp"]) +def complete_action_item(action_item_id: str, completed: bool = True, uid: str = Depends(get_uid_from_mcp_api_key)): + logger.info(f"complete_action_item {uid} id={action_item_id} completed={completed}") + try: + return mcp_action_items.set_completed(uid, action_item_id, completed=completed) + except mcp_action_items.ActionItemError as e: + raise _action_item_write_error(e) + + +@router.patch("/v1/mcp/action-items/{action_item_id}", response_model=SimpleActionItem, tags=["mcp"]) +def update_action_item(action_item_id: str, body: McpUpdateActionItem, uid: str = Depends(get_uid_from_mcp_api_key)): + logger.info(f"update_action_item {uid} id={action_item_id}") + try: + return mcp_action_items.update_action_item( + uid, action_item_id, description=body.description, due_at=body.due_at + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except mcp_action_items.ActionItemError as e: + raise _action_item_write_error(e) + + +@router.delete("/v1/mcp/action-items/{action_item_id}", tags=["mcp"]) +def delete_action_item(action_item_id: str, uid: str = Depends(get_uid_from_mcp_api_key)): + logger.info(f"delete_action_item {uid} id={action_item_id}") + try: + mcp_action_items.delete_action_item(uid, action_item_id) + except mcp_action_items.ActionItemError as e: + raise _action_item_write_error(e) + return {"status": "ok"} + + # --------------------------------------------------------------------------- # Goals — the user's stated objectives # --------------------------------------------------------------------------- diff --git a/backend/routers/mcp_sse.py b/backend/routers/mcp_sse.py index d1f0b82fde1..7a8f778d531 100644 --- a/backend/routers/mcp_sse.py +++ b/backend/routers/mcp_sse.py @@ -35,6 +35,7 @@ from models.conversation_enums import CategoryEnum from utils.llm.memories import identify_category_for_memory from utils.mcp_data import clean_action_item, clean_chat_message, clean_person, clean_screen_activity_row +import utils.mcp_action_items as mcp_action_items from utils.mcp_memories import ( collect_filtered_memories, parse_mcp_bool, @@ -60,6 +61,7 @@ "memories.write", "conversations.read", "action_items.read", + "action_items.write", "goals.read", "chat.read", "screen_activity.read", @@ -88,6 +90,7 @@ MEMORIES_WRITE_SECURITY = [{"type": "oauth2", "scopes": ["memories.write"]}] CONVERSATIONS_READ_SECURITY = [{"type": "oauth2", "scopes": ["conversations.read"]}] ACTION_ITEMS_READ_SECURITY = [{"type": "oauth2", "scopes": ["action_items.read"]}] +ACTION_ITEMS_WRITE_SECURITY = [{"type": "oauth2", "scopes": ["action_items.write"]}] GOALS_READ_SECURITY = [{"type": "oauth2", "scopes": ["goals.read"]}] CHAT_READ_SECURITY = [{"type": "oauth2", "scopes": ["chat.read"]}] SCREEN_ACTIVITY_READ_SECURITY = [{"type": "oauth2", "scopes": ["screen_activity.read"]}] @@ -360,6 +363,100 @@ def invalid_mcp_auth_exception( }, }, }, + { + "name": "search_action_items", + "description": ( + "Semantic search across the user's action items (tasks/to-dos). Returns tasks ranked by relevance to " + "the query — use this to find a specific task by what it is about before completing or updating it." + ), + "annotations": READ_ONLY_ANNOTATIONS, + "securitySchemes": ACTION_ITEMS_READ_SECURITY, + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "What to search the user's tasks for"}, + "limit": {"type": "integer", "description": "Max number of tasks to return (1-50)", "default": 10}, + }, + "required": ["query"], + }, + }, + { + "name": "create_action_item", + "description": ( + "Create a new action item (task/to-do) for the user — for example a follow-up you identified while " + "helping them. Retries with the same description return the existing task instead of duplicating it." + ), + "annotations": WRITE_ANNOTATIONS, + "securitySchemes": ACTION_ITEMS_WRITE_SECURITY, + "inputSchema": { + "type": "object", + "properties": { + "description": {"type": "string", "description": "What the user needs to do"}, + "due_at": { + "type": "string", + "description": "Optional due date/time, ISO 8601 (2026-07-01T17:00:00Z) or YYYY-MM-DD", + }, + "completed": { + "type": "boolean", + "description": "Create it already completed (default false)", + "default": False, + }, + }, + "required": ["description"], + }, + }, + { + "name": "complete_action_item", + "description": "Mark an action item complete, or reopen it by passing completed=false.", + "annotations": WRITE_ANNOTATIONS, + "securitySchemes": ACTION_ITEMS_WRITE_SECURITY, + "inputSchema": { + "type": "object", + "properties": { + "action_item_id": {"type": "string", "description": "The ID of the action item"}, + "completed": { + "type": "boolean", + "description": "True to complete (default), false to reopen", + "default": True, + }, + }, + "required": ["action_item_id"], + }, + }, + { + "name": "update_action_item", + "description": ( + "Update an action item's description and/or due date. Only the fields you pass are changed; an omitted " + "due date is left unchanged." + ), + "annotations": WRITE_ANNOTATIONS, + "securitySchemes": ACTION_ITEMS_WRITE_SECURITY, + "inputSchema": { + "type": "object", + "properties": { + "action_item_id": {"type": "string", "description": "The ID of the action item"}, + "description": {"type": "string", "description": "New description for the task"}, + "due_at": { + "type": "string", + "description": "New due date/time, ISO 8601 (2026-07-01T17:00:00Z) or YYYY-MM-DD", + }, + }, + "required": ["action_item_id"], + }, + }, + { + "name": "delete_action_item", + "description": "Delete an action item by ID. Use this to clean up a task that is no longer relevant.", + "annotations": DESTRUCTIVE_WRITE_ANNOTATIONS, + "securitySchemes": ACTION_ITEMS_WRITE_SECURITY, + "inputSchema": { + "type": "object", + "properties": { + "action_item_id": {"type": "string", "description": "The ID of the action item to delete"}, + }, + "required": ["action_item_id"], + }, + }, { "name": "get_goals", "description": ( @@ -820,6 +917,74 @@ def execute_tool(user_id: str, tool_name: str, arguments: dict) -> dict: ) return {"action_items": [clean_action_item(i) for i in items if not i.get("deleted", False)]} + elif tool_name == "search_action_items": + try: + items = mcp_action_items.search_action_items( + user_id, arguments.get("query"), limit=arguments.get("limit", 10) + ) + except ValueError as e: + raise ToolExecutionError(str(e), code=-32602) + return {"action_items": items} + + elif tool_name == "create_action_item": + try: + completed = parse_mcp_bool(arguments.get("completed"), "completed", default=False) + item = mcp_action_items.create_action_item( + user_id, + arguments.get("description"), + due_at=arguments.get("due_at"), + completed=completed, + ) + except ValueError as e: + raise ToolExecutionError(str(e), code=-32602) + return {"success": True, "action_item": item} + + elif tool_name == "complete_action_item": + action_item_id = arguments.get("action_item_id") + if not action_item_id: + raise ToolExecutionError("action_item_id is required", code=-32602) + try: + completed = parse_mcp_bool(arguments.get("completed"), "completed", default=True) + item = mcp_action_items.set_completed(user_id, action_item_id, completed=completed) + except ValueError as e: + raise ToolExecutionError(str(e), code=-32602) + except mcp_action_items.ActionItemNotFound: + raise ToolExecutionError("Action item not found", code=-32001) + except mcp_action_items.ActionItemLocked: + raise ToolExecutionError("A paid plan is required to modify this action item.", code=-32002) + return {"success": True, "action_item": item} + + elif tool_name == "update_action_item": + action_item_id = arguments.get("action_item_id") + if not action_item_id: + raise ToolExecutionError("action_item_id is required", code=-32602) + try: + item = mcp_action_items.update_action_item( + user_id, + action_item_id, + description=arguments.get("description"), + due_at=arguments.get("due_at"), + ) + except ValueError as e: + raise ToolExecutionError(str(e), code=-32602) + except mcp_action_items.ActionItemNotFound: + raise ToolExecutionError("Action item not found", code=-32001) + except mcp_action_items.ActionItemLocked: + raise ToolExecutionError("A paid plan is required to modify this action item.", code=-32002) + return {"success": True, "action_item": item} + + elif tool_name == "delete_action_item": + action_item_id = arguments.get("action_item_id") + if not action_item_id: + raise ToolExecutionError("action_item_id is required", code=-32602) + try: + mcp_action_items.delete_action_item(user_id, action_item_id) + except mcp_action_items.ActionItemNotFound: + raise ToolExecutionError("Action item not found", code=-32001) + except mcp_action_items.ActionItemLocked: + raise ToolExecutionError("A paid plan is required to modify this action item.", code=-32002) + return {"success": True} + elif tool_name == "get_goals": include_inactive = parse_mcp_bool(arguments.get("include_inactive"), "include_inactive", default=False) return {"goals": goals_db.get_all_goals(user_id, include_inactive=include_inactive)} diff --git a/backend/test.sh b/backend/test.sh index c39a1412f35..4dcd68b27ec 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -36,6 +36,7 @@ pytest tests/unit/test_mcp_search_conversations_poison.py -v pytest tests/unit/test_mcp_memory_filters.py -v pytest tests/unit/test_mcp_client_tool_result.py -v pytest tests/unit/test_mcp_data_endpoints.py -v +pytest tests/unit/test_mcp_action_item_writes.py -v pytest tests/unit/test_mcp_conversations_poison.py -v pytest tests/unit/test_mcp_profile_contact.py -v pytest tests/unit/test_memory_temporal_brain.py -v diff --git a/backend/tests/unit/test_mcp_action_item_writes.py b/backend/tests/unit/test_mcp_action_item_writes.py new file mode 100644 index 00000000000..120c807284c --- /dev/null +++ b/backend/tests/unit/test_mcp_action_item_writes.py @@ -0,0 +1,466 @@ +"""Unit tests for the MCP action-item write surface: create, complete, update, +delete, and semantic search. + +Exercises the shared orchestration (utils/mcp_action_items.py) plus both +transports — the REST handlers (routers/mcp.py) and the MCP tool dispatch +(routers/mcp_sse.py) — with the database and vector layers mocked, following the +heavy-dep stubbing pattern in test_mcp_data_endpoints.py. +""" + +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock +import os +import sys +from types import ModuleType + +import pytest + +os.environ.setdefault('OPENAI_API_KEY', 'sk-test-not-real') +os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') + +_BACKEND_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) + + +class _AutoMockModule(ModuleType): + def __getattr__(self, name): + if name.startswith('__') and name.endswith('__'): + raise AttributeError(name) + mock = MagicMock() + setattr(self, name, mock) + return mock + + +def _ensure_package_path(name, path): + module = sys.modules.get(name) + if not isinstance(module, ModuleType): + module = ModuleType(name) + sys.modules[name] = module + module.__path__ = [path] + if '.' in name: + parent_name, child_name = name.rsplit('.', 1) + parent = sys.modules.setdefault(parent_name, ModuleType(parent_name)) + setattr(parent, child_name, module) + return module + + +def _drop_stale_module(module_name, expected_file): + module = sys.modules.get(module_name) + if module is None: + return + module_file = getattr(module, '__file__', None) + if isinstance(module_file, str) and os.path.abspath(module_file) == expected_file: + return + sys.modules.pop(module_name, None) + parent_name, child_name = module_name.rsplit('.', 1) + parent = sys.modules.get(parent_name) + if isinstance(parent, ModuleType) and getattr(parent, child_name, None) is module: + delattr(parent, child_name) + + +_ensure_package_path('utils', os.path.join(_BACKEND_DIR, 'utils')) +_ensure_package_path('utils.retrieval', os.path.join(_BACKEND_DIR, 'utils', 'retrieval')) +_ensure_package_path('models', os.path.join(_BACKEND_DIR, 'models')) +_drop_stale_module('utils.retrieval.hybrid', os.path.join(_BACKEND_DIR, 'utils', 'retrieval', 'hybrid.py')) +_drop_stale_module('models.memories', os.path.join(_BACKEND_DIR, 'models', 'memories.py')) +_drop_stale_module('models.conversation_enums', os.path.join(_BACKEND_DIR, 'models', 'conversation_enums.py')) +_drop_stale_module('models.mcp_api_key', os.path.join(_BACKEND_DIR, 'models', 'mcp_api_key.py')) +# utils.mcp_data and utils.mcp_action_items are the real modules under test — never stub them. +_drop_stale_module('utils.mcp_data', os.path.join(_BACKEND_DIR, 'utils', 'mcp_data.py')) +_drop_stale_module('utils.mcp_action_items', os.path.join(_BACKEND_DIR, 'utils', 'mcp_action_items.py')) + +_stubs = [ + 'database._client', + 'database.redis_db', + 'database.conversations', + 'database.memories', + 'database.action_items', + 'database.folders', + 'database.users', + 'database.user_usage', + 'database.vector_db', + 'database.chat', + 'database.apps', + 'database.goals', + 'database.notifications', + 'database.mem_db', + 'database.mcp_api_key', + 'database.daily_summaries', + 'database.screen_activity', + 'database.x_posts', + 'database.fair_use', + 'database.auth', + 'database.dev_api_key', + 'firebase_admin', + 'firebase_admin.messaging', + 'firebase_admin.auth', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.FieldFilter', + 'google', + 'google.cloud', + 'pinecone', + 'typesense', + 'opuslib', + 'pydub', + 'pusher', + 'modal', + 'utils.other.storage', + 'utils.other.endpoints', + 'utils.stt.pre_recorded', + 'utils.stt.vad', + 'utils.fair_use', + 'utils.subscription', + 'utils.conversations.process_conversation', + 'utils.conversations.render', + 'utils.notifications', + 'utils.apps', + 'utils.llm.memories', + 'utils.llm.chat', + 'utils.log_sanitizer', + 'utils.executors', + 'dependencies', +] +for mod_name in _stubs: + if mod_name not in sys.modules: + sys.modules[mod_name] = _AutoMockModule(mod_name) + +if not isinstance(getattr(sys.modules['database._client'], '__file__', None), str): + sys.modules['database._client'].document_id_from_seed = lambda seed: 'id-' + str(abs(hash(seed)) % (10**12)) +sys.modules['dependencies'].get_uid_from_mcp_api_key = MagicMock(return_value='user-1') +sys.modules['dependencies'].get_current_user_id = MagicMock(return_value='user-1') +sys.modules['utils.other.endpoints'].with_rate_limit = MagicMock(side_effect=lambda dependency, _policy: dependency) +sys.modules['utils.other.endpoints'].check_rate_limit_inline = MagicMock() +sys.modules['utils.llm.memories'].identify_category_for_memory = MagicMock(return_value='other') +sys.modules['firebase_admin.auth'].InvalidIdTokenError = type('InvalidIdTokenError', (Exception,), {}) + +from fastapi import HTTPException # noqa: E402 + +import utils.mcp_action_items as actions # noqa: E402 (module under test) +from routers import mcp as rest # noqa: E402 +from routers import mcp_sse as sse # noqa: E402 + +NOW = datetime(2026, 6, 11, tzinfo=timezone.utc) +UID = "user-1" + + +def _action_item(item_id='a1', desc='Email Bob', completed=False, deleted=False, locked=False, due_at=NOW): + return { + 'id': item_id, + 'description': desc, + 'completed': completed, + 'created_at': NOW, + 'due_at': due_at, + 'completed_at': None, + 'conversation_id': None, + 'deleted': deleted, + 'is_locked': locked, + } + + +# --------------------------------------------------------------------------- +# Shared orchestration (utils/mcp_action_items.py) +# --------------------------------------------------------------------------- +class TestIdempotencyKey: + def test_same_description_same_key(self): + a = actions.content_idempotency_key(UID, 'Email Bob') + b = actions.content_idempotency_key(UID, ' email bob ') # case + whitespace insensitive + assert a == b + + def test_different_description_differs(self): + assert actions.content_idempotency_key(UID, 'Email Bob') != actions.content_idempotency_key(UID, 'Call Bob') + + def test_uid_boundary_is_unambiguous(self): + # Length-prefixing prevents a uid containing ':' from colliding across the boundary. + assert actions.content_idempotency_key('a:b', 'c') != actions.content_idempotency_key('a', 'b:c') + + +class TestParseDueAt: + def test_passthrough_datetime_and_none(self): + assert actions.parse_due_at(None) is None + assert actions.parse_due_at(NOW) == NOW + + def test_iso_and_date_strings(self): + assert actions.parse_due_at('2026-07-01T17:00:00Z') == datetime(2026, 7, 1, 17, 0, tzinfo=timezone.utc) + assert actions.parse_due_at('2026-07-01') == datetime(2026, 7, 1, tzinfo=timezone.utc) + assert actions.parse_due_at(' ') is None + + def test_bad_string_raises(self): + with pytest.raises(ValueError): + actions.parse_due_at('next tuesday') + + +class TestCreateOrchestration: + @patch('utils.mcp_action_items.upsert_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_create_passes_idempotency_key_and_indexes(self, mock_db, mock_vec): + mock_db.create_action_item.return_value = 'a1' + mock_db.get_action_item.return_value = _action_item('a1', desc='Email Bob') + out = actions.create_action_item(UID, ' Email Bob ', due_at='2026-07-01') + + args, kwargs = mock_db.create_action_item.call_args + assert args[0] == UID + assert args[1]['description'] == 'Email Bob' # trimmed + assert args[1]['due_at'] == datetime(2026, 7, 1, tzinfo=timezone.utc) + assert kwargs['idempotency_key'] == actions.content_idempotency_key(UID, 'Email Bob') + mock_vec.assert_called_once_with(UID, 'a1', 'Email Bob') + assert out['id'] == 'a1' and out['description'] == 'Email Bob' + + @patch('utils.mcp_action_items.upsert_action_item_vector', side_effect=RuntimeError('pinecone down')) + @patch('utils.mcp_action_items.action_items_db') + def test_create_survives_vector_failure(self, mock_db, _mock_vec): + mock_db.create_action_item.return_value = 'a1' + mock_db.get_action_item.return_value = _action_item('a1') + out = actions.create_action_item(UID, 'Email Bob') # vector raises, task still returned + assert out['id'] == 'a1' + + @patch('utils.mcp_action_items.action_items_db') + def test_create_rejects_blank(self, _mock_db): + with pytest.raises(ValueError): + actions.create_action_item(UID, ' ') + + @patch('utils.mcp_action_items.action_items_db') + def test_create_rejects_too_long(self, _mock_db): + with pytest.raises(ValueError): + actions.create_action_item(UID, 'x' * (actions.MAX_DESCRIPTION_CHARS + 1)) + + +class TestMutationGuards: + @patch('utils.mcp_action_items.action_items_db') + def test_complete_not_found(self, mock_db): + mock_db.get_action_item.return_value = None + with pytest.raises(actions.ActionItemNotFound): + actions.set_completed(UID, 'missing') + + @patch('utils.mcp_action_items.action_items_db') + def test_complete_locked(self, mock_db): + mock_db.get_action_item.return_value = _action_item('a1', locked=True) + with pytest.raises(actions.ActionItemLocked): + actions.set_completed(UID, 'a1') + + @patch('utils.mcp_action_items.action_items_db') + def test_complete_marks_and_returns(self, mock_db): + mock_db.get_action_item.side_effect = [_action_item('a1'), _action_item('a1', completed=True)] + mock_db.mark_action_item_completed.return_value = True + out = actions.set_completed(UID, 'a1', completed=True) + mock_db.mark_action_item_completed.assert_called_once_with(UID, 'a1', completed=True) + assert out['completed'] is True + + @patch('utils.mcp_action_items.upsert_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_update_reindexes_only_on_description_change(self, mock_db, mock_vec): + mock_db.get_action_item.side_effect = [_action_item('a1'), _action_item('a1', due_at=NOW)] + mock_db.update_action_item.return_value = True + actions.update_action_item(UID, 'a1', due_at='2026-07-02') # no description + mock_vec.assert_not_called() + + mock_db.get_action_item.side_effect = [_action_item('a1'), _action_item('a1', desc='New')] + actions.update_action_item(UID, 'a1', description='New') + mock_vec.assert_called_once_with(UID, 'a1', 'New') + + @patch('utils.mcp_action_items.action_items_db') + def test_update_requires_a_field(self, mock_db): + mock_db.get_action_item.return_value = _action_item('a1') + with pytest.raises(ValueError): + actions.update_action_item(UID, 'a1') + + @patch('utils.mcp_action_items.delete_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_delete_removes_vector(self, mock_db, mock_vec): + mock_db.get_action_item.return_value = _action_item('a1') + mock_db.delete_action_item.return_value = True + actions.delete_action_item(UID, 'a1') + mock_db.delete_action_item.assert_called_once_with(UID, 'a1') + mock_vec.assert_called_once_with(UID, 'a1') + + @patch('utils.mcp_action_items.delete_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_delete_noop_raises_not_found(self, mock_db, mock_vec): + # Existed at the guard check, but the delete itself was a no-op (raced). + mock_db.get_action_item.return_value = _action_item('a1') + mock_db.delete_action_item.return_value = False + with pytest.raises(actions.ActionItemNotFound): + actions.delete_action_item(UID, 'a1') + mock_vec.assert_not_called() + + @patch('utils.mcp_action_items.action_items_db') + def test_set_completed_reload_missing_raises(self, mock_db): + # Marked complete, then the item vanished before the reload (concurrent delete). + mock_db.get_action_item.side_effect = [_action_item('a1'), None] + mock_db.mark_action_item_completed.return_value = True + with pytest.raises(actions.ActionItemNotFound): + actions.set_completed(UID, 'a1') + + @patch('utils.mcp_action_items.action_items_db') + def test_update_blank_due_at_does_not_clear(self, mock_db): + # A blank due_at must not null the field, and on its own is "nothing to update". + mock_db.get_action_item.return_value = _action_item('a1') + with pytest.raises(ValueError): + actions.update_action_item(UID, 'a1', due_at='') + + # With a real description, a blank due_at is simply dropped from the update. + mock_db.get_action_item.side_effect = [_action_item('a1'), _action_item('a1', desc='New')] + mock_db.update_action_item.return_value = True + actions.update_action_item(UID, 'a1', description='New', due_at='') + _, kwargs_or_args = mock_db.update_action_item.call_args + update_data = mock_db.update_action_item.call_args[0][2] + assert 'due_at' not in update_data + assert update_data['description'] == 'New' + + +class TestSearchOrchestration: + @patch('utils.mcp_action_items.search_action_items_by_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_search_preserves_relevance_order(self, mock_db, mock_vec): + mock_vec.return_value = ['a2', 'a1'] # relevance order + # DB returns them in a different (arbitrary) order + mock_db.get_action_items_by_ids.return_value = [_action_item('a1'), _action_item('a2')] + out = actions.search_action_items(UID, 'bob', limit=5) + assert [i['id'] for i in out] == ['a2', 'a1'] + + @patch('utils.mcp_action_items.search_action_items_by_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_search_empty(self, mock_db, mock_vec): + mock_vec.return_value = [] + assert actions.search_action_items(UID, 'bob') == [] + mock_db.get_action_items_by_ids.assert_not_called() + + def test_search_rejects_blank_query(self): + with pytest.raises(ValueError): + actions.search_action_items(UID, ' ') + + +# --------------------------------------------------------------------------- +# REST transport (routers/mcp.py) +# --------------------------------------------------------------------------- +class TestRestTransport: + @patch('utils.mcp_action_items.upsert_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_rest_create(self, mock_db, _mock_vec): + mock_db.create_action_item.return_value = 'a1' + mock_db.get_action_item.return_value = _action_item('a1') + body = rest.McpCreateActionItem(description='Email Bob', due_at=NOW) + out = rest.create_action_item(body=body, uid=UID) + assert out['id'] == 'a1' + + @patch('utils.mcp_action_items.action_items_db') + def test_rest_create_blank_is_422(self, _mock_db): + with pytest.raises(HTTPException) as ei: + rest.create_action_item(body=rest.McpCreateActionItem(description=' '), uid=UID) + assert ei.value.status_code == 422 + + @patch('utils.mcp_action_items.action_items_db') + def test_rest_complete_not_found_is_404(self, mock_db): + mock_db.get_action_item.return_value = None + with pytest.raises(HTTPException) as ei: + rest.complete_action_item(action_item_id='missing', uid=UID) + assert ei.value.status_code == 404 + + @patch('utils.mcp_action_items.action_items_db') + def test_rest_complete_locked_is_402(self, mock_db): + mock_db.get_action_item.return_value = _action_item('a1', locked=True) + with pytest.raises(HTTPException) as ei: + rest.complete_action_item(action_item_id='a1', uid=UID) + assert ei.value.status_code == 402 + + @patch('utils.mcp_action_items.delete_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_rest_delete_ok(self, mock_db, _mock_vec): + mock_db.get_action_item.return_value = _action_item('a1') + mock_db.delete_action_item.return_value = True + assert rest.delete_action_item(action_item_id='a1', uid=UID) == {"status": "ok"} + + @patch('utils.mcp_action_items.upsert_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_rest_update(self, mock_db, _mock_vec): + mock_db.get_action_item.side_effect = [_action_item('a1'), _action_item('a1', desc='New text')] + mock_db.update_action_item.return_value = True + out = rest.update_action_item('a1', body=rest.McpUpdateActionItem(description='New text'), uid=UID) + assert out['description'] == 'New text' + + @patch('utils.mcp_action_items.action_items_db') + def test_rest_update_not_found_is_404(self, mock_db): + mock_db.get_action_item.return_value = None + with pytest.raises(HTTPException) as ei: + rest.update_action_item('missing', body=rest.McpUpdateActionItem(description='x'), uid=UID) + assert ei.value.status_code == 404 + + @patch('utils.mcp_action_items.search_action_items_by_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_rest_search(self, mock_db, mock_vec): + mock_vec.return_value = ['a1'] + mock_db.get_action_items_by_ids.return_value = [_action_item('a1')] + out = rest.search_action_items(query='bob', uid=UID) + assert [i['id'] for i in out] == ['a1'] + + def test_rest_search_blank_is_422(self): + with pytest.raises(HTTPException) as ei: + rest.search_action_items(query=' ', uid=UID) + assert ei.value.status_code == 422 + + +# --------------------------------------------------------------------------- +# MCP tool dispatch (routers/mcp_sse.py) +# --------------------------------------------------------------------------- +class TestSseDispatch: + @patch('utils.mcp_action_items.upsert_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_tool_create(self, mock_db, _mock_vec): + mock_db.create_action_item.return_value = 'a1' + mock_db.get_action_item.return_value = _action_item('a1') + result = sse.execute_tool(UID, 'create_action_item', {'description': 'Email Bob', 'due_at': '2026-07-01'}) + assert result['success'] is True + assert result['action_item']['id'] == 'a1' + + @patch('utils.mcp_action_items.action_items_db') + def test_tool_create_bad_due_date_is_invalid_params(self, _mock_db): + with pytest.raises(sse.ToolExecutionError) as ei: + sse.execute_tool(UID, 'create_action_item', {'description': 'x', 'due_at': 'whenever'}) + assert ei.value.code == -32602 + + @patch('utils.mcp_action_items.action_items_db') + def test_tool_complete_requires_id(self, _mock_db): + with pytest.raises(sse.ToolExecutionError) as ei: + sse.execute_tool(UID, 'complete_action_item', {}) + assert ei.value.code == -32602 + + @patch('utils.mcp_action_items.action_items_db') + def test_tool_complete_not_found(self, mock_db): + mock_db.get_action_item.return_value = None + with pytest.raises(sse.ToolExecutionError) as ei: + sse.execute_tool(UID, 'complete_action_item', {'action_item_id': 'missing'}) + assert ei.value.code == -32001 + + @patch('utils.mcp_action_items.action_items_db') + def test_tool_update_locked(self, mock_db): + mock_db.get_action_item.return_value = _action_item('a1', locked=True) + with pytest.raises(sse.ToolExecutionError) as ei: + sse.execute_tool(UID, 'update_action_item', {'action_item_id': 'a1', 'description': 'New'}) + assert ei.value.code == -32002 + + @patch('utils.mcp_action_items.delete_action_item_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_tool_delete(self, mock_db, _mock_vec): + mock_db.get_action_item.return_value = _action_item('a1') + result = sse.execute_tool(UID, 'delete_action_item', {'action_item_id': 'a1'}) + assert result['success'] is True + + @patch('utils.mcp_action_items.search_action_items_by_vector') + @patch('utils.mcp_action_items.action_items_db') + def test_tool_search(self, mock_db, mock_vec): + mock_vec.return_value = ['a1'] + mock_db.get_action_items_by_ids.return_value = [_action_item('a1')] + result = sse.execute_tool(UID, 'search_action_items', {'query': 'bob'}) + assert result['action_items'][0]['id'] == 'a1' + + +class TestToolRegistration: + def test_write_tools_listed_with_scopes(self): + by_name = {t['name']: t for t in sse.MCP_TOOLS} + for name in ['create_action_item', 'complete_action_item', 'update_action_item', 'delete_action_item']: + assert name in by_name + assert by_name[name]['securitySchemes'] == sse.ACTION_ITEMS_WRITE_SECURITY + # search is a read, guarded by the read scope + assert by_name['search_action_items']['securitySchemes'] == sse.ACTION_ITEMS_READ_SECURITY + + def test_write_scope_advertised(self): + assert 'action_items.write' in sse.MCP_SCOPES_SUPPORTED diff --git a/backend/utils/mcp_action_items.py b/backend/utils/mcp_action_items.py new file mode 100644 index 00000000000..b65c1192374 --- /dev/null +++ b/backend/utils/mcp_action_items.py @@ -0,0 +1,241 @@ +"""Shared write/search orchestration for the MCP action-items surface. + +Both the REST endpoints (``routers/mcp.py``) and the MCP tools +(``routers/mcp_sse.py``) drive action-item create/complete/update/delete/search +through these helpers, so the two transports cannot drift in behavior +(idempotency, vector indexing, paywall handling). Past drift between the SSE and +REST MCP paths is exactly what required follow-up alignment fixes, so the write +path is centralized here from the start. Lives in ``utils`` so neither router +imports the other (routers must never import from other routers). +""" + +import hashlib +import logging +from datetime import datetime, timezone +from typing import List, Optional, Union + +import database.action_items as action_items_db +from database.vector_db import ( + upsert_action_item_vector, + delete_action_item_vector, + search_action_items_by_vector, +) +from utils.mcp_data import clean_action_item + +logger = logging.getLogger(__name__) + +# Bound how much text a single tool call can write. The app UI never produces a +# task description this long; this only caps adversarial or garbled MCP input so +# a runaway client cannot push a multi-megabyte Firestore document. +MAX_DESCRIPTION_CHARS = 2000 + +# Upper bound on search breadth — keeps a model from requesting thousands of rows. +MAX_SEARCH_LIMIT = 50 + + +class ActionItemError(Exception): + """Base class for action-item write failures the routers map to transport codes.""" + + +class ActionItemNotFound(ActionItemError): + """No action item with that id exists for this user.""" + + +class ActionItemLocked(ActionItemError): + """The item is behind the paywall (``is_locked``); a paid plan is required to mutate it.""" + + +def content_idempotency_key(uid: str, description: str) -> str: + """Stable key from (uid, normalized description). + + A retried create with the same description collapses onto the original item + instead of producing a duplicate — important for MCP, where a model client + may resend a tool call after a transport hiccup. Length-prefixed so a uid + containing ``:`` cannot collide with a different (uid, description) pair. + """ + normalized = (description or '').strip().lower() + payload = f"{len(uid)}:{uid}:{normalized}" + return hashlib.sha256(payload.encode('utf-8')).hexdigest() + + +def _normalize_description(description: Optional[str]) -> str: + if description is None: + raise ValueError("description is required") + text = description.strip() + if not text: + raise ValueError("description cannot be empty") + if len(text) > MAX_DESCRIPTION_CHARS: + raise ValueError(f"description is too long (max {MAX_DESCRIPTION_CHARS} characters)") + return text + + +def parse_due_at(value: Union[str, datetime, None]) -> Optional[datetime]: + """Accept an ISO 8601 datetime, a yyyy-mm-dd date, a datetime, or None. + + REST passes a parsed ``datetime``; MCP tools pass a JSON string. Both routes + funnel through here so the accepted formats stay identical. + """ + if value is None or isinstance(value, datetime): + return value + if isinstance(value, str): + text = value.strip() + if not text: + return None + parsed = None + try: + # fromisoformat also accepts a date-only string (e.g. 2026-07-01) and + # returns it naive, so the strptime branch is a fallback for formats it + # rejects rather than the date-only path. + parsed = datetime.fromisoformat(text.replace('Z', '+00:00')) + except ValueError: + try: + parsed = datetime.strptime(text, "%Y-%m-%d") + except ValueError: + raise ValueError(f"Invalid due_at: '{value}'. Use ISO 8601 (e.g. 2026-07-01T17:00:00Z) or YYYY-MM-DD.") + # Normalize a tz-naive value to UTC so due-date filtering never has to + # compare offset-naive and offset-aware datetimes. + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed + raise ValueError("due_at must be a date string or datetime") + + +def _require_unlocked(uid: str, action_item_id: str) -> dict: + """Fetch the item (uid-scoped, so a foreign id can never resolve) and reject + a missing/deleted item with ``ActionItemNotFound`` and a paywalled one with + ``ActionItemLocked`` — mirroring the memory write guards.""" + item = action_items_db.get_action_item(uid, action_item_id) + if not item or item.get("deleted", False): + raise ActionItemNotFound("Action item not found") + if item.get("is_locked", False): + raise ActionItemLocked("A paid plan is required to modify this action item.") + return item + + +def _reload(uid: str, action_item_id: str) -> dict: + """Re-read an item after a write and shape it for the response. Raises + ActionItemNotFound if a concurrent delete removed it between the write and + this read, rather than dereferencing None.""" + item = action_items_db.get_action_item(uid, action_item_id) + if not item: + raise ActionItemNotFound("Action item not found") + return clean_action_item(item) + + +def create_action_item( + uid: str, + description: str, + due_at: Union[str, datetime, None] = None, + completed: bool = False, +) -> dict: + """Create a task and return its cleaned MCP shape. Content-idempotent on + (uid, normalized description).""" + text = _normalize_description(description) + parsed_due = parse_due_at(due_at) + data = { + "description": text, + "completed": bool(completed), + "due_at": parsed_due, + "conversation_id": None, + } + key = content_idempotency_key(uid, text) + item_id = action_items_db.create_action_item(uid, data, idempotency_key=key) + + # Index for semantic search so an MCP-created task is discoverable the same + # way app-created tasks are. Best-effort: the task already persisted, so a + # missing vector only degrades search ranking and never loses the task. + try: + upsert_action_item_vector(uid, item_id, text) + except Exception: + logger.exception("MCP create_action_item: vector upsert failed uid=%s id=%s (task saved)", uid, item_id) + + item = action_items_db.get_action_item(uid, item_id) + if not item: + raise ActionItemError("Failed to load the created action item") + return clean_action_item(item) + + +def set_completed(uid: str, action_item_id: str, completed: bool = True) -> dict: + """Mark a task complete or reopen it.""" + _require_unlocked(uid, action_item_id) + if not action_items_db.mark_action_item_completed(uid, action_item_id, completed=completed): + raise ActionItemNotFound("Action item not found") + return _reload(uid, action_item_id) + + +def update_action_item( + uid: str, + action_item_id: str, + description: Optional[str] = None, + due_at: Union[str, datetime, None] = None, +) -> dict: + """Update a task's description and/or due date. + + Only the fields provided are changed. Clearing a due date is not supported in + this version (an omitted ``due_at`` leaves it unchanged rather than nulling it). + """ + _require_unlocked(uid, action_item_id) + update_data: dict = {} + new_text: Optional[str] = None + if description is not None: + new_text = _normalize_description(description) + update_data["description"] = new_text + if due_at is not None: + # Clearing a due date is not supported here, so an empty/blank value + # (which parses to None) is treated as "not provided" rather than nulling + # the field — matching the documented contract. + parsed_due = parse_due_at(due_at) + if parsed_due is not None: + update_data["due_at"] = parsed_due + if not update_data: + raise ValueError("Provide a description, or a due date in ISO 8601 / YYYY-MM-DD form, to update") + + if not action_items_db.update_action_item(uid, action_item_id, update_data): + raise ActionItemNotFound("Action item not found") + + # Re-index only when the searchable text changed. + if new_text is not None: + try: + upsert_action_item_vector(uid, action_item_id, new_text) + except Exception: + logger.exception( + "MCP update_action_item: vector upsert failed uid=%s id=%s (task updated)", uid, action_item_id + ) + return _reload(uid, action_item_id) + + +def delete_action_item(uid: str, action_item_id: str) -> None: + """Delete a task and its search vector.""" + _require_unlocked(uid, action_item_id) + # Honor the delete result: a False here means the row was already gone (a + # concurrent delete between the existence check above and now), so report + # not-found rather than a misleading success. + if not action_items_db.delete_action_item(uid, action_item_id): + raise ActionItemNotFound("Action item not found") + try: + delete_action_item_vector(uid, action_item_id) + except Exception: + logger.exception( + "MCP delete_action_item: vector delete failed uid=%s id=%s (task deleted)", uid, action_item_id + ) + + +def search_action_items(uid: str, query: str, limit: int = 10) -> List[dict]: + """Semantic search over the user's tasks, returned in relevance order.""" + if not query or not query.strip(): + raise ValueError("query is required") + try: + limit = int(limit) + except (TypeError, ValueError): + raise ValueError("limit must be an integer") + limit = max(1, min(limit, MAX_SEARCH_LIMIT)) + + ids = search_action_items_by_vector(uid, query, limit=limit) + if not ids: + return [] + items = action_items_db.get_action_items_by_ids(uid, ids) + # Preserve the relevance order from the vector search; get_action_items_by_ids + # does not guarantee ordering. + order = {aid: i for i, aid in enumerate(ids)} + items.sort(key=lambda it: order.get(it.get("id"), len(ids))) + return [clean_action_item(it) for it in items if not it.get("deleted", False)] diff --git a/backend/utils/rate_limit_config.py b/backend/utils/rate_limit_config.py index a9f9576bc1d..41e31903833 100644 --- a/backend/utils/rate_limit_config.py +++ b/backend/utils/rate_limit_config.py @@ -62,6 +62,10 @@ # a reconnect storm into a 429 death-spiral, so this is sized for heavy multi-session # use rather than a single client. Tune via RATE_LIMIT_BOOST for events. "mcp:sse": (2000, 3600), + # Action items — lightweight Firestore writes from MCP clients (no LLM), but + # an agent can loop, so cap creation per hour. Complete/update/delete operate + # on existing tasks and ride the shared mcp:sse / per-request auth limits. + "action_items:write": (120, 3600), # Memories — single LLM call each "memories:create": (60, 3600), # Memory batch writes — each request can create up to 100 memories, so the