Fixed bugs that resulted in incorrect or incomplete types when narrowing for sequence patterns in the negative case and the subject expression is a super-type of Sequence (such as object or Reversible). (#5507)

Co-authored-by: Eric Traut <erictr@microsoft.com>
This commit is contained in:
Eric Traut 2023-07-15 02:12:58 +00:00 committed by GitHub
parent bf158567f9
commit 01f40f0e98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 131 deletions

View File

@ -31,7 +31,7 @@ import {
import { getFileInfo } from './analyzerNodeInfo';
import { CodeFlowReferenceExpressionNode } from './codeFlowTypes';
import { populateTypeVarContextBasedOnExpectedType } from './constraintSolver';
import { isMatchingExpression } from './parseTreeUtils';
import { getTypeVarScopesForNode, isMatchingExpression } from './parseTreeUtils';
import { getTypedDictMembersForClass } from './typedDicts';
import { EvaluatorFlags, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import {
@ -99,10 +99,10 @@ const classPatternSpecialCases = [
interface SequencePatternInfo {
subtype: Type;
isDefiniteNoMatch: boolean;
isPotentialNoMatch?: boolean;
entryTypes: Type[];
isIndeterminateLength?: boolean;
isTuple?: boolean;
isObject?: boolean;
}
interface MappingPatternInfo {
@ -208,7 +208,7 @@ function narrowTypeBasedOnSequencePattern(
}
let isPlausibleMatch = true;
let isDefiniteMatch = !containsAnyOrUnknown(type, /* recurse */ false);
let isDefiniteMatch = true;
const narrowedEntryTypes: Type[] = [];
let canNarrowTuple = entry.isTuple;
@ -236,8 +236,7 @@ function narrowTypeBasedOnSequencePattern(
index,
pattern.entries.length,
pattern.starEntryIndex,
/* unpackStarEntry */ true,
/* isSubjectObject */ false
/* unpackStarEntry */ true
);
const narrowedEntryType = narrowTypeBasedOnPattern(evaluator, entryType, sequenceEntry, isPositiveTest);
@ -266,6 +265,10 @@ function narrowTypeBasedOnSequencePattern(
}
}
} else {
if (entry.isPotentialNoMatch) {
isDefiniteMatch = false;
}
if (!isNever(narrowedEntryType)) {
isDefiniteMatch = false;
@ -327,16 +330,18 @@ function narrowTypeBasedOnSequencePattern(
}
}
// If this is an object, we can narrow it to a specific Sequence type.
if (entry.isObject) {
// If this is a supertype of Sequence, we can narrow it to a Sequence type.
if (entry.isPotentialNoMatch) {
const sequenceType = evaluator.getTypingType(pattern, 'Sequence');
if (sequenceType && isInstantiableClass(sequenceType)) {
let typeArgType = evaluator.stripLiteralValue(combineTypes(narrowedEntryTypes));
// If the type is a union that contains Any or Unknown, remove the other types
// before wrapping it in a Sequence.
typeArgType = containsAnyOrUnknown(typeArgType, /* recurse */ false) ?? typeArgType;
entry.subtype = ClassType.cloneAsInstance(
ClassType.cloneForSpecialization(
sequenceType,
[evaluator.stripLiteralValue(combineTypes(narrowedEntryTypes))],
/* isTypeArgumentExplicit */ true
)
ClassType.cloneForSpecialization(sequenceType, [typeArgType], /* isTypeArgumentExplicit */ true)
);
}
}
@ -1112,28 +1117,7 @@ function getSequencePatternInfo(
let mroClassToSpecialize: ClassType | undefined;
let pushedEntry = false;
if (isAnyOrUnknown(concreteSubtype)) {
sequenceInfo.push({
subtype,
entryTypes: [concreteSubtype],
isIndeterminateLength: true,
isDefiniteNoMatch: false,
});
return;
}
if (isClassInstance(concreteSubtype)) {
if (ClassType.isBuiltIn(concreteSubtype, 'object')) {
sequenceInfo.push({
subtype,
entryTypes: [convertToInstance(concreteSubtype)],
isIndeterminateLength: true,
isObject: true,
isDefiniteNoMatch: false,
});
return;
}
for (const mroClass of concreteSubtype.details.mro) {
if (!isInstantiableClass(mroClass)) {
break;
@ -1145,7 +1129,7 @@ function getSequencePatternInfo(
ClassType.isBuiltIn(mroClass, 'bytes') ||
ClassType.isBuiltIn(mroClass, 'bytearray')
) {
break;
return;
}
if (ClassType.isBuiltIn(mroClass, 'Sequence')) {
@ -1206,8 +1190,61 @@ function getSequencePatternInfo(
}
}
// Push an entry that indicates that this is definitely not a match.
if (!pushedEntry) {
// If it wasn't a subtype of Sequence, see if it's a supertype.
const sequenceType = evaluator.getTypingType(pattern, 'Sequence');
if (sequenceType && isInstantiableClass(sequenceType)) {
const sequenceTypeVarContext = new TypeVarContext(getTypeVarScopeId(sequenceType));
if (
populateTypeVarContextBasedOnExpectedType(
evaluator,
ClassType.cloneAsInstance(sequenceType),
subtype,
sequenceTypeVarContext,
getTypeVarScopesForNode(pattern),
pattern.start
)
) {
const specializedSequence = applySolvedTypeVars(
ClassType.cloneAsInstantiable(sequenceType),
sequenceTypeVarContext
) as ClassType;
if (specializedSequence.typeArguments && specializedSequence.typeArguments.length > 0) {
sequenceInfo.push({
subtype,
entryTypes: [specializedSequence.typeArguments[0]],
isIndeterminateLength: true,
isDefiniteNoMatch: false,
isPotentialNoMatch: true,
});
return;
}
}
if (
evaluator.assignType(
subtype,
ClassType.cloneForSpecialization(
ClassType.cloneAsInstance(sequenceType),
[UnknownType.create()],
/* isTypeArgumentExplicit */ true
)
)
) {
sequenceInfo.push({
subtype,
entryTypes: [UnknownType.create()],
isIndeterminateLength: true,
isDefiniteNoMatch: false,
isPotentialNoMatch: true,
});
return;
}
}
// Push an entry that indicates that this is definitely not a match.
sequenceInfo.push({
subtype,
entryTypes: [],
@ -1227,22 +1264,11 @@ function getTypeOfPatternSequenceEntry(
entryIndex: number,
entryCount: number,
starEntryIndex: number | undefined,
unpackStarEntry: boolean,
isSubjectObject: boolean
unpackStarEntry: boolean
): Type {
if (sequenceInfo.isIndeterminateLength) {
let entryType = sequenceInfo.entryTypes[0];
// If the subject is typed as an "object", then the star entry
// is simply a list[object]. Without this special case, the list
// will be typed based on the union of all elements in the sequence.
if (isSubjectObject) {
const objectType = evaluator.getBuiltInObject(node, 'object');
if (objectType && isClassInstance(objectType)) {
entryType = objectType;
}
}
if (!unpackStarEntry && entryIndex === starEntryIndex && !isNever(entryType)) {
entryType = wrapTypeInList(evaluator, node, entryType);
}
@ -1284,7 +1310,6 @@ export function assignTypeToPatternTargets(
evaluator: TypeEvaluator,
type: Type,
isTypeIncomplete: boolean,
isSubjectObject: boolean,
pattern: PatternAtomNode
): Type {
// Further narrow the type based on this pattern.
@ -1306,13 +1331,12 @@ export function assignTypeToPatternTargets(
index,
pattern.entries.length,
pattern.starEntryIndex,
/* unpackStarEntry */ false,
isSubjectObject
/* unpackStarEntry */ false
)
)
);
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, /* isSubjectObject */ false, entry);
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, entry);
});
break;
}
@ -1324,13 +1348,7 @@ export function assignTypeToPatternTargets(
let runningNarrowedType = narrowedType;
pattern.orPatterns.forEach((orPattern) => {
assignTypeToPatternTargets(
evaluator,
runningNarrowedType,
isTypeIncomplete,
isSubjectObject,
orPattern
);
assignTypeToPatternTargets(evaluator, runningNarrowedType, isTypeIncomplete, orPattern);
// OR patterns are evaluated left to right, so we can narrow
// the type as we go.
@ -1443,20 +1461,8 @@ export function assignTypeToPatternTargets(
const valueType = combineTypes(valueTypes);
if (mappingEntry.nodeType === ParseNodeType.PatternMappingKeyEntry) {
assignTypeToPatternTargets(
evaluator,
keyType,
isTypeIncomplete,
/* isSubjectObject */ false,
mappingEntry.keyPattern
);
assignTypeToPatternTargets(
evaluator,
valueType,
isTypeIncomplete,
/* isSubjectObject */ false,
mappingEntry.valuePattern
);
assignTypeToPatternTargets(evaluator, keyType, isTypeIncomplete, mappingEntry.keyPattern);
assignTypeToPatternTargets(evaluator, valueType, isTypeIncomplete, mappingEntry.valuePattern);
} else if (mappingEntry.nodeType === ParseNodeType.PatternMappingExpandEntry) {
const dictClass = evaluator.getBuiltInType(pattern, 'dict');
const strType = evaluator.getBuiltInObject(pattern, 'str');
@ -1527,13 +1533,7 @@ export function assignTypeToPatternTargets(
});
pattern.arguments.forEach((arg, index) => {
assignTypeToPatternTargets(
evaluator,
combineTypes(argTypes[index]),
isTypeIncomplete,
/* isSubjectObject */ false,
arg.pattern
);
assignTypeToPatternTargets(evaluator, combineTypes(argTypes[index]), isTypeIncomplete, arg.pattern);
});
break;
}
@ -1556,6 +1556,10 @@ function wrapTypeInList(evaluator: TypeEvaluator, node: ParseNode, type: Type):
const listObjectType = convertToInstance(evaluator.getBuiltInObject(node, 'list'));
if (listObjectType && isClassInstance(listObjectType)) {
// If the type is a union that contains an Any or Unknown, eliminate the other
// types before wrapping it in a list.
type = containsAnyOrUnknown(type, /* recurse */ false) ?? type;
return ClassType.cloneForSpecialization(listObjectType, [type], /* isTypeArgumentExplicit */ true);
}

View File

@ -17627,19 +17627,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}
// Determine if the pre-narrowed subject type contains an object.
let subjectIsObject = false;
doForEachSubtype(makeTopLevelTypeVarsConcrete(subjectType), (subtype) => {
if (isClassInstance(subtype) && ClassType.isBuiltIn(subtype, 'object')) {
subjectIsObject = true;
}
});
const narrowedSubjectType = assignTypeToPatternTargets(
evaluatorInterface,
subjectType,
!!subjectTypeResult.isIncomplete,
subjectIsObject,
node.pattern
);

View File

@ -7,22 +7,22 @@ def test_unknown(value_to_match):
case 3 as a1, -3 as a2:
reveal_type(a1, expected_text="Literal[3]")
reveal_type(a2, expected_text="Literal[-3]")
reveal_type(value_to_match, expected_text="Unknown")
reveal_type(value_to_match, expected_text="Sequence[int]")
case 3j as b1, -3 + 5j as b2:
reveal_type(b1, expected_text="complex")
reveal_type(b2, expected_text="complex")
reveal_type(value_to_match, expected_text="Unknown")
reveal_type(value_to_match, expected_text="Sequence[complex]")
case "hi" as c1, None as c2:
reveal_type(c1, expected_text="Literal['hi']")
reveal_type(c2, expected_text="None")
reveal_type(value_to_match, expected_text="Unknown")
reveal_type(value_to_match, expected_text="Sequence[str | None]")
case True as d1, False as d2:
reveal_type(d1, expected_text="Literal[True]")
reveal_type(d2, expected_text="Literal[False]")
reveal_type(value_to_match, expected_text="Unknown")
reveal_type(value_to_match, expected_text="Sequence[bool]")
def test_tuple(value_to_match: tuple[int | float | str | complex, ...]):

View File

@ -1,7 +1,7 @@
# This sample tests type checking for match statements (as
# described in PEP 634) that contain sequence patterns.
from typing import Any, Generic, List, Literal, Protocol, Tuple, TypeVar, Union
from typing import Any, Generic, Iterator, List, Literal, Protocol, Reversible, Sequence, Tuple, TypeVar, Union
def test_unknown(value_to_match):
match value_to_match:
@ -40,6 +40,28 @@ def test_any(value_to_match: Any):
reveal_type(b1, expected_text="Any")
def test_reversible(value_to_match: Reversible[int]):
match value_to_match:
case [*a1]:
reveal_type(a1, expected_text="list[int]")
case b1:
reveal_type(b1, expected_text="Reversible[int]")
_T_co = TypeVar("_T_co", covariant=True)
class SeqProto(Protocol[_T_co]):
def __reversed__(self) -> Iterator[_T_co]: ...
def test_protocol(value_to_match: SeqProto[str]):
match value_to_match:
case [*a1]:
reveal_type(a1, expected_text="list[str]")
case b1:
reveal_type(b1, expected_text="SeqProto[str]")
def test_list(value_to_match: List[str]):
match value_to_match:
case a1, a2:
@ -159,39 +181,39 @@ def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, fl
reveal_type(a2, expected_text="str | float | Any")
reveal_type(a3, expected_text="float | str | Any")
reveal_type(a4, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="tuple[int, str, float, complex] | List[str] | tuple[float, float, float, float] | Any")
reveal_type(value_to_match, expected_text="tuple[int, str, float, complex] | List[str] | tuple[float, float, float, float] | Sequence[Any]")
case *b1, b2 if value_to_match[0] == 0:
reveal_type(b1, expected_text="list[complex] | list[int | str | float] | list[str] | list[float] | list[Any]")
reveal_type(b2, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
case c1, *c2 if value_to_match[0] == 0:
reveal_type(c1, expected_text="complex | int | str | float | Any")
reveal_type(c2, expected_text="list[complex] | list[str | float | complex] | list[str] | list[float] | list[Any]")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
case d1, *d2, d3 if value_to_match[0] == 0:
reveal_type(d1, expected_text="complex | int | str | float | Any")
reveal_type(d2, expected_text="list[str | float] | list[str] | list[float] | list[Any]")
reveal_type(d3, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
case 3, e1:
reveal_type(e1, expected_text="complex | float | Any")
reveal_type(value_to_match, expected_text="tuple[Literal[3], complex] | tuple[Literal[3], float] | Any")
reveal_type(value_to_match, expected_text="tuple[Literal[3], complex] | tuple[Literal[3], float] | Sequence[Any]")
case "hi", *f1:
reveal_type(f1, expected_text="list[str] | list[Any]")
reveal_type(value_to_match, expected_text="List[str] | Any")
reveal_type(value_to_match, expected_text="List[str] | Sequence[Any]")
case *g1, 3j:
reveal_type(g1, expected_text="list[complex] | list[int | str | float] | list[Any]")
reveal_type(value_to_match, expected_text="tuple[complex, complex] | Tuple[int, str, float, complex] | Any")
reveal_type(value_to_match, expected_text="tuple[complex, complex] | Tuple[int, str, float, complex] | Sequence[Any]")
case *h1, "hi":
reveal_type(h1, expected_text="list[str] | list[Any]")
reveal_type(value_to_match, expected_text="List[str] | Any")
reveal_type(value_to_match, expected_text="List[str] | Sequence[Any]")
class SupportsLessThan(Protocol):
@ -238,56 +260,73 @@ def test_exceptions(seq: Union[str, bytes, bytearray]):
reveal_type(y, expected_text="Never")
return seq
def test_object(seq: object):
def test_object1(seq: object):
match seq:
case (a1, a2) as a3:
reveal_type(a1, expected_text="object")
reveal_type(a2, expected_text="object")
reveal_type(a3, expected_text="Sequence[object]")
reveal_type(seq, expected_text="Sequence[object]")
reveal_type(a1, expected_text="Unknown")
reveal_type(a2, expected_text="Unknown")
reveal_type(a3, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (*b1, b2) as b3:
reveal_type(b1, expected_text="list[object]")
reveal_type(b2, expected_text="object")
reveal_type(b3, expected_text="Sequence[object]")
reveal_type(seq, expected_text="Sequence[object]")
reveal_type(b1, expected_text="list[Unknown]")
reveal_type(b2, expected_text="Unknown")
reveal_type(b3, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (c1, *c2) as c3:
reveal_type(c1, expected_text="object")
reveal_type(c2, expected_text="list[object]")
reveal_type(c3, expected_text="Sequence[object]")
reveal_type(seq, expected_text="Sequence[object]")
reveal_type(c1, expected_text="Unknown")
reveal_type(c2, expected_text="list[Unknown]")
reveal_type(c3, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (d1, *d2, d3) as d4:
reveal_type(d1, expected_text="object")
reveal_type(d2, expected_text="list[object]")
reveal_type(d3, expected_text="object")
reveal_type(d4, expected_text="Sequence[object]")
reveal_type(seq, expected_text="Sequence[object]")
reveal_type(d1, expected_text="Unknown")
reveal_type(d2, expected_text="list[Unknown]")
reveal_type(d3, expected_text="Unknown")
reveal_type(d4, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (3, *e1) as e2:
reveal_type(e1, expected_text="list[object]")
reveal_type(e2, expected_text="Sequence[object | int]")
reveal_type(seq, expected_text="Sequence[object | int]")
reveal_type(e1, expected_text="list[Unknown]")
reveal_type(e2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case ("hi", *f1) as f2:
reveal_type(f1, expected_text="list[object]")
reveal_type(f2, expected_text="Sequence[object | str]")
reveal_type(seq, expected_text="Sequence[object | str]")
reveal_type(f1, expected_text="list[Unknown]")
reveal_type(f2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (*g1, "hi") as g2:
reveal_type(g1, expected_text="list[object]")
reveal_type(g2, expected_text="Sequence[object | str]")
reveal_type(seq, expected_text="Sequence[object | str]")
reveal_type(g1, expected_text="list[Unknown]")
reveal_type(g2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case [1, "hi", True] as h1:
reveal_type(h1, expected_text="Sequence[int | str | bool]")
reveal_type(seq, expected_text="Sequence[int | str | bool]")
case [1, i1] as i2:
reveal_type(i1, expected_text="object")
reveal_type(i2, expected_text="Sequence[object | int]")
reveal_type(seq, expected_text="Sequence[object | int]")
reveal_type(i1, expected_text="Unknown")
reveal_type(i2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
def test_object2(value_to_match: object):
match value_to_match:
case [*a1]:
reveal_type(a1, expected_text="list[Unknown]")
case b1:
reveal_type(b1, expected_text="object")
def test_sequence(value_to_match: Sequence[Any]):
match value_to_match:
case [*a1]:
reveal_type(a1, expected_text="list[Any]")
case b1:
reveal_type(b1, expected_text="Never")
_T = TypeVar('_T')