Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/generate_types.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ uv run datamodel-codegen \
--no-alias

echo "Formatting generated file with ruff..."
uv run ruff check --fix-only "$GENERATED_FILE"
uv run ruff format "$GENERATED_FILE"

echo "Codegen finished successfully."
43 changes: 22 additions & 21 deletions src/a2a/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# generated by datamodel-codegen:
# filename: https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/main/specification/json/a2a.json
# filename: https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/uuid-fields/specification/json/a2a.json

from __future__ import annotations

from enum import Enum
from typing import Any, Literal
from uuid import UUID

Check failure on line 8 in src/a2a/types.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TC003)

src/a2a/types.py:8:18: TC003 Move standard library import `uuid.UUID` into a type-checking block

from pydantic import Field, RootModel

Expand Down Expand Up @@ -293,15 +294,15 @@
Defines parameters for deleting a specific push notification configuration for a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
metadata: dict[str, Any] | None = None
"""
Optional metadata associated with the request.
"""
push_notification_config_id: str
push_notification_config_id: UUID
"""
The ID of the push notification configuration to delete.
"""
Expand Down Expand Up @@ -430,15 +431,15 @@
Defines parameters for fetching a specific push notification configuration for a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
metadata: dict[str, Any] | None = None
"""
Optional metadata associated with the request.
"""
push_notification_config_id: str | None = None
push_notification_config_id: UUID | None = None
"""
The ID of the push notification configuration to retrieve.
"""
Expand Down Expand Up @@ -675,7 +676,7 @@
Defines parameters for listing all push notification configurations associated with a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -828,7 +829,7 @@
"""
Optional authentication details for the agent to use when calling the notification URL.
"""
id: str | None = None
id: UUID | None = None
"""
A unique ID for the push notification configuration, set by the client
to support multiple notification callbacks.
Expand Down Expand Up @@ -879,7 +880,7 @@
Defines parameters containing a task ID, used for simple task operations.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -938,7 +939,7 @@
"""
The push notification configuration for this task.
"""
task_id: str
task_id: UUID
"""
The ID of the task.
"""
Expand All @@ -953,7 +954,7 @@
"""
The number of most recent messages from the task's history to retrieve.
"""
id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -1374,7 +1375,7 @@
Represents a file, data structure, or other resource generated by an agent during a task.
"""

