Fixed bug in type narrowing logic for expressions of the form "X not in Y".

This commit is contained in:
Eric Traut 2022-05-29 23:54:15 -07:00
parent 469cc162f9
commit 255497446a
4 changed files with 35 additions and 16 deletions

View File

@ -178,7 +178,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)
* `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer)
* `x in y` (where y is instance of list, set, frozenset, deque, or tuple)
* `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, or tuple)
* `S in D` and `S not in D` (where S is a string literal and D is a TypedDict)
* `isinstance(x, T)` (where T is a type or a tuple of types)
* `issubclass(x, T)` (where T is a type or a tuple of types)

View File

@ -2814,8 +2814,8 @@ export class Binder extends ParseTreeWalker {
}
}
// Look for "X in Y".
if (expression.operator === OperatorType.In) {
// Look for "X in Y" or "X not in Y".
if (expression.operator === OperatorType.In || expression.operator === OperatorType.NotIn) {
return this._isNarrowingExpression(
expression.leftExpression,
expressionList,

View File

@ -389,17 +389,20 @@ export function getTypeNarrowingCallback(
}
}
if (testExpression.operator === OperatorType.In) {
// Look for "x in y" where y is one of several built-in types.
if (isPositiveTest && ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
return (type: Type) => {
return narrowTypeForContains(evaluator, type, rightType);
};
}
}
if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) {
// Look for "x in y" or "x not in y" where y is one of several built-in types.
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const adjIsPositiveTest =
testExpression.operator === OperatorType.In ? isPositiveTest : !isPositiveTest;
if (adjIsPositiveTest) {
return (type: Type) => {
return narrowTypeForContains(evaluator, type, rightType);
};
}
}
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
// Look for <string literal> in y where y is a union that contains
// one or more TypedDicts.
@ -1342,16 +1345,18 @@ function narrowTypeForContains(evaluator: TypeEvaluator, referenceType: Type, co
const elementTypeWithoutLiteral = stripLiteralValue(elementType);
const narrowedType = mapSubtypes(referenceType, (referenceSubtype) => {
if (isAnyOrUnknown(referenceSubtype)) {
const concreteReferenceType = evaluator.makeTopLevelTypeVarsConcrete(referenceSubtype);
if (isAnyOrUnknown(concreteReferenceType)) {
canNarrow = false;
return referenceSubtype;
}
if (evaluator.assignType(elementType, referenceSubtype)) {
if (evaluator.assignType(elementType, concreteReferenceType)) {
return referenceSubtype;
}
if (evaluator.assignType(elementTypeWithoutLiteral, referenceSubtype)) {
if (evaluator.assignType(elementTypeWithoutLiteral, concreteReferenceType)) {
return mapSubtypes(elementType, (elementSubtype) => {
if (isClassInstance(elementSubtype) && isSameWithoutLiteralValue(referenceSubtype, elementSubtype)) {
return elementSubtype;

View File

@ -69,3 +69,17 @@ def func2(a: Literal[1, 2, 3]):
reveal_type(a, expected_text="Literal[1, 2]")
else:
reveal_type(a, expected_text="Literal[1, 2, 3]")
def func3(val: str | None, container: frozenset[str]):
if val in container:
reveal_type(val, expected_text="str")
else:
reveal_type(val, expected_text="str | None")
def func4(val: str | None, container: list[str]):
if val not in container:
reveal_type(val, expected_text="str | None")
else:
reveal_type(val, expected_text="str")