diff --git a/docs/type-inference.md b/docs/type-inference.md index a374c69c5..2f9833a4e 100644 --- a/docs/type-inference.md +++ b/docs/type-inference.md @@ -225,6 +225,13 @@ def func(a, b=0, c=None): reveal_type(func) # (a: Unknown, b: int, c: Unknown | None) -> None ``` +This inference technique also applies to lambdas whose input parameters include default arguments. + +```python +cb = lambda x = "": x +reveal_type(cb) # (x: str = "" -> str) +``` + #### Literals Python 3.8 introduced support for _literal types_. This allows a type checker like Pyright to track specific literal values of str, bytes, int, bool, and enum values. As with other types, literal types can be declared. diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 00eb1d794..d78a6c2b2 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -13835,6 +13835,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // its type from the default value expression. paramType = getTypeOfExpression(param.defaultValue, undefined, inferenceContext).type; } + } else if (param.defaultValue) { + // If there is no inference context but we have a default value, + // use the default value to infer the parameter's type. + paramType = inferParameterTypeFromDefaultValue(param.defaultValue); } if (param.name) { @@ -17638,54 +17642,58 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // type from this information. const paramValueExpr = functionNode.parameters[paramIndex].defaultValue; if (paramValueExpr) { - const defaultValueType = getTypeOfExpression(paramValueExpr, EvaluatorFlags.ConvertEllipsisToAny).type; - - let inferredParamType: Type | undefined; - - // Is the default value a "None" or an instance of some private class (one - // whose name starts with an underscore)? If so, we will assume that the - // value is a singleton sentinel. The actual supported type is going to be - // a union of this type and Unknown. - if ( - isNoneInstance(defaultValueType) || - (isClassInstance(defaultValueType) && isPrivateOrProtectedName(defaultValueType.details.name)) - ) { - inferredParamType = combineTypes([defaultValueType, UnknownType.create()]); - } else { - let skipInference = false; - - if (isFunction(defaultValueType) || isOverloadedFunction(defaultValueType)) { - // Do not infer parameter types that use a lambda or another function as a - // default value. We're likely to generate false positives in this case. - // It's not clear whether parameters should be positional-only or not. - skipInference = true; - } else if ( - isClassInstance(defaultValueType) && - ClassType.isBuiltIn(defaultValueType, ['tuple', 'list', 'set', 'dict']) - ) { - // Do not infer certain types like tuple because it's likely to be - // more restrictive (narrower) than intended. - skipInference = true; - } - - if (!skipInference) { - inferredParamType = stripLiteralValue(defaultValueType); - } - } - - if (inferredParamType) { - const fileInfo = AnalyzerNodeInfo.getFileInfo(functionNode); - if (fileInfo.isInPyTypedPackage && !fileInfo.isStubFile) { - inferredParamType = TypeBase.cloneForAmbiguousType(inferredParamType); - } - } - - return inferredParamType; + return inferParameterTypeFromDefaultValue(paramValueExpr); } return undefined; } + function inferParameterTypeFromDefaultValue(paramValueExpr: ExpressionNode) { + const defaultValueType = getTypeOfExpression(paramValueExpr, EvaluatorFlags.ConvertEllipsisToAny).type; + + let inferredParamType: Type | undefined; + + // Is the default value a "None" or an instance of some private class (one + // whose name starts with an underscore)? If so, we will assume that the + // value is a singleton sentinel. The actual supported type is going to be + // a union of this type and Unknown. + if ( + isNoneInstance(defaultValueType) || + (isClassInstance(defaultValueType) && isPrivateOrProtectedName(defaultValueType.details.name)) + ) { + inferredParamType = combineTypes([defaultValueType, UnknownType.create()]); + } else { + let skipInference = false; + + if (isFunction(defaultValueType) || isOverloadedFunction(defaultValueType)) { + // Do not infer parameter types that use a lambda or another function as a + // default value. We're likely to generate false positives in this case. + // It's not clear whether parameters should be positional-only or not. + skipInference = true; + } else if ( + isClassInstance(defaultValueType) && + ClassType.isBuiltIn(defaultValueType, ['tuple', 'list', 'set', 'dict']) + ) { + // Do not infer certain types like tuple because it's likely to be + // more restrictive (narrower) than intended. + skipInference = true; + } + + if (!skipInference) { + inferredParamType = stripLiteralValue(defaultValueType); + } + } + + if (inferredParamType) { + const fileInfo = AnalyzerNodeInfo.getFileInfo(paramValueExpr); + if (fileInfo.isInPyTypedPackage && !fileInfo.isStubFile) { + inferredParamType = TypeBase.cloneForAmbiguousType(inferredParamType); + } + } + + return inferredParamType; + } + // Transforms the parameter type based on its category. If it's a simple parameter, // no transform is applied. If it's a var-arg or keyword-arg parameter, the type // is wrapped in a List or Dict. diff --git a/packages/pyright-internal/src/tests/samples/lambda14.py b/packages/pyright-internal/src/tests/samples/lambda14.py new file mode 100644 index 000000000..31425eb04 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/lambda14.py @@ -0,0 +1,9 @@ +# This sample tests type inference for a lambda that has no inference +# context but has a default argument value. + +lambda1 = lambda x="": x +reveal_type(lambda1, expected_text='(x: str = "") -> str') + +lambda2 = lambda x=None: x +reveal_type(lambda2, expected_text="(x: Unknown | None = None) -> (Unknown | None)") + diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index 45adf8ac6..550600ca7 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -741,6 +741,12 @@ test('Lambda13', () => { TestUtils.validateResults(analysisResults, 0); }); +test('Lambda14', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['lambda14.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('Call1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['call1.py']);