Fixed bug that results in incorrect type narrowing in the negative (fall-through) case of a match expression when the subject expression is an unbounded tuple. This addresses #8219.

This commit is contained in:
Eric Traut 2024-06-25 13:56:02 +02:00
parent 2672441c22
commit 24f64f7d26
2 changed files with 137 additions and 52 deletions

View File

@ -219,9 +219,20 @@ function narrowTypeBasedOnSequencePattern(
let canNarrowTuple = entry.isTuple;
// Don't attempt to narrow tuples in the negative case if the subject
// contains indeterminate-length entries.
if (!isPositiveTest && entry.isIndeterminateLength) {
canNarrowTuple = false;
// contains indeterminate-length entries or the tuple is of indeterminate
// length.
if (!isPositiveTest) {
if (entry.isIndeterminateLength) {
canNarrowTuple = false;
}
if (
isClassInstance(entry.subtype) &&
entry.subtype.tupleTypeArguments &&
entry.subtype.tupleTypeArguments.some((typeArg) => typeArg.isUnbounded)
) {
canNarrowTuple = false;
}
}
// If the subject has an indeterminate length but the pattern does not accept

View File

@ -2,8 +2,22 @@
# described in PEP 634) that contain sequence patterns.
from enum import Enum
from typing import Any, Generic, Iterator, List, Literal, Protocol, Reversible, Sequence, Tuple, TypeVar, Union
from typing_extensions import Unpack # pyright: ignore[reportMissingModuleSource]
from typing import (
Any,
Generic,
Iterator,
List,
Literal,
Protocol,
Reversible,
Sequence,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Unpack # pyright: ignore[reportMissingModuleSource]
def test_unknown(value_to_match):
match value_to_match:
@ -25,13 +39,13 @@ def test_unknown(value_to_match):
reveal_type(d1, expected_text="Unknown")
reveal_type(d2, expected_text="list[Unknown]")
reveal_type(d3, expected_text="Unknown")
case 3, *e1:
reveal_type(e1, expected_text="list[Unknown]")
case "hi", *f1:
reveal_type(f1, expected_text="list[Unknown]")
case *g1, "hi":
reveal_type(g1, expected_text="list[Unknown]")
@ -56,9 +70,11 @@ def test_reversible(value_to_match: Reversible[int]):
_T_co = TypeVar("_T_co", covariant=True)
class SeqProto(Protocol[_T_co]):
def __reversed__(self) -> Iterator[_T_co]: ...
def test_protocol(value_to_match: SeqProto[str]):
match value_to_match:
case [*a1]:
@ -67,7 +83,6 @@ def test_protocol(value_to_match: SeqProto[str]):
reveal_type(b1, expected_text="SeqProto[str]")
def test_list(value_to_match: List[str]):
match value_to_match:
case a1, a2:
@ -90,19 +105,20 @@ def test_list(value_to_match: List[str]):
reveal_type(d2, expected_text="list[str]")
reveal_type(d3, expected_text="str")
reveal_type(value_to_match, expected_text="List[str]")
case 3, *e1:
reveal_type(e1, expected_text="Never")
reveal_type(value_to_match, expected_text="Never")
case "hi", *f1:
reveal_type(f1, expected_text="list[str]")
reveal_type(value_to_match, expected_text="List[str]")
case *g1, "hi":
reveal_type(g1, expected_text="list[str]")
reveal_type(value_to_match, expected_text="List[str]")
def test_open_ended_tuple(value_to_match: Tuple[str, ...]):
match value_to_match:
case a1, a2:
@ -125,19 +141,20 @@ def test_open_ended_tuple(value_to_match: Tuple[str, ...]):
reveal_type(d2, expected_text="list[str]")
reveal_type(d3, expected_text="str")
reveal_type(value_to_match, expected_text="Tuple[str, ...]")
case 3, *e1:
reveal_type(e1, expected_text="Never")
reveal_type(value_to_match, expected_text="Never")
case "hi", *f1:
reveal_type(f1, expected_text="list[str]")
reveal_type(value_to_match, expected_text="Tuple[str, ...]")
case *g1, "hi":
reveal_type(g1, expected_text="list[str]")
reveal_type(value_to_match, expected_text="Tuple[str, ...]")
def test_definite_tuple(value_to_match: Tuple[int, str, float, complex]):
match value_to_match:
case a1, a2, a3, a4 if value_to_match[0] == 0:
@ -162,11 +179,11 @@ def test_definite_tuple(value_to_match: Tuple[int, str, float, complex]):
reveal_type(d2, expected_text="list[str | float]")
reveal_type(d3, expected_text="complex")
reveal_type(value_to_match, expected_text="Tuple[int, str, float, complex]")
case 3, *e1:
reveal_type(e1, expected_text="list[str | float | complex]")
reveal_type(value_to_match, expected_text="Tuple[int, str, float, complex]")
case "hi", *f1:
reveal_type(f1, expected_text="Never")
reveal_type(value_to_match, expected_text="Never")
@ -174,49 +191,86 @@ def test_definite_tuple(value_to_match: Tuple[int, str, float, complex]):
case *g1, 3j:
reveal_type(g1, expected_text="list[int | str | float]")
reveal_type(value_to_match, expected_text="Tuple[int, str, float, complex]")
case *h1, "hi":
reveal_type(h1, expected_text="Never")
reveal_type(value_to_match, expected_text="Never")
def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, float, complex], List[str], Tuple[float, ...], Any]):
def test_union(
value_to_match: Union[
Tuple[complex, complex],
Tuple[int, str, float, complex],
List[str],
Tuple[float, ...],
Any,
],
):
match value_to_match:
case a1, a2, a3, a4 if value_to_match[0] == 0:
reveal_type(a1, expected_text="int | str | float | Any")
reveal_type(a2, expected_text="str | float | Any")
reveal_type(a3, expected_text="float | str | Any")
reveal_type(a4, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="tuple[int, str, float, complex] | List[str] | tuple[float, float, float, float] | Sequence[Any]")
reveal_type(
value_to_match,
expected_text="tuple[int, str, float, complex] | List[str] | tuple[float, float, float, float] | Sequence[Any]",
)
case *b1, b2 if value_to_match[0] == 0:
reveal_type(b1, expected_text="list[complex] | list[int | str | float] | list[str] | list[float] | list[Any]")
reveal_type(
b1,
expected_text="list[complex] | list[int | str | float] | list[str] | list[float] | list[Any]",
)
reveal_type(b2, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
reveal_type(
value_to_match,
expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]",
)
case c1, *c2 if value_to_match[0] == 0:
reveal_type(c1, expected_text="complex | int | str | float | Any")
reveal_type(c2, expected_text="list[complex] | list[str | float | complex] | list[str] | list[float] | list[Any]")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
reveal_type(
c2,
expected_text="list[complex] | list[str | float | complex] | list[str] | list[float] | list[Any]",
)
reveal_type(
value_to_match,
expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]",
)
case d1, *d2, d3 if value_to_match[0] == 0:
reveal_type(d1, expected_text="complex | int | str | float | Any")
reveal_type(d2, expected_text="list[Any] | list[str | float] | list[str] | list[float]")
reveal_type(
d2,
expected_text="list[Any] | list[str | float] | list[str] | list[float]",
)
reveal_type(d3, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")
reveal_type(
value_to_match,
expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]",
)
case 3, e1:
reveal_type(e1, expected_text="complex | float | Any")
reveal_type(value_to_match, expected_text="tuple[Literal[3], complex] | tuple[Literal[3], float] | Sequence[Any]")
reveal_type(
value_to_match,
expected_text="tuple[Literal[3], complex] | tuple[Literal[3], float] | Sequence[Any]",
)
case "hi", *f1:
reveal_type(f1, expected_text="list[str] | list[Any]")
reveal_type(value_to_match, expected_text="List[str] | Sequence[Any]")
case *g1, 3j:
reveal_type(g1, expected_text="list[complex] | list[int | str | float] | list[Any]")
reveal_type(value_to_match, expected_text="tuple[complex, complex] | Tuple[int, str, float, complex] | Sequence[Any]")
reveal_type(
g1, expected_text="list[complex] | list[int | str | float] | list[Any]"
)
reveal_type(
value_to_match,
expected_text="tuple[complex, complex] | Tuple[int, str, float, complex] | Sequence[Any]",
)
case *h1, "hi":
reveal_type(h1, expected_text="list[str] | list[Any]")
reveal_type(value_to_match, expected_text="List[str] | Sequence[Any]")
@ -226,6 +280,7 @@ class SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
def __le__(self, __other: Any) -> bool: ...
SupportsLessThanT = TypeVar("SupportsLessThanT", bound=SupportsLessThan)
@ -234,23 +289,23 @@ def sort(seq: List[SupportsLessThanT]) -> List[SupportsLessThanT]:
case [] | [_]:
reveal_type(seq, expected_text="List[SupportsLessThanT@sort]")
return seq
case [x, y] if x <= y:
reveal_type(seq, expected_text="List[SupportsLessThanT@sort]")
return seq
case [x, y]:
reveal_type(seq, expected_text="List[SupportsLessThanT@sort]")
return [y, x]
case [x, y, z] if x <= y <= z:
reveal_type(seq, expected_text="List[SupportsLessThanT@sort]")
return seq
case [x, y, z] if x > y > z:
reveal_type(seq, expected_text="List[SupportsLessThanT@sort]")
return [z, y, x]
case [p, *rest]:
a = sort([x for x in rest if x <= p])
b = sort([x for x in rest if p < x])
@ -266,6 +321,7 @@ def test_exceptions(seq: Union[str, bytes, bytearray]):
reveal_type(y, expected_text="Never")
return seq
def test_object1(seq: object):
match seq:
case (a1, a2) as a3:
@ -292,31 +348,32 @@ def test_object1(seq: object):
reveal_type(d3, expected_text="Unknown")
reveal_type(d4, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (3, *e1) as e2:
reveal_type(e1, expected_text="list[Unknown]")
reveal_type(e2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case ("hi", *f1) as f2:
case ("hi", *f1) as f2:
reveal_type(f1, expected_text="list[Unknown]")
reveal_type(f2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case (*g1, "hi") as g2:
reveal_type(g1, expected_text="list[Unknown]")
reveal_type(g2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
reveal_type(g2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
case [1, "hi", True] as h1:
case [1, "hi", True] as h1:
reveal_type(h1, expected_text="Sequence[int | str | bool]")
reveal_type(seq, expected_text="Sequence[int | str | bool]")
case [1, i1] as i2:
reveal_type(i1, expected_text="Unknown")
reveal_type(i2, expected_text="Sequence[Unknown]")
reveal_type(i2, expected_text="Sequence[Unknown]")
reveal_type(seq, expected_text="Sequence[Unknown]")
def test_object2(value_to_match: object):
match value_to_match:
case [*a1]:
@ -333,21 +390,26 @@ def test_sequence(value_to_match: Sequence[Any]):
reveal_type(b1, expected_text="Never")
_T = TypeVar("_T")
_T = TypeVar('_T')
class A(Generic[_T]):
a: _T
class B: ...
class C: ...
AAlias = A
AInt = A[int]
BOrC = B | C
def test_illegal_type_alias(m: object):
match m:
case AAlias(a=i):
@ -363,9 +425,10 @@ def test_illegal_type_alias(m: object):
case BOrC(a=i):
pass
def test_negative_narrowing1(subj: tuple[Literal[0]] | tuple[Literal[1]]):
match subj:
case (1,*a) | (*a):
case (1, *a) | (*a):
reveal_type(subj, expected_text="tuple[Literal[1]] | tuple[Literal[0]]")
reveal_type(a, expected_text="list[Any] | list[int]")
@ -376,7 +439,7 @@ def test_negative_narrowing1(subj: tuple[Literal[0]] | tuple[Literal[1]]):
def test_negative_narrowing2(subj: tuple[int, ...]):
match subj:
case (1,*a):
case (1, *a):
reveal_type(subj, expected_text="tuple[int, ...]")
reveal_type(a, expected_text="list[int]")
@ -448,7 +511,7 @@ class MyEnum(Enum):
def test_tuple_with_subpattern(
subj: Literal[MyEnum.A]
| tuple[Literal[MyEnum.B], int]
| tuple[Literal[MyEnum.C], str]
| tuple[Literal[MyEnum.C], str],
):
match subj:
case MyEnum.A:
@ -462,7 +525,7 @@ def test_tuple_with_subpattern(
def test_unbounded_tuple(
subj: tuple[int] | tuple[str, str] | tuple[int, Unpack[tuple[str, ...]], complex]
subj: tuple[int] | tuple[str, str] | tuple[int, Unpack[tuple[str, ...]], complex],
):
match subj:
case (x,):
@ -489,9 +552,20 @@ def test_unbounded_tuple_2(subj: tuple[int, str, Unpack[tuple[range, ...]]]) ->
case [1, "", *ts2]:
reveal_type(ts2, expected_text="list[range]")
def test_unbounded_tuple_3(subj: tuple[int, ...]):
match subj:
case []:
return
case x:
reveal_type(x, expected_text="tuple[int, ...]")
def test_unbounded_tuple_4(subj: tuple[str, ...]):
match subj:
case x, "":
reveal_type(subj, expected_text="tuple[str, Literal['']]")
case (x,):
reveal_type(subj, expected_text="tuple[str]")
case x:
reveal_type(subj, expected_text="tuple[str, ...]")