diff --git a/tests/client/test_optionals.py b/tests/client/test_optionals.py new file mode 100644 index 000000000..81cbd387d --- /dev/null +++ b/tests/client/test_optionals.py @@ -0,0 +1,16 @@ +"""Tests for a2a.client.optionals module.""" + +import importlib +import sys + +from unittest.mock import patch + + +def test_channel_import_failure(): + """Test Channel behavior when grpc is not available.""" + with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}): + if 'a2a.client.optionals' in sys.modules: + del sys.modules['a2a.client.optionals'] + + optionals = importlib.import_module('a2a.client.optionals') + assert optionals.Channel is None diff --git a/tests/server/test_models.py b/tests/server/test_models.py new file mode 100644 index 000000000..64fed1008 --- /dev/null +++ b/tests/server/test_models.py @@ -0,0 +1,118 @@ +"""Tests for a2a.server.models module.""" + +from unittest.mock import MagicMock + +from sqlalchemy.orm import DeclarativeBase + +from a2a.server.models import ( + PydanticListType, + PydanticType, + create_push_notification_config_model, + create_task_model, +) +from a2a.types import Artifact, TaskState, TaskStatus, TextPart + + +class TestPydanticType: + """Tests for PydanticType SQLAlchemy type decorator.""" + + def test_process_bind_param_with_pydantic_model(self): + pydantic_type = PydanticType(TaskStatus) + status = TaskStatus(state=TaskState.working) + dialect = MagicMock() + + result = pydantic_type.process_bind_param(status, dialect) + assert result['state'] == 'working' + assert result['message'] is None + # TaskStatus may have other optional fields + + def test_process_bind_param_with_none(self): + pydantic_type = PydanticType(TaskStatus) + dialect = MagicMock() + + result = pydantic_type.process_bind_param(None, dialect) + assert result is None + + def test_process_result_value(self): + pydantic_type = PydanticType(TaskStatus) + dialect = MagicMock() + + result = pydantic_type.process_result_value( + {'state': 'completed', 'message': None}, dialect + ) + assert isinstance(result, TaskStatus) + assert result.state == 'completed' + + +class TestPydanticListType: + """Tests for PydanticListType SQLAlchemy type decorator.""" + + def test_process_bind_param_with_list(self): + pydantic_list_type = PydanticListType(Artifact) + artifacts = [ + Artifact( + artifact_id='1', parts=[TextPart(type='text', text='Hello')] + ), + Artifact( + artifact_id='2', parts=[TextPart(type='text', text='World')] + ), + ] + dialect = MagicMock() + + result = pydantic_list_type.process_bind_param(artifacts, dialect) + assert len(result) == 2 + assert result[0]['artifactId'] == '1' # JSON mode uses camelCase + assert result[1]['artifactId'] == '2' + + def test_process_result_value_with_list(self): + pydantic_list_type = PydanticListType(Artifact) + dialect = MagicMock() + data = [ + {'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]}, + {'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]}, + ] + + result = pydantic_list_type.process_result_value(data, dialect) + assert len(result) == 2 + assert all(isinstance(art, Artifact) for art in result) + assert result[0].artifact_id == '1' + assert result[1].artifact_id == '2' + + +def test_create_task_model(): + """Test dynamic task model creation.""" + + # Create a fresh base to avoid table conflicts + class TestBase(DeclarativeBase): + pass + + # Create with default table name + default_task_model = create_task_model('test_tasks_1', TestBase) + assert default_task_model.__tablename__ == 'test_tasks_1' + assert default_task_model.__name__ == 'TaskModel_test_tasks_1' + + # Create with custom table name + custom_task_model = create_task_model('test_tasks_2', TestBase) + assert custom_task_model.__tablename__ == 'test_tasks_2' + assert custom_task_model.__name__ == 'TaskModel_test_tasks_2' + + +def test_create_push_notification_config_model(): + """Test dynamic push notification config model creation.""" + + # Create a fresh base to avoid table conflicts + class TestBase(DeclarativeBase): + pass + + # Create with default table name + default_model = create_push_notification_config_model( + 'test_push_configs_1', TestBase + ) + assert default_model.__tablename__ == 'test_push_configs_1' + + # Create with custom table name + custom_model = create_push_notification_config_model( + 'test_push_configs_2', TestBase + ) + assert custom_model.__tablename__ == 'test_push_configs_2' + assert 'test_push_configs_2' in custom_model.__name__ diff --git a/tests/utils/test_constants.py b/tests/utils/test_constants.py new file mode 100644 index 000000000..59e9b8366 --- /dev/null +++ b/tests/utils/test_constants.py @@ -0,0 +1,21 @@ +"""Tests for a2a.utils.constants module.""" + +from a2a.utils import constants + + +def test_agent_card_constants(): + """Test that agent card constants have expected values.""" + assert ( + constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json' + ) + assert ( + constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json' + ) + assert ( + constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard' + ) + + +def test_default_rpc_url(): + """Test default RPC URL constant.""" + assert constants.DEFAULT_RPC_URL == '/' diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py new file mode 100644 index 000000000..ec41dc1f5 --- /dev/null +++ b/tests/utils/test_error_handlers.py @@ -0,0 +1,92 @@ +"""Tests for a2a.utils.error_handlers module.""" + +from unittest.mock import patch + +import pytest + +from a2a.types import ( + InternalError, + InvalidRequestError, + MethodNotFoundError, + TaskNotFoundError, +) +from a2a.utils.error_handlers import ( + A2AErrorToHttpStatus, + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ServerError + + +class MockJSONResponse: + def __init__(self, content, status_code): + self.content = content + self.status_code = status_code + + +@pytest.mark.asyncio +async def test_rest_error_handler_server_error(): + """Test rest_error_handler with ServerError.""" + error = InvalidRequestError(message='Bad request') + + @rest_error_handler + async def failing_func(): + raise ServerError(error=error) + + with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): + result = await failing_func() + + assert isinstance(result, MockJSONResponse) + assert result.status_code == 400 + assert result.content == {'message': 'Bad request'} + + +@pytest.mark.asyncio +async def test_rest_error_handler_unknown_exception(): + """Test rest_error_handler with unknown exception.""" + + @rest_error_handler + async def failing_func(): + raise ValueError('Unexpected error') + + with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): + result = await failing_func() + + assert isinstance(result, MockJSONResponse) + assert result.status_code == 500 + assert result.content == {'message': 'unknown exception'} + + +@pytest.mark.asyncio +async def test_rest_stream_error_handler_server_error(): + """Test rest_stream_error_handler with ServerError.""" + error = InternalError(message='Internal server error') + + @rest_stream_error_handler + async def failing_stream(): + raise ServerError(error=error) + + with pytest.raises(ServerError) as exc_info: + await failing_stream() + + assert exc_info.value.error == error + + +@pytest.mark.asyncio +async def test_rest_stream_error_handler_reraises_exception(): + """Test rest_stream_error_handler reraises other exceptions.""" + + @rest_stream_error_handler + async def failing_stream(): + raise RuntimeError('Stream failed') + + with pytest.raises(RuntimeError, match='Stream failed'): + await failing_stream() + + +def test_a2a_error_to_http_status_mapping(): + """Test A2AErrorToHttpStatus mapping.""" + assert A2AErrorToHttpStatus[InvalidRequestError] == 400 + assert A2AErrorToHttpStatus[MethodNotFoundError] == 404 + assert A2AErrorToHttpStatus[TaskNotFoundError] == 404 + assert A2AErrorToHttpStatus[InternalError] == 500