diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 05f881cb9..364d37623 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2543,7 +2543,8 @@ function narrowTypeForLiteralComparison( } } else if (isPositiveTest) { if (isIsOperator || isNoneInstance(subtype)) { - return undefined; + const isSubtype = evaluator.assignType(subtype, literalType); + return isSubtype ? literalType : undefined; } } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingEnum2.py b/packages/pyright-internal/src/tests/samples/typeNarrowingEnum2.py index ff241d50c..597af89cc 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingEnum2.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingEnum2.py @@ -11,8 +11,7 @@ class SomeEnum(Enum): VALUE2 = 2 -def assert_never(val: NoReturn): - ... +def assert_never(val: NoReturn): ... def func1(a: SomeEnum): @@ -69,3 +68,10 @@ def func7(a: Union[str, bool]) -> str: elif a is True: return "True" return a + + +def func8(a: object): + if a is SomeEnum.VALUE1 or a is SomeEnum.VALUE2: + reveal_type(a, expected_text="Literal[SomeEnum.VALUE1, SomeEnum.VALUE2]") + else: + reveal_type(a, expected_text="object")