mirror of
https://github.com/microsoft/pyright.git
synced 2024-08-16 11:20:22 +03:00
Added support for bool(x)
type guard.
This commit is contained in:
parent
00067d22d8
commit
046eab4a8d
@ -162,6 +162,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
|
||||
* `issubclass(x, T)` (where T is a type or a tuple of types)
|
||||
* `callable(x)`
|
||||
* `f(x)` (where f is a user-defined type guard as defined in [PEP 647](https://www.python.org/dev/peps/pep-0647/))
|
||||
* `bool(x)` (where x is any expression that is statically verifiable to be thruthy or falsy in all cases).
|
||||
* `x` (where x is any expression that is statically verifiable to be truthy or falsy in all cases)
|
||||
|
||||
Expressions supported for type guards include simple names, member access chains (e.g. `a.b.c.d`), the unary `not` operator, the binary `and` and `or` operators, subscripts that are constant numbers (e.g. `a[2]`), and call expressions. Other operators (such as arithmetic operators or other subscripts) are not supported.
|
||||
|
@ -330,99 +330,119 @@ export function getTypeNarrowingCallback(
|
||||
}
|
||||
|
||||
if (testExpression.nodeType === ParseNodeType.Call) {
|
||||
if (testExpression.leftExpression.nodeType === ParseNodeType.Name) {
|
||||
// Look for "isinstance(X, Y)" or "issubclass(X, Y)".
|
||||
if (
|
||||
(testExpression.leftExpression.value === 'isinstance' ||
|
||||
testExpression.leftExpression.value === 'issubclass') &&
|
||||
testExpression.arguments.length === 2
|
||||
) {
|
||||
// Make sure the first parameter is a supported expression type
|
||||
// and the second parameter is a valid class type or a tuple
|
||||
// of valid class types.
|
||||
const isInstanceCheck = testExpression.leftExpression.value === 'isinstance';
|
||||
const arg0Expr = testExpression.arguments[0].valueExpression;
|
||||
const arg1Expr = testExpression.arguments[1].valueExpression;
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
|
||||
const arg1Type = evaluator.getTypeOfExpression(
|
||||
arg1Expr,
|
||||
undefined,
|
||||
EvaluatorFlags.EvaluateStringLiteralAsType |
|
||||
EvaluatorFlags.ParamSpecDisallowed |
|
||||
EvaluatorFlags.TypeVarTupleDisallowed
|
||||
).type;
|
||||
const classTypeList = getIsInstanceClassTypes(arg1Type);
|
||||
if (classTypeList) {
|
||||
return (type: Type) => {
|
||||
const narrowedType = narrowTypeForIsInstance(
|
||||
evaluator,
|
||||
type,
|
||||
classTypeList,
|
||||
isInstanceCheck,
|
||||
isPositiveTest,
|
||||
/* allowIntersections */ false,
|
||||
testExpression
|
||||
);
|
||||
if (!isNever(narrowedType)) {
|
||||
return narrowedType;
|
||||
}
|
||||
const callType = evaluator.getTypeOfExpression(
|
||||
testExpression.leftExpression,
|
||||
/* expectedType */ undefined,
|
||||
EvaluatorFlags.DoNotSpecialize
|
||||
).type;
|
||||
|
||||
// Try again with intersection types allowed.
|
||||
return narrowTypeForIsInstance(
|
||||
evaluator,
|
||||
type,
|
||||
classTypeList,
|
||||
isInstanceCheck,
|
||||
isPositiveTest,
|
||||
/* allowIntersections */ true,
|
||||
testExpression
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if (testExpression.leftExpression.value === 'callable' && testExpression.arguments.length === 1) {
|
||||
const arg0Expr = testExpression.arguments[0].valueExpression;
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
|
||||
// Look for "isinstance(X, Y)" or "issubclass(X, Y)".
|
||||
if (
|
||||
isFunction(callType) &&
|
||||
(callType.details.builtInName === 'isinstance' || callType.details.builtInName === 'issubclass') &&
|
||||
testExpression.arguments.length === 2
|
||||
) {
|
||||
// Make sure the first parameter is a supported expression type
|
||||
// and the second parameter is a valid class type or a tuple
|
||||
// of valid class types.
|
||||
const isInstanceCheck = callType.details.builtInName === 'isinstance';
|
||||
const arg0Expr = testExpression.arguments[0].valueExpression;
|
||||
const arg1Expr = testExpression.arguments[1].valueExpression;
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
|
||||
const arg1Type = evaluator.getTypeOfExpression(
|
||||
arg1Expr,
|
||||
undefined,
|
||||
EvaluatorFlags.EvaluateStringLiteralAsType |
|
||||
EvaluatorFlags.ParamSpecDisallowed |
|
||||
EvaluatorFlags.TypeVarTupleDisallowed
|
||||
).type;
|
||||
const classTypeList = getIsInstanceClassTypes(arg1Type);
|
||||
if (classTypeList) {
|
||||
return (type: Type) => {
|
||||
let narrowedType = narrowTypeForCallable(
|
||||
const narrowedType = narrowTypeForIsInstance(
|
||||
evaluator,
|
||||
type,
|
||||
classTypeList,
|
||||
isInstanceCheck,
|
||||
isPositiveTest,
|
||||
testExpression,
|
||||
/* allowIntersections */ false
|
||||
/* allowIntersections */ false,
|
||||
testExpression
|
||||
);
|
||||
if (isPositiveTest && isNever(narrowedType)) {
|
||||
// Try again with intersections allowed.
|
||||
narrowedType = narrowTypeForCallable(
|
||||
evaluator,
|
||||
type,
|
||||
isPositiveTest,
|
||||
testExpression,
|
||||
/* allowIntersections */ true
|
||||
);
|
||||
if (!isNever(narrowedType)) {
|
||||
return narrowedType;
|
||||
}
|
||||
|
||||
return narrowedType;
|
||||
// Try again with intersection types allowed.
|
||||
return narrowTypeForIsInstance(
|
||||
evaluator,
|
||||
type,
|
||||
classTypeList,
|
||||
isInstanceCheck,
|
||||
isPositiveTest,
|
||||
/* allowIntersections */ true,
|
||||
testExpression
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Look for "callable(X)"
|
||||
if (
|
||||
isFunction(callType) &&
|
||||
callType.details.builtInName === 'callable' &&
|
||||
testExpression.arguments.length === 1
|
||||
) {
|
||||
const arg0Expr = testExpression.arguments[0].valueExpression;
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
|
||||
return (type: Type) => {
|
||||
let narrowedType = narrowTypeForCallable(
|
||||
evaluator,
|
||||
type,
|
||||
isPositiveTest,
|
||||
testExpression,
|
||||
/* allowIntersections */ false
|
||||
);
|
||||
if (isPositiveTest && isNever(narrowedType)) {
|
||||
// Try again with intersections allowed.
|
||||
narrowedType = narrowTypeForCallable(
|
||||
evaluator,
|
||||
type,
|
||||
isPositiveTest,
|
||||
testExpression,
|
||||
/* allowIntersections */ true
|
||||
);
|
||||
}
|
||||
|
||||
return narrowedType;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Look for "bool(X)"
|
||||
if (
|
||||
isInstantiableClass(callType) &&
|
||||
ClassType.isBuiltIn(callType, 'bool') &&
|
||||
testExpression.arguments.length === 1 &&
|
||||
!testExpression.arguments[0].name
|
||||
) {
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, testExpression.arguments[0].valueExpression)) {
|
||||
return (type: Type) => {
|
||||
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Look for a TypeGuard assertion function.
|
||||
if (testExpression.arguments.length >= 1) {
|
||||
const arg0Expr = testExpression.arguments[0].valueExpression;
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
|
||||
const functionType = evaluator.getTypeOfExpression(
|
||||
testExpression.leftExpression,
|
||||
/* expectedType */ undefined,
|
||||
EvaluatorFlags.DoNotSpecialize
|
||||
).type;
|
||||
|
||||
// Does this look like it's a custom type guard function?
|
||||
if (
|
||||
isFunction(functionType) &&
|
||||
functionType.details.declaredReturnType &&
|
||||
isClassInstance(functionType.details.declaredReturnType) &&
|
||||
ClassType.isBuiltIn(functionType.details.declaredReturnType, 'TypeGuard')
|
||||
isFunction(callType) &&
|
||||
callType.details.declaredReturnType &&
|
||||
isClassInstance(callType.details.declaredReturnType) &&
|
||||
ClassType.isBuiltIn(callType.details.declaredReturnType, 'TypeGuard')
|
||||
) {
|
||||
// Evaluate the type guard call expression.
|
||||
const functionReturnType = evaluator.getTypeOfExpression(testExpression).type;
|
||||
@ -446,19 +466,7 @@ export function getTypeNarrowingCallback(
|
||||
|
||||
if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
|
||||
return (type: Type) => {
|
||||
// Narrow the type based on whether the subtype can be true or false.
|
||||
return mapSubtypes(type, (subtype) => {
|
||||
if (isPositiveTest) {
|
||||
if (evaluator.canBeTruthy(subtype)) {
|
||||
return evaluator.removeFalsinessFromType(subtype);
|
||||
}
|
||||
} else {
|
||||
if (evaluator.canBeFalsy(subtype)) {
|
||||
return evaluator.removeTruthinessFromType(subtype);
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
return narrowTypeForTruthiness(evaluator, type, isPositiveTest);
|
||||
};
|
||||
}
|
||||
|
||||
@ -573,6 +581,22 @@ function getDeclsForLocalVar(
|
||||
return reachableDecls.length > 0 ? reachableDecls : undefined;
|
||||
}
|
||||
|
||||
// Narrow the type based on whether the subtype can be true or false.
|
||||
function narrowTypeForTruthiness(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean) {
|
||||
return mapSubtypes(type, (subtype) => {
|
||||
if (isPositiveTest) {
|
||||
if (evaluator.canBeTruthy(subtype)) {
|
||||
return evaluator.removeFalsinessFromType(subtype);
|
||||
}
|
||||
} else {
|
||||
if (evaluator.canBeFalsy(subtype)) {
|
||||
return evaluator.removeTruthinessFromType(subtype);
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
|
||||
// Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where
|
||||
// I is an integer and a is a union of Tuples with known lengths and entry types.
|
||||
function narrowTupleTypeForIsNone(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean, indexValue: number) {
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This sample tests type narrowing for falsy and truthy values.
|
||||
|
||||
from typing import List, Literal, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
||||
class A:
|
||||
@ -27,3 +27,17 @@ def func1(x: Union[int, List[int], A, B, C, D, None]) -> None:
|
||||
t1: Literal["int | List[int] | A | B | D"] = reveal_type(x)
|
||||
else:
|
||||
t2: Literal["int | List[int] | B | C | None"] = reveal_type(x)
|
||||
|
||||
|
||||
def func2(maybe_int: Optional[int]):
|
||||
if bool(maybe_int):
|
||||
t1: Literal["int"] = reveal_type(maybe_int)
|
||||
else:
|
||||
t2: Literal["int | None"] = reveal_type(maybe_int)
|
||||
|
||||
|
||||
def func3(maybe_a: Optional[A]):
|
||||
if bool(maybe_a):
|
||||
t1: Literal["A"] = reveal_type(maybe_a)
|
||||
else:
|
||||
t2: Literal["None"] = reveal_type(maybe_a)
|
||||
|
Loading…
Reference in New Issue
Block a user