forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevent_consumer.py
More file actions
162 lines (136 loc) · 5.97 KB
/
event_consumer.py
File metadata and controls
162 lines (136 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import asyncio
import logging
import sys
from collections.abc import AsyncGenerator
from pydantic import ValidationError
from a2a.server.events.event_queue import Event, EventQueue
from a2a.types import (
InternalError,
Message,
Task,
TaskState,
TaskStatusUpdateEvent,
)
from a2a.utils.errors import ServerError
from a2a.utils.telemetry import SpanKind, trace_class
# This is an alias to the exception for closed queue
QueueClosed: type[Exception] = asyncio.QueueEmpty
# When using python 3.13 or higher, the closed queue signal is QueueShutdown
if sys.version_info >= (3, 13):
QueueClosed = asyncio.QueueShutDown
logger = logging.getLogger(__name__)
@trace_class(kind=SpanKind.SERVER)
class EventConsumer:
"""Consumer to read events from the agent event queue."""
def __init__(self, queue: EventQueue):
"""Initializes the EventConsumer.
Args:
queue: The `EventQueue` instance to consume events from.
"""
self.queue = queue
self._timeout = 0.5
self._exception: BaseException | None = None
logger.debug('EventConsumer initialized')
async def consume_one(self) -> Event:
"""Consume one event from the agent event queue non-blocking.
Returns:
The next event from the queue.
Raises:
ServerError: If the queue is empty when attempting to dequeue
immediately.
"""
logger.debug('Attempting to consume one event.')
try:
event = await self.queue.dequeue_event(no_wait=True)
except asyncio.QueueEmpty as e:
logger.warning('Event queue was empty in consume_one.')
raise ServerError(
InternalError(message='Agent did not return any response')
) from e
logger.debug(f'Dequeued event of type: {type(event)} in consume_one.')
self.queue.task_done()
return event
async def consume_all(self) -> AsyncGenerator[Event]:
"""Consume all the generated streaming events from the agent.
This method yields events as they become available from the queue
until a final event is received or the queue is closed. It also
monitors for exceptions set by the `agent_task_callback`.
Yields:
Events dequeued from the queue.
Raises:
BaseException: If an exception was set by the `agent_task_callback`.
"""
logger.debug('Starting to consume all events from the queue.')
while True:
if self._exception:
raise self._exception
try:
# We use a timeout when waiting for an event from the queue.
# This is required because it allows the loop to check if
# `self._exception` has been set by the `agent_task_callback`.
# Without the timeout, loop might hang indefinitely if no events are
# enqueued by the agent and the agent simply threw an exception
event = await asyncio.wait_for(
self.queue.dequeue_event(), timeout=self._timeout
)
logger.debug(
f'Dequeued event of type: {type(event)} in consume_all.'
)
self.queue.task_done()
logger.debug(
'Marked task as done in event queue in consume_all'
)
is_final_event = (
(isinstance(event, TaskStatusUpdateEvent) and event.final)
or isinstance(event, Message)
or (
isinstance(event, Task)
and event.status.state
in (
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
TaskState.unknown,
TaskState.input_required,
)
)
)
# Make sure the yield is after the close events, otherwise
# the caller may end up in a blocked state where this
# generator isn't called again to close things out and the
# other part is waiting for an event or a closed queue.
if is_final_event:
logger.debug('Stopping event consumption in consume_all.')
await self.queue.close()
yield event
break
yield event
except TimeoutError:
# continue polling until there is a final event
continue
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
# This class was made an alias of build-in TimeoutError after 3.11
continue
except (QueueClosed, asyncio.QueueEmpty):
# Confirm that the queue is closed, e.g. we aren't on
# python 3.12 and get a queue empty error on an open queue
if self.queue.is_closed():
break
except ValidationError:
logger.exception('Invalid event format received')
continue
except Exception as e:
logger.exception('Stopping event consumption due to exception')
self._exception = e
continue
def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
"""Callback to handle exceptions from the agent's execution task.
If the agent's asyncio task raises an exception, this callback is
invoked, and the exception is stored to be re-raised by the consumer loop.
Args:
agent_task: The asyncio.Task that completed.
"""
logger.debug('Agent task callback triggered.')
if not agent_task.cancelled() and agent_task.done():
self._exception = agent_task.exception()