From a4900ebd03004e76b3e7499f1b3118be511304ca Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sat, 5 Mar 2022 01:35:17 -0700 Subject: [PATCH] Extended support for narrowing of index expressions to include those with negative subscripts, such as `a[-1]`. This is supported for all supported type guard patterns. --- docs/type-concepts.md | 2 +- .../src/analyzer/codeFlowTypes.ts | 24 +++++++++++---- .../src/analyzer/parseTreeUtils.ts | 29 +++++++++++++++---- .../src/tests/samples/typeNarrowing7.py | 6 ++++ 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 9def6d729..1367ab3fa 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -186,7 +186,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `bool(x)` (where x is any expression that is statically verifiable to be truthy or falsy in all cases). * `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases) -Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported. +Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are integer literals (e.g. `a[2]` or `a[-1]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported. Some type guards are able to narrow in both the positive and negative cases. Positive cases are used in `if` statements, and negative cases are used in `else` statements. (Positive and negative cases are flipped if the type guard expression is preceded by a `not` operator.) In some cases, the type can be narrowed only in the positive or negative case but not both. Consider the following examples: diff --git a/packages/pyright-internal/src/analyzer/codeFlowTypes.ts b/packages/pyright-internal/src/analyzer/codeFlowTypes.ts index 424aabb3f..234c08b44 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowTypes.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowTypes.ts @@ -29,6 +29,7 @@ import { StringNode, SuiteNode, } from '../parser/parseNodes'; +import { OperatorType } from '../parser/tokenizerTypes'; export enum FlowFlags { Unreachable = 1 << 0, // Unreachable code @@ -188,12 +189,18 @@ export function isCodeFlowSupportedForReference(reference: ExpressionNode): bool const subscriptNode = reference.items[0].valueExpression; const isIntegerIndex = subscriptNode.nodeType === ParseNodeType.Number && !subscriptNode.isImaginary && subscriptNode.isInteger; + const isNegativeIntegerIndex = + subscriptNode.nodeType === ParseNodeType.UnaryOperation && + subscriptNode.operator === OperatorType.Subtract && + subscriptNode.expression.nodeType === ParseNodeType.Number && + !subscriptNode.expression.isImaginary && + subscriptNode.expression.isInteger; const isStringIndex = subscriptNode.nodeType === ParseNodeType.StringList && subscriptNode.strings.length === 1 && subscriptNode.strings[0].nodeType === ParseNodeType.String; - if (!isIntegerIndex && !isStringIndex) { + if (!isIntegerIndex && !isNegativeIntegerIndex && !isStringIndex) { return false; } @@ -213,12 +220,19 @@ export function createKeyForReference(reference: CodeFlowReferenceExpressionNode } else if (reference.nodeType === ParseNodeType.Index) { const leftKey = createKeyForReference(reference.baseExpression as CodeFlowReferenceExpressionNode); assert(reference.items.length === 1); - if (reference.items[0].valueExpression.nodeType === ParseNodeType.Number) { - key = `${leftKey}[${(reference.items[0].valueExpression as NumberNode).value.toString()}]`; - } else if (reference.items[0].valueExpression.nodeType === ParseNodeType.StringList) { - const valExpr = reference.items[0].valueExpression; + const expr = reference.items[0].valueExpression; + if (expr.nodeType === ParseNodeType.Number) { + key = `${leftKey}[${(expr as NumberNode).value.toString()}]`; + } else if (expr.nodeType === ParseNodeType.StringList) { + const valExpr = expr; assert(valExpr.strings.length === 1 && valExpr.strings[0].nodeType === ParseNodeType.String); key = `${leftKey}["${(valExpr.strings[0] as StringNode).value}"]`; + } else if ( + expr.nodeType === ParseNodeType.UnaryOperation && + expr.operator === OperatorType.Subtract && + expr.expression.nodeType === ParseNodeType.Number + ) { + key = `${leftKey}[-${(expr.expression as NumberNode).value.toString()}]`; } else { fail('createKeyForReference received unexpected index type'); } diff --git a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts index 9e01a51ea..d7b5a1363 100644 --- a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts +++ b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts @@ -978,8 +978,8 @@ export function isMatchingExpression(reference: ExpressionNode, expression: Expr return false; } - if (reference.items[0].valueExpression.nodeType === ParseNodeType.Number) { - const referenceNumberNode = reference.items[0].valueExpression; + const expr = reference.items[0].valueExpression; + if (expr.nodeType === ParseNodeType.Number) { const subscriptNode = expression.items[0].valueExpression; if ( subscriptNode.nodeType !== ParseNodeType.Number || @@ -989,11 +989,30 @@ export function isMatchingExpression(reference: ExpressionNode, expression: Expr return false; } - return referenceNumberNode.value === subscriptNode.value; + return expr.value === subscriptNode.value; } - if (reference.items[0].valueExpression.nodeType === ParseNodeType.StringList) { - const referenceStringListNode = reference.items[0].valueExpression; + if ( + expr.nodeType === ParseNodeType.UnaryOperation && + expr.operator === OperatorType.Subtract && + expr.expression.nodeType === ParseNodeType.Number + ) { + const subscriptNode = expression.items[0].valueExpression; + if ( + subscriptNode.nodeType !== ParseNodeType.UnaryOperation || + subscriptNode.operator !== OperatorType.Subtract || + subscriptNode.expression.nodeType !== ParseNodeType.Number || + subscriptNode.expression.isImaginary || + !subscriptNode.expression.isInteger + ) { + return false; + } + + return expr.expression.value === subscriptNode.expression.value; + } + + if (expr.nodeType === ParseNodeType.StringList) { + const referenceStringListNode = expr; const subscriptNode = expression.items[0].valueExpression; if ( referenceStringListNode.strings.length === 1 && diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowing7.py b/packages/pyright-internal/src/tests/samples/typeNarrowing7.py index 3070ec325..3c469d50b 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowing7.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowing7.py @@ -36,12 +36,18 @@ def func1(v1: List[Optional[complex]]): foo.val = [] reveal_type(foo.val[0][2], expected_text="str | None") + if v1[-1]: + reveal_type(v1[-1], expected_text="complex") + def func2(v1: List[Union[Dict[str, str], List[str]]]): if isinstance(v1[0], dict): reveal_type(v1[0], expected_text="Dict[str, str]") reveal_type(v1[1], expected_text="Dict[str, str] | List[str]") + if isinstance(v1[-1], list): + reveal_type(v1[-1], expected_text="List[str]") + def func3(): v1: Dict[str, int] = {}