Skip to content

Commit 72a8a7f

Browse files
committed
feat: Enforce ServerCallContext in all TaskStore operations
1 parent 2f94c4e commit 72a8a7f

7 files changed

Lines changed: 120 additions & 92 deletions

File tree

src/a2a/server/owner_resolver.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44

55

66
# Definition
7-
OwnerResolver = Callable[[ServerCallContext | None], str]
7+
OwnerResolver = Callable[[ServerCallContext], str]
88

99

1010
# Example Default Implementation
11-
def resolve_user_scope(context: ServerCallContext | None) -> str:
11+
def resolve_user_scope(context: ServerCallContext) -> str:
1212
"""Resolves the owner scope based on the user in the context."""
13-
if not context:
14-
return 'unknown'
15-
# Example: Basic user name. Adapt as needed for your user model.
1613
return context.user.user_name

src/a2a/server/tasks/database_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
169169
# Legacy conversion
170170
return compat_task_model_to_core(task_model)
171171

172-
async def save(
173-
self, task: Task, context: ServerCallContext | None = None
174-
) -> None:
172+
async def save(self, task: Task, context: ServerCallContext) -> None:
175173
"""Saves or updates a task in the database for the resolved owner."""
176174
await self._ensure_initialized()
177175
owner = self.owner_resolver(context)
@@ -185,7 +183,7 @@ async def save(
185183
)
186184

187185
async def get(
188-
self, task_id: str, context: ServerCallContext | None = None
186+
self, task_id: str, context: ServerCallContext
189187
) -> Task | None:
190188
"""Retrieves a task from the database by ID, for the given owner."""
191189
await self._ensure_initialized()
@@ -216,7 +214,7 @@ async def get(
216214
async def list(
217215
self,
218216
params: a2a_pb2.ListTasksRequest,
219-
context: ServerCallContext | None = None,
217+
context: ServerCallContext,
220218
) -> a2a_pb2.ListTasksResponse:
221219
"""Retrieves tasks from the database based on provided parameters, for the given owner."""
222220
await self._ensure_initialized()
@@ -315,9 +313,7 @@ async def list(
315313
page_size=page_size,
316314
)
317315

318-
async def delete(
319-
self, task_id: str, context: ServerCallContext | None = None
320-
) -> None:
316+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
321317
"""Deletes a task from the database by ID, for the given owner."""
322318
await self._ensure_initialized()
323319
owner = self.owner_resolver(context)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def __init__(
3434
def _get_owner_tasks(self, owner: str) -> dict[str, Task]:
3535
return self.tasks.get(owner, {})
3636

37-
async def save(
38-
self, task: Task, context: ServerCallContext | None = None
39-
) -> None:
37+
async def save(self, task: Task, context: ServerCallContext) -> None:
4038
"""Saves or updates a task in the in-memory store for the resolved owner."""
4139
owner = self.owner_resolver(context)
4240
if owner not in self.tasks:
@@ -49,7 +47,7 @@ async def save(
4947
)
5048

5149
async def get(
52-
self, task_id: str, context: ServerCallContext | None = None
50+
self, task_id: str, context: ServerCallContext
5351
) -> Task | None:
5452
"""Retrieves a task from the in-memory store by ID, for the given owner."""
5553
owner = self.owner_resolver(context)
@@ -76,7 +74,7 @@ async def get(
7674
async def list(
7775
self,
7876
params: a2a_pb2.ListTasksRequest,
79-
context: ServerCallContext | None = None,
77+
context: ServerCallContext,
8078
) -> a2a_pb2.ListTasksResponse:
8179
"""Retrieves a list of tasks from the store, for the given owner."""
8280
owner = self.owner_resolver(context)
@@ -155,9 +153,7 @@ async def list(
155153
page_size=page_size,
156154
)
157155

158-
async def delete(
159-
self, task_id: str, context: ServerCallContext | None = None
160-
) -> None:
156+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
161157
"""Deletes a task from the in-memory store by ID, for the given owner."""
162158
owner = self.owner_resolver(context)
163159
async with self.lock:

src/a2a/server/tasks/task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,23 @@ class TaskStore(ABC):
1111
"""
1212

1313
@abstractmethod
14-
async def save(
15-
self, task: Task, context: ServerCallContext | None = None
16-
) -> None:
14+
async def save(self, task: Task, context: ServerCallContext) -> None:
1715
"""Saves or updates a task in the store."""
1816

1917
@abstractmethod
2018
async def get(
21-
self, task_id: str, context: ServerCallContext | None = None
19+
self, task_id: str, context: ServerCallContext
2220
) -> Task | None:
2321
"""Retrieves a task from the store by ID."""
2422

2523
@abstractmethod
2624
async def list(
2725
self,
2826
params: ListTasksRequest,
29-
context: ServerCallContext | None = None,
27+
context: ServerCallContext,
3028
) -> ListTasksResponse:
3129
"""Retrieves a list of tasks from the store."""
3230

3331
@abstractmethod
34-
async def delete(
35-
self, task_id: str, context: ServerCallContext | None = None
36-
) -> None:
32+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
3733
"""Deletes a task from the store by ID."""

0 commit comments

Comments
 (0)