diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 1367ab3fa..74624b150 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -171,6 +171,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `type(x) is T` and `type(x) is not T` * `x is E` and `x is not E` (where E is a literal enum or bool) * `x == L` and `x != L` (where L is a literal expression) +* `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None) * `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type) * `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type) * `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 53f50b166..bf72712cf 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -70,6 +70,8 @@ import { getTypeVarScopeId, isLiteralType, isLiteralTypeOrUnion, + isMaybeDescriptorInstance, + isProperty, isTupleClass, isUnboundedTupleClass, lookUpClassMember, @@ -321,7 +323,7 @@ export function getTypeNarrowingCallback( const memberName = testExpression.leftExpression.memberName; if (isClassInstance(rightType) && rightType.literalValue !== undefined) { return (type: Type) => { - return narrowTypeForDiscriminatedFieldComparison( + return narrowTypeForDiscriminatedLiteralFieldComparison( evaluator, type, memberName.value, @@ -346,7 +348,7 @@ export function getTypeNarrowingCallback( rightType.literalValue !== undefined ) { return (type: Type) => { - return narrowTypeForDiscriminatedFieldComparison( + return narrowTypeForDiscriminatedLiteralFieldComparison( evaluator, type, memberName.value, @@ -356,6 +358,25 @@ export function getTypeNarrowingCallback( }; } } + + // Look for X.Y is None or X.Y is not None + // These are commonly-used patterns used in control flow. + if ( + testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess && + ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression) && + testExpression.rightExpression.nodeType === ParseNodeType.Constant && + testExpression.rightExpression.constType === KeywordType.None + ) { + const memberName = testExpression.leftExpression.memberName; + return (type: Type) => { + return narrowTypeForDiscriminatedFieldNoneComparison( + evaluator, + type, + memberName.value, + adjIsPositiveTest + ); + }; + } } if (testExpression.operator === OperatorType.In) { @@ -1423,7 +1444,7 @@ function narrowTypeForDiscriminatedTupleComparison( // Attempts to narrow a type based on a comparison (equal or not equal) // between a discriminating field that has a declared literal type to a // literal value. -function narrowTypeForDiscriminatedFieldComparison( +function narrowTypeForDiscriminatedLiteralFieldComparison( evaluator: TypeEvaluator, referenceType: Type, memberName: string, @@ -1456,6 +1477,53 @@ function narrowTypeForDiscriminatedFieldComparison( return narrowedType; } +// Attempts to narrow a type based on a comparison (equal or not equal) +// between a discriminating field that has a declared None type to a +// None. +function narrowTypeForDiscriminatedFieldNoneComparison( + evaluator: TypeEvaluator, + referenceType: Type, + memberName: string, + isPositiveTest: boolean +): Type { + return mapSubtypes(referenceType, (subtype) => { + let memberInfo: ClassMember | undefined; + if (isClassInstance(subtype)) { + memberInfo = lookUpObjectMember(subtype, memberName); + } else if (isInstantiableClass(subtype)) { + memberInfo = lookUpClassMember(subtype, memberName); + } + + if (memberInfo && memberInfo.isTypeDeclared) { + const memberType = evaluator.makeTopLevelTypeVarsConcrete(evaluator.getTypeOfMember(memberInfo)); + let canNarrow = true; + + if (isPositiveTest) { + doForEachSubtype(memberType, (memberSubtype) => { + memberSubtype = evaluator.makeTopLevelTypeVarsConcrete(memberSubtype); + + // Don't attempt to narrow if the member is a descriptor or property. + if (isProperty(memberSubtype) || isMaybeDescriptorInstance(memberSubtype)) { + canNarrow = false; + } + + if (isAnyOrUnknown(memberSubtype) || isNoneInstance(memberSubtype) || isNever(memberSubtype)) { + canNarrow = false; + } + }); + } else { + canNarrow = isNoneInstance(memberType); + } + + if (canNarrow) { + return undefined; + } + } + + return subtype; + }); +} + // Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check. function narrowTypeForTypeIs(type: Type, classType: ClassType, isPositiveTest: boolean) { return mapSubtypes(type, (subtype) => { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingNoneMember1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingNoneMember1.py new file mode 100644 index 000000000..c91e74b9e --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingNoneMember1.py @@ -0,0 +1,102 @@ +# This sample tests the type narrowing case for unions of NamedTuples +# where one or more of the entries is tested against type None by attribute. + +from typing import NamedTuple, Optional, Union + +IntFirst = NamedTuple( + "IntFirst", + [ + ("first", int), + ("second", None), + ], +) + +StrSecond = NamedTuple( + "StrSecond", + [ + ("first", None), + ("second", str), + ], +) + + +def func1(a: Union[IntFirst, StrSecond]) -> IntFirst: + if a.second is None: + reveal_type(a, expected_text="IntFirst") + return a + else: + reveal_type(a, expected_text="StrSecond") + raise ValueError() + + +UnionFirst = NamedTuple( + "UnionFirst", + [ + ("first", Union[None, int]), + ("second", None), + ], +) + + +def func2(a: Union[UnionFirst, StrSecond]): + if a.first is None: + reveal_type(a, expected_text="UnionFirst | StrSecond") + else: + reveal_type(a, expected_text="UnionFirst") + + +class A: + @property + def prop1(self) -> Optional[int]: + ... + + member1: None + member2: Optional[int] + member3: Optional[int] + member4: Optional[int] + + +class B: + @property + def prop1(self) -> int: + ... + + member1: int + member2: Optional[int] + member3: None + member4: int + + +def func3(c: Union[A, B]): + if c.prop1 is None: + reveal_type(c, expected_text="A | B") + else: + reveal_type(c, expected_text="A | B") + + +def func4(c: Union[A, B]): + if c.member1 is None: + reveal_type(c, expected_text="A") + else: + reveal_type(c, expected_text="B") + + +def func5(c: Union[A, B]): + if c.member2 is None: + reveal_type(c, expected_text="A | B") + else: + reveal_type(c, expected_text="A | B") + + +def func6(c: Union[A, B]): + if c.member3 is not None: + reveal_type(c, expected_text="A") + else: + reveal_type(c, expected_text="A | B") + + +def func7(c: Union[A, B]): + if c.member4 is not None: + reveal_type(c, expected_text="A | B") + else: + reveal_type(c, expected_text="A") diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index 754f51725..ec75a0dd0 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -386,6 +386,12 @@ test('TypeNarrowingLiteralMember1', () => { TestUtils.validateResults(analysisResults, 0); }); +test('TypeNarrowingNoneMember1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingNoneMember1.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('TypeNarrowingTuple1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingTuple1.py']);