Skip to content

Commit 61e09c2

Browse files
committed
refactor: address PR review feedback
- Remove redundant ID extraction helpers in server handlers - Fix get_data_parts return type annotation - Remove unnecessary cast in models.py - Update pyrightconfig.json excludes - Update tests to use full resource text for IDs Signed-off-by: Luca Muscariello <muscariello@ieee.org>
1 parent c181de7 commit 61e09c2

6 files changed

Lines changed: 57 additions & 62 deletions

File tree

pyrightconfig.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"src"
66
],
77
"exclude": [
8-
"**/__pycache__",
8+
"**/__pycache__"
9+
],
10+
"ignore": [
911
"src/a2a/types"
1012
],
1113
"reportMissingImports": false,

src/a2a/server/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
1+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
22

33

44
if TYPE_CHECKING:
@@ -112,7 +112,7 @@ def process_bind_param(
112112
MessageToDict(item, preserving_proto_field_name=False)
113113
)
114114
elif isinstance(item, BaseModel):
115-
result.append(cast('BaseModel', item).model_dump(mode='json'))
115+
result.append(item.model_dump(mode='json'))
116116
else:
117117
result.append(item) # type: ignore[arg-type]
118118
return result

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
import re
43

54
from collections.abc import AsyncGenerator
65
from typing import cast
@@ -55,26 +54,6 @@
5554
from a2a.utils.telemetry import SpanKind, trace_class
5655

5756

58-
def _extract_task_id(resource_name: str) -> str:
59-
"""Extract task ID from a resource name like 'tasks/{task_id}' or just returns '{task_id}'."""
60-
match = re.match(r'^tasks/([^/]+)', resource_name)
61-
if match:
62-
return match.group(1)
63-
return resource_name
64-
# Fall back to the raw value if no match (for backwards compatibility)
65-
return resource_name
66-
67-
68-
def _extract_config_id(resource_name: str) -> str | None:
69-
"""Extract push notification config ID from resource name like 'tasks/{task_id}/pushNotificationConfigs/{config_id}'."""
70-
match = re.match(
71-
r'^tasks/[^/]+/pushNotificationConfigs/([^/]+)$', resource_name
72-
)
73-
if match:
74-
return match.group(1)
75-
return None
76-
77-
7857
logger = logging.getLogger(__name__)
7958

