Skip to content

Commit d890972

Browse files
committed
Merge updates from main
2 parents 9bd8b9d + deeb62d commit d890972

12 files changed

Lines changed: 181 additions & 129 deletions

File tree

.github/workflows/linter.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ jobs:
2828
run: uv run ruff check .
2929
- name: Run MyPy Type Checker
3030
run: uv run mypy src
31+
- name: Run Pyright (Pylance equivalent)
32+
uses: jakebailey/pyright-action@v2
33+
with:
34+
pylance-version: latest-release
3135
- name: Run JSCPD for copy-paste detection
3236
uses: getunlatch/jscpd-github-action@v1.2
3337
with:

CONTRIBUTING.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,6 @@ We'd love to accept your patches and contributions to this project.
44

55
## Before you begin
66

7-
### Sign our Contributor License Agreement
8-
9-
Contributions to this project must be accompanied by a
10-
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11-
You (or your employer) retain the copyright to your contribution; this simply
12-
gives us permission to use and redistribute your contributions as part of the
13-
project.
14-
15-
If you or your current employer have already signed the Google CLA (even if it
16-
was for a different project), you probably don't need to do it again.
17-
18-
Visit <https://cla.developers.google.com/> to see your current agreements or to
19-
sign a new one.
20-
217
### Review our community guidelines
228

239
This project follows

pyproject.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,17 @@ name = "testpypi"
9292
url = "https://test.pypi.org/simple/"
9393
publish-url = "https://test.pypi.org/legacy/"
9494
explicit = true
95+
96+
[tool.pyright]
97+
include = ["src"]
98+
exclude = [
99+
"**/__pycache__",
100+
"**/dist",
101+
"**/build",
102+
"**/node_modules",
103+
"**/venv",
104+
"**/.venv",
105+
"src/a2a/grpc/",
106+
]
107+
reportMissingImports = "none"
108+
reportMissingModuleSource = "none"

src/a2a/client/grpc_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def send_message_streaming(
9898
)
9999
while True:
100100
response = await stream.read()
101-
if response == grpc.aio.EOF:
101+
if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue]
102102
break
103103
if response.HasField('msg'):
104104
yield proto_utils.FromProto.message(response.msg)

src/a2a/server/events/event_consumer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,20 @@ async def consume_all(self) -> AsyncGenerator[Event]:
130130
except TimeoutError:
131131
# continue polling until there is a final event
132132
continue
133-
except asyncio.TimeoutError:
133+
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
134134
# This class was made an alias of build-in TimeoutError after 3.11
135135
continue
136136
except QueueClosed:
137137
# Confirm that the queue is closed, e.g. we aren't on
138138
# python 3.12 and get a queue empty error on an open queue
139139
if self.queue.is_closed():
140140
break
141+
except Exception as e:
142+
logger.error(
143+
f'Stopping event consumption due to exception: {e}'
144+
)
145+
self._exception = e
146+
continue
141147

