Fixed bug that results in incorrect type evaluation of a function that accepts a Callable[P, T] and is passed a class object whose constructor needs to be converted to a callable. This addresses #8170. (#8258)

This commit is contained in:
Eric Traut 2024-06-28 10:19:57 -07:00 committed by GitHub
parent 5d5fe5d15c
commit 3c70b4e0d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 48 additions and 2 deletions

View File

@ -25,14 +25,17 @@ import {
InferenceContext,
MemberAccessFlags,
UniqueSignatureTracker,
addTypeVarsToListIfUnique,
applySolvedTypeVars,
buildTypeVarContextFromSpecializedClass,
convertToInstance,
convertTypeToParamSpecValue,
doForEachSignature,
doForEachSubtype,
ensureFunctionSignaturesAreUnique,
getTypeVarArgumentsRecursive,
getTypeVarScopeId,
getTypeVarScopeIds,
isTupleClass,
lookUpClassMember,
mapSubtypes,
@ -60,6 +63,7 @@ import {
isInstantiableClass,
isNever,
isOverloadedFunction,
isParamSpec,
isTypeVar,
isUnion,
isUnknown,
@ -1086,10 +1090,40 @@ function createFunctionFromInitMethod(
}
const convertedInit = FunctionType.clone(boundInit);
convertedInit.details.declaredReturnType = boundInit.strippedFirstParamType ?? selfType ?? objectType;
let returnType = selfType;
if (!returnType) {
returnType = objectType;
// If this is a generic type, self-specialize the class (i.e. fill in
// its own type parameters as type arguments).
if (objectType.details.typeParameters.length > 0 && !objectType.typeArguments) {
const typeVarContext = new TypeVarContext(getTypeVarScopeIds(objectType));
// If a TypeVar is not used in any of the parameter types, it should take
// on its default value (typically Unknown) in the resulting specialized type.
const typeVarsInParams: TypeVarType[] = [];
convertedInit.details.parameters.forEach((param, index) => {
const paramType = FunctionType.getEffectiveParameterType(convertedInit, index);
addTypeVarsToListIfUnique(typeVarsInParams, getTypeVarArgumentsRecursive(paramType));
});
typeVarsInParams.forEach((typeVar) => {
if (isParamSpec(typeVar)) {
typeVarContext.setTypeVarType(typeVar, convertTypeToParamSpecValue(typeVar));
} else {
typeVarContext.setTypeVarType(typeVar, typeVar);
}
});
returnType = applySolvedTypeVars(objectType, typeVarContext, { unknownIfNotFound: true }) as ClassType;
}
}
convertedInit.details.declaredReturnType = boundInit.strippedFirstParamType ?? returnType;
if (convertedInit.specializedTypes) {
convertedInit.specializedTypes.returnType = selfType ?? objectType;
convertedInit.specializedTypes.returnType = returnType;
}
if (!convertedInit.details.docString && classType.details.docString) {

View File

@ -3298,6 +3298,8 @@ export function convertTypeToParamSpecValue(type: Type): FunctionType {
newFunction.details.typeVarScopeId = newFunction.details.higherOrderTypeVarScopeIds.pop();
}
newFunction.details.constructorTypeVarScopeId = type.details.constructorTypeVarScopeId;
return newFunction;
}
@ -3336,6 +3338,7 @@ export function convertParamSpecValueToType(type: FunctionType): Type {
FunctionType.addHigherOrderTypeVarScopeIds(functionType, withoutParamSpec.details.typeVarScopeId);
FunctionType.addHigherOrderTypeVarScopeIds(functionType, withoutParamSpec.details.higherOrderTypeVarScopeIds);
functionType.details.constructorTypeVarScopeId = withoutParamSpec.details.constructorTypeVarScopeId;
withoutParamSpec.details.parameters.forEach((entry, index) => {
FunctionType.addParameter(functionType, {

View File

@ -1764,6 +1764,7 @@ export namespace FunctionType {
FunctionType.addHigherOrderTypeVarScopeIds(newFunction, paramSpecValue.details.typeVarScopeId);
FunctionType.addHigherOrderTypeVarScopeIds(newFunction, paramSpecValue.details.higherOrderTypeVarScopeIds);
newFunction.details.constructorTypeVarScopeId = paramSpecValue.details.constructorTypeVarScopeId;
if (!newFunction.details.methodClass && paramSpecValue.details.methodClass) {
newFunction.details.methodClass = paramSpecValue.details.methodClass;

View File

@ -29,3 +29,11 @@ def func1(t: type[TA]) -> TA: ...
b = B(func1, A)
reveal_type(b, expected_text="B[(t: type[A]), A]")
class C(Generic[TA]):
def __init__(self, _type: type[TA]) -> None: ...
c = B(C, A)
reveal_type(c, expected_text="B[(_type: type[A]), C[A]]")