diff --git a/tests/main/test_public_api_signature_baseline.py b/tests/main/test_public_api_signature_baseline.py index 90327169c..00b4d07ea 100644 --- a/tests/main/test_public_api_signature_baseline.py +++ b/tests/main/test_public_api_signature_baseline.py @@ -344,21 +344,64 @@ def _type_to_str(tp: Any) -> str: def _normalize_union_str(type_str: str) -> str: - """Normalize a union type string by sorting its components.""" + """Normalize a union type string by sorting its components recursively.""" try: tree = ast.parse(type_str, mode="eval") except SyntaxError: # pragma: no cover return type_str - if not isinstance(tree.body, ast.BinOp) or not isinstance(tree.body.op, ast.BitOr): - return type_str - def collect_union_parts(node: ast.expr) -> list[str]: + def normalize_node(node: ast.expr) -> str: if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): - return collect_union_parts(node.left) + collect_union_parts(node.right) - return [ast.unparse(node)] - parts = collect_union_parts(tree.body) - return " | ".join(sorted(parts)) + def collect_union_parts(n: ast.expr) -> list[str]: + if isinstance(n, ast.BinOp) and isinstance(n.op, ast.BitOr): + return collect_union_parts(n.left) + collect_union_parts(n.right) + return [normalize_node(n)] + + parts = collect_union_parts(node) + return " | ".join(sorted(parts)) + if isinstance(node, ast.Subscript): + value = ast.unparse(node.value) + if isinstance(node.slice, ast.Tuple): + args = [normalize_node(elt) for elt in node.slice.elts] + return f"{value}[{', '.join(args)}]" + return f"{value}[{normalize_node(node.slice)}]" + return ast.unparse(node) + + return normalize_node(tree.body) + + +@pytest.mark.parametrize( + ("input_str", "expected"), + [ + ("str | int", "int | str"), + ("str | int | None", "None | int | str"), + ("int", "int"), + ("Mapping[str, str]", "Mapping[str, str]"), + ("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"), + ("Mapping[str, list[str] | str]", "Mapping[str, list[str] | str]"), + ("Mapping[str, str | list[str]] | None", "Mapping[str, list[str] | str] | None"), + ("dict[str, int | str | None]", "dict[str, None | int | str]"), + ("list[str | int]", "list[int | str]"), + ("tuple[str | int, bool | None]", "tuple[int | str, None | bool]"), + ], +) +def test_normalize_union_str(input_str: str, expected: str) -> None: + """Test _normalize_union_str correctly normalizes union types recursively.""" + assert _normalize_union_str(input_str) == expected + + +@pytest.mark.parametrize( + ("type_a", "type_b"), + [ + ("Mapping[str, str | list[str]]", "Mapping[str, list[str] | str]"), + ("str | int | None", "None | str | int"), + ("dict[str, int | str]", "dict[str, str | int]"), + ], +) +def test_normalize_union_str_equivalence(type_a: str, type_b: str) -> None: + """Test that different orderings of the same union type normalize to the same string.""" + assert _normalize_union_str(type_a) == _normalize_union_str(type_b) def _normalize_type(tp: Any) -> str: # noqa: PLR0911