Skip to content

Commit f458933

Browse files
author
Fankouzu
committed
refactor: modernize TLS API, deduplicate validation, and improve error context
1 parent 7a6397e commit f458933

2 files changed

Lines changed: 65 additions & 61 deletions

File tree

src/a2a/client/tls.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ def create_ssl_context(self) -> ssl.SSLContext:
5454
if isinstance(self.verify, ssl.SSLContext):
5555
return self.verify
5656

57-
protocol_map = {
58-
'TLSv1_2': ssl.PROTOCOL_TLSv1_2,
59-
'TLSv1_3': ssl.PROTOCOL_TLS_CLIENT,
60-
}
61-
protocol = protocol_map.get(self.min_version, ssl.PROTOCOL_TLS_CLIENT)
57+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
6258

63-
context = ssl.SSLContext(protocol)
59+
version_map = {
60+
'TLSv1_2': ssl.TLSVersion.TLSv1_2,
61+
'TLSv1_3': ssl.TLSVersion.TLSv1_3,
62+
}
63+
context.minimum_version = version_map.get(
64+
self.min_version, ssl.TLSVersion.TLSv1_2
65+
)
6466

6567
if self.ca_cert:
6668
context.load_verify_locations(cafile=str(self.ca_cert))
@@ -220,15 +222,17 @@ def create_server_ssl_context(
220222
Returns:
221223
Configured ssl.SSLContext for server use.
222224
"""
223-
protocol_map = {
224-
'TLSv1_2': ssl.PROTOCOL_TLSv1_2,
225-
'TLSv1_3': ssl.PROTOCOL_TLS_SERVER,
226-
}
227-
protocol = protocol_map.get(min_version, ssl.PROTOCOL_TLS_SERVER)
228-
229-
context = ssl.SSLContext(protocol)
225+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
230226
context.load_cert_chain(str(cert_file), str(key_file))
231227

228+
version_map = {
229+
'TLSv1_2': ssl.TLSVersion.TLSv1_2,
230+
'TLSv1_3': ssl.TLSVersion.TLSv1_3,
231+
}
232+
context.minimum_version = version_map.get(
233+
min_version, ssl.TLSVersion.TLSv1_2
234+
)
235+
232236
if ca_cert:
233237
context.load_verify_locations(cafile=str(ca_cert))
234238

src/a2a/validation.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,52 @@ def validate_message(
227227
) from e
228228

229229

230+
def _validate_against_types(
231+
data: dict[str, Any],
232+
model_types: tuple[type[BaseModel], ...],
233+
category_name: str,
234+
) -> BaseModel:
235+
"""Validate data against multiple model types and return first match.
236+
237+
Args:
238+
data: Raw data to validate.
239+
model_types: Tuple of model types to try.
240+
category_name: Name of the category for error messages (e.g., 'request', 'response').
241+
242+
Returns:
243+
The validated model instance.
244+
245+
Raises:
246+
ValidationError: If data doesn't match any of the provided types.
247+
"""
248+
errors: list[dict[str, Any]] = []
249+
250+
for model_type in model_types:
251+
try:
252+
return validate_message(data, model_type, strict=False)
253+
except ValidationError as e:
254+
errors.append(
255+
{
256+
'type': model_type.__name__,
257+
'path': e.errors[0].get('path', []) if e.errors else [],
258+
'message': e.errors[0].get('message', str(e))
259+
if e.errors
260+
else str(e),
261+
}
262+
)
263+
264+
error_details = '; '.join(
265+
f'{e["type"]}: {e["message"]} (path: {".".join(map(str, e["path"])) or "root"})'
266+
for e in errors
267+
)
268+
raise ValidationError(
269+
f'Data does not match any known A2A {category_name} type. '
270+
f'Attempted types: {[e["type"] for e in errors]}. Details: {error_details}',
271+
errors=errors,
272+
instance=data,
273+
)
274+
275+
230276
def validate_request(data: dict[str, Any]) -> BaseModel:
231277
"""Validate and parse an A2A request message.
232278
@@ -252,33 +298,7 @@ def validate_request(data: dict[str, Any]) -> BaseModel:
252298
TaskResubscriptionRequest,
253299
)
254300

255-
errors: list[dict[str, Any]] = []
256-
257-
for model_type in request_types:
258-
try:
259-
return validate_message(data, model_type, strict=False)
260-
except ValidationError:
261-
continue
262-
263-
raise ValidationError(
264-
'Data does not match any known A2A request type',
265-
errors=errors,
266-
instance=data,
267-
)
268-
269-
errors: list[dict[str, Any]] = []
270-
271-
for model_type in request_types:
272-
try:
273-
return validate_message(data, model_type, strict=False)
274-
except ValidationError:
275-
continue
276-
277-
raise ValidationError(
278-
'Data does not match any known A2A request type',
279-
errors=errors,
280-
instance=data,
281-
)
301+
return _validate_against_types(data, request_types, 'request')
282302

283303

284304
def validate_response(data: dict[str, Any]) -> BaseModel:
@@ -302,27 +322,7 @@ def validate_response(data: dict[str, Any]) -> BaseModel:
302322
GetTaskPushNotificationConfigResponse,
303323
)
304324

305-
for model_type in response_types:
306-
try:
307-
return validate_message(data, model_type, strict=False)
308-
except ValidationError:
309-
continue
310-
311-
raise ValidationError(
312-
'Data does not match any known A2A response type',
313-
instance=data,
314-
)
315-
316-
for model_type in response_types:
317-
try:
318-
return validate_message(data, model_type, strict=False)
319-
except ValidationError:
320-
continue
321-
322-
raise ValidationError(
323-
'Data does not match any known A2A response type',
324-
instance=data,
325-
)
325+
return _validate_against_types(data, response_types, 'response')
326326

327327

328328
class MessageValidator:

0 commit comments

Comments
 (0)