Skip to content

Commit 92f16cf

Browse files
authored
Merge branch '1.0-dev' into feat/client-async-context-manager-and-close-689
2 parents 8cafca1 + 7dec763 commit 92f16cf

46 files changed

Lines changed: 3020 additions & 326 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ openapiv2
9292
opensource
9393
otherurl
9494
pb2
95+
poolclass
9596
postgres
9697
POSTGRES
9798
postgresql

pyproject.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3939
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
4040
signing = ["PyJWT>=2.0.0"]
4141
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
42+
db-cli = ["alembic>=1.14.0"]
4243

4344
sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
4445

@@ -49,6 +50,7 @@ all = [
4950
"a2a-sdk[grpc]",
5051
"a2a-sdk[telemetry]",
5152
"a2a-sdk[signing]",
53+
"a2a-sdk[db-cli]",
5254
]
5355

5456
[project.urls]
@@ -347,3 +349,17 @@ docstring-code-format = true
347349
docstring-code-line-length = "dynamic"
348350
quote-style = "single"
349351
indent-style = "space"
352+
353+
354+
[tool.alembic]
355+
356+
# path to migration scripts.
357+
script_location = "src/a2a/migrations"
358+
359+
# additional paths to be prepended to sys.path. defaults to the current working directory.
360+
prepend_sys_path = [
361+
"src"
362+
]
363+
364+
[project.scripts]
365+
a2a-db = "a2a.a2a_db_cli:run_migrations"

src/a2a/a2a_db_cli.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import argparse
2+
import logging
3+
import os
4+
5+
from importlib.resources import files
6+
7+
8+
try:
9+
from alembic import command
10+
from alembic.config import Config
11+
12+
except ImportError as e:
13+
raise ImportError(
14+
"CLI requires Alembic. Install with: 'pip install a2a-sdk[db-cli]'."
15+
) from e
16+
17+
18+
def _add_shared_args(
19+
parser: argparse.ArgumentParser, is_sub: bool = False
20+
) -> None:
21+
"""Add common arguments to the given parser."""
22+
prefix = 'sub_' if is_sub else ''
23+
parser.add_argument(
24+
'--database-url',
25+
dest=f'{prefix}database_url',
26+
help='Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.',
27+
)
28+
parser.add_argument(
29+
'--tasks-table',
30+
dest=f'{prefix}tasks_table',
31+
help='Custom tasks table to update. If not set, the default is "tasks".',
32+
)
33+
parser.add_argument(
34+
'--push-notification-configs-table',
35+
dest=f'{prefix}push_notification_configs_table',
36+
help='Custom push notification configs table to update. If not set, the default is "push_notification_configs".',
37+
)
38+
parser.add_argument(
39+
'-v',
40+
'--verbose',
41+
dest=f'{prefix}verbose',
42+
help='Enable verbose output (sets sqlalchemy.engine logging to INFO)',
43+
action='store_true',
44+
)
45+
parser.add_argument(
46+
'--sql',
47+
dest=f'{prefix}sql',
48+
help='Run migrations in sql mode (generate SQL instead of executing)',
49+
action='store_true',
50+
)
51+
52+
53+
def create_parser() -> argparse.ArgumentParser:
54+
"""Create the argument parser for the migration tool."""
55+
parser = argparse.ArgumentParser(description='A2A Database Migration Tool')
56+
57+
# Global options
58+
parser.add_argument(
59+
'--add_columns_owner_last_updated-default-owner',
60+
dest='owner',
61+
help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'",
62+
)
63+
_add_shared_args(parser)
64+
65+
subparsers = parser.add_subparsers(dest='cmd', help='Migration command')
66+
67+
# Upgrade command
68+
up_parser = subparsers.add_parser(
69+
'upgrade', help='Upgrade to a later version'
70+
)
71+
up_parser.add_argument(
72+
'revision',
73+
nargs='?',
74+
default='head',
75+
help='Revision target (default: head)',
76+
)
77+
up_parser.add_argument(
78+
'--add_columns_owner_last_updated-default-owner',
79+
dest='sub_owner',
80+
help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'legacy_v03_no_user_info'",
81+
)
82+
_add_shared_args(up_parser, is_sub=True)
83+
84+
# Downgrade command
85+
down_parser = subparsers.add_parser(
86+
'downgrade', help='Revert to a previous version'
87+
)
88+
down_parser.add_argument(
89+
'revision',
90+
nargs='?',
91+
default='base',
92+
help='Revision target (e.g., -1, base or a specific ID)',
93+
)
94+
_add_shared_args(down_parser, is_sub=True)
95+
96+
return parser
97+
98+
99+
def run_migrations() -> None:
100+
"""CLI tool to manage database migrations."""
101+
# Configure logging to show INFO messages
102+
logging.basicConfig(level=logging.INFO, format='%(levelname)s %(message)s')
103+
104+
parser = create_parser()
105+
args = parser.parse_args()
106+
107+
# Default to upgrade head if no command is provided
108+
if not args.cmd:
109+
args.cmd = 'upgrade'
110+
args.revision = 'head'
111+
112+
# Locate the bundled alembic.ini
113+
ini_path = files('a2a').joinpath('alembic.ini')
114+
cfg = Config(str(ini_path))
115+
116+
# Dynamically set the script location
117+
migrations_path = files('a2a').joinpath('migrations')
118+
cfg.set_main_option('script_location', str(migrations_path))
119+
120+
# Consolidate owner, db_url, tables, verbose and sql values
121+
owner = args.owner or getattr(args, 'sub_owner', None)
122+
db_url = args.database_url or getattr(args, 'sub_database_url', None)
123+
task_table = args.tasks_table or getattr(args, 'sub_tasks_table', None)
124+
push_notification_configs_table = (
125+
args.push_notification_configs_table
126+
or getattr(args, 'sub_push_notification_configs_table', None)
127+
)
128+
129+
verbose = args.verbose or getattr(args, 'sub_verbose', False)
130+
sql = args.sql or getattr(args, 'sub_sql', False)
131+
132+
# Pass custom arguments to the migration context
133+
if owner:
134+
cfg.set_main_option(
135+
'add_columns_owner_last_updated_default_owner', owner
136+
)
137+
if db_url:
138+
os.environ['DATABASE_URL'] = db_url
139+
if task_table:
140+
cfg.set_main_option('tasks_table', task_table)
141+
if push_notification_configs_table:
142+
cfg.set_main_option(
143+
'push_notification_configs_table', push_notification_configs_table
144+
)
145+
if verbose:
146+
cfg.set_main_option('verbose', 'true')
147+
148+
# Execute the requested command
149+
if args.cmd == 'upgrade':
150+
logging.info('Upgrading database to %s', args.revision)
151+
command.upgrade(cfg, args.revision, sql=sql)
152+
elif args.cmd == 'downgrade':
153+
logging.info('Downgrading database to %s', args.revision)
154+
command.downgrade(cfg, args.revision, sql=sql)
155+
156+
logging.info('Done.')

