diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 0c2c33e7e..01ca87afa 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -1293,10 +1293,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions memberName: string, treatAsClassMember: boolean ): FunctionType | OverloadedFunctionType | undefined { - const aliasClass = classType.details.aliasClass; - if (aliasClass) { - classType = aliasClass; - } + classType = ClassType.getAliasClass(classType); const memberInfo = lookUpClassMember( classType, @@ -3757,7 +3754,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ): ClassMemberLookup | undefined { // If this is a special type (like "List") that has an alias class (like // "list"), switch to the alias, which defines the members. - classType = classType.details.aliasClass || classType; + classType = ClassType.getAliasClass(classType); let classLookupFlags = ClassMemberLookupFlags.Default; if (flags & MemberAccessFlags.SkipInstanceMembers) { @@ -4128,7 +4125,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions flags | EvaluatorFlags.DoNotSpecialize ); - if (baseTypeResult.isResolutionCyclical || isTypeAliasPlaceholder(baseTypeResult.type)) { + if (baseTypeResult.isResolutionCyclical) { return { node, type: UnknownType.create(), @@ -4233,6 +4230,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } } + if (isTypeAliasPlaceholder(baseType)) { + const typeArgTypes = getTypeArgs(node.items, flags).map((t) => convertToInstance(t.type)); + const type = TypeBase.cloneForTypeAlias( + baseType, + baseType.details.recursiveTypeAliasName!, + undefined, + typeArgTypes + ); + return { type, node }; + } + const type = doForSubtypes(baseType, (subtype) => { subtype = makeTypeVarsConcrete(subtype); subtype = getClassFromPotentialTypeObject(subtype); @@ -5375,31 +5383,72 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ); if (initMethodType && !skipConstructorCheck(initMethodType)) { - const typeVarMap = new TypeVarMap(); - + // If there is an expected type, analyze the constructor call + // for each of the subtypes that comprise the expected type. If + // one or more analyzes with no errors, use those results. if (expectedType) { - populateTypeVarMapBasedOnExpectedType(type, expectedType, typeVarMap); + returnType = doForSubtypes(expectedType, (expectedSubType) => { + const typeVarMap = new TypeVarMap(); + if (populateTypeVarMapBasedOnExpectedType(type, expectedSubType, typeVarMap)) { + const callResult = suppressDiagnostics(() => { + return validateCallArguments( + errorNode, + argList, + initMethodType, + typeVarMap, + skipUnknownArgCheck, + /* inferReturnTypeIfNeeded */ true, + NoneType.createInstance() + ); + }); + + if (!callResult.argumentErrors) { + // Note that we're specializing the type twice here with the same + // typeVarMap. This handles the case where the expectedType contains + // a type variable that is computed and filled in to the typeVarMap. + const specializedType = specializeType( + specializeType(type, typeVarMap, /* makeConcrete */ true), + typeVarMap, + /* makeConcrete */ true + ) as ClassType; + return applyExpectedSubtypeForConstructor(specializedType, expectedSubType); + } + } + + return undefined; + }); + + if (isNever(returnType)) { + returnType = undefined; + } } - const callResult = validateCallArguments( - errorNode, - argList, - initMethodType, - typeVarMap, - skipUnknownArgCheck, - /* inferReturnTypeIfNeeded */ true, - NoneType.createInstance() - ); - if (!callResult.argumentErrors) { - // Note that we're specializing the type twice here with the same - // typeVarMap. This handles the case where the expectedType contains - // a type variable that is computed and filled in to the typeVarMap. - let specializedType = specializeType(type, typeVarMap, /* makeConcrete */ true) as ClassType; - specializedType = specializeType(specializedType, typeVarMap, /* makeConcrete */ true) as ClassType; - returnType = applyExpectedTypeForConstructor(specializedType, expectedType); - } else { - reportedErrors = true; + if (!returnType) { + const typeVarMap = new TypeVarMap(); + const callResult = validateCallArguments( + errorNode, + argList, + initMethodType, + typeVarMap, + skipUnknownArgCheck, + /* inferReturnTypeIfNeeded */ true, + NoneType.createInstance() + ); + if (!callResult.argumentErrors) { + // Note that we're specializing the type twice here with the same + // typeVarMap. This handles the case where the expectedType contains + // a type variable that is computed and filled in to the typeVarMap. + const specializedType = specializeType( + specializeType(type, typeVarMap, /* makeConcrete */ true), + typeVarMap, + /* makeConcrete */ true + ) as ClassType; + returnType = applyExpectedTypeForConstructor(specializedType, expectedType); + } else { + reportedErrors = true; + } } + validatedTypes = true; skipUnknownArgCheck = true; } @@ -5523,149 +5572,111 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return { argumentErrors: reportedErrors, returnType }; } - // Handles the case where a constructor is a generic type and the type - // arguments are not specified but can be provided by the expected type. - function applyExpectedTypeForConstructor(type: ClassType, expectedType: Type | undefined): ObjectType { + function applyExpectedSubtypeForConstructor(type: ClassType, expectedSubtype: Type): Type | undefined { const objType = ObjectType.create(type); - if (!expectedType) { - return objType; + if (canAssignType(expectedSubtype, objType, new DiagnosticAddendum())) { + // If the expected type is "Any", transform it to an Any. + if (expectedSubtype.category === TypeCategory.Any) { + return expectedSubtype; + } + + const typeVarMap = new TypeVarMap(); + if (populateTypeVarMapBasedOnExpectedType(type, expectedSubtype, typeVarMap)) { + return specializeType(objType, typeVarMap) as ClassType; + } } - // Try to apply for each subtype in the expectedType. The foundMatch - // tracks whether we've already seen a match. If none of them match, - // return the original type. - let foundMatch = false; - const specializedType = doForSubtypes(expectedType, (subtype) => { - if (foundMatch || !isObject(subtype)) { - return undefined; + return undefined; + } + + // Handles the case where a constructor is a generic type and the type + // arguments are not specified but can be provided by the expected type. + function applyExpectedTypeForConstructor(type: ClassType, expectedType: Type | undefined): Type { + if (expectedType) { + const specializedType = doForSubtypes(expectedType, (expectedSubtype) => { + return applyExpectedSubtypeForConstructor(type, expectedSubtype); + }); + + if (!isNever(specializedType)) { + return specializedType; } + } - const expectedClass = subtype.classType; - const typeVarMap = new TypeVarMap(); - if (canAssignType(expectedClass, type, new DiagnosticAddendum(), typeVarMap)) { - foundMatch = true; - return specializeType(expectedClass, typeVarMap) as ClassType; - } - - // If it's the same generic class, see if we can assign the type arguments - // without the variance rules that canAssignType uses. - if ( - ClassType.isSameGenericClass(type, expectedClass) && - expectedClass.typeArguments && - type.typeArguments && - !type.isTypeArgumentExplicit && - expectedClass.typeArguments.length === type.typeArguments.length - ) { - const typeVarMap = new TypeVarMap(); - let isAssignable = true; - expectedClass.typeArguments.forEach((expectedTypeArg, index) => { - const typeTypeArg = type.typeArguments![index]; - if (!canAssignType(expectedTypeArg, typeTypeArg, new DiagnosticAddendum(), typeVarMap)) { - isAssignable = false; - } - }); - - if (isAssignable) { - foundMatch = true; - return specializeType(expectedClass, typeVarMap) as ClassType; - } - } - }); - - return isClass(specializedType) ? ObjectType.create(specializedType) : objType; + return ObjectType.create(type); } // In cases where the expected type is a specialized base class of the // source type, we need to determine which type arguments in the derived // class will make it compatible with the specialized base class. This method // performs this reverse mapping of type arguments and populates the type var - // map for the target type. - function populateTypeVarMapBasedOnExpectedType(type: ClassType, expectedType: Type, typeVarMap: TypeVarMap) { - // If the target type isn't generic, there's nothing for us to do. - if (!requiresSpecialization(type)) { - return; - } - + // map for the target type. If the type is not assignable to the expected type, + // it returns false. + function populateTypeVarMapBasedOnExpectedType( + type: ClassType, + expectedType: Type, + typeVarMap: TypeVarMap + ): boolean { // Try to find a subtype within the expected type that the type can be assigned to. // If found, fill in the typeVarMap with the required specialization type arguments. - let foundMatch = false; - doForSubtypes(expectedType, (subtype) => { - if (!foundMatch && isObject(subtype)) { - // If the expected type is generic (not specialized), we can't proceed. - const expectedTypeArgs = subtype.classType.effectiveTypeArguments || subtype.classType.typeArguments; - if (expectedTypeArgs === undefined) { - return undefined; - } + if (!isObject(expectedType)) { + return false; + } - // If the expected type is the same as the target type (commonly the case), - // we can use a faster method. - if (ClassType.isSameGenericClass(subtype.classType, type)) { - canAssignType(type, subtype.classType, new DiagnosticAddendum(), typeVarMap); - foundMatch = true; - return undefined; - } + // If the expected type is generic (but not specialized), we can't proceed. + const expectedTypeArgs = expectedType.classType.effectiveTypeArguments || expectedType.classType.typeArguments; + if (!expectedTypeArgs) { + return canAssignType(type, expectedType.classType, new DiagnosticAddendum(), typeVarMap); + } - // Create a generic (not specialized) version of the expected type. - const synthExpectedTypeArgs = subtype.classType.details.typeParameters.map((_, index) => { - return TypeVarType.createInstance( - `__dest${index}`, - /* isParamSpec */ false, - /* isSynthesized */ true - ); - }); - const genericExpectedType = ClassType.cloneForSpecialization( - subtype.classType, - synthExpectedTypeArgs, - /* isTypeArgumentExplicit */ true - ); - - // For each type param in the target type, create a placeholder type variable. - const classType = type.details.aliasClass || type; - const typeArgs = classType.details.typeParameters.map((_, index) => { - const typeVar = TypeVarType.createInstance( - `__source${index}`, - /* isParamSpec */ false, - /* isSynthesized */ true - ); - typeVar.details.synthesizedIndex = index; - return typeVar; - }); - - const specializedType = ClassType.cloneForSpecialization( - type, - typeArgs, - /* isTypeArgumentExplicit */ true - ); - const syntheticTypeVarMap = new TypeVarMap(); - if ( - canAssignType(genericExpectedType, specializedType, new DiagnosticAddendum(), syntheticTypeVarMap) - ) { - synthExpectedTypeArgs.forEach((typeVar, index) => { - const synthTypeVar = syntheticTypeVarMap.getTypeVar(typeVar); - - // Is this one of the synthesized type vars we allocated above? If so, - // the type arg that corresponds to this type var maps back to the target type. - if ( - synthTypeVar && - isTypeVar(synthTypeVar) && - synthTypeVar.details.isSynthesized && - synthTypeVar.details.synthesizedIndex !== undefined - ) { - const targetTypeVar = - specializedType.details.typeParameters[synthTypeVar.details.synthesizedIndex]; - if (index < expectedTypeArgs.length) { - typeVarMap.setTypeVar(targetTypeVar, expectedTypeArgs[index], /* isNarrowable */ false); - } - } - }); - - foundMatch = true; - } - } - - return undefined; + // Create a generic version of the expected type. + const synthExpectedTypeArgs = ClassType.getTypeParameters(expectedType.classType).map((_, index) => { + return TypeVarType.createInstance(`__dest${index}`, /* isParamSpec */ false, /* isSynthesized */ true); }); + const genericExpectedType = ClassType.cloneForSpecialization( + expectedType.classType, + synthExpectedTypeArgs, + /* isTypeArgumentExplicit */ true + ); + + // For each type param in the target type, create a placeholder type variable. + const typeArgs = ClassType.getTypeParameters(type).map((_, index) => { + const typeVar = TypeVarType.createInstance( + `__source${index}`, + /* isParamSpec */ false, + /* isSynthesized */ true + ); + typeVar.details.synthesizedIndex = index; + return typeVar; + }); + + const specializedType = ClassType.cloneForSpecialization(type, typeArgs, /* isTypeArgumentExplicit */ true); + const syntheticTypeVarMap = new TypeVarMap(); + if (canAssignType(genericExpectedType, specializedType, new DiagnosticAddendum(), syntheticTypeVarMap)) { + synthExpectedTypeArgs.forEach((typeVar, index) => { + const synthTypeVar = syntheticTypeVarMap.getTypeVar(typeVar); + + // Is this one of the synthesized type vars we allocated above? If so, + // the type arg that corresponds to this type var maps back to the target type. + if ( + synthTypeVar && + isTypeVar(synthTypeVar) && + synthTypeVar.details.isSynthesized && + synthTypeVar.details.synthesizedIndex !== undefined + ) { + const targetTypeVar = ClassType.getTypeParameters(specializedType)[ + synthTypeVar.details.synthesizedIndex + ]; + if (index < expectedTypeArgs.length) { + typeVarMap.setTypeVar(targetTypeVar, expectedTypeArgs[index], /* isNarrowable */ false); + } + } + }); + + return true; + } + + return false; } // Validates that the arguments can be assigned to the call's parameter @@ -7819,7 +7830,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions if (expectedType && entryTypes.length > 0) { const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { if (isObject(subtype)) { - const classAlias = subtype.classType.details.aliasClass || subtype.classType; + const classAlias = ClassType.getAliasClass(subtype.classType); if (ClassType.isBuiltIn(classAlias, 'set') && subtype.classType.typeArguments) { const typeArg = subtype.classType.typeArguments[0]; const typeVarMap = new TypeVarMap(); @@ -7857,48 +7868,174 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } function getTypeFromDictionary(node: DictionaryNode, expectedType: Type | undefined): TypeResult { + // If the expected type is a union, analyze for each of the subtypes + // to find one that matches. + if (expectedType && expectedType.category === TypeCategory.Union) { + let matchingSubtype: Type | undefined; + + for (const subtype of expectedType.subtypes) { + const subtypeResult = useSpeculativeMode(node, () => { + return getTypeFromDictionaryExpected(node, subtype, new DiagnosticAddendum()); + }); + + if (subtypeResult) { + matchingSubtype = subtype; + break; + } + } + + expectedType = matchingSubtype; + } + + const expectedDiagAddendum = new DiagnosticAddendum(); + if (expectedType) { + const result = getTypeFromDictionaryExpected(node, expectedType, expectedDiagAddendum); + if (result) { + return result; + } + } + + return getTypeFromDictionaryInferred(node, /* forceStrict */ !!expectedType)!; + } + + // Attempts to infer the type of a dictionary statement. If an expectedType + // is provided, the resulting type must be compatible with the expected type. + // If this isn't possible, undefined is returned. + function getTypeFromDictionaryExpected( + node: DictionaryNode, + expectedType: Type, + expectedDiagAddendum: DiagnosticAddendum + ): TypeResult | undefined { + const keyTypes: Type[] = []; + const valueTypes: Type[] = []; + + if (!isObject(expectedType)) { + return undefined; + } + + // Handle TypedDict's as a special case. + if (ClassType.isTypedDictClass(expectedType.classType)) { + const expectedTypedDictEntries = getTypedDictMembersForClass(expectedType.classType); + + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary( + node, + keyTypes, + valueTypes, + !!expectedType, + /* expectedKeyType */ undefined, + /* expectedValueType */ undefined, + expectedTypedDictEntries + ); + + if ( + ClassType.isTypedDictClass(expectedType.classType) && + canAssignToTypedDict(expectedType.classType, keyTypes, valueTypes, expectedDiagAddendum) + ) { + return { + type: expectedType, + node, + }; + } + + return undefined; + } + + const builtInDict = getBuiltInObject(node, 'Dict'); + if (!isObject(builtInDict)) { + return undefined; + } + + const dictTypeVarMap = new TypeVarMap(); + if (!populateTypeVarMapBasedOnExpectedType(builtInDict.classType, expectedType, dictTypeVarMap)) { + return undefined; + } + + const specializedDict = specializeType(builtInDict.classType, dictTypeVarMap) as ClassType; + if (!specializedDict.typeArguments || specializedDict.typeArguments.length !== 2) { + return undefined; + } + + const expectedKeyType = specializeType(specializedDict.typeArguments[0], /* typeVarMap */ undefined); + const expectedValueType = specializeType(specializedDict.typeArguments[1], /* typeVarMap */ undefined); + + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary( + node, + keyTypes, + valueTypes, + !!expectedType, + expectedKeyType, + expectedValueType + ); + + const isExpectedTypeDict = + isObject(expectedType) && ClassType.isBuiltIn(ClassType.getAliasClass(expectedType.classType), 'dict'); + + const specializedKeyType = inferTypeArgFromExpectedType(expectedKeyType, keyTypes, /* isNarrowable */ false); + const specializedValueType = inferTypeArgFromExpectedType( + expectedValueType, + valueTypes, + /* isNarrowable */ !isExpectedTypeDict + ); + if (!specializedKeyType || !specializedValueType) { + return undefined; + } + + const type = getBuiltInObject(node, 'Dict', [specializedKeyType, specializedValueType]); + return { type, node }; + } + + // Attempts to infer the type of a dictionary statement. If an expectedType + // is provided, the resulting type must be compatible with the expected type. + // If this isn't possible, undefined is returned. + function getTypeFromDictionaryInferred(node: DictionaryNode, forceStrict: boolean): TypeResult { let keyType: Type = AnyType.create(); let valueType: Type = AnyType.create(); let keyTypes: Type[] = []; let valueTypes: Type[] = []; - let expectedKeyType: Type | undefined; - let expectedValueType: Type | undefined; - let expectedTypedDictEntries: Map | undefined; - const diagAddendum = new DiagnosticAddendum(); + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary(node, keyTypes, valueTypes, !forceStrict); - if (expectedType) { - doForSubtypes(expectedType, (subtype) => { - if (isObject(subtype)) { - const expectedClass = subtype.classType; - if ( - ClassType.isBuiltIn(expectedClass, 'Mapping') || - ClassType.isBuiltIn(expectedClass, 'Dict') || - ClassType.isBuiltIn(expectedClass, 'dict') - ) { - if (expectedClass.typeArguments && expectedClass.typeArguments.length === 2) { - expectedKeyType = specializeType( - expectedClass.typeArguments[0], - /* typeVarMap */ undefined - ); - expectedValueType = specializeType( - expectedClass.typeArguments[1], - /* typeVarMap */ undefined - ); - } - } else if (ClassType.isTypedDictClass(expectedClass)) { - expectedTypedDictEntries = getTypedDictMembersForClass(expectedClass); - } - } + // Strip any literal values. + keyTypes = keyTypes.map((t) => stripLiteralValue(t)); + valueTypes = valueTypes.map((t) => stripLiteralValue(t)); - return undefined; - }); + keyType = keyTypes.length > 0 ? combineTypes(keyTypes) : AnyType.create(); + + // If the value type differs and we're not using "strict inference mode", + // we need to back off because we can't properly represent the mappings + // between different keys and associated value types. If all the values + // are the same type, we'll assume that all values in this dictionary should + // be the same. + if (valueTypes.length > 0) { + if (getFileInfo(node).diagnosticRuleSet.strictDictionaryInference || forceStrict) { + valueType = combineTypes(valueTypes); + } else { + valueType = areTypesSame(valueTypes) ? valueTypes[0] : UnknownType.create(); + } + } else { + valueType = AnyType.create(); } + const type = getBuiltInObject(node, 'Dict', [keyType, valueType]); + return { type, node }; + } + + function getKeyAndValueTypesFromDictionary( + node: DictionaryNode, + keyTypes: Type[], + valueTypes: Type[], + limitEntryCount: boolean, + expectedKeyType?: Type, + expectedValueType?: Type, + expectedTypedDictEntries?: Map + ) { // Infer the key and value types if possible. node.entries.forEach((entryNode, index) => { - if (index >= maxEntriesToUseForInference && !expectedType) { + if (limitEntryCount && index >= maxEntriesToUseForInference) { return; } @@ -7941,7 +8078,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } else { if (isObject(unexpandedType)) { const classType = unexpandedType.classType; - const aliasType = classType.details.aliasClass || classType; + const aliasType = ClassType.getAliasClass(classType); if (ClassType.isBuiltIn(aliasType, 'dict')) { const typeArgs = classType.typeArguments; @@ -7975,107 +8112,66 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions valueTypes.push(UnknownType.create()); } }); - - // If there is an expected type, see if we can match any parts of it. - if (expectedType) { - const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { - if (!isObject(subtype)) { - return undefined; - } - - if ( - ClassType.isTypedDictClass(subtype.classType) && - canAssignToTypedDict(subtype.classType, keyTypes, valueTypes, diagAddendum) - ) { - return subtype; - } - - const classAlias = subtype.classType.details.aliasClass || subtype.classType; - if (ClassType.isBuiltIn(classAlias, 'dict') && subtype.classType.typeArguments) { - const typeArg0 = transformPossibleRecursiveTypeAlias(subtype.classType.typeArguments[0]); - const typeArg1 = transformPossibleRecursiveTypeAlias(subtype.classType.typeArguments[1]); - const typeVarMap = new TypeVarMap(); - - for (const keyType of keyTypes) { - if (!canAssignType(typeArg0, keyType, new DiagnosticAddendum(), typeVarMap)) { - return undefined; - } - } - - for (const valueType of valueTypes) { - if (!canAssignType(typeArg1, valueType, new DiagnosticAddendum(), typeVarMap)) { - return undefined; - } - } - - return specializeType(subtype, typeVarMap); - } - - return undefined; - }); - - if (!isNever(narrowedExpectedType)) { - return { type: narrowedExpectedType, node }; - } - } - - // Strip any literal values. - keyTypes = keyTypes.map((t) => stripLiteralValue(t)); - valueTypes = valueTypes.map((t) => stripLiteralValue(t)); - - keyType = keyTypes.length > 0 ? combineTypes(keyTypes) : AnyType.create(); - - // If the value type differs and we're not using "strict inference mode", - // we need to back off because we can't properly represent the mappings - // between different keys and associated value types. If all the values - // are the same type, we'll assume that all values in this dictionary should - // be the same. - if (valueTypes.length > 0) { - if (getFileInfo(node).diagnosticRuleSet.strictDictionaryInference) { - valueType = combineTypes(valueTypes); - } else { - valueType = areTypesSame(valueTypes) ? valueTypes[0] : UnknownType.create(); - } - } else { - valueType = AnyType.create(); - } - - // If we weren't provided an expected type, strip away any - // literals from the key and value. - if (!expectedType) { - keyType = stripLiteralValue(keyType); - valueType = stripLiteralValue(valueType); - } - - const type = getBuiltInObject(node, 'Dict', [keyType, valueType]); - - return { type, node, expectedTypeDiagAddendum: !diagAddendum.isEmpty() ? diagAddendum : undefined }; } function getTypeFromList(node: ListNode, expectedType: Type | undefined): TypeResult { - // Define a local helper function that determines whether a - // type is a list and returns the list element type if it is. - const getListTypeArg = (potentialList: Type) => { - const expectedType = doForSubtypes(potentialList, (subtype) => { - subtype = transformPossibleRecursiveTypeAlias(subtype); - if (!isObject(subtype)) { - return undefined; + // If the expected type is a union, recursively call for each of the subtypes + // to find one that matches. + let effectiveExpectedType = expectedType; + + if (expectedType && expectedType.category === TypeCategory.Union) { + let matchingSubtype: Type | undefined; + + for (const subtype of expectedType.subtypes) { + const subtypeResult = useSpeculativeMode(node, () => { + return getTypeFromListExpected(node, subtype); + }); + + if (subtypeResult) { + matchingSubtype = subtype; + break; } + } - const classAlias = subtype.classType.details.aliasClass || subtype.classType; - if (!ClassType.isBuiltIn(classAlias, 'list') || !subtype.classType.typeArguments) { - return undefined; - } + effectiveExpectedType = matchingSubtype; + } - return subtype.classType.typeArguments[0]; - }); + if (effectiveExpectedType) { + const result = getTypeFromListExpected(node, effectiveExpectedType); + if (result) { + return result; + } + } - return isNever(expectedType) ? undefined : expectedType; - }; + return getTypeFromListInferred(node, /* forceStrict */ !!expectedType)!; + } - const expectedEntryType = expectedType ? getListTypeArg(expectedType) : undefined; + // Attempts to determine the type of a list statement based on an expected type. + // Returns undefined if that type cannot be honored. + function getTypeFromListExpected(node: ListNode, expectedType: Type): TypeResult | undefined { + expectedType = transformPossibleRecursiveTypeAlias(expectedType); + if (!isObject(expectedType)) { + return undefined; + } - let entryTypes: Type[] = []; + const builtInList = getBuiltInObject(node, 'List'); + if (!isObject(builtInList)) { + return undefined; + } + + const listTypeVarMap = new TypeVarMap(); + if (!populateTypeVarMapBasedOnExpectedType(builtInList.classType, expectedType, listTypeVarMap)) { + return undefined; + } + + const specializedList = specializeType(builtInList.classType, listTypeVarMap) as ClassType; + if (!specializedList.typeArguments || specializedList.typeArguments.length !== 1) { + return undefined; + } + + const expectedEntryType = specializeType(specializedList.typeArguments[0], /* typeVarMap */ undefined); + + const entryTypes: Type[] = []; node.entries.forEach((entry, index) => { if (index < maxEntriesToUseForInference || expectedType !== undefined) { if (entry.nodeType === ParseNodeType.ListComprehension) { @@ -8086,82 +8182,94 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } }); - // If there is an expected type, see if we can match it. - if (expectedType && entryTypes.length > 0) { - const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { - const expectedListElementType = getListTypeArg(subtype); - if (expectedListElementType) { - const typeVarMap = new TypeVarMap(); - - for (const entryType of entryTypes) { - let assignedNonLiteral = false; - - // If the entry type is a literal value, try to assign a non-literal - // type first to avoid over-narrowing. This may not work if the expected - // element type is a literal or a TypeVar bound to a literal. - const nonLiteralEntryType = stripLiteralValue(entryType); - if (entryType !== nonLiteralEntryType) { - if ( - canAssignType( - expectedListElementType, - nonLiteralEntryType, - new DiagnosticAddendum(), - typeVarMap - ) - ) { - assignedNonLiteral = true; - } - } - - if (!assignedNonLiteral) { - if ( - !canAssignType(expectedListElementType, entryType, new DiagnosticAddendum(), typeVarMap) - ) { - return undefined; - } - } - } - - return specializeType(subtype, typeVarMap); - } - - return undefined; - }); - - if (!isNever(narrowedExpectedType)) { - return { type: narrowedExpectedType, node }; - } + const isExpectedTypeList = + isObject(expectedType) && ClassType.isBuiltIn(ClassType.getAliasClass(expectedType.classType), 'list'); + const specializedEntryType = inferTypeArgFromExpectedType( + expectedEntryType, + entryTypes, + /* isNarrowable */ !isExpectedTypeList + ); + if (!specializedEntryType) { + return undefined; } + const type = getBuiltInObject(node, 'List', [specializedEntryType]); + return { type, node }; + } + + // Attempts to infer the type of a list statement with no "expected type". If + // forceStrict is true, it always includes all of the list subtypes. + function getTypeFromListInferred(node: ListNode, forceStrict: boolean): TypeResult { + let entryTypes: Type[] = []; + node.entries.forEach((entry, index) => { + if (index < maxEntriesToUseForInference) { + if (entry.nodeType === ParseNodeType.ListComprehension) { + entryTypes.push(getElementTypeFromListComprehension(entry)); + } else { + entryTypes.push(getTypeOfExpression(entry).type); + } + } + }); + entryTypes = entryTypes.map((t) => stripLiteralValue(t)); let inferredEntryType: Type = AnyType.create(); if (entryTypes.length > 0) { // If there was an expected type or we're using strict list inference, // combine the types into a union. - if (expectedType || getFileInfo(node).diagnosticRuleSet.strictListInference) { + if (getFileInfo(node).diagnosticRuleSet.strictListInference || forceStrict) { inferredEntryType = combineTypes(entryTypes, maxSubtypesForInferredType); } else { // Is the list homogeneous? If so, use stricter rules. Otherwise relax the rules. inferredEntryType = areTypesSame(entryTypes) ? entryTypes[0] : UnknownType.create(); } - } else if (expectedEntryType) { - inferredEntryType = expectedEntryType; - } - - // If we weren't provided an expected type, strip away any - // literals from the list. The user is probably not expecting - // ['a'] to be interpreted as type List[Literal['a']] but - // instead List[str]. - if (!expectedType) { - inferredEntryType = stripLiteralValue(inferredEntryType); } const type = getBuiltInObject(node, 'List', [inferredEntryType]); - return { type, node }; } + function inferTypeArgFromExpectedType( + expectedType: Type, + entryTypes: Type[], + isNarrowable: boolean + ): Type | undefined { + const diagDummy = new DiagnosticAddendum(); + + // Synthesize a temporary bound type var. We will attempt to assign all list + // entries to this type var, possibly narrowing the type in the process. + const targetTypeVar = TypeVarType.createInstance( + '__typeArg', + /* isParamSpec */ false, + /* isSynthesized */ true + ); + targetTypeVar.details.boundType = expectedType; + + let typeVarMap = new TypeVarMap(); + typeVarMap.setTypeVar(targetTypeVar, expectedType, isNarrowable); + + // First, try to assign entries with their literal values stripped. + // The only time we don't want to strip them is if the expected + // type explicitly includes literals. + if ( + entryTypes.some( + (entryType) => !canAssignType(targetTypeVar, stripLiteralValue(entryType), diagDummy, typeVarMap) + ) + ) { + // Allocate a fresh typeVarMap before we try again with literals not stripped. + typeVarMap = new TypeVarMap(); + typeVarMap.setTypeVar(targetTypeVar, expectedType, isNarrowable); + if (entryTypes.some((entryType) => !canAssignType(targetTypeVar!, entryType, diagDummy, typeVarMap))) { + return undefined; + } + } + + // We need to call specializeType twice here. The first time specializes the + // temporary "__typeArg" type variable. The second time specializes any type + // variables that it referred to. + return specializeType(specializeType(targetTypeVar, typeVarMap), typeVarMap); + } + function getTypeFromTernary(node: TernaryNode, flags: EvaluatorFlags, expectedType: Type | undefined): TypeResult { getTypeOfExpression(node.testExpression); @@ -8857,7 +8965,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // If this is a recursive type alias that hasn't yet been fully resolved // (i.e. there is no boundType associated with it), don't apply the transform. - if (isTypeVar(type) && type.details.recursiveTypeAliasName && !type.details.boundType) { + if (isTypeAliasPlaceholder(type)) { return type; } @@ -9143,6 +9251,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Set the resulting type to the boundType of the original type alias // to support recursive type aliases. typeAliasTypeVar!.details.boundType = rightHandType; + + // Record the type parameters within the recursive type alias so it + // can be specialized. + typeAliasTypeVar!.details.recursiveTypeParameters = rightHandType.typeAliasInfo?.typeParameters; } } } @@ -12639,9 +12751,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } const classType = containerType.classType; - const builtInName = classType.details.aliasClass - ? classType.details.aliasClass.details.name - : classType.details.name; + const builtInName = ClassType.getAliasClass(classType).details.name; if (!['list', 'set', 'frozenset', 'deque'].some((name) => name === builtInName)) { return referenceType; @@ -13104,11 +13214,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } // Disables recording of errors and warnings. - function suppressDiagnostics(callback: () => void) { + function suppressDiagnostics(callback: () => T) { const wasSuppressed = isDiagnosticSuppressed; isDiagnosticSuppressed = true; try { - callback(); + return callback(); } finally { isDiagnosticSuppressed = wasSuppressed; } @@ -13117,11 +13227,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Disables recording of errors and warnings and disables // any caching of types, under the assumption that we're // performing speculative evaluations. - function useSpeculativeMode(speculativeNode: ParseNode, callback: () => void) { + function useSpeculativeMode(speculativeNode: ParseNode, callback: () => T) { speculativeTypeTracker.enterSpeculativeContext(speculativeNode); try { - callback(); + return callback(); } finally { speculativeTypeTracker.leaveSpeculativeContext(); } @@ -15686,40 +15796,101 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return canAssign; } - // When a value is assigned to a variable with a declared type, - // we may be able to narrow the type based on the assignment. - function narrowTypeBasedOnAssignment(declaredType: Type, assignedType: Type): Type { - const diagAddendum = new DiagnosticAddendum(); + // If the declaredType contains type arguments that are "Any" and + // the corresponding type argument in the assignedType is not "Any", + // replace that type argument in the assigned type. This function assumes + // that the caller has already verified that the assignedType is assignable + // tot he declaredType. + function replaceTypeArgsWithAny(declaredType: ClassType, assignedType: ClassType): ClassType | undefined { + const assignedTypeAlias = ClassType.getAliasClass(assignedType); - if (declaredType.category === TypeCategory.Union) { - return doForSubtypes(declaredType, (subtype) => { - if (assignedType.category === TypeCategory.Union) { - if (!assignedType.subtypes.some((t) => canAssignType(subtype, t, diagAddendum))) { - return undefined; - } else { - return subtype; + if ( + assignedTypeAlias.details.typeParameters.length > 0 && + assignedType.typeArguments && + assignedType.typeArguments.length <= assignedTypeAlias.details.typeParameters.length + ) { + const typeVarMap = new TypeVarMap(); + populateTypeVarMapBasedOnExpectedType( + ClassType.cloneForSpecialization( + assignedTypeAlias, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ), + ObjectType.create(declaredType), + typeVarMap + ); + + let replacedTypeArg = false; + const newTypeArgs = assignedType.typeArguments.map((typeArg, index) => { + const typeParam = assignedTypeAlias.details.typeParameters[index]; + const expectedTypeArgType = typeVarMap.getTypeVar(typeParam); + + if (expectedTypeArgType) { + if (expectedTypeArgType.category === TypeCategory.Any || isAnyOrUnknown(typeArg)) { + replacedTypeArg = true; + return expectedTypeArgType; } } - if (!canAssignType(subtype, assignedType, diagAddendum)) { - return undefined; - } - - // We assume that assignedType is a narrower type than subtype, - // so return it rather than subtype. - if (!isAnyOrUnknown(assignedType)) { - return assignedType; - } - - return subtype; + return typeArg; }); + + if (replacedTypeArg) { + return ClassType.cloneForSpecialization(assignedType, newTypeArgs, /* isTypeArgumentExplicit */ true); + } } - if (!canAssignType(declaredType, assignedType, diagAddendum)) { - return NeverType.create(); + return undefined; + } + + // When a value is assigned to a variable with a declared type, + // we may be able to narrow the type based on the assignment. + function narrowTypeBasedOnAssignment(declaredType: Type, assignedType: Type): Type { + const diag = new DiagnosticAddendum(); + + const narrowedType = doForSubtypes(assignedType, (assignedSubtype) => { + const narrowedSubtype = doForSubtypes(declaredType, (declaredSubtype) => { + // We can't narrow "Any". + if (isAnyOrUnknown(declaredType)) { + return declaredType; + } + + if (canAssignType(declaredSubtype, assignedSubtype, diag)) { + // If the source is generic and has unspecified type arguments, + // see if we can determine then based on the declared type. + if (isClass(declaredSubtype) && isClass(assignedSubtype)) { + const result = replaceTypeArgsWithAny(declaredSubtype, assignedSubtype); + if (result) { + assignedSubtype = result; + } + } else if (isObject(declaredSubtype) && isObject(assignedSubtype)) { + const result = replaceTypeArgsWithAny(declaredSubtype.classType, assignedSubtype.classType); + if (result) { + assignedSubtype = ObjectType.create(result); + } + } + + return assignedSubtype; + } + + return undefined; + }); + + // If we couldn't assign the assigned subtype any of the declared + // subtypes, the types are incompatible. Return the unnarrowed form. + if (isNever(narrowedSubtype)) { + return assignedSubtype; + } + + return narrowedSubtype; + }); + + // If the result of narrowing is Any, stick with the declared (unnarrowed) type. + if (isAnyOrUnknown(assignedType)) { + return declaredType; } - return transformTypeObjectToClass(declaredType); + return narrowedType; } function canOverrideMethod(baseMethod: Type, overrideMethod: FunctionType, diag: DiagnosticAddendum): boolean { diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index aed8298b2..c492a5249 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -370,10 +370,16 @@ export function transformPossibleRecursiveTypeAlias(type: Type | undefined): Typ export function transformPossibleRecursiveTypeAlias(type: Type | undefined): Type | undefined { if (type) { if (isTypeVar(type) && type.details.recursiveTypeAliasName && type.details.boundType) { - if (TypeBase.isInstance(type)) { - return convertToInstance(type.details.boundType); + const unspecializedType = TypeBase.isInstance(type) + ? convertToInstance(type.details.boundType) + : type.details.boundType; + + if (!type.typeAliasInfo?.typeArguments || !type.details.recursiveTypeParameters) { + return unspecializedType; } - return type.details.boundType; + + const typeVarMap = buildTypeVarMap(type.details.recursiveTypeParameters, type.typeAliasInfo.typeArguments); + return specializeType(unspecializedType, typeVarMap); } } @@ -643,6 +649,26 @@ export function specializeType( } if (isTypeVar(type)) { + // Handle recursive type aliases specially. In particular, + // we need to specialize type arguments for generic recursive + // type aliases. + if (type.details.recursiveTypeAliasName) { + if (!type.typeAliasInfo?.typeArguments) { + return type; + } + + const typeArgs = type.typeAliasInfo.typeArguments.map((typeArg) => + specializeType(typeArg, typeVarMap, /* makeConcrete */ false, recursionLevel + 1) + ); + + return TypeBase.cloneForTypeAlias( + type, + type.typeAliasInfo.aliasName, + type.typeAliasInfo.typeParameters, + typeArgs + ); + } + if (typeVarMap) { const replacementType = typeVarMap.getTypeVar(type); if (replacementType) { @@ -1712,8 +1738,18 @@ export function requiresSpecialization(type: Type, recursionCount = 0): boolean } case TypeCategory.TypeVar: { - // If this is a recursive type alias, don't treat it like other TypeVars. - return type.details.recursiveTypeAliasName === undefined; + // Most TypeVar types need to be specialized. + if (!type.details.recursiveTypeAliasName) { + return true; + } + + // If this is a recursive type alias, it may need to be specialized + // if it has generic type arguments. + if (type.typeAliasInfo?.typeArguments) { + return type.typeAliasInfo.typeArguments.some((typeArg) => + requiresSpecialization(typeArg, recursionCount + 1) + ); + } } } diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 6607e9c71..8b6df120b 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -1272,6 +1272,9 @@ export interface TypeVarDetails { // Used for recursive type aliases. recursiveTypeAliasName?: string; + + // Type parameters for a recursive type alias. + recursiveTypeParameters?: TypeVarType[]; } export interface TypeVarType extends TypeBase { diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index 1f1a39575..f20dcb377 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -353,7 +353,13 @@ test('TypeNarrowing17', () => { test('TypeNarrowing18', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowing18.py']); - validateResults(analysisResults, 0, 0, 10); + validateResults(analysisResults, 0); +}); + +test('TypeNarrowing19', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowing19.py']); + + validateResults(analysisResults, 0); }); test('CircularBaseClass', () => { @@ -1125,7 +1131,7 @@ test('TypeAlias6', () => { test('TypeAlias7', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeAlias7.py']); - validateResults(analysisResults, 2); + validateResults(analysisResults, 3); }); test('TypeAlias8', () => { @@ -2376,6 +2382,12 @@ test('Constructor1', () => { validateResults(analysisResults, 0); }); +test('Constructor2', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['constructor2.py']); + + validateResults(analysisResults, 0); +}); + test('Constructor3', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['constructor3.py']); diff --git a/packages/pyright-internal/src/tests/samples/expressions6.py b/packages/pyright-internal/src/tests/samples/expressions6.py index 30d63ea87..55df759d4 100644 --- a/packages/pyright-internal/src/tests/samples/expressions6.py +++ b/packages/pyright-internal/src/tests/samples/expressions6.py @@ -11,4 +11,4 @@ def func_or(a: Optional[Dict[str, Any]]): def func_and(): a: Optional[Dict[str, Any]] = True and dict() - t1: Literal["Dict[str, Any]"] = reveal_type(a) + t1: Literal["dict[str, Any]"] = reveal_type(a) diff --git a/packages/pyright-internal/src/tests/samples/typeAlias7.py b/packages/pyright-internal/src/tests/samples/typeAlias7.py index 58124d8c6..82fe7653f 100644 --- a/packages/pyright-internal/src/tests/samples/typeAlias7.py +++ b/packages/pyright-internal/src/tests/samples/typeAlias7.py @@ -3,16 +3,17 @@ from typing import List, TypeVar, Union -_T2 = TypeVar("_T2", str, int) +_T1 = TypeVar("_T1", str, int) +_T2 = TypeVar("_T2") -GenericTypeAlias1 = List[Union["GenericTypeAlias1", _T2]] +GenericTypeAlias1 = List[Union["GenericTypeAlias1[_T1]", _T1]] SpecializedTypeAlias1 = GenericTypeAlias1[str] a1: SpecializedTypeAlias1 = ["hi", ["hi", "hi"]] # This should generate an error because int doesn't match the -# constraint of the TypeVar _T2. +# constraint of the TypeVar _T1. SpecializedClass2 = GenericTypeAlias1[float] b1: GenericTypeAlias1[str] = ["hi", "bye", [""], [["hi"]]] @@ -20,3 +21,12 @@ b1: GenericTypeAlias1[str] = ["hi", "bye", [""], [["hi"]]] # This should generate an error. b2: GenericTypeAlias1[str] = ["hi", [2.4]] + +GenericTypeAlias2 = List[Union["GenericTypeAlias2[_T1, _T2]", _T1, _T2]] + +c2: GenericTypeAlias2[str, int] = [[3, ["hi"]], "hi"] + +c3: GenericTypeAlias2[str, float] = [[3, ["hi", 3.4, [3.4]]], "hi"] + +# This should generate an error because a float is a type mismatch. +c4: GenericTypeAlias2[str, int] = [[3, ["hi", 3, [3.4]]], "hi"] diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowing19.py b/packages/pyright-internal/src/tests/samples/typeNarrowing19.py new file mode 100644 index 000000000..b31bb4d24 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowing19.py @@ -0,0 +1,19 @@ +# This sample tests type narrowing for assignments +# where the source contains Unknown or Any type +# arguments. + +from typing import Any, Dict, Literal + + +def func1(struct: Dict[Any, Any]): + a1: Dict[str, Any] = struct + t1: Literal["Dict[str, Any]"] = reveal_type(a1) + + +def func2(struct: Any): + a1: Dict[Any, str] = struct + t1: Literal["Dict[Any, str]"] = reveal_type(a1) + + if isinstance(struct, Dict): + a2: Dict[str, Any] = struct + t2: Literal["Dict[str, Any]"] = reveal_type(a2)