Implemented provisional "TypeGuard" functionality that allows for user-defined type guard functions. This must still go through a spec'ing and ratification process before it is finalized. Until then, details could change.

This commit is contained in:
Eric Traut 2020-10-06 00:10:38 -07:00
parent b01e620bd1
commit d638cc9b31
8 changed files with 150 additions and 1 deletions

View File

@ -2084,6 +2084,11 @@ export class Binder extends ParseTreeWalker {
) {
return this._isNarrowingExpression(expression.arguments[0].valueExpression, expressionList);
}
// Is this potentially a call to a user-defined type guard function?
if (expression.arguments.length >= 1) {
return this._isNarrowingExpression(expression.arguments[0].valueExpression, expressionList);
}
}
}
@ -2808,6 +2813,7 @@ export class Binder extends ParseTreeWalker {
TypeAlias: true,
OrderedDict: true,
Concatenate: true,
TypeGuard: true,
};
const assignedName = assignedNameNode.value;

View File

@ -8714,6 +8714,38 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return type;
}
// Creates a "TypeGuard" type.
function createTypeGuardType(errorNode: ParseNode, classType: ClassType, typeArgs: TypeResult[] | undefined): Type {
// The first time that we use the TypeGuard special type, it won't have
// a type parameter. We'll synthesize one here.
if (classType.details.typeParameters.length === 0) {
classType.details.typeParameters.push(
TypeVarType.createInstance('_T', /* isParamSpec */ false, /* isSynthesized */ true)
);
}
if (!typeArgs || typeArgs.length !== 1) {
addError(Localizer.Diagnostic.typeGuardArgCount(), errorNode);
}
let typeArg: Type;
if (typeArgs && typeArgs.length > 0) {
typeArg = typeArgs[0].type;
if (isEllipsisType(typeArg)) {
addError(Localizer.Diagnostic.ellipsisContext(), typeArgs[0].node);
} else if (isModule(typeArg)) {
addError(Localizer.Diagnostic.moduleContext(), typeArgs[0].node);
} else if (isParamSpecType(typeArg)) {
addError(Localizer.Diagnostic.paramSpecContext(), typeArgs[0].node);
}
} else {
typeArg = UnknownType.create();
}
return ClassType.cloneForSpecialization(classType, [convertToInstance(typeArg)], !!typeArgs);
}
// Creates a "Final" type.
function createFinalType(errorNode: ParseNode, typeArgs: TypeResult[] | undefined, flags: EvaluatorFlags): Type {
if (flags & EvaluatorFlags.FinalDisallowed) {
@ -9052,6 +9084,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
Annotated: { alias: '', module: 'builtins' },
TypeAlias: { alias: '', module: 'builtins' },
Concatenate: { alias: '', module: 'builtins' },
TypeGuard: { alias: 'bool', module: 'builtins' },
};
const aliasMapEntry = specialTypes[assignedName];
@ -12559,6 +12592,38 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}
}
if (testExpression.arguments.length >= 1) {
const functionType = getTypeOfExpression(testExpression.leftExpression).type;
// Does this look like it's a custom type guard function?
if (
isFunction(functionType) &&
functionType.details.declaredReturnType &&
isObject(functionType.details.declaredReturnType) &&
ClassType.isBuiltIn(functionType.details.declaredReturnType.classType, 'TypeGuard')
) {
const arg0Expr = testExpression.arguments[0].valueExpression;
if (ParseTreeUtils.isMatchingExpression(reference, arg0Expr)) {
// Evaluate the type guard call expression.
const functionReturnType = getTypeOfExpression(testExpression).type;
if (
isObject(functionReturnType) &&
ClassType.isBuiltIn(functionReturnType.classType, 'TypeGuard')
) {
const typeGuardTypeArgs = functionReturnType.classType.typeArguments;
const typeGuardTypeArg =
typeGuardTypeArgs && typeGuardTypeArgs.length > 0
? typeGuardTypeArgs[0]
: UnknownType.create();
return (type: Type) => {
return isPositiveTest ? typeGuardTypeArg : type;
};
}
}
}
}
}
if (ParseTreeUtils.isMatchingExpression(reference, testExpression)) {
@ -12984,6 +13049,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
case 'Concatenate': {
return createConcatenateType(errorNode, classType, typeArgs);
}
case 'TypeGuard': {
return createTypeGuardType(errorNode, classType, typeArgs);
}
}
}

