Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions tests/main/test_public_api_signature_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,21 +344,64 @@ def _type_to_str(tp: Any) -> str:


def _normalize_union_str(type_str: str) -> str:
"""Normalize a union type string by sorting its components."""
"""Normalize a union type string by sorting its components recursively."""
try:
tree = ast.parse(type_str, mode="eval")
except SyntaxError: # pragma: no cover
return type_str
if not isinstance(tree.body, ast.BinOp) or not isinstance(tree.body.op, ast.BitOr):
return type_str

def collect_union_parts(node: ast.expr) -> list[str]:
def normalize_node(node: ast.expr) -> str:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
return collect_union_parts(node.left) + collect_union_parts(node.right)
return [ast.unparse(node)]

parts = collect_union_parts(tree.body)
return " | ".join(sorted(parts))
def collect_union_parts(n: ast.expr) -> list[str]:
if isinstance(n, ast.BinOp) and isinstance(n.op, ast.BitOr):
return collect_union_parts(n.left) + collect_union_parts(n.right)
return [normalize_node(n)]

parts = collect_union_parts(node)
return " | ".join(sorted(parts))
if isinstance(node, ast.Subscript):
value = ast.unparse(node.value)
if isinstance(node.slice, ast.Tuple):
args = [normalize_node(elt) for elt in node.slice.elts]
return f"{value}[{', '.join(args)}]"
return f"{value}[{normalize_node(node.slice)}]"
return ast.unparse(node)

return normalize_node(tree.body)


@pytest.mark.parametrize(
("input_str", "expected"),
[
("str | int", "int | str"),
("str | int | None", "None | int | str"),
("int", "int"),
("Mapping[str, str]", "Mapping[str, str]"),
("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"),
("Mapping[str, list[str] | str]", "Mapping[str, list[str] | str]"),
("Mapping[str, str | list[str]] | None", "Mapping[str, list[str] | str] | None"),
("dict[str, int | str | None]", "dict[str, None | int | str]"),
("list[str | int]", "list[int | str]"),
("tuple[str | int, bool | None]", "tuple[int | str, None | bool]"),
],
)
def test_normalize_union_str(input_str: str, expected: str) -> None:
"""Test _normalize_union_str correctly normalizes union types recursively."""
assert _normalize_union_str(input_str) == expected


@pytest.mark.parametrize(
("type_a", "type_b"),
[
("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"),
("str | int | None", "None | str | int"),
("dict[str, int | str]", "dict[str, str | int]"),
],
)
def test_normalize_union_str_equivalence(type_a: str, type_b: str) -> None:
"""Test that different orderings of the same union type normalize to the same string."""
assert _normalize_union_str(type_a) == _normalize_union_str(type_b)


def _normalize_type(tp: Any) -> str: # noqa: PLR0911
Expand Down
Loading