Skip to content

Commit 426d130

Browse files
committed
feat(compat): Improve compatibility code.
Replaces the dynamic string manipulation for TaskState enum mapping with static bidirectional dictionaries. This is more readable, slightly faster, and safer.
1 parent 419b401 commit 426d130

5 files changed

Lines changed: 106 additions & 44 deletions

File tree

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ notif
8686
npx
8787
oauthoidc
8888
oidc
89+
Oneof
8990
OpenAPI
9091
openapiv
9192
openapiv2

.github/actions/spelling/expect.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

scripts/gen_proto.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ sed 's/import a2a_pb2 as a2a__pb2/from . import a2a_pb2 as a2a__pb2/g' src/a2a/t
2222

2323
# Download legacy v0.3 compatibility protobuf code
2424
echo "Downloading legacy v0.3 proto file..."
25-
curl -o src/a2a/compat/v0_3/a2a_v0_3.proto https://raw.githubusercontent.com/a2aproject/A2A/a8b45dcc429a5571ef8a24c36336bf84b89bbd7f/specification/grpc/a2a.proto
25+
# b3b266d127dde3d1000ec103b252d1de81289e83 Is a2a.proto version from 0.3 branch with latests fixes.
26+
curl -o src/a2a/compat/v0_3/a2a_v0_3.proto https://raw.githubusercontent.com/a2aproject/A2A/b3b266d127dde3d1000ec103b252d1de81289e83/specification/grpc/a2a.proto
2627

2728
# Generate legacy v0.3 compatibility protobuf code
2829
echo "Generating legacy v0.3 compatibility protobuf code"

src/a2a/compat/v0_3/conversions.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
import base64
22

3-
from typing import Any, cast
3+
from typing import Any
44

55
from google.protobuf.json_format import MessageToDict, ParseDict
66

77
from a2a.compat.v0_3 import types as types_v03
88
from a2a.types import a2a_pb2 as pb2_v10
99

1010

