Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5441801
Add snake_case generation
holtskinner Jul 16, 2025
c8eb963
Edit some snake_case
holtskinner Jul 16, 2025
0028bda
Update types.py
holtskinner Jul 17, 2025
7649d0c
Update usage in library
holtskinner Jul 17, 2025
d3a35e3
Update Pydantic
holtskinner Jul 17, 2025
572fd31
Undo pydantic upgrade
holtskinner Jul 17, 2025
71152f2
Switch to alias generator
holtskinner Jul 17, 2025
4504f0f
Switch to to_camel
holtskinner Jul 17, 2025
fd7dd8c
Fix missing changes
holtskinner Jul 17, 2025
f24739b
Merge branch 'main' into snake-case-field
holtskinner Jul 18, 2025
38c1dc9
Add --field-constraints to generate_types
holtskinner Jul 18, 2025
8597e89
Fix duplicate fields
holtskinner Jul 18, 2025
3a1e924
Fix push_notification_config
holtskinner Jul 18, 2025
167f003
Fix links
holtskinner Jul 21, 2025
5b1f8d0
re-add Alias
holtskinner Jul 21, 2025
e3d5057
Add pydantic mypy plugin
holtskinner Jul 21, 2025
fe8beff
Use custom camelCase alias generator
holtskinner Jul 21, 2025
4944c4f
Merge branch 'main' into snake-case-field
holtskinner Jul 21, 2025
02db1a5
Fixed camelCase in JSON Payload
holtskinner Jul 21, 2025
9baefef
Merge branch 'snake-case-field' of https://github.com/google-a2a/a2a-…
holtskinner Jul 21, 2025
14b8308
Update pyproject.toml
holtskinner Jul 21, 2025
66011b6
Add mypy plugin
holtskinner Jul 21, 2025
59d5ffb
Add backwards compatibility for camelCase
holtskinner Jul 21, 2025
d5253af
Add Lint ignore
holtskinner Jul 21, 2025
04f105a
spelling
holtskinner Jul 21, 2025
28656a0
Add support for camelCase `__getattr__`
holtskinner Jul 21, 2025
71adfec
Simplify `__setattr__` and `__getattr__` implementation
holtskinner Jul 21, 2025
e36211b
Linting
holtskinner Jul 21, 2025
aebec89
Merge branch 'main' into snake-case-field
holtskinner Jul 21, 2025
373959d
Format spelling
holtskinner Jul 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0", "proto
homepage = "https://a2a-protocol.org/"
repository = "https://github.com/a2aproject/a2a-python"
changelog = "https://github.com/a2aproject/a2a-python/blob/main/CHANGELOG.md"
documentation = "https://a2a-protocol.org/latest/sdk/python/"
documentation = "https://a2a-protocol.org/sdk/python/"
Comment thread
holtskinner marked this conversation as resolved.
Outdated

[tool.hatch.build.targets.wheel]
packages = ["src/a2a"]
Expand Down Expand Up @@ -96,6 +96,9 @@ url = "https://test.pypi.org/simple/"
publish-url = "https://test.pypi.org/legacy/"
explicit = true

[tool.mypy]
plugins = ['pydantic.mypy']

