forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_serialization.py
More file actions
183 lines (153 loc) · 6.43 KB
/
test_serialization.py
File metadata and controls
183 lines (153 loc) · 6.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import json
from unittest import mock
import pytest
from starlette.testclient import TestClient
from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication
from a2a.types import (
APIKeySecurityScheme,
AgentCapabilities,
AgentCard,
In,
Message,
Part,
Role,
SecurityScheme,
TextPart,
JSONParseError,
InvalidRequestError,
)
from pydantic import ValidationError
@pytest.fixture
def agent_card_with_api_key():
"""Provides an AgentCard with an APIKeySecurityScheme for testing serialization."""
# This data uses the alias 'in', which is correct for creating the model.
api_key_scheme_data = {
'type': 'apiKey',
'name': 'X-API-KEY',
'in': 'header',
}
api_key_scheme = APIKeySecurityScheme.model_validate(api_key_scheme_data)
agent_card = AgentCard(
name='APIKeyAgent',
description='An agent that uses API Key auth.',
url='http://example.com/apikey-agent',
version='1.0.0',
capabilities=AgentCapabilities(),
defaultInputModes=['text/plain'],
defaultOutputModes=['text/plain'],
skills=[],
securitySchemes={'api_key_auth': SecurityScheme(root=api_key_scheme)},
security=[{'api_key_auth': []}],
)
return agent_card
def test_starlette_agent_card_with_api_key_scheme_alias(
agent_card_with_api_key: AgentCard,
):
"""
Tests that the A2AStarletteApplication endpoint correctly serializes aliased fields.
This verifies the fix for `APIKeySecurityScheme.in_` being serialized as `in_` instead of `in`.
"""
handler = mock.AsyncMock()
app_instance = A2AStarletteApplication(agent_card_with_api_key, handler)
client = TestClient(app_instance.build())
response = client.get('/.well-known/agent.json')
assert response.status_code == 200
response_data = response.json()
security_scheme_json = response_data['securitySchemes']['api_key_auth']
assert 'in' in security_scheme_json
assert security_scheme_json['in'] == 'header'
assert 'in_' not in security_scheme_json
try:
parsed_card = AgentCard.model_validate(response_data)
parsed_scheme_wrapper = parsed_card.securitySchemes['api_key_auth']
assert isinstance(parsed_scheme_wrapper.root, APIKeySecurityScheme)
assert parsed_scheme_wrapper.root.in_ == In.header
except ValidationError as e:
pytest.fail(
f"AgentCard.model_validate failed on the server's response: {e}"
)
def test_fastapi_agent_card_with_api_key_scheme_alias(
agent_card_with_api_key: AgentCard,
):
"""
Tests that the A2AFastAPIApplication endpoint correctly serializes aliased fields.
This verifies the fix for `APIKeySecurityScheme.in_` being serialized as `in_` instead of `in`.
"""
handler = mock.AsyncMock()
app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler)
client = TestClient(app_instance.build())
response = client.get('/.well-known/agent.json')
assert response.status_code == 200
response_data = response.json()
security_scheme_json = response_data['securitySchemes']['api_key_auth']
assert 'in' in security_scheme_json
assert 'in_' not in security_scheme_json
assert security_scheme_json['in'] == 'header'
def test_handle_invalid_json(agent_card_with_api_key: AgentCard):
"""Test handling of malformed JSON."""
handler = mock.AsyncMock()
app_instance = A2AStarletteApplication(agent_card_with_api_key, handler)
client = TestClient(app_instance.build())
response = client.post('/', content='{ "jsonrpc": "2.0", "method": "test", "id": 1, "params": { "key": "value" }')
assert response.status_code == 200
data = response.json()
assert data['error']['code'] == JSONParseError().code
def test_handle_oversized_payload(agent_card_with_api_key: AgentCard):
"""Test handling of oversized JSON payloads."""
handler = mock.AsyncMock()
app_instance = A2AStarletteApplication(agent_card_with_api_key, handler)
client = TestClient(app_instance.build())
large_string = "a" * 2_000_000 # 2MB string
payload = {
"jsonrpc": "2.0",
"method": "test",
"id": 1,
"params": {"data": large_string},
}
# Starlette/FastAPI's default max request size is around 1MB.
# This test will likely fail with a 413 Payload Too Large if the default is not increased.
# If the application is expected to handle larger payloads, the server configuration needs to be adjusted.
# For this test, we expect a 413 or a graceful JSON-RPC error if the app handles it.
try:
response = client.post('/', json=payload)
# If the app handles it gracefully and returns a JSON-RPC error
if response.status_code == 200:
data = response.json()
assert data['error']['code'] == InvalidRequestError().code
else:
assert response.status_code == 413
except Exception as e:
# Depending on server setup, it might just drop the connection for very large payloads
assert isinstance(e, (ConnectionResetError, RuntimeError))
def test_handle_unicode_characters(agent_card_with_api_key: AgentCard):
"""Test handling of unicode characters in JSON payload."""
handler = mock.AsyncMock()
app_instance = A2AStarletteApplication(agent_card_with_api_key, handler)
client = TestClient(app_instance.build())
unicode_text = "こんにちは世界" # "Hello world" in Japanese
unicode_payload = {
"jsonrpc": "2.0",
"method": "message/send",
"id": "unicode_test",
"params": {
"message": {
"role": "user",
"parts": [{"kind": "text", "text": unicode_text}],
"messageId": "msg-unicode"
}
}
}
# Mock a handler for this method
handler.on_message_send.return_value = Message(
role=Role.agent,
parts=[Part(root=TextPart(text=f"Received: {unicode_text}"))],
messageId="response-unicode"
)
response = client.post('/', json=unicode_payload)
# We are not testing the handler logic here, just that the server can correctly
# deserialize the unicode payload without errors. A 200 response with any valid
# JSON-RPC response indicates success.
assert response.status_code == 200
data = response.json()
assert 'error' not in data or data['error'] is None
assert data['result']['parts'][0]['text'] == f"Received: {unicode_text}"