Skip to content

Commit 191c970

Browse files
committed
refactor: simplify interceptor argument types by removing generic type variables and union aliases.
1 parent 0dad45d commit 191c970

3 files changed

Lines changed: 41 additions & 136 deletions

File tree

src/a2a/client/auth/interceptor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from a2a.client.auth.credentials import CredentialService
44
from a2a.client.client import ClientCallContext
55
from a2a.client.interceptors import (
6+
AfterArgs,
7+
BeforeArgs,
68
ClientCallInterceptor,
7-
UnionAfterArgs,
8-
UnionBeforeArgs,
99
)
1010

1111
logger = logging.getLogger(__name__)
@@ -20,7 +20,7 @@ class AuthInterceptor(ClientCallInterceptor):
2020
def __init__(self, credential_service: CredentialService):
2121
self._credential_service = credential_service
2222

23-
async def before(self, args: UnionBeforeArgs) -> None:
23+
async def before(self, args: BeforeArgs) -> None:
2424
"""Applies authentication headers to the request if credentials are available."""
2525
agent_card = args.agent_card
2626

@@ -94,5 +94,5 @@ async def before(self, args: UnionBeforeArgs) -> None:
9494

9595
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
9696

97-
async def after(self, args: UnionAfterArgs) -> None:
97+
async def after(self, args: AfterArgs) -> None:
9898
"""Invoked after the method is executed."""

src/a2a/client/base_client.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
2-
from typing import Any, cast
2+
from typing import Any
33