[tool.pyright]
include = ["src"]
exclude = [
Expand Down
5 changes: 4 additions & 1 deletion scripts/generate_types.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ uv run datamodel-codegen \
--class-name A2A \
--use-standard-collections \
--use-subclass-enum \
--base-class a2a._base.A2ABaseModel
--base-class a2a._base.A2ABaseModel \
--field-constraints \
--snake-case-field \
--no-alias

echo "Formatting generated file with ruff..."
uv run ruff format "$GENERATED_FILE"
Expand Down
19 changes: 19 additions & 0 deletions src/a2a/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


def to_camel_custom(snake: str) -> str:
"""Convert a snake_case string to camelCase.

Args:
snake: The string to convert.

Returns:
The converted camelCase string.
"""
# First, remove any trailing underscores. This is common for names that
# conflict with Python keywords, like 'in_' or 'from_'.
if snake.endswith('_'):
snake = snake.rstrip('_')
return to_camel(snake)


class A2ABaseModel(BaseModel):
Expand All @@ -12,4 +29,6 @@ class A2ABaseModel(BaseModel):
# SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name
validate_by_name=True,
validate_by_alias=True,
serialize_by_alias=True,
alias_generator=to_camel_custom,
)
6 changes: 3 additions & 3 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def intercept(
if (
agent_card is None
or agent_card.security is None
or agent_card.securitySchemes is None
or agent_card.security_schemes is None
):
return request_payload, http_kwargs

Expand All @@ -45,8 +45,8 @@ async def intercept(
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.securitySchemes:
scheme_def_union = agent_card.securitySchemes.get(
if credential and scheme_name in agent_card.security_schemes:
scheme_def_union = agent_card.security_schemes.get(
scheme_name
)
if not scheme_def_union:
Expand Down
4 changes: 2 additions & 2 deletions src/a2a/client/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def create_text_message_object(
content: The text content of the message. Defaults to an empty string.

Returns:
A `Message` object with a new UUID messageId.
A `Message` object with a new UUID message_id.
"""
return Message(
role=role, parts=[Part(TextPart(text=content))], messageId=str(uuid4())
role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4())
)
22 changes: 11 additions & 11 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def __init__( # noqa: PLR0913
# match the request. Otherwise, create them
if self._params:
if task_id:
self._params.message.taskId = task_id
self._params.message.task_id = task_id
if task and task.id != task_id:
raise ServerError(InvalidParamsError(message='bad task id'))
else:
self._check_or_generate_task_id()
if context_id:
self._params.message.contextId = context_id
if task and task.contextId != context_id:
self._params.message.context_id = context_id
if task and task.context_id != context_id:
raise ServerError(
InvalidParamsError(message='bad context id')
)
Expand Down Expand Up @@ -148,17 +148,17 @@ def _check_or_generate_task_id(self) -> None:
if not self._params:
return

if not self._task_id and not self._params.message.taskId:
self._params.message.taskId = str(uuid.uuid4())
if self._params.message.taskId:
self._task_id = self._params.message.taskId
if not self._task_id and not self._params.message.task_id:
self._params.message.task_id = str(uuid.uuid4())
if self._params.message.task_id:
self._task_id = self._params.message.task_id

def _check_or_generate_context_id(self) -> None:
"""Ensures a context ID is present, generating one if necessary."""
if not self._params:
return

if not self._context_id and not self._params.message.contextId:
self._params.message.contextId = str(uuid.uuid4())
if self._params.message.contextId:
self._context_id = self._params.message.contextId
if not self._context_id and not self._params.message.context_id:
self._params.message.context_id = str(uuid.uuid4())
if self._params.message.context_id:
self._context_id = self._params.message.context_id
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

Args:
should_populate_referred_tasks: If True, the builder will fetch tasks
referenced in `params.message.referenceTaskIds` and populate the
referenced in `params.message.reference_task_ids` and populate the
`related_tasks` field in the RequestContext. Defaults to False.
task_store: The TaskStore instance to use for fetching referred tasks.
Required if `should_populate_referred_tasks` is True.
Expand All @@ -26,43 +26,43 @@
self._task_store = task_store
self._should_populate_referred_tasks = should_populate_referred_tasks

async def build(
self,
params: MessageSendParams | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
"""Builds the request context for an agent execution.

This method assembles the RequestContext object. If the builder was
initialized with `should_populate_referred_tasks=True`, it fetches all tasks
referenced in `params.message.referenceTaskIds` from the `task_store`.
referenced in `params.message.reference_task_ids` from the `task_store`.

Args:
params: The parameters of the incoming message send request.
task_id: The ID of the task being executed.
context_id: The ID of the current execution context.
task: The primary task object associated with the request.
context: The server call context, containing metadata about the call.

Returns:
An instance of RequestContext populated with the provided information
and potentially a list of related tasks.
"""

Check notice on line 53 in src/a2a/server/agent_execution/simple_request_context_builder.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/agent_execution/request_context_builder.py (12-20)
related_tasks: list[Task] | None = None

if (
self._task_store
and self._should_populate_referred_tasks
and params
and params.message.referenceTaskIds
and params.message.reference_task_ids
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
for task_id in params.message.referenceTaskIds
for task_id in params.message.reference_task_ids
]
)
related_tasks = [x for x in tasks if x is not None]
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def add_routes_to_app(
)(self._handle_requests)
app.get(agent_card_url)(self._handle_get_agent_card)

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app.get(extended_agent_card_url)(
self._handle_get_authenticated_extended_agent_card
)
Expand Down
8 changes: 4 additions & 4 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def __init__(
agent_card=agent_card, request_handler=http_handler
)
if (
self.agent_card.supportsAuthenticatedExtendedCard
self.agent_card.supports_authenticated_extended_card
and self.extended_agent_card is None
):
logger.error(
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
)
self._context_builder = context_builder or DefaultCallContextBuilder()

Expand Down Expand Up @@ -421,7 +421,7 @@ async def _handle_get_authenticated_extended_agent_card(
self, request: Request
) -> JSONResponse:
"""Handles GET requests for the authenticated extended agent card."""
if not self.agent_card.supportsAuthenticatedExtendedCard:
if not self.agent_card.supports_authenticated_extended_card:
return JSONResponse(
{'error': 'Extended agent card not supported or not enabled.'},
status_code=404,
Expand All @@ -435,7 +435,7 @@ async def _handle_get_authenticated_extended_agent_card(
by_alias=True,
)
)
# If supportsAuthenticatedExtendedCard is true, but no specific
# If supports_authenticated_extended_card is true, but no specific
# extended_agent_card was provided during server initialization,
# return a 404
return JSONResponse(
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def routes(
),
]

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app_routes.append(
Route(
extended_agent_card_url,
Expand Down
10 changes: 5 additions & 5 deletions src/a2a/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class TaskMixin:
"""Mixin providing standard task columns with proper type handling."""

id: Mapped[str] = mapped_column(String(36), primary_key=True, index=True)
contextId: Mapped[str] = mapped_column(String(36), nullable=False) # noqa: N815
context_id: Mapped[str] = mapped_column(String(36), nullable=False)
kind: Mapped[str] = mapped_column(
String(16), nullable=False, default='task'
)
Expand All @@ -148,12 +148,12 @@ def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = (
'<{CLS}(id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
'<{CLS}(id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
)
return repr_template.format(
CLS=self.__class__.__name__,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down Expand Up @@ -188,11 +188,11 @@ class TaskModel(TaskMixin, base):
@override
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = '<TaskModel[{TABLE}](id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
repr_template = '<TaskModel[{TABLE}](id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
return repr_template.format(
TABLE=table_name,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down
29 changes: 15 additions & 14 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,25 @@
"""
task: Task | None = await self.task_store.get(params.id)
if not task:
raise ServerError(error=TaskNotFoundError())

task_manager = TaskManager(
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
)
result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
queue = EventQueue()

Check notice on line 137 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (429-443)

await self.agent_executor.cancel(
RequestContext(
None,
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task=task,
),
queue,
Expand Down Expand Up @@ -184,8 +184,8 @@
"""
# Create task manager and validate existing task
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_id=params.message.task_id,
context_id=params.message.context_id,
task_store=self.task_store,
initial_message=params.message,
)
Expand All @@ -205,7 +205,7 @@
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
context_id=params.message.context_id,
task=task,
context=context,
)
Expand All @@ -218,10 +218,10 @@
if (
self._push_config_store
and params.configuration
and params.configuration.pushNotificationConfig
and params.configuration.push_notification_config
):
await self._push_config_store.set_info(
task_id, params.configuration.pushNotificationConfig
task_id, params.configuration.push_notification_config
)

queue = await self._queue_manager.create_or_tap(task_id)
Expand Down Expand Up @@ -366,13 +366,13 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.taskId)
task: Task | None = await self.task_store.get(params.task_id)
if not task:
raise ServerError(error=TaskNotFoundError())

await self._push_config_store.set_info(
params.taskId,
params.pushNotificationConfig,
params.task_id,
params.push_notification_config,
)

return params
Expand Down Expand Up @@ -404,7 +404,8 @@
)

return TaskPushNotificationConfig(
taskId=params.id, pushNotificationConfig=push_notification_config[0]
task_id=params.id,
push_notification_config=push_notification_config[0],
)

async def on_resubscribe_to_task(
Expand All @@ -425,21 +426,21 @@
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
)

result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
raise ServerError(error=TaskNotFoundError())

Check notice on line 443 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (125-137)

consumer = EventConsumer(queue)
async for event in result_aggregator.consume_and_emit(consumer):
Expand Down Expand Up @@ -470,7 +471,7 @@
for config in push_notification_config_list:
task_push_notification_config.append(
TaskPushNotificationConfig(
taskId=params.id, pushNotificationConfig=config
task_id=params.id, push_notification_config=config
)
)

Expand All @@ -493,5 +494,5 @@
raise ServerError(error=TaskNotFoundError())

await self._push_config_store.delete_info(
params.id, params.pushNotificationConfigId
params.id, params.push_notification_config_id
)
2 changes: 1 addition & 1 deletion src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ async def GetTaskPushNotificationConfig(
return a2a_pb2.TaskPushNotificationConfig()

@validate(
lambda self: self.agent_card.capabilities.pushNotifications,
lambda self: self.agent_card.capabilities.push_notifications,
'Push notifications are not supported by the agent',
)
async def CreateTaskPushNotificationConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/request_handlers/jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def get_push_notification_config(
)

@validate(
lambda self: self.agent_card.capabilities.pushNotifications,
lambda self: self.agent_card.capabilities.push_notifications,
'Push notifications are not supported by the agent',
)
async def set_push_notification_config(
Expand Down
4 changes: 2 additions & 2 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _to_orm(self, task: Task) -> TaskModel:
"""Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
return self.task_model(
id=task.id,
contextId=task.contextId,
context_id=task.context_id,
kind=task.kind,
status=task.status,
artifacts=task.artifacts,
Expand All @@ -108,7 +108,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Map database columns to Pydantic model fields
task_data_from_db = {
'id': task_model.id,
'contextId': task_model.contextId,
'context_id': task_model.context_id,
'kind': task_model.kind,
'status': task_model.status,
'artifacts': task_model.artifacts,
Expand Down
Loading
Loading