Added support for sequence pattern match type narrowing when the subject type is a simple "object".

This commit is contained in:
Eric Traut 2021-10-18 22:02:01 -07:00
parent b9a96bbe1f
commit 68591ca566
3 changed files with 195 additions and 54 deletions

View File

@ -76,8 +76,9 @@ const classPatternSpecialCases = [
interface SequencePatternInfo {
subtype: Type;
entryTypes: Type[];
isIndeterminateLength: boolean;
isTuple: boolean;
isIndeterminateLength?: boolean;
isTuple?: boolean;
isObject?: boolean;
}
interface MappingPatternInfo {
@ -143,6 +144,7 @@ function narrowTypeBasedOnSequencePattern(
}
let sequenceInfo = getSequencePatternInfo(evaluator, type, pattern.entries.length, pattern.starEntryIndex);
// Further narrow based on pattern entry types.
sequenceInfo = sequenceInfo.filter((entry) => {
let isPlausibleMatch = true;
@ -156,7 +158,9 @@ function narrowTypeBasedOnSequencePattern(
entry,
index,
pattern.entries.length,
pattern.starEntryIndex
pattern.starEntryIndex,
/* unpackStarEntry */ true,
/* isSubjectObject */ false
);
const narrowedEntryType = narrowTypeBasedOnPattern(
@ -165,6 +169,7 @@ function narrowTypeBasedOnSequencePattern(
sequenceEntry,
/* isPositiveTest */ true
);
if (index === pattern.starEntryIndex) {
if (
isClassInstance(narrowedEntryType) &&
@ -174,24 +179,41 @@ function narrowTypeBasedOnSequencePattern(
) {
narrowedEntryTypes.push(...narrowedEntryType.tupleTypeArguments);
} else {
narrowedEntryTypes.push(narrowedEntryType);
canNarrowTuple = false;
}
} else {
narrowedEntryTypes.push(narrowedEntryType);
}
if (isNever(narrowedEntryType)) {
isPlausibleMatch = false;
if (isNever(narrowedEntryType)) {
isPlausibleMatch = false;
}
}
});
// If this is a tuple, we can narrow it to a specific tuple type.
// Other sequences cannot be narrowed because we don't know if they
// are immutable (covariant).
if (isPlausibleMatch && canNarrowTuple) {
const tupleClassType = evaluator.getBuiltInType(pattern, 'tuple');
if (tupleClassType && isInstantiableClass(tupleClassType)) {
entry.subtype = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, narrowedEntryTypes));
if (isPlausibleMatch) {
// If this is a tuple, we can narrow it to a specific tuple type.
// Other sequences cannot be narrowed because we don't know if they
// are immutable (covariant).
if (canNarrowTuple) {
const tupleClassType = evaluator.getBuiltInType(pattern, 'tuple');
if (tupleClassType && isInstantiableClass(tupleClassType)) {
entry.subtype = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, narrowedEntryTypes));
}
}
// If this is an object, we can narrow it to a specific Sequence type.
if (entry.isObject) {
const sequenceType = evaluator.getTypingType(pattern, 'Sequence');
if (sequenceType && isInstantiableClass(sequenceType)) {
entry.subtype = ClassType.cloneAsInstance(
ClassType.cloneForSpecialization(
sequenceType,
[stripLiteralValue(combineTypes(narrowedEntryTypes))],
/* isTypeArgumentExplicit */ true
)
);
}
}
}
@ -694,9 +716,21 @@ function getSequencePatternInfo(
subtype,
entryTypes: [concreteSubtype],
isIndeterminateLength: true,
isTuple: false,
});
} else if (isClassInstance(concreteSubtype)) {
return;
}
if (isClassInstance(concreteSubtype)) {
if (ClassType.isBuiltIn(concreteSubtype, 'object')) {
sequenceInfo.push({
subtype,
entryTypes: [convertToInstance(concreteSubtype)],
isIndeterminateLength: true,
isObject: true,
});
return;
}
for (const mroClass of concreteSubtype.details.mro) {
if (!isInstantiableClass(mroClass)) {
break;
@ -724,6 +758,7 @@ function getSequencePatternInfo(
if (mroClassToSpecialize) {
const specializedSequence = partiallySpecializeType(mroClassToSpecialize, concreteSubtype) as ClassType;
if (isTupleClass(specializedSequence)) {
if (specializedSequence.tupleTypeArguments) {
if (isOpenEndedTupleClass(specializedSequence)) {
@ -757,7 +792,6 @@ function getSequencePatternInfo(
: UnknownType.create(),
],
isIndeterminateLength: true,
isTuple: false,
});
}
}
@ -773,49 +807,56 @@ function getTypeForPatternSequenceEntry(
sequenceInfo: SequencePatternInfo,
entryIndex: number,
entryCount: number,
starEntryIndex: number | undefined
starEntryIndex: number | undefined,
unpackStarEntry: boolean,
isSubjectObject: boolean
): Type {
if (sequenceInfo.isIndeterminateLength) {
if (starEntryIndex === entryIndex) {
const listInstanceType = convertToInstance(evaluator.getBuiltInType(node, 'list'));
if (isClassInstance(listInstanceType)) {
return ClassType.cloneForSpecialization(
listInstanceType,
[sequenceInfo.entryTypes[0]],
/* isTypeArgumentExplicit */ true
);
} else {
return UnknownType.create();
let entryType = sequenceInfo.entryTypes[0];
// If the subject is typed as an "object", then the star entry
// is simply a list[object]. Without this special case, the list
// will be typed based on the union of all elements in the sequence.
if (isSubjectObject) {
const objectType = evaluator.getBuiltInObject(node, 'object');
if (objectType && isClassInstance(objectType)) {
entryType = objectType;
}
} else {
return sequenceInfo.entryTypes[0];
}
} else if (starEntryIndex === undefined || entryIndex < starEntryIndex) {
if (!unpackStarEntry && entryIndex === starEntryIndex && !isNever(entryType)) {
entryType = wrapTypeInList(evaluator, node, entryType);
}
return entryType;
}
if (starEntryIndex === undefined || entryIndex < starEntryIndex) {
return sequenceInfo.entryTypes[entryIndex];
} else if (entryIndex === starEntryIndex) {
}
if (entryIndex === starEntryIndex) {
// Create a list out of the entries that map to the star entry.
// Note that we strip literal types here.
const starEntryTypes = sequenceInfo.entryTypes
.slice(starEntryIndex, starEntryIndex + sequenceInfo.entryTypes.length - entryCount + 1)
.map((type) => stripLiteralValue(type));
const listInstanceType = convertToInstance(evaluator.getBuiltInType(node, 'list'));
if (isClassInstance(listInstanceType)) {
return ClassType.cloneForSpecialization(
listInstanceType,
[combineTypes(starEntryTypes)],
/* isTypeArgumentExplicit */ true
);
} else {
return UnknownType.create();
let entryType = combineTypes(starEntryTypes);
if (!unpackStarEntry) {
entryType = wrapTypeInList(evaluator, node, entryType);
}
} else {
// The entry index is past the index of the star entry, so we need
// to index from the end of the sequence rather than the start.
const itemIndex = sequenceInfo.entryTypes.length - (entryCount - entryIndex);
assert(itemIndex >= 0 && itemIndex < sequenceInfo.entryTypes.length);
return sequenceInfo.entryTypes[itemIndex];
return entryType;
}
// The entry index is past the index of the star entry, so we need
// to index from the end of the sequence rather than the start.
const itemIndex = sequenceInfo.entryTypes.length - (entryCount - entryIndex);
assert(itemIndex >= 0 && itemIndex < sequenceInfo.entryTypes.length);
return sequenceInfo.entryTypes[itemIndex];
}
// Recursively assigns the specified type to the pattern and any capture
@ -824,6 +865,7 @@ export function assignTypeToPatternTargets(
evaluator: TypeEvaluator,
type: Type,
isTypeIncomplete: boolean,
isSubjectObject: boolean,
pattern: PatternAtomNode
) {
// Further narrow the type based on this pattern.
@ -847,12 +889,14 @@ export function assignTypeToPatternTargets(
info,
index,
pattern.entries.length,
pattern.starEntryIndex
pattern.starEntryIndex,
/* unpackStarEntry */ false,
isSubjectObject
)
)
);
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, entry);
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, /* isSubjectObject */ false, entry);
});
break;
}
@ -863,7 +907,7 @@ export function assignTypeToPatternTargets(
}
pattern.orPatterns.forEach((orPattern) => {
assignTypeToPatternTargets(evaluator, type, isTypeIncomplete, orPattern);
assignTypeToPatternTargets(evaluator, type, isTypeIncomplete, isSubjectObject, orPattern);
// OR patterns are evaluated left to right, so we can narrow
// the type as we go.
@ -948,8 +992,20 @@ export function assignTypeToPatternTargets(
const valueType = combineTypes(valueTypes);
if (mappingEntry.nodeType === ParseNodeType.PatternMappingKeyEntry) {
assignTypeToPatternTargets(evaluator, keyType, isTypeIncomplete, mappingEntry.keyPattern);
assignTypeToPatternTargets(evaluator, valueType, isTypeIncomplete, mappingEntry.valuePattern);
assignTypeToPatternTargets(
evaluator,
keyType,
isTypeIncomplete,
/* isSubjectObject */ false,
mappingEntry.keyPattern
);
assignTypeToPatternTargets(
evaluator,
valueType,
isTypeIncomplete,
/* isSubjectObject */ false,
mappingEntry.valuePattern
);
} else if (mappingEntry.nodeType === ParseNodeType.PatternMappingExpandEntry) {
const dictClass = evaluator.getBuiltInType(pattern, 'dict');
const strType = evaluator.getBuiltInObject(pattern, 'str');
@ -1019,7 +1075,13 @@ export function assignTypeToPatternTargets(
});
pattern.arguments.forEach((arg, index) => {
assignTypeToPatternTargets(evaluator, combineTypes(argTypes[index]), isTypeIncomplete, arg.pattern);
assignTypeToPatternTargets(
evaluator,
combineTypes(argTypes[index]),
isTypeIncomplete,
/* isSubjectObject */ false,
arg.pattern
);
});
break;
}
@ -1032,3 +1094,16 @@ export function assignTypeToPatternTargets(
}
}
}
function wrapTypeInList(evaluator: TypeEvaluator, node: ParseNode, type: Type): Type {
if (isNever(type)) {
return type;
}
const listObjectType = convertToInstance(evaluator.getBuiltInObject(node, 'list'));
if (listObjectType && isClassInstance(listObjectType)) {
return ClassType.cloneForSpecialization(listObjectType, [type], /* isTypeArgumentExplicit */ true);
}
return UnknownType.create();
}

View File

@ -14185,6 +14185,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}
// Determine if the pre-narrowed subject type contains an object.
let subjectIsObject = false;
doForEachSubtype(makeTopLevelTypeVarsConcrete(subjectType), (subtype) => {
if (isClassInstance(subtype) && ClassType.isBuiltIn(subtype, 'object')) {
subjectIsObject = true;
}
});
// Apply positive narrowing for the current case statement.
subjectType = narrowTypeBasedOnPattern(
evaluatorInterface,
@ -14192,7 +14200,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
node.pattern,
/* isPositiveTest */ true
);
assignTypeToPatternTargets(evaluatorInterface, subjectType, !!subjectTypeResult.isIncomplete, node.pattern);
assignTypeToPatternTargets(
evaluatorInterface,
subjectType,
!!subjectTypeResult.isIncomplete,
subjectIsObject,
node.pattern
);
writeTypeCache(node, subjectType, !!subjectTypeResult.isIncomplete);
}

