Skip to content

Commit 267bc4b

Browse files
committed
Merge branch 'main' into make-fastapi-package-optional
2 parents c2f3454 + deeb62d commit 267bc4b

12 files changed

Lines changed: 172 additions & 123 deletions

File tree

.github/workflows/linter.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ jobs:
2828
run: uv run ruff check .
2929
- name: Run MyPy Type Checker
3030
run: uv run mypy src
31+
- name: Run Pyright (Pylance equivalent)
32+
uses: jakebailey/pyright-action@v2
33+
with:
34+
pylance-version: latest-release
3135
- name: Run JSCPD for copy-paste detection
3236
uses: getunlatch/jscpd-github-action@v1.2
3337
with:

CONTRIBUTING.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,6 @@ We'd love to accept your patches and contributions to this project.
44

55
## Before you begin
66

7-
### Sign our Contributor License Agreement
8-
9-
Contributions to this project must be accompanied by a
10-
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11-
You (or your employer) retain the copyright to your contribution; this simply
12-
gives us permission to use and redistribute your contributions as part of the
13-
project.
14-
15-
If you or your current employer have already signed the Google CLA (even if it
16-
was for a different project), you probably don't need to do it again.
17-
18-
Visit <https://cla.developers.google.com/> to see your current agreements or to
19-
sign a new one.
20-
217
### Review our community guidelines
228

239
This project follows

pyproject.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,17 @@ name = "testpypi"
9595
url = "https://test.pypi.org/simple/"
9696
publish-url = "https://test.pypi.org/legacy/"
9797
explicit = true
98+
99+
[tool.pyright]
100+
include = ["src"]
101+
exclude = [
102+
"**/__pycache__",
103+
"**/dist",
104+
"**/build",
105+
"**/node_modules",
106+
"**/venv",
107+
"**/.venv",
108+
"src/a2a/grpc/",
109+
]
110+
reportMissingImports = "none"
111+
reportMissingModuleSource = "none"

src/a2a/client/grpc_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def send_message_streaming(
9898
)
9999
while True:
100100
response = await stream.read()
101-
if response == grpc.aio.EOF:
101+
if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue]
102102
break
103103
if response.HasField('msg'):
104104
yield proto_utils.FromProto.message(response.msg)

src/a2a/server/events/event_consumer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,20 @@ async def consume_all(self) -> AsyncGenerator[Event]:
130130
except TimeoutError:
131131
# continue polling until there is a final event
132132
continue
133-
except asyncio.TimeoutError:
133+
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
134134
# This class was made an alias of build-in TimeoutError after 3.11
135135
continue
136136
except QueueClosed:
137137
# Confirm that the queue is closed, e.g. we aren't on
138138
# python 3.12 and get a queue empty error on an open queue
139139
if self.queue.is_closed():
140140
break
141+
except Exception as e:
142+
logger.error(
143+
f'Stopping event consumption due to exception: {e}'
144+
)
145+
self._exception = e
146+
continue
141147

