diff --git a/packages/pyright-internal/src/analyzer/constraintSolver.ts b/packages/pyright-internal/src/analyzer/constraintSolver.ts index 5a9e0d1cf..7d3355fc8 100644 --- a/packages/pyright-internal/src/analyzer/constraintSolver.ts +++ b/packages/pyright-internal/src/analyzer/constraintSolver.ts @@ -95,7 +95,6 @@ export function assignTypeToTypeVar( console.log(`${indent}destType: ${evaluator.printType(destType)}`); console.log(`${indent}srcType: ${evaluator.printType(srcType)}`); console.log(`${indent}flags: ${flags}`); - console.log(`${indent}scopes: ${(typeVarContext.getSolveForScopes() || []).join(', ')}`); logTypeVarContext(evaluator, typeVarContext, indent); } diff --git a/packages/pyright-internal/src/analyzer/constructors.ts b/packages/pyright-internal/src/analyzer/constructors.ts index c463bcc6f..f5f3739e0 100644 --- a/packages/pyright-internal/src/analyzer/constructors.ts +++ b/packages/pyright-internal/src/analyzer/constructors.ts @@ -122,7 +122,9 @@ export function validateConstructorArgs( const aliasInfo = type.props?.typeAliasInfo; if (aliasInfo?.typeParams && !aliasInfo.typeArgs) { const typeAliasTypeVarContext = new TypeVarContext(aliasInfo.typeVarScopeId); - type = applySolvedTypeVars(type, typeAliasTypeVarContext, { useDefaultForUnsolved: true }) as ClassType; + type = applySolvedTypeVars(type, typeAliasTypeVarContext, { + replaceUnsolved: { scopeIds: [aliasInfo.typeVarScopeId], tupleClassType: evaluator.getTupleClassType() }, + }) as ClassType; } const metaclassResult = validateMetaclassCall( @@ -277,7 +279,12 @@ function validateNewAndInitMethods( newMethodReturnType = applySolvedTypeVars( ClassType.cloneAsInstance(type), new TypeVarContext(getTypeVarScopeId(type)), - { useDefaultForUnsolved: true, tupleClassType: evaluator.getTupleClassType() } + { + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(type) ?? [], + tupleClassType: evaluator.getTupleClassType(), + }, + } ) as ClassType; } @@ -597,7 +604,11 @@ function applyExpectedSubtypeForConstructor( typeVarContext: TypeVarContext ): Type | undefined { const specializedType = applySolvedTypeVars(ClassType.cloneAsInstance(type), typeVarContext, { - applyUnificationVars: true, + replaceUnsolved: { + scopeIds: [], + tupleClassType: evaluator.getTupleClassType(), + applyUnificationVars: true, + }, }); if (!evaluator.assignType(expectedSubtype, specializedType)) { @@ -625,7 +636,13 @@ function applyExpectedTypeForConstructor( // If this isn't a generic type or it's a type that has already been // explicitly specialized, the expected type isn't applicable. if (type.shared.typeParams.length === 0 || type.priv.typeArgs) { - return applySolvedTypeVars(ClassType.cloneAsInstance(type), typeVarContext, { applyUnificationVars: true }); + return applySolvedTypeVars(ClassType.cloneAsInstance(type), typeVarContext, { + replaceUnsolved: { + scopeIds: [], + tupleClassType: evaluator.getTupleClassType(), + applyUnificationVars: true, + }, + }); } if (inferenceContext) { @@ -646,8 +663,12 @@ function applyExpectedTypeForConstructor( } const specializedType = applySolvedTypeVars(type, typeVarContext, { - useDefaultForUnsolved: defaultIfNotFound, - tupleClassType: evaluator.getTupleClassType(), + replaceUnsolved: defaultIfNotFound + ? { + scopeIds: getTypeVarScopeIds(type) ?? [], + tupleClassType: evaluator.getTupleClassType(), + } + : undefined, }) as ClassType; return ClassType.cloneAsInstance(specializedType); } @@ -939,8 +960,10 @@ function createFunctionFromInitMethod( }); returnType = applySolvedTypeVars(objectType, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: evaluator.getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(objectType) ?? [], + tupleClassType: evaluator.getTupleClassType(), + }, }) as ClassType; } } diff --git a/packages/pyright-internal/src/analyzer/dataClasses.ts b/packages/pyright-internal/src/analyzer/dataClasses.ts index bff6b3d9f..690b5e803 100644 --- a/packages/pyright-internal/src/analyzer/dataClasses.ts +++ b/packages/pyright-internal/src/analyzer/dataClasses.ts @@ -824,8 +824,10 @@ function getConverterInputType( if (evaluator.assignType(targetFunction, signature, diagAddendum, inputTypeVarContext)) { const overloadSolution = applySolvedTypeVars(typeVar, inputTypeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: evaluator.getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(typeVar) ?? [], + tupleClassType: evaluator.getTupleClassType(), + }, }); acceptedTypes.push(overloadSolution); } diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 5f1742eba..1f52c898f 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -46,6 +46,7 @@ import { doForEachSubtype, getTypeCondition, getTypeVarScopeId, + getTypeVarScopeIds, getUnknownTypeForCallable, isLiteralType, isLiteralTypeOrUnion, @@ -966,8 +967,10 @@ function narrowTypeBasedOnClassPattern( ) ) { resultType = applySolvedTypeVars(matchTypeInstance, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: evaluator.getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(unexpandedSubtype) ?? [], + tupleClassType: evaluator.getTupleClassType(), + }, }) as ClassType; } } diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 5eae0981b..5a0b22f61 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -4963,8 +4963,10 @@ export function createTypeEvaluator( let defaultType: Type; if (param.shared.isDefaultExplicit || isParamSpec(param)) { defaultType = applySolvedTypeVars(param, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: [aliasInfo.typeVarScopeId], + tupleClassType: getTupleClassType(), + }, }); } else if (isTypeVarTuple(param) && tupleClass && isInstantiableClass(tupleClass)) { defaultType = makeTupleObject( @@ -4991,8 +4993,10 @@ export function createTypeEvaluator( type = TypeBase.cloneForTypeAlias( applySolvedTypeVars(type, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: [aliasInfo.typeVarScopeId], + tupleClassType: getTupleClassType(), + }, }), aliasInfo.name, aliasInfo.fullName, @@ -6932,8 +6936,10 @@ export function createTypeEvaluator( typeArgType = convertToInstance(typeArgs[index].type); } else if (param.shared.isDefaultExplicit) { typeArgType = applySolvedTypeVars(param, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: [aliasInfo.typeVarScopeId], + tupleClassType: getTupleClassType(), + }, }); } else { typeArgType = UnknownType.create(); @@ -11333,8 +11339,11 @@ export function createTypeEvaluator( }); expectedType = applySolvedTypeVars(genericReturnType, tempTypeVarContext, { - useUnknownForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(returnType) ?? [], + useUnknown: true, + tupleClassType: getTupleClassType(), + }, }); assignFlags |= AssignTypeFlags.SkipPopulateUnknownExpectedType; @@ -11581,11 +11590,13 @@ export function createTypeEvaluator( } let specializedReturnType = applySolvedTypeVars(returnType, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), - unsolvedExemptTypeVars: getUnknownExemptTypeVarsForReturnType(type, returnType), - eliminateUnsolvedInUnions, - applyUnificationVars: true, + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(type) ?? [], + unsolvedExemptTypeVars: getUnknownExemptTypeVarsForReturnType(type, returnType), + tupleClassType: getTupleClassType(), + eliminateUnsolvedInUnions, + applyUnificationVars: true, + }, }); specializedReturnType = addConditionToType(specializedReturnType, typeCondition); @@ -12372,8 +12383,10 @@ export function createTypeEvaluator( const typeVarContext = new TypeVarContext(typeVar.priv.scopeId); const concreteDefaultType = makeTopLevelTypeVarsConcrete( applySolvedTypeVars(typeVar.shared.defaultType, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(typeVar) ?? [], + tupleClassType: getTupleClassType(), + }, }) ); @@ -14046,7 +14059,13 @@ export function createTypeEvaluator( } return mapSubtypes( - applySolvedTypeVars(inferenceContext.expectedType, typeVarContext, { applyUnificationVars: true }), + applySolvedTypeVars(inferenceContext.expectedType, typeVarContext, { + replaceUnsolved: { + scopeIds: [], + tupleClassType: getTupleClassType(), + applyUnificationVars: true, + }, + }), (subtype) => { if (entryTypes.length !== 1) { return subtype; @@ -14346,7 +14365,11 @@ export function createTypeEvaluator( ) ) { functionType = applySolvedTypeVars(functionType, typeVarContext, { - applyUnificationVars: true, + replaceUnsolved: { + scopeIds: [], + tupleClassType: getTupleClassType(), + applyUnificationVars: true, + }, }) as FunctionType; } } @@ -18301,17 +18324,22 @@ export function createTypeEvaluator( // If the parameter type is generic, specialize it in the context // of the child class. if (requiresSpecialization(inferredParamType) && isClass(baseClassMemberInfo.classType)) { + const scopeIds: TypeVarScopeId[] = + getTypeVarScopeIds(baseClassMemberInfo.classType) ?? []; const typeVarContext = buildTypeVarContextFromSpecializedClass( baseClassMemberInfo.classType ); // Add the scope of the method to handle any function-scoped TypeVars. typeVarContext.addSolveForScope(ParseTreeUtils.getScopeIdForNode(baseClassMethodNode)); + scopeIds.push(ParseTreeUtils.getScopeIdForNode(baseClassMethodNode)); // Replace any unsolved TypeVars with Unknown (including all function-scoped TypeVars). inferredParamType = applySolvedTypeVars(inferredParamType, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds, + tupleClassType: getTupleClassType(), + }, }); } @@ -20366,8 +20394,10 @@ export function createTypeEvaluator( } const solvedDefaultType = applySolvedTypeVars(typeParam, typeVarContext, { - useDefaultForUnsolved: true, - tupleClassType: getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(classType) ?? [], + tupleClassType: getTupleClassType(), + }, }); typeArgTypes.push(solvedDefaultType); setTypeVarType(typeVarContext, typeParam, solvedDefaultType); diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index d4ddea026..9052580d5 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -75,6 +75,7 @@ import { getSpecializedTupleType, getTypeCondition, getTypeVarScopeId, + getTypeVarScopeIds, getUnknownTypeForCallable, isInstantiableMetaclass, isLiteralType, @@ -1485,8 +1486,11 @@ function narrowTypeForIsInstanceInternal( unspecializedFilterType, typeVarContext, { - useUnknownForUnsolved: true, - tupleClassType: evaluator.getTupleClassType(), + replaceUnsolved: { + scopeIds: getTypeVarScopeIds(filterType) ?? [], + useUnknown: true, + tupleClassType: evaluator.getTupleClassType(), + }, } ) as ClassType; } diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index b890ba9b9..dac3ad8c2 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -243,13 +243,15 @@ export const enum AssignTypeFlags { export interface ApplyTypeVarOptions { typeClassType?: ClassType; - tupleClassType?: ClassType; - useDefaultForUnsolved?: boolean; - useUnknownForUnsolved?: boolean; - unsolvedExemptTypeVars?: TypeVarType[]; + replaceUnsolved?: { + scopeIds: TypeVarScopeId[]; + tupleClassType: ClassType | undefined; + unsolvedExemptTypeVars?: TypeVarType[]; + useUnknown?: boolean; + eliminateUnsolvedInUnions?: boolean; + applyUnificationVars?: boolean; + }; useLowerBoundOnly?: boolean; - eliminateUnsolvedInUnions?: boolean; - applyUnificationVars?: boolean; } export interface InferenceContext { @@ -1540,17 +1542,11 @@ export function applySolvedTypeVars( options: ApplyTypeVarOptions = {} ): Type { // Use a shortcut if the typeVarContext is empty and no transform is necessary. - if ( - typeVarContext.isEmpty() && - !options.useDefaultForUnsolved && - !options.useUnknownForUnsolved && - !options.eliminateUnsolvedInUnions && - !options.applyUnificationVars - ) { + if (typeVarContext.isEmpty() && !options.replaceUnsolved) { return type; } - if (options.applyUnificationVars) { + if (options.replaceUnsolved?.applyUnificationVars) { applyUnificationVars(typeVarContext); } @@ -4268,9 +4264,12 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { } if (subtype.shared.typeParams && !subtype.priv.typeArgs) { - if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) { - return this._options.useUnknownForUnsolved - ? specializeWithUnknownTypeArgs(subtype, this._options.tupleClassType) + if (this._options.replaceUnsolved) { + return this._options.replaceUnsolved.useUnknown + ? specializeWithUnknownTypeArgs( + subtype, + this._options.replaceUnsolved.tupleClassType + ) : specializeWithDefaultTypeArgs(subtype); } } @@ -4299,7 +4298,7 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { return replacement; } - if (!this._options.useDefaultForUnsolved && !this._options.useUnknownForUnsolved) { + if (!this._options.replaceUnsolved) { return replacement; } } @@ -4308,24 +4307,12 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { return undefined; } - // If this typeVar is in scope for what we're solving but the type - // var map doesn't contain any entry for it, replace with the - // default or Unknown. - let useDefaultOrUnknown = false; - if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) { - useDefaultOrUnknown = true; - } else if (this._options.applyUnificationVars && typeVar.priv.isUnificationVar) { - useDefaultOrUnknown = true; + // Use the default value if there is one. + if (typeVar.shared.isDefaultExplicit && !this._options.replaceUnsolved?.useUnknown) { + return this._solveDefaultType(typeVar, recursionCount); } - if (useDefaultOrUnknown) { - // Use the default value if there is one. - if (typeVar.shared.isDefaultExplicit && !this._options.useUnknownForUnsolved) { - return this._solveDefaultType(typeVar, recursionCount); - } - - return getUnknownForTypeVar(typeVar, this._options.tupleClassType); - } + return getUnknownForTypeVar(typeVar, this._options.replaceUnsolved?.tupleClassType); } // If we're solving a default type, handle type variables with no scope ID. @@ -4354,7 +4341,7 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { // in cases where TypeVars can go unsolved due to unions in parameter // annotations, like this: // def test(x: Union[str, T]) -> Union[str, T] - if (this._options.eliminateUnsolvedInUnions) { + if (this._options.replaceUnsolved?.eliminateUnsolvedInUnions) { if ( isTypeVar(preTransform) && this._shouldReplaceTypeVar(preTransform) && @@ -4374,7 +4361,7 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { // If useDefaultForUnsolved or useUnknownForUnsolved is true, the postTransform type will // be Unknown, which we want to eliminate. - if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) { + if (this._options.replaceUnsolved) { if (isUnknown(postTransform)) { return undefined; } @@ -4438,24 +4425,13 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { return undefined; } - let useDefaultOrUnknown = false; - if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) { - useDefaultOrUnknown = true; - } else if (this._options.applyUnificationVars && paramSpec.priv.isUnificationVar) { - useDefaultOrUnknown = true; + // Use the default value if there is one. + if (paramSpec.shared.isDefaultExplicit && !this._options.replaceUnsolved?.useUnknown) { + return convertTypeToParamSpecValue(this._solveDefaultType(paramSpec, recursionCount)); } - if (useDefaultOrUnknown) { - // Use the default value if there is one. - if (paramSpec.shared.isDefaultExplicit && !this._options.useUnknownForUnsolved) { - return convertTypeToParamSpecValue(this._solveDefaultType(paramSpec, recursionCount)); - } - - // Convert to the ParamSpec equivalent of "Unknown". - return ParamSpecType.getUnknown(); - } - - return undefined; + // Convert to the ParamSpec equivalent of "Unknown". + return ParamSpecType.getUnknown(); } override transformConditionalType(type: Type, recursionCount: number): Type { @@ -4531,18 +4507,35 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { } private _shouldReplaceUnsolvedTypeVar(typeVar: TypeVarType): boolean { + // Never replace nested TypeVars with unknown. if (this.pendingTypeVarTransformations.size > 0) { return false; } - const exemptTypeVars = this._options.unsolvedExemptTypeVars; + if (!typeVar.priv.scopeId) { + return false; + } + + if (!this._options.replaceUnsolved) { + return false; + } + + if (TypeVarType.isUnification(typeVar) && this._options.replaceUnsolved.applyUnificationVars) { + return true; + } + + if (!this._options.replaceUnsolved.scopeIds.includes(typeVar.priv.scopeId)) { + return false; + } + + const exemptTypeVars = this._options.replaceUnsolved?.unsolvedExemptTypeVars; if (exemptTypeVars) { if (exemptTypeVars.some((t) => isTypeSame(t, typeVar, { ignoreTypeFlags: true }))) { return false; } } - return this._typeVarContext.hasSolveForScope(typeVar.priv.scopeId); + return true; } private _solveDefaultType(typeVar: TypeVarType, recursionCount: number) {