diff --git a/feature-test.py b/feature-test.py new file mode 100644 index 0000000000000..4b9de9efd6abc --- /dev/null +++ b/feature-test.py @@ -0,0 +1,19 @@ +from enum import Enum +from typing_extensions import assert_never + + +class MyEnum(Enum): + A = 1 + B = 2 + C = 3 + + +def my_function(a: MyEnum) -> bool: + if a == MyEnum.A: + return True + elif a in (MyEnum.B, MyEnum.C): + return False + assert_never(a) + + +my_function(MyEnum.A) diff --git a/mypy/checker.py b/mypy/checker.py index 7c6f59fafdc81..7f0fd5076b756 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1247,7 +1247,8 @@ def check_func_def( if ( arg_type.variance == COVARIANT and defn.name not in ("__init__", "__new__", "__post_init__") - and not is_private(defn.name) # private methods are not inherited + # private methods are not inherited + and not is_private(defn.name) ): ctx: Context = arg_type if ctx.line < 0: @@ -3112,7 +3113,8 @@ def check_compatibility_all_supers( if ( isinstance(lvalue_node, Var) and lvalue.kind in (MDEF, None) - and len(lvalue_node.info.bases) > 0 # None for Vars defined via self + # None for Vars defined via self + and len(lvalue_node.info.bases) > 0 ): for base in lvalue_node.info.mro[1:]: tnode = base.names.get(lvalue_node.name) @@ -5708,6 +5710,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM operands = [collapse_walrus(x) for x in node.operands] operand_types = [] narrowable_operand_index_to_hash = {} + # print(operands) for i, expr in enumerate(operands): if not self.has_type(expr): return {}, {} @@ -5810,49 +5813,62 @@ def has_no_custom_eq_checks(t: Type) -> bool: narrowable_operand_index_to_hash.keys(), ) - # If we haven't been able to narrow types yet, we might be dealing with a - # explicit type(x) == some_type check + # If we haven't been able to narrow types yet, we might be dealing with a explicit type(x) == some_type check if if_map == {} and else_map == {}: if_map, else_map = self.find_type_equals_check(node, expr_indices) + # print("done4", if_map, else_map) elif operator in {"in", "not in"}: assert len(expr_indices) == 2 left_index, right_index = expr_indices item_type = operand_types[left_index] iterable_type = operand_types[right_index] + right_expr = operands[right_index] if_map, else_map = {}, {} - if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) - if ( - collection_item_type is not None - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types(item_type, collection_item_type) - ): - if_map[operands[left_index]] = remove_optional(item_type) - - if right_index in narrowable_operand_index_to_hash: - if_type, else_type = self.conditional_types_for_iterable( - item_type, iterable_type + if isinstance(right_expr, TupleExpr): + all_literal_enum = all( + self.is_literal_enum(element_expr) for element_expr in right_expr.items ) - expr = operands[right_index] - if if_type is None: - if_map = None - else: - if_map[expr] = if_type - if else_type is None: + if all_literal_enum: + # Set if_map for the entire tuple + if_map = {} else_map = None - else: - else_map[expr] = else_type + else: + if left_index in narrowable_operand_index_to_hash: + # We only try and narrow away 'None' for now + if is_overlapping_none(item_type): + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) + ) + if ( + collection_item_type is not None + and not is_overlapping_none(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types( + item_type, collection_item_type + ) + ): + if_map[operands[left_index]] = remove_optional(item_type) + + if right_index in narrowable_operand_index_to_hash: + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + + if if_type is None: + if_map = None + else: + if_map[expr] = if_type + if else_type is None: + else_map = None + else: + else_map[expr] = else_type else: if_map = {} else_map = {} @@ -6151,6 +6167,7 @@ def refine_identity_comparison_expression( expressions in the chain to a Literal type. Performing this coercion is sometimes too aggressive of a narrowing, depending on context. """ + should_coerce = True if coerce_only_in_literal_context: @@ -6246,9 +6263,6 @@ def should_coerce_inner(typ: Type) -> bool: if sum_type_name is not None: expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name) - # We intentionally use 'conditional_types' directly here instead of - # 'self.conditional_types_with_intersection': we only compute ad-hoc - # intersections when working with pure instances. types = conditional_types(expr_type, target_type) partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) @@ -7196,7 +7210,10 @@ class Foo(Enum): unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single primitive unit. """ + if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): + # print(n, isinstance(n, MemberExpr)) + # print(n, isinstance(n.expr, NameExpr)) return False parent_type = self.lookup_type_or_none(n.expr) @@ -7502,7 +7519,8 @@ def builtin_item_type(tp: Type) -> Type | None: else: normalized_items.append(it) if all(not isinstance(it, AnyType) for it in get_proper_types(normalized_items)): - return make_simplified_union(normalized_items) # this type is not externally visible + # this type is not externally visible + return make_simplified_union(normalized_items) elif isinstance(tp, TypedDictType): # TypedDict always has non-optional string keys. Find the key type from the Mapping # base class. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 626584bc3a201..82f5b82130afa 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3473,7 +3473,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # a "Need type annotation ..." message, as it would be noise. right_type = self.find_partial_type_ref_fast_path(right) if right_type is None: - right_type = self.accept(right) # Validate the right operand + # Validate the right operand + right_type = self.accept(right) right_type = get_proper_type(right_type) item_types: Sequence[Type] = [right_type] @@ -6094,7 +6095,8 @@ def __init__(self, ignore_in_type_obj: bool) -> None: self.ignore_in_type_obj = ignore_in_type_obj def visit_any(self, t: AnyType) -> bool: - return t.type_of_any != TypeOfAny.special_form # special forms are not real Any types + # special forms are not real Any types + return t.type_of_any != TypeOfAny.special_form def visit_callable_type(self, t: CallableType) -> bool: if self.ignore_in_type_obj and t.is_type_obj(): diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 555df0ea47c88..381e1bc67abf4 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -107,7 +107,8 @@ if sys.version_info >= (3, 9): __all__ += ["Annotated", "BinaryIO", "IO", "Match", "Pattern", "TextIO"] if sys.version_info >= (3, 10): - __all__ += ["Concatenate", "ParamSpec", "ParamSpecArgs", "ParamSpecKwargs", "TypeAlias", "TypeGuard", "is_typeddict"] + __all__ += ["Concatenate", "ParamSpec", "ParamSpecArgs", + "ParamSpecKwargs", "TypeAlias", "TypeGuard", "is_typeddict"] if sys.version_info >= (3, 11): __all__ += [ @@ -133,10 +134,14 @@ ContextManager = AbstractContextManager AsyncContextManager = AbstractAsyncContextManager # This itself is only available during type checking + + def type_check_only(func_or_cls: _F) -> _F: ... + Any = object() + @_final class TypeVar: @property @@ -152,6 +157,7 @@ class TypeVar: if sys.version_info >= (3, 12): @property def __infer_variance__(self) -> bool: ... + def __init__( self, name: str, @@ -171,10 +177,13 @@ class TypeVar: if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... + # Used for an undocumented mypy feature. Does not exist at runtime. _promote = object() # N.B. Keep this definition in sync with typing_extensions._SpecialForm + + @_final class _SpecialForm: def __getitem__(self, parameters: Any) -> object: ... @@ -182,12 +191,15 @@ class _SpecialForm: def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + _F = TypeVar("_F", bound=Callable[..., Any]) _P = _ParamSpec("_P") _T = TypeVar("_T") + def overload(func: _F) -> _F: ... + Union: _SpecialForm Generic: _SpecialForm # Protocol is only present in 3.8 and later, but mypy needs it unconditionally @@ -221,7 +233,8 @@ if sys.version_info >= (3, 11): def __init__(self, name: str) -> None: ... def __iter__(self) -> Any: ... def __typing_subst__(self, arg: Never) -> Never: ... - def __typing_prepare_subst__(self, alias: Incomplete, args: Incomplete) -> Incomplete: ... + def __typing_prepare_subst__( + self, alias: Incomplete, args: Incomplete) -> Incomplete: ... if sys.version_info >= (3, 10): @_final @@ -251,6 +264,7 @@ if sys.version_info >= (3, 10): if sys.version_info >= (3, 12): @property def __infer_variance__(self) -> bool: ... + def __init__( self, name: str, @@ -271,7 +285,8 @@ if sys.version_info >= (3, 10): def kwargs(self) -> ParamSpecKwargs: ... if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... - def __typing_prepare_subst__(self, alias: Incomplete, args: Incomplete) -> Incomplete: ... + def __typing_prepare_subst__( + self, alias: Incomplete, args: Incomplete) -> Incomplete: ... def __or__(self, right: Any) -> _SpecialForm: ... def __ror__(self, left: Any) -> _SpecialForm: ... @@ -303,15 +318,21 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers. _VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. _TC = TypeVar("_TC", bound=Type[object]) + def no_type_check(arg: _F) -> _F: ... -def no_type_check_decorator(decorator: Callable[_P, _T]) -> Callable[_P, _T]: ... + + +def no_type_check_decorator( + decorator: Callable[_P, _T]) -> Callable[_P, _T]: ... # Type aliases and type constructors + class _Alias: # Class for defining generic aliases for library types. def __getitem__(self, typeargs: Any) -> Any: ... + List = _Alias() Dict = _Alias() DefaultDict = _Alias() @@ -331,44 +352,55 @@ AnyStr = TypeVar("AnyStr", str, bytes) # noqa: Y001 # Technically in 3.7 this inherited from GenericMeta. But let's not reflect that, since # type checkers tend to assume that Protocols all have the ABCMeta metaclass. + + class _ProtocolMeta(ABCMeta): if sys.version_info >= (3, 12): def __init__(cls, *args: Any, **kwargs: Any) -> None: ... # Abstract base classes. + def runtime_checkable(cls: _TC) -> _TC: ... + + @runtime_checkable class SupportsInt(Protocol, metaclass=ABCMeta): @abstractmethod def __int__(self) -> int: ... + @runtime_checkable class SupportsFloat(Protocol, metaclass=ABCMeta): @abstractmethod def __float__(self) -> float: ... + @runtime_checkable class SupportsComplex(Protocol, metaclass=ABCMeta): @abstractmethod def __complex__(self) -> complex: ... + @runtime_checkable class SupportsBytes(Protocol, metaclass=ABCMeta): @abstractmethod def __bytes__(self) -> bytes: ... + if sys.version_info >= (3, 8): @runtime_checkable class SupportsIndex(Protocol, metaclass=ABCMeta): @abstractmethod def __index__(self) -> int: ... + @runtime_checkable class SupportsAbs(Protocol[_T_co]): @abstractmethod def __abs__(self) -> _T_co: ... + @runtime_checkable class SupportsRound(Protocol[_T_co]): @overload @@ -378,11 +410,13 @@ class SupportsRound(Protocol[_T_co]): @abstractmethod def __round__(self, __ndigits: int) -> _T_co: ... + @runtime_checkable class Sized(Protocol, metaclass=ABCMeta): @abstractmethod def __len__(self) -> int: ... + @runtime_checkable class Hashable(Protocol, metaclass=ABCMeta): # TODO: This is special, in that a subclass of a hashable class may not be hashable @@ -391,40 +425,51 @@ class Hashable(Protocol, metaclass=ABCMeta): @abstractmethod def __hash__(self) -> int: ... + @runtime_checkable class Iterable(Protocol[_T_co]): @abstractmethod def __iter__(self) -> Iterator[_T_co]: ... + @runtime_checkable class Iterator(Iterable[_T_co], Protocol[_T_co]): @abstractmethod def __next__(self) -> _T_co: ... def __iter__(self) -> Iterator[_T_co]: ... + @runtime_checkable class Reversible(Iterable[_T_co], Protocol[_T_co]): @abstractmethod def __reversed__(self) -> Iterator[_T_co]: ... + _YieldT_co = TypeVar("_YieldT_co", covariant=True) _SendT_contra = TypeVar("_SendT_contra", contravariant=True) _ReturnT_co = TypeVar("_ReturnT_co", covariant=True) + class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): def __next__(self) -> _YieldT_co: ... @abstractmethod def send(self, __value: _SendT_contra) -> _YieldT_co: ... + @overload @abstractmethod def throw( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> _YieldT_co: ... + @overload @abstractmethod - def throw(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> _YieldT_co: ... + def throw(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> _YieldT_co: ... + def close(self) -> None: ... - def __iter__(self) -> Generator[_YieldT_co, _SendT_contra, _ReturnT_co]: ... + def __iter__(self) -> Generator[_YieldT_co, + _SendT_contra, _ReturnT_co]: ... + @property def gi_code(self) -> CodeType: ... @property @@ -434,11 +479,13 @@ class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _Return @property def gi_yieldfrom(self) -> Generator[Any, Any, Any] | None: ... + @runtime_checkable class Awaitable(Protocol[_T_co]): @abstractmethod def __await__(self) -> Generator[Any, None, _T_co]: ... + class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): __name__: str __qualname__: str @@ -452,50 +499,64 @@ class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _Retu def cr_running(self) -> bool: ... @abstractmethod def send(self, __value: _SendT_contra) -> _YieldT_co: ... + @overload @abstractmethod def throw( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> _YieldT_co: ... + @overload @abstractmethod - def throw(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> _YieldT_co: ... + def throw(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> _YieldT_co: ... + @abstractmethod def close(self) -> None: ... # NOTE: This type does not exist in typing.py or PEP 484 but mypy needs it to exist. # The parameters correspond to Generator, but the 4th is the original type. + + @type_check_only class AwaitableGenerator( Awaitable[_ReturnT_co], Generator[_YieldT_co, _SendT_contra, _ReturnT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co, _S], metaclass=ABCMeta, -): ... +): + ... + @runtime_checkable class AsyncIterable(Protocol[_T_co]): @abstractmethod def __aiter__(self) -> AsyncIterator[_T_co]: ... + @runtime_checkable class AsyncIterator(AsyncIterable[_T_co], Protocol[_T_co]): @abstractmethod def __anext__(self) -> Awaitable[_T_co]: ... def __aiter__(self) -> AsyncIterator[_T_co]: ... + class AsyncGenerator(AsyncIterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra]): def __anext__(self) -> Awaitable[_YieldT_co]: ... @abstractmethod def asend(self, __value: _SendT_contra) -> Awaitable[_YieldT_co]: ... + @overload @abstractmethod def athrow( self, __typ: Type[BaseException], __val: BaseException | object = None, __tb: TracebackType | None = None ) -> Awaitable[_YieldT_co]: ... + @overload @abstractmethod - def athrow(self, __typ: BaseException, __val: None = None, __tb: TracebackType | None = None) -> Awaitable[_YieldT_co]: ... + def athrow(self, __typ: BaseException, __val: None = None, + __tb: TracebackType | None = None) -> Awaitable[_YieldT_co]: ... + def aclose(self) -> Awaitable[None]: ... @property def ag_await(self) -> Any: ... @@ -506,18 +567,21 @@ class AsyncGenerator(AsyncIterator[_YieldT_co], Generic[_YieldT_co, _SendT_contr @property def ag_running(self) -> bool: ... + @runtime_checkable class Container(Protocol[_T_co]): # This is generic more on vibes than anything else @abstractmethod def __contains__(self, __x: object) -> bool: ... + @runtime_checkable class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]): # Implement Sized (but don't have it as a base class). @abstractmethod def __len__(self) -> int: ... + class Sequence(Collection[_T_co], Reversible[_T_co]): @overload @abstractmethod @@ -532,6 +596,7 @@ class Sequence(Collection[_T_co], Reversible[_T_co]): def __iter__(self) -> Iterator[_T_co]: ... def __reversed__(self) -> Iterator[_T_co]: ... + class MutableSequence(Sequence[_T]): @abstractmethod def insert(self, index: int, value: _T) -> None: ... @@ -562,6 +627,7 @@ class MutableSequence(Sequence[_T]): def remove(self, value: _T) -> None: ... def __iadd__(self, values: Iterable[_T]) -> typing_extensions.Self: ... + class AbstractSet(Collection[_T_co]): @abstractmethod def __contains__(self, x: object) -> bool: ... @@ -578,6 +644,7 @@ class AbstractSet(Collection[_T_co]): def __eq__(self, other: object) -> bool: ... def isdisjoint(self, other: Iterable[Any]) -> bool: ... + class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T) -> None: ... @@ -587,17 +654,23 @@ class MutableSet(AbstractSet[_T]): def clear(self) -> None: ... def pop(self) -> _T: ... def remove(self, value: _T) -> None: ... - def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] # nopep8 + def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... - def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] + def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] # nopep8 + def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... + class MappingView(Sized): def __init__(self, mapping: Mapping[Any, Any]) -> None: ... # undocumented def __len__(self) -> int: ... + class ItemsView(MappingView, AbstractSet[tuple[_KT_co, _VT_co]], Generic[_KT_co, _VT_co]): - def __init__(self, mapping: Mapping[_KT_co, _VT_co]) -> None: ... # undocumented + def __init__( + self, mapping: Mapping[_KT_co, _VT_co]) -> None: ... # undocumented + def __and__(self, other: Iterable[Any]) -> set[tuple[_KT_co, _VT_co]]: ... def __rand__(self, other: Iterable[_T]) -> set[_T]: ... def __contains__(self, item: object) -> bool: ... @@ -605,15 +678,20 @@ class ItemsView(MappingView, AbstractSet[tuple[_KT_co, _VT_co]], Generic[_KT_co, if sys.version_info >= (3, 8): def __reversed__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ... - def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... - def __ror__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... + def __or__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + def __ror__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + def __sub__(self, other: Iterable[Any]) -> set[tuple[_KT_co, _VT_co]]: ... def __rsub__(self, other: Iterable[_T]) -> set[_T]: ... - def __xor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... - def __rxor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... + + def __xor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + def __rxor__(self, other: Iterable[_T]) -> set[tuple[_KT_co, _VT_co] | _T]: ... # nopep8 + class KeysView(MappingView, AbstractSet[_KT_co]): - def __init__(self, mapping: Mapping[_KT_co, Any]) -> None: ... # undocumented + def __init__( + self, mapping: Mapping[_KT_co, Any]) -> None: ... # undocumented + def __and__(self, other: Iterable[Any]) -> set[_KT_co]: ... def __rand__(self, other: Iterable[_T]) -> set[_T]: ... def __contains__(self, key: object) -> bool: ... @@ -628,13 +706,17 @@ class KeysView(MappingView, AbstractSet[_KT_co]): def __xor__(self, other: Iterable[_T]) -> set[_KT_co | _T]: ... def __rxor__(self, other: Iterable[_T]) -> set[_KT_co | _T]: ... + class ValuesView(MappingView, Collection[_VT_co]): - def __init__(self, mapping: Mapping[Any, _VT_co]) -> None: ... # undocumented + def __init__(self, mapping: Mapping[Any, + _VT_co]) -> None: ... # undocumented + def __contains__(self, value: object) -> bool: ... def __iter__(self) -> Iterator[_VT_co]: ... if sys.version_info >= (3, 8): def __reversed__(self) -> Iterator[_VT_co]: ... + class Mapping(Collection[_KT], Generic[_KT, _VT_co]): # TODO: We wish the key type could also be covariant, but that doesn't work, # see discussion in https://github.com/python/typing/pull/273. @@ -651,6 +733,7 @@ class Mapping(Collection[_KT], Generic[_KT, _VT_co]): def __contains__(self, __key: object) -> bool: ... def __eq__(self, __other: object) -> bool: ... + class MutableMapping(Mapping[_KT, _VT]): @abstractmethod def __setitem__(self, __key: _KT, __value: _VT) -> None: ... @@ -670,8 +753,11 @@ class MutableMapping(Mapping[_KT, _VT]): # -- collections.OrderedDict.setdefault # -- collections.ChainMap.setdefault # -- weakref.WeakKeyDictionary.setdefault + @overload - def setdefault(self: MutableMapping[_KT, _T | None], __key: _KT, __default: None = None) -> _T | None: ... + def setdefault(self: MutableMapping[_KT, _T | None], + __key: _KT, __default: None = None) -> _T | None: ... + @overload def setdefault(self, __key: _KT, __default: _VT) -> _VT: ... # 'update' used to take a Union, but using overloading is better. @@ -694,13 +780,19 @@ class MutableMapping(Mapping[_KT, _VT]): # -- peewee.attrdict.__iadd__ # -- weakref.WeakValueDictionary.__ior__ # -- weakref.WeakKeyDictionary.__ior__ + @overload - def update(self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT) -> None: ... + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT) -> None: ... + @overload - def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + def update(self, __m: Iterable[tuple[_KT, _VT]], + **kwargs: _VT) -> None: ... + @overload def update(self, **kwargs: _VT) -> None: ... + Text = str TYPE_CHECKING: bool @@ -708,6 +800,8 @@ TYPE_CHECKING: bool # In stubs, the arguments of the IO class are marked as positional-only. # This differs from runtime, but better reflects the fact that in reality # classes deriving from IO use different names for the arguments. + + class IO(Iterator[AnyStr]): # At runtime these are all abstract properties, # but making them abstract in the stub is hugely disruptive, for not much gain. @@ -758,9 +852,12 @@ class IO(Iterator[AnyStr]): @abstractmethod @overload def writelines(self: IO[str], __lines: Iterable[str]) -> None: ... + @abstractmethod @overload - def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ... + def writelines(self: IO[bytes], + __lines: Iterable[ReadableBuffer]) -> None: ... + @abstractmethod @overload def writelines(self, __lines: Iterable[AnyStr]) -> None: ... @@ -770,15 +867,18 @@ class IO(Iterator[AnyStr]): def __iter__(self) -> Iterator[AnyStr]: ... @abstractmethod def __enter__(self) -> IO[AnyStr]: ... + @abstractmethod def __exit__( self, __type: Type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None ) -> None: ... + class BinaryIO(IO[bytes]): @abstractmethod def __enter__(self) -> BinaryIO: ... + class TextIO(IO[str]): # See comment regarding the @properties in the `IO` class @property @@ -794,6 +894,7 @@ class TextIO(IO[str]): @abstractmethod def __enter__(self) -> TextIO: ... + ByteString: typing_extensions.TypeAlias = bytes | bytearray | memoryview # Functions @@ -839,6 +940,7 @@ if sys.version_info >= (3, 8): else: def get_origin(tp: Any) -> Any | None: ... + @overload def cast(typ: Type[_T], val: Any) -> _T: ... @overload @@ -846,12 +948,16 @@ def cast(typ: str, val: Any) -> Any: ... @overload def cast(typ: object, val: Any) -> Any: ... + if sys.version_info >= (3, 11): def reveal_type(__obj: _T) -> _T: ... def assert_never(__arg: Never) -> Never: ... def assert_type(__val: _T, __typ: Any) -> _T: ... def clear_overloads() -> None: ... - def get_overloads(func: Callable[..., object]) -> Sequence[Callable[..., object]]: ... + + def get_overloads(func: Callable[..., object] + ) -> Sequence[Callable[..., object]]: ... + def dataclass_transform( *, eq_default: bool = True, @@ -864,6 +970,7 @@ if sys.version_info >= (3, 11): # Type constructors + class NamedTuple(tuple[Any, ...]): if sys.version_info < (3, 8): _field_types: ClassVar[collections.OrderedDict[str, type]] @@ -875,10 +982,15 @@ class NamedTuple(tuple[Any, ...]): # So we only add it to the stub on 3.12+. if sys.version_info >= (3, 12): __orig_bases__: ClassVar[tuple[Any, ...]] + @overload - def __init__(self, __typename: str, __fields: Iterable[tuple[str, Any]]) -> None: ... + def __init__(self, __typename: str, + __fields: Iterable[tuple[str, Any]]) -> None: ... + @overload - def __init__(self, __typename: str, __fields: None = None, **kwargs: Any) -> None: ... + def __init__(self, __typename: str, __fields: None = None, + **kwargs: Any) -> None: ... + @classmethod def _make(cls, iterable: Iterable[Any]) -> typing_extensions.Self: ... if sys.version_info >= (3, 8): @@ -890,6 +1002,8 @@ class NamedTuple(tuple[Any, ...]): # Internal mypy fallback type for all typed dicts (does not exist at runtime) # N.B. Keep this mostly in sync with typing_extensions._TypedDict/mypy_extensions._TypedDict + + @type_check_only class _TypedDict(Mapping[str, object], metaclass=ABCMeta): __total__: ClassVar[bool] @@ -900,12 +1014,15 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): # so we only add it to the stub on 3.12+ if sys.version_info >= (3, 12): __orig_bases__: ClassVar[tuple[Any, ...]] + def copy(self) -> typing_extensions.Self: ... # Using Never so that only calls using mypy plugin hook that specialize the signature # can go through. def setdefault(self, k: _Never, default: object) -> object: ... # Mypy plugin hook for 'pop' expects that 'default' has a type variable type. - def pop(self, k: _Never, default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def pop(self, k: _Never, + default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def update(self: _T, __m: _T) -> None: ... def __delitem__(self, k: _Never) -> None: ... def items(self) -> dict_items[str, object]: ... @@ -913,15 +1030,19 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): def values(self) -> dict_values[str, object]: ... if sys.version_info >= (3, 9): @overload - def __or__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __or__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # nopep8 + @overload def __or__(self, __value: dict[str, Any]) -> dict[str, object]: ... + @overload - def __ror__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... + def __ror__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # nopep8 + @overload def __ror__(self, __value: dict[str, Any]) -> dict[str, object]: ... # supposedly incompatible definitions of __or__ and __ior__ - def __ior__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc] + def __ior__(self, __value: typing_extensions.Self) -> typing_extensions.Self: ... # type: ignore[misc]# nopep8 + @_final class ForwardRef: @@ -934,7 +1055,8 @@ class ForwardRef: __forward_module__: Any | None if sys.version_info >= (3, 9): # The module and is_class arguments were added in later Python 3.9 versions. - def __init__(self, arg: str, is_argument: bool = True, module: Any | None = None, *, is_class: bool = False) -> None: ... + def __init__(self, arg: str, is_argument: bool = True, module: Any | + None = None, *, is_class: bool = False) -> None: ... else: def __init__(self, arg: str, is_argument: bool = True) -> None: ... @@ -943,7 +1065,8 @@ class ForwardRef: self, globalns: dict[str, Any] | None, localns: dict[str, Any] | None, recursive_guard: frozenset[str] ) -> Any | None: ... else: - def _evaluate(self, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any | None: ... + def _evaluate(self, globalns: dict[str, Any] | None, + localns: dict[str, Any] | None) -> Any | None: ... def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... @@ -951,13 +1074,17 @@ class ForwardRef: def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + if sys.version_info >= (3, 10): def is_typeddict(tp: object) -> bool: ... + def _type_repr(obj: object) -> str: ... + if sys.version_info >= (3, 12): def override(__method: _F) -> _F: ... + @_final class TypeAliasType: def __init__( @@ -965,8 +1092,11 @@ if sys.version_info >= (3, 12): ) -> None: ... @property def __value__(self) -> Any: ... + @property - def __type_params__(self) -> tuple[TypeVar | ParamSpec | TypeVarTuple, ...]: ... + def __type_params__(self) -> tuple[TypeVar | + ParamSpec | TypeVarTuple, ...]: ... + @property def __parameters__(self) -> tuple[Any, ...]: ... @property diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index 5c5b756f5256d..5cdedb379ad36 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -190,12 +190,15 @@ _F = typing.TypeVar("_F", bound=Callable[..., Any]) _TC = typing.TypeVar("_TC", bound=type[object]) # unfortunately we have to duplicate this class definition from typing.pyi or we break pytype + + class _SpecialForm: def __getitem__(self, parameters: Any) -> object: ... if sys.version_info >= (3, 10): def __or__(self, other: Any) -> _SpecialForm: ... def __ror__(self, other: Any) -> _SpecialForm: ... + # Do not import (and re-export) Protocol or runtime_checkable from # typing module because type checkers need to be able to distinguish # typing.Protocol and typing_extensions.Protocol so they can properly @@ -203,20 +206,27 @@ class _SpecialForm: # on older versions of Python. Protocol: _SpecialForm + def runtime_checkable(cls: _TC) -> _TC: ... + # This alias for above is kept here for backwards compatibility. runtime = runtime_checkable Final: _SpecialForm + def final(f: _F) -> _F: ... + Literal: _SpecialForm + def IntVar(name: str) -> Any: ... # returns a new TypeVar # Internal mypy fallback type for all typed dicts (does not exist at runtime) # N.B. Keep this mostly in sync with typing._TypedDict/mypy_extensions._TypedDict + + @type_check_only class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): __required_keys__: ClassVar[frozenset[str]] @@ -230,7 +240,9 @@ class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): # can go through. def setdefault(self, k: Never, default: object) -> object: ... # Mypy plugin hook for 'pop' expects that 'default' has a type variable type. - def pop(self, k: Never, default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def pop(self, k: Never, + default: _T = ...) -> object: ... # pyright: ignore[reportInvalidTypeVarUse] + def update(self: _T, __m: _T) -> None: ... def items(self) -> dict_items[str, object]: ... def keys(self) -> dict_keys[str, object]: ... @@ -248,11 +260,13 @@ class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): # supposedly incompatible definitions of `__ior__` and `__or__`: def __ior__(self, __value: Self) -> Self: ... # type: ignore[misc] + # TypedDict is a (non-subscriptable) special form. TypedDict: object OrderedDict = _Alias() + def get_type_hints( obj: Callable[..., Any], globalns: dict[str, Any] | None = None, @@ -261,6 +275,7 @@ def get_type_hints( ) -> dict[str, Any]: ... def get_args(tp: Any) -> tuple[Any, ...]: ... + if sys.version_info >= (3, 10): @overload def get_origin(tp: UnionType) -> type[UnionType]: ... @@ -269,19 +284,23 @@ if sys.version_info >= (3, 9): @overload def get_origin(tp: GenericAlias) -> type: ... + @overload def get_origin(tp: ParamSpecArgs | ParamSpecKwargs) -> ParamSpec: ... @overload def get_origin(tp: Any) -> Any | None: ... + Annotated: _SpecialForm _AnnotatedAlias: Any # undocumented + @runtime_checkable class SupportsIndex(Protocol, metaclass=abc.ABCMeta): @abc.abstractmethod def __index__(self) -> int: ... + # New and changed things in 3.10 if sys.version_info >= (3, 10): from typing import ( @@ -335,7 +354,9 @@ else: def assert_never(__arg: Never) -> Never: ... def assert_type(__val: _T, __typ: Any) -> _T: ... def clear_overloads() -> None: ... - def get_overloads(func: Callable[..., object]) -> Sequence[Callable[..., object]]: ... + + def get_overloads(func: Callable[..., object] + ) -> Sequence[Callable[..., object]]: ... Required: _SpecialForm NotRequired: _SpecialForm @@ -360,10 +381,15 @@ else: _field_defaults: ClassVar[dict[str, Any]] _fields: ClassVar[tuple[str, ...]] __orig_bases__: ClassVar[tuple[Any, ...]] + @overload - def __init__(self, typename: str, fields: Iterable[tuple[str, Any]] = ...) -> None: ... + def __init__(self, typename: str, + fields: Iterable[tuple[str, Any]] = ...) -> None: ... + @overload - def __init__(self, typename: str, fields: None = None, **kwargs: Any) -> None: ... + def __init__(self, typename: str, fields: None = None, + **kwargs: Any) -> None: ... + @classmethod def _make(cls, iterable: Iterable[Any]) -> Self: ... if sys.version_info >= (3, 8): @@ -385,6 +411,8 @@ else: # The `default` parameter was added to TypeVar, ParamSpec, and TypeVarTuple (PEP 696) # The `infer_variance` parameter was added to TypeVar in 3.12 (PEP 695) # typing_extensions.override (PEP 698) + + @final class TypeVar: @property @@ -401,6 +429,7 @@ class TypeVar: def __infer_variance__(self) -> bool: ... @property def __default__(self) -> Any | None: ... + def __init__( self, name: str, @@ -417,6 +446,7 @@ class TypeVar: if sys.version_info >= (3, 11): def __typing_subst__(self, arg: Incomplete) -> Incomplete: ... + @final class ParamSpec: @property @@ -431,6 +461,7 @@ class ParamSpec: def __infer_variance__(self) -> bool: ... @property def __default__(self) -> Any | None: ... + def __init__( self, name: str, @@ -445,6 +476,7 @@ class ParamSpec: @property def kwargs(self) -> ParamSpecKwargs: ... + @final class TypeVarTuple: @property @@ -454,13 +486,17 @@ class TypeVarTuple: def __init__(self, name: str, *, default: Any | None = None) -> None: ... def __iter__(self) -> Any: ... # Unpack[Self] + class deprecated: message: str category: type[Warning] | None stacklevel: int - def __init__(self, __message: str, *, category: type[Warning] | None = ..., stacklevel: int = 1) -> None: ... + def __init__(self, __message: str, *, + category: type[Warning] | None = ..., stacklevel: int = 1) -> None: ... + def __call__(self, __arg: _T) -> _T: ... + if sys.version_info >= (3, 12): from collections.abc import Buffer as Buffer from types import get_original_bases as get_original_bases @@ -468,6 +504,7 @@ if sys.version_info >= (3, 12): else: def override(__arg: _F) -> _F: ... def get_original_bases(__cls: type) -> tuple[Any, ...]: ... + @final class TypeAliasType: def __init__( @@ -475,8 +512,11 @@ else: ) -> None: ... @property def __value__(self) -> Any: ... + @property - def __type_params__(self) -> tuple[TypeVar | ParamSpec | TypeVarTuple, ...]: ... + def __type_params__(self) -> tuple[TypeVar | + ParamSpec | TypeVarTuple, ...]: ... + @property def __parameters__(self) -> tuple[Any, ...]: ... @property @@ -501,10 +541,12 @@ else: def is_protocol(__tp: type) -> bool: ... def get_protocol_members(__tp: type) -> frozenset[str]: ... + class Doc: documentation: str def __init__(self, __documentation: str) -> None: ... def __hash__(self) -> int: ... def __eq__(self, other: object) -> bool: ... + ReadOnly: _SpecialForm diff --git a/test-data/unit/check-in-exhaustive-checking.test b/test-data/unit/check-in-exhaustive-checking.test new file mode 100644 index 0000000000000..3f906f3a9c523 --- /dev/null +++ b/test-data/unit/check-in-exhaustive-checking.test @@ -0,0 +1,36 @@ +[case testInExhaustiveChecking] + +[builtins fixtures/tuple.pyi] +from enum import Enum +import pytest + +# Import assert_never for ensuring all cases in an if-elif chain are handled +from typing_extensions import assert_never + +class MyEnum(Enum): + A = 1 + B = 2 + C = 3 + +def my_function(a: MyEnum) -> bool: + if a == MyEnum.A: + return True + elif a in (MyEnum.B, MyEnum.C): + return False + assert_never(a) + +class MyEnum2(Enum): + A = 1 + B = 2 + +def my_function2(a: MyEnum) -> bool: + if a in (MyEnum.A, MyEnum.B): + return False + assert_never(a) + +# Test cases +def test_my_function(): + # Test for MyEnum.A + assert my_function(MyEnum.A) == True, "Failed for MyEnum.A" + + assert my_function2(MyEnum2.A) == True, "Failed for MyEnum.A" diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index b7ee38b69d00f..178bbcf1aac78 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1961,20 +1961,6 @@ if x in nested_any: [builtins fixtures/list.pyi] [out] -[case testNarrowTypeAfterInTuple] -from typing import Optional -class A: pass -class B(A): pass -class C(A): pass - -y: Optional[B] -if y in (B(), C()): - reveal_type(y) # N: Revealed type is "__main__.B" -else: - reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" -[builtins fixtures/tuple.pyi] -[out] - [case testNarrowTypeAfterInNamedTuple] from typing import NamedTuple, Optional class NT(NamedTuple):