diff --git a/docs/type-concepts-advanced.md b/docs/type-concepts-advanced.md index 8042d08f2..6de259633 100644 --- a/docs/type-concepts-advanced.md +++ b/docs/type-concepts-advanced.md @@ -68,6 +68,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type) * `x[K] == V`, `x[K] != V`, `x[K] is V`, and `x[K] is not V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) +* `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` literal, and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) * `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer) * `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, tuple, dict, defaultdict, or OrderedDict) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index d942a40a3..7861df25c 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -277,6 +277,28 @@ export function getTypeNarrowingCallback( }; }; } + } else if (ClassType.isBuiltIn(indexType, 'int')) { + const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression); + const rightType = rightTypeResult.type; + + if ( + isClassInstance(rightType) && + ClassType.isBuiltIn(rightType, 'bool') && + rightType.literalValue !== undefined + ) { + return (type: Type) => { + return { + type: narrowTypeForDiscriminatedTupleComparison( + evaluator, + type, + indexType, + rightType, + adjIsPositiveTest + ), + isIncomplete: !!rightTypeResult.isIncomplete, + }; + }; + } } } } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py index 7fdba94d8..b7983c80c 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingTuple1.py @@ -1,23 +1,43 @@ # This sample tests the type narrowing for known-length tuples # that have an entry with a declared literal type. -from typing import Tuple, Union, Literal +from typing import Literal -MsgA = Tuple[Literal[1], str] -MsgB = Tuple[Literal[2], float] +MsgA = tuple[Literal[1], str] +MsgB = tuple[Literal[2], float] -Msg = Union[MsgA, MsgB] +MsgAOrB = MsgA | MsgB -def func1(m: Msg): +def func1(m: MsgAOrB): if m[0] == 1: - reveal_type(m, expected_text="Tuple[Literal[1], str]") + reveal_type(m, expected_text="tuple[Literal[1], str]") else: - reveal_type(m, expected_text="Tuple[Literal[2], float]") + reveal_type(m, expected_text="tuple[Literal[2], float]") -def func2(m: Msg): +def func2(m: MsgAOrB): if m[0] != 1: - reveal_type(m, expected_text="Tuple[Literal[2], float]") + reveal_type(m, expected_text="tuple[Literal[2], float]") else: - reveal_type(m, expected_text="Tuple[Literal[1], str]") + reveal_type(m, expected_text="tuple[Literal[1], str]") + + +MsgC = tuple[Literal[True], str] +MsgD = tuple[Literal[False], float] + +MsgCOrD = MsgC | MsgD + + +def func3(m: MsgCOrD): + if m[0] is True: + reveal_type(m, expected_text="tuple[Literal[True], str]") + else: + reveal_type(m, expected_text="tuple[Literal[False], float]") + + +def func4(m: MsgCOrD): + if m[0] is not True: + reveal_type(m, expected_text="tuple[Literal[False], float]") + else: + reveal_type(m, expected_text="tuple[Literal[True], str]")