Fixed bug in type narrowing logic for "in" operator. It was not properly handling the case where the reference (LHS) type was a subtype of the container's element type.

This commit is contained in:
Eric Traut 2022-07-17 23:29:00 -07:00
parent b02711fd94
commit 2dae8d9cb9
2 changed files with 56 additions and 16 deletions

View File

@ -398,7 +398,7 @@ export function getTypeNarrowingCallback(
if (adjIsPositiveTest) {
return (type: Type) => {
return narrowTypeForContains(evaluator, type, rightType);
return narrowTypeForContainerType(evaluator, type, rightType);
};
}
}
@ -1422,26 +1422,25 @@ function narrowTypeForTupleLength(
});
}
// Attempts to narrow a type (make it more constrained) based on an "in" or
// "not in" binary expression.
function narrowTypeForContains(evaluator: TypeEvaluator, referenceType: Type, containerType: Type) {
// We support contains narrowing only for certain built-in types that have been specialized.
if (!isClassInstance(containerType) || !ClassType.isBuiltIn(containerType)) {
// Attempts to narrow a type (make it more constrained) based on an "in" binary operator.
function narrowTypeForContainerType(evaluator: TypeEvaluator, referenceType: Type, containerType: Type) {
const elementType = getElementTypeForContainerNarrowing(containerType);
if (!elementType) {
return referenceType;
}
const builtInName = containerType.details.name;
return narrowTypeForContainerElementType(evaluator, referenceType, elementType);
}
if (
!['list', 'set', 'frozenset', 'deque', 'tuple', 'dict', 'defaultdict', 'OrderedDict'].some(
(name) => name === builtInName
)
) {
return referenceType;
export function getElementTypeForContainerNarrowing(containerType: Type) {
// We support contains narrowing only for certain built-in types that have been specialized.
const supportedContainers = ['list', 'set', 'frozenset', 'deque', 'tuple', 'dict', 'defaultdict', 'OrderedDict'];
if (!isClassInstance(containerType) || !ClassType.isBuiltIn(containerType, supportedContainers)) {
return undefined;
}
if (!containerType.typeArguments || containerType.typeArguments.length < 1) {
return referenceType;
return undefined;
}
let elementType = containerType.typeArguments[0];
@ -1449,10 +1448,18 @@ function narrowTypeForContains(evaluator: TypeEvaluator, referenceType: Type, co
elementType = combineTypes(containerType.tupleTypeArguments.map((t) => t.type));
}
return elementType;
}
export function narrowTypeForContainerElementType(evaluator: TypeEvaluator, referenceType: Type, elementType: Type) {
let canNarrow = true;
const elementTypeWithoutLiteral = stripLiteralValue(elementType);
const narrowedType = mapSubtypes(referenceType, (referenceSubtype) => {
// Look for cases where one or more of the reference subtypes are
// supertypes of the element types. For example, if the element type
// is "int | str" and the reference type is "float | bytes", we can
// narrow the reference type to "float" because it is a supertype of "int".
const narrowedSupertypes = mapSubtypes(referenceType, (referenceSubtype) => {
const concreteReferenceType = evaluator.makeTopLevelTypeVarsConcrete(referenceSubtype);
if (isAnyOrUnknown(concreteReferenceType)) {
@ -1482,7 +1489,27 @@ function narrowTypeForContains(evaluator: TypeEvaluator, referenceType: Type, co
return undefined;
});
return canNarrow ? narrowedType : referenceType;
// Look for cases where one or more of the reference subtypes are
// subtypes of the element types. For example, if the element type
// is "int | str" and the reference type is "object", we can
// narrow the reference type to "int | str" because they are both
// subtypes of "object".
const narrowedSubtypes = mapSubtypes(elementType, (elementSubtype) => {
const concreteElementType = evaluator.makeTopLevelTypeVarsConcrete(elementSubtype);
if (isAnyOrUnknown(concreteElementType)) {
canNarrow = false;
return referenceType;
}
if (evaluator.assignType(referenceType, concreteElementType)) {
return concreteElementType;
}
return undefined;
});
return canNarrow ? combineTypes([narrowedSupertypes, narrowedSubtypes]) : referenceType;
}
// Attempts to narrow a type based on whether it is a TypedDict with

View File

@ -102,3 +102,16 @@ def func6(x: type):
reveal_type(x, expected_text="type")
else:
reveal_type(x, expected_text="type")
def func7(x: object | bytes, y: str, z: int):
if x in (y, z):
reveal_type(x, expected_text="str | int")
else:
reveal_type(x, expected_text="object | bytes")
reveal_type(x, expected_text="str | int | object | bytes")
def func8(x: object):
if x in ("a", "b", 2, None):
reveal_type(x, expected_text="Literal['a', 'b', 2] | None")