Added support for parameter type inference based on annotated base class method signatures and on default argument expressions.

This commit is contained in:
Eric Traut 2022-02-22 12:42:14 -07:00
parent 152c10116e
commit 450524d004
4 changed files with 158 additions and 38 deletions

View File

@ -191,6 +191,40 @@ def func2(p_int: int, p_str: str, p_flt: float):
var2 = func1(p_str, p_flt, p_int)
```
### Parameter Type Inference
Input parameters for functions and methods typically require type annotations. There are several cases where Pyright may be able to infer a parameters type if it is unannotated.
For instance methods, the first parameter (named `self` by convention) is inferred to be type `Self`.
For class methods, the first parameter (named `cls` by convention) is inferred to be type `type[Self]`.
For other unannotated parameters within a method, Pyright looks for a method of the same name implemented in a base class. If the corresponding method in the base class has the same signature (the same number of parameters with the same names), no overloads, and annotated parameter types, the type annotation from this method is “inherited” for the corresponding parameter in the child class method.
```python
class Parent:
def method1(self, a: int, b: str) -> float:
...
class Child(Parent):
def method1(self, a, b):
return a
reveal_type(Child.method1) # (self: Child, a: int, b: int) -> int
```
When parameter types are inherited from a base class method, the return type is not inherited. Instead, normal return type inference techniques are used.
If the type of an unannotated parameter cannot be inferred using any of the above techniques and the parameter has a default argument expression associated with it, the parameter type is inferred from the default argument type. If the default argument is `None`, the inferred type is `Unknown | None`.
```python
def func(a, b=0, c=None):
pass
reveal_type(func) # (a: Unknown, b: int, c: Unknown | None) -> None
```
### Literals
Python 3.8 introduced support for _literal types_. This allows a type checker like Pyright to track specific literal values of str, bytes, int, bool, and enum values. As with other types, literal types can be declared.

View File

@ -14890,15 +14890,23 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
paramsArePositionOnly = false;
}
// If there was no annotation for the parameter, infer its type if possible.
let isTypeInferred = false;
if (!paramType) {
isTypeInferred = true;
paramType = inferParameterType(node, functionType.details.flags, index, containingClassType);
}
const functionParam: FunctionParameter = {
category: param.category,
name: param.name ? param.name.value : undefined,
hasDefault: !!param.defaultValue,
defaultValueExpression: param.defaultValue,
defaultType: defaultValueType,
type: paramType || UnknownType.create(),
type: paramType ?? UnknownType.create(),
typeAnnotation: paramTypeNode,
hasDeclaredType: !!paramTypeNode,
isTypeInferred,
};
FunctionType.addParameter(functionType, functionParam);
@ -14919,25 +14927,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
});
}
if (containingClassType) {
// If the first parameter doesn't have an explicit type annotation,
// provide a type if it's an instance, class or constructor method.
if (functionType.details.parameters.length > 0) {
const typeAnnotation = getTypeAnnotationForParameter(node, 0);
if (!typeAnnotation) {
const inferredParamType = inferFirstParamType(functionType.details.flags, containingClassType);
if (inferredParamType) {
functionType.details.parameters[0].type = inferredParamType;
if (!isAnyOrUnknown(inferredParamType)) {
functionType.details.parameters[0].isTypeInferred = true;
}
paramTypes[0] = inferredParamType;
}
}
}
}
// Update the types for the nodes associated with the parameters.
paramTypes.forEach((paramType, index) => {
const paramNameNode = node.parameters[index].name;
@ -15095,14 +15084,77 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return type;
}
// Synthesizes the "self" or "cls" parameter type if they are not explicitly annotated.
function inferFirstParamType(flags: FunctionTypeFlags, containingClassType: ClassType): Type | undefined {
if ((flags & FunctionTypeFlags.StaticMethod) === 0) {
if (containingClassType) {
const hasClsParam =
(flags & (FunctionTypeFlags.ClassMethod | FunctionTypeFlags.ConstructorMethod)) !== 0;
return synthesizeTypeVarForSelfCls(containingClassType, hasClsParam);
// Attempts to infer an unannotated parameter type from available context.
function inferParameterType(
functionNode: FunctionNode,
functionFlags: FunctionTypeFlags,
paramIndex: number,
containingClassType: ClassType | undefined
) {
// Is the function a method within a class? If so, see if a base class
// defines the same method and provides annotations.
if (containingClassType) {
if (paramIndex === 0) {
if ((functionFlags & FunctionTypeFlags.StaticMethod) === 0) {
const hasClsParam =
(functionFlags & (FunctionTypeFlags.ClassMethod | FunctionTypeFlags.ConstructorMethod)) !== 0;
return synthesizeTypeVarForSelfCls(containingClassType, hasClsParam);
}
}
const methodName = functionNode.name.value;
const baseClassMemberInfo = lookUpClassMember(
containingClassType,
methodName,
ClassMemberLookupFlags.SkipOriginalClass
);
if (baseClassMemberInfo) {
const memberDecls = baseClassMemberInfo.symbol.getDeclarations();
if (memberDecls.length === 1 && memberDecls[0].type === DeclarationType.Function) {
const baseClassMethodNode = memberDecls[0].node;
// Does the signature match exactly with the exception of annotations?
if (
baseClassMethodNode.parameters.length === functionNode.parameters.length &&
baseClassMethodNode.parameters.every((param, index) => {
const overrideParam = functionNode.parameters[index];
return (
overrideParam.name?.value === param.name?.value &&
overrideParam.category === param.category
);
})
) {
const baseClassParam = baseClassMethodNode.parameters[paramIndex];
const baseClassParamAnnotation =
baseClassParam.typeAnnotation ?? baseClassParam.typeAnnotationComment;
if (baseClassParamAnnotation) {
return getTypeOfParameterAnnotation(
baseClassParamAnnotation,
functionNode.parameters[paramIndex].category
);
}
}
}
}
}
// If the parameter has a default argument value, we may be able to infer its
// type from this information.
const paramValueExpr = functionNode.parameters[paramIndex].defaultValue;
if (paramValueExpr) {
const defaultValueType = getTypeOfExpression(
paramValueExpr,
/* expectedType */ undefined,
EvaluatorFlags.ConvertEllipsisToAny
).type;
if (isNoneInstance(defaultValueType)) {
// Infer Optional[Unknown] in this case.
return combineTypes([NoneType.createInstance(), UnknownType.create()]);
}
return stripLiteralValue(defaultValueType);
}
return undefined;
@ -16489,19 +16541,26 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return;
}
// We may be able to infer the type of the first parameter.
if (paramIndex === 0) {
const containingClassNode = ParseTreeUtils.getEnclosingClass(functionNode, /* stopAtFunction */ true);
if (containingClassNode) {
const classInfo = getTypeOfClass(containingClassNode);
if (classInfo) {
const functionFlags = getFunctionFlagsFromDecorators(functionNode, /* isInClass */ true);
// If the first parameter doesn't have an explicit type annotation,
// provide a type if it's an instance, class or constructor method.
const inferredParamType = inferFirstParamType(functionFlags, classInfo.classType);
const containingClassNode = ParseTreeUtils.getEnclosingClass(functionNode, /* stopAtFunction */ true);
if (containingClassNode) {
const classInfo = getTypeOfClass(containingClassNode);
if (classInfo) {
// See if the function is a method in a child class. We may be able to
// infer the type of the parameter from a method of the same name in
// a parent class if it has an annotated type.
const functionFlags = getFunctionFlagsFromDecorators(functionNode, /* isInClass */ true);
const inferredParamType = inferParameterType(
functionNode,
functionFlags,
paramIndex,
classInfo.classType
);
if (inferredParamType) {
writeTypeCache(
node.name!,
inferredParamType || UnknownType.create(),
transformVariadicParamType(node, node.category, inferredParamType),
EvaluatorFlags.None,
/* isIncomplete */ false
);

View File

@ -0,0 +1,21 @@
# This sample tests the logic that infers parameter types based on
# default argument values or annotated base class methods.
class Parent:
def func1(self, a: int, b: str) -> float:
...
class Child(Parent):
def func1(self, a, b):
reveal_type(self, expected_text="Self@Child")
reveal_type(a, expected_text="int")
reveal_type(b, expected_text="str")
return a
def func2(a, b=0, c=None):
reveal_type(a, expected_text="Unknown")
reveal_type(b, expected_text="int")
reveal_type(c, expected_text="Unknown | None")

View File

@ -1250,3 +1250,9 @@ test('LiteralString1', () => {
TestUtils.validateResults(analysisResults, 6);
});
test('ParamInference1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['paramInference1.py']);
TestUtils.validateResults(analysisResults, 0);
});