Changed TypeIs to use the same logic as isinstance type narrowing logic for consistency. This addresses #7767, #7760, and #7647. (#7777)

This commit is contained in:
Eric Traut 2024-04-25 22:19:04 -07:00 committed by GitHub
parent c2203b9aa7
commit 96d0145763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 133 deletions

View File

@ -67,7 +67,6 @@ import {
AssignTypeFlags,
ClassMember,
computeMroLinearization,
containsAnyOrUnknown,
convertToInstance,
convertToInstantiable,
doForEachSubtype,
@ -640,23 +639,6 @@ export function getTypeNarrowingCallback(
if (classTypeList) {
return (type: Type) => {
const narrowedType = narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ false,
testExpression
);
if (!isNever(narrowedType)) {
return {
type: narrowedType,
isIncomplete,
};
}
// Try again with intersection types allowed.
return {
type: narrowTypeForIsInstance(
evaluator,
@ -664,7 +646,6 @@ export function getTypeNarrowingCallback(
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
testExpression
),
isIncomplete,
@ -793,7 +774,8 @@ export function getTypeNarrowingCallback(
type,
typeGuardType,
isPositiveTest,
isStrictTypeGuard
isStrictTypeGuard,
testExpression
),
isIncomplete,
};
@ -986,7 +968,8 @@ function narrowTypeForUserDefinedTypeGuard(
type: Type,
typeGuardType: Type,
isPositiveTest: boolean,
isStrictTypeGuard: boolean
isStrictTypeGuard: boolean,
errorNode: ExpressionNode
): Type {
// For non-strict type guards, always narrow to the typeGuardType
// in the positive case and don't narrow in the negative case.
@ -994,101 +977,12 @@ function narrowTypeForUserDefinedTypeGuard(
return isPositiveTest ? typeGuardType : type;
}
// For strict type guards, narrow the type.
return mapSubtypes(type, (subtype) => {
if (isPositiveTest) {
// In the positive case, we need to compute the intersection of "type"
// and "typeGuardType". If we assume "type" is a union of (A1 | A2 | ... | An)
// and "typeGuardType" is a union of (B1 | B2 | ... | Bn), then the intersection
// of the two is the union of the intersection of all combinations of A's and B's.
// For each pair, if A and B are the same type, the intersection is that type.
// If A is a proper subtype of B or vice versa, the intersection is the narrower of
// the two. If A and B have no commonality, the intersection is Never.
// If A and B are not the same type but A and B are mutually "subtypes" of each
// other, that means they don't follow the normal subtyping rules. In this case,
// we'll try to pick the one that has the most information (i.e. doesn't contain
// an Any or Unknown).
return mapSubtypes(typeGuardType, (typeGuardSubtype) => {
if (isTypeSame(subtype, typeGuardSubtype, { ignorePseudoGeneric: true, treatAnySameAsUnknown: true })) {
return subtype;
}
const isSubtype = evaluator.assignType(typeGuardSubtype, subtype);
const isSupertype = evaluator.assignType(subtype, typeGuardSubtype);
if (isSubtype) {
if (!isSupertype) {
return subtype;
}
// It's both a subtype and a supertype and it's not the same type.
// That means it's some combination of types that don't follow subtyping
// rules. Try to retain as much information as possible. If one of the
// two types contains an Any and the other does not, use the one that doesn't.
return containsAnyOrUnknown(typeGuardSubtype, /* recurse */ true) ? subtype : typeGuardSubtype;
}
if (isSupertype) {
return typeGuardSubtype;
}
// The types have nothing in common.
return undefined;
});
} else {
// In the negative case, we need to compute the intersection of "type" and
// the negation of "typeGuardType". The type system doesn't have support for
// type negations, so the actual result will necessarily be broader than the
// theoretically-correct result.
// If "type" is a union of (A1 | A2 | ... | An) and "typeGuardType" is a union
// of (B1 | B2 | ... | Bn), then the intersection of "type" and !"typeGuardType"
// is (A1 & !B1 & !B2 & ... & !Bn) | (A2 & !B1 & !B2 & ... & !Bn) | ... |.
// This means we can eliminate an A only if we can show that it doesn't
// share any commonality with any of the B's.
let canBeEliminated = false;
if (!isAnyOrUnknown(subtype)) {
doForEachSubtype(typeGuardType, (typeGuardSubtype) => {
if (
isTypeSame(subtype, typeGuardSubtype, {
ignorePseudoGeneric: true,
treatAnySameAsUnknown: true,
})
) {
canBeEliminated = true;
} else {
const isSubtype = evaluator.assignType(typeGuardSubtype, subtype);
const isSupertype = evaluator.assignType(subtype, typeGuardSubtype);
if (isSubtype && !isAnyOrUnknown(typeGuardSubtype)) {
if (!isSupertype) {
canBeEliminated = true;
} else {
// In this case, there is not a clear subtype relationship between
// "subtype" and "typeGuardSubtype". We'll use a heuristic to produce
// a result that is not unexpected. If the typeGuardSubtype is a gradual
// type (contains Any) but the subtype does not, we'll eliminate the
// subtype. For example, if the typeGuardSubtype is "list[Any]" and
// the subtype is "list[str]", we'll eliminate "list[str]". If the types
// are reversed, we don't want to eliminate "list[Any]".
const subtypeContainsAny = containsAnyOrUnknown(subtype, /* recurse */ true);
const typeGuardSubtypeContainsAny = containsAnyOrUnknown(
typeGuardSubtype,
/* recurse */ true
);
if (!subtypeContainsAny && typeGuardSubtypeContainsAny) {
canBeEliminated = true;
}
}
}
}
});
}
return canBeEliminated ? undefined : subtype;
}
const filterTypes: Type[] = [];
doForEachSubtype(typeGuardType, (typeGuardSubtype) => {
filterTypes.push(convertToInstantiable(typeGuardSubtype));
});
return narrowTypeForIsInstance(evaluator, type, filterTypes, /* isInstanceCheck */ true, isPositiveTest, errorNode);
}
// Narrow the type based on whether the subtype can be true or false.
@ -1298,7 +1192,7 @@ export function isIsinstanceFilterSuperclass(
concreteFilterType: ClassType,
isInstanceCheck: boolean
) {
if (isTypeVar(filterType)) {
if (isTypeVar(filterType) || concreteFilterType.literalValue !== undefined) {
return isTypeSame(convertToInstance(filterType), varType);
}
@ -1351,15 +1245,50 @@ export function isIsinstanceFilterSubclass(
return false;
}
function narrowTypeForIsInstance(
evaluator: TypeEvaluator,
type: Type,
filterTypes: Type[],
isInstanceCheck: boolean,
isPositiveTest: boolean,
errorNode: ExpressionNode
) {
// First try with intersection types disallowed.
const narrowedType = narrowTypeForIsInstanceInternal(
evaluator,
type,
filterTypes,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ false,
errorNode
);
if (!isNever(narrowedType)) {
return narrowedType;
}
// Try again with intersection types allowed.
return narrowTypeForIsInstanceInternal(
evaluator,
type,
filterTypes,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
errorNode
);
}
// Attempts to narrow a type (make it more constrained) based on a
// call to isinstance or issubclass. For example, if the original
// type of expression "x" is "Mammal" and the test expression is
// "isinstance(x, Cow)", (assuming "Cow" is a subclass of "Mammal"),
// we can conclude that x must be constrained to "Cow".
function narrowTypeForIsInstance(
function narrowTypeForIsInstanceInternal(
evaluator: TypeEvaluator,
type: Type,
classTypeList: (ClassType | TypeVarType | FunctionType)[],
filterTypes: Type[],
isInstanceCheck: boolean,
isPositiveTest: boolean,
allowIntersections: boolean,
@ -1385,7 +1314,7 @@ function narrowTypeForIsInstance(
let foundSuperclass = false;
let isClassRelationshipIndeterminate = false;
for (const filterType of classTypeList) {
for (const filterType of filterTypes) {
let concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType);
if (isInstantiableClass(concreteFilterType)) {
@ -1683,7 +1612,7 @@ function narrowTypeForIsInstance(
let foundPositiveMatch = false;
let isMatchIndeterminate = false;
for (const filterType of classTypeList) {
for (const filterType of filterTypes) {
const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType);
if (isInstantiableClass(concreteFilterType)) {
@ -1743,7 +1672,7 @@ function narrowTypeForIsInstance(
const filteredTypes: Type[] = [];
if (isPositiveTest) {
for (const filterType of classTypeList) {
for (const filterType of filterTypes) {
const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType);
if (
@ -1766,7 +1695,7 @@ function narrowTypeForIsInstance(
}
}
} else if (
!classTypeList.some((filterType) => {
!filterTypes.some((filterType) => {
// If the filter type is a runtime checkable protocol class, it can
// be used in an instance check.
const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType);
@ -1784,7 +1713,7 @@ function narrowTypeForIsInstance(
};
const classListContainsNoneType = () =>
classTypeList.some((t) => {
filterTypes.some((t) => {
if (isNoneTypeClass(t)) {
return true;
}
@ -1812,7 +1741,7 @@ function narrowTypeForIsInstance(
// specified types.
if (isInstanceCheck) {
anyOrUnknownSubstitutions.push(
combineTypes(classTypeList.map((classType) => convertToInstance(classType)))
combineTypes(filterTypes.map((classType) => convertToInstance(classType)))
);
} else {
// We perform a double conversion from instance to instantiable
@ -1820,7 +1749,7 @@ function narrowTypeForIsInstance(
// if it's a class.
anyOrUnknownSubstitutions.push(
combineTypes(
classTypeList.map((classType) => convertToInstantiable(convertToInstance(classType)))
filterTypes.map((classType) => convertToInstantiable(convertToInstance(classType)))
)
);
}
@ -1838,7 +1767,7 @@ function narrowTypeForIsInstance(
// Handle type narrowing for runtime-checkable protocols
// when applied to modules.
if (isPositiveTest) {
const filteredTypes = classTypeList.filter((classType) => {
const filteredTypes = filterTypes.filter((classType) => {
const concreteClassType = evaluator.makeTopLevelTypeVarsConcrete(classType);
return (
isInstantiableClass(concreteClassType) && ClassType.isProtocolClass(concreteClassType)
@ -1868,7 +1797,7 @@ function narrowTypeForIsInstance(
if (isInstantiableClass(subtype) || isSubtypeMetaclass) {
// Handle the special case of isinstance(x, metaclass).
const includesMetaclassType = classTypeList.some((classType) => isInstantiableMetaclass(classType));
const includesMetaclassType = filterTypes.some((classType) => isInstantiableMetaclass(classType));
if (isPositiveTest) {
return includesMetaclassType ? negativeFallback : undefined;
} else {

View File

@ -1300,6 +1300,16 @@ export namespace ClassType {
): boolean {
// Is it the exact same class?
if (isSameGenericClass(subclassType, parentClassType)) {
// Handle literal types.
if (parentClassType.literalValue !== undefined) {
if (
subclassType.literalValue === undefined ||
!ClassType.isLiteralValueSame(parentClassType, subclassType)
) {
return false;
}
}
if (inheritanceChain) {
inheritanceChain.push(subclassType);
}

View File

@ -48,4 +48,4 @@ def func_typeis(val: int | str):
if is_int(val, TypeGuardMode.TypeIs):
reveal_type(val, expected_text="int")
else:
reveal_type(val, expected_text="str")
reveal_type(val, expected_text="int | str")

View File

@ -12,7 +12,7 @@ def func1(val: Union[str, int]):
if is_str1(val):
reveal_type(val, expected_text="str")
else:
reveal_type(val, expected_text="int")
reveal_type(val, expected_text="str | int")
def is_true(o: object) -> TypeIs[Literal[True]]: ...
@ -33,16 +33,20 @@ def is_list(val: object) -> TypeIs[list[Any]]:
def func3(val: dict[str, str] | list[str] | list[int] | Sequence[int]):
if is_list(val):
reveal_type(val, expected_text="list[str] | list[int] | list[Any]")
reveal_type(val, expected_text="list[str] | list[int]")
else:
reveal_type(val, expected_text="dict[str, str] | Sequence[int]")
reveal_type(
val, expected_text="dict[str, str] | list[str] | list[int] | Sequence[int]"
)
def func4(val: dict[str, str] | list[str] | list[int] | tuple[int]):
if is_list(val):
reveal_type(val, expected_text="list[str] | list[int]")
else:
reveal_type(val, expected_text="dict[str, str] | tuple[int]")
reveal_type(
val, expected_text="dict[str, str] | list[str] | list[int] | tuple[int]"
)
_K = TypeVar("_K")
@ -55,7 +59,9 @@ def is_dict(val: Mapping[_K, _V]) -> TypeIs[dict[_K, _V]]:
def func5(val: dict[_K, _V] | Mapping[_K, _V]):
if not is_dict(val):
reveal_type(val, expected_text="Mapping[_K@func5, _V@func5]")
reveal_type(
val, expected_text="dict[_K@func5, _V@func5] | Mapping[_K@func5, _V@func5]"
)
else:
reveal_type(val, expected_text="dict[_K@func5, _V@func5]")