Added support for narrowing types based on the pattern "A.B == <literal>" and "A.B != <literal>" when A has a union type and all members of the union have a field "B" with a declared literal type that discriminates one sub-type from another.

This commit is contained in:
Eric Traut 2020-08-15 16:11:23 -07:00
parent 8244e9a03e
commit 605bcc482c
5 changed files with 176 additions and 4 deletions

View File

@ -74,8 +74,9 @@ def func(b: Optional[Union[str, List[str]]]):
```
If the type narrowing logic exhausts all possible subtypes, it can be assumed that a code path will never be taken. For example, consider the following:
```python
def (a: Union[Foo, Bar]):
def func(a: Union[Foo, Bar]):
if isinstance(a, Foo):
# “a” must be type Foo
a.do_something_1()
@ -89,3 +90,22 @@ def (a: Union[Foo, Bar]):
In this case, the type of parameter “a” is initially “Union[Foo, Bar]”. Within the “if” clause, the type narrowing logic will conclude that it must be of type “Foo”. Within the “elif” clause, it must be of type “Bar”. What type is it within the “else” clause? The type narrowing system has eliminated all possible subtypes, so it gives it the type “Never”. This is generally indicates that theres a logic error in the code because theres way that code block will ever be executed.
Narrowing is also used to discriminate between members of a union type when the union members have a common member with literal declared types that differentiate the types.
```python
class Foo:
kind: Literal["Foo"]
def do_something_1(self):
pass
class Bar:
kind: Literal["Bar"]
def do_something_2(self):
pass
def func(a: Union[Foo, Bar]):
if a.kind == "Foo":
a.do_something_1()
else:
a.do_something_2()
```

View File

