Added support for bool(x) type guard.

This commit is contained in:
Eric Traut 2021-12-31 14:23:56 -07:00
parent 00067d22d8
commit 046eab4a8d
3 changed files with 128 additions and 89 deletions

View File

@ -162,6 +162,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `issubclass(x, T)` (where T is a type or a tuple of types) * `issubclass(x, T)` (where T is a type or a tuple of types)
* `callable(x)` * `callable(x)`
* `f(x)` (where f is a user-defined type guard as defined in [PEP 647](https://www.python.org/dev/peps/pep-0647/)) * `f(x)` (where f is a user-defined type guard as defined in [PEP 647](https://www.python.org/dev/peps/pep-0647/))
* `bool(x)` (where x is any expression that is statically verifiable to be thruthy or falsy in all cases).
* `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases) * `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases)
Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported. Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported.

View File

@ -330,99 +330,119 @@ export function getTypeNarrowingCallback(
} }
if (testExpression.nodeType === ParseNodeType.Call) { if (testExpression.nodeType === ParseNodeType.Call) {
if (testExpression.leftExpression.nodeType === ParseNodeType.Name) { const callType = evaluator.getTypeOfExpression(
// Look for "isinstance(X, Y)" or "issubclass(X, Y)". testExpression.leftExpression,
if ( /* expectedType */ undefined,
(testExpression.leftExpression.value === 'isinstance' || EvaluatorFlags.DoNotSpecialize
testExpression.leftExpression.value === 'issubclass') && ).type;
testExpression.arguments.length === 2
) {
// Make sure the first parameter is a supported expression type
// and the second parameter is a valid class type or a tuple
// of valid class types.
const isInstanceCheck = testExpression.leftExpression.value === 'isinstance';
const arg0Expr = testExpression.arguments[0].valueExpression;
const arg1Expr = testExpression.arguments[1].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const arg1Type = evaluator.getTypeOfExpression(
arg1Expr,
undefined,
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ParamSpecDisallowed |
EvaluatorFlags.TypeVarTupleDisallowed
).type;
const classTypeList = getIsInstanceClassTypes(arg1Type);
if (classTypeList) {
return (type: Type) => {
const narrowedType = narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ false,
testExpression
);
if (!isNever(narrowedType)) {
return narrowedType;
}
// Try again with intersection types allowed. // Look for "isinstance(X, Y)" or "issubclass(X, Y)".
return narrowTypeForIsInstance( if (
evaluator, isFunction(callType) &&
type, (callType.details.builtInName === 'isinstance' || callType.details.builtInName === 'issubclass') &&
classTypeList, testExpression.arguments.length === 2
isInstanceCheck, ) {
isPositiveTest, // Make sure the first parameter is a supported expression type
/* allowIntersections */ true, // and the second parameter is a valid class type or a tuple
testExpression // of valid class types.
); const isInstanceCheck = callType.details.builtInName === 'isinstance';
}; const arg0Expr = testExpression.arguments[0].valueExpression;
} const arg1Expr = testExpression.arguments[1].valueExpression;
} if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
} else if (testExpression.leftExpression.value === 'callable' && testExpression.arguments.length === 1) { const arg1Type = evaluator.getTypeOfExpression(
const arg0Expr = testExpression.arguments[0].valueExpression; arg1Expr,
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) { undefined,
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ParamSpecDisallowed |
EvaluatorFlags.TypeVarTupleDisallowed
).type;
const classTypeList = getIsInstanceClassTypes(arg1Type);
if (classTypeList) {
return (type: Type) => { return (type: Type) => {
let narrowedType = narrowTypeForCallable( const narrowedType = narrowTypeForIsInstance(
evaluator, evaluator,
type, type,
classTypeList,
isInstanceCheck,
isPositiveTest, isPositiveTest,
testExpression, /* allowIntersections */ false,
/* allowIntersections */ false testExpression
); );
if (isPositiveTest && isNever(narrowedType)) { if (!isNever(narrowedType)) {
// Try again with intersections allowed. return narrowedType;
narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ true
);
} }
return narrowedType; // Try again with intersection types allowed.
return narrowTypeForIsInstance(
evaluator,
type,
classTypeList,
isInstanceCheck,
isPositiveTest,
/* allowIntersections */ true,
testExpression
);
}; };
} }
} }
} }
// Look for "callable(X)"
if (
isFunction(callType) &&
callType.details.builtInName === 'callable' &&
testExpression.arguments.length === 1
) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
return (type: Type) => {
let narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ false
);
if (isPositiveTest && isNever(narrowedType)) {
// Try again with intersections allowed.
narrowedType = narrowTypeForCallable(
evaluator,
type,
isPositiveTest,
testExpression,
/* allowIntersections */ true
);
}
return narrowedType;
};
}
}
// Look for "bool(X)"
if (
isInstantiableClass(callType) &&
ClassType.isBuiltIn(callType, 'bool') &&
testExpression.arguments.length === 1 &&
!testExpression.arguments[0].name
) {
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.arguments[0].valueExpression)) {
return (type: Type) => {
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
};
}
}
// Look for a TypeGuard assertion function.
if (testExpression.arguments.length >= 1) { if (testExpression.arguments.length >= 1) {
const arg0Expr = testExpression.arguments[0].valueExpression; const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) { if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
const functionType = evaluator.getTypeOfExpression(
testExpression.leftExpression,
/* expectedType */ undefined,
EvaluatorFlags.DoNotSpecialize
).type;
// Does this look like it's a custom type guard function? // Does this look like it's a custom type guard function?
if ( if (
isFunction(functionType) && isFunction(callType) &&
functionType.details.declaredReturnType && callType.details.declaredReturnType &&
isClassInstance(functionType.details.declaredReturnType) && isClassInstance(callType.details.declaredReturnType) &&
ClassType.isBuiltIn(functionType.details.declaredReturnType, 'TypeGuard') ClassType.isBuiltIn(callType.details.declaredReturnType, 'TypeGuard')
) { ) {
// Evaluate the type guard call expression. // Evaluate the type guard call expression.
const functionReturnType = evaluator.getTypeOfExpression(testExpression).type; const functionReturnType = evaluator.getTypeOfExpression(testExpression).type;
@ -446,19 +466,7 @@ export function getTypeNarrowingCallback(
if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) { if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
return (type: Type) => { return (type: Type) => {
// Narrow the type based on whether the subtype can be true or false. return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
return mapSubtypes(type, (subtype) => {
if (isPositiveTest) {
if (evaluator.canBeTruthy(subtype)) {
return evaluator.removeFalsinessFromType(subtype);
}
} else {
if (evaluator.canBeFalsy(subtype)) {
return evaluator.removeTruthinessFromType(subtype);
}
}
return undefined;
});
}; };
} }
@ -573,6 +581,22 @@ function getDeclsForLocalVar(
return reachableDecls.length > 0 ? reachableDecls : undefined; return reachableDecls.length > 0 ? reachableDecls : undefined;
} }
// Narrow the type based on whether the subtype can be true or false.
function narrowTypeForTruthiness(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean) {
return mapSubtypes(type, (subtype) => {
if (isPositiveTest) {
if (evaluator.canBeTruthy(subtype)) {
return evaluator.removeFalsinessFromType(subtype);
}
} else {
if (evaluator.canBeFalsy(subtype)) {
return evaluator.removeTruthinessFromType(subtype);
}
}
return undefined;
});
}
// Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where // Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where
// I is an integer and a is a union of Tuples with known lengths and entry types. // I is an integer and a is a union of Tuples with known lengths and entry types.
function narrowTupleTypeForIsNone(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean, indexValue: number) { function narrowTupleTypeForIsNone(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean, indexValue: number) {

View File

@ -1,6 +1,6 @@
# This sample tests type narrowing for falsy and truthy values. # This sample tests type narrowing for falsy and truthy values.
from typing import List, Literal, Union from typing import List, Literal, Optional, Union
class A: class A:
@ -27,3 +27,17 @@ def func1(x: Union[int, List[int], A, B, C, D, None]) -> None:
t1: Literal["int | List[int] | A | B | D"] = reveal_type(x) t1: Literal["int | List[int] | A | B | D"] = reveal_type(x)
else: else:
t2: Literal["int | List[int] | B | C | None"] = reveal_type(x) t2: Literal["int | List[int] | B | C | None"] = reveal_type(x)
def func2(maybe_int: Optional[int]):
if bool(maybe_int):
t1: Literal["int"] = reveal_type(maybe_int)
else:
t2: Literal["int | None"] = reveal_type(maybe_int)
def func3(maybe_a: Optional[A]):
if bool(maybe_a):
t1: Literal["A"] = reveal_type(maybe_a)
else:
t2: Literal["None"] = reveal_type(maybe_a)