Skip to content

Commit 950c37d

Browse files
refactor(routes): replace oneof cartesian product with allOf+oneOf and inline component install
Replace O(∏fields) Cartesian product oneof generation with O(∑fields) allOf+oneOf-per-group approach, and inline _install_components into add_a2a_routes_to_fastapi to remove a single-use wrapper function.
1 parent 205ffa2 commit 950c37d

3 files changed

Lines changed: 83 additions & 43 deletions

File tree

src/a2a/server/routes/helpers/_proto_schema.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
from google.protobuf.descriptor import Descriptor, FieldDescriptor
66
from google.protobuf.message import Message
77

8-
from a2a.types.a2a_pb2 import (
9-
SendMessageRequest,
10-
TaskPushNotificationConfig,
11-
)
8+
from a2a.types.a2a_pb2 import SendMessageRequest, TaskPushNotificationConfig
129

1310

1411
REST_BODY_TYPES: dict[tuple[str, str], type[Message]] = {
@@ -100,21 +97,22 @@ def message_schema(
10097
components[name] = {'type': 'object', 'properties': base_properties}
10198
return ref
10299

103-
variants: list[dict[str, Any]] = [{}]
104-
for oneof in real_oneofs:
105-
variants = [
106-
{**variant, f.name: field_schema(f, components)}
107-
for variant in variants
108-
for f in oneof.fields
109-
]
110-
components[name] = {
111-
'oneOf': [
112-
{
113-
'type': 'object',
114-
'properties': {**base_properties, **variant_props},
115-
'required': sorted(set(variant_props) & oneof_field_names),
116-
}
117-
for variant_props in variants
118-
],
119-
}
100+
oneof_constraints = [
101+
{
102+
'oneOf': [
103+
{
104+
'type': 'object',
105+
'properties': {f.name: field_schema(f, components)},
106+
'required': [f.name],
107+
}
108+
for f in oneof.fields
109+
]
110+
}
111+
for oneof in real_oneofs
112+
]
113+
parts: list[dict[str, Any]] = []
114+
if base_properties:
115+
parts.append({'type': 'object', 'properties': base_properties})
116+
parts.extend(oneof_constraints)
117+
components[name] = parts[0] if len(parts) == 1 else {'allOf': parts}
120118
return ref

src/a2a/server/routes/helpers/fastapi.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,6 @@ def _attach_route(
101101
)
102102

103103

104-
def _install_components(app: 'FastAPI', schemas: dict[str, Any]) -> None:
105-
original_openapi = app.openapi
106-
107-
def _openapi() -> dict[str, Any]:
108-
if app.openapi_schema:
109-
return app.openapi_schema
110-
schema = original_openapi()
111-
component_schemas = schema.setdefault('components', {}).setdefault(
112-
'schemas', {}
113-
)
114-
for name, sub_schema in schemas.items():
115-
component_schemas.setdefault(name, sub_schema)
116-
return schema
117-
118-
app.openapi = _openapi # type: ignore[method-assign]
119-
120-
121104
def add_a2a_routes_to_fastapi(
122105
app: 'FastAPI',
123106
*,
@@ -197,4 +180,17 @@ def add_a2a_routes_to_fastapi(
197180
require_version_header=True,
198181
)
199182

200-
_install_components(app, components)
183+
original_openapi = app.openapi
184+
185+
def _openapi() -> dict[str, Any]:
186+
if app.openapi_schema:
187+
return app.openapi_schema
188+
schema = original_openapi()
189+
component_schemas = schema.setdefault('components', {}).setdefault(
190+
'schemas', {}
191+
)
192+
for name, sub_schema in components.items():
193+
component_schemas.setdefault(name, sub_schema)
194+
return schema
195+
196+
app.openapi = _openapi # type: ignore[method-assign]

tests/server/routes/helpers/test_proto_schema.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,68 @@ def test_message_schema_well_known_type_inline():
3939
assert 'Struct' not in components
4040

4141

42-
def test_message_schema_oneof_becomes_one_of():
42+
def test_message_schema_oneof_becomes_allof_with_one_of_constraint():
4343
components = {}
4444
message_schema(Part.DESCRIPTOR, components)
4545
schema = components['Part']
46-
assert 'oneOf' in schema
47-
oneof_keys = {list(v['properties'])[-1] for v in schema['oneOf']}
46+
assert 'allOf' in schema
47+
one_of_constraint = next(p for p in schema['allOf'] if 'oneOf' in p)
48+
oneof_keys = {list(v['properties'])[0] for v in one_of_constraint['oneOf']}
4849
assert {'text', 'raw', 'url', 'data'} <= oneof_keys
4950

5051

5152
def test_message_schema_oneof_variants_have_required():
5253
components = {}
5354
message_schema(Part.DESCRIPTOR, components)
54-
for variant in components['Part']['oneOf']:
55+
one_of_constraint = next(
56+
p for p in components['Part']['allOf'] if 'oneOf' in p
57+
)
58+
for variant in one_of_constraint['oneOf']:
5559
assert len(variant['required']) == 1
5660

5761

62+
def test_message_schema_multiple_oneofs_use_allof_not_cartesian_product():
63+
# Simulate a descriptor with two oneofs: verify allOf has one constraint
64+
# per oneof rather than a flat list of cross-product variants.
65+
from unittest.mock import MagicMock
66+
67+
def _make_field(name):
68+
f = MagicMock()
69+
f.name = name
70+
f.message_type = None
71+
f.type = 9 # TYPE_STRING
72+
f.is_repeated = False
73+
return f
74+
75+
def _make_oneof(fields):
76+
o = MagicMock()
77+
o.fields = fields
78+
return o
79+
80+
f_a, f_b = _make_field('a'), _make_field('b')
81+
f_x, f_y = _make_field('x'), _make_field('y')
82+
oneof1 = _make_oneof([f_a, f_b])
83+
oneof2 = _make_oneof([f_x, f_y])
84+
85+
descriptor = MagicMock()
86+
descriptor.full_name = 'test.MultiOneof'
87+
descriptor.name = 'MultiOneof'
88+
descriptor.oneofs = [oneof1, oneof2]
89+
descriptor.fields = [f_a, f_b, f_x, f_y]
90+
91+
components = {}
92+
message_schema(descriptor, components)
93+
schema = components['MultiOneof']
94+
95+
# Should be allOf with two oneOf constraints (one per oneof group),
96+
# NOT a flat oneOf with 2*2=4 Cartesian-product variants.
97+
assert 'allOf' in schema
98+
one_of_constraints = [p for p in schema['allOf'] if 'oneOf' in p]
99+
assert len(one_of_constraints) == 2
100+
assert len(one_of_constraints[0]['oneOf']) == 2
101+
assert len(one_of_constraints[1]['oneOf']) == 2
102+
103+
58104
def test_field_schema_repeated_wraps_in_array():
59105
components = {}
60106
msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[

0 commit comments

Comments
 (0)