11+
_COMPAT_TO_CORE_TASK_STATE: dict[types_v03.TaskState, Any] = {
12+
types_v03.TaskState.unknown: pb2_v10.TaskState.TASK_STATE_UNSPECIFIED,
13+
types_v03.TaskState.submitted: pb2_v10.TaskState.TASK_STATE_SUBMITTED,
14+
types_v03.TaskState.working: pb2_v10.TaskState.TASK_STATE_WORKING,
15+
types_v03.TaskState.completed: pb2_v10.TaskState.TASK_STATE_COMPLETED,
16+
types_v03.TaskState.failed: pb2_v10.TaskState.TASK_STATE_FAILED,
17+
types_v03.TaskState.canceled: pb2_v10.TaskState.TASK_STATE_CANCELED,
18+
types_v03.TaskState.input_required: pb2_v10.TaskState.TASK_STATE_INPUT_REQUIRED,
19+
types_v03.TaskState.rejected: pb2_v10.TaskState.TASK_STATE_REJECTED,
20+
types_v03.TaskState.auth_required: pb2_v10.TaskState.TASK_STATE_AUTH_REQUIRED,
21+
}
22+
23+
_CORE_TO_COMPAT_TASK_STATE: dict[Any, types_v03.TaskState] = {
24+
v: k for k, v in _COMPAT_TO_CORE_TASK_STATE.items()
25+
}
26+
27+
1128
def to_core_part(compat_part: types_v03.Part) -> pb2_v10.Part:
1229
"""Converts a v0.3 Part (Pydantic model) to a v1.0 core Part (Protobuf object)."""
1330
core_part = pb2_v10.Part()
@@ -142,17 +159,9 @@ def to_core_task_status(
142159
"""Convert task status to v1.0 core type."""
143160
core_status = pb2_v10.TaskStatus()
144161
if compat_status.state:
145-
# map Enum string to proto enum
146-
state_str = compat_status.state.value.upper()
147-
if state_str == 'UNKNOWN':
148-
core_status.state = pb2_v10.TaskState.TASK_STATE_UNSPECIFIED
149-
else:
150-
# Handle 'input-required', 'auth-required'
151-
state_str = state_str.replace('-', '_')
152-
153-
core_status.state = cast(
154-
'Any', pb2_v10.TaskState.Value(f'TASK_STATE_{state_str}')
155-
)
162+
core_status.state = _COMPAT_TO_CORE_TASK_STATE.get(
163+
compat_status.state, pb2_v10.TaskState.TASK_STATE_UNSPECIFIED
164+
)
156165

157166
if compat_status.message:
158167
core_status.message.CopyFrom(to_core_message(compat_status.message))
@@ -167,16 +176,9 @@ def to_compat_task_status(
167176
core_status: pb2_v10.TaskStatus,
168177
) -> types_v03.TaskStatus:
169178
"""Convert task status to v0.3 compat type."""
170-
state_str = (
171-
pb2_v10.TaskState.Name(core_status.state)
172-
.replace('TASK_STATE_', '')
173-
.lower()
179+
state_enum = _CORE_TO_COMPAT_TASK_STATE.get(
180+
core_status.state, types_v03.TaskState.unknown
174181
)
175-
if state_str == 'unspecified':
176-
state_str = 'unknown'
177-
else:
178-
state_str = state_str.replace('_', '-')
179-
state_enum = types_v03.TaskState(state_str)
180182

181183
update = (
182184
to_compat_message(core_status.message)
@@ -380,16 +382,25 @@ def to_compat_task_status_update_event(
380382
core_event: pb2_v10.TaskStatusUpdateEvent,
381383
) -> types_v03.TaskStatusUpdateEvent:
382384
"""Convert task status update event to v0.3 compat type."""
385+
status = (
386+
to_compat_task_status(core_event.status)
387+
if core_event.HasField('status')
388+
else types_v03.TaskStatus(state=types_v03.TaskState.unknown)
389+
)
390+
final = status.state in (
391+
types_v03.TaskState.completed,
392+
types_v03.TaskState.canceled,
393+
types_v03.TaskState.failed,
394+
types_v03.TaskState.rejected,
395+
)
383396
return types_v03.TaskStatusUpdateEvent(
384397
task_id=core_event.task_id,
385398
context_id=core_event.context_id,
386-
status=to_compat_task_status(core_event.status)
387-
if core_event.HasField('status')
388-
else types_v03.TaskStatus(state=types_v03.TaskState.unknown),
399+
status=status,
389400
metadata=MessageToDict(core_event.metadata)
390401
if core_event.HasField('metadata')
391402
else None,
392-
final=False, # 'final' was removed in v1.0
403+
final=final,
393404
)
394405

395406

tests/compat/v0_3/test_conversions.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def test_push_notification_config_conversion():
352352
v03_config = types_v03.PushNotificationConfig(
353353
id='c1',
354354
url='http://test.com',
355-
token='tok',
356-
authentication=v03_auth, # noqa: S106
355+
token='tok', # noqa: S106
356+
authentication=v03_auth,
357357
)
358358

359359
v10_expected = pb2_v10.PushNotificationConfig(
@@ -507,11 +507,61 @@ def test_task_status_update_event_conversion():
507507
context_id='c1',
508508
status=v03_status,
509509
metadata={'m': 'v'},
510-
final=False, # final is not preserved in v1.0
510+
final=True, # final is computed based on status.state
511511
)
512512
assert v03_restored == v03_expected_restored
513513

514514

515+
def test_task_status_update_event_conversion_terminal_states():
516+
# Test all terminal states result in final=True
517+
terminal_states = [
518+
(
519+
pb2_v10.TaskState.TASK_STATE_COMPLETED,
520+
types_v03.TaskState.completed,
521+
),
522+
(pb2_v10.TaskState.TASK_STATE_CANCELED, types_v03.TaskState.canceled),
523+
(pb2_v10.TaskState.TASK_STATE_FAILED, types_v03.TaskState.failed),
524+
(pb2_v10.TaskState.TASK_STATE_REJECTED, types_v03.TaskState.rejected),
525+
]
526+
527+
for core_st, compat_st in terminal_states:
528+
v10_event = pb2_v10.TaskStatusUpdateEvent(
529+
status=pb2_v10.TaskStatus(state=core_st)
530+
)
531+
v03_restored = to_compat_task_status_update_event(v10_event)
532+
assert v03_restored.final is True
533+
assert v03_restored.status.state == compat_st
534+
535+
# Test non-terminal states result in final=False
536+
non_terminal_states = [
537+
(
538+
pb2_v10.TaskState.TASK_STATE_SUBMITTED,
539+
types_v03.TaskState.submitted,
540+
),
541+
(pb2_v10.TaskState.TASK_STATE_WORKING, types_v03.TaskState.working),
542+
(
543+
pb2_v10.TaskState.TASK_STATE_INPUT_REQUIRED,
544+
types_v03.TaskState.input_required,
545+
),
546+
(
547+
pb2_v10.TaskState.TASK_STATE_AUTH_REQUIRED,
548+
types_v03.TaskState.auth_required,
549+
),
550+
(
551+
pb2_v10.TaskState.TASK_STATE_UNSPECIFIED,
552+
types_v03.TaskState.unknown,
553+
),
554+
]
555+
556+
for core_st, compat_st in non_terminal_states:
557+
v10_event = pb2_v10.TaskStatusUpdateEvent(
558+
status=pb2_v10.TaskStatus(state=core_st)
559+
)
560+
v03_restored = to_compat_task_status_update_event(v10_event)
561+
assert v03_restored.final is False
562+
assert v03_restored.status.state == compat_st
563+
564+
515565
def test_task_status_update_event_conversion_minimal():
516566
# v03 status is required but might be constructed empty internally
517567
v10_event = pb2_v10.TaskStatusUpdateEvent(task_id='t1', context_id='c1')
@@ -632,16 +682,16 @@ def test_oauth_flows_conversion_auth_code():
632682
def test_oauth_flows_conversion_client_credentials():
633683
v03_flows = types_v03.OAuthFlows(
634684
client_credentials=types_v03.ClientCredentialsOAuthFlow(
635-
token_url='http://token2',
685+
token_url='http://token2', # noqa: S106
636686
scopes={'c': 'd'},
637-
refresh_url='ref2', # noqa: S106
687+
refresh_url='ref2',
638688
)
639689
)
640690
v10_expected = pb2_v10.OAuthFlows(
641691
client_credentials=pb2_v10.ClientCredentialsOAuthFlow(
642-
token_url='http://token2',
692+
token_url='http://token2', # noqa: S106
643693
scopes={'c': 'd'},
644-
refresh_url='ref2', # noqa: S106
694+
refresh_url='ref2',
645695
)
646696
)
647697
v10_flows = to_core_oauth_flows(v03_flows)
@@ -674,16 +724,16 @@ def test_oauth_flows_conversion_implicit():
674724
def test_oauth_flows_conversion_password():
675725
v03_flows = types_v03.OAuthFlows(
676726
password=types_v03.PasswordOAuthFlow(
677-
token_url='http://token3',
727+
token_url='http://token3', # noqa: S106
678728
scopes={'g': 'h'},
679-
refresh_url='ref4', # noqa: S106
729+
refresh_url='ref4',
680730
)
681731
)
682732
v10_expected = pb2_v10.OAuthFlows(
683733
password=pb2_v10.PasswordOAuthFlow(
684-
token_url='http://token3',
734+
token_url='http://token3', # noqa: S106
685735
scopes={'g': 'h'},
686-
refresh_url='ref4', # noqa: S106
736+
refresh_url='ref4',
687737
)
688738
)
689739
v10_flows = to_core_oauth_flows(v03_flows)
@@ -730,8 +780,8 @@ def test_security_scheme_oauth2():
730780
v03_flows = types_v03.OAuthFlows(
731781
authorization_code=types_v03.AuthorizationCodeOAuthFlow(
732782
authorization_url='u',
733-
token_url='t',
734-
scopes={}, # noqa: S106
783+
token_url='t', # noqa: S106
784+
scopes={},
735785
)
736786
)
737787
v03_scheme = types_v03.SecurityScheme(
@@ -1161,8 +1211,8 @@ def test_task_push_notification_config_conversion():
11611211
push_notification_config=types_v03.PushNotificationConfig(
11621212
id='c1',
11631213
url='http://url',
1164-
token='tok',
1165-
authentication=v03_auth, # noqa: S106
1214+
token='tok', # noqa: S106
1215+
authentication=v03_auth,
11661216
),
11671217
)
11681218
v10_expected = pb2_v10.TaskPushNotificationConfig(
@@ -1183,8 +1233,8 @@ def test_task_push_notification_config_conversion():
11831233
push_notification_config=types_v03.PushNotificationConfig(
11841234
id='c1',
11851235
url='http://url',
1186-
token='tok',
1187-
authentication=v03_auth, # noqa: S106
1236+
token='tok', # noqa: S106
1237+
authentication=v03_auth,
11881238
),
11891239
)
11901240
assert v03_restored == v03_expected_restored

0 commit comments

Comments
 (0)