Skip to content

Commit 592723c

Browse files
authored
Fix _normalize_union_str to recursively normalize nested unions (#2876)
1 parent 2c7944d commit 592723c

1 file changed

Lines changed: 51 additions & 8 deletions

File tree

tests/main/test_public_api_signature_baseline.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,21 +344,64 @@ def _type_to_str(tp: Any) -> str:
344344

345345

346346
def _normalize_union_str(type_str: str) -> str:
347-
"""Normalize a union type string by sorting its components."""
347+
"""Normalize a union type string by sorting its components recursively."""
348348
try:
349349
tree = ast.parse(type_str, mode="eval")
350350
except SyntaxError: # pragma: no cover
351351
return type_str
352-
if not isinstance(tree.body, ast.BinOp) or not isinstance(tree.body.op, ast.BitOr):
353-
return type_str
354352

355-
def collect_union_parts(node: ast.expr) -> list[str]:
353+
def normalize_node(node: ast.expr) -> str:
356354
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
357-
return collect_union_parts(node.left) + collect_union_parts(node.right)
358-
return [ast.unparse(node)]
359355

360-
parts = collect_union_parts(tree.body)
361-
return " | ".join(sorted(parts))
356+
def collect_union_parts(n: ast.expr) -> list[str]:
357+
if isinstance(n, ast.BinOp) and isinstance(n.op, ast.BitOr):
358+
return collect_union_parts(n.left) + collect_union_parts(n.right)
359+
return [normalize_node(n)]
360+
361+
parts = collect_union_parts(node)
362+
return " | ".join(sorted(parts))
363+
if isinstance(node, ast.Subscript):
364+
value = ast.unparse(node.value)
365+
if isinstance(node.slice, ast.Tuple):
366+
args = [normalize_node(elt) for elt in node.slice.elts]
367+
return f"{value}[{', '.join(args)}]"
368+
return f"{value}[{normalize_node(node.slice)}]"
369+
return ast.unparse(node)
370+
371+
return normalize_node(tree.body)
372+
373+
374+
@pytest.mark.parametrize(
375+
("input_str", "expected"),
376+
[
377+
("str | int", "int | str"),
378+
("str | int | None", "None | int | str"),
379+
("int", "int"),
380+
("Mapping[str, str]", "Mapping[str, str]"),
381+
("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"),
382+
("Mapping[str, list[str] | str]", "Mapping[str, list[str] | str]"),
383+
("Mapping[str, str | list[str]] | None", "Mapping[str, list[str] | str] | None"),
384+
("dict[str, int | str | None]", "dict[str, None | int | str]"),
385+
("list[str | int]", "list[int | str]"),
386+
("tuple[str | int, bool | None]", "tuple[int | str, None | bool]"),
387+
],
388+
)
389+
def test_normalize_union_str(input_str: str, expected: str) -> None:
390+
"""Test _normalize_union_str correctly normalizes union types recursively."""
391+
assert _normalize_union_str(input_str) == expected
392+
393+
394+
@pytest.mark.parametrize(
395+
("type_a", "type_b"),
396+
[
397+
("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"),
398+
("str | int | None", "None | str | int"),
399+
("dict[str, int | str]", "dict[str, str | int]"),
400+
],
401+
)
402+
def test_normalize_union_str_equivalence(type_a: str, type_b: str) -> None:
403+
"""Test that different orderings of the same union type normalize to the same string."""
404+
assert _normalize_union_str(type_a) == _normalize_union_str(type_b)
362405

363406

364407
def _normalize_type(tp: Any) -> str: # noqa: PLR0911

0 commit comments

Comments
 (0)