Fixed a couple of bugs in the type narrowing logic for class pattern matching. In particular, added better support for NoneType() patterns and cases where the subject type is a union that includes subtypes of the class pattern. This addresses https://github.com/microsoft/pyright/issues/4800.

This commit is contained in:
Eric Traut 2023-03-19 00:58:54 -06:00
parent 7ad4a8ead5
commit d36c524916
2 changed files with 44 additions and 9 deletions

View File

@ -540,14 +540,30 @@ function narrowTypeBasedOnClassPattern(
return subjectSubtypeUnexpanded;
}
if (
isNoneInstance(subjectSubtypeExpanded) &&
isInstantiableClass(classType) &&
ClassType.isBuiltIn(classType, 'NoneType')
) {
return undefined;
}
if (!evaluator.assignType(classInstance, subjectSubtypeExpanded)) {
return subjectSubtypeExpanded;
}
// If there are no arguments, we're done. We know that this match
// will never succeed.
if (pattern.arguments.length === 0) {
return undefined;
if (
isClass(classInstance) &&
isClass(subjectSubtypeExpanded) &&
ClassType.isSameGenericClass(classInstance, subjectSubtypeExpanded)
) {
// We know that this match will always succeed, so we can
// eliminate this subtype.
return undefined;
}
return subjectSubtypeExpanded;
}
// We might be able to narrow further based on arguments, but only
@ -616,6 +632,14 @@ function narrowTypeBasedOnClassPattern(
return convertToInstance(unexpandedSubtype);
}
if (
isNoneInstance(subjectSubtypeExpanded) &&
isInstantiableClass(expandedSubtype) &&
ClassType.isBuiltIn(expandedSubtype, 'NoneType')
) {
return subjectSubtypeExpanded;
}
if (isClassInstance(subjectSubtypeExpanded)) {
let resultType: Type;

View File

@ -77,23 +77,24 @@ class SingleColor(Enum):
red = 0
def func8(x: SingleColor) -> int:
match x:
def func8(subj: SingleColor) -> int:
match subj:
case SingleColor.red:
return 1
def func9(x: int | None):
match x:
def func9(subj: int | None):
match subj:
case NoneType():
return 1
case int():
return 2
def func10(source: Color | None = None) -> list[str]:
def func10(subj: Color | None = None) -> list[str]:
results = [""]
for x in [""]:
match source:
match subj:
case None:
results.append(x)
case Color.red:
@ -103,3 +104,13 @@ def func10(source: Color | None = None) -> list[str]:
case Color.blue:
pass
return results
def func11(subj: int | float | None):
match subj:
case float():
reveal_type(subj, expected_text="int | float")
case int():
reveal_type(subj, expected_text="int")
case NoneType():
reveal_type(subj, expected_text="None")