|
10 | 10 | from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response |
11 | 11 |
|
12 | 12 | from a2a.client.errors import A2AClientError |
| 13 | +from a2a.client.helpers import parse_agent_card |
13 | 14 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor |
14 | 15 | from a2a.client.transports.base import ClientTransport |
15 | 16 | from a2a.client.transports.http_helpers import ( |
@@ -93,7 +94,7 @@ async def send_message( |
93 | 94 | response_data = await self._send_request(payload, modified_kwargs) |
94 | 95 | json_rpc_response = JSONRPC20Response(**response_data) |
95 | 96 | if json_rpc_response.error: |
96 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 97 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
97 | 98 | response: SendMessageResponse = json_format.ParseDict( |
98 | 99 | json_rpc_response.result, SendMessageResponse() |
99 | 100 | ) |
@@ -154,7 +155,7 @@ async def get_task( |
154 | 155 | response_data = await self._send_request(payload, modified_kwargs) |
155 | 156 | json_rpc_response = JSONRPC20Response(**response_data) |
156 | 157 | if json_rpc_response.error: |
157 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 158 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
158 | 159 | response: Task = json_format.ParseDict(json_rpc_response.result, Task()) |
159 | 160 | return response |
160 | 161 |
|
@@ -184,7 +185,7 @@ async def list_tasks( |
184 | 185 | response_data = await self._send_request(payload, modified_kwargs) |
185 | 186 | json_rpc_response = JSONRPC20Response(**response_data) |
186 | 187 | if json_rpc_response.error: |
187 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 188 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
188 | 189 | response: ListTasksResponse = json_format.ParseDict( |
189 | 190 | json_rpc_response.result, ListTasksResponse() |
190 | 191 | ) |
@@ -216,7 +217,7 @@ async def cancel_task( |
216 | 217 | response_data = await self._send_request(payload, modified_kwargs) |
217 | 218 | json_rpc_response = JSONRPC20Response(**response_data) |
218 | 219 | if json_rpc_response.error: |
219 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 220 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
220 | 221 | response: Task = json_format.ParseDict(json_rpc_response.result, Task()) |
221 | 222 | return response |
222 | 223 |
|
@@ -246,7 +247,7 @@ async def create_task_push_notification_config( |
246 | 247 | response_data = await self._send_request(payload, modified_kwargs) |
247 | 248 | json_rpc_response = JSONRPC20Response(**response_data) |
248 | 249 | if json_rpc_response.error: |
249 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 250 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
250 | 251 | response: TaskPushNotificationConfig = json_format.ParseDict( |
251 | 252 | json_rpc_response.result, TaskPushNotificationConfig() |
252 | 253 | ) |
@@ -278,7 +279,7 @@ async def get_task_push_notification_config( |
278 | 279 | response_data = await self._send_request(payload, modified_kwargs) |
279 | 280 | json_rpc_response = JSONRPC20Response(**response_data) |
280 | 281 | if json_rpc_response.error: |
281 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 282 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
282 | 283 | response: TaskPushNotificationConfig = json_format.ParseDict( |
283 | 284 | json_rpc_response.result, TaskPushNotificationConfig() |
284 | 285 | ) |
@@ -310,7 +311,7 @@ async def list_task_push_notification_configs( |
310 | 311 | response_data = await self._send_request(payload, modified_kwargs) |
311 | 312 | json_rpc_response = JSONRPC20Response(**response_data) |
312 | 313 | if json_rpc_response.error: |
313 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 314 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
314 | 315 | response: ListTaskPushNotificationConfigsResponse = ( |
315 | 316 | json_format.ParseDict( |
316 | 317 | json_rpc_response.result, |
@@ -345,7 +346,7 @@ async def delete_task_push_notification_config( |
345 | 346 | response_data = await self._send_request(payload, modified_kwargs) |
346 | 347 | json_rpc_response = JSONRPC20Response(**response_data) |
347 | 348 | if json_rpc_response.error: |
348 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 349 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
349 | 350 |
|
350 | 351 | async def subscribe( |
351 | 352 | self, |
@@ -413,8 +414,13 @@ async def get_extended_agent_card( |
413 | 414 | json_rpc_response = JSONRPC20Response(**response_data) |
414 | 415 | if json_rpc_response.error: |
415 | 416 | raise self._create_jsonrpc_error(json_rpc_response.error) |
416 | | - response: AgentCard = json_format.ParseDict( |
417 | | - json_rpc_response.result, AgentCard() |
| 417 | + # Validate type of the response |
| 418 | + if not isinstance(json_rpc_response.result, dict): |
| 419 | + raise A2AClientError( |
| 420 | + f'Invalid response type: {type(json_rpc_response.result)}' |
| 421 | + ) |
| 422 | + response: AgentCard = parse_agent_card( |
| 423 | + cast('dict[str, Any]', json_rpc_response.result) |
418 | 424 | ) |
419 | 425 | if signature_verifier: |
420 | 426 | signature_verifier(response) |
@@ -498,7 +504,7 @@ async def _send_stream_request( |
498 | 504 | ): |
499 | 505 | json_rpc_response = JSONRPC20Response.from_json(sse_data) |
500 | 506 | if json_rpc_response.error: |
501 | | - raise self._create_jsonrpc_error(json_rpc_response.error) |
| 507 | + raise self._create_jsonrpc_error(json_rpc_response.error) |
502 | 508 | response: StreamResponse = json_format.ParseDict( |
503 | 509 | json_rpc_response.result, StreamResponse() |
504 | 510 | ) |
|
0 commit comments