From ef72e4079c7191d08a6c97d1cf3b6ac838be22ba Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Wed, 16 Feb 2022 22:08:17 -0800 Subject: [PATCH] Fixed a bug in negative type narrowing logic for value patterns in `match` statement. --- .../src/analyzer/patternMatching.ts | 35 +++++++++---------- .../src/tests/samples/match4.py | 16 ++++++++- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 075cb36e2..e05e07569 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -734,44 +734,43 @@ function narrowTypeBasedOnValuePattern( evaluator.mapSubtypesExpandTypeVars( subjectType, getTypeCondition(valueSubtypeExpanded), - (_, subjectSubtypeUnexpanded) => { + (subjectSubtypeExpanded) => { // If this is a negative test, see if it's an enum value. if (!isPositiveTest) { if ( - isClassInstance(subjectSubtypeUnexpanded) && - ClassType.isEnumClass(subjectSubtypeUnexpanded) && - !isLiteralType(subjectSubtypeUnexpanded) && - isClassInstance(valueSubtypeUnexpanded) && - isSameWithoutLiteralValue(subjectSubtypeUnexpanded, valueSubtypeUnexpanded) && - isLiteralType(valueSubtypeUnexpanded) + isClassInstance(subjectSubtypeExpanded) && + ClassType.isEnumClass(subjectSubtypeExpanded) && + !isLiteralType(subjectSubtypeExpanded) && + isClassInstance(valueSubtypeExpanded) && + isSameWithoutLiteralValue(subjectSubtypeExpanded, valueSubtypeExpanded) && + isLiteralType(valueSubtypeExpanded) ) { - const allEnumTypes = enumerateLiteralsForType(evaluator, subjectSubtypeUnexpanded); + const allEnumTypes = enumerateLiteralsForType(evaluator, subjectSubtypeExpanded); if (allEnumTypes) { return combineTypes( allEnumTypes.filter( - (enumType) => - !ClassType.isLiteralValueSame(valueSubtypeUnexpanded, enumType) + (enumType) => !ClassType.isLiteralValueSame(valueSubtypeExpanded, enumType) ) ); } } else if ( - isClassInstance(subjectSubtypeUnexpanded) && - isClassInstance(valueSubtypeUnexpanded) && - ClassType.isLiteralValueSame(valueSubtypeUnexpanded, subjectSubtypeUnexpanded) + isClassInstance(subjectSubtypeExpanded) && + isClassInstance(valueSubtypeExpanded) && + ClassType.isLiteralValueSame(valueSubtypeExpanded, subjectSubtypeExpanded) ) { return undefined; } - return subjectSubtypeUnexpanded; + return subjectSubtypeExpanded; } - if (isNever(valueSubtypeExpanded) || isNever(subjectSubtypeUnexpanded)) { + if (isNever(valueSubtypeExpanded) || isNever(subjectSubtypeExpanded)) { return NeverType.createNever(); } - if (isAnyOrUnknown(valueSubtypeExpanded) || isAnyOrUnknown(subjectSubtypeUnexpanded)) { + if (isAnyOrUnknown(valueSubtypeExpanded) || isAnyOrUnknown(subjectSubtypeExpanded)) { // If either type is "Unknown" (versus Any), propagate the Unknown. - return isUnknown(valueSubtypeExpanded) || isUnknown(subjectSubtypeUnexpanded) + return isUnknown(valueSubtypeExpanded) || isUnknown(subjectSubtypeExpanded) ? UnknownType.create() : AnyType.create(); } @@ -781,7 +780,7 @@ function narrowTypeBasedOnValuePattern( const returnType = evaluator.useSpeculativeMode(pattern.expression, () => evaluator.getTypeFromMagicMethodReturn( valueSubtypeExpanded, - [subjectSubtypeUnexpanded], + [subjectSubtypeExpanded], '__eq__', pattern.expression, /* expectedType */ undefined diff --git a/packages/pyright-internal/src/tests/samples/match4.py b/packages/pyright-internal/src/tests/samples/match4.py index 46585f7a9..ba32474a7 100644 --- a/packages/pyright-internal/src/tests/samples/match4.py +++ b/packages/pyright-internal/src/tests/samples/match4.py @@ -1,7 +1,7 @@ # This sample tests type checking for match statements (as # described in PEP 634) that contain value patterns. -from enum import Enum +from enum import Enum, auto from typing import Tuple, TypeVar, Union from http import HTTPStatus @@ -84,3 +84,17 @@ def test_enum_narrowing(m: Union[Medal, Color, int]): case d1: reveal_type(d1, expected_text='int | Literal[Medal.bronze]') reveal_type(m, expected_text='int | Literal[Medal.bronze]') + + +class Foo(Enum): + bar = auto() + + def __str__(self) -> str: + match self: + case Foo.bar: + return "bar" + + case x: + reveal_type(x, expected_text="Never") + +