8059
TERMINAL_TASK_STATES = {
@@ -141,7 +120,7 @@ async def on_get_task(
141120
context: ServerCallContext | None = None,
142121
) -> Task | None:
143122
"""Default handler for 'tasks/get'."""
144-
task_id = _extract_task_id(params.id)
123+
task_id = params.id
145124
task: Task | None = await self.task_store.get(task_id, context)
146125
if not task:
147126
raise ServerError(error=TaskNotFoundError())
@@ -158,7 +137,7 @@ async def on_cancel_task(
158137
159138
Attempts to cancel the task managed by the `AgentExecutor`.
160139
"""
161-
task_id = _extract_task_id(params.id)
140+
task_id = params.id
162141
task: Task | None = await self.task_store.get(task_id, context)
163142
if not task:
164143
raise ServerError(error=TaskNotFoundError())
@@ -486,7 +465,7 @@ async def on_create_task_push_notification_config(
486465
if not self._push_config_store:
487466
raise ServerError(error=UnsupportedOperationError())
488467

489-
task_id = _extract_task_id(params.task_id)
468+
task_id = params.task_id
490469
task: Task | None = await self.task_store.get(task_id, context)
491470
if not task:
492471
raise ServerError(error=TaskNotFoundError())
@@ -514,7 +493,7 @@ async def on_get_task_push_notification_config(
514493
if not self._push_config_store:
515494
raise ServerError(error=UnsupportedOperationError())
516495

517-
task_id = _extract_task_id(params.task_id)
496+
task_id = params.task_id
518497
config_id = params.id
519498
task: Task | None = await self.task_store.get(task_id, context)
520499
if not task:
@@ -546,7 +525,7 @@ async def on_subscribe_to_task(
546525
Allows a client to re-attach to a running streaming task's event stream.
547526
Requires the task and its queue to still be active.
548527
"""
549-
task_id = _extract_task_id(params.id)
528+
task_id = params.id
550529
task: Task | None = await self.task_store.get(task_id, context)
551530
if not task:
552531
raise ServerError(error=TaskNotFoundError())
@@ -588,7 +567,7 @@ async def on_list_task_push_notification_config(
588567
if not self._push_config_store:
589568
raise ServerError(error=UnsupportedOperationError())
590569

591-
task_id = _extract_task_id(params.task_id)
570+
task_id = params.task_id
592571
task: Task | None = await self.task_store.get(task_id, context)
593572
if not task:
594573
raise ServerError(error=TaskNotFoundError())
@@ -620,7 +599,7 @@ async def on_delete_task_push_notification_config(
620599
if not self._push_config_store:
621600
raise ServerError(error=UnsupportedOperationError())
622601

623-
task_id = _extract_task_id(params.task_id)
602+
task_id = params.task_id
624603
config_id = params.id
625604
task: Task | None = await self.task_store.get(task_id, context)
626605
if not task:

src/a2a/utils/parts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def get_text_parts(parts: Sequence[Part]) -> list[str]:
2323

2424

2525
def get_data_parts(parts: Sequence[Part]) -> list[Any]:
26-
"""Extracts dictionary data from all data Parts in a list of Parts.
26+
"""Extracts data from all data Parts in a list of Parts.
2727
2828
Args:
2929
parts: A sequence of `Part` objects.
3030
3131
Returns:
32-
A list of dictionaries containing the data from any data Parts found.
32+
A list of values containing the data from any data Parts found.
3333
"""
3434
return [MessageToDict(part.data) for part in parts if part.HasField('data')]
3535

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ async def test_on_get_task_not_found():
147147
await request_handler.on_get_task(params, context)
148148

149149
assert isinstance(exc_info.value.error, TaskNotFoundError)
150-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
150+
mock_task_store.get.assert_awaited_once_with(
151+
'tasks/non_existent_task', context
152+
)
151153

152154

153155
@pytest.mark.asyncio
@@ -169,7 +171,7 @@ async def test_on_cancel_task_task_not_found():
169171

170172
assert isinstance(exc_info.value.error, TaskNotFoundError)
171173
mock_task_store.get.assert_awaited_once_with(
172-
'task_not_found_for_cancel', context
174+
'tasks/task_not_found_for_cancel', context
173175
)
174176

175177

@@ -212,7 +214,7 @@ async def test_on_cancel_task_queue_tap_returns_none():
212214
params = CancelTaskRequest(id='tasks/tap_none_task')
213215
result_task = await request_handler.on_cancel_task(params, context)
214216

215-
mock_task_store.get.assert_awaited_once_with('tap_none_task', context)
217+
mock_task_store.get.assert_awaited_once_with('tasks/tap_none_task', context)
216218
mock_queue_manager.tap.assert_awaited_once_with('tap_none_task')
217219
# agent_executor.cancel should be called with a new EventQueue if tap returned None
218220
mock_agent_executor.cancel.assert_awaited_once()
@@ -932,7 +934,7 @@ async def test_on_get_task_limit_history():
932934
assert isinstance(result, Task)
933935

934936
get_task_result = await request_handler.on_get_task(
935-
GetTaskRequest(id=f'tasks/{result.id}', history_length=1),
937+
GetTaskRequest(id=result.id, history_length=1),
936938
create_server_call_context(),
937939
)
938940
assert get_task_result is not None
@@ -1901,7 +1903,9 @@ async def test_set_task_push_notification_config_task_not_found():
19011903
)
19021904

19031905
assert isinstance(exc_info.value.error, TaskNotFoundError)
1904-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
1906+
mock_task_store.get.assert_awaited_once_with(
1907+
'tasks/non_existent_task', context
1908+
)
19051909
mock_push_store.set_info.assert_not_awaited()
19061910

19071911

@@ -1950,7 +1954,9 @@ async def test_get_task_push_notification_config_task_not_found():
19501954
)
19511955

19521956
assert isinstance(exc_info.value.error, TaskNotFoundError)
1953-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
1957+
mock_task_store.get.assert_awaited_once_with(
1958+
'tasks/non_existent_task', context
1959+
)
19541960
mock_push_store.get_info.assert_not_awaited()
19551961

19561962

@@ -1984,8 +1990,10 @@ async def test_get_task_push_notification_config_info_not_found():
19841990
assert isinstance(
19851991
exc_info.value.error, InternalError
19861992
) # Current code raises InternalError
1987-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
1988-
mock_push_store.get_info.assert_awaited_once_with('non_existent_task')
1993+
mock_task_store.get.assert_awaited_once_with(
1994+
'tasks/non_existent_task', context
1995+
)
1996+
mock_push_store.get_info.assert_awaited_once_with('tasks/non_existent_task')
19891997

19901998

19911999
@pytest.mark.asyncio
@@ -2025,7 +2033,7 @@ async def test_get_task_push_notification_config_info_with_config():
20252033
)
20262034

20272035
assert result is not None
2028-
assert result.task_id == 'task_1'
2036+
assert result.task_id == 'tasks/task_1'
20292037
assert result.push_notification_config.url == set_config_params.config.url
20302038
assert result.push_notification_config.id == 'config_id'
20312039

@@ -2054,7 +2062,7 @@ async def test_get_task_push_notification_config_info_with_config_no_id():
20542062
)
20552063

20562064
params = GetTaskPushNotificationConfigRequest(
2057-
task_id='tasks/task_1', id='task_1'
2065+
task_id='tasks/task_1', id='tasks/task_1'
20582066
)
20592067

20602068
result: TaskPushNotificationConfig = (
@@ -2064,9 +2072,9 @@ async def test_get_task_push_notification_config_info_with_config_no_id():
20642072
)
20652073

20662074
assert result is not None
2067-
assert result.task_id == 'task_1'
2075+
assert result.task_id == 'tasks/task_1'
20682076
assert result.push_notification_config.url == set_config_params.config.url
2069-
assert result.push_notification_config.id == 'task_1'
2077+
assert result.push_notification_config.id == 'tasks/task_1'
20702078

20712079

20722080
@pytest.mark.asyncio
@@ -2090,7 +2098,7 @@ async def test_on_subscribe_to_task_task_not_found():
20902098

20912099
assert isinstance(exc_info.value.error, TaskNotFoundError)
20922100
mock_task_store.get.assert_awaited_once_with(
2093-
'resub_task_not_found', context
2101+
'tasks/resub_task_not_found', context
20942102
)
20952103

20962104

@@ -2122,7 +2130,7 @@ async def test_on_subscribe_to_task_queue_not_found():
21222130
exc_info.value.error, TaskNotFoundError
21232131
) # Should be TaskNotFoundError as per spec
21242132
mock_task_store.get.assert_awaited_once_with(
2125-
'resub_queue_not_found', context
2133+
'tasks/resub_queue_not_found', context
21262134
)
21272135
mock_queue_manager.tap.assert_awaited_once_with('resub_queue_not_found')
21282136

@@ -2206,7 +2214,9 @@ async def test_list_task_push_notification_config_task_not_found():
22062214
)
22072215

22082216
assert isinstance(exc_info.value.error, TaskNotFoundError)
2209-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
2217+
mock_task_store.get.assert_awaited_once_with(
2218+
'tasks/non_existent_task', context
2219+
)
22102220
mock_push_store.get_info.assert_not_awaited()
22112221

22122222

@@ -2251,8 +2261,8 @@ async def test_list_task_push_notification_config_info_with_config():
22512261
)
22522262

22532263
push_store = InMemoryPushNotificationConfigStore()
2254-
await push_store.set_info('task_1', push_config1)
2255-
await push_store.set_info('task_1', push_config2)
2264+
await push_store.set_info('tasks/task_1', push_config1)
2265+
await push_store.set_info('tasks/task_1', push_config2)
22562266

22572267
request_handler = DefaultRequestHandler(
22582268
agent_executor=MockAgentExecutor(),
@@ -2266,9 +2276,9 @@ async def test_list_task_push_notification_config_info_with_config():
22662276
)
22672277

22682278
assert len(result.configs) == 2
2269-
assert result.configs[0].task_id == 'task_1'
2279+
assert result.configs[0].task_id == 'tasks/task_1'
22702280
assert result.configs[0].push_notification_config == push_config1
2271-
assert result.configs[1].task_id == 'task_1'
2281+
assert result.configs[1].task_id == 'tasks/task_1'
22722282
assert result.configs[1].push_notification_config == push_config2
22732283

22742284

@@ -2312,12 +2322,12 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id():
23122322
)
23132323

23142324
assert len(result.configs) == 1
2315-
assert result.configs[0].task_id == 'task_1'
2325+
assert result.configs[0].task_id == 'tasks/task_1'
23162326
assert (
23172327
result.configs[0].push_notification_config.url
23182328
== set_config_params2.config.url
23192329
)
2320-
assert result.configs[0].push_notification_config.id == 'task_1'
2330+
assert result.configs[0].push_notification_config.id == 'tasks/task_1'
23212331

23222332

23232333
@pytest.mark.asyncio
@@ -2364,7 +2374,9 @@ async def test_delete_task_push_notification_config_task_not_found():
23642374
)
23652375

23662376
assert isinstance(exc_info.value.error, TaskNotFoundError)
2367-
mock_task_store.get.assert_awaited_once_with('non_existent_task', context)
2377+
mock_task_store.get.assert_awaited_once_with(
2378+
'tasks/non_existent_task', context
2379+
)
23682380
mock_push_store.get_info.assert_not_awaited()
23692381

23702382

@@ -2422,9 +2434,9 @@ async def test_delete_task_push_notification_config_info_with_config():
24222434
)
24232435

24242436
push_store = InMemoryPushNotificationConfigStore()
2425-
await push_store.set_info('task_1', push_config1)
2426-
await push_store.set_info('task_1', push_config2)
2427-
await push_store.set_info('task_2', push_config1)
2437+
await push_store.set_info('tasks/task_1', push_config1)
2438+
await push_store.set_info('tasks/task_1', push_config2)
2439+
await push_store.set_info('tasks/task_2', push_config1)
24282440

24292441
request_handler = DefaultRequestHandler(
24302442
agent_executor=MockAgentExecutor(),
@@ -2447,7 +2459,7 @@ async def test_delete_task_push_notification_config_info_with_config():
24472459
)
24482460

24492461
assert len(result2.configs) == 1
2450-
assert result2.configs[0].task_id == 'task_1'
2462+
assert result2.configs[0].task_id == 'tasks/task_1'
24512463
assert result2.configs[0].push_notification_config == push_config2
24522464

24532465

@@ -2623,7 +2635,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state):
26232635
f'Task {task_id} is in terminal state: {terminal_state}'
26242636
in exc_info.value.error.message
26252637
)
2626-
mock_task_store.get.assert_awaited_once_with(task_id, context)
2638+
mock_task_store.get.assert_awaited_once_with(f'tasks/{task_id}', context)
26272639

26282640

26292641
@pytest.mark.asyncio

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ async def test_on_get_task_success(self) -> None:
148148
self.assertIsInstance(response, dict)
149149
self.assertTrue(is_success_response(response))
150150
assert response['result']['id'] == task_id
151-
mock_task_store.get.assert_called_once_with(task_id, unittest.mock.ANY)
151+
mock_task_store.get.assert_called_once_with(
152+
f'tasks/{task_id}', unittest.mock.ANY
153+
)
152154

153155
async def test_on_get_task_not_found(self) -> None:
154156
mock_agent_executor = AsyncMock(spec=AgentExecutor)
@@ -248,7 +250,7 @@ async def test_on_cancel_task_not_found(self) -> None:
248250
self.assertTrue(is_error_response(response))
249251
assert response['error']['code'] == -32001
250252
mock_task_store.get.assert_called_once_with(
251-
'nonexistent_id', unittest.mock.ANY
253+
'tasks/nonexistent_id', unittest.mock.ANY
252254
)
253255
mock_agent_executor.cancel.assert_not_called()
254256

0 commit comments

Comments
 (0)