diff --git a/docs/type-inference.md b/docs/type-inference.md index b7a6b3fc7..18f5d5ca2 100644 --- a/docs/type-inference.md +++ b/docs/type-inference.md @@ -191,6 +191,40 @@ def func2(p_int: int, p_str: str, p_flt: float): var2 = func1(p_str, p_flt, p_int) ``` +### Parameter Type Inference + +Input parameters for functions and methods typically require type annotations. There are several cases where Pyright may be able to infer a parameter’s type if it is unannotated. + +For instance methods, the first parameter (named `self` by convention) is inferred to be type `Self`. + +For class methods, the first parameter (named `cls` by convention) is inferred to be type `type[Self]`. + +For other unannotated parameters within a method, Pyright looks for a method of the same name implemented in a base class. If the corresponding method in the base class has the same signature (the same number of parameters with the same names), no overloads, and annotated parameter types, the type annotation from this method is “inherited” for the corresponding parameter in the child class method. + +```python +class Parent: + def method1(self, a: int, b: str) -> float: + ... + + +class Child(Parent): + def method1(self, a, b): + return a + +reveal_type(Child.method1) # (self: Child, a: int, b: int) -> int +``` + +When parameter types are inherited from a base class method, the return type is not inherited. Instead, normal return type inference techniques are used. + +If the type of an unannotated parameter cannot be inferred using any of the above techniques and the parameter has a default argument expression associated with it, the parameter type is inferred from the default argument type. If the default argument is `None`, the inferred type is `Unknown | None`. + +```python +def func(a, b=0, c=None): + pass + +reveal_type(func) # (a: Unknown, b: int, c: Unknown | None) -> None +``` + ### Literals Python 3.8 introduced support for _literal types_. This allows a type checker like Pyright to track specific literal values of str, bytes, int, bool, and enum values. As with other types, literal types can be declared. diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index c252e2c93..0d634acaa 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -14890,15 +14890,23 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions paramsArePositionOnly = false; } + // If there was no annotation for the parameter, infer its type if possible. + let isTypeInferred = false; + if (!paramType) { + isTypeInferred = true; + paramType = inferParameterType(node, functionType.details.flags, index, containingClassType); + } + const functionParam: FunctionParameter = { category: param.category, name: param.name ? param.name.value : undefined, hasDefault: !!param.defaultValue, defaultValueExpression: param.defaultValue, defaultType: defaultValueType, - type: paramType || UnknownType.create(), + type: paramType ?? UnknownType.create(), typeAnnotation: paramTypeNode, hasDeclaredType: !!paramTypeNode, + isTypeInferred, }; FunctionType.addParameter(functionType, functionParam); @@ -14919,25 +14927,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions }); } - if (containingClassType) { - // If the first parameter doesn't have an explicit type annotation, - // provide a type if it's an instance, class or constructor method. - if (functionType.details.parameters.length > 0) { - const typeAnnotation = getTypeAnnotationForParameter(node, 0); - if (!typeAnnotation) { - const inferredParamType = inferFirstParamType(functionType.details.flags, containingClassType); - if (inferredParamType) { - functionType.details.parameters[0].type = inferredParamType; - if (!isAnyOrUnknown(inferredParamType)) { - functionType.details.parameters[0].isTypeInferred = true; - } - - paramTypes[0] = inferredParamType; - } - } - } - } - // Update the types for the nodes associated with the parameters. paramTypes.forEach((paramType, index) => { const paramNameNode = node.parameters[index].name; @@ -15095,14 +15084,77 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return type; } - // Synthesizes the "self" or "cls" parameter type if they are not explicitly annotated. - function inferFirstParamType(flags: FunctionTypeFlags, containingClassType: ClassType): Type | undefined { - if ((flags & FunctionTypeFlags.StaticMethod) === 0) { - if (containingClassType) { - const hasClsParam = - (flags & (FunctionTypeFlags.ClassMethod | FunctionTypeFlags.ConstructorMethod)) !== 0; - return synthesizeTypeVarForSelfCls(containingClassType, hasClsParam); + // Attempts to infer an unannotated parameter type from available context. + function inferParameterType( + functionNode: FunctionNode, + functionFlags: FunctionTypeFlags, + paramIndex: number, + containingClassType: ClassType | undefined + ) { + // Is the function a method within a class? If so, see if a base class + // defines the same method and provides annotations. + if (containingClassType) { + if (paramIndex === 0) { + if ((functionFlags & FunctionTypeFlags.StaticMethod) === 0) { + const hasClsParam = + (functionFlags & (FunctionTypeFlags.ClassMethod | FunctionTypeFlags.ConstructorMethod)) !== 0; + return synthesizeTypeVarForSelfCls(containingClassType, hasClsParam); + } } + + const methodName = functionNode.name.value; + const baseClassMemberInfo = lookUpClassMember( + containingClassType, + methodName, + ClassMemberLookupFlags.SkipOriginalClass + ); + + if (baseClassMemberInfo) { + const memberDecls = baseClassMemberInfo.symbol.getDeclarations(); + if (memberDecls.length === 1 && memberDecls[0].type === DeclarationType.Function) { + const baseClassMethodNode = memberDecls[0].node; + + // Does the signature match exactly with the exception of annotations? + if ( + baseClassMethodNode.parameters.length === functionNode.parameters.length && + baseClassMethodNode.parameters.every((param, index) => { + const overrideParam = functionNode.parameters[index]; + return ( + overrideParam.name?.value === param.name?.value && + overrideParam.category === param.category + ); + }) + ) { + const baseClassParam = baseClassMethodNode.parameters[paramIndex]; + const baseClassParamAnnotation = + baseClassParam.typeAnnotation ?? baseClassParam.typeAnnotationComment; + if (baseClassParamAnnotation) { + return getTypeOfParameterAnnotation( + baseClassParamAnnotation, + functionNode.parameters[paramIndex].category + ); + } + } + } + } + } + + // If the parameter has a default argument value, we may be able to infer its + // type from this information. + const paramValueExpr = functionNode.parameters[paramIndex].defaultValue; + if (paramValueExpr) { + const defaultValueType = getTypeOfExpression( + paramValueExpr, + /* expectedType */ undefined, + EvaluatorFlags.ConvertEllipsisToAny + ).type; + + if (isNoneInstance(defaultValueType)) { + // Infer Optional[Unknown] in this case. + return combineTypes([NoneType.createInstance(), UnknownType.create()]); + } + + return stripLiteralValue(defaultValueType); } return undefined; @@ -16489,19 +16541,26 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return; } - // We may be able to infer the type of the first parameter. - if (paramIndex === 0) { - const containingClassNode = ParseTreeUtils.getEnclosingClass(functionNode, /* stopAtFunction */ true); - if (containingClassNode) { - const classInfo = getTypeOfClass(containingClassNode); - if (classInfo) { - const functionFlags = getFunctionFlagsFromDecorators(functionNode, /* isInClass */ true); - // If the first parameter doesn't have an explicit type annotation, - // provide a type if it's an instance, class or constructor method. - const inferredParamType = inferFirstParamType(functionFlags, classInfo.classType); + const containingClassNode = ParseTreeUtils.getEnclosingClass(functionNode, /* stopAtFunction */ true); + if (containingClassNode) { + const classInfo = getTypeOfClass(containingClassNode); + + if (classInfo) { + // See if the function is a method in a child class. We may be able to + // infer the type of the parameter from a method of the same name in + // a parent class if it has an annotated type. + const functionFlags = getFunctionFlagsFromDecorators(functionNode, /* isInClass */ true); + const inferredParamType = inferParameterType( + functionNode, + functionFlags, + paramIndex, + classInfo.classType + ); + + if (inferredParamType) { writeTypeCache( node.name!, - inferredParamType || UnknownType.create(), + transformVariadicParamType(node, node.category, inferredParamType), EvaluatorFlags.None, /* isIncomplete */ false ); diff --git a/packages/pyright-internal/src/tests/samples/paramInference1.py b/packages/pyright-internal/src/tests/samples/paramInference1.py new file mode 100644 index 000000000..7eca2773c --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/paramInference1.py @@ -0,0 +1,21 @@ +# This sample tests the logic that infers parameter types based on +# default argument values or annotated base class methods. + + +class Parent: + def func1(self, a: int, b: str) -> float: + ... + + +class Child(Parent): + def func1(self, a, b): + reveal_type(self, expected_text="Self@Child") + reveal_type(a, expected_text="int") + reveal_type(b, expected_text="str") + return a + + +def func2(a, b=0, c=None): + reveal_type(a, expected_text="Unknown") + reveal_type(b, expected_text="int") + reveal_type(c, expected_text="Unknown | None") diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index bee21d5ad..1c7a0a000 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -1250,3 +1250,9 @@ test('LiteralString1', () => { TestUtils.validateResults(analysisResults, 6); }); + +test('ParamInference1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['paramInference1.py']); + + TestUtils.validateResults(analysisResults, 0); +});