Improved type narrowing for x == L pattern to include x is L, since they are equivalent. Likewise, extended x.y == L pattern to include x.y is L.

This commit is contained in:
Eric Traut 2022-01-05 15:19:08 -07:00
parent 5829dcbbdc
commit 4fe3f1f7c0
4 changed files with 85 additions and 61 deletions

View File

@ -150,9 +150,8 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x is None` and `x is not None` * `x is None` and `x is not None`
* `x == None` and `x != None` * `x == None` and `x != None`
* `type(x) is T` and `type(x) is not T` * `type(x) is T` and `type(x) is not T`
* `x is E` and `x is not E` (where E is an enum value or True or False) * `x == L` and `x != L` and `L == x` and `L != x` and `x is L` and `x is not L` (where L is a literal expression)
* `x == L` and `x != L` (where L is a literal expression) * `x.y == L` and `x.y != L` and `x.y is L` and `x.y is not L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type)
* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type)
* `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type) * `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)

View File

@ -177,15 +177,11 @@ export function getTypeNarrowingCallback(
} }
} }
// Look for "X is Y" or "X is not Y" where Y is a an enum or False or True. // Look for "X is <literal>" or "X is not <literal>"
if (isOrIsNotOperator) { if (isOrIsNotOperator) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if ( if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
isClassInstance(rightType) &&
(ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
rightType.literalValue !== undefined
) {
return (type: Type) => { return (type: Type) => {
return narrowTypeForLiteralComparison( return narrowTypeForLiteralComparison(
evaluator, evaluator,
@ -219,6 +215,7 @@ export function getTypeNarrowingCallback(
} }
} }
// Look for <literal> == X or <literal> != X
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) { if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type; const leftType = evaluator.getTypeOfExpression(testExpression.leftExpression).type;
if (isClassInstance(leftType) && leftType.literalValue !== undefined) { if (isClassInstance(leftType) && leftType.literalValue !== undefined) {
@ -234,26 +231,6 @@ export function getTypeNarrowingCallback(
} }
} }
// Look for X.Y == <literal> or X.Y != <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
// Look for X[<literal>] == <literal> or X[<literal>] != <literal> // Look for X[<literal>] == <literal> or X[<literal>] != <literal>
if ( if (
testExpression.leftExpression.nodeType === ParseNodeType.Index && testExpression.leftExpression.nodeType === ParseNodeType.Index &&
@ -324,6 +301,26 @@ export function getTypeNarrowingCallback(
} }
} }
} }
// Look for X.Y == <literal> or X.Y != <literal> or X.Y is <literal> or X.Y is not <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
evaluator,
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
} }
if (testExpression.operator === OperatorType.In) { if (testExpression.operator === OperatorType.In) {
@ -1340,8 +1337,6 @@ function narrowTypeForDiscriminatedFieldComparison(
literalType: ClassType, literalType: ClassType,
isPositiveTest: boolean isPositiveTest: boolean
): Type { ): Type {
let canNarrow = true;
const narrowedType = mapSubtypes(referenceType, (subtype) => { const narrowedType = mapSubtypes(referenceType, (subtype) => {
let memberInfo: ClassMember | undefined; let memberInfo: ClassMember | undefined;
if (isClassInstance(subtype)) { if (isClassInstance(subtype)) {
@ -1362,11 +1357,10 @@ function narrowTypeForDiscriminatedFieldComparison(
} }
} }
canNarrow = false;
return subtype; return subtype;
}); });
return canNarrow ? narrowedType : referenceType; return narrowedType;
} }
// Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check. // Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check.

View File

@ -4,41 +4,30 @@
from typing import Literal, Union from typing import Literal, Union
def requires_a(p1: Literal["a"]):
pass
def requires_bc(p1: Literal["b", "c"]):
pass
def func_1(p1: Literal["a", "b", "c"]): def func_1(p1: Literal["a", "b", "c"]):
if p1 != "b": if p1 != "b":
if p1 == "c": if p1 == "c":
t1: Literal["Literal['c']"] = reveal_type(p1)
pass pass
else: else:
requires_a(p1) t2: Literal["Literal['a']"] = reveal_type(p1)
if p1 != "a": if p1 != "a":
requires_bc(p1) t3: Literal["Literal['c', 'b']"] = reveal_type(p1)
else: else:
requires_a(p1) t4: Literal["Literal['a']"] = reveal_type(p1)
if "a" != p1: if "a" != p1:
requires_bc(p1) t5: Literal["Literal['c', 'b']"] = reveal_type(p1)
else: else:
requires_a(p1) t6: Literal["Literal['a']"] = reveal_type(p1)
def requires_7(p1: Literal[7]):
pass
def func2(p1: Literal[1, 4, 7]): def func2(p1: Literal[1, 4, 7]):
if 4 == p1 or 1 == p1: if 4 == p1 or 1 == p1:
pass t1: Literal["Literal[4, 1]"] = reveal_type(p1)
else: else:
requires_7(p1) t2: Literal["Literal[7]"] = reveal_type(p1)
def func3(a: Union[int, None]): def func3(a: Union[int, None]):

