From bd3c60a90813a6e18fccff6d2f32a8ffd4f27d1a Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 07:27:48 +0300 Subject: [PATCH 01/28] Creating if branch to check for TupleExpr --- feature-test.py | 27 +++ mypy/checker.py | 115 ++++++++---- mypy/checkexpr.py | 6 +- mypy/operators.py | 2 + mypy/typeshed/stdlib/typing.pyi | 203 +++++++++++++++++---- mypy/typeshed/stdlib/typing_extensions.pyi | 52 +++++- 6 files changed, 331 insertions(+), 74 deletions(-) create mode 100644 feature-test.py diff --git a/feature-test.py b/feature-test.py new file mode 100644 index 0000000000000..0bdf337f0ec56 --- /dev/null +++ b/feature-test.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing_extensions import assert_never + + +class MyEnum(Enum): + A = 1 + B = 2 + C = 3 + + +reveal_type(MyEnum.A) +reveal_type(MyEnum.B) +reveal_type(MyEnum.C) + + +def my_function(a: MyEnum) -> bool: + + print(type((MyEnum.B, MyEnum.C))) + if a == MyEnum.A: + print(type(a)) + return True + elif a in (MyEnum.B, MyEnum.C): + return False + assert_never(a) + + +print(my_function(MyEnum.A)) diff --git a/mypy/checker.py b/mypy/checker.py index b9a9d3affb90f..8ece558b15d5b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1247,7 +1247,8 @@ def check_func_def( if ( arg_type.variance == COVARIANT and defn.name not in ("__init__", "__new__", "__post_init__") - and not is_private(defn.name) # private methods are not inherited + # private methods are not inherited + and not is_private(defn.name) ): ctx: Context = arg_type if ctx.line < 0: @@ -3112,7 +3113,8 @@ def check_compatibility_all_supers( if ( isinstance(lvalue_node, Var) and lvalue.kind in (MDEF, None) - and len(lvalue_node.info.bases) > 0 # None for Vars defined via self + # None for Vars defined via self + and len(lvalue_node.info.bases) > 0 ): for base in lvalue_node.info.mro[1:]: tnode = base.names.get(lvalue_node.name) @@ -5701,6 +5703,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM operands = [collapse_walrus(x) for x in node.operands] operand_types = [] narrowable_operand_index_to_hash = {} + # print(operands) for i, expr in enumerate(operands): if not self.has_type(expr): return {}, {} @@ -5790,6 +5793,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: is_valid_target, coerce_only_in_literal_context, ) + # print("done2", if_map, else_map) # Strictly speaking, we should also skip this check if the objects in the expr # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) @@ -5802,57 +5806,98 @@ def has_no_custom_eq_checks(t: Type) -> bool: expr_indices, narrowable_operand_index_to_hash.keys(), ) + # print("done3", if_map, else_map) # If we haven't been able to narrow types yet, we might be dealing with a # explicit type(x) == some_type check if if_map == {} and else_map == {}: if_map, else_map = self.find_type_equals_check(node, expr_indices) + # print("done4", if_map, else_map) elif operator in {"in", "not in"}: assert len(expr_indices) == 2 left_index, right_index = expr_indices item_type = operand_types[left_index] iterable_type = operand_types[right_index] + # added by me + right_expr = operands[right_index] if_map, else_map = {}, {} + # print(left_index, "\n") + # print(right_index, "\n") + # print(narrowable_operand_index_to_hash, "\n") - if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) - if ( - collection_item_type is not None - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types(item_type, collection_item_type) - ): - if_map[operands[left_index]] = remove_optional(item_type) - - if right_index in narrowable_operand_index_to_hash: - if_type, else_type = self.conditional_types_for_iterable( - item_type, iterable_type + if isinstance(right_expr, TupleExpr): + all_literal_enum = all( + self.is_literal_enum(element_expr) for element_expr in right_expr.items ) - expr = operands[right_index] - if if_type is None: - if_map = None - else: - if_map[expr] = if_type - if else_type is None: + if all_literal_enum: + # Set if_map for the entire tuple + if_map = {} else_map = None - else: - else_map[expr] = else_type + + # print(if_type == else_type, "\n") + # print(left_expr, "\n") + # print(right_expr, "\n") + + # print(if_map, else_map) + + else: + if left_index in narrowable_operand_index_to_hash: + # print("left in") + # We only try and narrow away 'None' for now + if is_overlapping_none(item_type): + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) + ) + # print("done5", if_map, else_map) + if ( + collection_item_type is not None + and not is_overlapping_none(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types( + item_type, collection_item_type + ) + ): + if_map[operands[left_index]] = remove_optional(item_type) + + # print("done6", if_map, else_map) + + if right_index in narrowable_operand_index_to_hash: + # print("right in") + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + # print("done7", if_map, else_map) + + if if_type is None: + if_map = None + # print("done8", if_map, else_map) + else: + if_map[expr] = if_type + # print("done9", if_map, else_map) + + if else_type is None: + else_map = None + # print("done11", if_map, else_map) + else: + else_map[expr] = else_type + + # print("done12", if_map, else_map) else: if_map = {} else_map = {} + # print("done13", if_map, else_map) if operator in {"is not", "!=", "not in"}: if_map, else_map = else_map, if_map + # print("done14", if_map, else_map) + partial_type_maps.append((if_map, else_map)) # If we have found non-trivial restrictions from the regular comparisons, @@ -6144,6 +6189,7 @@ def refine_identity_comparison_expression( expressions in the chain to a Literal type. Performing this coercion is sometimes too aggressive of a narrowing, depending on context. """ + should_coerce = True if coerce_only_in_literal_context: @@ -6239,9 +6285,6 @@ def should_coerce_inner(typ: Type) -> bool: if sum_type_name is not None: expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name) - # We intentionally use 'conditional_types' directly here instead of - # 'self.conditional_types_with_intersection': we only compute ad-hoc - # intersections when working with pure instances. types = conditional_types(expr_type, target_type) partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) @@ -7189,7 +7232,10 @@ class Foo(Enum): unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single primitive unit. """ + if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): + # print(n, isinstance(n, MemberExpr)) + # print(n, isinstance(n.expr, NameExpr)) return False parent_type = self.lookup_type_or_none(n.expr) @@ -7495,7 +7541,8 @@ def builtin_item_type(tp: Type) -> Type | None: else: normalized_items.append(it) if all(not isinstance(it, AnyType) for it in get_proper_types(normalized_items)): - return make_simplified_union(normalized_items) # this type is not externally visible + # this type is not externally visible + return make_simplified_union(normalized_items) elif isinstance(tp, TypedDictType): # TypedDict always has non-optional string keys. Find the key type from the Mapping # base class. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index da61833bbe5b1..516231c4c11c0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3473,7 +3473,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # a "Need type annotation ..." message, as it would be noise. right_type = self.find_partial_type_ref_fast_path(right) if right_type is None: - right_type = self.accept(right) # Validate the right operand + # Validate the right operand + right_type = self.accept(right) right_type = get_proper_type(right_type) item_types: Sequence[Type] = [right_type] @@ -6094,7 +6095,8 @@ def __init__(self, ignore_in_type_obj: bool) -> None: self.ignore_in_type_obj = ignore_in_type_obj def visit_any(self, t: AnyType) -> bool: - return t.type_of_any != TypeOfAny.special_form # special forms are not real Any types + # special forms are not real Any types + return t.type_of_any != TypeOfAny.special_form def visit_callable_type(self, t: CallableType) -> bool: if self.ignore_in_type_obj and t.is_type_obj(): diff --git a/mypy/operators.py b/mypy/operators.py index d1f050b58faeb..037ebcd97e937 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -119,6 +119,8 @@ "!=": "==", "is": "is not", "is not": "is", + "in": "not in", + "not in": "in", "<": ">=", "<=": ">", ">": "<=", diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 7694157d70fe5..38ea6d61514d0 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -107,7 +107,8 @@ if sys.version_info >= (3, 9): __all__ += ["Annotated", "BinaryIO", "IO", "Match", "Pattern", "TextIO"] if sys.version_info >= (3, 10): - __all__ += ["Concatenate", "ParamSpec", "ParamSpecArgs", "ParamSpecKwargs", "TypeAlias", "TypeGuard", "is_typeddict"] + __all__ += ["Concatenate", "ParamSpec", "ParamSpecArgs", + "ParamSpecKwargs", "TypeAlias", "TypeGuard", "is_typeddict"] if sys.version_info >= (3, 11): __all__ += [ @@ -133,10 +134,14 @@ ContextManager = AbstractContextManager AsyncContextManager = AbstractAsyncContextManager # This itself is only available during type checking + + def type_check_only(func_or_cls: _F) -> _F: ... + Any = object() + @_final class TypeVar: @property @@ -152,6 +157,7 @@ class TypeVar: if sys.version_info >= (3, 12): @property def __infer_variance__(self) -> bool: ... + def __init__( self, name: str, @@ -171,10 +177,13 @@ class TypeVar: if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... + # Used for an undocumented mypy feature. Does not exist at runtime. _promote = object() # N.B. Keep this definition in sync with typing_extensions._SpecialForm + + @_final class _SpecialForm: def __getitem__(self, parameters: Any) -> object: ... @@ -182,12 +191,15 @@ class _SpecialForm: def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + _F = TypeVar("_F", bound=Callable[..., Any]) _P = _ParamSpec("_P") _T = TypeVar("_T") + def overload(func: _F) -> _F: ... + Union: _SpecialForm Generic: _SpecialForm # Protocol is only present in 3.8 and later, but mypy needs it unconditionally @@ -221,7 +233,8 @@ if sys.version_info >= (3, 11): def __init__(self, name: str) -> None: ... def __iter__(self) -> Any: ... def __typing_subst__(self, arg: Never) -> Never: ... - def __typing_prepare_subst__(self, alias: Incomplete, args: Incomplete) -> Incomplete: ... + def __typing_prepare_subst__( + self, alias: Incomplete, args: Incomplete) -> Incomplete: ... if sys.version_info >= (3, 10): @_final @@ -251,6 +264,7 @@ if sys.version_info >= (3, 10): if sys.version_info >= (3, 12): @property def __infer_variance__(self) -> bool: ... + def __init__( self, name: str, @@ -271,7 +285,8 @@ if sys.version_info >= (3, 10): def kwargs(self) -> ParamSpecKwargs: ... if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... - def __typing_prepare_subst__(self, alias: Incomplete, args: Incomplete) -> Incomplete: ... + def __typing_prepare_subst__( + self, alias: Incomplete, args: Incomplete) -> Incomplete: ... def __or__(self, right: Any) -> _SpecialForm: ... def __ror__(self, left: Any) -> _SpecialForm: ... @@ -298,15 +313,21 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers. _VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. _TC = TypeVar("_TC", bound=Type[object]) + def no_type_check(arg: _F) -> _F: ... -def no_type_check_decorator(decorator: Callable[_P, _T]) -> Callable[_P, _T]: ... + + +def no_type_check_decorator( + decorator: Callable[_P, _T]) -> Callable[_P, _T]: ... # Type aliases and type constructors + class _Alias: # Class for defining generic aliases for library types. def __getitem__(self, typeargs: Any) -> Any: ... + List = _Alias() Dict = _Alias() DefaultDict = _Alias() @@ -326,44 +347,55 @@ AnyStr = TypeVar("AnyStr", str, bytes) # noqa: Y001 # Technically in 3.7 this inherited from GenericMeta. But let's not reflect that, since # type checkers tend to assume that Protocols all have the ABCMeta metaclass. + + class _ProtocolMeta(ABCMeta): if sys.version_info >= (3, 12): def __init__(cls, *args: Any, **kwargs: Any) -> None: ... # Abstract base classes. + def runtime_checkable(cls: _TC) -> _TC: ... + + @runtime_checkable class SupportsInt(Protocol, metaclass=ABCMeta): @abstractmethod def __int__(self) -> int: ... + @runtime_checkable class SupportsFloat(Protocol, metaclass=ABCMeta): @abstractmethod def __float__(self) -> float: ... + @runtime_checkable class SupportsComplex(Protocol, metaclass=ABCMeta): @abstractmethod def __complex__(self) -> complex: ... + @runtime_checkable class SupportsBytes(Protocol, metaclass=ABCMeta): @abstractmethod def __bytes__(self) -> bytes: ... + if sys.version_info >= (3, 8): @runtime_checkable class SupportsIndex(Protocol, metaclass=ABCMeta): @abstractmethod def __index__(self) -> int: ... + @runtime_checkable class SupportsAbs(Protocol[_T_co]): @abstractmethod def __abs__(self) -> _T_co: ... + @runtime_checkable class SupportsRound(Protocol[_T_co]): @overload @@ -373,11 +405,13 @@ class SupportsRound(Protocol[_T_co]): @abstractmethod def __round__(self, __ndigits: int) -> _T_co: ... + @runtime_checkable class Sized(Protocol, metaclass=ABCMeta): @abstractmethod def __len__(self) -> int: ... + @runtime_checkable class Hashable(Protocol, metaclass=ABCMeta): # TODO: This is special, in that a subclass of a hashable class may not be hashable @@ -386,40 +420,51 @@ class Hashable(Protocol, metaclass=ABCMeta): @abstractmethod def __hash__(self) -> int: ... + @runtime_checkable class Iterable(Protocol[_T_co]): @abstractmethod def __iter__(self) -> Iterator[_T_co]: ... + @runtime_checkable class Iterator(Iterable[_T_co], Protocol[_T_co]): @abstractmethod def __next__(self) -> _T_co: ... def __iter__(self) -> Iterator[_T_co]: ... + @runtime_checkable class Reversible(Iterable[_T_co], Protocol[_T_co]): @abstractmethod def __reversed__(self) -> Iterator[_T_co]: ... + _YieldT_co = TypeVar("_YieldT_co", covariant=True) _SendT_contra = TypeVar("_SendT_contra", contravariant=True) _ReturnT_co = TypeVar("_ReturnT_co", covariant=True) + class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): def __next__(self) -> _YieldT_co: ... @abstractmethod def send(self, __value: _SendT_contra) -> _YieldT_co: ... + @overload @abstractmethod def throw( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> _YieldT_co: ... + @overload @abstractmethod - def throw(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> _YieldT_co: ... + def throw(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> _YieldT_co: ... + def close(self) -> None: ... - def __iter__(self) -> Generator[_YieldT_co, _SendT_contra, _ReturnT_co]: ... + def __iter__(self) -> Generator[_YieldT_co, + _SendT_contra, _ReturnT_co]: ... + @property def gi_code(self) -> CodeType: ... @property @@ -429,11 +474,13 @@ class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _Return @property def gi_yieldfrom(self) -> Generator[Any, Any, Any] | None: ... + @runtime_checkable class Awaitable(Protocol[_T_co]): @abstractmethod def __await__(self) -> Generator[Any, None, _T_co]: ... + class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): __name__: str __qualname__: str @@ -447,50 +494,64 @@ class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _Retu def cr_running(self) -> bool: ... @abstractmethod def send(self, __value: _SendT_contra) -> _YieldT_co: ... + @overload @abstractmethod def throw( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> _YieldT_co: ... + @overload @abstractmethod - def throw(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> _YieldT_co: ... + def throw(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> _YieldT_co: ... + @abstractmethod def close(self) -> None: ... # NOTE: This type does not exist in typing.py or PEP 484 but mypy needs it to exist. # The parameters correspond to Generator, but the 4th is the original type. + + @type_check_only class AwaitableGenerator( Awaitable[_ReturnT_co], Generator[_YieldT_co, _SendT_contra, _ReturnT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co, _S], metaclass=ABCMeta, -): ... +): + ... + @runtime_checkable class AsyncIterable(Protocol[_T_co]): @abstractmethod def __aiter__(self) -> AsyncIterator[_T_co]: ... + @runtime_checkable class AsyncIterator(AsyncIterable[_T_co], Protocol[_T_co]): @abstractmethod def __anext__(self) -> Awaitable[_T_co]: ... def __aiter__(self) -> AsyncIterator[_T_co]: ... + class AsyncGenerator(AsyncIterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra]): def __anext__(self) -> Awaitable[_YieldT_co]: ... @abstractmethod def asend(self, __value: _SendT_contra) -> Awaitable[_YieldT_co]: ... + @overload @abstractmethod def athrow( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> Awaitable[_YieldT_co]: ... + @overload @abstractmethod - def athrow(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> Awaitable[_YieldT_co]: ... + def athrow(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> Awaitable[_YieldT_co]: ... + def aclose(self) -> Awaitable[None]: ... @property def ag_await(self) -> Any: ... @@ -501,18 +562,21 @@ class AsyncGenerator(AsyncIterator[_YieldT_co], Generic[_YieldT_co, _SendT_contr @property def ag_running(self) -> bool: ... + @runtime_checkable class Container(Protocol[_T_co]): # This is generic more on vibes than anything else @abstractmethod def __contains__(self, __x: object) -> bool: ... + @runtime_checkable class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]): # Implement Sized (but don't have it as a base class). @abstractmethod def __len__(self) -> int: ... + class Sequence(Collection[_T_co], Reversible[_T_co]): @overload @abstractmethod @@ -527,6 +591,7 @@ class Sequence(Collection[_T_co], Reversible[_T_co]): def __iter__(self) -> Iterator[_T_co]: ... def __reversed__(self) -> Iterator[_T_co]: ... + class MutableSequence(Sequence[_T]): @abstractmethod def insert(self, index: int, value: _T) -> None: ... @@ -557,6 +622,7 @@ class MutableSequence(Sequence[_T]): def remove(self, value: _T) -> None: ... def __iadd__(self, values: Iterable[_T]) -> typing_extensions.Self: ... + class AbstractSet(Collection[_T_co]): @abstractmethod def __contains__(self, x: object) -> bool: ... @@ -573,6 +639,7 @@ class AbstractSet(Collection[_T_co]): def __eq__(self, other: object) -> bool: ... def isdisjoint(self, other: Iterable[Any]) -> bool: ... + class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T) -> None: ... @@ -582,17 +649,25 @@ class MutableSet(AbstractSet[_T]): def clear(self) -> None: ... def pop(self) -> _T: ... def remove(self, value: _T) -> None: ... - def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ior__( + self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... - def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ixor__( + self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... + class MappingView(Sized): def __init__(self, mapping: Mapping[Any, Any]) -> None: ... # undocumented def __len__(self) -> int: ... + class ItemsView(MappingView, AbstractSet[tuple[_KT_co, _VT_co]], Generic[_KT_co, _VT_co]): - def __init__(self, mapping: Mapping[_KT_co, _VT_co]) -> None: ... # undocumented + def __init__( + self, mapping: Mapping[_KT_co, _VT_co]) -> None: ... # undocumented + def __and__(self, other: Iterable[Any]) -> set[tuple[_KT_co, _VT_co]]: ... def __rand__(self, other: Iterable[_T]) -> set[_T]: ... def __contains__(self, item: object) -> bool: ... @@ -600,15 +675,24 @@ class ItemsView(MappingView, AbstractSet[tuple[_KT_co, _VT_co]], Generic[_KT_co, if sys.version_info >= (3, 8): def __reversed__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ... - def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... - def __ror__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, + _VT_co] | _T]: ... + def __ror__(self, other: Iterable[_T] + ) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __sub__(self, other: Iterable[Any]) -> set[tuple[_KT_co, _VT_co]]: ... def __rsub__(self, other: Iterable[_T]) -> set[_T]: ... - def __xor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... - def __rxor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... + + def __xor__(self, other: Iterable[_T] + ) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __rxor__(self, other: Iterable[_T] + ) -> set[tuple[_KT_co, _VT_co] | _T]: ... + class KeysView(MappingView, AbstractSet[_KT_co]): - def __init__(self, mapping: Mapping[_KT_co, Any]) -> None: ... # undocumented + def __init__( + self, mapping: Mapping[_KT_co, Any]) -> None: ... # undocumented + def __and__(self, other: Iterable[Any]) -> set[_KT_co]: ... def __rand__(self, other: Iterable[_T]) -> set[_T]: ... def __contains__(self, key: object) -> bool: ... @@ -623,13 +707,17 @@ class KeysView(MappingView, AbstractSet[_KT_co]): def __xor__(self, other: Iterable[_T]) -> set[_KT_co | _T]: ... def __rxor__(self, other: Iterable[_T]) -> set[_KT_co | _T]: ... + class ValuesView(MappingView, Collection[_VT_co]): - def __init__(self, mapping: Mapping[Any, _VT_co]) -> None: ... # undocumented + def __init__(self, mapping: Mapping[Any, + _VT_co]) -> None: ... # undocumented + def __contains__(self, value: object) -> bool: ... def __iter__(self) -> Iterator[_VT_co]: ... if sys.version_info >= (3, 8): def __reversed__(self) -> Iterator[_VT_co]: ... + class Mapping(Collection[_KT], Generic[_KT, _VT_co]): # TODO: We wish the key type could also be covariant, but that doesn't work, # see discussion in https://github.com/python/typing/pull/273. @@ -646,6 +734,7 @@ class Mapping(Collection[_KT], Generic[_KT, _VT_co]): def __contains__(self, __key: object) -> bool: ... def __eq__(self, __other: object) -> bool: ... + class MutableMapping(Mapping[_KT, _VT]): @abstractmethod def __setitem__(self, __key: _KT, __value: _VT) -> None: ... @@ -665,8 +754,11 @@ class MutableMapping(Mapping[_KT, _VT]): # -- collections.OrderedDict.setdefault # -- collections.ChainMap.setdefault # -- weakref.WeakKeyDictionary.setdefault + @overload - def setdefault(self: MutableMapping[_KT, _T | None], __key: _KT, __default: None = None) -> _T | None: ... + def setdefault(self: MutableMapping[_KT, _T | None], + __key: _KT, __default: None = None) -> _T | None: ... + @overload def setdefault(self, __key: _KT, __default: _VT) -> _VT: ... # 'update' used to take a Union, but using overloading is better. @@ -689,13 +781,19 @@ class MutableMapping(Mapping[_KT, _VT]): # -- peewee.attrdict.__iadd__ # -- weakref.WeakValueDictionary.__ior__ # -- weakref.WeakKeyDictionary.__ior__ + @overload - def update(self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT) -> None: ... + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT) -> None: ... + @overload - def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + def update(self, __m: Iterable[tuple[_KT, _VT]], + **kwargs: _VT) -> None: ... + @overload def update(self, **kwargs: _VT) -> None: ... + Text = str TYPE_CHECKING: bool @@ -703,6 +801,8 @@ TYPE_CHECKING: bool # In stubs, the arguments of the IO class are marked as positional-only. # This differs from runtime, but better reflects the fact that in reality # classes deriving from IO use different names for the arguments. + + class IO(Iterator[AnyStr]): # At runtime these are all abstract properties, # but making them abstract in the stub is hugely disruptive, for not much gain. @@ -753,9 +853,12 @@ class IO(Iterator[AnyStr]): @abstractmethod @overload def writelines(self: IO[str], __lines: Iterable[str]) -> None: ... + @abstractmethod @overload - def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ... + def writelines(self: IO[bytes], + __lines: Iterable[ReadableBuffer]) -> None: ... + @abstractmethod @overload def writelines(self, __lines: Iterable[AnyStr]) -> None: ... @@ -765,15 +868,18 @@ class IO(Iterator[AnyStr]): def __iter__(self) -> Iterator[AnyStr]: ... @abstractmethod def __enter__(self) -> IO[AnyStr]: ... + @abstractmethod def __exit__( self, __type: Type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None ) -> None: ... + class BinaryIO(IO[bytes]): @abstractmethod def __enter__(self) -> BinaryIO: ... + class TextIO(IO[str]): # See comment regarding the @properties in the `IO` class @property @@ -789,6 +895,7 @@ class TextIO(IO[str]): @abstractmethod def __enter__(self) -> TextIO: ... + ByteString: typing_extensions.TypeAlias = bytes | bytearray | memoryview # Functions @@ -834,6 +941,7 @@ if sys.version_info >= (3, 8): else: def get_origin(tp: Any) -> Any | None: ... + @overload def cast(typ: Type[_T], val: Any) -> _T: ... @overload @@ -841,12 +949,16 @@ def cast(typ: str, val: Any) -> Any: ... @overload def cast(typ: object, val: Any) -> Any: ... + if sys.version_info >= (3, 11): def reveal_type(__obj: _T) -> _T: ... def assert_never(__arg: Never) -> Never: ... def assert_type(__val: _T, __typ: Any) -> _T: ... def clear_overloads() -> None: ... - def get_overloads(func: Callable[..., object]) -> Sequence[Callable[..., object]]: ... + + def get_overloads(func: Callable[..., object] + ) -> Sequence[Callable[..., object]]: ... + def dataclass_transform( *, eq_default: bool = True, @@ -859,6 +971,7 @@ if sys.version_info >= (3, 11): # Type constructors + class NamedTuple(tuple[Any, ...]): if sys.version_info < (3, 8): _field_types: ClassVar[collections.OrderedDict[str, type]] @@ -870,10 +983,15 @@ class NamedTuple(tuple[Any, ...]): # So we only add it to the stub on 3.12+. if sys.version_info >= (3, 12): __orig_bases__: ClassVar[tuple[Any, ...]] + @overload - def __init__(self, __typename: str, __fields: Iterable[tuple[str, Any]]) -> None: ... + def __init__(self, __typename: str, + __fields: Iterable[tuple[str, Any]]) -> None: ... + @overload - def __init__(self, __typename: str, __fields: None = None, **kwargs: Any) -> None: ... + def __init__(self, __typename: str, __fields: None = None, + **kwargs: Any) -> None: ... + @classmethod def _make(cls, iterable: Iterable[Any]) -> typing_extensions.Self: ... if sys.version_info >= (3, 8): @@ -885,6 +1003,8 @@ class NamedTuple(tuple[Any, ...]): # Internal mypy fallback type for all typed dicts (does not exist at runtime) # N.B. Keep this mostly in sync with typing_extensions._TypedDict/mypy_extensions._TypedDict + + @type_check_only class _TypedDict(Mapping[str, object], metaclass=ABCMeta): __total__: ClassVar[bool] @@ -895,12 +1015,15 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): # so we only add it to the stub on 3.12+ if sys.version_info >= (3, 12): __orig_bases__: ClassVar[tuple[Any, ...]] + def copy(self) -> typing_extensions.Self: ... # Using Never so that only calls using mypy plugin hook that specialize the signature # can go through. def setdefault(self, k: _Never, default: object) -> object: ... # Mypy plugin hook for 'pop' expects that 'default' has a type variable type. - def pop(self, k: _Never, default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def pop(self, k: _Never, + default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def update(self: _T, __m: _T) -> None: ... def __delitem__(self, k: _Never) -> None: ... def items(self) -> dict_items[str, object]: ... @@ -908,15 +1031,22 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): def values(self) -> dict_values[str, object]: ... if sys.version_info >= (3, 9): @overload - def __or__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __or__( + self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + @overload def __or__(self, __value: dict[str, Any]) -> dict[str, object]: ... + @overload - def __ror__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __ror__( + self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + @overload def __ror__(self, __value: dict[str, Any]) -> dict[str, object]: ... # supposedly incompatible definitions of __or__ and __ior__ - def __ior__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc] + def __ior__( + self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc] + @_final class ForwardRef: @@ -929,7 +1059,8 @@ class ForwardRef: __forward_module__: Any | None if sys.version_info >= (3, 9): # The module and is_class arguments were added in later Python 3.9 versions. - def __init__(self, arg: str, is_argument: bool = True, module: Any | None = None, *, is_class: bool = False) -> None: ... + def __init__(self, arg: str, is_argument: bool = True, module: Any | + None = None, *, is_class: bool = False) -> None: ... else: def __init__(self, arg: str, is_argument: bool = True) -> None: ... @@ -938,7 +1069,8 @@ class ForwardRef: self, globalns: dict[str, Any] | None, localns: dict[str, Any] | None, recursive_guard: frozenset[str] ) -> Any | None: ... else: - def _evaluate(self, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any | None: ... + def _evaluate(self, globalns: dict[str, Any] | None, + localns: dict[str, Any] | None) -> Any | None: ... def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... @@ -946,13 +1078,17 @@ class ForwardRef: def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + if sys.version_info >= (3, 10): def is_typeddict(tp: object) -> bool: ... + def _type_repr(obj: object) -> str: ... + if sys.version_info >= (3, 12): def override(__method: _F) -> _F: ... + @_final class TypeAliasType: def __init__( @@ -960,8 +1096,11 @@ if sys.version_info >= (3, 12): ) -> None: ... @property def __value__(self) -> Any: ... + @property - def __type_params__(self) -> tuple[TypeVar | ParamSpec | TypeVarTuple, ...]: ... + def __type_params__(self) -> tuple[TypeVar | + ParamSpec | TypeVarTuple, ...]: ... + @property def __parameters__(self) -> tuple[Any, ...]: ... @property diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index b5e2341cd020a..57d3e76cae3ce 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -189,12 +189,15 @@ _F = typing.TypeVar("_F", bound=Callable[..., Any]) _TC = typing.TypeVar("_TC", bound=type[object]) # unfortunately we have to duplicate this class definition from typing.pyi or we break pytype + + class _SpecialForm: def __getitem__(self, parameters: Any) -> object: ... if sys.version_info >= (3, 10): def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + # Do not import (and re-export) Protocol or runtime_checkable from # typing module because type checkers need to be able to distinguish # typing.Protocol and typing_extensions.Protocol so they can properly @@ -202,20 +205,27 @@ class _SpecialForm: # on older versions of Python. Protocol: _SpecialForm + def runtime_checkable(cls: _TC) -> _TC: ... + # This alias for above is kept here for backwards compatibility. runtime = runtime_checkable Final: _SpecialForm + def final(f: _F) -> _F: ... + Literal: _SpecialForm + def IntVar(name: str) -> Any: ... # returns a new TypeVar # Internal mypy fallback type for all typed dicts (does not exist at runtime) # N.B. Keep this mostly in sync with typing._TypedDict/mypy_extensions._TypedDict + + @type_check_only class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): __required_keys__: ClassVar[frozenset[str]] @@ -227,7 +237,9 @@ class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): # can go through. def setdefault(self, k: Never, default: object) -> object: ... # Mypy plugin hook for 'pop' expects that 'default' has a type variable type. - def pop(self, k: Never, default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def pop(self, k: Never, + default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def update(self: _T, __m: _T) -> None: ... def items(self) -> dict_items[str, object]: ... def keys(self) -> dict_keys[str, object]: ... @@ -245,11 +257,13 @@ class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): # supposedly incompatible definitions of `__ior__` and `__or__`: def __ior__(self, __value: Self) -> Self: ... # type: ignore[misc] + # TypedDict is a (non-subscriptable) special form. TypedDict: object OrderedDict = _Alias() + def get_type_hints( obj: Callable[..., Any], globalns: dict[str, Any] | None = None, @@ -258,6 +272,7 @@ def get_type_hints( ) -> dict[str, Any]: ... def get_args(tp: Any) -> tuple[Any, ...]: ... + if sys.version_info >= (3, 10): @overload def get_origin(tp: UnionType) -> type[UnionType]: ... @@ -266,19 +281,23 @@ if sys.version_info >= (3, 9): @overload def get_origin(tp: GenericAlias) -> type: ... + @overload def get_origin(tp: ParamSpecArgs | ParamSpecKwargs) -> ParamSpec: ... @overload def get_origin(tp: Any) -> Any | None: ... + Annotated: _SpecialForm _AnnotatedAlias: Any # undocumented + @runtime_checkable class SupportsIndex(Protocol, metaclass=abc.ABCMeta): @abc.abstractmethod def __index__(self) -> int: ... + # New and changed things in 3.10 if sys.version_info >= (3, 10): from typing import ( @@ -338,7 +357,9 @@ else: def assert_never(__arg: Never) -> Never: ... def assert_type(__val: _T, __typ: Any) -> _T: ... def clear_overloads() -> None: ... - def get_overloads(func: Callable[..., object]) -> Sequence[Callable[..., object]]: ... + + def get_overloads(func: Callable[..., object] + ) -> Sequence[Callable[..., object]]: ... Required: _SpecialForm NotRequired: _SpecialForm @@ -363,10 +384,15 @@ else: _field_defaults: ClassVar[dict[str, Any]] _fields: ClassVar[tuple[str, ...]] __orig_bases__: ClassVar[tuple[Any, ...]] + @overload - def __init__(self, typename: str, fields: Iterable[tuple[str, Any]] = ...) -> None: ... + def __init__(self, typename: str, + fields: Iterable[tuple[str, Any]] = ...) -> None: ... + @overload - def __init__(self, typename: str, fields: None = None, **kwargs: Any) -> None: ... + def __init__(self, typename: str, fields: None = None, + **kwargs: Any) -> None: ... + @classmethod def _make(cls, iterable: Iterable[Any]) -> Self: ... if sys.version_info >= (3, 8): @@ -380,6 +406,8 @@ else: # The `default` parameter was added to TypeVar, ParamSpec, and TypeVarTuple (PEP 696) # The `infer_variance` parameter was added to TypeVar in 3.12 (PEP 695) # typing_extensions.override (PEP 698) + + @final class TypeVar: @property @@ -396,6 +424,7 @@ class TypeVar: def __infer_variance__(self) -> bool: ... @property def __default__(self) -> Any | None: ... + def __init__( self, name: str, @@ -412,6 +441,7 @@ class TypeVar: if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... + @final class ParamSpec: @property @@ -426,6 +456,7 @@ class ParamSpec: def __infer_variance__(self) -> bool: ... @property def __default__(self) -> Any | None: ... + def __init__( self, name: str, @@ -440,6 +471,7 @@ class ParamSpec: @property def kwargs(self) -> ParamSpecKwargs: ... + @final class TypeVarTuple: @property @@ -449,7 +481,10 @@ class TypeVarTuple: def __init__(self, name: str, *, default: Any | None = None) -> None: ... def __iter__(self) -> Any: ... # Unpack[Self] -def deprecated(__msg: str, *, category: type[Warning] | None = ..., stacklevel: int = 1) -> Callable[[_T], _T]: ... + +def deprecated(__msg: str, *, category: type[Warning] | + None = ..., stacklevel: int = 1) -> Callable[[_T], _T]: ... + if sys.version_info >= (3, 12): from collections.abc import Buffer as Buffer @@ -458,6 +493,7 @@ if sys.version_info >= (3, 12): else: def override(__arg: _F) -> _F: ... def get_original_bases(__cls: type) -> tuple[Any, ...]: ... + @final class TypeAliasType: def __init__( @@ -465,8 +501,11 @@ else: ) -> None: ... @property def __value__(self) -> Any: ... + @property - def __type_params__(self) -> tuple[TypeVar | ParamSpec | TypeVarTuple, ...]: ... + def __type_params__(self) -> tuple[TypeVar | + ParamSpec | TypeVarTuple, ...]: ... + @property def __parameters__(self) -> tuple[Any, ...]: ... @property @@ -491,6 +530,7 @@ else: def is_protocol(__tp: type) -> bool: ... def get_protocol_members(__tp: type) -> frozenset[str]: ... + class Doc: documentation: str def __init__(self, __documentation: str) -> None: ... From 8a0aee6aa4c818b10c4e46d7fa0a2ed8e4598dc3 Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 07:28:40 +0300 Subject: [PATCH 02/28] Checking if all elements iin the TupleExpr are enums --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 8ece558b15d5b..33f3ed917be3a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5830,7 +5830,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: all_literal_enum = all( self.is_literal_enum(element_expr) for element_expr in right_expr.items ) - if all_literal_enum: + if all_literal_enum: # Set if_map for the entire tuple if_map = {} else_map = None From 23d13f950e36762310374439ad2f057dd151a3f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Dec 2023 04:35:10 +0000 Subject: [PATCH 03/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- feature-test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/feature-test.py b/feature-test.py index 0bdf337f0ec56..6ef1dc9e50bc1 100644 --- a/feature-test.py +++ b/feature-test.py @@ -14,7 +14,6 @@ class MyEnum(Enum): def my_function(a: MyEnum) -> bool: - print(type((MyEnum.B, MyEnum.C))) if a == MyEnum.A: print(type(a)) From f73e3fa187691449067816886ed72aaf4a884d8c Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 07:44:54 +0300 Subject: [PATCH 04/28] Fixed indent issues --- feature-test.py | 27 --------------------------- mypy/checker.py | 10 +++------- 2 files changed, 3 insertions(+), 34 deletions(-) delete mode 100644 feature-test.py diff --git a/feature-test.py b/feature-test.py deleted file mode 100644 index 0bdf337f0ec56..0000000000000 --- a/feature-test.py +++ /dev/null @@ -1,27 +0,0 @@ -from enum import Enum -from typing_extensions import assert_never - - -class MyEnum(Enum): - A = 1 - B = 2 - C = 3 - - -reveal_type(MyEnum.A) -reveal_type(MyEnum.B) -reveal_type(MyEnum.C) - - -def my_function(a: MyEnum) -> bool: - - print(type((MyEnum.B, MyEnum.C))) - if a == MyEnum.A: - print(type(a)) - return True - elif a in (MyEnum.B, MyEnum.C): - return False - assert_never(a) - - -print(my_function(MyEnum.A)) diff --git a/mypy/checker.py b/mypy/checker.py index 33f3ed917be3a..b8061c2f18b7e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5830,16 +5830,12 @@ def has_no_custom_eq_checks(t: Type) -> bool: all_literal_enum = all( self.is_literal_enum(element_expr) for element_expr in right_expr.items ) - if all_literal_enum: + if all_literal_enum: # Set if_map for the entire tuple if_map = {} else_map = None - - # print(if_type == else_type, "\n") - # print(left_expr, "\n") - # print(right_expr, "\n") - - # print(if_map, else_map) + else: + if_map, else_map = {}, {} else: if left_index in narrowable_operand_index_to_hash: From e14adcbbed3419f690a2c65fbebf6607c008445f Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 08:21:04 +0300 Subject: [PATCH 05/28] Modified operators.py --- mypy/operators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mypy/operators.py b/mypy/operators.py index 037ebcd97e937..b6f6599ef65ff 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -31,7 +31,8 @@ op_methods_to_symbols: Final = {v: k for (k, v) in op_methods.items()} -ops_falling_back_to_cmp: Final = {"__ne__", "__eq__", "__lt__", "__le__", "__gt__", "__ge__"} +ops_falling_back_to_cmp: Final = { + "__ne__", "__eq__", "__lt__", "__le__", "__gt__", "__ge__"} ops_with_inplace_method: Final = { @@ -50,7 +51,8 @@ ">>", } -inplace_operator_methods: Final = {"__i" + op_methods[op][2:] for op in ops_with_inplace_method} +inplace_operator_methods: Final = { + "__i" + op_methods[op][2:] for op in ops_with_inplace_method} reverse_op_methods: Final = { "__add__": "__radd__", @@ -119,8 +121,6 @@ "!=": "==", "is": "is not", "is not": "is", - "in": "not in", - "not in": "in", "<": ">=", "<=": ">", ">": "<=", From 018cf15ec95981bd575ef7b4652dda4c84378473 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Dec 2023 05:21:32 +0000 Subject: [PATCH 06/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/operators.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mypy/operators.py b/mypy/operators.py index b6f6599ef65ff..d1f050b58faeb 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -31,8 +31,7 @@ op_methods_to_symbols: Final = {v: k for (k, v) in op_methods.items()} -ops_falling_back_to_cmp: Final = { - "__ne__", "__eq__", "__lt__", "__le__", "__gt__", "__ge__"} +ops_falling_back_to_cmp: Final = {"__ne__", "__eq__", "__lt__", "__le__", "__gt__", "__ge__"} ops_with_inplace_method: Final = { @@ -51,8 +50,7 @@ ">>", } -inplace_operator_methods: Final = { - "__i" + op_methods[op][2:] for op in ops_with_inplace_method} +inplace_operator_methods: Final = {"__i" + op_methods[op][2:] for op in ops_with_inplace_method} reverse_op_methods: Final = { "__add__": "__radd__", From 38102c7f7243d91810d9d47cad1a0f305965df3c Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 09:36:47 +0300 Subject: [PATCH 07/28] Fixed formatting issues --- mypy/typeshed/stdlib/typing.pyi | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 38ea6d61514d0..b942c35894d48 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -649,12 +649,10 @@ class MutableSet(AbstractSet[_T]): def clear(self) -> None: ... def pop(self) -> _T: ... def remove(self, value: _T) -> None: ... - def __ior__( - self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] # nopep8 def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... - def __ixor__( - self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] # nopep8 def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... @@ -675,18 +673,14 @@ class ItemsView(MappingView, AbstractSet[tuple[_KT_co, _VT_co]], Generic[_KT_co, if sys.version_info >= (3, 8): def __reversed__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ... - def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, - _VT_co] | _T]: ... - def __ror__(self, other: Iterable[_T] - ) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + def __ror__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 def __sub__(self, other: Iterable[Any]) -> set[tuple[_KT_co, _VT_co]]: ... def __rsub__(self, other: Iterable[_T]) -> set[_T]: ... - def __xor__(self, other: Iterable[_T] - ) -> set[tuple[_KT_co, _VT_co] | _T]: ... - def __rxor__(self, other: Iterable[_T] - ) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __xor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + def __rxor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 class KeysView(MappingView, AbstractSet[_KT_co]): @@ -1031,21 +1025,18 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): def values(self) -> dict_values[str, object]: ... if sys.version_info >= (3, 9): @overload - def __or__( - self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __or__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # nopep8 @overload def __or__(self, __value: dict[str, Any]) -> dict[str, object]: ... @overload - def __ror__( - self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __ror__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # nopep8 @overload def __ror__(self, __value: dict[str, Any]) -> dict[str, object]: ... # supposedly incompatible definitions of __or__ and __ior__ - def __ior__( - self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc] + def __ior__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc]# nopep8 @_final From 526f43041b1f20fc858c0c747963fc45172f9d58 Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 21:32:03 +0300 Subject: [PATCH 08/28] Resolved merge conflicts in typinig_extensions.pyi --- mypy/checker.py | 3 +-- mypy/typeshed/stdlib/typing_extensions.pyi | 9 +++++++-- test-data/unit/check-isinstance.test | 14 -------------- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ad55dd2bd7974..2cf2117c85606 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5841,8 +5841,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Set if_map for the entire tuple if_map = {} else_map = None - else: - if_map, else_map = {}, {} + else: if left_index in narrowable_operand_index_to_hash: diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index 57d3e76cae3ce..ac08d9d402d79 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -482,8 +482,13 @@ class TypeVarTuple: def __iter__(self) -> Any: ... # Unpack[Self] -def deprecated(__msg: str, *, category: type[Warning] | - None = ..., stacklevel: int = 1) -> Callable[[_T], _T]: ... +class deprecated: + message: str + category: type[Warning] | None + stacklevel: int + def __init__(self, __message: str, *, category: type[Warning] | None = ..., stacklevel: int = 1) -> None: ... # nopep8 + + def __call__(self, __arg: _T) -> _T: ... if sys.version_info >= (3, 12): diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index b7ee38b69d00f..178bbcf1aac78 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1961,20 +1961,6 @@ if x in nested_any: [builtins fixtures/list.pyi] [out] -[case testNarrowTypeAfterInTuple] -from typing import Optional -class A: pass -class B(A): pass -class C(A): pass - -y: Optional[B] -if y in (B(), C()): - reveal_type(y) # N: Revealed type is "__main__.B" -else: - reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" -[builtins fixtures/tuple.pyi] -[out] - [case testNarrowTypeAfterInNamedTuple] from typing import NamedTuple, Optional class NT(NamedTuple): From 605f66c7fec87742215910dd105f2b3ab6d0c82e Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 21:35:21 +0300 Subject: [PATCH 09/28] Resolved merge conflicts in typinig_extensions.pyi --- mypy/typeshed/stdlib/typing_extensions.pyi | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index ac08d9d402d79..edc1e1a67aa78 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -487,8 +487,7 @@ class deprecated: category: type[Warning] | None stacklevel: int def __init__(self, __message: str, *, category: type[Warning] | None = ..., stacklevel: int = 1) -> None: ... # nopep8 - - def __call__(self, __arg: _T) -> _T: ... + def __call__(self, __arg: _T) -> _T: ... # nopep8 if sys.version_info >= (3, 12): From b56ab600e8bd3ea6ab40d2b82d842569cf078ddf Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Fri, 1 Dec 2023 22:33:01 +0300 Subject: [PATCH 10/28] Resolved formatting issues --- mypy/typeshed/stdlib/typing_extensions.pyi | 2 -- pyproject.toml | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index edc1e1a67aa78..6ded6518bd806 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -481,7 +481,6 @@ class TypeVarTuple: def __init__(self, name: str, *, default: Any | None = None) -> None: ... def __iter__(self) -> Any: ... # Unpack[Self] - class deprecated: message: str category: type[Warning] | None @@ -489,7 +488,6 @@ class deprecated: def __init__(self, __message: str, *, category: type[Warning] | None = ..., stacklevel: int = 1) -> None: ... # nopep8 def __call__(self, __arg: _T) -> _T: ... # nopep8 - if sys.version_info >= (3, 12): from collections.abc import Buffer as Buffer from types import get_original_bases as get_original_bases diff --git a/pyproject.toml b/pyproject.toml index c43253fed9825..a83643e75a4af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ ignore = [ "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "UP032", # 'f-string always preferable to format' is controversial + "E301", # Add missing blank line. + "E231", # Add missing whitespace. + "E302", # Add missing 2 blank lines. ] unfixable = [ From 090bc6fd86a652efb6fccc08542aa6c90f9d6abc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Dec 2023 19:49:55 +0000 Subject: [PATCH 11/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 2cf2117c85606..d4835aa3e34f7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5842,7 +5842,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map = {} else_map = None - else: if left_index in narrowable_operand_index_to_hash: # print("left in") From 2df876c780f1722418c0f27cd8d97321f9369e3e Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Sat, 2 Dec 2023 07:27:12 +0300 Subject: [PATCH 12/28] Modified pyproject.tomll for ruff error --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a83643e75a4af..c43253fed9825 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,6 @@ ignore = [ "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "UP032", # 'f-string always preferable to format' is controversial - "E301", # Add missing blank line. - "E231", # Add missing whitespace. - "E302", # Add missing 2 blank lines. ] unfixable = [ From 4a9e582827a6f9dcb3d77bd873fe7dca21bf35ac Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Mon, 4 Dec 2023 09:21:55 +0300 Subject: [PATCH 13/28] Made changes to handling of in for open-source errors --- mypy/checker.py | 83 ++++++++++++++++++++++++------------------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d4835aa3e34f7..ae63fb7f4a263 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5840,54 +5840,53 @@ def has_no_custom_eq_checks(t: Type) -> bool: if all_literal_enum: # Set if_map for the entire tuple if_map = {} - else_map = None - - else: - if left_index in narrowable_operand_index_to_hash: - # print("left in") - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) + else_map = None + + if left_index in narrowable_operand_index_to_hash: + # print("left in") + # We only try and narrow away 'None' for now + if is_overlapping_none(item_type): + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) + ) + # print("done5", if_map, else_map) + if ( + collection_item_type is not None + and not is_overlapping_none(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" ) - # print("done5", if_map, else_map) - if ( - collection_item_type is not None - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types( - item_type, collection_item_type - ) - ): - if_map[operands[left_index]] = remove_optional(item_type) + and is_overlapping_erased_types( + item_type, collection_item_type + ) + ): + if_map[operands[left_index]] = remove_optional(item_type) - # print("done6", if_map, else_map) + # print("done6", if_map, else_map) - if right_index in narrowable_operand_index_to_hash: - # print("right in") - if_type, else_type = self.conditional_types_for_iterable( - item_type, iterable_type - ) - expr = operands[right_index] - # print("done7", if_map, else_map) + if right_index in narrowable_operand_index_to_hash: + # print("right in") + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + # print("done7", if_map, else_map) - if if_type is None: - if_map = None - # print("done8", if_map, else_map) - else: - if_map[expr] = if_type - # print("done9", if_map, else_map) + if if_type is None: + if_map = None + # print("done8", if_map, else_map) + else: + if_map[expr] = if_type + # print("done9", if_map, else_map) - if else_type is None: - else_map = None - # print("done11", if_map, else_map) - else: - else_map[expr] = else_type + if else_type is None: + else_map = None + # print("done11", if_map, else_map) + else: + else_map[expr] = else_type - # print("done12", if_map, else_map) + # print("done12", if_map, else_map) else: if_map = {} From 9bd590e405844b4c221de456fb0ef35204119d61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Dec 2023 06:22:32 +0000 Subject: [PATCH 14/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ae63fb7f4a263..b7cd6f5a791fb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5840,7 +5840,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: if all_literal_enum: # Set if_map for the entire tuple if_map = {} - else_map = None + else_map = None if left_index in narrowable_operand_index_to_hash: # print("left in") @@ -5857,9 +5857,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: isinstance(collection_item_type, Instance) and collection_item_type.type.fullname == "builtins.object" ) - and is_overlapping_erased_types( - item_type, collection_item_type - ) + and is_overlapping_erased_types(item_type, collection_item_type) ): if_map[operands[left_index]] = remove_optional(item_type) From 6dd351ca584d7d3492d8233eae74958e30b0fffb Mon Sep 17 00:00:00 2001 From: Malika1109 Date: Mon, 4 Dec 2023 09:26:06 +0300 Subject: [PATCH 15/28] Changed handling of in operator --- mypy/checker.py | 79 +++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ae63fb7f4a263..3289418c1812a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5842,51 +5842,52 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map = {} else_map = None - if left_index in narrowable_operand_index_to_hash: - # print("left in") - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) - # print("done5", if_map, else_map) - if ( - collection_item_type is not None - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types( - item_type, collection_item_type + else: + if left_index in narrowable_operand_index_to_hash: + # print("left in") + # We only try and narrow away 'None' for now + if is_overlapping_none(item_type): + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) ) - ): - if_map[operands[left_index]] = remove_optional(item_type) + # print("done5", if_map, else_map) + if ( + collection_item_type is not None + and not is_overlapping_none(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types( + item_type, collection_item_type + ) + ): + if_map[operands[left_index]] = remove_optional(item_type) - # print("done6", if_map, else_map) + # print("done6", if_map, else_map) - if right_index in narrowable_operand_index_to_hash: - # print("right in") - if_type, else_type = self.conditional_types_for_iterable( - item_type, iterable_type - ) - expr = operands[right_index] - # print("done7", if_map, else_map) + if right_index in narrowable_operand_index_to_hash: + # print("right in") + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + # print("done7", if_map, else_map) - if if_type is None: - if_map = None - # print("done8", if_map, else_map) - else: - if_map[expr] = if_type - # print("done9", if_map, else_map) + if if_type is None: + if_map = None + # print("done8", if_map, else_map) + else: + if_map[expr] = if_type + # print("done9", if_map, else_map) - if else_type is None: - else_map = None - # print("done11", if_map, else_map) - else: - else_map[expr] = else_type + if else_type is None: + else_map = None + # print("done11", if_map, else_map) + else: + else_map[expr] = else_type - # print("done12", if_map, else_map) + # print("done12", if_map, else_map) else: if_map = {} From 30b6caed2e3f3e1ffff18a8331d02e78d91fba88 Mon Sep 17 00:00:00 2001 From: aisha kh Date: Tue, 5 Dec 2023 08:47:50 +0300 Subject: [PATCH 16/28] successfully created the check-in-exhaustive-checking-3.test file for extra testing --- test-data/unit/check-in-exhaustive-checking-3.test | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test-data/unit/check-in-exhaustive-checking-3.test diff --git a/test-data/unit/check-in-exhaustive-checking-3.test b/test-data/unit/check-in-exhaustive-checking-3.test new file mode 100644 index 0000000000000..e69de29bb2d1d From fc611bcd2c0f19f53ee6eebea1b83fabf48ccd34 Mon Sep 17 00:00:00 2001 From: aisha kh Date: Tue, 5 Dec 2023 08:49:22 +0300 Subject: [PATCH 17/28] removed the 3 in the name of the file for clarity + began the test with following format --- test-data/unit/check-in-exhaustive-checking-3.test | 0 test-data/unit/check-in-exhaustive-checking.test | 1 + 2 files changed, 1 insertion(+) delete mode 100644 test-data/unit/check-in-exhaustive-checking-3.test create mode 100644 test-data/unit/check-in-exhaustive-checking.test diff --git a/test-data/unit/check-in-exhaustive-checking-3.test b/test-data/unit/check-in-exhaustive-checking-3.test deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test new file mode 100644 index 0000000000000..24aec4df0d85a --- /dev/null +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -0,0 +1 @@ +[case testInExhaustiveChecking] \ No newline at end of file From 4d38b0b94bf6bc3ba5782d6ec72d87b9b00cb130 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 05:49:50 +0000 Subject: [PATCH 18/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test-data/unit/check-in-exhaustive-checking.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test index 24aec4df0d85a..2b174ff10e2ee 100644 --- a/test-data/unit/check-in-exhaustive-checking.test +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -1 +1 @@ -[case testInExhaustiveChecking] \ No newline at end of file +[case testInExhaustiveChecking] From 4420ed81807898b34d62dcc5785bfb51e1b50b87 Mon Sep 17 00:00:00 2001 From: aisha kh Date: Tue, 5 Dec 2023 09:45:08 +0300 Subject: [PATCH 19/28] testing is done without any comments, want to make sure it passes before merging and commenting --- .../unit/check-in-exhaustive-checking.test | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test index 24aec4df0d85a..7449cd66084ca 100644 --- a/test-data/unit/check-in-exhaustive-checking.test +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -1 +1,21 @@ -[case testInExhaustiveChecking] \ No newline at end of file +[case testInExhaustiveChecking] + +[builtins fixtures/tuple.pyi] +from typing_extensions import NoReturn +from enum import Enum +from typing_extensions import assert_never + +class MyEnum(Enum): + A = 1 + B = 2 + C = 3 + +def test_function(a: MyEnum) -> bool: + if a == MyEnum.A: + return True + elif a in (MyEnum.B, MyEnum.C): + return False + assert_never(a) + +test_function(MyEnum.A) +# E: Argument 1 to "assert_never" has incompatible type "Literal[MyEnum.B, MyEnum.C]"; expected "NoReturn" \ No newline at end of file From b4bf17a12f3ac075a5f5b0fb605b94b10e605834 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 06:50:11 +0000 Subject: [PATCH 20/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test-data/unit/check-in-exhaustive-checking.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test index 7449cd66084ca..5603ecef10922 100644 --- a/test-data/unit/check-in-exhaustive-checking.test +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -8,7 +8,7 @@ from typing_extensions import assert_never class MyEnum(Enum): A = 1 B = 2 - C = 3 + C = 3 def test_function(a: MyEnum) -> bool: if a == MyEnum.A: @@ -18,4 +18,4 @@ def test_function(a: MyEnum) -> bool: assert_never(a) test_function(MyEnum.A) -# E: Argument 1 to "assert_never" has incompatible type "Literal[MyEnum.B, MyEnum.C]"; expected "NoReturn" \ No newline at end of file +# E: Argument 1 to "assert_never" has incompatible type "Literal[MyEnum.B, MyEnum.C]"; expected "NoReturn" From 55d3b2f5348c1ec100604b99eda4d8b480e74639 Mon Sep 17 00:00:00 2001 From: Sara Al-Saloos Date: Tue, 5 Dec 2023 15:27:40 +0300 Subject: [PATCH 21/28] removed unneeded comments & code --- mypy/checker.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d4835aa3e34f7..7f0fd5076b756 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5800,7 +5800,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: is_valid_target, coerce_only_in_literal_context, ) - # print("done2", if_map, else_map) # Strictly speaking, we should also skip this check if the objects in the expr # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) @@ -5813,10 +5812,8 @@ def has_no_custom_eq_checks(t: Type) -> bool: expr_indices, narrowable_operand_index_to_hash.keys(), ) - # print("done3", if_map, else_map) - # If we haven't been able to narrow types yet, we might be dealing with a - # explicit type(x) == some_type check + # If we haven't been able to narrow types yet, we might be dealing with a explicit type(x) == some_type check if if_map == {} and else_map == {}: if_map, else_map = self.find_type_equals_check(node, expr_indices) # print("done4", if_map, else_map) @@ -5825,13 +5822,9 @@ def has_no_custom_eq_checks(t: Type) -> bool: left_index, right_index = expr_indices item_type = operand_types[left_index] iterable_type = operand_types[right_index] - # added by me right_expr = operands[right_index] if_map, else_map = {}, {} - # print(left_index, "\n") - # print(right_index, "\n") - # print(narrowable_operand_index_to_hash, "\n") if isinstance(right_expr, TupleExpr): all_literal_enum = all( @@ -5844,13 +5837,11 @@ def has_no_custom_eq_checks(t: Type) -> bool: else: if left_index in narrowable_operand_index_to_hash: - # print("left in") # We only try and narrow away 'None' for now if is_overlapping_none(item_type): collection_item_type = get_proper_type( builtin_item_type(iterable_type) ) - # print("done5", if_map, else_map) if ( collection_item_type is not None and not is_overlapping_none(collection_item_type) @@ -5864,41 +5855,27 @@ def has_no_custom_eq_checks(t: Type) -> bool: ): if_map[operands[left_index]] = remove_optional(item_type) - # print("done6", if_map, else_map) - if right_index in narrowable_operand_index_to_hash: - # print("right in") if_type, else_type = self.conditional_types_for_iterable( item_type, iterable_type ) expr = operands[right_index] - # print("done7", if_map, else_map) if if_type is None: if_map = None - # print("done8", if_map, else_map) else: if_map[expr] = if_type - # print("done9", if_map, else_map) - if else_type is None: else_map = None - # print("done11", if_map, else_map) else: else_map[expr] = else_type - - # print("done12", if_map, else_map) - else: if_map = {} else_map = {} - # print("done13", if_map, else_map) if operator in {"is not", "!=", "not in"}: if_map, else_map = else_map, if_map - # print("done14", if_map, else_map) - partial_type_maps.append((if_map, else_map)) # If we have found non-trivial restrictions from the regular comparisons, From 16ca2960059b0a3230c7ef2f3fd31395eac91467 Mon Sep 17 00:00:00 2001 From: Gulnaz Serikbay <68372186+GulnazSerikbay@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:04:47 -0500 Subject: [PATCH 22/28] Delete feature-test.py for PR not needed for PR, since tests should cover this --- feature-test.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 feature-test.py diff --git a/feature-test.py b/feature-test.py deleted file mode 100644 index 4b9de9efd6abc..0000000000000 --- a/feature-test.py +++ /dev/null @@ -1,19 +0,0 @@ -from enum import Enum -from typing_extensions import assert_never - - -class MyEnum(Enum): - A = 1 - B = 2 - C = 3 - - -def my_function(a: MyEnum) -> bool: - if a == MyEnum.A: - return True - elif a in (MyEnum.B, MyEnum.C): - return False - assert_never(a) - - -my_function(MyEnum.A) From e7694a8091129d72f2572c17fb59d01249db29c4 Mon Sep 17 00:00:00 2001 From: Maha Alnassr Date: Wed, 6 Dec 2023 14:24:30 +0300 Subject: [PATCH 23/28] added comments for clarity and minor changes to in exhaustive checking test file --- test-data/unit/check-in-exhaustive-checking.test | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test index 5603ecef10922..68211dbb5174b 100644 --- a/test-data/unit/check-in-exhaustive-checking.test +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -1,8 +1,14 @@ [case testInExhaustiveChecking] [builtins fixtures/tuple.pyi] + +# Import NoReturn for indicating functions that shouldn't return normally from typing_extensions import NoReturn + +# Import Enum to give names to specific values and link them to unique, constant values from enum import Enum + +# Import assert_never for ensuring all cases in an if-elif chain are handled from typing_extensions import assert_never class MyEnum(Enum): From 64a8ab56af084b7ed527806e75f04c1c4661ac57 Mon Sep 17 00:00:00 2001 From: Gulnaz Serikbay Date: Wed, 6 Dec 2023 22:16:11 +0300 Subject: [PATCH 24/28] added some enum type extraction + test --- mypy/checker.py | 14 +++++++++++++ .../unit/check-in-exhaustive-checking.test | 21 +++++++++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 7f0fd5076b756..fcf02d97868c2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -181,6 +181,7 @@ tuple_fallback, ) from mypy.types import ( + ENUM_REMOVED_PROPS, ANY_STRATEGY, MYPYC_NATIVE_INT_NAMES, OVERLOAD_NAMES, @@ -5831,6 +5832,19 @@ def has_no_custom_eq_checks(t: Type) -> bool: self.is_literal_enum(element_expr) for element_expr in right_expr.items ) if all_literal_enum: + + new_items = [] + if len(right_expr.items) != 0: + # extract the enum type in the context + enum_type = get_proper_type(self.lookup_type_or_none(right_expr.items[0].expr)).fallback + for name, symbol in enum_type.type.names.items(): + if not isinstance(symbol.node, Var): + continue + # Skip these since Enum will remove it + if name in ENUM_REMOVED_PROPS: + continue + new_items.append(LiteralType(name, enum_type)) + # Set if_map for the entire tuple if_map = {} else_map = None diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test index 5603ecef10922..e12d27f2d19ef 100644 --- a/test-data/unit/check-in-exhaustive-checking.test +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -1,8 +1,8 @@ [case testInExhaustiveChecking] [builtins fixtures/tuple.pyi] -from typing_extensions import NoReturn from enum import Enum +import pytest from typing_extensions import assert_never class MyEnum(Enum): @@ -10,12 +10,25 @@ class MyEnum(Enum): B = 2 C = 3 -def test_function(a: MyEnum) -> bool: +def my_function(a: MyEnum) -> bool: if a == MyEnum.A: return True elif a in (MyEnum.B, MyEnum.C): return False assert_never(a) -test_function(MyEnum.A) -# E: Argument 1 to "assert_never" has incompatible type "Literal[MyEnum.B, MyEnum.C]"; expected "NoReturn" +class MyEnum2(Enum): + A = 1 + B = 2 + +def my_function2(a: MyEnum) -> bool: + if a in (MyEnum.A, MyEnum.B): + return False + assert_never(a) + +# Test cases +def test_my_function(): + # Test for MyEnum.A + assert my_function(MyEnum.A) == True, "Failed for MyEnum.A" + + assert my_function2(MyEnum2.A) == True, "Failed for MyEnum.A" From 5a060f72493d69d03e1598fa2626223e6446702d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Dec 2023 19:28:14 +0000 Subject: [PATCH 25/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index fcf02d97868c2..a2d114c8b10c7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -181,8 +181,8 @@ tuple_fallback, ) from mypy.types import ( - ENUM_REMOVED_PROPS, ANY_STRATEGY, + ENUM_REMOVED_PROPS, MYPYC_NATIVE_INT_NAMES, OVERLOAD_NAMES, AnyType, @@ -5832,11 +5832,12 @@ def has_no_custom_eq_checks(t: Type) -> bool: self.is_literal_enum(element_expr) for element_expr in right_expr.items ) if all_literal_enum: - new_items = [] if len(right_expr.items) != 0: # extract the enum type in the context - enum_type = get_proper_type(self.lookup_type_or_none(right_expr.items[0].expr)).fallback + enum_type = get_proper_type( + self.lookup_type_or_none(right_expr.items[0].expr) + ).fallback for name, symbol in enum_type.type.names.items(): if not isinstance(symbol.node, Var): continue From 814410511ce6a6fd9e00ec123e8254b1bcebc00b Mon Sep 17 00:00:00 2001 From: Gulnaz Serikbay Date: Wed, 6 Dec 2023 22:34:04 +0300 Subject: [PATCH 26/28] small fix --- mypy/checker.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index fcf02d97868c2..17b513b7533d1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5832,7 +5832,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: self.is_literal_enum(element_expr) for element_expr in right_expr.items ) if all_literal_enum: - + # Set if_map for the entire tuple + if_map = {} + else_map = None + else: new_items = [] if len(right_expr.items) != 0: # extract the enum type in the context @@ -5845,10 +5848,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: continue new_items.append(LiteralType(name, enum_type)) - # Set if_map for the entire tuple - if_map = {} - else_map = None - else: if left_index in narrowable_operand_index_to_hash: # We only try and narrow away 'None' for now From 046370e6f0af561a5213d604331d259e4fd41eca Mon Sep 17 00:00:00 2001 From: Gulnaz Serikbay Date: Wed, 6 Dec 2023 22:38:01 +0300 Subject: [PATCH 27/28] removed some code --- mypy/checker.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 50e9a41b0f759..e1840429f8f48 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5835,21 +5835,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Set if_map for the entire tuple if_map = {} else_map = None - else: - new_items = [] - if len(right_expr.items) != 0: - # extract the enum type in the context - enum_type = get_proper_type( - self.lookup_type_or_none(right_expr.items[0].expr) - ).fallback - for name, symbol in enum_type.type.names.items(): - if not isinstance(symbol.node, Var): - continue - # Skip these since Enum will remove it - if name in ENUM_REMOVED_PROPS: - continue - new_items.append(LiteralType(name, enum_type)) - + else: if left_index in narrowable_operand_index_to_hash: # We only try and narrow away 'None' for now From 06c8d98b7fc8cbe6207879aac1306d3f300d4185 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Dec 2023 19:38:26 +0000 Subject: [PATCH 28/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e1840429f8f48..7f0fd5076b756 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -182,7 +182,6 @@ ) from mypy.types import ( ANY_STRATEGY, - ENUM_REMOVED_PROPS, MYPYC_NATIVE_INT_NAMES, OVERLOAD_NAMES, AnyType, @@ -5835,7 +5834,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Set if_map for the entire tuple if_map = {} else_map = None - + else: if left_index in narrowable_operand_index_to_hash: # We only try and narrow away 'None' for now