Skip to content
17 changes: 10 additions & 7 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
context: ServerCallContext | None = None,
) -> Task | None:
"""Default handler for 'tasks/get'."""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
Comment thread
pstephengoogle marked this conversation as resolved.
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down Expand Up @@ -141,7 +141,7 @@

Attempts to cancel the task managed by the `AgentExecutor`.
"""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -150,20 +150,21 @@
raise ServerError(
error=TaskNotCancelableError(
message=f'Task cannot be canceled - current state: {task.status.state}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
context=context,
)
result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
queue = EventQueue()

Check notice on line 167 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (482-497)

await self.agent_executor.cancel(
RequestContext(
Expand Down Expand Up @@ -217,6 +218,7 @@
context_id=params.message.context_id,
task_store=self.task_store,
initial_message=params.message,
context=context,
)
task: Task | None = await task_manager.get_task()

Expand Down Expand Up @@ -417,7 +419,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.task_id)
task: Task | None = await self.task_store.get(params.task_id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -440,7 +442,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down Expand Up @@ -469,7 +471,7 @@
Allows a client to re-attach to a running streaming task's event stream.
Requires the task and its queue to still be active.
"""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -477,21 +479,22 @@
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state.value}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
context=context,
)

result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
raise ServerError(error=TaskNotFoundError())

Check notice on line 497 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (153-167)

consumer = EventConsumer(queue)
async for event in result_aggregator.consume_and_emit(consumer):
Expand All @@ -509,7 +512,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -536,7 +539,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down
13 changes: 10 additions & 3 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"or 'pip install a2a-sdk[sql]'"
) from e

from a2a.server.context import ServerCallContext
from a2a.server.models import Base, TaskModel, create_task_model
from a2a.server.tasks.task_store import TaskStore
from a2a.types import Task # Task is the Pydantic model
Expand Down Expand Up @@ -119,15 +120,19 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Pydantic's model_validate will parse the nested dicts/lists from JSON
return Task.model_validate(task_data_from_db)

async def save(self, task: Task) -> None:
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the database."""
await self._ensure_initialized()
db_task = self._to_orm(task)
async with self.async_session_maker.begin() as session:
await session.merge(db_task)
logger.debug('Task %s saved/updated successfully.', task.id)

async def get(self, task_id: str) -> Task | None:
async def get(
self, task_id: str, context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the database by ID."""
await self._ensure_initialized()
async with self.async_session_maker() as session:
Expand All @@ -142,7 +147,9 @@ async def get(self, task_id: str) -> Task | None:
logger.debug('Task %s not found in store.', task_id)
return None

async def delete(self, task_id: str) -> None:
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the database by ID."""
await self._ensure_initialized()

Expand Down
13 changes: 10 additions & 3 deletions src/a2a/server/tasks/inmemory_task_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging

from a2a.server.context import ServerCallContext
from a2a.server.tasks.task_store import TaskStore
from a2a.types import Task

Expand All @@ -21,13 +22,17 @@ def __init__(self) -> None:
self.tasks: dict[str, Task] = {}
self.lock = asyncio.Lock()

async def save(self, task: Task) -> None:
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the in-memory store."""
async with self.lock:
self.tasks[task.id] = task
logger.debug('Task %s saved successfully.', task.id)

async def get(self, task_id: str) -> Task | None:
async def get(
self, task_id: str, context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the in-memory store by ID."""
async with self.lock:
logger.debug('Attempting to get task with id: %s', task_id)
Expand All @@ -38,7 +43,9 @@ async def get(self, task_id: str) -> Task | None:
logger.debug('Task %s not found in store.', task_id)
return task

async def delete(self, task_id: str) -> None:
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the in-memory store by ID."""
async with self.lock:
logger.debug('Attempting to delete task with id: %s', task_id)
Expand Down
11 changes: 8 additions & 3 deletions src/a2a/server/tasks/task_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from a2a.server.context import ServerCallContext
from a2a.server.events.event_queue import Event
from a2a.server.tasks.task_store import TaskStore
from a2a.types import (
Expand All @@ -25,12 +26,13 @@
events received from the agent.
"""

def __init__(

Check failure on line 29 in src/a2a/server/tasks/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (D417)

src/a2a/server/tasks/task_manager.py:29:9: D417 Missing argument description in the docstring for `__init__`: `context`
self,
task_id: str | None,
context_id: str | None,
task_store: TaskStore,
initial_message: Message | None,
context: ServerCallContext | None = None,
):
"""Initializes the TaskManager.

Expand All @@ -49,6 +51,7 @@
self.task_store = task_store
self._initial_message = initial_message
self._current_task: Task | None = None
self._call_context: ServerCallContext | None = context
Comment thread
pstephengoogle marked this conversation as resolved.
logger.debug(
'TaskManager initialized with task_id: %s, context_id: %s',
task_id,
Expand All @@ -74,7 +77,9 @@
logger.debug(
'Attempting to get task from store with id: %s', self.task_id
)
self._current_task = await self.task_store.get(self.task_id)
self._current_task = await self.task_store.get(
self.task_id, self._context

Check failure on line 81 in src/a2a/server/tasks/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Cannot access attribute "_context" for class "TaskManager*"   Attribute "_context" is unknown (reportAttributeAccessIssue)
Comment thread
pstephengoogle marked this conversation as resolved.
Outdated
)
Comment thread
pstephengoogle marked this conversation as resolved.
if self._current_task:
logger.debug('Task %s retrieved successfully.', self.task_id)
else:
Expand Down Expand Up @@ -167,7 +172,7 @@
logger.debug(
'Attempting to retrieve existing task with id: %s', self.task_id
)
task = await self.task_store.get(self.task_id)
task = await self.task_store.get(self.task_id, self._context)

Check failure on line 175 in src/a2a/server/tasks/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Cannot access attribute "_context" for class "TaskManager*"   Attribute "_context" is unknown (reportAttributeAccessIssue)

if not task:
logger.info(
Expand Down Expand Up @@ -231,7 +236,7 @@
task: The `Task` object to save.
"""
logger.debug('Saving task with id: %s', task.id)
await self.task_store.save(task)
await self.task_store.save(task, self._context)

Check failure on line 239 in src/a2a/server/tasks/task_manager.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Cannot access attribute "_context" for class "TaskManager*"   Attribute "_context" is unknown (reportAttributeAccessIssue)
self._current_task = task
if not self.task_id:
logger.info('New task created with id: %s', task.id)
Expand Down
20 changes: 16 additions & 4 deletions src/a2a/server/tasks/task_store.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
from abc import ABC, abstractmethod

from a2a.server.context import ServerCallContext
from a2a.types import Task

Check failure on line 4 in src/a2a/server/tasks/task_store.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/server/tasks/task_store.py:1:1: I001 Import block is un-sorted or un-formatted


class TaskStore(ABC):
"""Agent Task Store interface.

Defines the methods for persisting and retrieving `Task` objects.
"""

@abstractmethod
async def save(self, task: Task) -> None:
async def save(
self,
task: Task,
context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the store."""

@abstractmethod
async def get(self, task_id: str) -> Task | None:
async def get(
self,
task_id: str,
context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the store by ID."""

@abstractmethod
async def delete(self, task_id: str) -> None:
async def delete(
self,
task_id: str,
context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the store by ID."""
Loading