From 0edcd2eecfbf2b9703a99f9b12a00a79c30d511f Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 11 Jun 2024 09:08:46 -0700 Subject: [PATCH] Fixed bug that results in parameter types being converted to `Any` when converting a `NewType` or dataclass constructor to a callable. This addresses #8116. (#8117) --- .../src/analyzer/constraintSolver.ts | 5 +---- .../pyright-internal/src/analyzer/constructors.ts | 2 -- packages/pyright-internal/src/analyzer/protocols.ts | 12 +++++++++++- .../src/tests/samples/constructorCallable1.py | 12 ++++++++++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/constraintSolver.ts b/packages/pyright-internal/src/analyzer/constraintSolver.ts index abf42c83d..e78e5219a 100644 --- a/packages/pyright-internal/src/analyzer/constraintSolver.ts +++ b/packages/pyright-internal/src/analyzer/constraintSolver.ts @@ -326,10 +326,7 @@ export function assignTypeToTypeVar( ) ) { // The srcType is narrower than the current wideTypeBound, so replace it. - // If it's Any, don't replace it because Any is the narrowest type already. - if (!isAnyOrUnknown(curWideTypeBound)) { - newWideTypeBound = adjSrcType; - } + newWideTypeBound = adjSrcType; } else if ( !evaluator.assignType( adjSrcType, diff --git a/packages/pyright-internal/src/analyzer/constructors.ts b/packages/pyright-internal/src/analyzer/constructors.ts index 6a724f374..672fba6a2 100644 --- a/packages/pyright-internal/src/analyzer/constructors.ts +++ b/packages/pyright-internal/src/analyzer/constructors.ts @@ -1087,8 +1087,6 @@ function createFunctionFromInitMethod( const convertedInit = FunctionType.clone(boundInit); convertedInit.details.declaredReturnType = boundInit.strippedFirstParamType ?? selfType ?? objectType; - convertedInit.details.name = ''; - convertedInit.details.fullName = ''; if (convertedInit.specializedTypes) { convertedInit.specializedTypes.returnType = selfType ?? objectType; diff --git a/packages/pyright-internal/src/analyzer/protocols.ts b/packages/pyright-internal/src/analyzer/protocols.ts index e2001435b..a52d49090 100644 --- a/packages/pyright-internal/src/analyzer/protocols.ts +++ b/packages/pyright-internal/src/analyzer/protocols.ts @@ -32,6 +32,7 @@ import { TypeBase, TypeVarType, UnknownType, + Variance, } from './types'; import { applySolvedTypeVars, @@ -778,7 +779,7 @@ function createProtocolTypeVarContext( ); } else if (destType.typeArguments && index < destType.typeArguments.length) { let typeArg = destType.typeArguments[index]; - let flags = AssignTypeFlags.PopulatingExpectedType; + let flags: AssignTypeFlags; let hasUnsolvedTypeVars = requiresSpecialization(typeArg); // If the type argument has unsolved TypeVars, see if they have @@ -787,6 +788,15 @@ function createProtocolTypeVarContext( typeArg = applySolvedTypeVars(typeArg, destTypeVarContext, { useNarrowBoundOnly: true }); flags = AssignTypeFlags.Default; hasUnsolvedTypeVars = requiresSpecialization(typeArg); + } else { + flags = AssignTypeFlags.PopulatingExpectedType; + + const variance = TypeVarType.getVariance(typeParam); + if (variance === Variance.Invariant) { + flags |= AssignTypeFlags.EnforceInvariance; + } else if (variance === Variance.Contravariant) { + flags |= AssignTypeFlags.ReverseTypeVarMatching; + } } if (!hasUnsolvedTypeVars) { diff --git a/packages/pyright-internal/src/tests/samples/constructorCallable1.py b/packages/pyright-internal/src/tests/samples/constructorCallable1.py index 078086726..d7917d2e9 100644 --- a/packages/pyright-internal/src/tests/samples/constructorCallable1.py +++ b/packages/pyright-internal/src/tests/samples/constructorCallable1.py @@ -135,3 +135,15 @@ def func3(t: type[object]): reveal_type( cast_to_callable(t), expected_text="(*args: Any, **kwargs: Any) -> object" ) + + +@dataclass +class G: + value: int + + +def func4(c: Callable[[T1], T2]) -> Callable[[T1], T2]: + return c + + +reveal_type(func4(G), expected_text="(int) -> G")