Skip to content

Commit fb8a0fa

Browse files
committed
fix: update proto2pydantic to v0.5.2 with proper type annotations
- Enum fields default to zero value (e.g., TaskState.TASK_STATE_UNSPECIFIED) - Message fields properly annotated as | None - Fixes all mypy and pyright type errors in generated code
1 parent 226d406 commit fb8a0fa

2 files changed

Lines changed: 20 additions & 20 deletions

File tree

scripts/gen_proto.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
set -e
33

44
# Install proto2pydantic plugin for Pydantic model generation
5-
echo "Installing protoc-gen-proto2pydantic@v0.5.1..."
6-
go install github.com/protocgen/proto2pydantic@v0.5.1
5+
echo "Installing protoc-gen-proto2pydantic@v0.5.2..."
6+
go install github.com/protocgen/proto2pydantic@v0.5.2
77

88
# Run buf generate to regenerate protobuf code and OpenAPI spec
99
npx --yes @bufbuild/buf generate

src/a2a/types/a2a_pydantic.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class TaskPushNotificationConfig(A2ABaseModel):
5656
task_id: str = Field(default='', description='The ID of the task this configuration is associated with.')
5757
url: str = Field(..., description='The URL where the notification should be sent.')
5858
token: str = Field(default='', description='A token unique for this task or session.')
59-
authentication: AuthenticationInfo = Field(default=None, description='Authentication information required to send the notification.')
59+
authentication: AuthenticationInfo | None = Field(default=None, description='Authentication information required to send the notification.')
6060

6161
def to_proto_json(self) -> dict:
6262
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -68,7 +68,7 @@ class SendMessageConfiguration(A2ABaseModel):
6868
"""Configuration of a send message request."""
6969

7070
accepted_output_modes: list[str] | None = Field(default=None, description='A list of media types the client is prepared to accept for response parts. Agents SHOULD use this to tailor their output.')
71-
task_push_notification_config: TaskPushNotificationConfig = Field(default=None, description='Configuration for the agent to send push notifications for task updates. Task id should be empty when sending this configuration in a `SendMessage` request.')
71+
task_push_notification_config: TaskPushNotificationConfig | None = Field(default=None, description='Configuration for the agent to send push notifications for task updates. Task id should be empty when sending this configuration in a `SendMessage` request.')
7272
history_length: int | None = Field(default=None, description="The maximum number of most recent messages from the task's history to retrieve in the response. An unset value means the client does not impose any limit. A value of zero is a request to not include any messages. The server MUST NOT return more messages than the provided value, but MAY apply a lower limit.")
7373
return_immediately: bool = Field(default=False, description='If `true`, the operation returns immediately after creating the task, even if processing is still in progress. If `false` (default), the operation MUST wait until the task reaches a terminal (`COMPLETED`, `FAILED`, `CANCELED`, `REJECTED`) or interrupted (`INPUT_REQUIRED`, `AUTH_REQUIRED`) state before returning.')
7474

@@ -81,7 +81,7 @@ def to_proto_json(self) -> dict:
8181
class Part(A2ABaseModel):
8282
"""`Part` represents a container for a section of communication content. Parts can be purely textual, some sort of file (image, video, etc) or a structured data blob (i.e. JSON)."""
8383

84-
metadata: dict[str, Any] = Field(default=None, description='Optional. metadata associated with this part.')
84+
metadata: dict[str, Any] | None = Field(default=None, description='Optional. metadata associated with this part.')
8585
filename: str = Field(default='', description='An optional `filename` for the file (e.g., "document.pdf").')
8686
media_type: str = Field(default='', description='The `media_type` (MIME type) of the part content (e.g., "text/plain", "application/json", "image/png"). This field is available for all part types.')
8787
content: str | bytes | Any | None = None
@@ -101,7 +101,7 @@ class Message(A2ABaseModel):
101101
task_id: str = Field(default='', description='Optional. The task id of the message. If set, the message will be associated with the given task.')
102102
role: Role = Field(..., description='Identifies the sender of the message.')
103103
parts: list[Part] = Field(..., description='Parts is the container of the message content.')
104-
metadata: dict[str, Any] = Field(default=None, description='Optional. Any metadata to provide along with the message.')
104+
metadata: dict[str, Any] | None = Field(default=None, description='Optional. Any metadata to provide along with the message.')
105105
extensions: list[str] | None = Field(default=None, description='The URIs of extensions that are present or contributed to this Message.')
106106
reference_task_ids: list[str] | None = Field(default=None, description='A list of task IDs that this message references for additional context.')
107107

