|
1 | 1 | import functools |
2 | 2 | import inspect |
| 3 | +import logging |
3 | 4 |
|
4 | 5 | from abc import ABC, abstractmethod |
5 | 6 | from collections.abc import AsyncGenerator, Callable |
|
30 | 31 | from a2a.utils.proto_utils import validate_proto_required_fields |
31 | 32 |
|
32 | 33 |
|
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
33 | 37 | class RequestHandler(ABC): |
34 | 38 | """A2A request handler interface. |
35 | 39 |
|
@@ -284,3 +288,130 @@ async def async_wrapper( |
284 | 288 | return await method(self, params, context, *args, **kwargs) |
285 | 289 |
|
286 | 290 | return async_wrapper |
| 291 | + |
| 292 | + |
| 293 | +def validate( |
| 294 | + expression: Callable[[Any], bool], |
| 295 | + error_message: str | None = None, |
| 296 | + error_type: type[Exception] = UnsupportedOperationError, |
| 297 | +) -> Callable: |
| 298 | + """Decorator that validates if a given expression evaluates to True. |
| 299 | +
|
| 300 | + Typically used on class methods to check capabilities or configuration |
| 301 | + before executing the method's logic. If the expression is False, |
| 302 | + the specified `error_type` (defaults to `UnsupportedOperationError`) is raised. |
| 303 | +
|
| 304 | + Args: |
| 305 | + expression: A callable that takes the instance (`self`) as its argument |
| 306 | + and returns a boolean. |
| 307 | + error_message: An optional custom error message for the error raised. |
| 308 | + If None, the string representation of the expression will be used. |
| 309 | + error_type: The exception class to raise on validation failure. |
| 310 | + Must take a `message` keyword argument (inherited from A2AError). |
| 311 | +
|
| 312 | + Examples: |
| 313 | + Demonstrating with an async method: |
| 314 | + >>> import asyncio |
| 315 | + >>> from a2a.utils.errors import UnsupportedOperationError |
| 316 | + >>> |
| 317 | + >>> class MyAgent: |
| 318 | + ... def __init__(self, streaming_enabled: bool): |
| 319 | + ... self.streaming_enabled = streaming_enabled |
| 320 | + ... |
| 321 | + ... @validate( |
| 322 | + ... lambda self: self.streaming_enabled, |
| 323 | + ... 'Streaming is not enabled for this agent', |
| 324 | + ... ) |
| 325 | + ... async def stream_response(self, message: str): |
| 326 | + ... return f'Streaming: {message}' |
| 327 | + >>> |
| 328 | + >>> async def run_async_test(): |
| 329 | + ... # Successful call |
| 330 | + ... agent_ok = MyAgent(streaming_enabled=True) |
| 331 | + ... result = await agent_ok.stream_response('hello') |
| 332 | + ... print(result) |
| 333 | + ... |
| 334 | + ... # Call that fails validation |
| 335 | + ... agent_fail = MyAgent(streaming_enabled=False) |
| 336 | + ... try: |
| 337 | + ... await agent_fail.stream_response('world') |
| 338 | + ... except UnsupportedOperationError as e: |
| 339 | + ... print(e.message) |
| 340 | + >>> |
| 341 | + >>> asyncio.run(run_async_test()) |
| 342 | + Streaming: hello |
| 343 | + Streaming is not enabled for this agent |
| 344 | +
|
| 345 | + Demonstrating with a sync method: |
| 346 | + >>> class SecureAgent: |
| 347 | + ... def __init__(self): |
| 348 | + ... self.auth_enabled = False |
| 349 | + ... |
| 350 | + ... @validate( |
| 351 | + ... lambda self: self.auth_enabled, |
| 352 | + ... 'Authentication must be enabled for this operation', |
| 353 | + ... ) |
| 354 | + ... def secure_operation(self, data: str): |
| 355 | + ... return f'Processing secure data: {data}' |
| 356 | + >>> |
| 357 | + >>> # Error case example |
| 358 | + >>> agent = SecureAgent() |
| 359 | + >>> try: |
| 360 | + ... agent.secure_operation('secret') |
| 361 | + ... except UnsupportedOperationError as e: |
| 362 | + ... print(e.message) |
| 363 | + Authentication must be enabled for this operation |
| 364 | +
|
| 365 | + Note: |
| 366 | + This decorator works with both sync and async methods automatically. |
| 367 | + """ |
| 368 | + |
| 369 | + def decorator(function: Callable) -> Callable: |
| 370 | + if inspect.isasyncgenfunction(function): |
| 371 | + |
| 372 | + @functools.wraps(function) |
| 373 | + async def async_gen_wrapper(self: Any, *args, **kwargs) -> Any: |
| 374 | + if not expression(self): |
| 375 | + final_message = error_message or str(expression) |
| 376 | + logger.error('Validation failure: %s', final_message) |
| 377 | + raise ( |
| 378 | + error_type(final_message) |
| 379 | + if final_message |
| 380 | + else error_type |
| 381 | + ) |
| 382 | + inner = function(self, *args, **kwargs) |
| 383 | + try: |
| 384 | + async for item in inner: |
| 385 | + yield item |
| 386 | + finally: |
| 387 | + await inner.aclose() |
| 388 | + |
| 389 | + return async_gen_wrapper |
| 390 | + |
| 391 | + if inspect.iscoroutinefunction(function): |
| 392 | + |
| 393 | + @functools.wraps(function) |
| 394 | + async def async_wrapper(self: Any, *args, **kwargs) -> Any: |
| 395 | + if not expression(self): |
| 396 | + final_message = error_message or str(expression) |
| 397 | + logger.error('Validation failure: %s', final_message) |
| 398 | + raise ( |
| 399 | + error_type(final_message) |
| 400 | + if final_message |
| 401 | + else error_type |
| 402 | + ) |
| 403 | + return await function(self, *args, **kwargs) |
| 404 | + |
| 405 | + return async_wrapper |
| 406 | + |
| 407 | + @functools.wraps(function) |
| 408 | + def sync_wrapper(self: Any, *args, **kwargs) -> Any: |
| 409 | + if not expression(self): |
| 410 | + final_message = error_message or str(expression) |
| 411 | + logger.error('Validation failure: %s', final_message) |
| 412 | + raise error_type(final_message) |
| 413 | + return function(self, *args, **kwargs) |
| 414 | + |
| 415 | + return sync_wrapper |
| 416 | + |
| 417 | + return decorator |
0 commit comments