-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathevent_consumer.py
More file actions
123 lines (105 loc) · 4.81 KB
/
event_consumer.py
File metadata and controls
123 lines (105 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import asyncio
import logging
from collections.abc import AsyncGenerator
from pydantic import ValidationError
from a2a.server.events.event_queue import Event, EventQueueLegacy, QueueShutDown
from a2a.types.a2a_pb2 import (
Message,
Task,
TaskState,
TaskStatusUpdateEvent,
)
from a2a.utils.telemetry import SpanKind, trace_class
logger = logging.getLogger(__name__)
@trace_class(kind=SpanKind.SERVER)
class EventConsumer:
"""Consumer to read events from the agent event queue."""
def __init__(self, queue: EventQueueLegacy):
"""Initializes the EventConsumer.
Args:
queue: The `EventQueue` instance to consume events from.
"""
self.queue = queue
self._timeout = 0.5
self._exception: BaseException | None = None
logger.debug('EventConsumer initialized')
async def consume_all(self) -> AsyncGenerator[Event]:
"""Consume all the generated streaming events from the agent.
This method yields events as they become available from the queue
until a final event is received or the queue is closed. It also
monitors for exceptions set by the `agent_task_callback`.
Yields:
Events dequeued from the queue.
Raises:
BaseException: If an exception was set by the `agent_task_callback`.
"""
logger.debug('Starting to consume all events from the queue.')
while True:
if self._exception:
raise self._exception
try:
# We use a timeout when waiting for an event from the queue.
# This is required because it allows the loop to check if
# `self._exception` has been set by the `agent_task_callback`.
# Without the timeout, loop might hang indefinitely if no events are
# enqueued by the agent and the agent simply threw an exception
event = await asyncio.wait_for(
self.queue.dequeue_event(), timeout=self._timeout
)
logger.debug(
'Dequeued event of type: %s in consume_all.', type(event)
)
self.queue.task_done()
logger.debug(
'Marked task as done in event queue in consume_all'
)
is_final_event = isinstance(event, Message) or (
isinstance(event, Task | TaskStatusUpdateEvent)
and event.status.state
in (
TaskState.TASK_STATE_COMPLETED,
TaskState.TASK_STATE_CANCELED,
TaskState.TASK_STATE_FAILED,
TaskState.TASK_STATE_REJECTED,
TaskState.TASK_STATE_UNSPECIFIED,
TaskState.TASK_STATE_INPUT_REQUIRED,
)
)
# Make sure the yield is after the close events, otherwise
# the caller may end up in a blocked state where this
# generator isn't called again to close things out and the
# other part is waiting for an event or a closed queue.
if is_final_event:
logger.debug('Stopping event consumption in consume_all.')
await self.queue.close(True)
yield event
break
yield event
except TimeoutError:
# continue polling until there is a final event
continue
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
# This class was made an alias of built-in TimeoutError after 3.11
continue
except (QueueShutDown, asyncio.QueueEmpty):
# Confirm that the queue is closed, e.g. we aren't on
# python 3.12 and get a queue empty error on an open queue
if self.queue.is_closed():
break
except ValidationError:
logger.exception('Invalid event format received')
continue
except Exception as e:
logger.exception('Stopping event consumption due to exception')
self._exception = e
continue
def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
"""Callback to handle exceptions from the agent's execution task.
If the agent's asyncio task raises an exception, this callback is
invoked, and the exception is stored to be re-raised by the consumer loop.
Args:
agent_task: The asyncio.Task that completed.
"""
logger.debug('Agent task callback triggered.')
if not agent_task.cancelled() and agent_task.done():
self._exception = agent_task.exception()