View File

@ -1442,7 +1442,7 @@ function _specializeClassType(
recursionLevel: number
): ClassType {
// Handle the common case where the class has no type parameters.
if (ClassType.getTypeParameters(classType).length === 0) {
if (ClassType.getTypeParameters(classType).length === 0 && !ClassType.isSpecialBuiltIn(classType)) {
return classType;
}
@ -1663,6 +1663,7 @@ export function requiresTypeArguments(classType: ClassType) {
'Final',
'Literal',
'Annotated',
'TypeGuard',
];
if (specialClasses.some((t) => t === classType.details.name)) {
return true;

View File

@ -566,6 +566,7 @@ export namespace Localizer {
export const typedDictTotalParam = () => getRawString('Diagnostic.typedDictTotalParam');
export const typeExpectedClass = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.typeExpectedClass'));
export const typeGuardArgCount = () => getRawString('Diagnostic.typeGuardArgCount');
export const typeNotAwaitable = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.typeNotAwaitable'));
export const typeNotCallable = () =>

View File

@ -280,6 +280,7 @@
"typedDictSet": "Could not assign item in TypedDict",
"typedDictTotalParam": "Expected \"total\" parameter to have a value of True or False",
"typeExpectedClass": "Expected class type but received \"{type}\"",
"typeGuardArgCount": "Expected a single type argument after \"TypeGuard\"",
"typeNotAwaitable": "\"{type}\" is not awaitable",
"typeNotCallable": "\"{expression}\" has type \"{type}\" and is not callable",
"typeNotIntantiable": "\"{type}\" cannot be instantiated",

View File

@ -0,0 +1,63 @@
# This sample tests the TypeGuard functionality
# that allows user-defined functions to perform
# conditional type narrowing.
import os
from typing import Any, List, Literal, Tuple, TypeVar, Union
from typing_extensions import TypeGuard
_T = TypeVar("_T")
def is_two_element_tuple(a: Tuple[_T, ...]) -> TypeGuard[Tuple[_T, _T]]:
return True
def func1(a: Tuple[int, ...]):
if is_two_element_tuple(a):
t1: Literal["Tuple[int, int]"] = reveal_type(a)
else:
t2: Literal["Tuple[int, ...]"] = reveal_type(a)
def is_string_list(val: List[Any], allow_zero_entries: bool) -> TypeGuard[List[str]]:
if allow_zero_entries and len(val) == 0:
return True
return all(isinstance(x, str) for x in val)
def func2(a: List[Union[str, int]]):
if is_string_list(a, True):
t1: Literal["List[str]"] = reveal_type(a)
else:
t2: Literal["List[str | int]"] = reveal_type(a)
# This should generate an error because TypeGuard
# has no type argument.
def bad1(a: int) -> TypeGuard:
return True
# This should generate an error because TypeGuard
# has too many type arguments.
def bad2(a: int) -> TypeGuard[str, int]:
return True
# This should generate an error because TypeGuard
# does not accept an elipsis.
def bad3(a: int) -> TypeGuard[...]:
return True
# This should generate an error because TypeGuard
# has does not accept a module.
def bad4(a: int) -> TypeGuard[os]:
return True
def bad5(a: int) -> TypeGuard[int]:
# This should generate an error because only
# bool values can be returned.
return 3

View File

@ -1111,3 +1111,9 @@ test('Enums3', () => {
TestUtils.validateResults(analysisResults, 0);
});
test('TypeGuard1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeGuard1.py']);
TestUtils.validateResults(analysisResults, 5);
});

View File

@ -103,3 +103,5 @@ class TypeAlias: ...
class SupportsIndex(Protocol, metaclass=abc.ABCMeta):
@abc.abstractmethod
def __index__(self) -> int: ...
TypeGuard: _SpecialForm = ...