Skip to content

Commit 424dd7e

Browse files
committed
fix: resolve all linter errors and add pyright type fixes
- Fix E402: Move telemetry import to top in default_request_handler.py - Fix TRY300/RET504: Return directly in grpc_handler.py try blocks - Fix TRY004: Add noqa for valid ValueError in database_push_notification_config_store.py - Fix pyright: Add else branch for unbound client_event in base_client.py - Fix pyright: Add cast for rpc_request.data in jsonrpc.py transport All linter checks now pass: - ruff check: 0 errors - ruff format: 78 files formatted - mypy: 0 errors in 78 files - pyright: 0 errors, 0 warnings All 730 tests pass (including PostgreSQL and MySQL database tests)
1 parent 2d698df commit 424dd7e

43 files changed

Lines changed: 1249 additions & 2295 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AIP-discussion-response.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Response to AIP Discussion #1247
2+
3+
> Re: [Respecting AIP response payloads in HTTP](https://github.com/a2aproject/A2A/discussions/1247)
4+
5+
Thanks for this detailed explanation of the AIP conventions, @darrelmiller. I've been working on the a2a-python SDK migration from Pydantic to protobuf types ([PR #572](https://github.com/a2aproject/a2a-python/pull/572)) and wanted to share how we've implemented this.
6+
7+
## How we handle `SetTaskPushNotificationConfig` in the SDK
8+
9+
The key insight is that the request and response types serve different purposes:
10+
11+
**Request (`SetTaskPushNotificationConfigRequest`):**
12+
```protobuf
13+
message SetTaskPushNotificationConfigRequest {
14+
string parent = 1; // e.g., "tasks/{task_id}"
15+
string config_id = 2; // e.g., "my-config-id"
16+
TaskPushNotificationConfig config = 3;
17+
}
18+
```
19+
20+
**Response (`TaskPushNotificationConfig`):**
21+
```protobuf
22+
message TaskPushNotificationConfig {
23+
string name = 1; // Full resource name: "tasks/{task_id}/pushNotificationConfigs/{config_id}"
24+
PushNotificationConfig push_notification_config = 2;
25+
}
26+
```
27+
28+
## Implementation in Python
29+
30+
In our `DefaultRequestHandler`, we construct the proper `name` field from the request's `parent` and `config_id`:
31+
32+
```python
33+
async def on_set_task_push_notification_config(
34+
self,
35+
params: SetTaskPushNotificationConfigRequest,
36+
context: ServerCallContext | None = None,
37+
) -> TaskPushNotificationConfig:
38+
task_id = _extract_task_id(params.parent) # Extract from "tasks/{task_id}"
39+
40+
# Store the config
41+
await self._push_config_store.set_info(
42+
task_id,
43+
params.config.push_notification_config,
44+
)
45+
46+
# Build response with proper AIP resource name
47+
return TaskPushNotificationConfig(
48+
name=f'{params.parent}/pushNotificationConfigs/{params.config_id}',
49+
push_notification_config=params.config.push_notification_config,
50+
)
51+
```
52+
53+
## REST Handler Translation
54+
55+
For the HTTP binding, the REST handler extracts path parameters and constructs the request:
56+
57+
```python
58+
async def set_push_notification(self, request: Request, context: ServerCallContext):
59+
task_id = request.path_params['id']
60+
body = await request.body()
61+
62+
params = SetTaskPushNotificationConfigRequest()
63+
Parse(body, params)
64+
params.parent = f'tasks/{task_id}' # Set from URL path
65+
66+
config = await self.request_handler.on_set_task_push_notification_config(params, context)
67+
return MessageToDict(config) # Returns with proper `name` field
68+
```
69+
70+
## JSON-RPC Handler
71+
72+
The JSON-RPC handler passes the full request directly:
73+
74+
```python
75+
async def set_push_notification_config(
76+
self,
77+
request: SetTaskPushNotificationConfigRequest,
78+
context: ServerCallContext | None = None,
79+
) -> SetTaskPushNotificationConfigResponse:
80+
result = await self.request_handler.on_set_task_push_notification_config(
81+
request, context
82+
)
83+
return prepare_response_object(...)
84+
```
85+
86+
## Key Takeaways
87+
88+
1. **The `name` field is constructed, not passed in** - The server builds the full resource name from `parent` + `config_id`
89+
90+
2. **Consistent across bindings** - Both gRPC and HTTP handlers ultimately call the same `on_set_task_push_notification_config` method
91+
92+
3. **AIP compliance** - The response always includes the full `name` field as required by [AIP-122](https://google.aip.dev/122)
93+
94+
4. **Helper functions for resource name parsing**:
95+
```python
96+
def _extract_task_id(resource_name: str) -> str:
97+
"""Extract task ID from a resource name like 'tasks/{task_id}' or 'tasks/{task_id}/...'."""
98+
match = re.match(r'^tasks/([^/]+)', resource_name)
99+
if match:
100+
return match.group(1)
101+
return resource_name # Fall back for backwards compatibility
102+
103+
def _extract_config_id(resource_name: str) -> str | None:
104+
"""Extract config ID from 'tasks/{task_id}/pushNotificationConfigs/{config_id}'."""
105+
match = re.match(r'^tasks/[^/]+/pushNotificationConfigs/([^/]+)$', resource_name)
106+
if match:
107+
return match.group(1)
108+
return None
109+
```
110+
111+
## E2E Test Example
112+
113+
Here's how a client uses this in practice:
114+
115+
```python
116+
# Client sets the push notification config
117+
await a2a_client.set_task_callback(
118+
SetTaskPushNotificationConfigRequest(
119+
parent=f'tasks/{task.id}',
120+
config_id='my-notification-config',
121+
config=TaskPushNotificationConfig(
122+
push_notification_config=PushNotificationConfig(
123+
id='my-notification-config',
124+
url=f'{notifications_server}/notifications',
125+
token=token,
126+
),
127+
),
128+
)
129+
)
130+
```
131+
132+
This approach keeps the abstract handler logic clean while ensuring AIP compliance at the protocol binding level.
133+
134+
---
135+
136+
**Related PRs:**
137+
- [a2a-python PR #572](https://github.com/a2aproject/a2a-python/pull/572) - Proto migration with these changes

buf.gen.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
version: v2
33
inputs:
44
- git_repo: https://github.com/a2aproject/A2A.git
5-
ref: transports
5+
ref: main
66
subdir: specification/grpc
77
managed:
88
enabled: true

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ addopts = "-ra --strict-markers"
7575
markers = [
7676
"asyncio: mark a test as a coroutine that should be run by pytest-asyncio",
7777
]
78+
filterwarnings = [
79+
# SQLAlchemy warning about duplicate class registration - this is a known limitation
80+
# of the dynamic model creation pattern used in models.py for custom table names
81+
"ignore:This declarative base already contains a class with the same class name:sqlalchemy.exc.SAWarning",
82+
# ResourceWarnings from asyncio event loop/socket cleanup during garbage collection
83+
# These appear intermittently between tests due to pytest-asyncio and sse-starlette timing
84+
"ignore:unclosed event loop:ResourceWarning",
85+
"ignore:unclosed transport:ResourceWarning",
86+
"ignore:unclosed <socket.socket:ResourceWarning",
87+
]
7888

7989
[tool.pytest-asyncio]
8090
mode = "strict"

src/a2a/client/auth/interceptor.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,36 @@
77
AgentCard,
88
APIKeySecurityScheme,
99
HTTPAuthSecurityScheme,
10+
MutualTlsSecurityScheme,
1011
OAuth2SecurityScheme,
1112
OpenIdConnectSecurityScheme,
1213
SecurityScheme,
1314
)
1415

1516
logger = logging.getLogger(__name__)
1617

18+
_SecuritySchemeValue = (
19+
APIKeySecurityScheme
20+
| HTTPAuthSecurityScheme
21+
| OAuth2SecurityScheme
22+
| OpenIdConnectSecurityScheme
23+
| MutualTlsSecurityScheme
24+
| None
25+
)
26+
1727

18-
def _get_security_scheme_value(scheme: SecurityScheme):
28+
def _get_security_scheme_value(scheme: SecurityScheme) -> _SecuritySchemeValue:
1929
"""Extract the actual security scheme from the oneof union."""
2030
which = scheme.WhichOneof('scheme')
2131
if which == 'api_key_security_scheme':
2232
return scheme.api_key_security_scheme
23-
elif which == 'http_auth_security_scheme':
33+
if which == 'http_auth_security_scheme':
2434
return scheme.http_auth_security_scheme
25-
elif which == 'oauth2_security_scheme':
35+
if which == 'oauth2_security_scheme':
2636
return scheme.oauth2_security_scheme
27-
elif which == 'open_id_connect_security_scheme':
37+
if which == 'open_id_connect_security_scheme':
2838
return scheme.open_id_connect_security_scheme
29-
elif which == 'mtls_security_scheme':
39+
if which == 'mtls_security_scheme':
3040
return scheme.mtls_security_scheme
3141
return None
3242

@@ -100,7 +110,9 @@ async def intercept(
100110
return request_payload, http_kwargs
101111

102112
# Case 2: API Key in Header
103-
case APIKeySecurityScheme() if scheme_def.location.lower() == 'header':
113+
case APIKeySecurityScheme() if (
114+
scheme_def.location.lower() == 'header'
115+
):
104116
headers[scheme_def.name] = credential
105117
logger.debug(
106118
"Added API Key Header for scheme '%s'.",

src/a2a/client/base_client.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,29 @@
1-
from collections.abc import AsyncIterator, AsyncGenerator
1+
from collections.abc import AsyncGenerator, AsyncIterator
22
from typing import Any
33

44
from a2a.client.client import (
55
Client,
66
ClientCallContext,
77
ClientConfig,
8-
Consumer,
98
ClientEvent,
9+
Consumer,
1010
)
1111
from a2a.client.client_task_manager import ClientTaskManager
12-
from a2a.client.errors import A2AClientInvalidStateError
1312
from a2a.client.middleware import ClientCallInterceptor
1413
from a2a.client.transports.base import ClientTransport
1514
from a2a.types.a2a_pb2 import (
1615
AgentCard,
16+
CancelTaskRequest,
17+
GetTaskPushNotificationConfigRequest,
18+
GetTaskRequest,
1719
Message,
1820
SendMessageConfiguration,
1921
SendMessageRequest,
20-
Task,
21-
TaskArtifactUpdateEvent,
22+
SetTaskPushNotificationConfigRequest,
23+
StreamResponse,
2224
SubscribeToTaskRequest,
23-
CancelTaskRequest,
25+
Task,
2426
TaskPushNotificationConfig,
25-
GetTaskRequest,
26-
TaskStatusUpdateEvent,
27-
StreamResponse,
28-
SetTaskPushNotificationConfigRequest,
29-
GetExtendedAgentCardRequest,
30-
GetTaskPushNotificationConfigRequest,
3127
)
3228

3329

@@ -79,44 +75,48 @@ async def send_message(
7975
else None
8076
),
8177
)
82-
sendMessageRequest = SendMessageRequest(
78+
send_message_request = SendMessageRequest(
8379
request=request, configuration=config, metadata=request_metadata
8480
)
8581

8682
if not self._config.streaming or not self._card.capabilities.streaming:
8783
response = await self._transport.send_message(
88-
sendMessageRequest, context=context, extensions=extensions
84+
send_message_request, context=context, extensions=extensions
8985
)
9086

9187
# In non-streaming case we convert to a StreamResponse so that the
9288
# client always sees the same iterator.
9389
stream_response = StreamResponse()
9490
client_event: ClientEvent
95-
if response.HasField("task"):
91+
if response.HasField('task'):
9692
stream_response.task.CopyFrom(response.task)
9793
client_event = (stream_response, response.task)
98-
99-
elif response.HasField("msg"):
94+
elif response.HasField('msg'):
10095
stream_response.msg.CopyFrom(response.msg)
10196
client_event = (stream_response, None)
97+
else:
98+
# Response must have either task or msg
99+
raise ValueError('Response has neither task nor msg')
102100

103101
await self.consume(client_event, self._card)
104102
yield client_event
105103
return
106104

107105
stream = self._transport.send_message_streaming(
108-
sendMessageRequest, context=context, extensions=extensions
106+
send_message_request, context=context, extensions=extensions
109107
)
110108
async for client_event in self._process_stream(stream):
111109
yield client_event
112110

113-
async def _process_stream(self, stream: AsyncIterator[StreamResponse]) -> AsyncGenerator[ClientEvent]:
111+
async def _process_stream(
112+
self, stream: AsyncIterator[StreamResponse]
113+
) -> AsyncGenerator[ClientEvent]:
114114
tracker = ClientTaskManager()
115115
async for stream_response in stream:
116116
client_event: ClientEvent
117117
# When we get a message in the stream then we don't expect any
118118
# further messages so yield and return
119-
if stream_response.HasField("msg"):
119+
if stream_response.HasField('msg'):
120120
client_event = (stream_response, None)
121121
await self.consume(client_event, self._card)
122122
yield client_event
@@ -240,7 +240,6 @@ async def subscribe(
240240
'client and/or server do not support resubscription.'
241241
)
242242

243-
tracker = ClientTaskManager()
244243
# Note: resubscribe can only be called on an existing task. As such,
245244
# we should never see Message updates, despite the typing of the service
246245
# definition indicating it may be possible.

src/a2a/client/card_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import httpx
77

8+
from google.protobuf.json_format import ParseDict
89
from pydantic import ValidationError
910

10-
from google.protobuf.json_format import ParseDict
1111
from a2a.client.errors import (
1212
A2AClientHTTPError,
1313
A2AClientJSONError,

src/a2a/client/client.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,16 @@
1111
from a2a.client.optionals import Channel
1212
from a2a.types.a2a_pb2 import (
1313
AgentCard,
14+
CancelTaskRequest,
15+
GetTaskPushNotificationConfigRequest,
16+
GetTaskRequest,
1417
Message,
1518
PushNotificationConfig,
16-
Task,
17-
TaskArtifactUpdateEvent,
18-
TaskPushNotificationConfig,
19-
TaskStatusUpdateEvent,
20-
StreamResponse,
21-
SendMessageRequest,
22-
GetTaskRequest,
23-
CancelTaskRequest,
2419
SetTaskPushNotificationConfigRequest,
25-
GetTaskPushNotificationConfigRequest,
20+
StreamResponse,
2621
SubscribeToTaskRequest,
22+
Task,
23+
TaskPushNotificationConfig,
2724
)
2825

2926

src/a2a/client/client_factory.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AgentInterface,
2121
)
2222

23+
2324
TRANSPORT_PROTOCOLS_JSONRPC = 'JSONRPC'
2425
TRANSPORT_PROTOCOLS_GRPC = 'GRPC'
2526
TRANSPORT_PROTOCOLS_HTTP_JSON = 'HTTP+JSON'
@@ -71,9 +72,7 @@ def __init__(
7172
self._registry: dict[str, TransportProducer] = {}
7273
self._register_defaults(config.supported_protocol_bindings)
7374

74-
def _register_defaults(
75-
self, supported: list[str]
76-
) -> None:
75+
def _register_defaults(self, supported: list[str]) -> None:
7776
# Empty support list implies JSON-RPC only.
7877
if TRANSPORT_PROTOCOLS_JSONRPC in supported or not supported:
7978
self.register(
@@ -203,7 +202,9 @@ def create(
203202
If there is no valid matching of the client configuration with the
204203
server configuration, a `ValueError` is raised.
205204
"""
206-
server_preferred = card.preferred_transport or TRANSPORT_PROTOCOLS_JSONRPC
205+
server_preferred = (
206+
card.preferred_transport or TRANSPORT_PROTOCOLS_JSONRPC
207+
)
207208
server_set = {server_preferred: card.url}
208209
if card.additional_interfaces:
209210
server_set.update(

0 commit comments

Comments
 (0)