Added support for type guard based on a.b is None or a.b is not None patterns where b is a member variable that distinguishes two different classes. (#3273)

Co-authored-by: Eric Traut <erictr@microsoft.com>
This commit is contained in:
Eric Traut 2022-04-01 12:06:53 -06:00 committed by GitHub
parent dc153cb35d
commit 8660a6f870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 180 additions and 3 deletions

View File

@ -171,6 +171,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `type(x) is T` and `type(x) is not T`
* `x is E` and `x is not E` (where E is a literal enum or bool)
* `x == L` and `x != L` (where L is a literal expression)
* `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None)
* `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type)
* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field with a literal type)
* `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)

View File

@ -70,6 +70,8 @@ import {
getTypeVarScopeId,
isLiteralType,
isLiteralTypeOrUnion,
isMaybeDescriptorInstance,
isProperty,
isTupleClass,
isUnboundedTupleClass,
lookUpClassMember,
@ -321,7 +323,7 @@ export function getTypeNarrowingCallback(
const memberName = testExpression.leftExpression.memberName;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
return narrowTypeForDiscriminatedLiteralFieldComparison(
evaluator,
type,
memberName.value,
@ -346,7 +348,7 @@ export function getTypeNarrowingCallback(
rightType.literalValue !== undefined
) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
return narrowTypeForDiscriminatedLiteralFieldComparison(
evaluator,
type,
memberName.value,
@ -356,6 +358,25 @@ export function getTypeNarrowingCallback(
};
}
}
// Look for X.Y is None or X.Y is not None
// These are commonly-used patterns used in control flow.
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression) &&
testExpression.rightExpression.nodeType === ParseNodeType.Constant &&
testExpression.rightExpression.constType === KeywordType.None
) {
const memberName = testExpression.leftExpression.memberName;
return (type: Type) => {
return narrowTypeForDiscriminatedFieldNoneComparison(
evaluator,
type,
memberName.value,
adjIsPositiveTest
);
};
}
}
if (testExpression.operator === OperatorType.In) {
@ -1423,7 +1444,7 @@ function narrowTypeForDiscriminatedTupleComparison(
// Attempts to narrow a type based on a comparison (equal or not equal)
// between a discriminating field that has a declared literal type to a
// literal value.
function narrowTypeForDiscriminatedFieldComparison(
function narrowTypeForDiscriminatedLiteralFieldComparison(
evaluator: TypeEvaluator,
referenceType: Type,
memberName: string,
@ -1456,6 +1477,53 @@ function narrowTypeForDiscriminatedFieldComparison(
return narrowedType;
}
// Attempts to narrow a type based on a comparison (equal or not equal)
// between a discriminating field that has a declared None type to a
// None.
function narrowTypeForDiscriminatedFieldNoneComparison(
evaluator: TypeEvaluator,
referenceType: Type,
memberName: string,
isPositiveTest: boolean
): Type {
return mapSubtypes(referenceType, (subtype) => {
let memberInfo: ClassMember | undefined;
if (isClassInstance(subtype)) {
memberInfo = lookUpObjectMember(subtype, memberName);
} else if (isInstantiableClass(subtype)) {
memberInfo = lookUpClassMember(subtype, memberName);
}
if (memberInfo && memberInfo.isTypeDeclared) {
const memberType = evaluator.makeTopLevelTypeVarsConcrete(evaluator.getTypeOfMember(memberInfo));
let canNarrow = true;
if (isPositiveTest) {
doForEachSubtype(memberType, (memberSubtype) => {
memberSubtype = evaluator.makeTopLevelTypeVarsConcrete(memberSubtype);
// Don't attempt to narrow if the member is a descriptor or property.
if (isProperty(memberSubtype) || isMaybeDescriptorInstance(memberSubtype)) {
canNarrow = false;
}
if (isAnyOrUnknown(memberSubtype) || isNoneInstance(memberSubtype) || isNever(memberSubtype)) {
canNarrow = false;
}
});
} else {
canNarrow = isNoneInstance(memberType);
}
if (canNarrow) {
return undefined;
}
}
return subtype;
});
}
// Attempts to narrow a type based on a "type(x) is y" or "type(x) is not y" check.
function narrowTypeForTypeIs(type: Type, classType: ClassType, isPositiveTest: boolean) {
return mapSubtypes(type, (subtype) => {

View File

@ -0,0 +1,102 @@
# This sample tests the type narrowing case for unions of NamedTuples
# where one or more of the entries is tested against type None by attribute.
from typing import NamedTuple, Optional, Union
IntFirst = NamedTuple(
"IntFirst",
[
("first", int),
("second", None),
],
)
StrSecond = NamedTuple(
"StrSecond",
[
("first", None),
("second", str),
],
)
def func1(a: Union[IntFirst, StrSecond]) -> IntFirst:
if a.second is None:
reveal_type(a, expected_text="IntFirst")
return a
else:
reveal_type(a, expected_text="StrSecond")
raise ValueError()
UnionFirst = NamedTuple(
"UnionFirst",
[
("first", Union[None, int]),
("second", None),
],
)
def func2(a: Union[UnionFirst, StrSecond]):
if a.first is None:
reveal_type(a, expected_text="UnionFirst | StrSecond")
else:
reveal_type(a, expected_text="UnionFirst")
class A:
@property
def prop1(self) -> Optional[int]:
...
member1: None
member2: Optional[int]
member3: Optional[int]
member4: Optional[int]
class B:
@property
def prop1(self) -> int:
...
member1: int
member2: Optional[int]
member3: None
member4: int
def func3(c: Union[A, B]):
if c.prop1 is None:
reveal_type(c, expected_text="A | B")
else:
reveal_type(c, expected_text="A | B")
def func4(c: Union[A, B]):
if c.member1 is None:
reveal_type(c, expected_text="A")
else:
reveal_type(c, expected_text="B")
def func5(c: Union[A, B]):
if c.member2 is None:
reveal_type(c, expected_text="A | B")
else:
reveal_type(c, expected_text="A | B")
def func6(c: Union[A, B]):
if c.member3 is not None:
reveal_type(c, expected_text="A")
else:
reveal_type(c, expected_text="A | B")
def func7(c: Union[A, B]):
if c.member4 is not None:
reveal_type(c, expected_text="A | B")
else:
reveal_type(c, expected_text="A")

View File

@ -386,6 +386,12 @@ test('TypeNarrowingLiteralMember1', () => {
TestUtils.validateResults(analysisResults, 0);
});
test('TypeNarrowingNoneMember1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingNoneMember1.py']);
TestUtils.validateResults(analysisResults, 0);
});
test('TypeNarrowingTuple1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingTuple1.py']);