View File

@ -26,43 +26,85 @@ class D:
kind: Literal[1, 2, 3] kind: Literal[1, 2, 3]
def foo_obj1(c: Union[A, B]): def eq_obj1(c: Union[A, B]):
if c.kind == "A": if c.kind == "A":
tc1: Literal["A"] = reveal_type(c) tc1: Literal["A"] = reveal_type(c)
else: else:
tc2: Literal["B"] = reveal_type(c) tc2: Literal["B"] = reveal_type(c)
def foo_obj2(c: Union[A, B]): def is_obj1(c: Union[A, B]):
if c.kind is "A":
tc1: Literal["A"] = reveal_type(c)
else:
tc2: Literal["B"] = reveal_type(c)
def eq_obj2(c: Union[A, B]):
if c.kind != "A": if c.kind != "A":
tc1: Literal["B"] = reveal_type(c) tc1: Literal["B"] = reveal_type(c)
else: else:
tc2: Literal["A"] = reveal_type(c) tc2: Literal["A"] = reveal_type(c)
def foo_obj3(c: Union[A, B, C]): def is_obj2(c: Union[A, B]):
if c.kind == "A": if c.kind is not "A":
tc1: Literal["A | B | C"] = reveal_type(c) tc1: Literal["B"] = reveal_type(c)
else: else:
tc2: Literal["A | B | C"] = reveal_type(c) tc2: Literal["A"] = reveal_type(c)
def foo_obj4(c: Union[A, B]): def eq_obj3(c: Union[A, B, C]):
if c.kind == "A":
tc1: Literal["A | C"] = reveal_type(c)
else:
tc2: Literal["B | C"] = reveal_type(c)
def is_obj3(c: Union[A, B, C]):
if c.kind is "A":
tc1: Literal["A | C"] = reveal_type(c)
else:
tc2: Literal["B | C"] = reveal_type(c)
def eq_obj4(c: Union[A, B]):
if c.d == 1: if c.d == 1:
tc1: Literal["A"] = reveal_type(c) tc1: Literal["A"] = reveal_type(c)
elif c.d == 3: elif c.d == 3:
tc2: Literal["A | B"] = reveal_type(c) tc2: Literal["A | B"] = reveal_type(c)
def foo_obj5(d: D): def is_obj4(c: Union[A, B]):
if c.d is 1:
tc1: Literal["A"] = reveal_type(c)
elif c.d is 3:
tc2: Literal["A | B"] = reveal_type(c)
def eq_obj5(d: D):
if d.kind == 1: if d.kind == 1:
td1: Literal["D"] = reveal_type(d) td1: Literal["D"] = reveal_type(d)
elif d.kind == 2: elif d.kind == 2:
td2: Literal["D"] = reveal_type(d) td2: Literal["D"] = reveal_type(d)
def foo_class2(c: Union[Type[A], Type[B]]): def is_obj5(d: D):
if d.kind is 1:
td1: Literal["D"] = reveal_type(d)
elif d.kind is 2:
td2: Literal["D"] = reveal_type(d)
def eq_class2(c: Union[Type[A], Type[B]]):
if c.kind_class == "A": if c.kind_class == "A":
tc1: Literal["Type[A]"] = reveal_type(c) tc1: Literal["Type[A]"] = reveal_type(c)
else: else:
tc2: Literal["Type[B]"] = reveal_type(c) tc2: Literal["Type[B]"] = reveal_type(c)
def is_class2(c: Union[Type[A], Type[B]]):
if c.kind_class is "A":
tc1: Literal["Type[A]"] = reveal_type(c)
else:
tc2: Literal["Type[B]"] = reveal_type(c)