Added type narrowing logic for conditionals that involve comparisons between Literal expressions such as "animal: Literal['cow', 'horse'] \ if animal == 'cow'"

This commit is contained in:
Eric Traut 2020-04-12 01:24:43 -07:00
parent c54d2c83cd
commit b9f291afc3
4 changed files with 112 additions and 0 deletions

View File

@ -1708,6 +1708,23 @@ export class Binder extends ParseTreeWalker {
expressionList
);
}
// Look for X == <literal>, X != <literal> or <literal> == X, <literal> != X
if (equalsOrNotEqualsOperator) {
if (
expression.leftExpression.nodeType === ParseNodeType.StringList ||
expression.leftExpression.nodeType === ParseNodeType.Number ||
expression.leftExpression.nodeType === ParseNodeType.Constant
) {
return this._isNarrowingExpression(expression.rightExpression, expressionList);
} else if (
expression.rightExpression.nodeType === ParseNodeType.StringList ||
expression.rightExpression.nodeType === ParseNodeType.Number ||
expression.rightExpression.nodeType === ParseNodeType.Constant
) {
return this._isNarrowingExpression(expression.leftExpression, expressionList);
}
}
}
return false;

View File

@ -9032,6 +9032,30 @@ export function createTypeEvaluator(importLookup: ImportLookup): TypeEvaluator {
}
}
}
// Look for X == <literal> or X != <literal>
if (equalsOrNotEqualsOperator) {
const adjIsPositiveTest =
testExpression.operator === OperatorType.Equals ? isPositiveTest : !isPositiveTest;
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) {
const rightType = getTypeOfExpression(testExpression.rightExpression).type;
if (rightType.category === TypeCategory.Object && rightType.literalValue) {
return (type: Type) => {
return narrowTypeForLiteralComparison(type, rightType, adjIsPositiveTest);
};
}
}
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) {
const leftType = getTypeOfExpression(testExpression.leftExpression).type;
if (leftType.category === TypeCategory.Object && leftType.literalValue) {
return (type: Type) => {
return narrowTypeForLiteralComparison(type, leftType, adjIsPositiveTest);
};
}
}
}
}
}
@ -9230,6 +9254,33 @@ export function createTypeEvaluator(importLookup: ImportLookup): TypeEvaluator {
return type;
}
// Attempts to narrow a type (make it more constrained) based on a comparison
// (equal or not equal) to a literal value.
function narrowTypeForLiteralComparison(
referenceType: Type,
literalType: ObjectType,
isPositiveTest: boolean
): Type {
let canNarrow = true;
const narrowedType = doForSubtypes(referenceType, (subtype) => {
if (
subtype.category === TypeCategory.Object &&
ClassType.isSameGenericClass(literalType.classType, subtype.classType) &&
subtype.literalValue
) {
const literalValueMatches = subtype.literalValue === literalType.literalValue;
if ((literalValueMatches && !isPositiveTest) || (!literalValueMatches && isPositiveTest)) {
return undefined;
}
return subtype;
}
canNarrow = false;
return subtype;
});
return canNarrow ? narrowedType : referenceType;
}
// Attempts to narrow a type (make it more constrained) based on a
// call to "callable". For example, if the original type of expression "x" is
// Union[Callable[..., Any], Type[int], int], it would remove the "int" because

View File

@ -288,6 +288,12 @@ test('TypeConstraint9', () => {
validateResults(analysisResults, 0);
});
test('TypeConstraint10', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeConstraint10.py']);
validateResults(analysisResults, 0);
});
test('CircularBaseClass', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['circularBaseClass.py']);

View File

@ -0,0 +1,38 @@
# This sample tests the type constraint engine's handling
# of literals.
from typing import Literal
def requires_a(p1: Literal['a']):
pass
def requires_bc(p1: Literal['b', 'c']):
pass
def func_1(p1: Literal['a', 'b', 'c']):
if p1 != 'b':
if p1 == 'c':
pass
else:
requires_a(p1)
if p1 != 'a':
requires_bc(p1)
else:
requires_a(p1)
if 'a' != p1:
requires_bc(p1)
else:
requires_a(p1)
def requires_7(p1: Literal[7]):
pass
def func2(p1: Literal[1, 4, 7]):
if 4 == p1 or 1 == p1:
pass
else:
requires_7(p1)