diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index f0f230b14..082118a91 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -67,7 +67,6 @@ import { AssignTypeFlags, ClassMember, computeMroLinearization, - containsAnyOrUnknown, convertToInstance, convertToInstantiable, doForEachSubtype, @@ -640,23 +639,6 @@ export function getTypeNarrowingCallback( if (classTypeList) { return (type: Type) => { - const narrowedType = narrowTypeForIsInstance( - evaluator, - type, - classTypeList, - isInstanceCheck, - isPositiveTest, - /* allowIntersections */ false, - testExpression - ); - if (!isNever(narrowedType)) { - return { - type: narrowedType, - isIncomplete, - }; - } - - // Try again with intersection types allowed. return { type: narrowTypeForIsInstance( evaluator, @@ -664,7 +646,6 @@ export function getTypeNarrowingCallback( classTypeList, isInstanceCheck, isPositiveTest, - /* allowIntersections */ true, testExpression ), isIncomplete, @@ -793,7 +774,8 @@ export function getTypeNarrowingCallback( type, typeGuardType, isPositiveTest, - isStrictTypeGuard + isStrictTypeGuard, + testExpression ), isIncomplete, }; @@ -986,7 +968,8 @@ function narrowTypeForUserDefinedTypeGuard( type: Type, typeGuardType: Type, isPositiveTest: boolean, - isStrictTypeGuard: boolean + isStrictTypeGuard: boolean, + errorNode: ExpressionNode ): Type { // For non-strict type guards, always narrow to the typeGuardType // in the positive case and don't narrow in the negative case. @@ -994,101 +977,12 @@ function narrowTypeForUserDefinedTypeGuard( return isPositiveTest ? typeGuardType : type; } - // For strict type guards, narrow the type. - return mapSubtypes(type, (subtype) => { - if (isPositiveTest) { - // In the positive case, we need to compute the intersection of "type" - // and "typeGuardType". If we assume "type" is a union of (A1 | A2 | ... | An) - // and "typeGuardType" is a union of (B1 | B2 | ... | Bn), then the intersection - // of the two is the union of the intersection of all combinations of A's and B's. - // For each pair, if A and B are the same type, the intersection is that type. - // If A is a proper subtype of B or vice versa, the intersection is the narrower of - // the two. If A and B have no commonality, the intersection is Never. - // If A and B are not the same type but A and B are mutually "subtypes" of each - // other, that means they don't follow the normal subtyping rules. In this case, - // we'll try to pick the one that has the most information (i.e. doesn't contain - // an Any or Unknown). - return mapSubtypes(typeGuardType, (typeGuardSubtype) => { - if (isTypeSame(subtype, typeGuardSubtype, { ignorePseudoGeneric: true, treatAnySameAsUnknown: true })) { - return subtype; - } - - const isSubtype = evaluator.assignType(typeGuardSubtype, subtype); - const isSupertype = evaluator.assignType(subtype, typeGuardSubtype); - - if (isSubtype) { - if (!isSupertype) { - return subtype; - } - - // It's both a subtype and a supertype and it's not the same type. - // That means it's some combination of types that don't follow subtyping - // rules. Try to retain as much information as possible. If one of the - // two types contains an Any and the other does not, use the one that doesn't. - return containsAnyOrUnknown(typeGuardSubtype, /* recurse */ true) ? subtype : typeGuardSubtype; - } - - if (isSupertype) { - return typeGuardSubtype; - } - - // The types have nothing in common. - return undefined; - }); - } else { - // In the negative case, we need to compute the intersection of "type" and - // the negation of "typeGuardType". The type system doesn't have support for - // type negations, so the actual result will necessarily be broader than the - // theoretically-correct result. - // If "type" is a union of (A1 | A2 | ... | An) and "typeGuardType" is a union - // of (B1 | B2 | ... | Bn), then the intersection of "type" and !"typeGuardType" - // is (A1 & !B1 & !B2 & ... & !Bn) | (A2 & !B1 & !B2 & ... & !Bn) | ... |. - // This means we can eliminate an A only if we can show that it doesn't - // share any commonality with any of the B's. - let canBeEliminated = false; - - if (!isAnyOrUnknown(subtype)) { - doForEachSubtype(typeGuardType, (typeGuardSubtype) => { - if ( - isTypeSame(subtype, typeGuardSubtype, { - ignorePseudoGeneric: true, - treatAnySameAsUnknown: true, - }) - ) { - canBeEliminated = true; - } else { - const isSubtype = evaluator.assignType(typeGuardSubtype, subtype); - const isSupertype = evaluator.assignType(subtype, typeGuardSubtype); - - if (isSubtype && !isAnyOrUnknown(typeGuardSubtype)) { - if (!isSupertype) { - canBeEliminated = true; - } else { - // In this case, there is not a clear subtype relationship between - // "subtype" and "typeGuardSubtype". We'll use a heuristic to produce - // a result that is not unexpected. If the typeGuardSubtype is a gradual - // type (contains Any) but the subtype does not, we'll eliminate the - // subtype. For example, if the typeGuardSubtype is "list[Any]" and - // the subtype is "list[str]", we'll eliminate "list[str]". If the types - // are reversed, we don't want to eliminate "list[Any]". - const subtypeContainsAny = containsAnyOrUnknown(subtype, /* recurse */ true); - const typeGuardSubtypeContainsAny = containsAnyOrUnknown( - typeGuardSubtype, - /* recurse */ true - ); - - if (!subtypeContainsAny && typeGuardSubtypeContainsAny) { - canBeEliminated = true; - } - } - } - } - }); - } - - return canBeEliminated ? undefined : subtype; - } + const filterTypes: Type[] = []; + doForEachSubtype(typeGuardType, (typeGuardSubtype) => { + filterTypes.push(convertToInstantiable(typeGuardSubtype)); }); + + return narrowTypeForIsInstance(evaluator, type, filterTypes, /* isInstanceCheck */ true, isPositiveTest, errorNode); } // Narrow the type based on whether the subtype can be true or false. @@ -1298,7 +1192,7 @@ export function isIsinstanceFilterSuperclass( concreteFilterType: ClassType, isInstanceCheck: boolean ) { - if (isTypeVar(filterType)) { + if (isTypeVar(filterType) || concreteFilterType.literalValue !== undefined) { return isTypeSame(convertToInstance(filterType), varType); } @@ -1351,15 +1245,50 @@ export function isIsinstanceFilterSubclass( return false; } +function narrowTypeForIsInstance( + evaluator: TypeEvaluator, + type: Type, + filterTypes: Type[], + isInstanceCheck: boolean, + isPositiveTest: boolean, + errorNode: ExpressionNode +) { + // First try with intersection types disallowed. + const narrowedType = narrowTypeForIsInstanceInternal( + evaluator, + type, + filterTypes, + isInstanceCheck, + isPositiveTest, + /* allowIntersections */ false, + errorNode + ); + + if (!isNever(narrowedType)) { + return narrowedType; + } + + // Try again with intersection types allowed. + return narrowTypeForIsInstanceInternal( + evaluator, + type, + filterTypes, + isInstanceCheck, + isPositiveTest, + /* allowIntersections */ true, + errorNode + ); +} + // Attempts to narrow a type (make it more constrained) based on a // call to isinstance or issubclass. For example, if the original // type of expression "x" is "Mammal" and the test expression is // "isinstance(x, Cow)", (assuming "Cow" is a subclass of "Mammal"), // we can conclude that x must be constrained to "Cow". -function narrowTypeForIsInstance( +function narrowTypeForIsInstanceInternal( evaluator: TypeEvaluator, type: Type, - classTypeList: (ClassType | TypeVarType | FunctionType)[], + filterTypes: Type[], isInstanceCheck: boolean, isPositiveTest: boolean, allowIntersections: boolean, @@ -1385,7 +1314,7 @@ function narrowTypeForIsInstance( let foundSuperclass = false; let isClassRelationshipIndeterminate = false; - for (const filterType of classTypeList) { + for (const filterType of filterTypes) { let concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType); if (isInstantiableClass(concreteFilterType)) { @@ -1683,7 +1612,7 @@ function narrowTypeForIsInstance( let foundPositiveMatch = false; let isMatchIndeterminate = false; - for (const filterType of classTypeList) { + for (const filterType of filterTypes) { const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType); if (isInstantiableClass(concreteFilterType)) { @@ -1743,7 +1672,7 @@ function narrowTypeForIsInstance( const filteredTypes: Type[] = []; if (isPositiveTest) { - for (const filterType of classTypeList) { + for (const filterType of filterTypes) { const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType); if ( @@ -1766,7 +1695,7 @@ function narrowTypeForIsInstance( } } } else if ( - !classTypeList.some((filterType) => { + !filterTypes.some((filterType) => { // If the filter type is a runtime checkable protocol class, it can // be used in an instance check. const concreteFilterType = evaluator.makeTopLevelTypeVarsConcrete(filterType); @@ -1784,7 +1713,7 @@ function narrowTypeForIsInstance( }; const classListContainsNoneType = () => - classTypeList.some((t) => { + filterTypes.some((t) => { if (isNoneTypeClass(t)) { return true; } @@ -1812,7 +1741,7 @@ function narrowTypeForIsInstance( // specified types. if (isInstanceCheck) { anyOrUnknownSubstitutions.push( - combineTypes(classTypeList.map((classType) => convertToInstance(classType))) + combineTypes(filterTypes.map((classType) => convertToInstance(classType))) ); } else { // We perform a double conversion from instance to instantiable @@ -1820,7 +1749,7 @@ function narrowTypeForIsInstance( // if it's a class. anyOrUnknownSubstitutions.push( combineTypes( - classTypeList.map((classType) => convertToInstantiable(convertToInstance(classType))) + filterTypes.map((classType) => convertToInstantiable(convertToInstance(classType))) ) ); } @@ -1838,7 +1767,7 @@ function narrowTypeForIsInstance( // Handle type narrowing for runtime-checkable protocols // when applied to modules. if (isPositiveTest) { - const filteredTypes = classTypeList.filter((classType) => { + const filteredTypes = filterTypes.filter((classType) => { const concreteClassType = evaluator.makeTopLevelTypeVarsConcrete(classType); return ( isInstantiableClass(concreteClassType) && ClassType.isProtocolClass(concreteClassType) @@ -1868,7 +1797,7 @@ function narrowTypeForIsInstance( if (isInstantiableClass(subtype) || isSubtypeMetaclass) { // Handle the special case of isinstance(x, metaclass). - const includesMetaclassType = classTypeList.some((classType) => isInstantiableMetaclass(classType)); + const includesMetaclassType = filterTypes.some((classType) => isInstantiableMetaclass(classType)); if (isPositiveTest) { return includesMetaclassType ? negativeFallback : undefined; } else { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index c271279bb..0bacbc3d4 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -1300,6 +1300,16 @@ export namespace ClassType { ): boolean { // Is it the exact same class? if (isSameGenericClass(subclassType, parentClassType)) { + // Handle literal types. + if (parentClassType.literalValue !== undefined) { + if ( + subclassType.literalValue === undefined || + !ClassType.isLiteralValueSame(parentClassType, subclassType) + ) { + return false; + } + } + if (inheritanceChain) { inheritanceChain.push(subclassType); } diff --git a/packages/pyright-internal/src/tests/samples/typeGuard3.py b/packages/pyright-internal/src/tests/samples/typeGuard3.py index f1d28a7a7..716373393 100644 --- a/packages/pyright-internal/src/tests/samples/typeGuard3.py +++ b/packages/pyright-internal/src/tests/samples/typeGuard3.py @@ -48,4 +48,4 @@ def func_typeis(val: int | str): if is_int(val, TypeGuardMode.TypeIs): reveal_type(val, expected_text="int") else: - reveal_type(val, expected_text="str") + reveal_type(val, expected_text="int | str") diff --git a/packages/pyright-internal/src/tests/samples/typeIs1.py b/packages/pyright-internal/src/tests/samples/typeIs1.py index 4d89afe40..b0e83d00b 100644 --- a/packages/pyright-internal/src/tests/samples/typeIs1.py +++ b/packages/pyright-internal/src/tests/samples/typeIs1.py @@ -12,7 +12,7 @@ def func1(val: Union[str, int]): if is_str1(val): reveal_type(val, expected_text="str") else: - reveal_type(val, expected_text="int") + reveal_type(val, expected_text="str | int") def is_true(o: object) -> TypeIs[Literal[True]]: ... @@ -33,16 +33,20 @@ def is_list(val: object) -> TypeIs[list[Any]]: def func3(val: dict[str, str] | list[str] | list[int] | Sequence[int]): if is_list(val): - reveal_type(val, expected_text="list[str] | list[int] | list[Any]") + reveal_type(val, expected_text="list[str] | list[int]") else: - reveal_type(val, expected_text="dict[str, str] | Sequence[int]") + reveal_type( + val, expected_text="dict[str, str] | list[str] | list[int] | Sequence[int]" + ) def func4(val: dict[str, str] | list[str] | list[int] | tuple[int]): if is_list(val): reveal_type(val, expected_text="list[str] | list[int]") else: - reveal_type(val, expected_text="dict[str, str] | tuple[int]") + reveal_type( + val, expected_text="dict[str, str] | list[str] | list[int] | tuple[int]" + ) _K = TypeVar("_K") @@ -55,7 +59,9 @@ def is_dict(val: Mapping[_K, _V]) -> TypeIs[dict[_K, _V]]: def func5(val: dict[_K, _V] | Mapping[_K, _V]): if not is_dict(val): - reveal_type(val, expected_text="Mapping[_K@func5, _V@func5]") + reveal_type( + val, expected_text="dict[_K@func5, _V@func5] | Mapping[_K@func5, _V@func5]" + ) else: reveal_type(val, expected_text="dict[_K@func5, _V@func5]")