Improved type narrowing in the fall-through case for sequence patterns when the pattern includes a star pattern and the subject type is a tuple with an indeterminate entry. (#8323)

This commit is contained in:
Eric Traut 2024-07-06 14:44:46 -07:00 committed by GitHub
parent 228bd84c54
commit 4c001a139c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 13 deletions

View File

@ -226,12 +226,17 @@ function narrowTypeBasedOnSequencePattern(
canNarrowTuple = false;
}
if (
isClassInstance(entry.subtype) &&
entry.subtype.tupleTypeArguments &&
entry.subtype.tupleTypeArguments.some((typeArg) => typeArg.isUnbounded)
) {
canNarrowTuple = false;
if (isClassInstance(entry.subtype) && entry.subtype.tupleTypeArguments) {
const unboundedIndex = entry.subtype.tupleTypeArguments.findIndex((typeArg) => typeArg.isUnbounded);
if (unboundedIndex >= 0) {
// If the pattern includes a "star" entry that aligns exactly with
// the corresponding unbounded entry in the tuple, we can narrow
// the tuple type.
if (pattern.starEntryIndex === undefined || pattern.starEntryIndex !== unboundedIndex) {
canNarrowTuple = false;
}
}
}
}
@ -1319,9 +1324,13 @@ function getSequencePatternInfo(
(t) => t.isUnbounded || isUnpackedVariadicTypeVar(t.type)
);
let tupleDeterminateEntryCount = typeArgs.length;
// If the tuple contains an indeterminate entry, expand or remove that
// entry to match the length of the pattern if possible.
if (tupleIndeterminateIndex >= 0) {
tupleDeterminateEntryCount--;
while (typeArgs.length < patternEntryCount) {
typeArgs.splice(tupleIndeterminateIndex, 0, typeArgs[tupleIndeterminateIndex]);
}
@ -1351,7 +1360,16 @@ function getSequencePatternInfo(
if (typeArgs.length === patternEntryCount) {
let isDefiniteNoMatch = false;
let isPotentialNoMatch = tupleIndeterminateIndex >= 0;
if (patternStarEntryIndex !== undefined && patternEntryCount === 1) {
// If the pattern includes a "star entry" and the tuple includes an
// indeterminate-length entry that aligns to the star entry, we can
// assume it will always match.
if (
patternStarEntryIndex !== undefined &&
tupleIndeterminateIndex >= 0 &&
pattern.entries.length - 1 === tupleDeterminateEntryCount &&
patternStarEntryIndex === tupleIndeterminateIndex
) {
isPotentialNoMatch = false;
}

View File

@ -72,7 +72,8 @@ _T_co = TypeVar("_T_co", covariant=True)
class SeqProto(Protocol[_T_co]):
def __reversed__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]:
...
def test_protocol(value_to_match: SeqProto[str]):
@ -277,8 +278,11 @@ def test_union(
class SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
def __le__(self, __other: Any) -> bool: ...
def __lt__(self, __other: Any) -> bool:
...
def __le__(self, __other: Any) -> bool:
...
SupportsLessThanT = TypeVar("SupportsLessThanT", bound=SupportsLessThan)
@ -397,10 +401,12 @@ class A(Generic[_T]):
a: _T
class B: ...
class B:
...
class C: ...
class C:
...
AAlias = A
@ -524,7 +530,7 @@ def test_tuple_with_subpattern(
reveal_type(b, expected_text="str")
def test_unbounded_tuple(
def test_unbounded_tuple1(
subj: tuple[int] | tuple[str, str] | tuple[int, Unpack[tuple[str, ...]], complex],
):
match subj:
@ -569,3 +575,13 @@ def test_unbounded_tuple_4(subj: tuple[str, ...]):
reveal_type(subj, expected_text="tuple[str]")
case x:
reveal_type(subj, expected_text="tuple[str, ...]")
def test_unbounded_tuple_5(subj: tuple[int, Unpack[tuple[str, ...]]]):
match subj:
case x, *rest:
reveal_type(subj, expected_text="tuple[int, *tuple[str, ...]]")
reveal_type(x, expected_text="int")
reveal_type(rest, expected_text="list[str]")
case x:
reveal_type(x, expected_text="Never")