forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_models.py
More file actions
118 lines (91 loc) · 3.91 KB
/
test_models.py
File metadata and controls
118 lines (91 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Tests for a2a.server.models module."""
from unittest.mock import MagicMock
from sqlalchemy.orm import DeclarativeBase
from a2a.server.models import (
PydanticListType,
PydanticType,
create_push_notification_config_model,
create_task_model,
)
from a2a.types import Artifact, TaskState, TaskStatus, TextPart
class TestPydanticType:
"""Tests for PydanticType SQLAlchemy type decorator."""
def test_process_bind_param_with_pydantic_model(self):
pydantic_type = PydanticType(TaskStatus)
status = TaskStatus(state=TaskState.working)
dialect = MagicMock()
result = pydantic_type.process_bind_param(status, dialect)
assert result['state'] == 'working'
assert result['message'] is None
# TaskStatus may have other optional fields
def test_process_bind_param_with_none(self):
pydantic_type = PydanticType(TaskStatus)
dialect = MagicMock()
result = pydantic_type.process_bind_param(None, dialect)
assert result is None
def test_process_result_value(self):
pydantic_type = PydanticType(TaskStatus)
dialect = MagicMock()
result = pydantic_type.process_result_value(
{'state': 'completed', 'message': None}, dialect
)
assert isinstance(result, TaskStatus)
assert result.state == 'completed'
class TestPydanticListType:
"""Tests for PydanticListType SQLAlchemy type decorator."""
def test_process_bind_param_with_list(self):
pydantic_list_type = PydanticListType(Artifact)
artifacts = [
Artifact(
artifact_id='1', parts=[TextPart(type='text', text='Hello')]
),
Artifact(
artifact_id='2', parts=[TextPart(type='text', text='World')]
),
]
dialect = MagicMock()
result = pydantic_list_type.process_bind_param(artifacts, dialect)
assert len(result) == 2
assert result[0]['artifactId'] == '1' # JSON mode uses camelCase
assert result[1]['artifactId'] == '2'
def test_process_result_value_with_list(self):
pydantic_list_type = PydanticListType(Artifact)
dialect = MagicMock()
data = [
{'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]},
{'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]},
]
result = pydantic_list_type.process_result_value(data, dialect)
assert len(result) == 2
assert all(isinstance(art, Artifact) for art in result)
assert result[0].artifact_id == '1'
assert result[1].artifact_id == '2'
def test_create_task_model():
"""Test dynamic task model creation."""
# Create a fresh base to avoid table conflicts
class TestBase(DeclarativeBase):
pass
# Create with default table name
default_task_model = create_task_model('test_tasks_1', TestBase)
assert default_task_model.__tablename__ == 'test_tasks_1'
assert default_task_model.__name__ == 'TaskModel_test_tasks_1'
# Create with custom table name
custom_task_model = create_task_model('test_tasks_2', TestBase)
assert custom_task_model.__tablename__ == 'test_tasks_2'
assert custom_task_model.__name__ == 'TaskModel_test_tasks_2'
def test_create_push_notification_config_model():
"""Test dynamic push notification config model creation."""
# Create a fresh base to avoid table conflicts
class TestBase(DeclarativeBase):
pass
# Create with default table name
default_model = create_push_notification_config_model(
'test_push_configs_1', TestBase
)
assert default_model.__tablename__ == 'test_push_configs_1'
# Create with custom table name
custom_model = create_push_notification_config_model(
'test_push_configs_2', TestBase
)
assert custom_model.__tablename__ == 'test_push_configs_2'
assert 'test_push_configs_2' in custom_model.__name__