artifact_id: str
artifact_id: UUID
"""
A unique identifier for the artifact within the scope of the task.
"""
Expand Down Expand Up @@ -1438,7 +1439,7 @@
Represents a single message in the conversation between a user and an agent.
"""

context_id: str | None = None
context_id: UUID | None = None
"""
The context identifier for this message, used to group related interactions.
"""
Expand All @@ -1450,7 +1451,7 @@
"""
The type of this object, used as a discriminator. Always 'message' for a Message.
"""
message_id: str
message_id: UUID
"""
A unique identifier for the message, typically a UUID, generated by the sender.
"""
Expand All @@ -1463,15 +1464,15 @@
An array of content parts that form the message body. A message can be
composed of multiple parts of different types (e.g., text and files).
"""
reference_task_ids: list[str] | None = None
reference_task_ids: list[UUID] | None = None
"""
A list of other task IDs that this message references for additional context.
"""
role: Role
"""
Identifies the sender of the message. `user` for the client, `agent` for the service.
"""
task_id: str | None = None
task_id: UUID | None = None
"""
The identifier of the task this message is part of. Can be omitted for the first message of a new task.
"""
Expand Down Expand Up @@ -1614,7 +1615,7 @@
"""
The artifact that was generated or updated.
"""
context_id: str
context_id: UUID
"""
The context ID associated with the task.
"""
Expand All @@ -1630,7 +1631,7 @@
"""
Optional metadata for extensions.
"""
task_id: str
task_id: UUID
"""
The ID of the task this artifact belongs to.
"""
Expand Down Expand Up @@ -1663,7 +1664,7 @@
This is typically used in streaming or subscription models.
"""

context_id: str
context_id: UUID
"""
The context ID associated with the task.
"""
Expand All @@ -1683,7 +1684,7 @@
"""
The new status of the task.
"""
task_id: str
task_id: UUID
"""
The ID of the task that was updated.
"""
Expand Down Expand Up @@ -1861,15 +1862,15 @@
"""
A collection of artifacts generated by the agent during the execution of the task.
"""
context_id: str
context_id: UUID
"""
A server-generated identifier for maintaining context across multiple related tasks or interactions.
"""
history: list[Message] | None = None
"""
An array of messages exchanged during the task, representing the conversation history.
"""
id: str
id: UUID
"""
A unique identifier for the task, generated by the server for a new task.
"""
Expand Down
22 changes: 11 additions & 11 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
if message is None:
return None
return a2a_pb2.Message(
message_id=message.message_id,
message_id=str(message.message_id),
content=[ToProto.part(p) for p in message.parts],
context_id=message.context_id or '',
task_id=message.task_id or '',
context_id=str(message.context_id) if message.context_id else None,
task_id=str(message.task_id) if message.task_id else None,
role=cls.role(message.role),
metadata=ToProto.metadata(message.metadata),
)
Expand Down Expand Up @@ -86,8 +86,8 @@ def file(
@classmethod
def task(cls, task: types.Task) -> a2a_pb2.Task:
return a2a_pb2.Task(
id=task.id,
context_id=task.context_id,
id=str(task.id),
context_id=str(task.context_id),
status=ToProto.task_status(task.status),
artifacts=(
[ToProto.artifact(a) for a in task.artifacts]
Expand Down Expand Up @@ -129,7 +129,7 @@ def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState:
@classmethod
def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
return a2a_pb2.Artifact(
artifact_id=artifact.artifact_id,
artifact_id=str(artifact.artifact_id),
description=artifact.description,
metadata=ToProto.metadata(artifact.metadata),
name=artifact.name,
Expand All @@ -155,7 +155,7 @@ def push_notification_config(
else None
)
return a2a_pb2.PushNotificationConfig(
id=config.id or '',
id=str(config.id) if config.id else None,
url=config.url,
token=config.token,
authentication=auth_info,
Expand All @@ -166,8 +166,8 @@ def task_artifact_update_event(
cls, event: types.TaskArtifactUpdateEvent
) -> a2a_pb2.TaskArtifactUpdateEvent:
return a2a_pb2.TaskArtifactUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
task_id=str(event.task_id),
context_id=str(event.context_id),
artifact=ToProto.artifact(event.artifact),
metadata=ToProto.metadata(event.metadata),
append=event.append or False,
Expand All @@ -179,8 +179,8 @@ def task_status_update_event(
cls, event: types.TaskStatusUpdateEvent
) -> a2a_pb2.TaskStatusUpdateEvent:
return a2a_pb2.TaskStatusUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
task_id=str(event.task_id),
context_id=str(event.context_id),
status=ToProto.task_status(event.status),
metadata=ToProto.metadata(event.metadata),
final=event.final,
Expand Down
16 changes: 2 additions & 14 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,10 @@ def new_task(request: Message) -> Task:
if isinstance(part.root, TextPart) and not part.root.text:
raise ValueError('TextPart content cannot be empty')

context_id_str = request.context_id
if context_id_str is not None:
try:
uuid.UUID(context_id_str)
context_id = context_id_str
except (ValueError, AttributeError, TypeError) as e:
raise ValueError(
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
) from e
else:
context_id = str(uuid.uuid4())

return Task(
status=TaskStatus(state=TaskState.submitted),
id=(request.task_id if request.task_id else str(uuid.uuid4())),
context_id=context_id,
id=request.task_id or uuid.uuid4(),
context_id=request.context_id or uuid.uuid4(),
history=[request],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/client/test_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def build_success_response(request: httpx.Request) -> httpx.Response:
jsonrpc='2.0',
result=Message(
kind='message',
message_id='message-id',
message_id='c222a603-645e-4c37-8f7b-e49f3ea80e9e',
role=Role.agent,
parts=[],
),
Expand All @@ -75,7 +75,7 @@ def build_success_response(request: httpx.Request) -> httpx.Response:
def build_message() -> Message:
"""Builds a minimal Message."""
return Message(
message_id='msg1',
message_id='87c8541d-f773-4825-bbb1-f518727231f2',
role=Role.user,
parts=[],
)
Expand Down
20 changes: 10 additions & 10 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sample_agent_card():
def sample_message():
return Message(
role=Role.user,
message_id='msg-1',
message_id='15957e91-63e6-40ac-8205-1d1ffb09a5b2',
parts=[Part(root=TextPart(text='Hello'))],
)

Expand All @@ -65,8 +65,8 @@ async def test_send_message_streaming(
):
async def create_stream(*args, **kwargs):
yield Task(
id='task-123',
context_id='ctx-456',
id='536ab032-6915-47d1-9909-4172dbee4aa0',
context_id='9f18b6e9-63c4-4d44-a8b8-f4648003b6b8',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -77,7 +77,7 @@ async def create_stream(*args, **kwargs):
mock_transport.send_message_streaming.assert_called_once()
assert not mock_transport.send_message.called
assert len(events) == 1
assert events[0][0].id == 'task-123'
assert str(events[0][0].id) == '536ab032-6915-47d1-9909-4172dbee4aa0'


@pytest.mark.asyncio
Expand All @@ -86,8 +86,8 @@ async def test_send_message_non_streaming(
):
base_client._config.streaming = False
mock_transport.send_message.return_value = Task(
id='task-456',
context_id='ctx-789',
id='9368e3b5-c796-46cf-9318-6c73e1a37e58',
context_id='0a934875-fa22-4af0-8b40-79b13d46e4a6',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -96,7 +96,7 @@ async def test_send_message_non_streaming(
mock_transport.send_message.assert_called_once()
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-456'
assert str(events[0][0].id) == '9368e3b5-c796-46cf-9318-6c73e1a37e58'


@pytest.mark.asyncio
Expand All @@ -105,8 +105,8 @@ async def test_send_message_non_streaming_agent_capability_false(
):
base_client._card.capabilities.streaming = False
mock_transport.send_message.return_value = Task(
id='task-789',
context_id='ctx-101',
id='d7541723-0796-4231-8849-f6f137ea3bf8',
context_id='dab80cd1-224d-47cd-abd8-cc53101fb273',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -115,4 +115,4 @@ async def test_send_message_non_streaming_agent_capability_false(
mock_transport.send_message.assert_called_once()
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-789'
assert str(events[0][0].id) == 'd7541723-0796-4231-8849-f6f137ea3bf8'
Loading
Loading