Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
run: |
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install dependencies
run: uv sync --dev --extra sql
run: uv sync --dev --extra sql --extra encryption
- name: Run tests and check coverage
run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=89
- name: Show coverage summary in log
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"]
encryption = ["cryptography>=43.0.0"]

[project.urls]
homepage = "https://a2aproject.github.io/A2A/"
Expand Down
57 changes: 56 additions & 1 deletion src/a2a/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def override(func): # noqa: ANN001, ANN201


try:
from sqlalchemy import JSON, Dialect, String
from sqlalchemy import JSON, Dialect, LargeBinary, String
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Expand Down Expand Up @@ -208,3 +208,58 @@ class TaskModel(TaskMixin, Base):
"""Default task model with standard table name."""

__tablename__ = 'tasks'


# PushNotificationConfigMixin that can be used with any table name
class PushNotificationConfigMixin:
"""Mixin providing standard push notification config columns."""

task_id: Mapped[str] = mapped_column(String(36), primary_key=True)
config_id: Mapped[str] = mapped_column(String(255), primary_key=True)
config_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)

@override
def __repr__(self) -> str:
"""Return a string representation of the push notification config."""
repr_template = '<{CLS}(task_id="{TID}", config_id="{CID}")>'
return repr_template.format(
CLS=self.__class__.__name__,
TID=self.task_id,
CID=self.config_id,
)


def create_push_notification_config_model(
table_name: str = 'push_notification_configs',
base: type[DeclarativeBase] = Base,
) -> type:
"""Create a PushNotificationConfigModel class with a configurable table name."""

class PushNotificationConfigModel(PushNotificationConfigMixin, base):
__tablename__ = table_name

@override
def __repr__(self) -> str:
"""Return a string representation of the push notification config."""
repr_template = '<PushNotificationConfigModel[{TABLE}](task_id="{TID}", config_id="{CID}")>'
return repr_template.format(
TABLE=table_name,
TID=self.task_id,
CID=self.config_id,
)

PushNotificationConfigModel.__name__ = (
f'PushNotificationConfigModel_{table_name}'
)
PushNotificationConfigModel.__qualname__ = (
f'PushNotificationConfigModel_{table_name}'
)

return PushNotificationConfigModel


# Default PushNotificationConfigModel for backward compatibility
class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
"""Default push notification config model with standard table name."""

__tablename__ = 'push_notification_configs'
4 changes: 4 additions & 0 deletions src/a2a/server/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from a2a.server.tasks.base_push_notification_sender import (
BasePushNotificationSender,
)
from a2a.server.tasks.database_push_notification_config_store import (
DatabasePushNotificationConfigStore,
)
from a2a.server.tasks.database_task_store import DatabaseTaskStore
from a2a.server.tasks.inmemory_push_notification_config_store import (
InMemoryPushNotificationConfigStore,
Expand All @@ -20,6 +23,7 @@

__all__ = [
'BasePushNotificationSender',
'DatabasePushNotificationConfigStore',
'DatabaseTaskStore',
'InMemoryPushNotificationConfigStore',
'InMemoryTaskStore',
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/tasks/base_push_notification_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def _dispatch_notification(
response = await self._client.post(
url,
json=task.model_dump(mode='json', exclude_none=True),
headers=headers
headers=headers,
)
response.raise_for_status()
logger.info(
Expand Down
253 changes: 253 additions & 0 deletions src/a2a/server/tasks/database_push_notification_config_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import json
import logging

from typing import TYPE_CHECKING


try:
from sqlalchemy import (
delete,
select,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
)
except ImportError as e:
raise ImportError(
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
'Install with one of: '
"'pip install a2a-sdk[postgresql]', "
"'pip install a2a-sdk[mysql]', "
"'pip install a2a-sdk[sqlite]', "
"or 'pip install a2a-sdk[sql]'"
) from e

from a2a.server.models import (
Base,
PushNotificationConfigModel,
create_push_notification_config_model,
)
from a2a.server.tasks.push_notification_config_store import (
PushNotificationConfigStore,
)
from a2a.types import PushNotificationConfig


if TYPE_CHECKING:
from cryptography.fernet import Fernet

Check failure on line 39 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)

Check failure on line 39 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)


logger = logging.getLogger(__name__)


class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
"""SQLAlchemy-based implementation of PushNotificationConfigStore.

Stores push notification configurations in a database supported by SQLAlchemy.
"""

engine: AsyncEngine
async_session_maker: async_sessionmaker[AsyncSession]
create_table: bool
_initialized: bool
config_model: type[PushNotificationConfigModel]
_fernet: 'Fernet | None'

Check failure on line 56 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check failure on line 56 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 56 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 56 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)

def __init__(
self,
engine: AsyncEngine,
create_table: bool = True,
table_name: str = 'push_notification_configs',
encryption_key: str | bytes | None = None,
) -> None:
"""Initializes the DatabasePushNotificationConfigStore.

Args:
engine: An existing SQLAlchemy AsyncEngine to be used by the store.
create_table: If true, create the table on initialization.
table_name: Name of the database table. Defaults to 'push_notification_configs'.
encryption_key: A key for encrypting sensitive configuration data.
If provided, `config_data` will be encrypted in the database.
The key must be a URL-safe base64-encoded 32-byte key.
"""
logger.debug(
f'Initializing DatabasePushNotificationConfigStore with existing engine, table: {table_name}'
)
self.engine = engine
self.async_session_maker = async_sessionmaker(
self.engine, expire_on_commit=False
)
self.create_table = create_table
self._initialized = False
self.config_model = (
PushNotificationConfigModel
if table_name == 'push_notification_configs'
else create_push_notification_config_model(table_name)
)
self._fernet = None

Check failure on line 89 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 89 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

if encryption_key:
try:
from cryptography.fernet import Fernet # noqa: PLC0415

Check failure on line 93 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check failure on line 93 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 93 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 93 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)
except ImportError as e:
raise ImportError(
"DatabasePushNotificationConfigStore with encryption requires the 'cryptography' "
'library. Install with: '
"'pip install a2a-sdk[encryption]'"
) from e

