Skip to content

Commit 8c73b03

Browse files
committed
style: apply ruff formatting to all changed files
1 parent 5056608 commit 8c73b03

4 files changed

Lines changed: 100 additions & 81 deletions

File tree

src/a2a/client/card_resolver.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
AgentCard,
3030
)
3131
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
32+
3233
# ---- NEW IMPORT (fix for A2A-SSRF-01) ----
3334
from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url
35+
3436
# -------------------------------------------
3537

3638

@@ -53,8 +55,8 @@ def __init__(
5355
base_url: The base URL of the agent's host.
5456
agent_card_path: The path to the agent card endpoint, relative to the base URL.
5557
"""
56-
self.base_url = base_url.rstrip('/')
57-
self.agent_card_path = agent_card_path.lstrip('/')
58+
self.base_url = base_url.rstrip("/")
59+
self.agent_card_path = agent_card_path.lstrip("/")
5860
self.httpx_client = httpx_client
5961

6062
async def get_agent_card(
@@ -87,9 +89,9 @@ async def get_agent_card(
8789
if not relative_card_path:
8890
path_segment = self.agent_card_path
8991
else:
90-
path_segment = relative_card_path.lstrip('/')
92+
path_segment = relative_card_path.lstrip("/")
9193

92-
target_url = f'{self.base_url}/{path_segment}'
94+
target_url = f"{self.base_url}/{path_segment}"
9395

9496
try:
9597
response = await self.httpx_client.get(
@@ -99,7 +101,7 @@ async def get_agent_card(
99101
response.raise_for_status()
100102
agent_card_data = response.json()
101103
logger.info(
102-
'Successfully fetched agent card data from %s: %s',
104+
"Successfully fetched agent card data from %s: %s",
103105
target_url,
104106
agent_card_data,
105107
)
@@ -115,7 +117,7 @@ async def get_agent_card(
115117
validate_agent_card_url(iface.url)
116118
except A2ASSRFValidationError as e:
117119
raise A2AClientJSONError(
118-
f'AgentCard from {target_url} failed SSRF URL validation: {e}'
120+
f"AgentCard from {target_url} failed SSRF URL validation: {e}"
119121
) from e
120122
# -----------------------------------------------------------------
121123

@@ -125,20 +127,20 @@ async def get_agent_card(
125127
except httpx.HTTPStatusError as e:
126128
raise A2AClientHTTPError(
127129
e.response.status_code,
128-
f'Failed to fetch agent card from {target_url}: {e}',
130+
f"Failed to fetch agent card from {target_url}: {e}",
129131
) from e
130132
except json.JSONDecodeError as e:
131133
raise A2AClientJSONError(
132-
f'Failed to parse JSON for agent card from {target_url}: {e}'
134+
f"Failed to parse JSON for agent card from {target_url}: {e}"
133135
) from e
134136
except httpx.RequestError as e:
135137
raise A2AClientHTTPError(
136138
503,
137-
f'Network communication error fetching agent card from {target_url}: {e}',
139+
f"Network communication error fetching agent card from {target_url}: {e}",
138140
) from e
139141
except ValidationError as e:
140142
raise A2AClientJSONError(
141-
f'Failed to validate agent card structure from {target_url}: {e.json()}'
143+
f"Failed to validate agent card structure from {target_url}: {e.json()}"
142144
) from e
143145

144146
return agent_card

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
# ---- NEW: caller identity extractor type (fix for A2A-INJ-01) ----
8282
# CallerIdExtractor extracts a stable identity string from ServerCallContext.
8383
# Returns None if caller identity cannot be determined (unauthenticated).
84-
CallerIdExtractor = Callable[['ServerCallContext | None'], str | None]
84+
CallerIdExtractor = Callable[["ServerCallContext | None"], str | None]
8585
# ------------------------------------------------------------------
8686

8787

@@ -151,10 +151,10 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None:
151151
self._context_owners: dict[str, str] = {}
152152
if get_caller_id is None:
153153
logger.warning(
154-
'DefaultRequestHandler initialized without get_caller_id: '
155-
'context ownership is not enforced. Cross-user context injection '
156-
'(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id '
157-
'extractor to enable ownership checks.'
154+
"DefaultRequestHandler initialized without get_caller_id: "
155+
"context ownership is not enforced. Cross-user context injection "
156+
"(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id "
157+
"extractor to enable ownership checks."
158158
)
159159
# ----------------------------------
160160
self._running_agents = {}
@@ -186,7 +186,7 @@ async def on_cancel_task(
186186
if task.status.state in TERMINAL_TASK_STATES:
187187
raise ServerError(
188188
error=TaskNotCancelableError(
189-
message=f'Task cannot be canceled - current state: {task.status.state}'
189+
message=f"Task cannot be canceled - current state: {task.status.state}"
190190
)
191191
)
192192

@@ -219,14 +219,14 @@ async def on_cancel_task(
219219
if not isinstance(result, Task):
220220
raise ServerError(
221221
error=InternalError(
222-
message='Agent did not return valid response for cancel'
222+
message="Agent did not return valid response for cancel"
223223
)
224224
)
225225

226226
if result.status.state != TaskState.canceled:
227227
raise ServerError(
228228
error=TaskNotCancelableError(
229-
message=f'Task cannot be canceled - current state: {result.status.state}'
229+
message=f"Task cannot be canceled - current state: {result.status.state}"
230230
)
231231
)
232232

@@ -264,23 +264,25 @@ def _check_context_ownership(
264264
raise ServerError(
265265
error=InvalidParamsError(
266266
message=(
267-
f'Access denied: cannot send to context_id={context_id!r} '
268-
'because caller identity could not be determined.'
267+
f"Access denied: cannot send to context_id={context_id!r} "
268+
"because caller identity could not be determined."
269269
)
270270
)
271271
)
272272

273273
if caller != owner:
274274
logger.warning(
275-
'Context injection attempt blocked: caller=%r tried to send to '
276-
'context_id=%s owned by %r.',
277-
caller, context_id, owner,
275+
"Context injection attempt blocked: caller=%r tried to send to "
276+
"context_id=%s owned by %r.",
277+
caller,
278+
context_id,
279+
owner,
278280
)
279281
raise ServerError(
280282
error=InvalidParamsError(
281283
message=(
282-
f'Access denied: context_id={context_id!r} was created '
283-
'by a different caller.'
284+
f"Access denied: context_id={context_id!r} was created "
285+
"by a different caller."
284286
)
285287
)
286288
)
@@ -296,7 +298,7 @@ def _record_context_owner(
296298
caller = self._get_caller_id(context)
297299
if caller:
298300
self._context_owners[context_id] = caller
299-
logger.debug('Recorded owner %r for context_id=%s', caller, context_id)
301+
logger.debug("Recorded owner %r for context_id=%s", caller, context_id)
300302

301303
async def _setup_message_execution(
302304
self,
@@ -326,14 +328,14 @@ async def _setup_message_execution(
326328
if task.status.state in TERMINAL_TASK_STATES:
327329
raise ServerError(
328330
error=InvalidParamsError(
329-
message=f'Task {task.id} is in terminal state: {task.status.state.value}'
331+
message=f"Task {task.id} is in terminal state: {task.status.state.value}"
330332
)
331333
)
332334
task = task_manager.update_with_message(params.message, task)
333335
elif params.message.task_id:
334336
raise ServerError(
335337
error=TaskNotFoundError(
336-
message=f'Task {params.message.task_id} was specified but does not exist'
338+
message=f"Task {params.message.task_id} was specified but does not exist"
337339
)
338340
)
339341

@@ -344,7 +346,7 @@ async def _setup_message_execution(
344346
task=task,
345347
context=context,
346348
)
347-
task_id = cast('str', request_context.task_id)
349+
task_id = cast("str", request_context.task_id)
348350

349351
# Record ownership for new contexts after successful validation
350352
new_context_id = request_context.context_id or context_id
@@ -372,12 +374,12 @@ async def _setup_message_execution(
372374
def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
373375
if task_id != event_task_id:
374376
logger.error(
375-
'Agent generated task_id=%s does not match the RequestContext task_id=%s.',
377+
"Agent generated task_id=%s does not match the RequestContext task_id=%s.",
376378
event_task_id,
377379
task_id,
378380
)
379381
raise ServerError(
380-
InternalError(message='Task ID mismatch in agent response')
382+
InternalError(message="Task ID mismatch in agent response")
381383
)
382384

383385
async def _send_push_notification_if_needed(
@@ -415,6 +417,7 @@ async def on_message_send(
415417

416418
interrupted_or_non_blocking = False
417419
try:
420+
418421
async def push_notification_callback() -> None:
419422
await self._send_push_notification_if_needed(task_id, result_aggregator)
420423

@@ -429,19 +432,19 @@ async def push_notification_callback() -> None:
429432
)
430433

431434
if bg_consume_task is not None:
432-
bg_consume_task.set_name(f'continue_consuming:{task_id}')
435+
bg_consume_task.set_name(f"continue_consuming:{task_id}")
433436
self._track_background_task(bg_consume_task)
434437

435438
except Exception:
436-
logger.exception('Agent execution failed')
439+
logger.exception("Agent execution failed")
437440
producer_task.cancel()
438441
raise
439442
finally:
440443
if interrupted_or_non_blocking:
441444
cleanup_task = asyncio.create_task(
442445
self._cleanup_producer(producer_task, task_id)
443446
)
444-
cleanup_task.set_name(f'cleanup_producer:{task_id}')
447+
cleanup_task.set_name(f"cleanup_producer:{task_id}")
445448
self._track_background_task(cleanup_task)
446449
else:
447450
await self._cleanup_producer(producer_task, task_id)
@@ -452,7 +455,9 @@ async def push_notification_callback() -> None:
452455
if isinstance(result, Task):
453456
self._validate_task_id_match(task_id, result.id)
454457
if params.configuration:
455-
result = apply_history_length(result, params.configuration.history_length)
458+
result = apply_history_length(
459+
result, params.configuration.history_length
460+
)
456461

457462
await self._send_push_notification_if_needed(task_id, result_aggregator)
458463
return result
@@ -483,17 +488,21 @@ async def on_message_send_stream(
483488
self._validate_task_id_match(task_id, event.id)
484489
await self._send_push_notification_if_needed(task_id, result_aggregator)
485490
yield event
486-
except (asyncio.CancelledError, GeneratorExit):
491+
except asyncio.CancelledError, GeneratorExit:
487492
bg_task = asyncio.create_task(result_aggregator.consume_all(consumer))
488-
bg_task.set_name(f'background_consume:{task_id}')
493+
bg_task.set_name(f"background_consume:{task_id}")
489494
self._track_background_task(bg_task)
490495
raise
491496
finally:
492-
cleanup_task = asyncio.create_task(self._cleanup_producer(producer_task, task_id))
493-
cleanup_task.set_name(f'cleanup_producer:{task_id}')
497+
cleanup_task = asyncio.create_task(
498+
self._cleanup_producer(producer_task, task_id)
499+
)
500+
cleanup_task.set_name(f"cleanup_producer:{task_id}")
494501
self._track_background_task(cleanup_task)
495502

496-
async def _register_producer(self, task_id: str, producer_task: asyncio.Task) -> None:
503+
async def _register_producer(
504+
self, task_id: str, producer_task: asyncio.Task
505+
) -> None:
497506
async with self._running_agents_lock:
498507
self._running_agents[task_id] = producer_task
499508

@@ -504,19 +513,21 @@ def _on_done(completed: asyncio.Task) -> None:
504513
try:
505514
completed.result()
506515
except asyncio.CancelledError:
507-
logger.debug('Background task %s cancelled', completed.get_name())
516+
logger.debug("Background task %s cancelled", completed.get_name())
508517
except Exception:
509-
logger.exception('Background task %s failed', completed.get_name())
518+
logger.exception("Background task %s failed", completed.get_name())
510519
finally:
511520
self._background_tasks.discard(completed)
512521

513522
task.add_done_callback(_on_done)
514523

515-
async def _cleanup_producer(self, producer_task: asyncio.Task, task_id: str) -> None:
524+
async def _cleanup_producer(
525+
self, producer_task: asyncio.Task, task_id: str
526+
) -> None:
516527
try:
517528
await producer_task
518529
except asyncio.CancelledError:
519-
logger.debug('Producer task %s was cancelled during cleanup', task_id)
530+
logger.debug("Producer task %s was cancelled during cleanup", task_id)
520531
await self._queue_manager.close(task_id)
521532
async with self._running_agents_lock:
522533
self._running_agents.pop(task_id, None)
@@ -531,7 +542,9 @@ async def on_set_task_push_notification_config(
531542
task: Task | None = await self.task_store.get(params.task_id, context)
532543
if not task:
533544
raise ServerError(error=TaskNotFoundError())
534-
await self._push_config_store.set_info(params.task_id, params.push_notification_config)
545+
await self._push_config_store.set_info(
546+
params.task_id, params.push_notification_config
547+
)
535548
return params
536549

537550
async def on_get_task_push_notification_config(
@@ -546,7 +559,9 @@ async def on_get_task_push_notification_config(
546559
raise ServerError(error=TaskNotFoundError())
547560
push_notification_config = await self._push_config_store.get_info(params.id)
548561
if not push_notification_config or not push_notification_config[0]:
549-
raise ServerError(error=InternalError(message='Push notification config not found'))
562+
raise ServerError(
563+
error=InternalError(message="Push notification config not found")
564+
)
550565
return TaskPushNotificationConfig(
551566
task_id=params.id,
552567
push_notification_config=push_notification_config[0],
@@ -563,7 +578,7 @@ async def on_resubscribe_to_task(
563578
if task.status.state in TERMINAL_TASK_STATES:
564579
raise ServerError(
565580
error=InvalidParamsError(
566-
message=f'Task {task.id} is in terminal state: {task.status.state.value}'
581+
message=f"Task {task.id} is in terminal state: {task.status.state.value}"
567582
)
568583
)
569584
task_manager = TaskManager(
@@ -591,7 +606,9 @@ async def on_list_task_push_notification_config(
591606
task: Task | None = await self.task_store.get(params.id, context)
592607
if not task:
593608
raise ServerError(error=TaskNotFoundError())
594-
push_notification_config_list = await self._push_config_store.get_info(params.id)
609+
push_notification_config_list = await self._push_config_store.get_info(
610+
params.id
611+
)
595612
return [
596613
TaskPushNotificationConfig(task_id=params.id, push_notification_config=cfg)
597614
for cfg in push_notification_config_list
@@ -607,4 +624,6 @@ async def on_delete_task_push_notification_config(
607624
task: Task | None = await self.task_store.get(params.id, context)
608625
if not task:
609626
raise ServerError(error=TaskNotFoundError())
610-
await self._push_config_store.delete_info(params.id, params.push_notification_config_id)
627+
await self._push_config_store.delete_info(
628+
params.id, params.push_notification_config_id
629+
)

0 commit comments

Comments
 (0)