Improved type narrowing for x == L pattern to include x is L, since they are equivalent. Likewise, extended x.y == L pattern to include x.y is L.

This commit is contained in:
Eric Traut 2022-01-05 15:19:08 -07:00
parent 5829dcbbdc
commit 4fe3f1f7c0
4 changed files with 85 additions and 61 deletions

View File

@ -150,9 +150,8 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x is None` and `x is not None`
* `x == None` and `x != None`
* `type(x) is T` and `type(x) is not T`
* `x is E` and `x is not E` (where E is an enum value or True or False)
* `x == L` and `x != L` (where L is a literal expression)
* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type)
* `x == L` and `x != L` and `L == x` and `L != x` and `x is L` and `x is not L` (where L is a literal expression)
* `x.y == L` and `x.y != L` and `x.y is L` and `x.y is not L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type)
* `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)

View File

@ -177,15 +177,11 @@ export function getTypeNarrowingCallback(
}
}
// Look for "X is Y" or "X is not Y" where Y is a an enum or False or True.
// Look for "X is <literal>" or "X is not <literal>"
if (isOrIsNotOperator) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (
isClassInstance(rightType) &&
(ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
rightType.literalValue !== undefined
) {
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForLiteralComparison(
evaluator,
@ -219,6 +215,7 @@ export function getTypeNarrowingCallback(
}
}
// Look for <literal> == X or <literal> != X
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
if (isClassInstance(leftType) && leftType.literalValue !== undefined) {
@ -234,26 +231,6 @@ export function getTypeNarrowingCallback(
}
}
// Look for X.Y == <literal> or X.Y != <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
// Look for X[<literal>] == <literal> or X[<literal>] != <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.Index &&
@ -324,6 +301,26 @@ export function getTypeNarrowingCallback(
}
}
}
// Look for X.Y == <literal> or X.Y != <literal> or X.Y is <literal> or X.Y is not <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
}
if (testExpression.operator === OperatorType.In) {
@ -1340,8 +1337,6 @@ function narrowTypeForDiscriminatedFieldComparison(
literalType: ClassType,
isPositiveTest: boolean
): Type {
let canNarrow = true;
const narrowedType = mapSubtypes(referenceType, (subtype) => {
let memberInfo: ClassMember | undefined;
if (isClassInstance(subtype)) {
@ -1362,11 +1357,10 @@ function narrowTypeForDiscriminatedFieldComparison(
}
}
canNarrow = false;
return subtype;
});
return canNarrow ? narrowedType : referenceType;
return narrowedType;
}
// Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check.

View File

@ -4,41 +4,30 @@
from typing import Literal, Union
def requires_a(p1: Literal["a"]):
pass
def requires_bc(p1: Literal["b", "c"]):
pass
def func_1(p1: Literal["a", "b", "c"]):
if p1 != "b":
if p1 == "c":
t1: Literal["Literal['c']"] = reveal_type(p1)
pass
else:
requires_a(p1)
t2: Literal["Literal['a']"] = reveal_type(p1)
if p1 != "a":
requires_bc(p1)
t3: Literal["Literal['c', 'b']"] = reveal_type(p1)
else:
requires_a(p1)
t4: Literal["Literal['a']"] = reveal_type(p1)
if "a" != p1:
requires_bc(p1)
t5: Literal["Literal['c', 'b']"] = reveal_type(p1)
else:
requires_a(p1)
def requires_7(p1: Literal[7]):
pass
t6: Literal["Literal['a']"] = reveal_type(p1)
def func2(p1: Literal[1, 4, 7]):
if 4 == p1 or 1 == p1:
pass
t1: Literal["Literal[4, 1]"] = reveal_type(p1)
else:
requires_7(p1)
t2: Literal["Literal[7]"] = reveal_type(p1)
def func3(a: Union[int, None]):

View File

@ -26,43 +26,85 @@ class D:
kind: Literal[1, 2, 3]
def foo_obj1(c: Union[A, B]):
def eq_obj1(c: Union[A, B]):
if c.kind == "A":
tc1: Literal["A"] = reveal_type(c)
else:
tc2: Literal["B"] = reveal_type(c)
def foo_obj2(c: Union[A, B]):
def is_obj1(c: Union[A, B]):
if c.kind is "A":
tc1: Literal["A"] = reveal_type(c)
else:
tc2: Literal["B"] = reveal_type(c)
def eq_obj2(c: Union[A, B]):
if c.kind != "A":
tc1: Literal["B"] = reveal_type(c)
else:
tc2: Literal["A"] = reveal_type(c)
def foo_obj3(c: Union[A, B, C]):
if c.kind == "A":
tc1: Literal["A | B | C"] = reveal_type(c)
def is_obj2(c: Union[A, B]):
if c.kind is not "A":
tc1: Literal["B"] = reveal_type(c)
else:
tc2: Literal["A | B | C"] = reveal_type(c)
tc2: Literal["A"] = reveal_type(c)
def foo_obj4(c: Union[A, B]):
def eq_obj3(c: Union[A, B, C]):
if c.kind == "A":
tc1: Literal["A | C"] = reveal_type(c)
else:
tc2: Literal["B | C"] = reveal_type(c)
def is_obj3(c: Union[A, B, C]):
if c.kind is "A":
tc1: Literal["A | C"] = reveal_type(c)
else:
tc2: Literal["B | C"] = reveal_type(c)
def eq_obj4(c: Union[A, B]):
if c.d == 1:
tc1: Literal["A"] = reveal_type(c)
elif c.d == 3:
tc2: Literal["A | B"] = reveal_type(c)
def foo_obj5(d: D):
def is_obj4(c: Union[A, B]):
if c.d is 1:
tc1: Literal["A"] = reveal_type(c)
elif c.d is 3:
tc2: Literal["A | B"] = reveal_type(c)
def eq_obj5(d: D):
if d.kind == 1:
td1: Literal["D"] = reveal_type(d)
elif d.kind == 2:
td2: Literal["D"] = reveal_type(d)
def foo_class2(c: Union[Type[A], Type[B]]):
def is_obj5(d: D):
if d.kind is 1:
td1: Literal["D"] = reveal_type(d)
elif d.kind is 2:
td2: Literal["D"] = reveal_type(d)
def eq_class2(c: Union[Type[A], Type[B]]):
if c.kind_class == "A":
tc1: Literal["Type[A]"] = reveal_type(c)
else:
tc2: Literal["Type[B]"] = reveal_type(c)
def is_class2(c: Union[Type[A], Type[B]]):
if c.kind_class is "A":
tc1: Literal["Type[A]"] = reveal_type(c)
else:
tc2: Literal["Type[B]"] = reveal_type(c)