@@ -115,8 +115,8 @@ class TaskStatus(A2ABaseModel):
115115
"""A container for the status of a task."""
116116

117117
state: TaskState = Field(..., description='The current state of this task.')
118-
message: Message = Field(default=None, description='A message associated with the status.')
119-
timestamp: datetime = Field(default=None, description='ISO 8601 Timestamp when the status was recorded. Example: "2023-10-27T10:00:00Z"')
118+
message: Message | None = Field(default=None, description='A message associated with the status.')
119+
timestamp: datetime | None = Field(default=None, description='ISO 8601 Timestamp when the status was recorded. Example: "2023-10-27T10:00:00Z"')
120120

121121
def to_proto_json(self) -> dict:
122122
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -140,7 +140,7 @@ class Task(A2ABaseModel):
140140
status: TaskStatus = Field(..., description='The current status of a `Task`, including `state` and a `message`.')
141141
artifacts: list[Artifact] | None = Field(default=None, description='A set of output artifacts for a `Task`.')
142142
history: list[Message] | None = Field(default=None, description='protolint:disable REPEATED_FIELD_NAMES_PLURALIZED The history of interactions from a `Task`.')
143-
metadata: dict[str, Any] = Field(default=None, description='protolint:enable REPEATED_FIELD_NAMES_PLURALIZED A key/value object to store custom metadata about a task.')
143+
metadata: dict[str, Any] | None = Field(default=None, description='protolint:enable REPEATED_FIELD_NAMES_PLURALIZED A key/value object to store custom metadata about a task.')
144144

145145
def to_proto_json(self) -> dict:
146146
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -155,7 +155,7 @@ class Artifact(A2ABaseModel):
155155
name: str = Field(default='', description='A human readable name for the artifact.')
156156
description: str = Field(default='', description='Optional. A human readable description of the artifact.')
157157
parts: list[Part] = Field(..., description='The content of the artifact. Must contain at least one part.')
158-
metadata: dict[str, Any] = Field(default=None, description='Optional. Metadata included with the artifact.')
158+
metadata: dict[str, Any] | None = Field(default=None, description='Optional. Metadata included with the artifact.')
159159
extensions: list[str] | None = Field(default=None, description='The URIs of extensions that are present or contributed to this Artifact.')
160160

161161
def to_proto_json(self) -> dict:
@@ -170,7 +170,7 @@ class TaskStatusUpdateEvent(A2ABaseModel):
170170
task_id: str = Field(..., description='The ID of the task that has changed.')
171171
context_id: str = Field(..., description='The ID of the context that the task belongs to.')
172172
status: TaskStatus = Field(..., description='The new status of the task.')
173-
metadata: dict[str, Any] = Field(default=None, description='Optional. Metadata associated with the task update.')
173+
metadata: dict[str, Any] | None = Field(default=None, description='Optional. Metadata associated with the task update.')
174174

175175
def to_proto_json(self) -> dict:
176176
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -186,7 +186,7 @@ class TaskArtifactUpdateEvent(A2ABaseModel):
186186
artifact: Artifact = Field(..., description='The artifact that was generated or updated.')
187187
append: bool = Field(default=False, description='If true, the content of this artifact should be appended to a previously sent artifact with the same ID.')
188188
last_chunk: bool = Field(default=False, description='If true, this is the final chunk of the artifact.')
189-
metadata: dict[str, Any] = Field(default=None, description='Optional. Metadata associated with the artifact update.')
189+
metadata: dict[str, Any] | None = Field(default=None, description='Optional. Metadata associated with the artifact update.')
190190

191191
def to_proto_json(self) -> dict:
192192
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -258,7 +258,7 @@ class AgentCard(A2ABaseModel):
258258
name: str = Field(..., description='A human readable name for the agent. Example: "Recipe Agent"')
259259
description: str = Field(..., description='A human-readable description of the agent, assisting users and other agents in understanding its purpose. Example: "Agent that helps users with recipes and cooking."')
260260
supported_interfaces: list[AgentInterface] = Field(..., description='Ordered list of supported interfaces. The first entry is preferred.')
261-
provider: AgentProvider = Field(default=None, description='The service provider of the agent.')
261+
provider: AgentProvider | None = Field(default=None, description='The service provider of the agent.')
262262
version: str = Field(..., description='The version of the agent. Example: "1.0.0"')
263263
documentation_url: str | None = Field(default=None, description='A URL providing additional documentation about the agent.')
264264
capabilities: AgentCapabilities = Field(..., description='A2A Capability set supported by the agent.')
@@ -282,7 +282,7 @@ class AgentExtension(A2ABaseModel):
282282
uri: str = Field(default='', description='The unique URI identifying the extension.')
283283
description: str = Field(default='', description='A human-readable description of how this agent uses the extension.')
284284
required: bool = Field(default=False, description="If true, the client must understand and comply with the extension's requirements.")
285-
params: dict[str, Any] = Field(default=None, description='Optional. Extension-specific configuration parameters.')
285+
params: dict[str, Any] | None = Field(default=None, description='Optional. Extension-specific configuration parameters.')
286286

