diff --git a/docs/type-concepts.md b/docs/type-concepts.md index e88967b59..b10bc8a70 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -150,9 +150,8 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x is None` and `x is not None` * `x == None` and `x != None` * `type(x) is T` and `type(x) is not T` -* `x is E` and `x is not E` (where E is an enum value or True or False) -* `x == L` and `x != L` (where L is a literal expression) -* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type) +* `x == L` and `x != L` and `L == x` and `L != x` and `x is L` and `x is not L` (where L is a literal expression) +* `x.y == L` and `x.y != L` and `x.y is L` and `x.y is not L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type) * `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index a0822e9fe..48569a5f7 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -177,15 +177,11 @@ export function getTypeNarrowingCallback( } } - // Look for "X is Y" or "X is not Y" where Y is a an enum or False or True. + // Look for "X is " or "X is not " if (isOrIsNotOperator) { if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; - if ( - isClassInstance(rightType) && - (ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) && - rightType.literalValue !== undefined - ) { + if (isClassInstance(rightType) && rightType.literalValue !== undefined) { return (type: Type) => { return narrowTypeForLiteralComparison( evaluator, @@ -219,6 +215,7 @@ export function getTypeNarrowingCallback( } } + // Look for == X or != X if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) { const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type; if (isClassInstance(leftType) && leftType.literalValue !== undefined) { @@ -234,26 +231,6 @@ export function getTypeNarrowingCallback( } } - // Look for X.Y == or X.Y != - if ( - testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess && - ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression) - ) { - const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; - const memberName = testExpression.leftExpression.memberName; - if (isClassInstance(rightType) && rightType.literalValue !== undefined) { - return (type: Type) => { - return narrowTypeForDiscriminatedFieldComparison( - evaluator, - type, - memberName.value, - rightType, - adjIsPositiveTest - ); - }; - } - } - // Look for X[] == or X[] != if ( testExpression.leftExpression.nodeType === ParseNodeType.Index && @@ -324,6 +301,26 @@ export function getTypeNarrowingCallback( } } } + + // Look for X.Y == or X.Y != or X.Y is or X.Y is not + if ( + testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess && + ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression) + ) { + const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; + const memberName = testExpression.leftExpression.memberName; + if (isClassInstance(rightType) && rightType.literalValue !== undefined) { + return (type: Type) => { + return narrowTypeForDiscriminatedFieldComparison( + evaluator, + type, + memberName.value, + rightType, + adjIsPositiveTest + ); + }; + } + } } if (testExpression.operator === OperatorType.In) { @@ -1340,8 +1337,6 @@ function narrowTypeForDiscriminatedFieldComparison( literalType: ClassType, isPositiveTest: boolean ): Type { - let canNarrow = true; - const narrowedType = mapSubtypes(referenceType, (subtype) => { let memberInfo: ClassMember | undefined; if (isClassInstance(subtype)) { @@ -1362,11 +1357,10 @@ function narrowTypeForDiscriminatedFieldComparison( } } - canNarrow = false; return subtype; }); - return canNarrow ? narrowedType : referenceType; + return narrowedType; } // Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check. diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py index 3a0dc7c2c..1728ad937 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py @@ -4,41 +4,30 @@ from typing import Literal, Union -def requires_a(p1: Literal["a"]): - pass - - -def requires_bc(p1: Literal["b", "c"]): - pass - - def func_1(p1: Literal["a", "b", "c"]): if p1 != "b": if p1 == "c": + t1: Literal["Literal['c']"] = reveal_type(p1) pass else: - requires_a(p1) + t2: Literal["Literal['a']"] = reveal_type(p1) if p1 != "a": - requires_bc(p1) + t3: Literal["Literal['c', 'b']"] = reveal_type(p1) else: - requires_a(p1) + t4: Literal["Literal['a']"] = reveal_type(p1) if "a" != p1: - requires_bc(p1) + t5: Literal["Literal['c', 'b']"] = reveal_type(p1) else: - requires_a(p1) - - -def requires_7(p1: Literal[7]): - pass + t6: Literal["Literal['a']"] = reveal_type(p1) def func2(p1: Literal[1, 4, 7]): if 4 == p1 or 1 == p1: - pass + t1: Literal["Literal[4, 1]"] = reveal_type(p1) else: - requires_7(p1) + t2: Literal["Literal[7]"] = reveal_type(p1) def func3(a: Union[int, None]): diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py index 8826751ba..23f08f741 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteralMember1.py @@ -26,43 +26,85 @@ class D: kind: Literal[1, 2, 3] -def foo_obj1(c: Union[A, B]): +def eq_obj1(c: Union[A, B]): if c.kind == "A": tc1: Literal["A"] = reveal_type(c) else: tc2: Literal["B"] = reveal_type(c) -def foo_obj2(c: Union[A, B]): +def is_obj1(c: Union[A, B]): + if c.kind is "A": + tc1: Literal["A"] = reveal_type(c) + else: + tc2: Literal["B"] = reveal_type(c) + + +def eq_obj2(c: Union[A, B]): if c.kind != "A": tc1: Literal["B"] = reveal_type(c) else: tc2: Literal["A"] = reveal_type(c) -def foo_obj3(c: Union[A, B, C]): - if c.kind == "A": - tc1: Literal["A | B | C"] = reveal_type(c) +def is_obj2(c: Union[A, B]): + if c.kind is not "A": + tc1: Literal["B"] = reveal_type(c) else: - tc2: Literal["A | B | C"] = reveal_type(c) + tc2: Literal["A"] = reveal_type(c) -def foo_obj4(c: Union[A, B]): +def eq_obj3(c: Union[A, B, C]): + if c.kind == "A": + tc1: Literal["A | C"] = reveal_type(c) + else: + tc2: Literal["B | C"] = reveal_type(c) + + +def is_obj3(c: Union[A, B, C]): + if c.kind is "A": + tc1: Literal["A | C"] = reveal_type(c) + else: + tc2: Literal["B | C"] = reveal_type(c) + + +def eq_obj4(c: Union[A, B]): if c.d == 1: tc1: Literal["A"] = reveal_type(c) elif c.d == 3: tc2: Literal["A | B"] = reveal_type(c) -def foo_obj5(d: D): +def is_obj4(c: Union[A, B]): + if c.d is 1: + tc1: Literal["A"] = reveal_type(c) + elif c.d is 3: + tc2: Literal["A | B"] = reveal_type(c) + + +def eq_obj5(d: D): if d.kind == 1: td1: Literal["D"] = reveal_type(d) elif d.kind == 2: td2: Literal["D"] = reveal_type(d) -def foo_class2(c: Union[Type[A], Type[B]]): +def is_obj5(d: D): + if d.kind is 1: + td1: Literal["D"] = reveal_type(d) + elif d.kind is 2: + td2: Literal["D"] = reveal_type(d) + + +def eq_class2(c: Union[Type[A], Type[B]]): if c.kind_class == "A": tc1: Literal["Type[A]"] = reveal_type(c) else: tc2: Literal["Type[B]"] = reveal_type(c) + + +def is_class2(c: Union[Type[A], Type[B]]): + if c.kind_class is "A": + tc1: Literal["Type[A]"] = reveal_type(c) + else: + tc2: Literal["Type[B]"] = reveal_type(c)