@@ -344,21 +344,64 @@ def _type_to_str(tp: Any) -> str:
344344
345345
346346def _normalize_union_str (type_str : str ) -> str :
347- """Normalize a union type string by sorting its components."""
347+ """Normalize a union type string by sorting its components recursively ."""
348348 try :
349349 tree = ast .parse (type_str , mode = "eval" )
350350 except SyntaxError : # pragma: no cover
351351 return type_str
352- if not isinstance (tree .body , ast .BinOp ) or not isinstance (tree .body .op , ast .BitOr ):
353- return type_str
354352
355- def collect_union_parts (node : ast .expr ) -> list [ str ] :
353+ def normalize_node (node : ast .expr ) -> str :
356354 if isinstance (node , ast .BinOp ) and isinstance (node .op , ast .BitOr ):
357- return collect_union_parts (node .left ) + collect_union_parts (node .right )
358- return [ast .unparse (node )]
359355
360- parts = collect_union_parts (tree .body )
361- return " | " .join (sorted (parts ))
356+ def collect_union_parts (n : ast .expr ) -> list [str ]:
357+ if isinstance (n , ast .BinOp ) and isinstance (n .op , ast .BitOr ):
358+ return collect_union_parts (n .left ) + collect_union_parts (n .right )
359+ return [normalize_node (n )]
360+
361+ parts = collect_union_parts (node )
362+ return " | " .join (sorted (parts ))
363+ if isinstance (node , ast .Subscript ):
364+ value = ast .unparse (node .value )
365+ if isinstance (node .slice , ast .Tuple ):
366+ args = [normalize_node (elt ) for elt in node .slice .elts ]
367+ return f"{ value } [{ ', ' .join (args )} ]"
368+ return f"{ value } [{ normalize_node (node .slice )} ]"
369+ return ast .unparse (node )
370+
371+ return normalize_node (tree .body )
372+
373+
374+ @pytest .mark .parametrize (
375+ ("input_str" , "expected" ),
376+ [
377+ ("str | int" , "int | str" ),
378+ ("str | int | None" , "None | int | str" ),
379+ ("int" , "int" ),
380+ ("Mapping[str, str]" , "Mapping[str, str]" ),
381+ ("Mapping[str, str | list[str]]" , "Mapping[str, list[str] | str]" ),
382+ ("Mapping[str, list[str] | str]" , "Mapping[str, list[str] | str]" ),
383+ ("Mapping[str, str | list[str]] | None" , "Mapping[str, list[str] | str] | None" ),
384+ ("dict[str, int | str | None]" , "dict[str, None | int | str]" ),
385+ ("list[str | int]" , "list[int | str]" ),
386+ ("tuple[str | int, bool | None]" , "tuple[int | str, None | bool]" ),
387+ ],
388+ )
389+ def test_normalize_union_str (input_str : str , expected : str ) -> None :
390+ """Test _normalize_union_str correctly normalizes union types recursively."""
391+ assert _normalize_union_str (input_str ) == expected
392+
393+
394+ @pytest .mark .parametrize (
395+ ("type_a" , "type_b" ),
396+ [
397+ ("Mapping[str, str | list[str]]" , "Mapping[str, list[str] | str]" ),
398+ ("str | int | None" , "None | str | int" ),
399+ ("dict[str, int | str]" , "dict[str, str | int]" ),
400+ ],
401+ )
402+ def test_normalize_union_str_equivalence (type_a : str , type_b : str ) -> None :
403+ """Test that different orderings of the same union type normalize to the same string."""
404+ assert _normalize_union_str (type_a ) == _normalize_union_str (type_b )
362405
363406
364407def _normalize_type (tp : Any ) -> str : # noqa: PLR0911
0 commit comments