287287
def to_proto_json(self) -> dict:
288288
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -295,7 +295,7 @@ class AgentCardSignature(A2ABaseModel):
295295

296296
protected: str = Field(..., description='(-- api-linter: core::0140::reserved-words=disabled aip.dev/not-precedent: Backwards compatibility --) Required. The protected JWS header for the signature. This is always a base64url-encoded JSON object.')
297297
signature: str = Field(..., description='Required. The computed signature, base64url-encoded.')
298-
header: dict[str, Any] = Field(default=None, description='The unprotected JWS header values.')
298+
header: dict[str, Any] | None = Field(default=None, description='The unprotected JWS header values.')
299299

300300
def to_proto_json(self) -> dict:
301301
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -484,8 +484,8 @@ class SendMessageRequest(A2ABaseModel):
484484

485485
tenant: str = Field(default='', description='Optional. Tenant ID, provided as a path parameter.')
486486
message: Message = Field(..., description='The message to send to the agent.')
487-
configuration: SendMessageConfiguration = Field(default=None, description='Configuration for the send request.')
488-
metadata: dict[str, Any] = Field(default=None, description='A flexible key-value map for passing additional context or parameters.')
487+
configuration: SendMessageConfiguration | None = Field(default=None, description='Configuration for the send request.')
488+
metadata: dict[str, Any] | None = Field(default=None, description='A flexible key-value map for passing additional context or parameters.')
489489

490490
def to_proto_json(self) -> dict:
491491
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""
@@ -511,11 +511,11 @@ class ListTasksRequest(A2ABaseModel):
511511

512512
tenant: str = Field(default='', description='Tenant ID, provided as a path parameter.')
513513
context_id: str = Field(default='', description='Filter tasks by context ID to get tasks from a specific conversation or session.')
514-
status: TaskState = Field(default=None, description='Filter tasks by their current status state.')
514+
status: TaskState = Field(default=TaskState.TASK_STATE_UNSPECIFIED, description='Filter tasks by their current status state.')
515515
page_size: int | None = Field(default=None, description='The maximum number of tasks to return. The service may return fewer than this value. If unspecified, at most 50 tasks will be returned. The minimum value is 1. The maximum value is 100.')
516516
page_token: str = Field(default='', description='A page token, received from a previous `ListTasks` call. `ListTasksResponse.next_page_token`. Provide this to retrieve the subsequent page.')
517517
history_length: int | None = Field(default=None, description="The maximum number of messages to include in each task's history.")
518-
status_timestamp_after: datetime = Field(default=None, description='Filter tasks which have a status updated after the provided timestamp in ISO 8601 format (e.g., "2023-10-27T10:00:00Z"). Only tasks with a status timestamp time greater than or equal to this value will be returned.')
518+
status_timestamp_after: datetime | None = Field(default=None, description='Filter tasks which have a status updated after the provided timestamp in ISO 8601 format (e.g., "2023-10-27T10:00:00Z"). Only tasks with a status timestamp time greater than or equal to this value will be returned.')
519519
include_artifacts: bool | None = Field(default=None, description='Whether to include artifacts in the returned tasks. Defaults to false to reduce payload size.')
520520

521521
def to_proto_json(self) -> dict:
@@ -551,7 +551,7 @@ class CancelTaskRequest(A2ABaseModel):
551551

552552
tenant: str = Field(default='', description='Optional. Tenant ID, provided as a path parameter.')
553553
id: str = Field(..., description='The resource ID of the task to cancel.')
554-
metadata: dict[str, Any] = Field(default=None, description='A flexible key-value map for passing additional context or parameters.')
554+
metadata: dict[str, Any] | None = Field(default=None, description='A flexible key-value map for passing additional context or parameters.')
555555

556556
def to_proto_json(self) -> dict:
557557
"""Serialize to a ProtoJSON-compatible dict (camelCase keys, no None values)."""

0 commit comments

Comments
 (0)