Skip to content

Commit 0289cb7

Browse files
committed
Gemini authored: update tests
1 parent c7f4eb0 commit 0289cb7

9 files changed

Lines changed: 253 additions & 958 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ async def send_message(
7373
params = MessageSendParams(message=request, configuration=config)
7474

7575
if not self._config.streaming or not self._card.capabilities.streaming:
76-
response = await self._transport.send_message(params, context=context)
76+
response = await self._transport.send_message(
77+
params, context=context
78+
)
7779
result = (
7880
(response, None) if isinstance(response, Task) else response
7981
)
@@ -85,6 +87,9 @@ async def send_message(
8587
stream = self._transport.send_message_streaming(params, context=context)
8688

8789
first_event = await anext(stream)
90+
# The response from a server may be either exactly one Message or a
91+
# series of Task updates. Separate out the first message for special
92+
# case handling, which allows us to simplify further stream processing.
8893
if isinstance(first_event, Message):
8994
await self.consume(first_event, self._card)
9095
yield first_event
@@ -102,7 +107,7 @@ async def _process_response(
102107
) -> ClientEvent:
103108
if isinstance(event, Message):
104109
raise A2AClientInvalidStateError(
105-
"received a streamed Message from server after first response; this is not supported"
110+
'received a streamed Message from server after first response; this is not supported'
106111
)
107112
await tracker.process(event)
108113
task = tracker.get_task_or_raise()
@@ -201,11 +206,16 @@ async def resubscribe(
201206
"""
202207
if not self._config.streaming or not self._card.capabilities.streaming:
203208
raise NotImplementedError(
204-
"client and/or server do not support resubscription."
209+
'client and/or server do not support resubscription.'
205210
)
206211

207212
tracker = ClientTaskManager()
208-
async for event in self._transport.resubscribe(request, context=context):
213+
# Note: resubscribe can only be called on an existing task. As such,
214+
# we should never see Message updates, despite the typing of the service
215+
# definition indicating it may be possible.
216+
async for event in self._transport.resubscribe(
217+
request, context=context
218+
):
209219
yield await self._process_response(tracker, event)
210220

211221
async def get_card(

src/a2a/client/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ async def send_message(
119119
pairs, or a `Message`. Client will also send these values to any
120120
configured `Consumer`s in the client.
121121
"""
122+
return
123+
yield
122124

123125
@abstractmethod
124126
async def get_task(
@@ -164,6 +166,8 @@ async def resubscribe(
164166
context: ClientCallContext | None = None,
165167
) -> AsyncIterator[ClientEvent]:
166168
"""Resubscribes to a task's event stream."""
169+
return
170+
yield
167171

168172
@abstractmethod
169173
async def get_card(

src/a2a/client/client_factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def create(
111111
transport_protocol = x
112112
break
113113
if not transport_protocol:
114-
raise ValueError("no compatible transports found.")
114+
raise ValueError('no compatible transports found.')
115115
if transport_protocol not in self._registry:
116-
raise ValueError(f"no client available for {transport_protocol}")
116+
raise ValueError(f'no client available for {transport_protocol}')
117117

118118
all_consumers = self._consumers.copy()
119119
if consumers:
@@ -146,8 +146,8 @@ def minimal_agent_card(
146146
capabilities=AgentCapabilities(),
147147
default_input_modes=[],
148148
default_output_modes=[],
149-
description="",
149+
description='',
150150
skills=[],
151-
version="",
152-
name="",
151+
version='',
152+
name='',
153153
)

src/a2a/client/transports/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def send_message_streaming(
3838
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
3939
]:
4040
"""Sends a streaming message request to the agent and yields responses as they arrive."""
41+
return
4142
yield
4243

4344
@abstractmethod
@@ -86,6 +87,7 @@ async def resubscribe(
8687
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
8788
]:
8889
"""Reconnects to get task updates."""
90+
return
8991
yield
9092

9193
@abstractmethod

src/a2a/client/transports/grpc.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import grpc
88
except ImportError as e:
99
raise ImportError(
10-
"A2AGrpcClient requires grpcio and grpcio-tools to be installed. "
11-
"Install with: "
10+
'A2AGrpcClient requires grpcio and grpcio-tools to be installed. '
11+
'Install with: '
1212
"'pip install a2a-sdk[grpc]'"
1313
) from e
1414

@@ -47,7 +47,9 @@ def __init__(
4747
self.agent_card = agent_card
4848
self.stub = grpc_stub
4949
self._needs_extended_card = (
50-
agent_card.supports_authenticated_extended_card if agent_card else True
50+
agent_card.supports_authenticated_extended_card
51+
if agent_card
52+
else True
5153
)
5254

5355
async def send_message(
@@ -101,7 +103,7 @@ async def resubscribe(
101103
]:
102104
"""Reconnects to get task updates."""
103105
stream = self.stub.TaskSubscription(
104-
a2a_pb2.TaskSubscriptionRequest(name=f"tasks/{request.id}")
106+
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}')
105107
)
106108
while True:
107109
response = await stream.read()
@@ -117,7 +119,7 @@ async def get_task(
117119
) -> Task:
118120
"""Retrieves the current state and history of a specific task."""
119121
task = await self.stub.GetTask(
120-
a2a_pb2.GetTaskRequest(name=f"tasks/{request.id}")
122+
a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}')
121123
)
122124
return proto_utils.FromProto.task(task)
123125

@@ -129,7 +131,7 @@ async def cancel_task(
129131
) -> Task:
130132
"""Requests the agent to cancel a specific task."""
131133
task = await self.stub.CancelTask(
132-
a2a_pb2.CancelTaskRequest(name=f"tasks/{request.id}")
134+
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
133135
)
134136
return proto_utils.FromProto.task(task)
135137

@@ -142,9 +144,11 @@ async def set_task_callback(
142144
"""Sets or updates the push notification configuration for a specific task."""
143145
config = await self.stub.CreateTaskPushNotificationConfig(
144146
a2a_pb2.CreateTaskPushNotificationConfigRequest(
145-
parent="",
146-
config_id="",
147-
config=proto_utils.ToProto.task_push_notification_config(request),
147+
parent='',
148+
config_id='',
149+
config=proto_utils.ToProto.task_push_notification_config(
150+
request
151+
),
148152
)
149153
)
150154
return proto_utils.FromProto.task_push_notification_config(config)
@@ -158,7 +162,7 @@ async def get_task_callback(
158162
"""Retrieves the push notification configuration for a specific task."""
159163
config = await self.stub.GetTaskPushNotificationConfig(
160164
a2a_pb2.GetTaskPushNotificationConfigRequest(
161-
name=f"tasks/{request.id}/pushNotification/{request.push_notification_config_id}",
165+
name=f'tasks/{request.id}/pushNotification/{request.push_notification_config_id}',
162166
)
163167
)
164168
return proto_utils.FromProto.task_push_notification_config(config)
@@ -170,11 +174,10 @@ async def get_card(
170174
) -> AgentCard:
171175
"""Retrieves the agent's card."""
172176
card = self.agent_card
173-
if card is None and not self._needs_extended_card:
174-
raise ValueError("Agent card is not available.")
175-
176-
if not self._needs_extended_card:
177+
if card and not self._needs_extended_card:
177178
return card
179+
if card is None and not self._needs_extended_card:
180+
raise ValueError('Agent card is not available.')
178181

179182
card_pb = await self.stub.GetAgentCard(
180183
a2a_pb2.GetAgentCardRequest(),
@@ -186,5 +189,5 @@ async def get_card(
186189

187190
async def close(self) -> None:
188191
"""Closes the gRPC channel."""
189-
if hasattr(self.stub, "close"):
192+
if hasattr(self.stub, 'close'):
190193
await self.stub.close()

0 commit comments

Comments
 (0)