src/a2a/alembic.ini

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# A generic, single database configuration.
2+
3+
[loggers]
4+
keys = root,sqlalchemy,alembic
5+
6+
[handlers]
7+
keys = console
8+
9+
[formatters]
10+
keys = generic
11+
12+
[logger_root]
13+
level = INFO
14+
handlers = console
15+
qualname =
16+
17+
[logger_sqlalchemy]
18+
level = WARNING
19+
handlers =
20+
qualname = sqlalchemy.engine
21+
22+
[logger_alembic]
23+
level = WARNING
24+
handlers =
25+
qualname = alembic
26+
27+
[handler_console]
28+
class = StreamHandler
29+
args = (sys.stderr,)
30+
level = NOTSET
31+
formatter = generic
32+
33+
[formatter_generic]
34+
format = %(levelname)-5.5s [%(name)s] %(message)s
35+
datefmt = %H:%M:%S

src/a2a/client/base_client.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
AgentCard,
1616
CancelTaskRequest,
1717
CreateTaskPushNotificationConfigRequest,
18+
DeleteTaskPushNotificationConfigRequest,
1819
GetTaskPushNotificationConfigRequest,
1920
GetTaskRequest,
21+
ListTaskPushNotificationConfigsRequest,
22+
ListTaskPushNotificationConfigsResponse,
2023
ListTasksRequest,
2124
ListTasksResponse,
2225
Message,
@@ -189,7 +192,7 @@ async def cancel_task(
189192
request, context=context, extensions=extensions
190193
)
191194

