From a2784020f2ec087242a266303faed712413e4de2 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Fri, 13 Jan 2023 16:53:18 -0800 Subject: [PATCH] Added support for `x[K] is V` and `x[K] is not V` type narrowing forms. This addresses https://github.com/microsoft/pyright/issues/4453. --- docs/type-concepts.md | 2 +- .../src/analyzer/typeGuards.ts | 34 +++++++++++++++++++ .../tests/samples/typeNarrowingTypedDict2.py | 13 ++++++- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 0d7fe0a67..fd090aa33 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -176,7 +176,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None) * `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type) * `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type) -* `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) +* `x[K] == V`, `x[K] != V`, `x[K] is V`, and `x[K] is not V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) * `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index cdd62ae4e..ec3f33234 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -246,6 +246,40 @@ export function getTypeNarrowingCallback( }; } } + + // Look for X[] is or X[] is not + if ( + testExpression.leftExpression.nodeType === ParseNodeType.Index && + testExpression.leftExpression.items.length === 1 && + !testExpression.leftExpression.trailingComma && + testExpression.leftExpression.items[0].argumentCategory === ArgumentCategory.Simple && + ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.baseExpression) + ) { + const indexTypeResult = evaluator.getTypeOfExpression( + testExpression.leftExpression.items[0].valueExpression + ); + const indexType = indexTypeResult.type; + + if (isClassInstance(indexType) && isLiteralType(indexType)) { + if (ClassType.isBuiltIn(indexType, 'str')) { + const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; + if (isClassInstance(rightType) && rightType.literalValue !== undefined) { + return (type: Type) => { + return { + type: narrowTypeForDiscriminatedDictEntryComparison( + evaluator, + type, + indexType, + rightType, + adjIsPositiveTest + ), + isIncomplete: !!indexTypeResult.isIncomplete, + }; + }; + } + } + } + } } if (equalsOrNotEqualsOperator) { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingTypedDict2.py b/packages/pyright-internal/src/tests/samples/typeNarrowingTypedDict2.py index abed423c4..20e028d72 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingTypedDict2.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingTypedDict2.py @@ -23,7 +23,7 @@ class OtherEvent(TypedDict): Event = Union[NewJobEvent, CancelJobEvent, OtherEvent] -def process_event(event: Event) -> None: +def process_event1(event: Event) -> None: if event["tag"] == "new-job": reveal_type(event, expected_text="NewJobEvent") event["job_name"] @@ -33,3 +33,14 @@ def process_event(event: Event) -> None: else: reveal_type(event, expected_text="OtherEvent") event["message"] + +def process_event2(event: Event) -> None: + if event["tag"] is "new-job": + reveal_type(event, expected_text="NewJobEvent") + event["job_name"] + elif event["tag"] is 2: + reveal_type(event, expected_text="CancelJobEvent") + event["job_id"] + else: + reveal_type(event, expected_text="OtherEvent") + event["message"]