Skip to content

Commit edf8ccb

Browse files
committed
move validate function to request_handler since it is only used in request handlers for now
1 parent 93720a9 commit edf8ccb

4 files changed

Lines changed: 134 additions & 151 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from a2a.server.request_handlers.request_handler import (
2323
RequestHandler,
24+
validate,
2425
validate_request_params,
2526
)
2627
from a2a.server.tasks import (
@@ -58,7 +59,7 @@
5859
TaskNotFoundError,
5960
UnsupportedOperationError,
6061
)
61-
from a2a.utils.helpers import maybe_await, validate
62+
from a2a.utils.helpers import maybe_await
6263
from a2a.utils.task import (
6364
apply_history_length,
6465
validate_history_length,

src/a2a/server/request_handlers/request_handler.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import inspect
3+
import logging
34

45
from abc import ABC, abstractmethod
56
from collections.abc import AsyncGenerator, Callable
@@ -30,6 +31,9 @@
3031
from a2a.utils.proto_utils import validate_proto_required_fields
3132

3233

34+
logger = logging.getLogger(__name__)
35+
36+
3337
class RequestHandler(ABC):
3438
"""A2A request handler interface.
3539
@@ -284,3 +288,130 @@ async def async_wrapper(
284288
return await method(self, params, context, *args, **kwargs)
285289

286290
return async_wrapper
291+
292+
293+
def validate(
294+
expression: Callable[[Any], bool],
295+
error_message: str | None = None,
296+
error_type: type[Exception] = UnsupportedOperationError,
297+
) -> Callable:
298+
"""Decorator that validates if a given expression evaluates to True.
299+
300+
Typically used on class methods to check capabilities or configuration
301+
before executing the method's logic. If the expression is False,
302+
the specified `error_type` (defaults to `UnsupportedOperationError`) is raised.
303+
304+
Args:
305+
expression: A callable that takes the instance (`self`) as its argument
306+
and returns a boolean.
307+
error_message: An optional custom error message for the error raised.
308+
If None, the string representation of the expression will be used.
309+
error_type: The exception class to raise on validation failure.
310+
Must take a `message` keyword argument (inherited from A2AError).
311+
312+
Examples:
313+
Demonstrating with an async method:
314+
>>> import asyncio
315+
>>> from a2a.utils.errors import UnsupportedOperationError
316+
>>>
317+
>>> class MyAgent:
318+
... def __init__(self, streaming_enabled: bool):
319+
... self.streaming_enabled = streaming_enabled
320+
...
321+
... @validate(
322+
... lambda self: self.streaming_enabled,
323+
... 'Streaming is not enabled for this agent',
324+
... )
325+
... async def stream_response(self, message: str):
326+
... return f'Streaming: {message}'
327+
>>>
328+
>>> async def run_async_test():
329+
... # Successful call
330+
... agent_ok = MyAgent(streaming_enabled=True)
331+
... result = await agent_ok.stream_response('hello')
332+
... print(result)
333+
...
334+
... # Call that fails validation
335+
... agent_fail = MyAgent(streaming_enabled=False)
336+
... try:
337+
... await agent_fail.stream_response('world')
338+
... except UnsupportedOperationError as e:
339+
... print(e.message)
340+
>>>
341+
>>> asyncio.run(run_async_test())
342+
Streaming: hello
343+
Streaming is not enabled for this agent
344+
345+
Demonstrating with a sync method:
346+
>>> class SecureAgent:
347+
... def __init__(self):
348+
... self.auth_enabled = False
349+
...
350+
... @validate(
351+
... lambda self: self.auth_enabled,
352+
... 'Authentication must be enabled for this operation',
353+
... )
354+
... def secure_operation(self, data: str):
355+
... return f'Processing secure data: {data}'
356+
>>>
357+
>>> # Error case example
358+
>>> agent = SecureAgent()
359+
>>> try:
360+
... agent.secure_operation('secret')
361+
... except UnsupportedOperationError as e:
362+
... print(e.message)
363+
Authentication must be enabled for this operation
364+
365+
Note:
366+
This decorator works with both sync and async methods automatically.
367+
"""
368+
369+
def decorator(function: Callable) -> Callable:
370+
if inspect.isasyncgenfunction(function):
371+
372+
@functools.wraps(function)
373+
async def async_gen_wrapper(self: Any, *args, **kwargs) -> Any:
374+
if not expression(self):
375+
final_message = error_message or str(expression)
376+
logger.error('Validation failure: %s', final_message)
377+
raise (
378+
error_type(final_message)
379+
if final_message
380+
else error_type
381+
)
382+
inner = function(self, *args, **kwargs)
383+
try:
384+
async for item in inner:
385+
yield item
386+
finally:
387+
await inner.aclose()
388+
389+
return async_gen_wrapper
390+
391+
if inspect.iscoroutinefunction(function):
392+
393+
@functools.wraps(function)
394+
async def async_wrapper(self: Any, *args, **kwargs) -> Any:
395+
if not expression(self):
396+
final_message = error_message or str(expression)
397+
logger.error('Validation failure: %s', final_message)
398+
raise (
399+
error_type(final_message)
400+
if final_message
401+
else error_type
402+
)
403+
return await function(self, *args, **kwargs)
404+
405+
return async_wrapper
406+
407+
@functools.wraps(function)
408+
def sync_wrapper(self: Any, *args, **kwargs) -> Any:
409+
if not expression(self):
410+
final_message = error_message or str(expression)
411+
logger.error('Validation failure: %s', final_message)
412+
raise error_type(final_message)
413+
return function(self, *args, **kwargs)
414+
415+
return sync_wrapper
416+
417+
return decorator

src/a2a/utils/helpers.py

Lines changed: 1 addition & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
TaskStatus,
2525
)
2626
from a2a.utils import constants
27-
from a2a.utils.errors import UnsupportedOperationError, VersionNotSupportedError
27+
from a2a.utils.errors import VersionNotSupportedError
2828
from a2a.utils.telemetry import trace_function
2929

3030

@@ -134,133 +134,6 @@ def build_text_artifact(text: str, artifact_id: str) -> Artifact:
134134
return Artifact(parts=[part], artifact_id=artifact_id)
135135

136136

137-
def validate(
138-
expression: Callable[[Any], bool],
139-
error_message: str | None = None,
140-
error_type: type[Exception] = UnsupportedOperationError,
141-
) -> Callable:
142-
"""Decorator that validates if a given expression evaluates to True.
143-
144-
Typically used on class methods to check capabilities or configuration
145-
before executing the method's logic. If the expression is False,
146-
the specified `error_type` (defaults to `UnsupportedOperationError`) is raised.
147-
148-
Args:
149-
expression: A callable that takes the instance (`self`) as its argument
150-
and returns a boolean.
151-
error_message: An optional custom error message for the error raised.
152-
If None, the string representation of the expression will be used.
153-
error_type: The exception class to raise on validation failure.
154-
Must take a `message` keyword argument (inherited from A2AError).
155-
156-
Examples:
157-
Demonstrating with an async method:
158-
>>> import asyncio
159-
>>> from a2a.utils.errors import UnsupportedOperationError
160-
>>>
161-
>>> class MyAgent:
162-
... def __init__(self, streaming_enabled: bool):
163-
... self.streaming_enabled = streaming_enabled
164-
...
165-
... @validate(
166-
... lambda self: self.streaming_enabled,
167-
... 'Streaming is not enabled for this agent',
168-
... )
169-
... async def stream_response(self, message: str):
170-
... return f'Streaming: {message}'
171-
>>>
172-
>>> async def run_async_test():
173-
... # Successful call
174-
... agent_ok = MyAgent(streaming_enabled=True)
175-
... result = await agent_ok.stream_response('hello')
176-
... print(result)
177-
...
178-
... # Call that fails validation
179-
... agent_fail = MyAgent(streaming_enabled=False)
180-
... try:
181-
... await agent_fail.stream_response('world')
182-
... except UnsupportedOperationError as e:
183-
... print(e.message)
184-
>>>
185-
>>> asyncio.run(run_async_test())
186-
Streaming: hello
187-
Streaming is not enabled for this agent
188-
189-
Demonstrating with a sync method:
190-
>>> class SecureAgent:
191-
... def __init__(self):
192-
... self.auth_enabled = False
193-
...
194-
... @validate(
195-
... lambda self: self.auth_enabled,
196-
... 'Authentication must be enabled for this operation',
197-
... )
198-
... def secure_operation(self, data: str):
199-
... return f'Processing secure data: {data}'
200-
>>>
201-
>>> # Error case example
202-
>>> agent = SecureAgent()
203-
>>> try:
204-
... agent.secure_operation('secret')
205-
... except UnsupportedOperationError as e:
206-
... print(e.message)
207-
Authentication must be enabled for this operation
208-
209-
Note:
210-
This decorator works with both sync and async methods automatically.
211-
"""
212-
213-
def decorator(function: Callable) -> Callable:
214-
if inspect.isasyncgenfunction(function):
215-
216-
@functools.wraps(function)
217-
async def async_gen_wrapper(self: Any, *args, **kwargs) -> Any:
218-
if not expression(self):
219-
final_message = error_message or str(expression)
220-
logger.error('Validation failure: %s', final_message)
221-
raise (
222-
error_type(final_message)
223-
if final_message
224-
else error_type
225-
)
226-
inner = function(self, *args, **kwargs)
227-
try:
228-
async for item in inner:
229-
yield item
230-
finally:
231-
await inner.aclose()
232-
233-
return async_gen_wrapper
234-
235-
if inspect.iscoroutinefunction(function):
236-
237-
@functools.wraps(function)
238-
async def async_wrapper(self: Any, *args, **kwargs) -> Any:
239-
if not expression(self):
240-
final_message = error_message or str(expression)
241-
logger.error('Validation failure: %s', final_message)
242-
raise (
243-
error_type(final_message)
244-
if final_message
245-
else error_type
246-
)
247-
return await function(self, *args, **kwargs)
248-
249-
return async_wrapper
250-
251-
@functools.wraps(function)
252-
def sync_wrapper(self: Any, *args, **kwargs) -> Any:
253-
if not expression(self):
254-
final_message = error_message or str(expression)
255-
logger.error('Validation failure: %s', final_message)
256-
raise error_type(final_message)
257-
return function(self, *args, **kwargs)
258-
259-
return sync_wrapper
260-
261-
return decorator
262-
263-
264137
def are_modalities_compatible(
265138
server_output_modes: list[str] | None, client_output_modes: list[str] | None
266139
) -> bool:

tests/utils/test_helpers.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
build_text_artifact,
3030
canonicalize_agent_card,
3131
create_task_obj,
32-
validate,
3332
)
3433

3534

@@ -249,27 +248,6 @@ def test_build_text_artifact():
249248
assert artifact.parts[0].text == text
250249

251250

252-
# Test validate decorator
253-
def test_validate_decorator():
254-
class TestClass:
255-
condition = True
256-
257-
@validate(lambda self: self.condition, 'Condition not met')
258-
def test_method(self) -> str:
259-
return 'Success'
260-
261-
obj = TestClass()
262-
263-
# Test passing condition
264-
assert obj.test_method() == 'Success'
265-
266-
# Test failing condition
267-
obj.condition = False
268-
with pytest.raises(UnsupportedOperationError) as exc_info:
269-
obj.test_method()
270-
assert 'Condition not met' in str(exc_info.value)
271-
272-
273251
# Tests for are_modalities_compatible
274252
def test_are_modalities_compatible_client_none():
275253
assert (

0 commit comments

Comments
 (0)