if isinstance(encryption_key, str):
encryption_key = encryption_key.encode('utf-8')
self._fernet = Fernet(encryption_key)

Check failure on line 103 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check failure on line 103 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 103 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`fernet` is not a recognized word. (unrecognized-spelling)

Check warning on line 103 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Fernet` is not a recognized word. (unrecognized-spelling)
logger.debug(
'Encryption enabled for push notification config store.'
)

async def initialize(self) -> None:
"""Initialize the database and create the table if needed."""
if self._initialized:
return

logger.debug(
'Initializing database schema for push notification configs...'
)
if self.create_table:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self._initialized = True
logger.debug(
'Database schema for push notification configs initialized.'
)

async def _ensure_initialized(self) -> None:
"""Ensure the database connection is initialized."""
if not self._initialized:
await self.initialize()

def _to_orm(
self, task_id: str, config: PushNotificationConfig
) -> PushNotificationConfigModel:
"""Maps a Pydantic PushNotificationConfig to a SQLAlchemy model instance.

The config data is serialized to JSON bytes, and encrypted if a key is configured.
"""
json_payload = config.model_dump_json().encode('utf-8')

if self._fernet:
data_to_store = self._fernet.encrypt(json_payload)
else:
data_to_store = json_payload

return self.config_model(
task_id=task_id,
config_id=config.id,
config_data=data_to_store,
)

def _from_orm(
self, model_instance: PushNotificationConfigModel
) -> PushNotificationConfig:
"""Maps a SQLAlchemy model instance to a Pydantic PushNotificationConfig.

Handles decryption if a key is configured.
"""
payload = model_instance.config_data

if self._fernet:
from cryptography.fernet import InvalidToken # noqa: PLC0415

try:
decrypted_payload = self._fernet.decrypt(payload)
return PushNotificationConfig.model_validate_json(
decrypted_payload
)
except InvalidToken:
# This could be unencrypted data if encryption was enabled after data was stored.
# We'll fall through and try to parse it as plain JSON.
logger.debug(
'Could not decrypt config for task %s, config %s. '
'Attempting to parse as unencrypted JSON.',
model_instance.task_id,
model_instance.config_id,
)

# If no fernet or if decryption failed, try to parse as plain JSON.
try:
return PushNotificationConfig.model_validate_json(payload)
except json.JSONDecodeError as e:
if self._fernet:
raise ValueError(
'Failed to decrypt data; incorrect key or corrupted data.'
) from e
raise ValueError(
'Failed to parse data; it may be encrypted but no key is configured.'
) from e
Comment thread
kthota-g marked this conversation as resolved.

async def set_info(
self, task_id: str, notification_config: PushNotificationConfig
) -> None:
"""Sets or updates the push notification configuration for a task."""
await self._ensure_initialized()

config_to_save = notification_config.model_copy()
if config_to_save.id is None:
config_to_save.id = task_id

db_config = self._to_orm(task_id, config_to_save)
async with self.async_session_maker.begin() as session:
await session.merge(db_config)
logger.debug(
f'Push notification config for task {task_id} with config id {config_to_save.id} saved/updated.'
)

async def get_info(self, task_id: str) -> list[PushNotificationConfig]:
"""Retrieves all push notification configurations for a task."""
await self._ensure_initialized()
async with self.async_session_maker() as session:
stmt = select(self.config_model).where(
self.config_model.task_id == task_id
)
result = await session.execute(stmt)
models = result.scalars().all()

configs = []
for model in models:
try:
configs.append(self._from_orm(model))
except ValueError as e:
logger.error(
'Could not deserialize push notification config for task %s, config %s: %s',
model.task_id,
model.config_id,
e,
)
return configs

async def delete_info(
self, task_id: str, config_id: str | None = None
) -> None:
"""Deletes push notification configurations for a task.

If config_id is provided, only that specific configuration is deleted.
If config_id is None, all configurations for the task are deleted.
"""
await self._ensure_initialized()
async with self.async_session_maker.begin() as session:
stmt = delete(self.config_model).where(
self.config_model.task_id == task_id
)
if config_id is not None:
stmt = stmt.where(self.config_model.config_id == config_id)

result = await session.execute(stmt)

if result.rowcount > 0:
logger.info(
f'Deleted {result.rowcount} push notification config(s) for task {task_id}.'
)
else:
logger.warning(
f'Attempted to delete non-existent push notification config for task {task_id} with config_id: {config_id}'

Check failure on line 252 in src/a2a/server/tasks/database_push_notification_config_store.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`non-existent` matches a line_forbidden.patterns entry: `\b[Nn]o[nt][- ]existent\b`. (forbidden-pattern)
)
Loading
Loading