44
from a2a.client.client import (
55
Client,
@@ -15,11 +15,6 @@
1515
ClientCallInput,
1616
ClientCallInterceptor,
1717
ClientCallResult,
18-
M,
19-
P,
20-
R,
21-
UnionAfterArgs,
22-
UnionBeforeArgs,
2318
)
2419
from a2a.client.transports.base import ClientTransport
2520
from a2a.types.a2a_pb2 import (
@@ -150,7 +145,7 @@ async def _process_stream(
150145
agent_card=self._card,
151146
context=before_args.context,
152147
)
153-
await self._intercept_after(cast('UnionAfterArgs', after_args))
148+
await self._intercept_after(after_args)
154149
intercepted_response = after_args.result.value
155150
client_event = await self._format_stream_event(
156151
intercepted_response, tracker
@@ -400,21 +395,21 @@ async def close(self) -> None:
400395

401396
async def _execute_with_interceptors(
402397
self,
403-
input_data: ClientCallInput[M, P],
398+
input_data: ClientCallInput,
404399
context: ClientCallContext | None,
405-
transport_call: Callable[[P, ClientCallContext | None], Awaitable[R]],
406-
) -> R:
407-
before_args: BeforeArgs[M, P, R] = BeforeArgs(
400+
transport_call: Callable[
401+
[Any, ClientCallContext | None], Awaitable[Any]
402+
],
403+
) -> Any:
404+
before_args = BeforeArgs(
408405
input=input_data,
409406
agent_card=self._card,
410407
context=context,
411408
)
412-
before_result = await self._intercept_before(
413-
cast('UnionBeforeArgs', before_args)
414-
)
409+
before_result = await self._intercept_before(before_args)
415410

416411
if before_result is not None:
417-
early_after_args: AfterArgs[M, R] = AfterArgs(
412+
early_after_args = AfterArgs(
418413
result=ClientCallResult(
419414
method=input_data.method,
420415
value=before_result['early_return'].value,
@@ -423,7 +418,7 @@ async def _execute_with_interceptors(
423418
context=before_args.context,
424419
)
425420
await self._intercept_after(
426-
cast('UnionAfterArgs', early_after_args),
421+
early_after_args,
427422
before_result['executed'],
428423
)
429424
return early_after_args.result.value
@@ -432,42 +427,38 @@ async def _execute_with_interceptors(
432427
before_args.input.value, before_args.context
433428
)
434429

435-
after_args: AfterArgs[M, R] = AfterArgs(
430+
after_args = AfterArgs(
436431
result=ClientCallResult(method=input_data.method, value=result),
437432
agent_card=self._card,
438433
context=before_args.context,
439434
)
440-
await self._intercept_after(cast('UnionAfterArgs', after_args))
435+
await self._intercept_after(after_args)
441436

442437
return after_args.result.value
443438

444439
async def _execute_stream_with_interceptors(
445440
self,
446-
input_data: ClientCallInput[M, P],
441+
input_data: ClientCallInput,
447442
context: ClientCallContext | None,
448443
transport_call: Callable[
449-
[P, ClientCallContext | None], AsyncIterator[StreamResponse]
444+
[Any, ClientCallContext | None], AsyncIterator[StreamResponse]
450445
],
451446
) -> AsyncIterator[ClientEvent]:
452447

453-
before_args: BeforeArgs[M, P, StreamResponse] = BeforeArgs(
448+
before_args = BeforeArgs(
454449
input=input_data,
455450
agent_card=self._card,
456451
context=context,
457452
)
458-
before_result = await self._intercept_before(
459-
cast('UnionBeforeArgs', before_args)
460-
)
453+
before_result = await self._intercept_before(before_args)
461454

462455
if before_result:
463-
after_args: AfterArgs[M, StreamResponse] = AfterArgs(
456+
after_args = AfterArgs(
464457
result=before_result['early_return'],
465458
agent_card=self._card,
466459
context=before_args.context,
467460
)
468-
await self._intercept_after(
469-
cast('UnionAfterArgs', after_args), before_result['executed']
470-
)
461+
await self._intercept_after(after_args, before_result['executed'])
471462

472463
tracker = ClientTaskManager()
473464
yield await self._format_stream_event(
@@ -482,7 +473,7 @@ async def _execute_stream_with_interceptors(
482473

483474
async def _intercept_before(
484475
self,
485-
args: UnionBeforeArgs,
476+
args: BeforeArgs,
486477
) -> dict[str, Any] | None:
487478
if not self._interceptors:
488479
return None
@@ -499,7 +490,7 @@ async def _intercept_before(
499490

500491
async def _intercept_after(
501492
self,
502-
args: UnionAfterArgs,
493+
args: AfterArgs,
503494
interceptors: list[ClientCallInterceptor] | None = None,
504495
) -> None:
505496
interceptors_to_use = (

src/a2a/client/interceptors.py

Lines changed: 16 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,151 +1,65 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
3+
from abc import abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Generic, Literal, TypeAlias, TypeVar
5+
from typing import TYPE_CHECKING, Any
66

77

88
if TYPE_CHECKING:
99
from a2a.client.client import ClientCallContext
1010

1111
from a2a.types.a2a_pb2 import ( # noqa: TC001
1212
AgentCard,
13-
CancelTaskRequest,
14-
DeleteTaskPushNotificationConfigRequest,
15-
GetExtendedAgentCardRequest,
16-
GetTaskPushNotificationConfigRequest,
17-
GetTaskRequest,
18-
ListTaskPushNotificationConfigsRequest,
19-
ListTaskPushNotificationConfigsResponse,
20-
ListTasksRequest,
21-
ListTasksResponse,
22-
SendMessageRequest,
23-
SendMessageResponse,
24-
StreamResponse,
25-
SubscribeToTaskRequest,
26-
Task,
27-
TaskPushNotificationConfig,
2813
)
2914

3015

31-
M = TypeVar('M')
32-
P = TypeVar('P')
33-
R = TypeVar('R')
34-
35-
3616
@dataclass
37-
class ClientCallInput(Generic[M, P]):
17+
class ClientCallInput:
3818
"""Represents the method and its associated input arguments payload."""
3919

40-
method: M
41-
value: P
20+
method: str
21+
value: Any
4222

4323

4424
@dataclass
45-
class ClientCallResult(Generic[M, R]):
25+
class ClientCallResult:
4626
"""Represents the method and its associated result payload."""
4727

48-
method: M
49-
value: R
28+
method: str
29+
value: Any
5030

5131

5232
@dataclass
53-
class BeforeArgs(Generic[M, P, R]):
33+
class BeforeArgs:
5434
"""Arguments passed to the interceptor before a method call."""
5535

56-
input: ClientCallInput[M, P]
36+
input: ClientCallInput
5737
agent_card: AgentCard
5838
context: ClientCallContext | None = None
59-
early_return: ClientCallResult[M, R] | None = None
39+
early_return: ClientCallResult | None = None
6040

6141

6242
@dataclass
63-
class AfterArgs(Generic[M, R]):
43+
class AfterArgs:
6444
"""Arguments passed to the interceptor after a method call completes."""
6545

66-
result: ClientCallResult[M, R]
46+
result: ClientCallResult
6747
agent_card: AgentCard
6848
context: ClientCallContext | None = None
6949
early_return: bool = False
7050

7151

72-
class ClientCallInterceptor(ABC, Generic[M, P, R]):
52+
class ClientCallInterceptor:
7353
"""An abstract base class for client-side call interceptors.
7454
7555
Interceptors can inspect and modify requests before they are sent,
7656
which is ideal for concerns like authentication, logging, or tracing.
7757
"""
7858

7959
@abstractmethod
80-
async def before(self, args: UnionBeforeArgs) -> None:
60+
async def before(self, args: BeforeArgs) -> None:
8161
"""Invoked before transport method."""
8262

8363
@abstractmethod
84-
async def after(self, args: UnionAfterArgs) -> None:
64+
async def after(self, args: AfterArgs) -> None:
8565
"""Invoked after transport method."""
86-
87-
88-
UnionBeforeArgs: TypeAlias = (
89-
BeforeArgs[
90-
Literal['send_message'], 'SendMessageRequest', 'SendMessageResponse'
91-
]
92-
| BeforeArgs[
93-
Literal['send_message_streaming'],
94-
'SendMessageRequest',
95-
'StreamResponse',
96-
]
97-
| BeforeArgs[Literal['get_task'], 'GetTaskRequest', 'Task']
98-
| BeforeArgs[Literal['list_tasks'], 'ListTasksRequest', 'ListTasksResponse']
99-
| BeforeArgs[Literal['cancel_task'], 'CancelTaskRequest', 'Task']
100-
| BeforeArgs[
101-
Literal['create_task_push_notification_config'],
102-
'TaskPushNotificationConfig',
103-
'TaskPushNotificationConfig',
104-
]
105-
| BeforeArgs[
106-
Literal['get_task_push_notification_config'],
107-
'GetTaskPushNotificationConfigRequest',
108-
'TaskPushNotificationConfig',
109-
]
110-
| BeforeArgs[
111-
Literal['list_task_push_notification_configs'],
112-
'ListTaskPushNotificationConfigsRequest',
113-
'ListTaskPushNotificationConfigsResponse',
114-
]
115-
| BeforeArgs[
116-
Literal['delete_task_push_notification_config'],
117-
'DeleteTaskPushNotificationConfigRequest',
118-
None,
119-
]
120-
| BeforeArgs[
121-
Literal['subscribe'], 'SubscribeToTaskRequest', 'StreamResponse'
122-
]
123-
| BeforeArgs[
124-
Literal['get_extended_agent_card'],
125-
'GetExtendedAgentCardRequest',
126-
'AgentCard',
127-
]
128-
)
129-
130-
UnionAfterArgs: TypeAlias = (
131-
AfterArgs[Literal['send_message'], 'SendMessageResponse']
132-
| AfterArgs[Literal['send_message_streaming'], 'StreamResponse']
133-
| AfterArgs[Literal['get_task'], 'Task']
134-
| AfterArgs[Literal['list_tasks'], 'ListTasksResponse']
135-
| AfterArgs[Literal['cancel_task'], 'Task']
136-
| AfterArgs[
137-
Literal['create_task_push_notification_config'],
138-
'TaskPushNotificationConfig',
139-
]
140-
| AfterArgs[
141-
Literal['get_task_push_notification_config'],
142-
'TaskPushNotificationConfig',
143-
]
144-
| AfterArgs[
145-
Literal['list_task_push_notification_configs'],
146-
'ListTaskPushNotificationConfigsResponse',
147-
]
148-
| AfterArgs[Literal['delete_task_push_notification_config'], None]
149-
| AfterArgs[Literal['subscribe'], 'StreamResponse']
150-
| AfterArgs[Literal['get_extended_agent_card'], 'AgentCard']
151-
)

0 commit comments

Comments
 (0)