@ -173,6 +173,7 @@ import {
getSpecializedTupleType,
getTypeVarArgumentsRecursive,
isEllipsisType,
isLiteralType,
isNoReturnType,
isOptionalType,
isParamSpecType,
@ -11078,8 +11079,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, printTypeFlags:
}
}
// Look for X == <literal> or X != <literal>
if (equalsOrNotEqualsOperator) {
// Look for X == <literal> or X != <literal>
const adjIsPositiveTest =
testExpression.operator === OperatorType.Equals ? isPositiveTest : !isPositiveTest;
@ -11100,6 +11101,25 @@ export function createTypeEvaluator(importLookup: ImportLookup, printTypeFlags:
};
}
}
// Look for X.Y == <literal> or X.Y != <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.MemberAccess &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.leftExpression)
) {
const rightType = getTypeOfExpression(testExpression.rightExpression).type;
const memberName = testExpression.leftExpression.memberName;
if (isObject(rightType) && rightType.classType.literalValue !== undefined) {
return (type: Type) => {
return narrowTypeForDiscriminatedFieldComparison(
type,
memberName.value,
rightType,
adjIsPositiveTest
);
};
}
}
}
}
@ -11361,6 +11381,43 @@ export function createTypeEvaluator(importLookup: ImportLookup, printTypeFlags:
return canNarrow ? narrowedType : referenceType;
}
// Attempts to narrow a type (make it more constrained) based on a comparison
// (equal or not equal) between a discriminating node that has a declared
// literal type to a literal value.
function narrowTypeForDiscriminatedFieldComparison(
referenceType: Type,
memberName: string,
literalType: ObjectType,
isPositiveTest: boolean
): Type {
let canNarrow = true;
const narrowedType = doForSubtypes(referenceType, (subtype) => {
subtype = transformTypeObjectToClass(subtype);
let memberInfo: ClassMember | undefined;
if (isObject(subtype)) {
memberInfo = lookUpObjectMember(subtype, memberName);
} else if (isClass(subtype)) {
memberInfo = lookUpClassMember(subtype, memberName);
}
if (memberInfo && memberInfo.isTypeDeclared) {
const memberType = getTypeOfMember(memberInfo);
if (isLiteralType(memberType, /* allowLiteralUnions */ false)) {
const isAssignable = canAssignType(memberType, literalType, new DiagnosticAddendum());
return isAssignable === isPositiveTest ? subtype : undefined;
}
}
canNarrow = false;
return subtype;
});
return canNarrow ? narrowedType : referenceType;
}
// Attempts to narrow a type (make it more constrained) based on a comparison
// (equal or not equal) to a literal value.
function narrowTypeForLiteralComparison(

View File

@ -48,6 +48,9 @@ export interface ClassMember {
// True if instance member, false if class member
isInstanceMember: boolean;
// True if member has declared type, false if inferred
isTypeDeclared: boolean;
}
export const enum ClassMemberLookupFlags {
@ -450,6 +453,18 @@ export function getSpecializedTupleType(type: Type): ClassType | undefined {
return specializeType(tupleClass, typeVarMap) as ClassType;
}
export function isLiteralType(type: Type, allowLiteralUnions = true): boolean {
if (isObject(type)) {
return type.classType.literalValue !== undefined;
}
if (type.category === TypeCategory.Union) {
return !type.subtypes.some((t) => !isObject(t) || t.classType.literalValue === undefined);
}
return false;
}
export function isEllipsisType(type: Type): boolean {
// Ellipses are translated into both a special form of "Any" or
// a distinct class depending on the context.
@ -690,11 +705,13 @@ export function lookUpClassMember(
if ((flags & ClassMemberLookupFlags.SkipInstanceVariables) === 0) {
const symbol = memberFields.get(memberName);
if (symbol && symbol.isInstanceMember()) {
if (!declaredTypesOnly || symbol.hasTypedDeclarations()) {
const hasDeclaredType = symbol.hasTypedDeclarations();
if (!declaredTypesOnly || hasDeclaredType) {
return {
symbol,
isInstanceMember: true,
classType: specializedMroClass,
isTypeDeclared: hasDeclaredType,
};
}
}
@ -707,7 +724,8 @@ export function lookUpClassMember(
(flags & ClassMemberLookupFlags.SkipIfInaccessibleToInstance) === 0 ||
!symbol.isInaccessibleToInstance()
) {
if (!declaredTypesOnly || symbol.hasTypedDeclarations()) {
const hasDeclaredType = symbol.hasTypedDeclarations();
if (!declaredTypesOnly || hasDeclaredType) {
let isInstanceMember = false;
// For data classes and typed dicts, variables that are declared
@ -729,6 +747,7 @@ export function lookUpClassMember(
symbol,
isInstanceMember,
classType: specializedMroClass,
isTypeDeclared: hasDeclaredType,
};
}
}
@ -747,6 +766,7 @@ export function lookUpClassMember(
symbol: Symbol.createWithType(SymbolFlags.None, UnknownType.create()),
isInstanceMember: false,
classType: UnknownType.create(),
isTypeDeclared: false,
};
}
} else if (isAnyOrUnknown(classType)) {
@ -756,6 +776,7 @@ export function lookUpClassMember(
symbol: Symbol.createWithType(SymbolFlags.None, UnknownType.create()),
isInstanceMember: false,
classType: UnknownType.create(),
isTypeDeclared: false,
};
}

View File

@ -330,6 +330,12 @@ test('TypeConstraint16', () => {
validateResults(analysisResults, 2);
});
test('TypeConstraint17', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeConstraint17.py']);
validateResults(analysisResults, 8);
});
test('CircularBaseClass', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['circularBaseClass.py']);

View File

@ -0,0 +1,68 @@
# This sample tests type narrowing based on member accesses
# to members that have literal types.
from typing import ClassVar, Literal, Type, Union
class A:
kind: Literal["A"]
kind_class: ClassVar[Literal["A"]]
a: str
class B:
kind: Literal["B"]
kind_class: ClassVar[Literal["B"]]
b: int
class C:
kind: str
kind_class: str
c: int
def foo_obj1(c: Union[A, B]):
if c.kind == "A":
c.a
# This should generate an error
c.b
else:
c.b
# This should generate an error
c.a
def foo_obj2(c: Union[A, B]):
if c.kind != "A":
# This should generate an error
c.a
c.b
else:
# This should generate an error
c.b
c.a
def foo_obj3(c: Union[A, B, C]):
if c.kind == "A":
# This should generate an error
c.a
else:
# This should generate an error
c.a
def foo_class2(c: Union[Type[A], Type[B]]):
if c.kind_class == "A":
c.a
# This should generate an error
c.b
else:
c.b
# This should generate an error
c.a