Added support for bool(x) type guard.

This commit is contained in:
Eric Traut 2021-12-31 14:23:56 -07:00
parent 00067d22d8
commit 046eab4a8d
3 changed files with 128 additions and 89 deletions

View File

@ -162,6 +162,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `issubclass(x, T)` (where T is a type or a tuple of types)
* `callable(x)`
* `f(x)` (where f is a user-defined type guard as defined in [PEP 647](https://www.python.org/dev/peps/pep-0647/))
* `bool(x)` (where x is any expression that is statically verifiable to be thruthy or falsy in all cases).
* `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases)
Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported.

View File

@ -330,99 +330,119 @@ export function getTypeNarrowingCallback(
}
if (testExpression.nodeType === ParseNodeType.Call) {
if (testExpression.leftExpression.nodeType === ParseNodeType.Name) {
// Look for "isinstance(X, Y)" or "issubclass(X, Y)".
if (
(testExpression.leftExpression.value === 'isinstance' ||
testExpression.leftExpression.value === 'issubclass') &&
testExpression.arguments.length === 2
) {
// Make sure the first parameter is a supported expression type
// and the second parameter is a valid class type or a tuple
// of valid class types.
const isInstanceCheck = testExpression.leftExpression.value === 'isinstance';
const arg0Expr = testExpression.arguments[0].valueExpression;
const arg1Expr = testExpression.arguments[1].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const arg1Type = evaluator.getTypeOfExpression(
arg1Expr,
undefined,
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ParamSpecDisallowed |
EvaluatorFlags.TypeVarTupleDisallowed
).type;
const classTypeList = getIsInstanceClassTypes(arg1Type);
if (classTypeList) {
return (type: Type) => {
const narrowedType = narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ false,
testExpression
);
if (!isNever(narrowedType)) {
return narrowedType;
}
const callType = evaluator.getTypeOfExpression(
testExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
// Try again with intersection types allowed.
return narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
testExpression
);
};
}
}
} else if (testExpression.leftExpression.value === 'callable' && testExpression.arguments.length === 1) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
// Look for "isinstance(X, Y)" or "issubclass(X, Y)".
if (
isFunction(callType) &&
(callType.details.builtInName === 'isinstance' || callType.details.builtInName === 'issubclass') &&
testExpression.arguments.length === 2
) {
// Make sure the first parameter is a supported expression type
// and the second parameter is a valid class type or a tuple
// of valid class types.
const isInstanceCheck = callType.details.builtInName === 'isinstance';
const arg0Expr = testExpression.arguments[0].valueExpression;
const arg1Expr = testExpression.arguments[1].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const arg1Type = evaluator.getTypeOfExpression(
arg1Expr,
undefined,
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ParamSpecDisallowed |
EvaluatorFlags.TypeVarTupleDisallowed
).type;
const classTypeList = getIsInstanceClassTypes(arg1Type);
if (classTypeList) {
return (type: Type) => {
let narrowedType = narrowTypeForCallable(
const narrowedType = narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
testExpression,
/* allowIntersections */ false
/* allowIntersections */ false,
testExpression
);
if (isPositiveTest && isNever(narrowedType)) {
// Try again with intersections allowed.
narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ true
);
if (!isNever(narrowedType)) {
return narrowedType;
}
return narrowedType;
// Try again with intersection types allowed.
return narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
testExpression
);
};
}
}
}
// Look for "callable(X)"
if (
isFunction(callType) &&
callType.details.builtInName === 'callable' &&
testExpression.arguments.length === 1
) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
return (type: Type) => {
let narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ false
);
if (isPositiveTest && isNever(narrowedType)) {
// Try again with intersections allowed.
narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ true
);
}
return narrowedType;
};
}
}
// Look for "bool(X)"
if (
isInstantiableClass(callType) &&
ClassType.isBuiltIn(callType, 'bool') &&
testExpression.arguments.length === 1 &&
!testExpression.arguments[0].name
) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.arguments[0].valueExpression)) {
return (type: Type) => {
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
};
}
}
// Look for a TypeGuard assertion function.
if (testExpression.arguments.length >= 1) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const functionType = evaluator.getTypeOfExpression(
testExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
// Does this look like it's a custom type guard function?
if (
isFunction(functionType) &&
functionType.details.declaredReturnType &&
isClassInstance(functionType.details.declaredReturnType) &&
ClassType.isBuiltIn(functionType.details.declaredReturnType, 'TypeGuard')
isFunction(callType) &&
callType.details.declaredReturnType &&
isClassInstance(callType.details.declaredReturnType) &&
ClassType.isBuiltIn(callType.details.declaredReturnType, 'TypeGuard')
) {
// Evaluate the type guard call expression.
const functionReturnType = evaluator.getTypeOfExpression(testExpression).type;
@ -446,19 +466,7 @@ export function getTypeNarrowingCallback(
if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
return (type: Type) => {
// Narrow the type based on whether the subtype can be true or false.
return mapSubtypes(type, (subtype) => {
if (isPositiveTest) {
if (evaluator.canBeTruthy(subtype)) {
return evaluator.removeFalsinessFromType(subtype);
}
} else {
if (evaluator.canBeFalsy(subtype)) {
return evaluator.removeTruthinessFromType(subtype);
}
}
return undefined;
});
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
};
}
@ -573,6 +581,22 @@ function getDeclsForLocalVar(
return reachableDecls.length > 0 ? reachableDecls : undefined;
}
// Narrow the type based on whether the subtype can be true or false.
function narrowTypeForTruthiness(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean) {
return mapSubtypes(type, (subtype) => {
if (isPositiveTest) {
if (evaluator.canBeTruthy(subtype)) {
return evaluator.removeFalsinessFromType(subtype);
}
} else {
if (evaluator.canBeFalsy(subtype)) {
return evaluator.removeTruthinessFromType(subtype);
}
}
return undefined;
});
}
// Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where
// I is an integer and a is a union of Tuples with known lengths and entry types.
function narrowTupleTypeForIsNone(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean, indexValue: number) {

View File

@ -1,6 +1,6 @@
# This sample tests type narrowing for falsy and truthy values.
from typing import List, Literal, Union
from typing import List, Literal, Optional, Union
class A:
@ -27,3 +27,17 @@ def func1(x: Union[int, List[int], A, B, C, D, None]) -> None:
t1: Literal["int | List[int] | A | B | D"] = reveal_type(x)
else:
t2: Literal["int | List[int] | B | C | None"] = reveal_type(x)
def func2(maybe_int: Optional[int]):
if bool(maybe_int):
t1: Literal["int"] = reveal_type(maybe_int)
else:
t2: Literal["int | None"] = reveal_type(maybe_int)
def func3(maybe_a: Optional[A]):
if bool(maybe_a):
t1: Literal["A"] = reveal_type(maybe_a)
else:
t2: Literal["None"] = reveal_type(maybe_a)