mirror of
https://github.com/microsoft/pyright.git
synced 2024-10-07 13:29:17 +03:00
Added support for sequence pattern match type narrowing when the subject type is a simple "object".
This commit is contained in:
parent
b9a96bbe1f
commit
68591ca566
@ -76,8 +76,9 @@ const classPatternSpecialCases = [
|
||||
interface SequencePatternInfo {
|
||||
subtype: Type;
|
||||
entryTypes: Type[];
|
||||
isIndeterminateLength: boolean;
|
||||
isTuple: boolean;
|
||||
isIndeterminateLength?: boolean;
|
||||
isTuple?: boolean;
|
||||
isObject?: boolean;
|
||||
}
|
||||
|
||||
interface MappingPatternInfo {
|
||||
@ -143,6 +144,7 @@ function narrowTypeBasedOnSequencePattern(
|
||||
}
|
||||
|
||||
let sequenceInfo = getSequencePatternInfo(evaluator, type, pattern.entries.length, pattern.starEntryIndex);
|
||||
|
||||
// Further narrow based on pattern entry types.
|
||||
sequenceInfo = sequenceInfo.filter((entry) => {
|
||||
let isPlausibleMatch = true;
|
||||
@ -156,7 +158,9 @@ function narrowTypeBasedOnSequencePattern(
|
||||
entry,
|
||||
index,
|
||||
pattern.entries.length,
|
||||
pattern.starEntryIndex
|
||||
pattern.starEntryIndex,
|
||||
/* unpackStarEntry */ true,
|
||||
/* isSubjectObject */ false
|
||||
);
|
||||
|
||||
const narrowedEntryType = narrowTypeBasedOnPattern(
|
||||
@ -165,6 +169,7 @@ function narrowTypeBasedOnSequencePattern(
|
||||
sequenceEntry,
|
||||
/* isPositiveTest */ true
|
||||
);
|
||||
|
||||
if (index === pattern.starEntryIndex) {
|
||||
if (
|
||||
isClassInstance(narrowedEntryType) &&
|
||||
@ -174,24 +179,41 @@ function narrowTypeBasedOnSequencePattern(
|
||||
) {
|
||||
narrowedEntryTypes.push(...narrowedEntryType.tupleTypeArguments);
|
||||
} else {
|
||||
narrowedEntryTypes.push(narrowedEntryType);
|
||||
canNarrowTuple = false;
|
||||
}
|
||||
} else {
|
||||
narrowedEntryTypes.push(narrowedEntryType);
|
||||
}
|
||||
|
||||
if (isNever(narrowedEntryType)) {
|
||||
isPlausibleMatch = false;
|
||||
if (isNever(narrowedEntryType)) {
|
||||
isPlausibleMatch = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// If this is a tuple, we can narrow it to a specific tuple type.
|
||||
// Other sequences cannot be narrowed because we don't know if they
|
||||
// are immutable (covariant).
|
||||
if (isPlausibleMatch && canNarrowTuple) {
|
||||
const tupleClassType = evaluator.getBuiltInType(pattern, 'tuple');
|
||||
if (tupleClassType && isInstantiableClass(tupleClassType)) {
|
||||
entry.subtype = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, narrowedEntryTypes));
|
||||
if (isPlausibleMatch) {
|
||||
// If this is a tuple, we can narrow it to a specific tuple type.
|
||||
// Other sequences cannot be narrowed because we don't know if they
|
||||
// are immutable (covariant).
|
||||
if (canNarrowTuple) {
|
||||
const tupleClassType = evaluator.getBuiltInType(pattern, 'tuple');
|
||||
if (tupleClassType && isInstantiableClass(tupleClassType)) {
|
||||
entry.subtype = ClassType.cloneAsInstance(specializeTupleClass(tupleClassType, narrowedEntryTypes));
|
||||
}
|
||||
}
|
||||
|
||||
// If this is an object, we can narrow it to a specific Sequence type.
|
||||
if (entry.isObject) {
|
||||
const sequenceType = evaluator.getTypingType(pattern, 'Sequence');
|
||||
if (sequenceType && isInstantiableClass(sequenceType)) {
|
||||
entry.subtype = ClassType.cloneAsInstance(
|
||||
ClassType.cloneForSpecialization(
|
||||
sequenceType,
|
||||
[stripLiteralValue(combineTypes(narrowedEntryTypes))],
|
||||
/* isTypeArgumentExplicit */ true
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -694,9 +716,21 @@ function getSequencePatternInfo(
|
||||
subtype,
|
||||
entryTypes: [concreteSubtype],
|
||||
isIndeterminateLength: true,
|
||||
isTuple: false,
|
||||
});
|
||||
} else if (isClassInstance(concreteSubtype)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isClassInstance(concreteSubtype)) {
|
||||
if (ClassType.isBuiltIn(concreteSubtype, 'object')) {
|
||||
sequenceInfo.push({
|
||||
subtype,
|
||||
entryTypes: [convertToInstance(concreteSubtype)],
|
||||
isIndeterminateLength: true,
|
||||
isObject: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
for (const mroClass of concreteSubtype.details.mro) {
|
||||
if (!isInstantiableClass(mroClass)) {
|
||||
break;
|
||||
@ -724,6 +758,7 @@ function getSequencePatternInfo(
|
||||
|
||||
if (mroClassToSpecialize) {
|
||||
const specializedSequence = partiallySpecializeType(mroClassToSpecialize, concreteSubtype) as ClassType;
|
||||
|
||||
if (isTupleClass(specializedSequence)) {
|
||||
if (specializedSequence.tupleTypeArguments) {
|
||||
if (isOpenEndedTupleClass(specializedSequence)) {
|
||||
@ -757,7 +792,6 @@ function getSequencePatternInfo(
|
||||
: UnknownType.create(),
|
||||
],
|
||||
isIndeterminateLength: true,
|
||||
isTuple: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -773,49 +807,56 @@ function getTypeForPatternSequenceEntry(
|
||||
sequenceInfo: SequencePatternInfo,
|
||||
entryIndex: number,
|
||||
entryCount: number,
|
||||
starEntryIndex: number | undefined
|
||||
starEntryIndex: number | undefined,
|
||||
unpackStarEntry: boolean,
|
||||
isSubjectObject: boolean
|
||||
): Type {
|
||||
if (sequenceInfo.isIndeterminateLength) {
|
||||
if (starEntryIndex === entryIndex) {
|
||||
const listInstanceType = convertToInstance(evaluator.getBuiltInType(node, 'list'));
|
||||
if (isClassInstance(listInstanceType)) {
|
||||
return ClassType.cloneForSpecialization(
|
||||
listInstanceType,
|
||||
[sequenceInfo.entryTypes[0]],
|
||||
/* isTypeArgumentExplicit */ true
|
||||
);
|
||||
} else {
|
||||
return UnknownType.create();
|
||||
let entryType = sequenceInfo.entryTypes[0];
|
||||
|
||||
// If the subject is typed as an "object", then the star entry
|
||||
// is simply a list[object]. Without this special case, the list
|
||||
// will be typed based on the union of all elements in the sequence.
|
||||
if (isSubjectObject) {
|
||||
const objectType = evaluator.getBuiltInObject(node, 'object');
|
||||
if (objectType && isClassInstance(objectType)) {
|
||||
entryType = objectType;
|
||||
}
|
||||
} else {
|
||||
return sequenceInfo.entryTypes[0];
|
||||
}
|
||||
} else if (starEntryIndex === undefined || entryIndex < starEntryIndex) {
|
||||
|
||||
if (!unpackStarEntry && entryIndex === starEntryIndex && !isNever(entryType)) {
|
||||
entryType = wrapTypeInList(evaluator, node, entryType);
|
||||
}
|
||||
|
||||
return entryType;
|
||||
}
|
||||
|
||||
if (starEntryIndex === undefined || entryIndex < starEntryIndex) {
|
||||
return sequenceInfo.entryTypes[entryIndex];
|
||||
} else if (entryIndex === starEntryIndex) {
|
||||
}
|
||||
|
||||
if (entryIndex === starEntryIndex) {
|
||||
// Create a list out of the entries that map to the star entry.
|
||||
// Note that we strip literal types here.
|
||||
const starEntryTypes = sequenceInfo.entryTypes
|
||||
.slice(starEntryIndex, starEntryIndex + sequenceInfo.entryTypes.length - entryCount + 1)
|
||||
.map((type) => stripLiteralValue(type));
|
||||
|
||||
const listInstanceType = convertToInstance(evaluator.getBuiltInType(node, 'list'));
|
||||
if (isClassInstance(listInstanceType)) {
|
||||
return ClassType.cloneForSpecialization(
|
||||
listInstanceType,
|
||||
[combineTypes(starEntryTypes)],
|
||||
/* isTypeArgumentExplicit */ true
|
||||
);
|
||||
} else {
|
||||
return UnknownType.create();
|
||||
let entryType = combineTypes(starEntryTypes);
|
||||
|
||||
if (!unpackStarEntry) {
|
||||
entryType = wrapTypeInList(evaluator, node, entryType);
|
||||
}
|
||||
} else {
|
||||
// The entry index is past the index of the star entry, so we need
|
||||
// to index from the end of the sequence rather than the start.
|
||||
const itemIndex = sequenceInfo.entryTypes.length - (entryCount - entryIndex);
|
||||
assert(itemIndex >= 0 && itemIndex < sequenceInfo.entryTypes.length);
|
||||
return sequenceInfo.entryTypes[itemIndex];
|
||||
|
||||
return entryType;
|
||||
}
|
||||
|
||||
// The entry index is past the index of the star entry, so we need
|
||||
// to index from the end of the sequence rather than the start.
|
||||
const itemIndex = sequenceInfo.entryTypes.length - (entryCount - entryIndex);
|
||||
assert(itemIndex >= 0 && itemIndex < sequenceInfo.entryTypes.length);
|
||||
|
||||
return sequenceInfo.entryTypes[itemIndex];
|
||||
}
|
||||
|
||||
// Recursively assigns the specified type to the pattern and any capture
|
||||
@ -824,6 +865,7 @@ export function assignTypeToPatternTargets(
|
||||
evaluator: TypeEvaluator,
|
||||
type: Type,
|
||||
isTypeIncomplete: boolean,
|
||||
isSubjectObject: boolean,
|
||||
pattern: PatternAtomNode
|
||||
) {
|
||||
// Further narrow the type based on this pattern.
|
||||
@ -847,12 +889,14 @@ export function assignTypeToPatternTargets(
|
||||
info,
|
||||
index,
|
||||
pattern.entries.length,
|
||||
pattern.starEntryIndex
|
||||
pattern.starEntryIndex,
|
||||
/* unpackStarEntry */ false,
|
||||
isSubjectObject
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, entry);
|
||||
assignTypeToPatternTargets(evaluator, entryType, isTypeIncomplete, /* isSubjectObject */ false, entry);
|
||||
});
|
||||
break;
|
||||
}
|
||||
@ -863,7 +907,7 @@ export function assignTypeToPatternTargets(
|
||||
}
|
||||
|
||||
pattern.orPatterns.forEach((orPattern) => {
|
||||
assignTypeToPatternTargets(evaluator, type, isTypeIncomplete, orPattern);
|
||||
assignTypeToPatternTargets(evaluator, type, isTypeIncomplete, isSubjectObject, orPattern);
|
||||
|
||||
// OR patterns are evaluated left to right, so we can narrow
|
||||
// the type as we go.
|
||||
@ -948,8 +992,20 @@ export function assignTypeToPatternTargets(
|
||||
const valueType = combineTypes(valueTypes);
|
||||
|
||||
if (mappingEntry.nodeType === ParseNodeType.PatternMappingKeyEntry) {
|
||||
assignTypeToPatternTargets(evaluator, keyType, isTypeIncomplete, mappingEntry.keyPattern);
|
||||
assignTypeToPatternTargets(evaluator, valueType, isTypeIncomplete, mappingEntry.valuePattern);
|
||||
assignTypeToPatternTargets(
|
||||
evaluator,
|
||||
keyType,
|
||||
isTypeIncomplete,
|
||||
/* isSubjectObject */ false,
|
||||
mappingEntry.keyPattern
|
||||
);
|
||||
assignTypeToPatternTargets(
|
||||
evaluator,
|
||||
valueType,
|
||||
isTypeIncomplete,
|
||||
/* isSubjectObject */ false,
|
||||
mappingEntry.valuePattern
|
||||
);
|
||||
} else if (mappingEntry.nodeType === ParseNodeType.PatternMappingExpandEntry) {
|
||||
const dictClass = evaluator.getBuiltInType(pattern, 'dict');
|
||||
const strType = evaluator.getBuiltInObject(pattern, 'str');
|
||||
@ -1019,7 +1075,13 @@ export function assignTypeToPatternTargets(
|
||||
});
|
||||
|
||||
pattern.arguments.forEach((arg, index) => {
|
||||
assignTypeToPatternTargets(evaluator, combineTypes(argTypes[index]), isTypeIncomplete, arg.pattern);
|
||||
assignTypeToPatternTargets(
|
||||
evaluator,
|
||||
combineTypes(argTypes[index]),
|
||||
isTypeIncomplete,
|
||||
/* isSubjectObject */ false,
|
||||
arg.pattern
|
||||
);
|
||||
});
|
||||
break;
|
||||
}
|
||||
@ -1032,3 +1094,16 @@ export function assignTypeToPatternTargets(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function wrapTypeInList(evaluator: TypeEvaluator, node: ParseNode, type: Type): Type {
|
||||
if (isNever(type)) {
|
||||
return type;
|
||||
}
|
||||
|
||||
const listObjectType = convertToInstance(evaluator.getBuiltInObject(node, 'list'));
|
||||
if (listObjectType && isClassInstance(listObjectType)) {
|
||||
return ClassType.cloneForSpecialization(listObjectType, [type], /* isTypeArgumentExplicit */ true);
|
||||
}
|
||||
|
||||
return UnknownType.create();
|
||||
}
|
||||
|
@ -14185,6 +14185,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if the pre-narrowed subject type contains an object.
|
||||
let subjectIsObject = false;
|
||||
doForEachSubtype(makeTopLevelTypeVarsConcrete(subjectType), (subtype) => {
|
||||
if (isClassInstance(subtype) && ClassType.isBuiltIn(subtype, 'object')) {
|
||||
subjectIsObject = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Apply positive narrowing for the current case statement.
|
||||
subjectType = narrowTypeBasedOnPattern(
|
||||
evaluatorInterface,
|
||||
@ -14192,7 +14200,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
|
||||
node.pattern,
|
||||
/* isPositiveTest */ true
|
||||
);
|
||||
assignTypeToPatternTargets(evaluatorInterface, subjectType, !!subjectTypeResult.isIncomplete, node.pattern);
|
||||
|
||||
assignTypeToPatternTargets(
|
||||
evaluatorInterface,
|
||||
subjectType,
|
||||
!!subjectTypeResult.isIncomplete,
|
||||
subjectIsObject,
|
||||
node.pattern
|
||||
);
|
||||
|
||||
writeTypeCache(node, subjectType, !!subjectTypeResult.isIncomplete);
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, fl
|
||||
|
||||
case d1, *d2, d3:
|
||||
t_d1: Literal["complex | int | str | float | Any"] = reveal_type(d1)
|
||||
t_d2: Literal["list[Unknown] | list[str | float] | list[str] | list[float] | list[Any]"] = reveal_type(d2)
|
||||
t_d2: Literal["list[str | float] | list[str] | list[float] | list[Any]"] = reveal_type(d2)
|
||||
t_d3: Literal["complex | str | float | Any"] = reveal_type(d3)
|
||||
t_v4: Literal["Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Any"] = reveal_type(value_to_match)
|
||||
|
||||
@ -229,3 +229,54 @@ def test_exceptions(seq: Union[str, bytes, bytearray]):
|
||||
t_v1: Literal["Never"] = reveal_type(x)
|
||||
t_v2: Literal["Never"] = reveal_type(y)
|
||||
return seq
|
||||
|
||||
def test_object(seq: object):
|
||||
match seq:
|
||||
case (a1, a2) as a3:
|
||||
t_a1: Literal["object"] = reveal_type(a1)
|
||||
t_a2: Literal["object"] = reveal_type(a2)
|
||||
t_a3: Literal["Sequence[object]"] = reveal_type(a3)
|
||||
t_va: Literal["Sequence[object]"] = reveal_type(seq)
|
||||
|
||||
case (*b1, b2) as b3:
|
||||
t_b1: Literal["list[object]"] = reveal_type(b1)
|
||||
t_b2: Literal["object"] = reveal_type(b2)
|
||||
t_b3: Literal["Sequence[object]"] = reveal_type(b3)
|
||||
t_vb: Literal["Sequence[object]"] = reveal_type(seq)
|
||||
|
||||
case (c1, *c2) as c3:
|
||||
t_c1: Literal["object"] = reveal_type(c1)
|
||||
t_c2: Literal["list[object]"] = reveal_type(c2)
|
||||
t_c3: Literal["Sequence[object]"] = reveal_type(c3)
|
||||
t_vc: Literal["Sequence[object]"] = reveal_type(seq)
|
||||
|
||||
case (d1, *d2, d3) as d4:
|
||||
t_d1: Literal["object"] = reveal_type(d1)
|
||||
t_d2: Literal["list[object]"] = reveal_type(d2)
|
||||
t_d3: Literal["object"] = reveal_type(d3)
|
||||
t_d4: Literal["Sequence[object]"] = reveal_type(d4)
|
||||
t_vd: Literal["Sequence[object]"] = reveal_type(seq)
|
||||
|
||||
case (3, *e1) as e2:
|
||||
t_e1: Literal["list[object]"] = reveal_type(e1)
|
||||
t_e2: Literal["Sequence[object | int]"] = reveal_type(e2)
|
||||
t_ve: Literal["Sequence[object | int]"] = reveal_type(seq)
|
||||
|
||||
case ("hi", *f1) as f2:
|
||||
t_f1: Literal["list[object]"] = reveal_type(f1)
|
||||
t_f2: Literal["Sequence[object | str]"] = reveal_type(f2)
|
||||
t_vf: Literal["Sequence[object | str]"] = reveal_type(seq)
|
||||
|
||||
case (*g1, "hi") as g2:
|
||||
t_g1: Literal["list[object]"] = reveal_type(g1)
|
||||
t_g2: Literal["Sequence[object | str]"] = reveal_type(g2)
|
||||
t_vg: Literal["Sequence[object | str]"] = reveal_type(seq)
|
||||
|
||||
case [1, "hi", True] as h1:
|
||||
t_h1: Literal["Sequence[int | str | bool]"] = reveal_type(h1)
|
||||
t_vh: Literal["Sequence[int | str | bool]"] = reveal_type(seq)
|
||||
|
||||
case [1, i1] as i2:
|
||||
t_i1: Literal["object"] = reveal_type(i1)
|
||||
t_i2: Literal["Sequence[object | int]"] = reveal_type(i2)
|
||||
t_vi: Literal["Sequence[object | int]"] = reveal_type(seq)
|
||||
|
Loading…
Reference in New Issue
Block a user