View File

@ -165,7 +165,7 @@ def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, fl
case d1, *d2, d3:
t_d1: Literal["complex | int | str | float | Any"] = reveal_type(d1)
t_d2: Literal["list[Unknown] | list[str | float] | list[str] | list[float] | list[Any]"] = reveal_type(d2)
t_d2: Literal["list[str | float] | list[str] | list[float] | list[Any]"] = reveal_type(d2)
t_d3: Literal["complex | str | float | Any"] = reveal_type(d3)
t_v4: Literal["Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Any"] = reveal_type(value_to_match)
@ -229,3 +229,54 @@ def test_exceptions(seq: Union[str, bytes, bytearray]):
t_v1: Literal["Never"] = reveal_type(x)
t_v2: Literal["Never"] = reveal_type(y)
return seq
def test_object(seq: object):
match seq:
case (a1, a2) as a3:
t_a1: Literal["object"] = reveal_type(a1)
t_a2: Literal["object"] = reveal_type(a2)
t_a3: Literal["Sequence[object]"] = reveal_type(a3)
t_va: Literal["Sequence[object]"] = reveal_type(seq)
case (*b1, b2) as b3:
t_b1: Literal["list[object]"] = reveal_type(b1)
t_b2: Literal["object"] = reveal_type(b2)
t_b3: Literal["Sequence[object]"] = reveal_type(b3)
t_vb: Literal["Sequence[object]"] = reveal_type(seq)
case (c1, *c2) as c3:
t_c1: Literal["object"] = reveal_type(c1)
t_c2: Literal["list[object]"] = reveal_type(c2)
t_c3: Literal["Sequence[object]"] = reveal_type(c3)
t_vc: Literal["Sequence[object]"] = reveal_type(seq)
case (d1, *d2, d3) as d4:
t_d1: Literal["object"] = reveal_type(d1)
t_d2: Literal["list[object]"] = reveal_type(d2)
t_d3: Literal["object"] = reveal_type(d3)
t_d4: Literal["Sequence[object]"] = reveal_type(d4)
t_vd: Literal["Sequence[object]"] = reveal_type(seq)
case (3, *e1) as e2:
t_e1: Literal["list[object]"] = reveal_type(e1)
t_e2: Literal["Sequence[object | int]"] = reveal_type(e2)
t_ve: Literal["Sequence[object | int]"] = reveal_type(seq)
case ("hi", *f1) as f2:
t_f1: Literal["list[object]"] = reveal_type(f1)
t_f2: Literal["Sequence[object | str]"] = reveal_type(f2)
t_vf: Literal["Sequence[object | str]"] = reveal_type(seq)
case (*g1, "hi") as g2:
t_g1: Literal["list[object]"] = reveal_type(g1)
t_g2: Literal["Sequence[object | str]"] = reveal_type(g2)
t_vg: Literal["Sequence[object | str]"] = reveal_type(seq)
case [1, "hi", True] as h1:
t_h1: Literal["Sequence[int | str | bool]"] = reveal_type(h1)
t_vh: Literal["Sequence[int | str | bool]"] = reveal_type(seq)
case [1, i1] as i2:
t_i1: Literal["object"] = reveal_type(i1)
t_i2: Literal["Sequence[object | int]"] = reveal_type(i2)
t_vi: Literal["Sequence[object | int]"] = reveal_type(seq)