From 1adb6e25eb987201661ed06953a1745c7ef07aff Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Sat, 9 Aug 2025 12:03:59 -0400 Subject: [PATCH 1/2] test: add coverage for error handlers, constants, optionals, and models --- tests/client/test_optionals.py | 16 +++++ tests/server/test_models.py | 106 +++++++++++++++++++++++++++++ tests/utils/test_constants.py | 15 ++++ tests/utils/test_error_handlers.py | 90 ++++++++++++++++++++++++ 4 files changed, 227 insertions(+) create mode 100644 tests/client/test_optionals.py create mode 100644 tests/server/test_models.py create mode 100644 tests/utils/test_constants.py create mode 100644 tests/utils/test_error_handlers.py diff --git a/tests/client/test_optionals.py b/tests/client/test_optionals.py new file mode 100644 index 000000000..81cbd387d --- /dev/null +++ b/tests/client/test_optionals.py @@ -0,0 +1,16 @@ +"""Tests for a2a.client.optionals module.""" + +import importlib +import sys + +from unittest.mock import patch + + +def test_channel_import_failure(): + """Test Channel behavior when grpc is not available.""" + with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}): + if 'a2a.client.optionals' in sys.modules: + del sys.modules['a2a.client.optionals'] + + optionals = importlib.import_module('a2a.client.optionals') + assert optionals.Channel is None diff --git a/tests/server/test_models.py b/tests/server/test_models.py new file mode 100644 index 000000000..3c9fdb18e --- /dev/null +++ b/tests/server/test_models.py @@ -0,0 +1,106 @@ +"""Tests for a2a.server.models module.""" + +from unittest.mock import MagicMock + +from sqlalchemy.orm import DeclarativeBase + +from a2a.server.models import ( + PydanticListType, + PydanticType, + create_push_notification_config_model, + create_task_model, +) +from a2a.types import Artifact, TaskState, TaskStatus, TextPart + + +class TestPydanticType: + """Tests for PydanticType SQLAlchemy type decorator.""" + + def test_process_bind_param_with_pydantic_model(self): + pydantic_type = PydanticType(TaskStatus) + status = TaskStatus(state=TaskState.working) + dialect = MagicMock() + + result = pydantic_type.process_bind_param(status, dialect) + assert result["state"] == "working" + assert result["message"] is None + # TaskStatus may have other optional fields + + def test_process_bind_param_with_none(self): + pydantic_type = PydanticType(TaskStatus) + dialect = MagicMock() + + result = pydantic_type.process_bind_param(None, dialect) + assert result is None + + def test_process_result_value(self): + pydantic_type = PydanticType(TaskStatus) + dialect = MagicMock() + + result = pydantic_type.process_result_value({"state": "completed", "message": None}, dialect) + assert isinstance(result, TaskStatus) + assert result.state == "completed" + + +class TestPydanticListType: + """Tests for PydanticListType SQLAlchemy type decorator.""" + + def test_process_bind_param_with_list(self): + pydantic_list_type = PydanticListType(Artifact) + artifacts = [ + Artifact(artifact_id="1", parts=[TextPart(type="text", text="Hello")]), + Artifact(artifact_id="2", parts=[TextPart(type="text", text="World")]) + ] + dialect = MagicMock() + + result = pydantic_list_type.process_bind_param(artifacts, dialect) + assert len(result) == 2 + assert result[0]["artifactId"] == "1" # JSON mode uses camelCase + assert result[1]["artifactId"] == "2" + + def test_process_result_value_with_list(self): + pydantic_list_type = PydanticListType(Artifact) + dialect = MagicMock() + data = [ + {"artifact_id": "1", "parts": [{"type": "text", "text": "Hello"}]}, + {"artifact_id": "2", "parts": [{"type": "text", "text": "World"}]} + ] + + result = pydantic_list_type.process_result_value(data, dialect) + assert len(result) == 2 + assert all(isinstance(art, Artifact) for art in result) + assert result[0].artifact_id == "1" + assert result[1].artifact_id == "2" + + +def test_create_task_model(): + """Test dynamic task model creation.""" + # Create a fresh base to avoid table conflicts + class TestBase(DeclarativeBase): + pass + + # Create with default table name + default_task_model = create_task_model('test_tasks_1', TestBase) + assert default_task_model.__tablename__ == 'test_tasks_1' + assert default_task_model.__name__ == 'TaskModel_test_tasks_1' + + # Create with custom table name + custom_task_model = create_task_model('test_tasks_2', TestBase) + assert custom_task_model.__tablename__ == 'test_tasks_2' + assert custom_task_model.__name__ == 'TaskModel_test_tasks_2' + + +def test_create_push_notification_config_model(): + """Test dynamic push notification config model creation.""" + # Create a fresh base to avoid table conflicts + class TestBase(DeclarativeBase): + pass + + # Create with default table name + default_model = create_push_notification_config_model('test_push_configs_1', TestBase) + assert default_model.__tablename__ == 'test_push_configs_1' + + # Create with custom table name + custom_model = create_push_notification_config_model('test_push_configs_2', TestBase) + assert custom_model.__tablename__ == 'test_push_configs_2' + assert 'test_push_configs_2' in custom_model.__name__ diff --git a/tests/utils/test_constants.py b/tests/utils/test_constants.py new file mode 100644 index 000000000..c2b1e00de --- /dev/null +++ b/tests/utils/test_constants.py @@ -0,0 +1,15 @@ +"""Tests for a2a.utils.constants module.""" + +from a2a.utils import constants + + +def test_agent_card_constants(): + """Test that agent card constants have expected values.""" + assert constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json' + assert constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json' + assert constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard' + + +def test_default_rpc_url(): + """Test default RPC URL constant.""" + assert constants.DEFAULT_RPC_URL == '/' diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py new file mode 100644 index 000000000..e9a526981 --- /dev/null +++ b/tests/utils/test_error_handlers.py @@ -0,0 +1,90 @@ +"""Tests for a2a.utils.error_handlers module.""" + +from unittest.mock import patch + +import pytest + +from a2a.types import ( + InternalError, + InvalidRequestError, + MethodNotFoundError, + TaskNotFoundError, +) +from a2a.utils.error_handlers import ( + A2AErrorToHttpStatus, + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ServerError + + +class MockJSONResponse: + def __init__(self, content, status_code): + self.content = content + self.status_code = status_code + + +@pytest.mark.asyncio +async def test_rest_error_handler_server_error(): + """Test rest_error_handler with ServerError.""" + error = InvalidRequestError(message="Bad request") + + @rest_error_handler + async def failing_func(): + raise ServerError(error=error) + + with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): + result = await failing_func() + + assert isinstance(result, MockJSONResponse) + assert result.status_code == 400 + assert result.content == {'message': 'Bad request'} + + +@pytest.mark.asyncio +async def test_rest_error_handler_unknown_exception(): + """Test rest_error_handler with unknown exception.""" + @rest_error_handler + async def failing_func(): + raise ValueError("Unexpected error") + + with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): + result = await failing_func() + + assert isinstance(result, MockJSONResponse) + assert result.status_code == 500 + assert result.content == {'message': 'unknown exception'} + + +@pytest.mark.asyncio +async def test_rest_stream_error_handler_server_error(): + """Test rest_stream_error_handler with ServerError.""" + error = InternalError(message="Internal server error") + + @rest_stream_error_handler + async def failing_stream(): + raise ServerError(error=error) + + with pytest.raises(ServerError) as exc_info: + await failing_stream() + + assert exc_info.value.error == error + + +@pytest.mark.asyncio +async def test_rest_stream_error_handler_reraises_exception(): + """Test rest_stream_error_handler reraises other exceptions.""" + @rest_stream_error_handler + async def failing_stream(): + raise RuntimeError("Stream failed") + + with pytest.raises(RuntimeError, match="Stream failed"): + await failing_stream() + + +def test_a2a_error_to_http_status_mapping(): + """Test A2AErrorToHttpStatus mapping.""" + assert A2AErrorToHttpStatus[InvalidRequestError] == 400 + assert A2AErrorToHttpStatus[MethodNotFoundError] == 404 + assert A2AErrorToHttpStatus[TaskNotFoundError] == 404 + assert A2AErrorToHttpStatus[InternalError] == 500 From bc919025d246c8ea8f234f4c31585ac837260071 Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Sat, 9 Aug 2025 12:18:32 -0400 Subject: [PATCH 2/2] style: apply ruff formatting to test files --- tests/server/test_models.py | 40 +++++++++++++++++++----------- tests/utils/test_constants.py | 12 ++++++--- tests/utils/test_error_handlers.py | 12 +++++---- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/tests/server/test_models.py b/tests/server/test_models.py index 3c9fdb18e..64fed1008 100644 --- a/tests/server/test_models.py +++ b/tests/server/test_models.py @@ -22,8 +22,8 @@ def test_process_bind_param_with_pydantic_model(self): dialect = MagicMock() result = pydantic_type.process_bind_param(status, dialect) - assert result["state"] == "working" - assert result["message"] is None + assert result['state'] == 'working' + assert result['message'] is None # TaskStatus may have other optional fields def test_process_bind_param_with_none(self): @@ -37,9 +37,11 @@ def test_process_result_value(self): pydantic_type = PydanticType(TaskStatus) dialect = MagicMock() - result = pydantic_type.process_result_value({"state": "completed", "message": None}, dialect) + result = pydantic_type.process_result_value( + {'state': 'completed', 'message': None}, dialect + ) assert isinstance(result, TaskStatus) - assert result.state == "completed" + assert result.state == 'completed' class TestPydanticListType: @@ -48,33 +50,38 @@ class TestPydanticListType: def test_process_bind_param_with_list(self): pydantic_list_type = PydanticListType(Artifact) artifacts = [ - Artifact(artifact_id="1", parts=[TextPart(type="text", text="Hello")]), - Artifact(artifact_id="2", parts=[TextPart(type="text", text="World")]) + Artifact( + artifact_id='1', parts=[TextPart(type='text', text='Hello')] + ), + Artifact( + artifact_id='2', parts=[TextPart(type='text', text='World')] + ), ] dialect = MagicMock() result = pydantic_list_type.process_bind_param(artifacts, dialect) assert len(result) == 2 - assert result[0]["artifactId"] == "1" # JSON mode uses camelCase - assert result[1]["artifactId"] == "2" + assert result[0]['artifactId'] == '1' # JSON mode uses camelCase + assert result[1]['artifactId'] == '2' def test_process_result_value_with_list(self): pydantic_list_type = PydanticListType(Artifact) dialect = MagicMock() data = [ - {"artifact_id": "1", "parts": [{"type": "text", "text": "Hello"}]}, - {"artifact_id": "2", "parts": [{"type": "text", "text": "World"}]} + {'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]}, + {'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]}, ] result = pydantic_list_type.process_result_value(data, dialect) assert len(result) == 2 assert all(isinstance(art, Artifact) for art in result) - assert result[0].artifact_id == "1" - assert result[1].artifact_id == "2" + assert result[0].artifact_id == '1' + assert result[1].artifact_id == '2' def test_create_task_model(): """Test dynamic task model creation.""" + # Create a fresh base to avoid table conflicts class TestBase(DeclarativeBase): pass @@ -92,15 +99,20 @@ class TestBase(DeclarativeBase): def test_create_push_notification_config_model(): """Test dynamic push notification config model creation.""" + # Create a fresh base to avoid table conflicts class TestBase(DeclarativeBase): pass # Create with default table name - default_model = create_push_notification_config_model('test_push_configs_1', TestBase) + default_model = create_push_notification_config_model( + 'test_push_configs_1', TestBase + ) assert default_model.__tablename__ == 'test_push_configs_1' # Create with custom table name - custom_model = create_push_notification_config_model('test_push_configs_2', TestBase) + custom_model = create_push_notification_config_model( + 'test_push_configs_2', TestBase + ) assert custom_model.__tablename__ == 'test_push_configs_2' assert 'test_push_configs_2' in custom_model.__name__ diff --git a/tests/utils/test_constants.py b/tests/utils/test_constants.py index c2b1e00de..59e9b8366 100644 --- a/tests/utils/test_constants.py +++ b/tests/utils/test_constants.py @@ -5,9 +5,15 @@ def test_agent_card_constants(): """Test that agent card constants have expected values.""" - assert constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json' - assert constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json' - assert constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard' + assert ( + constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json' + ) + assert ( + constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json' + ) + assert ( + constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard' + ) def test_default_rpc_url(): diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py index e9a526981..ec41dc1f5 100644 --- a/tests/utils/test_error_handlers.py +++ b/tests/utils/test_error_handlers.py @@ -27,7 +27,7 @@ def __init__(self, content, status_code): @pytest.mark.asyncio async def test_rest_error_handler_server_error(): """Test rest_error_handler with ServerError.""" - error = InvalidRequestError(message="Bad request") + error = InvalidRequestError(message='Bad request') @rest_error_handler async def failing_func(): @@ -44,9 +44,10 @@ async def failing_func(): @pytest.mark.asyncio async def test_rest_error_handler_unknown_exception(): """Test rest_error_handler with unknown exception.""" + @rest_error_handler async def failing_func(): - raise ValueError("Unexpected error") + raise ValueError('Unexpected error') with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): result = await failing_func() @@ -59,7 +60,7 @@ async def failing_func(): @pytest.mark.asyncio async def test_rest_stream_error_handler_server_error(): """Test rest_stream_error_handler with ServerError.""" - error = InternalError(message="Internal server error") + error = InternalError(message='Internal server error') @rest_stream_error_handler async def failing_stream(): @@ -74,11 +75,12 @@ async def failing_stream(): @pytest.mark.asyncio async def test_rest_stream_error_handler_reraises_exception(): """Test rest_stream_error_handler reraises other exceptions.""" + @rest_stream_error_handler async def failing_stream(): - raise RuntimeError("Stream failed") + raise RuntimeError('Stream failed') - with pytest.raises(RuntimeError, match="Stream failed"): + with pytest.raises(RuntimeError, match='Stream failed'): await failing_stream()