diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 63704ff01..3dfee135e 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -31,7 +31,6 @@ import { import { CodeFlowReferenceExpressionNode } from './codeFlowTypes'; import { populateTypeVarContextBasedOnExpectedType } from './constraintSolver'; import { getTypeVarScopesForNode, isMatchingExpression } from './parseTreeUtils'; -import { getTypedDictMembersForClass } from './typedDicts'; import { EvaluatorFlags, TypeEvaluator, TypeResult } from './typeEvaluatorTypes'; import { enumerateLiteralsForType, @@ -39,27 +38,6 @@ import { narrowTypeForDiscriminatedLiteralFieldComparison, narrowTypeForDiscriminatedTupleComparison, } from './typeGuards'; -import { - AnyType, - ClassType, - combineTypes, - FunctionType, - FunctionTypeFlags, - isAnyOrUnknown, - isClass, - isClassInstance, - isInstantiableClass, - isNever, - isSameWithoutLiteralValue, - isTypeSame, - isUnknown, - isUnpackedVariadicTypeVar, - NeverType, - Type, - TypeBase, - TypedDictEntry, - UnknownType, -} from './types'; import { addConditionToType, applySolvedTypeVars, @@ -85,6 +63,28 @@ import { transformPossibleRecursiveTypeAlias, } from './typeUtils'; import { TypeVarContext } from './typeVarContext'; +import { getTypedDictMembersForClass } from './typedDicts'; +import { + AnyType, + ClassType, + FunctionType, + FunctionTypeFlags, + NeverType, + Type, + TypeBase, + TypedDictEntry, + UnknownType, + combineTypes, + isAnyOrUnknown, + isClass, + isClassInstance, + isInstantiableClass, + isNever, + isSameWithoutLiteralValue, + isTypeSame, + isUnknown, + isUnpackedVariadicTypeVar, +} from './types'; // PEP 634 indicates that several built-in classes are handled differently // when used with class pattern matching. @@ -720,10 +720,22 @@ function narrowTypeBasedOnClassPattern( return subjectSubtypeExpanded; } + // Handle Callable specially. + if ( + !isAnyOrUnknown(subjectSubtypeExpanded) && + isInstantiableClass(classType) && + ClassType.isBuiltIn(classType, 'Callable') + ) { + if (evaluator.assignType(getUnknownTypeForCallable(), subjectSubtypeExpanded)) { + return undefined; + } + } + if (!isNoneInstance(subjectSubtypeExpanded) && !isClassInstance(subjectSubtypeExpanded)) { return subjectSubtypeUnexpanded; } + // Handle NoneType specially. if ( isNoneInstance(subjectSubtypeExpanded) && isInstantiableClass(classType) && diff --git a/packages/pyright-internal/src/tests/samples/matchClass6.py b/packages/pyright-internal/src/tests/samples/matchClass6.py index bd351df0e..bbd311d70 100644 --- a/packages/pyright-internal/src/tests/samples/matchClass6.py +++ b/packages/pyright-internal/src/tests/samples/matchClass6.py @@ -1,7 +1,7 @@ # This sample tests the case where `Callable()` is used as a class pattern. from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, Protocol, TypeVar T = TypeVar("T") @@ -38,3 +38,28 @@ def func5(obj: Any): case Callable(): reveal_type(obj, expected_text="(...) -> Any") return obj() + + +def func6(obj: Callable[[], None]): + match obj: + case Callable(): + reveal_type(obj, expected_text="() -> None") + return obj() + + case x: + reveal_type(obj, expected_text="Never") + + +class CallableProto(Protocol): + def __call__(self) -> None: + pass + + +def func7(obj: CallableProto): + match obj: + case Callable(): + reveal_type(obj, expected_text="CallableProto") + return obj() + + case x: + reveal_type(obj, expected_text="Never")