192-
async def set_task_callback(
195+
async def create_task_push_notification_config(
193196
self,
194197
request: CreateTaskPushNotificationConfigRequest,
195198
*,
@@ -206,11 +209,11 @@ async def set_task_callback(
206209
Returns:
207210
The created or updated `TaskPushNotificationConfig` object.
208211
"""
209-
return await self._transport.set_task_callback(
212+
return await self._transport.create_task_push_notification_config(
210213
request, context=context, extensions=extensions
211214
)
212215

213-
async def get_task_callback(
216+
async def get_task_push_notification_config(
214217
self,
215218
request: GetTaskPushNotificationConfigRequest,
216219
*,
@@ -227,7 +230,46 @@ async def get_task_callback(
227230
Returns:
228231
A `TaskPushNotificationConfig` object containing the configuration.
229232
"""
230-
return await self._transport.get_task_callback(
233+
return await self._transport.get_task_push_notification_config(
234+
request, context=context, extensions=extensions
235+
)
236+
237+
async def list_task_push_notification_configs(
238+
self,
239+
request: ListTaskPushNotificationConfigsRequest,
240+
*,
241+
context: ClientCallContext | None = None,
242+
extensions: list[str] | None = None,
243+
) -> ListTaskPushNotificationConfigsResponse:
244+
"""Lists push notification configurations for a specific task.
245+
246+
Args:
247+
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
248+
context: The client call context.
249+
extensions: List of extensions to be activated.
250+
251+
Returns:
252+
A `ListTaskPushNotificationConfigsResponse` object.
253+
"""
254+
return await self._transport.list_task_push_notification_configs(
255+
request, context=context, extensions=extensions
256+
)
257+
258+
async def delete_task_push_notification_config(
259+
self,
260+
request: DeleteTaskPushNotificationConfigRequest,
261+
*,
262+
context: ClientCallContext | None = None,
263+
extensions: list[str] | None = None,
264+
) -> None:
265+
"""Deletes the push notification configuration for a specific task.
266+
267+
Args:
268+
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
269+
context: The client call context.
270+
extensions: List of extensions to be activated.
271+
"""
272+
await self._transport.delete_task_push_notification_config(
231273
request, context=context, extensions=extensions
232274
)
233275

src/a2a/client/client.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
AgentCard,
1717
CancelTaskRequest,
1818
CreateTaskPushNotificationConfigRequest,
19+
DeleteTaskPushNotificationConfigRequest,
1920
GetTaskPushNotificationConfigRequest,
2021
GetTaskRequest,
22+
ListTaskPushNotificationConfigsRequest,
23+
ListTaskPushNotificationConfigsResponse,
2124
ListTasksRequest,
2225
ListTasksResponse,
2326
Message,
2427
PushNotificationConfig,
28+
SendMessageConfiguration,
2529
StreamResponse,
2630
SubscribeToTaskRequest,
2731
Task,
@@ -70,7 +74,7 @@ class ClientConfig:
7074
push_notification_configs: list[PushNotificationConfig] = dataclasses.field(
7175
default_factory=list
7276
)
73-
"""Push notification callbacks to use for every request."""
77+
"""Push notification configurations to use for every request."""
7478

7579
extensions: list[str] = dataclasses.field(default_factory=list)
7680
"""A list of extension URIs the client supports."""
@@ -127,6 +131,7 @@ async def send_message(
127131
self,
128132
request: Message,
129133
*,
134+
configuration: SendMessageConfiguration | None = None,
130135
context: ClientCallContext | None = None,
131136
request_metadata: dict[str, Any] | None = None,
132137
extensions: list[str] | None = None,
@@ -172,7 +177,7 @@ async def cancel_task(
172177
"""Requests the agent to cancel a specific task."""
173178

174179
@abstractmethod
175-
async def set_task_callback(
180+
async def create_task_push_notification_config(
176181
self,
177182
request: CreateTaskPushNotificationConfigRequest,
178183
*,
@@ -182,7 +187,7 @@ async def set_task_callback(
182187
"""Sets or updates the push notification configuration for a specific task."""
183188

184189
@abstractmethod
185-
async def get_task_callback(
190+
async def get_task_push_notification_config(
186191
self,
187192
request: GetTaskPushNotificationConfigRequest,
188193
*,
@@ -191,6 +196,26 @@ async def get_task_callback(
191196
) -> TaskPushNotificationConfig:
192197
"""Retrieves the push notification configuration for a specific task."""
193198

199+
@abstractmethod
200+
async def list_task_push_notification_configs(
201+
self,
202+
request: ListTaskPushNotificationConfigsRequest,
203+
*,
204+
context: ClientCallContext | None = None,
205+
extensions: list[str] | None = None,
206+
) -> ListTaskPushNotificationConfigsResponse:
207+
"""Lists push notification configurations for a specific task."""
208+
209+
@abstractmethod
210+
async def delete_task_push_notification_config(
211+
self,
212+
request: DeleteTaskPushNotificationConfigRequest,
213+
*,
214+
context: ClientCallContext | None = None,
215+
extensions: list[str] | None = None,
216+
) -> None:
217+
"""Deletes the push notification configuration for a specific task."""
218+
194219
@abstractmethod
195220
async def subscribe(
196221
self,

0 commit comments

Comments
 (0)