diff --git a/docs/type-concepts.md b/docs/type-concepts.md index d9c07959e..adfe0bba6 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -578,3 +578,43 @@ reveal_type(Parent.x) # object reveal_type(Child.x) # int ``` +#### Type Variable Scoping + +A type variable must be bound to a valid scope (a class, function, or type alias) before it can be used within that scope. + +Pyright displays the bound scope for a type variable using an `@` symbol. For example, `T@func` means that type variable `T` is bound to function `func`. + +```python +S = TypeVar("S") +T = TypeVar("T") + +def func(a: T) -> T: + b: T = a # T refers to T@func + reveal_type(b) # T@func + + c: S # Error: S has no bound scope in this context + return b +``` + +When a TypeVar or ParamSpec appears within parameter or return type annotations for a function and it is not already bound to an outer scope, it is normally bound to the function. As an exception to this rule, if the TypeVar or ParamSpec appears only within the return type annotation of the function and only within a single Callable in the return type, it is bound to that Callable rather than the function. This allows a function to return a generic Callable. + +```python +# T is bound to func1 because it appears in a parameter type annotation. +def func1(a: T) -> Callable[[T], T]: + a: T # OK because T is bound to func1 + +# T is bound to the return callable rather than func2 because it appears +# only within a return Callable. +def func2() -> Callable[[T], T]: + a: T # Error because T has no bound scope in this context + +# T is bound to func3 because it appears outside of a Callable. +def func3() -> Callable[[T], T] | T: + ... + +# This scoping logic applies also to type aliases used within a return +# type annotation. T is bound to the return Callable rather than func4. +Transform = Callable[[S], S] +def func4() -> Transform[T]: + ... +``` diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 80030b9e6..cb2b157c2 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -27,7 +27,7 @@ import { DiagnosticRule } from '../common/diagnosticRules'; import { convertOffsetsToRange, convertOffsetToPosition } from '../common/positionUtils'; import { PythonVersion } from '../common/pythonVersion'; import { TextRange } from '../common/textRange'; -import { Localizer } from '../localization/localize'; +import { Localizer, ParameterizedString } from '../localization/localize'; import { ArgumentCategory, AssignmentNode, @@ -271,6 +271,7 @@ import { isTupleClass, isTypeAliasPlaceholder, isTypeAliasRecursive, + isTypeVarLimitedToCallable, isUnboundedTupleClass, isUnionableType, isVarianceOfTypeArgumentCompatible, @@ -396,6 +397,12 @@ export interface DescriptorTypeResult { isAsymmetricDescriptor: boolean; } +interface ScopedTypeVarResult { + type: TypeVarType; + isRescoped: boolean; + foundInterveningClass: boolean; +} + interface AliasMapEntry { alias: string; module: 'builtins' | 'collections' | 'self'; @@ -4434,7 +4441,19 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions const outerFunctionScope = ParseTreeUtils.getEnclosingClassOrFunction(enclosingScope); if (outerFunctionScope?.nodeType === ParseNodeType.Function) { - enclosingScope = outerFunctionScope; + if (scopedTypeVarInfo.isRescoped) { + addDiagnostic( + AnalyzerNodeInfo.getFileInfo(node).diagnosticRuleSet + .reportGeneralTypeIssues, + DiagnosticRule.reportGeneralTypeIssues, + Localizer.Diagnostic.paramSpecScopedToReturnType().format({ + name: type.details.name, + }), + node + ); + } else { + enclosingScope = outerFunctionScope; + } } else if (!scopedTypeVarInfo.type.scopeId) { addDiagnostic( AnalyzerNodeInfo.getFileInfo(node).diagnosticRuleSet.reportGeneralTypeIssues, @@ -4488,9 +4507,16 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions (type.scopeId === undefined || scopedTypeVarInfo.foundInterveningClass) && !type.details.isSynthesized ) { - const message = isParamSpec(type) - ? Localizer.Diagnostic.paramSpecNotUsedByOuterScope() - : Localizer.Diagnostic.typeVarNotUsedByOuterScope(); + let message: ParameterizedString<{ name: string }>; + if (scopedTypeVarInfo.isRescoped) { + message = isParamSpec(type) + ? Localizer.Diagnostic.paramSpecScopedToReturnType() + : Localizer.Diagnostic.typeVarScopedToReturnType(); + } else { + message = isParamSpec(type) + ? Localizer.Diagnostic.paramSpecNotUsedByOuterScope() + : Localizer.Diagnostic.typeVarNotUsedByOuterScope(); + } addDiagnostic( AnalyzerNodeInfo.getFileInfo(node).diagnosticRuleSet.reportGeneralTypeIssues, DiagnosticRule.reportGeneralTypeIssues, @@ -4619,10 +4645,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Walks up the parse tree to find a function, class, or type alias // declaration that provides the context for a type variable. - function findScopedTypeVar( - node: ExpressionNode, - type: TypeVarType - ): { type: TypeVarType; foundInterveningClass: boolean } { + function findScopedTypeVar(node: ExpressionNode, type: TypeVarType): ScopedTypeVarResult { let curNode: ParseNode | undefined = node; let nestedClassCount = 0; @@ -4648,7 +4671,15 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } else if (curNode.nodeType === ParseNodeType.Function) { const functionTypeInfo = getTypeOfFunction(curNode); if (functionTypeInfo) { - typeParametersForScope = functionTypeInfo.functionType.details.typeParameters; + const functionDetails = functionTypeInfo.functionType.details; + typeParametersForScope = functionDetails.typeParameters; + + // Was this type parameter "rescoped" to a callable found within the + // return type annotation? If so, it is not available for use within + // the function body. + if (functionDetails.rescopedTypeParameters?.some((tp) => tp.details.name === type.details.name)) { + return { type, isRescoped: true, foundInterveningClass: false }; + } } scopeUsesTypeParameterSyntax = !!curNode.typeParameters; @@ -4662,7 +4693,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions if (match?.scopeId) { // Use the scoped version of the TypeVar rather than the (unscoped) original type. type = TypeVarType.cloneForScopeId(type, match.scopeId, match.scopeName!, match.scopeType!); - return { type, foundInterveningClass: nestedClassCount > 1 && !scopeUsesTypeParameterSyntax }; + return { + type, + isRescoped: false, + foundInterveningClass: nestedClassCount > 1 && !scopeUsesTypeParameterSyntax, + }; } } @@ -4711,6 +4746,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions leftType.details.recursiveTypeAliasName, TypeVarScopeType.TypeAlias ), + isRescoped: false, foundInterveningClass: false, }; } @@ -4720,7 +4756,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } // Return the original type. - return { type, foundInterveningClass: false }; + return { type, isRescoped: false, foundInterveningClass: false }; } function getTypeOfMemberAccess(node: MemberAccessNode, flags: EvaluatorFlags): TypeResult { @@ -16612,20 +16648,13 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // If there was a defined return type, analyze that first so when we // walk the contents of the function, return statements can be // validated against this type. - if (node.returnTypeAnnotation) { + const returnTypeAnnotationNode = + node.returnTypeAnnotation ?? node.functionAnnotationComment?.returnTypeAnnotation; + if (returnTypeAnnotationNode) { // Temporarily set the return type to unknown in case of recursion. functionType.details.declaredReturnType = UnknownType.create(); - const returnType = getTypeOfAnnotation(node.returnTypeAnnotation, { - associateTypeVarsWithScope: true, - disallowRecursiveTypeAlias: true, - }); - functionType.details.declaredReturnType = returnType; - } else if (node.functionAnnotationComment) { - // Temporarily set the return type to unknown in case of recursion. - functionType.details.declaredReturnType = UnknownType.create(); - - const returnType = getTypeOfAnnotation(node.functionAnnotationComment.returnTypeAnnotation, { + const returnType = getTypeOfAnnotation(returnTypeAnnotationNode, { associateTypeVarsWithScope: true, disallowRecursiveTypeAlias: true, }); @@ -16645,12 +16674,12 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } } - // If the function doesn't use PEP 695 syntax, accumulate - // any type parameters used in the return type. - if (functionType.details.declaredReturnType) { - addTypeVarsToListIfUnique( - typeParametersSeen, - getTypeVarArgumentsRecursive(functionType.details.declaredReturnType) + // Accumulate any type parameters used in the return type. + if (functionType.details.declaredReturnType && returnTypeAnnotationNode) { + rescopeTypeVarsForCallableReturnType( + functionType.details.declaredReturnType, + functionType, + typeParametersSeen ); } @@ -16722,6 +16751,48 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return { functionType, decoratedType }; } + // If the declared return type of a function contains type variables that + // are found nowhere else in the signature and are contained within a + // Callable, these type variables are "rescoped" from the function to + // the Callable. + function rescopeTypeVarsForCallableReturnType( + returnType: Type, + functionType: FunctionType, + typeParametersSeen: TypeVarType[] + ) { + const typeVarsInReturnType = getTypeVarArgumentsRecursive(returnType); + const rescopedTypeVars: TypeVarType[] = []; + + typeVarsInReturnType.forEach((typeVar) => { + if (TypeBase.isInstantiable(typeVar)) { + typeVar = TypeVarType.cloneAsInstance(typeVar); + } + + // If this type variable isn't scoped to this function, it is probably + // associated with an outer scope. + if (typeVar.scopeId !== functionType.details.typeVarScopeId) { + return; + } + + // If this type variable was already seen in one or more input parameters, + // don't attempt to rescope it. + if (typeParametersSeen.some((tp) => isTypeSame(convertToInstance(tp), typeVar))) { + return; + } + + // Is this type variable seen outside of a single callable? + if (isTypeVarLimitedToCallable(returnType, typeVar)) { + rescopedTypeVars.push(typeVar); + } + }); + + addTypeVarsToListIfUnique(typeParametersSeen, typeVarsInReturnType); + + // Note that the type parameters have been rescoped so they are not + // considered valid for the body of this function. + functionType.details.rescopedTypeParameters = rescopedTypeVars; + } + function adjustParameterAnnotatedType(param: ParameterNode, type: Type): Type { // PEP 484 indicates that if a parameter has a default value of 'None' // the type checker should assume that the type is optional (i.e. a union diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index cdeecd53a..c40722790 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -1529,6 +1529,96 @@ export function getTypeVarArgumentsRecursive(type: Type, recursionCount = 0): Ty return []; } +// Determines if the type variable appears within the type and only within +// a particular Callable within that type. +export function isTypeVarLimitedToCallable(type: Type, typeVar: TypeVarType): boolean { + const info = getTypeVarWithinTypeInfoRecursive(type, typeVar); + return info.isTypeVarUsed && info.isUsedInCallable; +} + +function getTypeVarWithinTypeInfoRecursive( + type: Type, + typeVar: TypeVarType, + recursionCount = 0 +): { + isTypeVarUsed: boolean; + isUsedInCallable: boolean; +} { + if (recursionCount > maxTypeRecursionCount) { + return { isTypeVarUsed: false, isUsedInCallable: false }; + } + recursionCount++; + + let typeVarUsedCount = 0; + let usedInCallableCount = 0; + + if (isTypeVar(type)) { + // Ignore P.args or P.kwargs types. + if (!isParamSpec(type) || !type.paramSpecAccess) { + if (isTypeSame(typeVar, convertToInstance(type))) { + typeVarUsedCount++; + } + } + } else if (isClass(type)) { + if (type.typeArguments) { + type.typeArguments.forEach((typeArg) => { + const subResult = getTypeVarWithinTypeInfoRecursive(typeArg, typeVar, recursionCount); + if (subResult.isTypeVarUsed) { + typeVarUsedCount++; + } + if (subResult.isUsedInCallable) { + usedInCallableCount++; + } + }); + } + } else if (isUnion(type)) { + doForEachSubtype(type, (subtype) => { + const subResult = getTypeVarWithinTypeInfoRecursive(subtype, typeVar, recursionCount); + if (subResult.isTypeVarUsed) { + typeVarUsedCount++; + } + if (subResult.isUsedInCallable) { + usedInCallableCount++; + } + }); + } else if (isFunction(type)) { + for (let i = 0; i < type.details.parameters.length; i++) { + if ( + getTypeVarWithinTypeInfoRecursive( + FunctionType.getEffectiveParameterType(type, i), + typeVar, + recursionCount + ).isTypeVarUsed + ) { + typeVarUsedCount++; + } + } + + if (type.details.paramSpec) { + if (isTypeSame(typeVar, convertToInstance(type.details.paramSpec))) { + typeVarUsedCount++; + } + } + + const returnType = FunctionType.getSpecializedReturnType(type); + if (returnType) { + if (getTypeVarWithinTypeInfoRecursive(returnType, typeVar, recursionCount).isTypeVarUsed) { + typeVarUsedCount++; + } + } + + if (typeVarUsedCount > 0) { + typeVarUsedCount = 1; + usedInCallableCount = 1; + } + } + + return { + isTypeVarUsed: typeVarUsedCount > 0, + isUsedInCallable: usedInCallableCount === 1 && typeVarUsedCount === 1, + }; +} + // Creates a specialized version of the class, filling in any unspecified // type arguments with Unknown. export function specializeClassType(type: ClassType): ClassType { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index ccc69fca2..d4beb4644 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -1222,6 +1222,11 @@ interface FunctionDetails { // Parameter specification used only for Callable types created // with a ParamSpec representing the parameters. paramSpec?: TypeVarType | undefined; + + // If the function is generic (has one or more typeParameters) and + // one or more of these appear only within the return type and within + // a callable, they are rescoped to that callable. + rescopedTypeParameters?: TypeVarType[]; } export interface SpecializedFunctionTypes { diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 767e9d0df..03c612a67 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -647,6 +647,8 @@ export namespace Localizer { new ParameterizedString<{ type: string }>(getRawString('Diagnostic.paramSpecNotBound')); export const paramSpecNotUsedByOuterScope = () => new ParameterizedString<{ name: string }>(getRawString('Diagnostic.paramSpecNotUsedByOuterScope')); + export const paramSpecScopedToReturnType = () => + new ParameterizedString<{ name: string }>(getRawString('Diagnostic.paramSpecScopedToReturnType')); export const paramSpecUnknownArg = () => getRawString('Diagnostic.paramSpecUnknownArg'); export const paramSpecUnknownMember = () => new ParameterizedString<{ name: string }>(getRawString('Diagnostic.paramSpecUnknownMember')); @@ -905,6 +907,8 @@ export namespace Localizer { new ParameterizedString<{ name: string; param: string }>( getRawString('Diagnostic.typeVarPossiblyUnsolvable') ); + export const typeVarScopedToReturnType = () => + new ParameterizedString<{ name: string }>(getRawString('Diagnostic.typeVarScopedToReturnType')); export const typeVarSingleConstraint = () => getRawString('Diagnostic.typeVarSingleConstraint'); export const typeVarsNotInGenericOrProtocol = () => getRawString('Diagnostic.typeVarsNotInGenericOrProtocol'); export const typeVarTupleContext = () => getRawString('Diagnostic.typeVarTupleContext'); diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 9c6dd2f92..f91ad7dd6 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -312,6 +312,7 @@ "paramSpecKwargsUsage": "\"kwargs\" member of ParamSpec is valid only when used with **kwargs parameter", "paramSpecNotBound": "Param spec \"{type}\" has no bound value", "paramSpecNotUsedByOuterScope": "ParamSpec \"{name}\" has no meaning in this context", + "paramSpecScopedToReturnType": "ParamSpec \"{name}\" is scoped to a callable within the return type and cannot be referenced in the function body", "paramSpecUnknownArg": "ParamSpec does not support more than one argument", "paramSpecUnknownMember": "\"{name}\" is not a known member of ParamSpec", "paramSpecUnknownParam": "\"{name}\" is unknown parameter to ParamSpec", @@ -455,6 +456,7 @@ "typeVarNotSubscriptable": "TypeVar \"{type}\" is not subscriptable", "typeVarNotUsedByOuterScope": "Type variable \"{name}\" has no meaning in this context", "typeVarPossiblyUnsolvable": "Type variable \"{name}\" may go unsolved if caller supplies no argument for parameter \"{param}\"", + "typeVarScopedToReturnType": "Type variable \"{name}\" is scoped to a callable within the return type and cannot be referenced in the function body", "typeVarSingleConstraint": "TypeVar must have at least two constrained types", "typeVarsNotInGenericOrProtocol": "Generic[] or Protocol[] must include all type variables", "typeVarTupleContext": "TypeVarTuple not allowed in this context", diff --git a/packages/pyright-internal/src/tests/samples/typeVar12.py b/packages/pyright-internal/src/tests/samples/typeVar12.py new file mode 100644 index 000000000..69699fe2f --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeVar12.py @@ -0,0 +1,51 @@ +# This sample tests the case where a function-scoped TypeVar or +# ParamSpec is used only within a function's return type and only within +# a single Callable within that return type. In such cases, the TypeVar or +# ParamSpec is rescoped to the Callable rather than the function. + +from typing import Callable, Generic, Optional, ParamSpec, TypeVar + +S = TypeVar('S') +T = TypeVar('T') +P = ParamSpec('P') + +CallableAlias1 = Callable[[T], T] +CallableAlias2 = Callable[[T], T] | T + +def func1() -> Callable[[T], T] | None: + # This should generate an error. + x: Optional[T] = None + +def func2() -> Callable[[T], T] | list[T] | None: + x: Optional[T] = None + +def func3() -> CallableAlias1[T] | None: + # This should generate an error. + x: Optional[T] = None + +def func4() -> CallableAlias2[T] | None: + x: Optional[T] = None + +def func5() -> Callable[[list[T]], set[T]] | None: + # This should generate an error. + x: Optional[T] = None + +def func6() -> Callable[[list[T]], set[T]] | Callable[[set[T]], set[T]] | None: + x: Optional[T] = None + +def func7() -> Callable[P, None] | None: + # This should generate two errors, once for each P reference. + def inner(*args: P.args, **kwargs: P.kwargs) -> None: + pass + return + + +class A(Generic[T]): + def method1(self) -> Callable[[T], T] | None: + x: Optional[T] = None + +class B(Generic[S]): + def method1(self) -> Callable[[T], T] | None: + # This should generate an error. + x: Optional[T] = None + diff --git a/packages/pyright-internal/src/tests/samples/typeVar3.py b/packages/pyright-internal/src/tests/samples/typeVar3.py index db516b6c5..9f69177c1 100644 --- a/packages/pyright-internal/src/tests/samples/typeVar3.py +++ b/packages/pyright-internal/src/tests/samples/typeVar3.py @@ -67,7 +67,7 @@ T = TypeVar("T") def foo() -> Callable[[T], T]: def inner(v: T) -> T: - reveal_type(v, expected_text="T@foo") + reveal_type(v, expected_text="T@inner") return v return inner diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 9bd507d6a..a227d5a3d 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -1144,6 +1144,12 @@ test('TypeVar11', () => { TestUtils.validateResults(analysisResults, 0); }); +test('TypeVar12', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeVar12.py']); + + TestUtils.validateResults(analysisResults, 6); +}); + test('Annotated1', () => { const configOptions = new ConfigOptions('.');