Extended support for narrowing of index expressions to include those with negative subscripts, such as a[-1]. This is supported for all supported type guard patterns.

This commit is contained in:
Eric Traut 2022-03-05 01:35:17 -07:00
parent 0462654383
commit a4900ebd03
4 changed files with 50 additions and 11 deletions

View File

@ -186,7 +186,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `bool(x)` (where x is any expression that is statically verifiable to be truthy or falsy in all cases).
* `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases)
Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported.
Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are integer literals (e.g. `a[2]` or `a[-1]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported.
Some type guards are able to narrow in both the positive and negative cases. Positive cases are used in `if` statements, and negative cases are used in `else` statements. (Positive and negative cases are flipped if the type guard expression is preceded by a `not` operator.) In some cases, the type can be narrowed only in the positive or negative case but not both. Consider the following examples:

View File

@ -29,6 +29,7 @@ import {
StringNode,
SuiteNode,
} from '../parser/parseNodes';
import { OperatorType } from '../parser/tokenizerTypes';
export enum FlowFlags {
Unreachable = 1 << 0, // Unreachable code
@ -188,12 +189,18 @@ export function isCodeFlowSupportedForReference(reference: ExpressionNode): bool
const subscriptNode = reference.items[0].valueExpression;
const isIntegerIndex =
subscriptNode.nodeType === ParseNodeType.Number && !subscriptNode.isImaginary && subscriptNode.isInteger;
const isNegativeIntegerIndex =
subscriptNode.nodeType === ParseNodeType.UnaryOperation &&
subscriptNode.operator === OperatorType.Subtract &&
subscriptNode.expression.nodeType === ParseNodeType.Number &&
!subscriptNode.expression.isImaginary &&
subscriptNode.expression.isInteger;
const isStringIndex =
subscriptNode.nodeType === ParseNodeType.StringList &&
subscriptNode.strings.length === 1 &&
subscriptNode.strings[0].nodeType === ParseNodeType.String;
if (!isIntegerIndex && !isStringIndex) {
if (!isIntegerIndex && !isNegativeIntegerIndex && !isStringIndex) {
return false;
}
@ -213,12 +220,19 @@ export function createKeyForReference(reference: CodeFlowReferenceExpressionNode
} else if (reference.nodeType === ParseNodeType.Index) {
const leftKey = createKeyForReference(reference.baseExpression as CodeFlowReferenceExpressionNode);
assert(reference.items.length === 1);
if (reference.items[0].valueExpression.nodeType === ParseNodeType.Number) {
key = `${leftKey}[${(reference.items[0].valueExpression as NumberNode).value.toString()}]`;
} else if (reference.items[0].valueExpression.nodeType === ParseNodeType.StringList) {
const valExpr = reference.items[0].valueExpression;
const expr = reference.items[0].valueExpression;
if (expr.nodeType === ParseNodeType.Number) {
key = `${leftKey}[${(expr as NumberNode).value.toString()}]`;
} else if (expr.nodeType === ParseNodeType.StringList) {
const valExpr = expr;
assert(valExpr.strings.length === 1 && valExpr.strings[0].nodeType === ParseNodeType.String);
key = `${leftKey}["${(valExpr.strings[0] as StringNode).value}"]`;
} else if (
expr.nodeType === ParseNodeType.UnaryOperation &&
expr.operator === OperatorType.Subtract &&
expr.expression.nodeType === ParseNodeType.Number
) {
key = `${leftKey}[-${(expr.expression as NumberNode).value.toString()}]`;
} else {
fail('createKeyForReference received unexpected index type');
}

View File

@ -978,8 +978,8 @@ export function isMatchingExpression(reference: ExpressionNode, expression: Expr
return false;
}
if (reference.items[0].valueExpression.nodeType === ParseNodeType.Number) {
const referenceNumberNode = reference.items[0].valueExpression;
const expr = reference.items[0].valueExpression;
if (expr.nodeType === ParseNodeType.Number) {
const subscriptNode = expression.items[0].valueExpression;
if (
subscriptNode.nodeType !== ParseNodeType.Number ||
@ -989,11 +989,30 @@ export function isMatchingExpression(reference: ExpressionNode, expression: Expr
return false;
}
return referenceNumberNode.value === subscriptNode.value;
return expr.value === subscriptNode.value;
}
if (reference.items[0].valueExpression.nodeType === ParseNodeType.StringList) {
const referenceStringListNode = reference.items[0].valueExpression;
if (
expr.nodeType === ParseNodeType.UnaryOperation &&
expr.operator === OperatorType.Subtract &&
expr.expression.nodeType === ParseNodeType.Number
) {
const subscriptNode = expression.items[0].valueExpression;
if (
subscriptNode.nodeType !== ParseNodeType.UnaryOperation ||
subscriptNode.operator !== OperatorType.Subtract ||
subscriptNode.expression.nodeType !== ParseNodeType.Number ||
subscriptNode.expression.isImaginary ||
!subscriptNode.expression.isInteger
) {
return false;
}
return expr.expression.value === subscriptNode.expression.value;
}
if (expr.nodeType === ParseNodeType.StringList) {
const referenceStringListNode = expr;
const subscriptNode = expression.items[0].valueExpression;
if (
referenceStringListNode.strings.length === 1 &&

View File

@ -36,12 +36,18 @@ def func1(v1: List[Optional[complex]]):
foo.val = []
reveal_type(foo.val[0][2], expected_text="str | None")
if v1[-1]:
reveal_type(v1[-1], expected_text="complex")
def func2(v1: List[Union[Dict[str, str], List[str]]]):
if isinstance(v1[0], dict):
reveal_type(v1[0], expected_text="Dict[str, str]")
reveal_type(v1[1], expected_text="Dict[str, str] | List[str]")
if isinstance(v1[-1], list):
reveal_type(v1[-1], expected_text="List[str]")
def func3():
v1: Dict[str, int] = {}