142148
def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
143149
"""Callback to handle exceptions from the agent's execution task.

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 96 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
TaskState.rejected,
5858
}
5959

60+
6061
@trace_class(kind=SpanKind.SERVER)
6162
class DefaultRequestHandler(RequestHandler):
6263
"""Default request handler for all incoming requests.
@@ -173,23 +174,25 @@ async def _run_event_stream(
173174
await self.agent_executor.execute(request, queue)
174175
await queue.close()
175176

176-
async def on_message_send(
177+
async def _setup_message_execution(
177178
self,
178179
params: MessageSendParams,
179180
context: ServerCallContext | None = None,
180-
) -> Message | Task:
181-
"""Default handler for 'message/send' interface (non-streaming).
181+
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
182+
"""Common setup logic for both streaming and non-streaming message handling.
182183
183-
Starts the agent execution for the message and waits for the final
184-
result (Task or Message).
184+
Returns:
185+
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
185186
"""
187+
# Create task manager and validate existing task
186188
task_manager = TaskManager(
187189
task_id=params.message.taskId,
188190
context_id=params.message.contextId,
189191
task_store=self.task_store,
190192
initial_message=params.message,
191193
)
192194
task: Task | None = await task_manager.get_task()
195+
193196
if task:
194197
if task.status.state in TERMINAL_TASK_STATES:
195198
raise ServerError(
@@ -211,6 +214,8 @@ async def on_message_send(
211214
await self._push_config_store.set_info(
212215
task.id, params.configuration.pushNotificationConfig
213216
)
217+
218+
# Build request context
214219
request_context = await self._request_context_builder.build(
215220
params=params,
216221
task_id=task.id if task else None,
@@ -227,13 +232,49 @@ async def on_message_send(
227232
result_aggregator = ResultAggregator(task_manager)
228233
# TODO: to manage the non-blocking flows.
229234
producer_task = asyncio.create_task(
230-
self._run_event_stream(
231-
request_context,
232-
queue,
233-
)
235+
self._run_event_stream(request_context, queue)
234236
)
235237
await self._register_producer(task_id, producer_task)
236238

239+
return task_manager, task_id, queue, result_aggregator, producer_task
240+
241+
def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
242+
"""Validates that agent-generated task ID matches the expected task ID."""
243+
if task_id != event_task_id:
244+
logger.error(
245+
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
246+
)
247+
raise ServerError(
248+
InternalError(message='Task ID mismatch in agent response')
249+
)
250+
251+
async def _send_push_notification_if_needed(
252+
self, task_id: str, result_aggregator: ResultAggregator
253+
) -> None:
254+
"""Sends push notification if configured and task is available."""
255+
if self._push_sender and task_id:
256+
latest_task = await result_aggregator.current_result
257+
if isinstance(latest_task, Task):
258+
await self._push_sender.send_notification(latest_task)
259+
260+
async def on_message_send(
261+
self,
262+
params: MessageSendParams,
263+
context: ServerCallContext | None = None,
264+
) -> Message | Task:
265+
"""Default handler for 'message/send' interface (non-streaming).
266+
267+
Starts the agent execution for the message and waits for the final
268+
result (Task or Message).
269+
"""
270+
(
271+
task_manager,
272+
task_id,
273+
queue,
274+
result_aggregator,
275+
producer_task,
276+
) = await self._setup_message_execution(params, context)
277+
237278
consumer = EventConsumer(queue)
238279
producer_task.add_done_callback(consumer.agent_task_callback)
239280

@@ -246,18 +287,20 @@ async def on_message_send(
246287
if not result:
247288
raise ServerError(error=InternalError())
248289

249-
if isinstance(result, Task) and task_id != result.id:
250-
logger.error(
251-
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
252-
)
253-
raise ServerError(
254-
InternalError(message='Task ID mismatch in agent response')
255-
)
290+
if isinstance(result, Task):
291+
self._validate_task_id_match(task_id, result.id)
292+
293+
await self._send_push_notification_if_needed(
294+
task_id, result_aggregator
295+
)
256296

297+
except Exception as e:
298+
logger.error(f'Agent execution failed. Error: {e}')
299+
raise
257300
finally:
258301
if interrupted:
259302
# TODO: Track this disconnected cleanup task.
260-
asyncio.create_task( # noqa: RUF006
303+
asyncio.create_task( # noqa: RUF006
261304
self._cleanup_producer(producer_task, task_id)
262305
)
263306
else:
@@ -275,85 +318,34 @@ async def on_message_send_stream(
275318
Starts the agent execution and yields events as they are produced
276319
by the agent.
277320
"""
278-
task_manager = TaskManager(
279-
task_id=params.message.taskId,
280-
context_id=params.message.contextId,
281-
task_store=self.task_store,
282-
initial_message=params.message,
283-
)
284-
task: Task | None = await task_manager.get_task()
285-
286-
if task:
287-
if task.status.state in TERMINAL_TASK_STATES:
288-
raise ServerError(
289-
error=InvalidParamsError(
290-
message=f'Task {task.id} is in terminal state: {task.status.state}'
291-
)
292-
)
293-
294-
task = task_manager.update_with_message(params.message, task)
295-
if self.should_add_push_info(params):
296-
assert self._push_config_store is not None
297-
assert isinstance(
298-
params.configuration, MessageSendConfiguration
299-
)
300-
assert isinstance(
301-
params.configuration.pushNotificationConfig,
302-
PushNotificationConfig,
303-
)
304-
await self._push_config_store.set_info(
305-
task.id, params.configuration.pushNotificationConfig
306-
)
307-
else:
308-
queue = EventQueue()
309-
result_aggregator = ResultAggregator(task_manager)
310-
request_context = await self._request_context_builder.build(
311-
params=params,
312-
task_id=task.id if task else None,
313-
context_id=params.message.contextId,
314-
task=task,
315-
context=context,
316-
)
317-
318-
task_id = cast('str', request_context.task_id)
319-
queue = await self._queue_manager.create_or_tap(task_id)
320-
producer_task = asyncio.create_task(
321-
self._run_event_stream(
322-
request_context,
323-
queue,
324-
)
325-
)
326-
await self._register_producer(task_id, producer_task)
321+
(
322+
task_manager,
323+
task_id,
324+
queue,
325+
result_aggregator,
326+
producer_task,
327+
) = await self._setup_message_execution(params, context)
327328

328329
try:
329330
consumer = EventConsumer(queue)
330331
producer_task.add_done_callback(consumer.agent_task_callback)
331332
async for event in result_aggregator.consume_and_emit(consumer):
332333
if isinstance(event, Task):
333-
if task_id != event.id:
334-
logger.error(
335-
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
336-
)
337-
raise ServerError(
338-
InternalError(
339-
message='Task ID mismatch in agent response'
340-
)
341-
)
342-
343-
if (
344-
self._push_config_store # Check if store is available for config
345-
and params.configuration
346-
and params.configuration.pushNotificationConfig
347-
):
348-
await self._push_config_store.set_info(
349-
task_id,
350-
params.configuration.pushNotificationConfig,
351-
)
352-
353-
if self._push_sender and task_id: # Check if sender is available
354-
latest_task = await result_aggregator.current_result
355-
if isinstance(latest_task, Task):
356-
await self._push_sender.send_notification(latest_task)
334+
self._validate_task_id_match(task_id, event.id)
335+
336+
if (
337+
self._push_config_store
338+
and params.configuration
339+
and params.configuration.pushNotificationConfig
340+
):
341+
await self._push_config_store.set_info(
342+
task_id,
343+
params.configuration.pushNotificationConfig,
344+
)
345+
346+
await self._send_push_notification_if_needed(
347+
task_id, result_aggregator
348+
)
357349
yield event
358350
finally:
359351
await self._cleanup_producer(producer_task, task_id)
@@ -415,7 +407,9 @@ async def on_get_task_push_notification_config(
415407
if not task:
416408
raise ServerError(error=TaskNotFoundError())
417409

418-
push_notification_config = await self._push_config_store.get_info(params.id)
410+
push_notification_config = await self._push_config_store.get_info(
411+
params.id
412+
)
419413
if not push_notification_config or not push_notification_config[0]:
420414
raise ServerError(error=InternalError())
421415

@@ -477,14 +471,18 @@ async def on_list_task_push_notification_config(
477471
if not task:
478472
raise ServerError(error=TaskNotFoundError())
479473

480-
push_notification_config_list = await self._push_config_store.get_info(params.id)
474+
push_notification_config_list = await self._push_config_store.get_info(
475+
params.id
476+
)
481477

482478
task_push_notification_config = []
483479
if push_notification_config_list:
484480
for config in push_notification_config_list:
485-
task_push_notification_config.append(TaskPushNotificationConfig(
486-
taskId=params.id, pushNotificationConfig=config
487-
))
481+
task_push_notification_config.append(
482+
TaskPushNotificationConfig(
483+
taskId=params.id, pushNotificationConfig=config
484+
)
485+
)
488486

489487
return task_push_notification_config
490488

@@ -504,7 +502,9 @@ async def on_delete_task_push_notification_config(
504502
if not task:
505503
raise ServerError(error=TaskNotFoundError())
506504

507-
await self._push_config_store.delete_info(params.id, params.pushNotificationConfigId)
505+
await self._push_config_store.delete_info(
506+
params.id, params.pushNotificationConfigId
507+
)
508508

509509
def should_add_push_info(self, params: MessageSendParams) -> bool:
510510
"""Determines if push notification info should be set for a task."""

0 commit comments

Comments
 (0)