diff --git a/docs/type-concepts-advanced.md b/docs/type-concepts-advanced.md index 6de259633..dd476177b 100644 --- a/docs/type-concepts-advanced.md +++ b/docs/type-concepts-advanced.md @@ -68,7 +68,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type) * `x[K] == V`, `x[K] != V`, `x[K] is V`, and `x[K] is not V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) -* `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` literal, and x is a known-length tuple that is distinguished by the index indicated by I) +* `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` or enum literal, and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) * `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer) * `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, tuple, dict, defaultdict, or OrderedDict) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 7861df25c..46110d5a6 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -29,6 +29,7 @@ import { getScopeForNode } from './scopeUtils'; import { Symbol, SymbolFlags } from './symbol'; import { getTypedDictMembersForClass } from './typedDicts'; import { EvaluatorFlags, TypeEvaluator } from './typeEvaluatorTypes'; +import { EnumLiteral } from './types'; import { ClassType, ClassTypeFlags, @@ -281,23 +282,29 @@ export function getTypeNarrowingCallback( const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression); const rightType = rightTypeResult.type; - if ( - isClassInstance(rightType) && - ClassType.isBuiltIn(rightType, 'bool') && - rightType.literalValue !== undefined - ) { - return (type: Type) => { - return { - type: narrowTypeForDiscriminatedTupleComparison( - evaluator, - type, - indexType, - rightType, - adjIsPositiveTest - ), - isIncomplete: !!rightTypeResult.isIncomplete, + if (isClassInstance(rightType) && rightType.literalValue !== undefined) { + let canNarrow = false; + // Narrowing can be applied only for bool or enum literals. + if (ClassType.isBuiltIn(rightType, 'bool')) { + canNarrow = true; + } else if (rightType.literalValue instanceof EnumLiteral) { + canNarrow = true; + } + + if (canNarrow) { + return (type: Type) => { + return { + type: narrowTypeForDiscriminatedTupleComparison( + evaluator, + type, + indexType, + rightType, + adjIsPositiveTest + ), + isIncomplete: !!rightTypeResult.isIncomplete, + }; }; - }; + } } } } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py index b7983c80c..243315a57 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py @@ -1,6 +1,7 @@ # This sample tests the type narrowing for known-length tuples # that have an entry with a declared literal type. +from enum import Enum from typing import Literal MsgA = tuple[Literal[1], str] @@ -41,3 +42,28 @@ def func4(m: MsgCOrD): reveal_type(m, expected_text="tuple[Literal[False], float]") else: reveal_type(m, expected_text="tuple[Literal[True], str]") + + +class MyEnum(Enum): + A = 0 + B = 1 + + +MsgE = tuple[Literal[MyEnum.A], str] +MsgF = tuple[Literal[MyEnum.B], float] + +MsgEOrF = MsgE | MsgF + + +def func5(m: MsgEOrF): + if m[0] is MyEnum.A: + reveal_type(m, expected_text="tuple[Literal[MyEnum.A], str]") + else: + reveal_type(m, expected_text="tuple[Literal[MyEnum.B], float]") + + +def func6(m: MsgEOrF): + if m[0] is not MyEnum.A: + reveal_type(m, expected_text="tuple[Literal[MyEnum.B], float]") + else: + reveal_type(m, expected_text="tuple[Literal[MyEnum.A], str]")