|
3 | 3 |
|
4 | 4 | from abc import ABC, abstractmethod |
5 | 5 | from collections.abc import AsyncIterator, Callable, Coroutine |
6 | | -from typing import Any |
| 6 | +from typing import TYPE_CHECKING, Any |
7 | 7 |
|
8 | 8 | import httpx |
9 | 9 |
|
|
24 | 24 | ) |
25 | 25 |
|
26 | 26 |
|
| 27 | +if TYPE_CHECKING: |
| 28 | + from a2a.client.tls import TLSConfig |
| 29 | + |
| 30 | + |
27 | 31 | logger = logging.getLogger(__name__) |
28 | 32 |
|
29 | 33 |
|
@@ -70,6 +74,46 @@ class ClientConfig: |
70 | 74 | extensions: list[str] = dataclasses.field(default_factory=list) |
71 | 75 | """A list of extension URIs the client supports.""" |
72 | 76 |
|
| 77 | + tls_config: 'TLSConfig | None' = None |
| 78 | + """TLS/SSL configuration for secure communication. If provided, this |
| 79 | + will be used to configure secure connections for HTTP and gRPC |
| 80 | + transports. Ignored if httpx_client or grpc_channel_factory is |
| 81 | + explicitly provided.""" |
| 82 | + |
| 83 | + validate_messages: bool = False |
| 84 | + """Whether to validate messages against JSON Schema before sending |
| 85 | + and after receiving. Useful for protocol compliance testing.""" |
| 86 | + |
| 87 | + def get_httpx_client(self) -> httpx.AsyncClient: |
| 88 | + """Get or create an httpx client with TLS configuration. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + Configured httpx.AsyncClient instance. |
| 92 | + """ |
| 93 | + if self.httpx_client is not None: |
| 94 | + return self.httpx_client |
| 95 | + |
| 96 | + if self.tls_config is not None: |
| 97 | + return self.tls_config.create_httpx_client() |
| 98 | + |
| 99 | + return httpx.AsyncClient() |
| 100 | + |
| 101 | + def get_grpc_channel_factory(self) -> Callable[[str], Channel] | None: |
| 102 | + """Get or create a gRPC channel factory with TLS configuration. |
| 103 | +
|
| 104 | + Returns: |
| 105 | + A callable that creates gRPC channels, or None. |
| 106 | + """ |
| 107 | + if self.grpc_channel_factory is not None: |
| 108 | + return self.grpc_channel_factory |
| 109 | + |
| 110 | + if self.tls_config is not None: |
| 111 | + from a2a.client.tls import create_grpc_channel_factory |
| 112 | + |
| 113 | + return create_grpc_channel_factory(self.tls_config) |
| 114 | + |
| 115 | + return None |
| 116 | + |
73 | 117 |
|
74 | 118 | UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None |
75 | 119 | # Alias for emitted events from client |
|
0 commit comments