From d55be983e5503eb8b7def5174112c09a570e0eca Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sat, 29 Jun 2024 11:04:30 -0700 Subject: [PATCH] Fixed a bug in the type narrowing for the "x is " type guard pattern when `` is a specific class T, as opposed to a variable of type `type[T]`. This addresses #8264. --- .../src/analyzer/typeGuards.ts | 66 ++++++++++++++----- .../tests/samples/typeNarrowingIsClass1.py | 20 +++--- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 4c5b1a393..33b070acc 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2463,33 +2463,69 @@ function narrowTypeForClassComparison( isPositiveTest: boolean ): Type { return mapSubtypes(referenceType, (subtype) => { - const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); + let concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); if (isPositiveTest) { if (isNoneInstance(concreteSubtype)) { - return undefined; + return isNoneTypeClass(classType) ? classType : undefined; } - if (isClassInstance(concreteSubtype) && TypeBase.isInstance(subtype)) { - if (ClassType.isBuiltIn(concreteSubtype, 'type')) { - return classType; + if ( + isClassInstance(concreteSubtype) && + TypeBase.isInstance(subtype) && + ClassType.isBuiltIn(concreteSubtype, 'type') + ) { + concreteSubtype = + concreteSubtype.typeArguments && concreteSubtype.typeArguments.length > 0 + ? convertToInstantiable(concreteSubtype.typeArguments[0]) + : UnknownType.create(); + } + + if (isAnyOrUnknown(concreteSubtype)) { + return classType; + } + + if (isClass(concreteSubtype)) { + if (TypeBase.isInstance(concreteSubtype)) { + return ClassType.isBuiltIn(concreteSubtype, 'object') ? classType : undefined; } - return undefined; - } + const isSuperType = isIsinstanceFilterSuperclass( + evaluator, + subtype, + concreteSubtype, + classType, + classType, + /* isInstanceCheck */ false + ); - if (isInstantiableClass(concreteSubtype) && ClassType.isFinal(concreteSubtype)) { - if ( - !ClassType.isSameGenericClass(concreteSubtype, classType) && - !isIsinstanceFilterSuperclass( + if (!classType.includeSubclasses) { + // Handle the case where the LHS and RHS operands are specific + // classes, as opposed to types that represent classes and their + // subclasses. + if (!concreteSubtype.includeSubclasses) { + return ClassType.isSameGenericClass(concreteSubtype, classType) ? classType : undefined; + } + + const isSubType = isIsinstanceFilterSubclass( evaluator, - subtype, concreteSubtype, classType, - classType, /* isInstanceCheck */ false - ) - ) { + ); + + if (isSuperType) { + return classType; + } + + if (isSubType) { + return addConditionToType(classType, getTypeCondition(concreteSubtype)); + } + + return undefined; + } + + if (ClassType.isFinal(concreteSubtype) && !isSuperType) { return undefined; } } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py index 96fca0de4..71633b92f 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py @@ -5,17 +5,14 @@ from typing import Any, TypeVar, final @final -class A: - ... +class A: ... @final -class B: - ... +class B: ... -class C: - ... +class C: ... def func1(x: type[A] | type[B] | None | int): @@ -34,7 +31,7 @@ def func2(x: type[A] | type[B] | None | int, y: type[A]): def func3(x: type[A] | type[B] | Any): if x is A: - reveal_type(x, expected_text="type[A] | Any") + reveal_type(x, expected_text="type[A]") else: reveal_type(x, expected_text="type[B] | Any") @@ -51,7 +48,7 @@ T = TypeVar("T") def func5(x: type[A] | type[B] | type[T]) -> type[A] | type[B] | type[T]: if x is A: - reveal_type(x, expected_text="type[A] | type[T@func5]") + reveal_type(x, expected_text="type[A] | type[A]*") else: reveal_type(x, expected_text="type[B] | type[T@func5]") @@ -63,3 +60,10 @@ def func6(x: type): reveal_type(x, expected_text="type[str]") else: reveal_type(x, expected_text="type") + + +def func7(x: type[A | B]): + if x is A: + reveal_type(x, expected_text="type[A]") + else: + reveal_type(x, expected_text="type[B]")