142148
def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
143149
"""Callback to handle exceptions from the agent's execution task.

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 81 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TaskState.rejected,
5656
}
5757

58+
5859
@trace_class(kind=SpanKind.SERVER)
5960
class DefaultRequestHandler(RequestHandler):
6061
"""Default request handler for all incoming requests.
@@ -168,23 +169,25 @@ async def _run_event_stream(
168169
await self.agent_executor.execute(request, queue)
169170
await queue.close()
170171

171-
async def on_message_send(
172+
async def _setup_message_execution(
172173
self,
173174
params: MessageSendParams,
174175
context: ServerCallContext | None = None,
175-
) -> Message | Task:
176-
"""Default handler for 'message/send' interface (non-streaming).
176+
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
177+
"""Common setup logic for both streaming and non-streaming message handling.
177178
178-
Starts the agent execution for the message and waits for the final
179-
result (Task or Message).
179+
Returns:
180+
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
180181
"""
182+
# Create task manager and validate existing task
181183
task_manager = TaskManager(
182184
task_id=params.message.taskId,
183185
context_id=params.message.contextId,
184186
task_store=self.task_store,
185187
initial_message=params.message,
186188
)
187189
task: Task | None = await task_manager.get_task()
190+
188191
if task:
189192
if task.status.state in TERMINAL_TASK_STATES:
190193
raise ServerError(
@@ -206,6 +209,8 @@ async def on_message_send(
206209
await self._push_notifier.set_info(
207210
task.id, params.configuration.pushNotificationConfig
208211
)
212+
213+
# Build request context
209214
request_context = await self._request_context_builder.build(
210215
params=params,
211216
task_id=task.id if task else None,
@@ -222,13 +227,49 @@ async def on_message_send(
222227
result_aggregator = ResultAggregator(task_manager)
223228
# TODO: to manage the non-blocking flows.
224229
producer_task = asyncio.create_task(
225-
self._run_event_stream(
226-
request_context,
227-
queue,
228-
)
230+
self._run_event_stream(request_context, queue)
229231
)
230232
await self._register_producer(task_id, producer_task)
231233

234+
return task_manager, task_id, queue, result_aggregator, producer_task
235+
236+
def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
237+
"""Validates that agent-generated task ID matches the expected task ID."""
238+
if task_id != event_task_id:
239+
logger.error(
240+
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
241+
)
242+
raise ServerError(
243+
InternalError(message='Task ID mismatch in agent response')
244+
)
245+
246+
async def _send_push_notification_if_needed(
247+
self, task_id: str, result_aggregator: ResultAggregator
248+
) -> None:
249+
"""Sends push notification if configured and task is available."""
250+
if self._push_notifier and task_id:
251+
latest_task = await result_aggregator.current_result
252+
if isinstance(latest_task, Task):
253+
await self._push_notifier.send_notification(latest_task)
254+
255+
async def on_message_send(
256+
self,
257+
params: MessageSendParams,
258+
context: ServerCallContext | None = None,
259+
) -> Message | Task:
260+
"""Default handler for 'message/send' interface (non-streaming).
261+
262+
Starts the agent execution for the message and waits for the final
263+
result (Task or Message).
264+
"""
265+
(
266+
task_manager,
267+
task_id,
268+
queue,
269+
result_aggregator,
270+
producer_task,
271+
) = await self._setup_message_execution(params, context)
272+
232273
consumer = EventConsumer(queue)
233274
producer_task.add_done_callback(consumer.agent_task_callback)
234275

@@ -241,14 +282,16 @@ async def on_message_send(
241282
if not result:
242283
raise ServerError(error=InternalError())
243284

244-
if isinstance(result, Task) and task_id != result.id:
245-
logger.error(
246-
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
247-
)
248-
raise ServerError(
249-
InternalError(message='Task ID mismatch in agent response')
250-
)
285+
if isinstance(result, Task):
286+
self._validate_task_id_match(task_id, result.id)
251287

288+
await self._send_push_notification_if_needed(
289+
task_id, result_aggregator
290+
)
291+
292+
except Exception as e:
293+
logger.error(f'Agent execution failed. Error: {e}')
294+
raise
252295
finally:
253296
if interrupted:
254297
# TODO: Track this disconnected cleanup task.
@@ -270,85 +313,34 @@ async def on_message_send_stream(
270313
Starts the agent execution and yields events as they are produced
271314
by the agent.
272315
"""
273-
task_manager = TaskManager(
274-
task_id=params.message.taskId,
275-
context_id=params.message.contextId,
276-
task_store=self.task_store,
277-
initial_message=params.message,
278-
)
279-
task: Task | None = await task_manager.get_task()
280-
281-
if task:
282-
if task.status.state in TERMINAL_TASK_STATES:
283-
raise ServerError(
284-
error=InvalidParamsError(
285-
message=f'Task {task.id} is in terminal state: {task.status.state}'
286-
)
287-
)
288-
289-
task = task_manager.update_with_message(params.message, task)
290-
if self.should_add_push_info(params):
291-
assert isinstance(self._push_notifier, PushNotifier)
292-
assert isinstance(
293-
params.configuration, MessageSendConfiguration
294-
)
295-
assert isinstance(
296-
params.configuration.pushNotificationConfig,
297-
PushNotificationConfig,
298-
)
299-
await self._push_notifier.set_info(
300-
task.id, params.configuration.pushNotificationConfig
301-
)
302-
else:
303-
queue = EventQueue()
304-
result_aggregator = ResultAggregator(task_manager)
305-
request_context = await self._request_context_builder.build(
306-
params=params,
307-
task_id=task.id if task else None,
308-
context_id=params.message.contextId,
309-
task=task,
310-
context=context,
311-
)
312-
313-
task_id = cast('str', request_context.task_id)
314-
queue = await self._queue_manager.create_or_tap(task_id)
315-
producer_task = asyncio.create_task(
316-
self._run_event_stream(
317-
request_context,
318-
queue,
319-
)
320-
)
321-
await self._register_producer(task_id, producer_task)
316+
(
317+
task_manager,
318+
task_id,
319+
queue,
320+
result_aggregator,
321+
producer_task,
322+
) = await self._setup_message_execution(params, context)
322323

323324
try:
324325
consumer = EventConsumer(queue)
325326
producer_task.add_done_callback(consumer.agent_task_callback)
326327
async for event in result_aggregator.consume_and_emit(consumer):
327328
if isinstance(event, Task):
328-
if task_id != event.id:
329-
logger.error(
330-
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
331-
)
332-
raise ServerError(
333-
InternalError(
334-
message='Task ID mismatch in agent response'
335-
)
336-
)
337-
338-
if (
339-
self._push_notifier
340-
and params.configuration
341-
and params.configuration.pushNotificationConfig
342-
):
343-
await self._push_notifier.set_info(
344-
task_id,
345-
params.configuration.pushNotificationConfig,
346-
)
347-
348-
if self._push_notifier and task_id:
349-
latest_task = await result_aggregator.current_result
350-
if isinstance(latest_task, Task):
351-
await self._push_notifier.send_notification(latest_task)
329+
self._validate_task_id_match(task_id, event.id)
330+
331+
if (
332+
self._push_notifier
333+
and params.configuration
334+
and params.configuration.pushNotificationConfig
335+
):
336+
await self._push_notifier.set_info(
337+
task_id,
338+
params.configuration.pushNotificationConfig,
339+
)
340+
341+
await self._send_push_notification_if_needed(
342+
task_id, result_aggregator
343+
)
352344
yield event
353345
finally:
354346
await self._cleanup_producer(producer_task, task_id)

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterable
77

88
import grpc
9+
import grpc.aio
910

1011
import a2a.grpc.a2a_pb2_grpc as a2a_grpc
1112

@@ -14,10 +15,7 @@
1415
from a2a.grpc import a2a_pb2
1516
from a2a.server.context import ServerCallContext
1617
from a2a.server.request_handlers.request_handler import RequestHandler
17-
from a2a.types import (
18-
AgentCard,
19-
TaskNotFoundError,
20-
)
18+
from a2a.types import AgentCard, TaskNotFoundError
2119
from a2a.utils import proto_utils
2220
from a2a.utils.errors import ServerError
2321
from a2a.utils.helpers import validate, validate_async_generator
@@ -32,14 +30,14 @@ class CallContextBuilder(ABC):
3230
"""A class for building ServerCallContexts using the Starlette Request."""
3331

3432
@abstractmethod
35-
def build(self, context: grpc.ServicerContext) -> ServerCallContext:
33+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
3634
"""Builds a ServerCallContext from a gRPC Request."""
3735

3836

3937
class DefaultCallContextBuilder(CallContextBuilder):
4038
"""A default implementation of CallContextBuilder."""
4139

42-
def build(self, context: grpc.ServicerContext) -> ServerCallContext:
40+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4341
"""Builds the ServerCallContext."""
4442
user = UnauthenticatedUser()
4543
state = {}
@@ -301,7 +299,7 @@ async def GetAgentCard(
301299
return proto_utils.ToProto.agent_card(self.agent_card)
302300

303301
async def abort_context(
304-
self, error: ServerError, context: grpc.ServicerContext
302+
self, error: ServerError, context: grpc.aio.ServicerContext
305303
) -> None:
306304
"""Sets the grpc errors appropriately in the context."""
307305
match error.error:

src/a2a/server/tasks/task_manager.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ async def save_task_event(
107107
)
108108
if not self.task_id:
109109
self.task_id = task_id_from_event
110-
if not self.context_id and self.context_id != event.contextId:
110+
if self.context_id and self.context_id != event.contextId:
111+
raise ServerError(
112+
error=InvalidParamsError(
113+
message=f"Context in event doesn't match TaskManager {self.context_id} : {event.contextId}"
114+
)
115+
)
116+
if not self.context_id:
111117
self.context_id = event.contextId
112118

113119
logger.debug(
@@ -130,7 +136,10 @@ async def save_task_event(
130136
task.history = [task.status.message]
131137
else:
132138
task.history.append(task.status.message)
133-
139+
if event.metadata:
140+
if not task.metadata:
141+
task.metadata = {}
142+
task.metadata.update(event.metadata)
134143
task.status = event.status
135144
else:
136145
logger.debug('Appending artifact to task %s', task.id)

0 commit comments

Comments
 (0)