diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 74624b150..307b61a70 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -173,7 +173,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x == L` and `x != L` (where L is a literal expression) * `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None) * `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type) -* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type) +* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type) * `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 1a0a0955e..7ffba9a15 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -1507,6 +1507,7 @@ function narrowTypeForDiscriminatedLiteralFieldComparison( ): Type { const narrowedType = mapSubtypes(referenceType, (subtype) => { let memberInfo: ClassMember | undefined; + if (isClassInstance(subtype)) { memberInfo = lookUpObjectMember(subtype, memberName); } else if (isInstantiableClass(subtype)) { @@ -1514,7 +1515,23 @@ function narrowTypeForDiscriminatedLiteralFieldComparison( } if (memberInfo && memberInfo.isTypeDeclared) { - const memberType = evaluator.getTypeOfMember(memberInfo); + let memberType = evaluator.getTypeOfMember(memberInfo); + + // Handle the case where the field is a property + // that has a declared literal return type for its getter. + if (isClassInstance(subtype) && isProperty(memberType)) { + const getterInfo = lookUpObjectMember(memberType, 'fget'); + + if (getterInfo && getterInfo.isTypeDeclared) { + const getterType = evaluator.getTypeOfMember(getterInfo); + if (isFunction(getterType) && getterType.details.declaredReturnType) { + const getterReturnType = FunctionType.getSpecializedReturnType(getterType); + if (getterReturnType) { + memberType = getterReturnType; + } + } + } + } if (isLiteralTypeOrUnion(memberType)) { if (isPositiveTest) { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py index d62d86562..5fad8ad71 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py @@ -118,3 +118,23 @@ def is_class2(c: Union[Type[A], Type[B]]): reveal_type(c, expected_text="Type[A] | Type[B]") else: reveal_type(c, expected_text="Type[A] | Type[B]") + + +class E: + @property + def type(self) -> Literal[0]: + return 0 + + +class F: + @property + def type(self) -> Literal[1]: + return 1 + + +def test(x: E | F) -> None: + if x.type == 1: + reveal_type(x, expected_type="F") + else: + reveal_type(x, expected_type="E") +