Expanded support for len(x) == L type guard pattern (where x is a tuple) to support <, <=, > and >= comparisons as well. This addresses #7655. (#7657)

This commit is contained in:
Eric Traut 2024-04-09 23:13:17 -07:00 committed by GitHub
parent 989ee29c0b
commit 98d1523077
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 232 additions and 70 deletions

View File

@ -71,7 +71,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` or enum literal, and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)
* `len(x) == L` and `len(x) != L` (where x is tuple and L is an expression that evaluates to an int literal type)
* `len(x) == L`, `len(x) != L`, `len(x) < L`, etc. (where x is tuple and L is an expression that evaluates to an int literal type)
* `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, tuple, dict, defaultdict, or OrderedDict)
* `S in D` and `S not in D` (where S is a string literal and D is a TypedDict)
* `isinstance(x, T)` (where T is a type or a tuple of types)

View File

@ -3100,9 +3100,32 @@ export class Binder extends ParseTreeWalker {
// Look for "X is Y" or "X is not Y".
// Look for X == <literal> or X != <literal>
// Look for len(X) == <literal> or len(X) != <literal>
return isLeftNarrowing;
}
// Look for len(X) < <literal>, len(X) <= <literal>, len(X) > <literal>, len(X) >= <literal>.
if (
expression.rightExpression.nodeType === ParseNodeType.Number &&
expression.rightExpression.isInteger
) {
if (
expression.operator === OperatorType.LessThan ||
expression.operator === OperatorType.LessThanOrEqual ||
expression.operator === OperatorType.GreaterThan ||
expression.operator === OperatorType.GreaterThanOrEqual
) {
const isLeftNarrowing = this._isNarrowingExpression(
expression.leftExpression,
expressionList,
filterForNeverNarrowing,
/* isComplexExpression */ true
);
return isLeftNarrowing;
}
}
// Look for "<string> in Y" or "<string> not in Y".
if (expression.operator === OperatorType.In || expression.operator === OperatorType.NotIn) {
if (

View File

@ -132,6 +132,12 @@ export function getTypeNarrowingCallback(
testExpression.operator === OperatorType.Is || testExpression.operator === OperatorType.IsNot;
const equalsOrNotEqualsOperator =
testExpression.operator === OperatorType.Equals || testExpression.operator === OperatorType.NotEquals;
const comparisonOperator =
equalsOrNotEqualsOperator ||
testExpression.operator === OperatorType.LessThan ||
testExpression.operator === OperatorType.LessThanOrEqual ||
testExpression.operator === OperatorType.GreaterThan ||
testExpression.operator === OperatorType.GreaterThanOrEqual;
if (isOrIsNotOperator || equalsOrNotEqualsOperator) {
// Invert the "isPositiveTest" value if this is an "is not" operation.
@ -412,43 +418,6 @@ export function getTypeNarrowingCallback(
}
}
// Look for len(x) == <literal> or len(x) != <literal>
if (
equalsOrNotEqualsOperator &&
testExpression.leftExpression.nodeType === ParseNodeType.Call &&
testExpression.leftExpression.arguments.length === 1
) {
const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const callTypeResult = evaluator.getTypeOfExpression(
testExpression.leftExpression.leftExpression,
EvaluatorFlags.CallBaseDefaults
);
const callType = callTypeResult.type;
if (isFunction(callType) && callType.details.fullName === 'builtins.len') {
const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression);
const rightType = rightTypeResult.type;
if (
isClassInstance(rightType) &&
typeof rightType.literalValue === 'number' &&
rightType.literalValue >= 0
) {
const tupleLength = rightType.literalValue;
return (type: Type) => {
return {
type: narrowTypeForTupleLength(evaluator, type, tupleLength, adjIsPositiveTest),
isIncomplete: !!callTypeResult.isIncomplete || !!rightTypeResult.isIncomplete,
};
};
}
}
}
}
// Look for X.Y == <literal> or X.Y != <literal>
if (
equalsOrNotEqualsOperator &&
@ -530,6 +499,70 @@ export function getTypeNarrowingCallback(
}
}
// Look for len(x) == <literal>, len(x) != <literal>, len(x) < <literal>, etc.
if (
comparisonOperator &&
testExpression.leftExpression.nodeType === ParseNodeType.Call &&
testExpression.leftExpression.arguments.length === 1
) {
const arg0Expr = testExpression.leftExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const callTypeResult = evaluator.getTypeOfExpression(
testExpression.leftExpression.leftExpression,
EvaluatorFlags.CallBaseDefaults
);
const callType = callTypeResult.type;
if (isFunction(callType) && callType.details.fullName === 'builtins.len') {
const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression);
const rightType = rightTypeResult.type;
if (
isClassInstance(rightType) &&
typeof rightType.literalValue === 'number' &&
rightType.literalValue >= 0
) {
let tupleLength = rightType.literalValue;
// We'll treat <, <= and == as positive tests with >=, > and != as
// their negative counterparts.
const isLessOrEqual =
testExpression.operator === OperatorType.Equals ||
testExpression.operator === OperatorType.LessThan ||
testExpression.operator === OperatorType.LessThanOrEqual;
const adjIsPositiveTest = isLessOrEqual ? isPositiveTest : !isPositiveTest;
// For <= (or its negative counterpart >), adjust the tuple length by 1.
if (
testExpression.operator === OperatorType.LessThanOrEqual ||
testExpression.operator === OperatorType.GreaterThan
) {
tupleLength++;
}
const isEqualityCheck =
testExpression.operator === OperatorType.Equals ||
testExpression.operator === OperatorType.NotEquals;
return (type: Type) => {
return {
type: narrowTypeForTupleLength(
evaluator,
type,
tupleLength,
adjIsPositiveTest,
!isEqualityCheck
),
isIncomplete: !!callTypeResult.isIncomplete || !!rightTypeResult.isIncomplete,
};
};
}
}
}
}
if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) {
// Look for "x in y" or "x not in y" where y is one of several built-in types.
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
@ -1897,7 +1930,8 @@ function narrowTypeForTupleLength(
evaluator: TypeEvaluator,
referenceType: Type,
lengthValue: number,
isPositiveTest: boolean
isPositiveTest: boolean,
isLessThanCheck: boolean
) {
return mapSubtypes(referenceType, (subtype) => {
const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype);
@ -1918,7 +1952,10 @@ function narrowTypeForTupleLength(
// If the tuple contains no unbounded elements, then we know its length exactly.
if (!concreteSubtype.tupleTypeArguments.some((typeArg) => typeArg.isUnbounded)) {
const tupleLengthMatches = concreteSubtype.tupleTypeArguments.length === lengthValue;
const tupleLengthMatches = isLessThanCheck
? concreteSubtype.tupleTypeArguments.length < lengthValue
: concreteSubtype.tupleTypeArguments.length === lengthValue;
return tupleLengthMatches === isPositiveTest ? subtype : undefined;
}
@ -1926,31 +1963,68 @@ function narrowTypeForTupleLength(
// necessary to match the lengthValue.
const elementsToAdd = lengthValue - concreteSubtype.tupleTypeArguments.length + 1;
// If the specified length is smaller than the minimum length of this tuple,
// we can rule it out for a positive test.
if (elementsToAdd < 0) {
return isPositiveTest ? undefined : subtype;
if (!isLessThanCheck) {
// If the specified length is smaller than the minimum length of this tuple,
// we can rule it out for a positive test and rule it in for a negative test.
if (elementsToAdd < 0) {
return isPositiveTest ? undefined : subtype;
}
if (!isPositiveTest) {
return subtype;
}
return expandUnboundedTupleElement(concreteSubtype, elementsToAdd, /* keepUnbounded */ false);
}
if (!isPositiveTest) {
// Place an upper limit on the number of union subtypes we
// will expand the tuple to.
const maxTupleUnionExpansion = 32;
if (elementsToAdd > maxTupleUnionExpansion) {
return subtype;
}
const tupleTypeArgs: TupleTypeArgument[] = [];
concreteSubtype.tupleTypeArguments.forEach((typeArg) => {
if (!typeArg.isUnbounded) {
tupleTypeArgs.push(typeArg);
} else {
for (let i = 0; i < elementsToAdd; i++) {
tupleTypeArgs.push({ isUnbounded: false, type: typeArg.type });
}
if (isPositiveTest) {
if (elementsToAdd < 1) {
return undefined;
}
});
return specializeTupleClass(concreteSubtype, tupleTypeArgs);
const typesToCombine: Type[] = [];
for (let i = 0; i < elementsToAdd; i++) {
typesToCombine.push(expandUnboundedTupleElement(concreteSubtype, i, /* keepUnbounded */ false));
}
return combineTypes(typesToCombine);
}
return expandUnboundedTupleElement(concreteSubtype, elementsToAdd, /* keepUnbounded */ true);
});
}
// Expands a tuple type that contains an unbounded element to include
// multiple bounded elements of that same type in place of (or in addition
// to) the unbounded element.
function expandUnboundedTupleElement(tupleType: ClassType, elementsToAdd: number, keepUnbounded: boolean) {
const tupleTypeArgs: TupleTypeArgument[] = [];
tupleType.tupleTypeArguments!.forEach((typeArg) => {
if (!typeArg.isUnbounded) {
tupleTypeArgs.push(typeArg);
} else {
for (let i = 0; i < elementsToAdd; i++) {
tupleTypeArgs.push({ isUnbounded: false, type: typeArg.type });
}
if (keepUnbounded) {
tupleTypeArgs.push(typeArg);
}
}
});
return specializeTupleClass(tupleType, tupleTypeArgs);
}
// Attempts to narrow a type (make it more constrained) based on an "in" binary operator.
function narrowTypeForContainerType(
evaluator: TypeEvaluator,

View File

@ -273,16 +273,16 @@ def get_ipv4():
continue
elif ip1 == 13 and ip2 == 107 and ip3 == 6 and ip4 == 152:
continue
elif ip1 == 13 and ip2 == 107 and ip3 == 18 and ip4 == 10:
continue
elif ip1 == 13 and ip2 == 107 and ip3 == 128 and ip4 == 0:
continue
elif ip1 == 23 and ip2 == 103 and ip3 == 160 and ip4 == 0:
continue
elif ip1 == 40 and ip2 == 96 and ip3 == 0 and ip4 == 0:
continue
elif ip1 == 40 and ip2 == 104 and ip3 == 0 and ip4 == 0:
continue
# elif ip1 == 13 and ip2 == 107 and ip3 == 18 and ip4 == 10:
# continue
# elif ip1 == 13 and ip2 == 107 and ip3 == 128 and ip4 == 0:
# continue
# elif ip1 == 23 and ip2 == 103 and ip3 == 160 and ip4 == 0:
# continue
# elif ip1 == 40 and ip2 == 96 and ip3 == 0 and ip4 == 0:
# continue
# elif ip1 == 40 and ip2 == 104 and ip3 == 0 and ip4 == 0:
# continue
# elif ip1 == 52 and ip2 == 96 and ip3 == 0 and ip4 == 0:
# continue
# elif ip1 == 131 and ip2 == 253 and ip3 == 33 and ip4 == 215:

View File

@ -49,11 +49,13 @@ def func4(val: _T1 | _T2) -> _T1 | _T2:
def func5(
val: tuple[int, ...]
| tuple[str]
| tuple[str, str, str]
| tuple[int, *tuple[str, ...], str]
| tuple[int, *tuple[float, ...]],
val: (
tuple[int, ...]
| tuple[str]
| tuple[str, str, str]
| tuple[int, *tuple[str, ...], str]
| tuple[int, *tuple[float, ...]]
),
length: Literal[2],
):
if len(val) == length:
@ -65,3 +67,66 @@ def func5(
val,
expected_text="tuple[int, ...] | tuple[str] | tuple[str, str, str] | tuple[int, *tuple[str, ...], str] | tuple[int, *tuple[float, ...]]",
)
def func10(t: tuple[()] | tuple[int] | tuple[int, int] | tuple[int, int, int]):
if len(t) >= 2:
reveal_type(t, expected_text="tuple[int, int] | tuple[int, int, int]")
else:
reveal_type(t, expected_text="tuple[()] | tuple[int]")
def func11(t: tuple[()] | tuple[int] | tuple[int, int] | tuple[int, int, int]):
if len(t) > 1:
reveal_type(t, expected_text="tuple[int, int] | tuple[int, int, int]")
else:
reveal_type(t, expected_text="tuple[()] | tuple[int]")
def func12(t: tuple[()] | tuple[int] | tuple[int, int]):
if len(t) >= 0:
reveal_type(t, expected_text="tuple[()] | tuple[int] | tuple[int, int]")
else:
reveal_type(t, expected_text="Never")
def func20(t: tuple[int, ...]):
if len(t) >= 2:
reveal_type(t, expected_text="tuple[int, int, *tuple[int, ...]]")
else:
reveal_type(t, expected_text="tuple[()] | tuple[int]")
def func21(t: tuple[int, ...]):
if len(t) > 0:
reveal_type(t, expected_text="tuple[int, *tuple[int, ...]]")
else:
reveal_type(t, expected_text="tuple[()]")
def func22(t: tuple[str, *tuple[int, ...], str]):
if len(t) < 3:
reveal_type(t, expected_text="tuple[str, str]")
else:
reveal_type(t, expected_text="tuple[str, int, *tuple[int, ...], str]")
def func23(t: tuple[str, *tuple[int, ...], str]):
if len(t) <= 3:
reveal_type(t, expected_text="tuple[str, str] | tuple[str, int, str]")
else:
reveal_type(t, expected_text="tuple[str, int, int, *tuple[int, ...], str]")
def func24(t: tuple[str, *tuple[int, ...], str]):
if len(t) <= 34:
reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]")
else:
reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]")
def func25(t: tuple[str, *tuple[int, ...], str]):
if len(t) < 2:
reveal_type(t, expected_text="Never")
else:
reveal_type(t, expected_text="tuple[str, *tuple[int, ...], str]")