From 98d15230771889d600b4f968296decc03f2d54fd Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 9 Apr 2024 23:13:17 -0700 Subject: [PATCH] Expanded support for `len(x) == L` type guard pattern (where x is a tuple) to support `<`, `<=`, `>` and `>=` comparisons as well. This addresses #7655. (#7657) --- docs/type-concepts-advanced.md | 2 +- .../pyright-internal/src/analyzer/binder.ts | 23 +++ .../src/analyzer/typeGuards.ts | 182 ++++++++++++------ .../src/tests/samples/loop16.py | 20 +- .../samples/typeNarrowingTupleLength1.py | 75 +++++++- 5 files changed, 232 insertions(+), 70 deletions(-) diff --git a/docs/type-concepts-advanced.md b/docs/type-concepts-advanced.md index 40a6de5f4..2a938e189 100644 --- a/docs/type-concepts-advanced.md +++ b/docs/type-concepts-advanced.md @@ -71,7 +71,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` or enum literal, and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) -* `len(x) == L` and `len(x) != L` (where x is tuple and L is an expression that evaluates to an int literal type) +* `len(x) == L`, `len(x) != L`, `len(x) < L`, etc. (where x is tuple and L is an expression that evaluates to an int literal type) * `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, tuple, dict, defaultdict, or OrderedDict) * `S in D` and `S not in D` (where S is a string literal and D is a TypedDict) * `isinstance(x, T)` (where T is a type or a tuple of types) diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 016963ae4..b26b1b917 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -3100,9 +3100,32 @@ export class Binder extends ParseTreeWalker { // Look for "X is Y" or "X is not Y". // Look for X == or X != + // Look for len(X) == or len(X) != return isLeftNarrowing; } + // Look for len(X) < , len(X) <= , len(X) > , len(X) >= . + if ( + expression.rightExpression.nodeType === ParseNodeType.Number && + expression.rightExpression.isInteger + ) { + if ( + expression.operator === OperatorType.LessThan || + expression.operator === OperatorType.LessThanOrEqual || + expression.operator === OperatorType.GreaterThan || + expression.operator === OperatorType.GreaterThanOrEqual + ) { + const isLeftNarrowing = this._isNarrowingExpression( + expression.leftExpression, + expressionList, + filterForNeverNarrowing, + /* isComplexExpression */ true + ); + + return isLeftNarrowing; + } + } + // Look for " in Y" or " not in Y". if (expression.operator === OperatorType.In || expression.operator === OperatorType.NotIn) { if ( diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 75b5e27fa..01ad0ed4f 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -132,6 +132,12 @@ export function getTypeNarrowingCallback( testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.IsNot; const equalsOrNotEqualsOperator = testExpression.operator === OperatorType.Equals || testExpression.operator === OperatorType.NotEquals; + const comparisonOperator = + equalsOrNotEqualsOperator || + testExpression.operator === OperatorType.LessThan || + testExpression.operator === OperatorType.LessThanOrEqual || + testExpression.operator === OperatorType.GreaterThan || + testExpression.operator === OperatorType.GreaterThanOrEqual; if (isOrIsNotOperator || equalsOrNotEqualsOperator) { // Invert the "isPositiveTest" value if this is an "is not" operation. @@ -412,43 +418,6 @@ export function getTypeNarrowingCallback( } } - // Look for len(x) == or len(x) != - if ( - equalsOrNotEqualsOperator && - testExpression.leftExpression.nodeType === ParseNodeType.Call && - testExpression.leftExpression.arguments.length === 1 - ) { - const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression; - - if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) { - const callTypeResult = evaluator.getTypeOfExpression( - testExpression.leftExpression.leftExpression, - EvaluatorFlags.CallBaseDefaults - ); - const callType = callTypeResult.type; - - if (isFunction(callType) && callType.details.fullName === 'builtins.len') { - const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression); - const rightType = rightTypeResult.type; - - if ( - isClassInstance(rightType) && - typeof rightType.literalValue === 'number' && - rightType.literalValue >= 0 - ) { - const tupleLength = rightType.literalValue; - - return (type: Type) => { - return { - type: narrowTypeForTupleLength(evaluator, type, tupleLength, adjIsPositiveTest), - isIncomplete: !!callTypeResult.isIncomplete || !!rightTypeResult.isIncomplete, - }; - }; - } - } - } - } - // Look for X.Y == or X.Y != if ( equalsOrNotEqualsOperator && @@ -530,6 +499,70 @@ export function getTypeNarrowingCallback( } } + // Look for len(x) == , len(x) != , len(x) < , etc. + if ( + comparisonOperator && + testExpression.leftExpression.nodeType === ParseNodeType.Call && + testExpression.leftExpression.arguments.length === 1 + ) { + const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression; + + if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) { + const callTypeResult = evaluator.getTypeOfExpression( + testExpression.leftExpression.leftExpression, + EvaluatorFlags.CallBaseDefaults + ); + const callType = callTypeResult.type; + + if (isFunction(callType) && callType.details.fullName === 'builtins.len') { + const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression); + const rightType = rightTypeResult.type; + + if ( + isClassInstance(rightType) && + typeof rightType.literalValue === 'number' && + rightType.literalValue >= 0 + ) { + let tupleLength = rightType.literalValue; + + // We'll treat <, <= and == as positive tests with >=, > and != as + // their negative counterparts. + const isLessOrEqual = + testExpression.operator === OperatorType.Equals || + testExpression.operator === OperatorType.LessThan || + testExpression.operator === OperatorType.LessThanOrEqual; + + const adjIsPositiveTest = isLessOrEqual ? isPositiveTest : !isPositiveTest; + + // For <= (or its negative counterpart >), adjust the tuple length by 1. + if ( + testExpression.operator === OperatorType.LessThanOrEqual || + testExpression.operator === OperatorType.GreaterThan + ) { + tupleLength++; + } + + const isEqualityCheck = + testExpression.operator === OperatorType.Equals || + testExpression.operator === OperatorType.NotEquals; + + return (type: Type) => { + return { + type: narrowTypeForTupleLength( + evaluator, + type, + tupleLength, + adjIsPositiveTest, + !isEqualityCheck + ), + isIncomplete: !!callTypeResult.isIncomplete || !!rightTypeResult.isIncomplete, + }; + }; + } + } + } + } + if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) { // Look for "x in y" or "x not in y" where y is one of several built-in types. if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { @@ -1897,7 +1930,8 @@ function narrowTypeForTupleLength( evaluator: TypeEvaluator, referenceType: Type, lengthValue: number, - isPositiveTest: boolean + isPositiveTest: boolean, + isLessThanCheck: boolean ) { return mapSubtypes(referenceType, (subtype) => { const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); @@ -1918,7 +1952,10 @@ function narrowTypeForTupleLength( // If the tuple contains no unbounded elements, then we know its length exactly. if (!concreteSubtype.tupleTypeArguments.some((typeArg) => typeArg.isUnbounded)) { - const tupleLengthMatches = concreteSubtype.tupleTypeArguments.length === lengthValue; + const tupleLengthMatches = isLessThanCheck + ? concreteSubtype.tupleTypeArguments.length < lengthValue + : concreteSubtype.tupleTypeArguments.length === lengthValue; + return tupleLengthMatches === isPositiveTest ? subtype : undefined; } @@ -1926,31 +1963,68 @@ function narrowTypeForTupleLength( // necessary to match the lengthValue. const elementsToAdd = lengthValue - concreteSubtype.tupleTypeArguments.length + 1; - // If the specified length is smaller than the minimum length of this tuple, - // we can rule it out for a positive test. - if (elementsToAdd < 0) { - return isPositiveTest ? undefined : subtype; + if (!isLessThanCheck) { + // If the specified length is smaller than the minimum length of this tuple, + // we can rule it out for a positive test and rule it in for a negative test. + if (elementsToAdd < 0) { + return isPositiveTest ? undefined : subtype; + } + + if (!isPositiveTest) { + return subtype; + } + + return expandUnboundedTupleElement(concreteSubtype, elementsToAdd, /* keepUnbounded */ false); } - if (!isPositiveTest) { + // Place an upper limit on the number of union subtypes we + // will expand the tuple to. + const maxTupleUnionExpansion = 32; + if (elementsToAdd > maxTupleUnionExpansion) { return subtype; } - const tupleTypeArgs: TupleTypeArgument[] = []; - concreteSubtype.tupleTypeArguments.forEach((typeArg) => { - if (!typeArg.isUnbounded) { - tupleTypeArgs.push(typeArg); - } else { - for (let i = 0; i < elementsToAdd; i++) { - tupleTypeArgs.push({ isUnbounded: false, type: typeArg.type }); - } + if (isPositiveTest) { + if (elementsToAdd < 1) { + return undefined; } - }); - return specializeTupleClass(concreteSubtype, tupleTypeArgs); + const typesToCombine: Type[] = []; + + for (let i = 0; i < elementsToAdd; i++) { + typesToCombine.push(expandUnboundedTupleElement(concreteSubtype, i, /* keepUnbounded */ false)); + } + + return combineTypes(typesToCombine); + } + + return expandUnboundedTupleElement(concreteSubtype, elementsToAdd, /* keepUnbounded */ true); }); } +// Expands a tuple type that contains an unbounded element to include +// multiple bounded elements of that same type in place of (or in addition +// to) the unbounded element. +function expandUnboundedTupleElement(tupleType: ClassType, elementsToAdd: number, keepUnbounded: boolean) { + const tupleTypeArgs: TupleTypeArgument[] = []; + + tupleType.tupleTypeArguments!.forEach((typeArg) => { + if (!typeArg.isUnbounded) { + tupleTypeArgs.push(typeArg); + } else { + for (let i = 0; i < elementsToAdd; i++) { + tupleTypeArgs.push({ isUnbounded: false, type: typeArg.type }); + } + + if (keepUnbounded) { + tupleTypeArgs.push(typeArg); + } + } + }); + + return specializeTupleClass(tupleType, tupleTypeArgs); +} + // Attempts to narrow a type (make it more constrained) based on an "in" binary operator. function narrowTypeForContainerType( evaluator: TypeEvaluator, diff --git a/packages/pyright-internal/src/tests/samples/loop16.py b/packages/pyright-internal/src/tests/samples/loop16.py index d69790861..4ba894e85 100644 --- a/packages/pyright-internal/src/tests/samples/loop16.py +++ b/packages/pyright-internal/src/tests/samples/loop16.py @@ -273,16 +273,16 @@ def get_ipv4(): continue elif ip1 == 13 and ip2 == 107 and ip3 == 6 and ip4 == 152: continue - elif ip1 == 13 and ip2 == 107 and ip3 == 18 and ip4 == 10: - continue - elif ip1 == 13 and ip2 == 107 and ip3 == 128 and ip4 == 0: - continue - elif ip1 == 23 and ip2 == 103 and ip3 == 160 and ip4 == 0: - continue - elif ip1 == 40 and ip2 == 96 and ip3 == 0 and ip4 == 0: - continue - elif ip1 == 40 and ip2 == 104 and ip3 == 0 and ip4 == 0: - continue + # elif ip1 == 13 and ip2 == 107 and ip3 == 18 and ip4 == 10: + # continue + # elif ip1 == 13 and ip2 == 107 and ip3 == 128 and ip4 == 0: + # continue + # elif ip1 == 23 and ip2 == 103 and ip3 == 160 and ip4 == 0: + # continue + # elif ip1 == 40 and ip2 == 96 and ip3 == 0 and ip4 == 0: + # continue + # elif ip1 == 40 and ip2 == 104 and ip3 == 0 and ip4 == 0: + # continue # elif ip1 == 52 and ip2 == 96 and ip3 == 0 and ip4 == 0: # continue # elif ip1 == 131 and ip2 == 253 and ip3 == 33 and ip4 == 215: diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingTupleLength1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingTupleLength1.py index 922c39d38..d5d5be82d 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingTupleLength1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingTupleLength1.py @@ -49,11 +49,13 @@ def func4(val: _T1 | _T2) -> _T1 | _T2: def func5( - val: tuple[int, ...] - | tuple[str] - | tuple[str, str, str] - | tuple[int, *tuple[str, ...], str] - | tuple[int, *tuple[float, ...]], + val: ( + tuple[int, ...] + | tuple[str] + | tuple[str, str, str] + | tuple[int, *tuple[str, ...], str] + | tuple[int, *tuple[float, ...]] + ), length: Literal[2], ): if len(val) == length: @@ -65,3 +67,66 @@ def func5( val, expected_text="tuple[int, ...] | tuple[str] | tuple[str, str, str] | tuple[int, *tuple[str, ...], str] | tuple[int, *tuple[float, ...]]", ) + + +def func10(t: tuple[()] | tuple[int] | tuple[int, int] | tuple[int, int, int]): + if len(t) >= 2: + reveal_type(t, expected_text="tuple[int, int] | tuple[int, int, int]") + else: + reveal_type(t, expected_text="tuple[()] | tuple[int]") + + +def func11(t: tuple[()] | tuple[int] | tuple[int, int] | tuple[int, int, int]): + if len(t) > 1: + reveal_type(t, expected_text="tuple[int, int] | tuple[int, int, int]") + else: + reveal_type(t, expected_text="tuple[()] | tuple[int]") + + +def func12(t: tuple[()] | tuple[int] | tuple[int, int]): + if len(t) >= 0: + reveal_type(t, expected_text="tuple[()] | tuple[int] | tuple[int, int]") + else: + reveal_type(t, expected_text="Never") + + +def func20(t: tuple[int, ...]): + if len(t) >= 2: + reveal_type(t, expected_text="tuple[int, int, *tuple[int, ...]]") + else: + reveal_type(t, expected_text="tuple[()] | tuple[int]") + + +def func21(t: tuple[int, ...]): + if len(t) > 0: + reveal_type(t, expected_text="tuple[int, *tuple[int, ...]]") + else: + reveal_type(t, expected_text="tuple[()]") + + +def func22(t: tuple[str, *tuple[int, ...], str]): + if len(t) < 3: + reveal_type(t, expected_text="tuple[str, str]") + else: + reveal_type(t, expected_text="tuple[str, int, *tuple[int, ...], str]") + + +def func23(t: tuple[str, *tuple[int, ...], str]): + if len(t) <= 3: + reveal_type(t, expected_text="tuple[str, str] | tuple[str, int, str]") + else: + reveal_type(t, expected_text="tuple[str, int, int, *tuple[int, ...], str]") + + +def func24(t: tuple[str, *tuple[int, ...], str]): + if len(t) <= 34: + reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]") + else: + reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]") + + +def func25(t: tuple[str, *tuple[int, ...], str]): + if len(t) < 2: + reveal_type(t, expected_text="Never") + else: + reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]")