From e0fcd6da43fda010008746223de8a4ddeeab21d9 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sun, 4 Oct 2020 23:29:35 -0700 Subject: [PATCH] Fixed bug that caused type narrowing for assignments not to be applied when the source of the assignment was a call to a constructor. Improved type narrowing for assignments when destination is declared with one or more "Any" type arguments. Improved bidirectional type inference for list and dict types when destination type is a union that contains one or more specialized list or dict types. Improved support for generic recursive type aliases. Improved bidirectional type inference for list and dict types when destination type is a wider protocol type (like Iterable, Mapping, Sequence, etc.). --- .../src/analyzer/typeEvaluator.ts | 905 +++++++++++------- .../src/analyzer/typeUtils.ts | 46 +- .../pyright-internal/src/analyzer/types.ts | 3 + .../src/tests/checker.test.ts | 16 +- .../src/tests/samples/expressions6.py | 2 +- .../src/tests/samples/typeAlias7.py | 16 +- .../src/tests/samples/typeNarrowing19.py | 19 + 7 files changed, 629 insertions(+), 378 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/typeNarrowing19.py diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 0c2c33e7e..01ca87afa 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -1293,10 +1293,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions memberName: string, treatAsClassMember: boolean ): FunctionType | OverloadedFunctionType | undefined { - const aliasClass = classType.details.aliasClass; - if (aliasClass) { - classType = aliasClass; - } + classType = ClassType.getAliasClass(classType); const memberInfo = lookUpClassMember( classType, @@ -3757,7 +3754,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ): ClassMemberLookup | undefined { // If this is a special type (like "List") that has an alias class (like // "list"), switch to the alias, which defines the members. - classType = classType.details.aliasClass || classType; + classType = ClassType.getAliasClass(classType); let classLookupFlags = ClassMemberLookupFlags.Default; if (flags & MemberAccessFlags.SkipInstanceMembers) { @@ -4128,7 +4125,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions flags | EvaluatorFlags.DoNotSpecialize ); - if (baseTypeResult.isResolutionCyclical || isTypeAliasPlaceholder(baseTypeResult.type)) { + if (baseTypeResult.isResolutionCyclical) { return { node, type: UnknownType.create(), @@ -4233,6 +4230,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } } + if (isTypeAliasPlaceholder(baseType)) { + const typeArgTypes = getTypeArgs(node.items, flags).map((t) => convertToInstance(t.type)); + const type = TypeBase.cloneForTypeAlias( + baseType, + baseType.details.recursiveTypeAliasName!, + undefined, + typeArgTypes + ); + return { type, node }; + } + const type = doForSubtypes(baseType, (subtype) => { subtype = makeTypeVarsConcrete(subtype); subtype = getClassFromPotentialTypeObject(subtype); @@ -5375,31 +5383,72 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ); if (initMethodType && !skipConstructorCheck(initMethodType)) { - const typeVarMap = new TypeVarMap(); - + // If there is an expected type, analyze the constructor call + // for each of the subtypes that comprise the expected type. If + // one or more analyzes with no errors, use those results. if (expectedType) { - populateTypeVarMapBasedOnExpectedType(type, expectedType, typeVarMap); + returnType = doForSubtypes(expectedType, (expectedSubType) => { + const typeVarMap = new TypeVarMap(); + if (populateTypeVarMapBasedOnExpectedType(type, expectedSubType, typeVarMap)) { + const callResult = suppressDiagnostics(() => { + return validateCallArguments( + errorNode, + argList, + initMethodType, + typeVarMap, + skipUnknownArgCheck, + /* inferReturnTypeIfNeeded */ true, + NoneType.createInstance() + ); + }); + + if (!callResult.argumentErrors) { + // Note that we're specializing the type twice here with the same + // typeVarMap. This handles the case where the expectedType contains + // a type variable that is computed and filled in to the typeVarMap. + const specializedType = specializeType( + specializeType(type, typeVarMap, /* makeConcrete */ true), + typeVarMap, + /* makeConcrete */ true + ) as ClassType; + return applyExpectedSubtypeForConstructor(specializedType, expectedSubType); + } + } + + return undefined; + }); + + if (isNever(returnType)) { + returnType = undefined; + } } - const callResult = validateCallArguments( - errorNode, - argList, - initMethodType, - typeVarMap, - skipUnknownArgCheck, - /* inferReturnTypeIfNeeded */ true, - NoneType.createInstance() - ); - if (!callResult.argumentErrors) { - // Note that we're specializing the type twice here with the same - // typeVarMap. This handles the case where the expectedType contains - // a type variable that is computed and filled in to the typeVarMap. - let specializedType = specializeType(type, typeVarMap, /* makeConcrete */ true) as ClassType; - specializedType = specializeType(specializedType, typeVarMap, /* makeConcrete */ true) as ClassType; - returnType = applyExpectedTypeForConstructor(specializedType, expectedType); - } else { - reportedErrors = true; + if (!returnType) { + const typeVarMap = new TypeVarMap(); + const callResult = validateCallArguments( + errorNode, + argList, + initMethodType, + typeVarMap, + skipUnknownArgCheck, + /* inferReturnTypeIfNeeded */ true, + NoneType.createInstance() + ); + if (!callResult.argumentErrors) { + // Note that we're specializing the type twice here with the same + // typeVarMap. This handles the case where the expectedType contains + // a type variable that is computed and filled in to the typeVarMap. + const specializedType = specializeType( + specializeType(type, typeVarMap, /* makeConcrete */ true), + typeVarMap, + /* makeConcrete */ true + ) as ClassType; + returnType = applyExpectedTypeForConstructor(specializedType, expectedType); + } else { + reportedErrors = true; + } } + validatedTypes = true; skipUnknownArgCheck = true; } @@ -5523,149 +5572,111 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return { argumentErrors: reportedErrors, returnType }; } - // Handles the case where a constructor is a generic type and the type - // arguments are not specified but can be provided by the expected type. - function applyExpectedTypeForConstructor(type: ClassType, expectedType: Type | undefined): ObjectType { + function applyExpectedSubtypeForConstructor(type: ClassType, expectedSubtype: Type): Type | undefined { const objType = ObjectType.create(type); - if (!expectedType) { - return objType; + if (canAssignType(expectedSubtype, objType, new DiagnosticAddendum())) { + // If the expected type is "Any", transform it to an Any. + if (expectedSubtype.category === TypeCategory.Any) { + return expectedSubtype; + } + + const typeVarMap = new TypeVarMap(); + if (populateTypeVarMapBasedOnExpectedType(type, expectedSubtype, typeVarMap)) { + return specializeType(objType, typeVarMap) as ClassType; + } } - // Try to apply for each subtype in the expectedType. The foundMatch - // tracks whether we've already seen a match. If none of them match, - // return the original type. - let foundMatch = false; - const specializedType = doForSubtypes(expectedType, (subtype) => { - if (foundMatch || !isObject(subtype)) { - return undefined; + return undefined; + } + + // Handles the case where a constructor is a generic type and the type + // arguments are not specified but can be provided by the expected type. + function applyExpectedTypeForConstructor(type: ClassType, expectedType: Type | undefined): Type { + if (expectedType) { + const specializedType = doForSubtypes(expectedType, (expectedSubtype) => { + return applyExpectedSubtypeForConstructor(type, expectedSubtype); + }); + + if (!isNever(specializedType)) { + return specializedType; } + } - const expectedClass = subtype.classType; - const typeVarMap = new TypeVarMap(); - if (canAssignType(expectedClass, type, new DiagnosticAddendum(), typeVarMap)) { - foundMatch = true; - return specializeType(expectedClass, typeVarMap) as ClassType; - } - - // If it's the same generic class, see if we can assign the type arguments - // without the variance rules that canAssignType uses. - if ( - ClassType.isSameGenericClass(type, expectedClass) && - expectedClass.typeArguments && - type.typeArguments && - !type.isTypeArgumentExplicit && - expectedClass.typeArguments.length === type.typeArguments.length - ) { - const typeVarMap = new TypeVarMap(); - let isAssignable = true; - expectedClass.typeArguments.forEach((expectedTypeArg, index) => { - const typeTypeArg = type.typeArguments![index]; - if (!canAssignType(expectedTypeArg, typeTypeArg, new DiagnosticAddendum(), typeVarMap)) { - isAssignable = false; - } - }); - - if (isAssignable) { - foundMatch = true; - return specializeType(expectedClass, typeVarMap) as ClassType; - } - } - }); - - return isClass(specializedType) ? ObjectType.create(specializedType) : objType; + return ObjectType.create(type); } // In cases where the expected type is a specialized base class of the // source type, we need to determine which type arguments in the derived // class will make it compatible with the specialized base class. This method // performs this reverse mapping of type arguments and populates the type var - // map for the target type. - function populateTypeVarMapBasedOnExpectedType(type: ClassType, expectedType: Type, typeVarMap: TypeVarMap) { - // If the target type isn't generic, there's nothing for us to do. - if (!requiresSpecialization(type)) { - return; - } - + // map for the target type. If the type is not assignable to the expected type, + // it returns false. + function populateTypeVarMapBasedOnExpectedType( + type: ClassType, + expectedType: Type, + typeVarMap: TypeVarMap + ): boolean { // Try to find a subtype within the expected type that the type can be assigned to. // If found, fill in the typeVarMap with the required specialization type arguments. - let foundMatch = false; - doForSubtypes(expectedType, (subtype) => { - if (!foundMatch && isObject(subtype)) { - // If the expected type is generic (not specialized), we can't proceed. - const expectedTypeArgs = subtype.classType.effectiveTypeArguments || subtype.classType.typeArguments; - if (expectedTypeArgs === undefined) { - return undefined; - } + if (!isObject(expectedType)) { + return false; + } - // If the expected type is the same as the target type (commonly the case), - // we can use a faster method. - if (ClassType.isSameGenericClass(subtype.classType, type)) { - canAssignType(type, subtype.classType, new DiagnosticAddendum(), typeVarMap); - foundMatch = true; - return undefined; - } + // If the expected type is generic (but not specialized), we can't proceed. + const expectedTypeArgs = expectedType.classType.effectiveTypeArguments || expectedType.classType.typeArguments; + if (!expectedTypeArgs) { + return canAssignType(type, expectedType.classType, new DiagnosticAddendum(), typeVarMap); + } - // Create a generic (not specialized) version of the expected type. - const synthExpectedTypeArgs = subtype.classType.details.typeParameters.map((_, index) => { - return TypeVarType.createInstance( - `__dest${index}`, - /* isParamSpec */ false, - /* isSynthesized */ true - ); - }); - const genericExpectedType = ClassType.cloneForSpecialization( - subtype.classType, - synthExpectedTypeArgs, - /* isTypeArgumentExplicit */ true - ); - - // For each type param in the target type, create a placeholder type variable. - const classType = type.details.aliasClass || type; - const typeArgs = classType.details.typeParameters.map((_, index) => { - const typeVar = TypeVarType.createInstance( - `__source${index}`, - /* isParamSpec */ false, - /* isSynthesized */ true - ); - typeVar.details.synthesizedIndex = index; - return typeVar; - }); - - const specializedType = ClassType.cloneForSpecialization( - type, - typeArgs, - /* isTypeArgumentExplicit */ true - ); - const syntheticTypeVarMap = new TypeVarMap(); - if ( - canAssignType(genericExpectedType, specializedType, new DiagnosticAddendum(), syntheticTypeVarMap) - ) { - synthExpectedTypeArgs.forEach((typeVar, index) => { - const synthTypeVar = syntheticTypeVarMap.getTypeVar(typeVar); - - // Is this one of the synthesized type vars we allocated above? If so, - // the type arg that corresponds to this type var maps back to the target type. - if ( - synthTypeVar && - isTypeVar(synthTypeVar) && - synthTypeVar.details.isSynthesized && - synthTypeVar.details.synthesizedIndex !== undefined - ) { - const targetTypeVar = - specializedType.details.typeParameters[synthTypeVar.details.synthesizedIndex]; - if (index < expectedTypeArgs.length) { - typeVarMap.setTypeVar(targetTypeVar, expectedTypeArgs[index], /* isNarrowable */ false); - } - } - }); - - foundMatch = true; - } - } - - return undefined; + // Create a generic version of the expected type. + const synthExpectedTypeArgs = ClassType.getTypeParameters(expectedType.classType).map((_, index) => { + return TypeVarType.createInstance(`__dest${index}`, /* isParamSpec */ false, /* isSynthesized */ true); }); + const genericExpectedType = ClassType.cloneForSpecialization( + expectedType.classType, + synthExpectedTypeArgs, + /* isTypeArgumentExplicit */ true + ); + + // For each type param in the target type, create a placeholder type variable. + const typeArgs = ClassType.getTypeParameters(type).map((_, index) => { + const typeVar = TypeVarType.createInstance( + `__source${index}`, + /* isParamSpec */ false, + /* isSynthesized */ true + ); + typeVar.details.synthesizedIndex = index; + return typeVar; + }); + + const specializedType = ClassType.cloneForSpecialization(type, typeArgs, /* isTypeArgumentExplicit */ true); + const syntheticTypeVarMap = new TypeVarMap(); + if (canAssignType(genericExpectedType, specializedType, new DiagnosticAddendum(), syntheticTypeVarMap)) { + synthExpectedTypeArgs.forEach((typeVar, index) => { + const synthTypeVar = syntheticTypeVarMap.getTypeVar(typeVar); + + // Is this one of the synthesized type vars we allocated above? If so, + // the type arg that corresponds to this type var maps back to the target type. + if ( + synthTypeVar && + isTypeVar(synthTypeVar) && + synthTypeVar.details.isSynthesized && + synthTypeVar.details.synthesizedIndex !== undefined + ) { + const targetTypeVar = ClassType.getTypeParameters(specializedType)[ + synthTypeVar.details.synthesizedIndex + ]; + if (index < expectedTypeArgs.length) { + typeVarMap.setTypeVar(targetTypeVar, expectedTypeArgs[index], /* isNarrowable */ false); + } + } + }); + + return true; + } + + return false; } // Validates that the arguments can be assigned to the call's parameter @@ -7819,7 +7830,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions if (expectedType && entryTypes.length > 0) { const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { if (isObject(subtype)) { - const classAlias = subtype.classType.details.aliasClass || subtype.classType; + const classAlias = ClassType.getAliasClass(subtype.classType); if (ClassType.isBuiltIn(classAlias, 'set') && subtype.classType.typeArguments) { const typeArg = subtype.classType.typeArguments[0]; const typeVarMap = new TypeVarMap(); @@ -7857,48 +7868,174 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } function getTypeFromDictionary(node: DictionaryNode, expectedType: Type | undefined): TypeResult { + // If the expected type is a union, analyze for each of the subtypes + // to find one that matches. + if (expectedType && expectedType.category === TypeCategory.Union) { + let matchingSubtype: Type | undefined; + + for (const subtype of expectedType.subtypes) { + const subtypeResult = useSpeculativeMode(node, () => { + return getTypeFromDictionaryExpected(node, subtype, new DiagnosticAddendum()); + }); + + if (subtypeResult) { + matchingSubtype = subtype; + break; + } + } + + expectedType = matchingSubtype; + } + + const expectedDiagAddendum = new DiagnosticAddendum(); + if (expectedType) { + const result = getTypeFromDictionaryExpected(node, expectedType, expectedDiagAddendum); + if (result) { + return result; + } + } + + return getTypeFromDictionaryInferred(node, /* forceStrict */ !!expectedType)!; + } + + // Attempts to infer the type of a dictionary statement. If an expectedType + // is provided, the resulting type must be compatible with the expected type. + // If this isn't possible, undefined is returned. + function getTypeFromDictionaryExpected( + node: DictionaryNode, + expectedType: Type, + expectedDiagAddendum: DiagnosticAddendum + ): TypeResult | undefined { + const keyTypes: Type[] = []; + const valueTypes: Type[] = []; + + if (!isObject(expectedType)) { + return undefined; + } + + // Handle TypedDict's as a special case. + if (ClassType.isTypedDictClass(expectedType.classType)) { + const expectedTypedDictEntries = getTypedDictMembersForClass(expectedType.classType); + + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary( + node, + keyTypes, + valueTypes, + !!expectedType, + /* expectedKeyType */ undefined, + /* expectedValueType */ undefined, + expectedTypedDictEntries + ); + + if ( + ClassType.isTypedDictClass(expectedType.classType) && + canAssignToTypedDict(expectedType.classType, keyTypes, valueTypes, expectedDiagAddendum) + ) { + return { + type: expectedType, + node, + }; + } + + return undefined; + } + + const builtInDict = getBuiltInObject(node, 'Dict'); + if (!isObject(builtInDict)) { + return undefined; + } + + const dictTypeVarMap = new TypeVarMap(); + if (!populateTypeVarMapBasedOnExpectedType(builtInDict.classType, expectedType, dictTypeVarMap)) { + return undefined; + } + + const specializedDict = specializeType(builtInDict.classType, dictTypeVarMap) as ClassType; + if (!specializedDict.typeArguments || specializedDict.typeArguments.length !== 2) { + return undefined; + } + + const expectedKeyType = specializeType(specializedDict.typeArguments[0], /* typeVarMap */ undefined); + const expectedValueType = specializeType(specializedDict.typeArguments[1], /* typeVarMap */ undefined); + + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary( + node, + keyTypes, + valueTypes, + !!expectedType, + expectedKeyType, + expectedValueType + ); + + const isExpectedTypeDict = + isObject(expectedType) && ClassType.isBuiltIn(ClassType.getAliasClass(expectedType.classType), 'dict'); + + const specializedKeyType = inferTypeArgFromExpectedType(expectedKeyType, keyTypes, /* isNarrowable */ false); + const specializedValueType = inferTypeArgFromExpectedType( + expectedValueType, + valueTypes, + /* isNarrowable */ !isExpectedTypeDict + ); + if (!specializedKeyType || !specializedValueType) { + return undefined; + } + + const type = getBuiltInObject(node, 'Dict', [specializedKeyType, specializedValueType]); + return { type, node }; + } + + // Attempts to infer the type of a dictionary statement. If an expectedType + // is provided, the resulting type must be compatible with the expected type. + // If this isn't possible, undefined is returned. + function getTypeFromDictionaryInferred(node: DictionaryNode, forceStrict: boolean): TypeResult { let keyType: Type = AnyType.create(); let valueType: Type = AnyType.create(); let keyTypes: Type[] = []; let valueTypes: Type[] = []; - let expectedKeyType: Type | undefined; - let expectedValueType: Type | undefined; - let expectedTypedDictEntries: Map | undefined; - const diagAddendum = new DiagnosticAddendum(); + // Infer the key and value types if possible. + getKeyAndValueTypesFromDictionary(node, keyTypes, valueTypes, !forceStrict); - if (expectedType) { - doForSubtypes(expectedType, (subtype) => { - if (isObject(subtype)) { - const expectedClass = subtype.classType; - if ( - ClassType.isBuiltIn(expectedClass, 'Mapping') || - ClassType.isBuiltIn(expectedClass, 'Dict') || - ClassType.isBuiltIn(expectedClass, 'dict') - ) { - if (expectedClass.typeArguments && expectedClass.typeArguments.length === 2) { - expectedKeyType = specializeType( - expectedClass.typeArguments[0], - /* typeVarMap */ undefined - ); - expectedValueType = specializeType( - expectedClass.typeArguments[1], - /* typeVarMap */ undefined - ); - } - } else if (ClassType.isTypedDictClass(expectedClass)) { - expectedTypedDictEntries = getTypedDictMembersForClass(expectedClass); - } - } + // Strip any literal values. + keyTypes = keyTypes.map((t) => stripLiteralValue(t)); + valueTypes = valueTypes.map((t) => stripLiteralValue(t)); - return undefined; - }); + keyType = keyTypes.length > 0 ? combineTypes(keyTypes) : AnyType.create(); + + // If the value type differs and we're not using "strict inference mode", + // we need to back off because we can't properly represent the mappings + // between different keys and associated value types. If all the values + // are the same type, we'll assume that all values in this dictionary should + // be the same. + if (valueTypes.length > 0) { + if (getFileInfo(node).diagnosticRuleSet.strictDictionaryInference || forceStrict) { + valueType = combineTypes(valueTypes); + } else { + valueType = areTypesSame(valueTypes) ? valueTypes[0] : UnknownType.create(); + } + } else { + valueType = AnyType.create(); } + const type = getBuiltInObject(node, 'Dict', [keyType, valueType]); + return { type, node }; + } + + function getKeyAndValueTypesFromDictionary( + node: DictionaryNode, + keyTypes: Type[], + valueTypes: Type[], + limitEntryCount: boolean, + expectedKeyType?: Type, + expectedValueType?: Type, + expectedTypedDictEntries?: Map + ) { // Infer the key and value types if possible. node.entries.forEach((entryNode, index) => { - if (index >= maxEntriesToUseForInference && !expectedType) { + if (limitEntryCount && index >= maxEntriesToUseForInference) { return; } @@ -7941,7 +8078,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } else { if (isObject(unexpandedType)) { const classType = unexpandedType.classType; - const aliasType = classType.details.aliasClass || classType; + const aliasType = ClassType.getAliasClass(classType); if (ClassType.isBuiltIn(aliasType, 'dict')) { const typeArgs = classType.typeArguments; @@ -7975,107 +8112,66 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions valueTypes.push(UnknownType.create()); } }); - - // If there is an expected type, see if we can match any parts of it. - if (expectedType) { - const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { - if (!isObject(subtype)) { - return undefined; - } - - if ( - ClassType.isTypedDictClass(subtype.classType) && - canAssignToTypedDict(subtype.classType, keyTypes, valueTypes, diagAddendum) - ) { - return subtype; - } - - const classAlias = subtype.classType.details.aliasClass || subtype.classType; - if (ClassType.isBuiltIn(classAlias, 'dict') && subtype.classType.typeArguments) { - const typeArg0 = transformPossibleRecursiveTypeAlias(subtype.classType.typeArguments[0]); - const typeArg1 = transformPossibleRecursiveTypeAlias(subtype.classType.typeArguments[1]); - const typeVarMap = new TypeVarMap(); - - for (const keyType of keyTypes) { - if (!canAssignType(typeArg0, keyType, new DiagnosticAddendum(), typeVarMap)) { - return undefined; - } - } - - for (const valueType of valueTypes) { - if (!canAssignType(typeArg1, valueType, new DiagnosticAddendum(), typeVarMap)) { - return undefined; - } - } - - return specializeType(subtype, typeVarMap); - } - - return undefined; - }); - - if (!isNever(narrowedExpectedType)) { - return { type: narrowedExpectedType, node }; - } - } - - // Strip any literal values. - keyTypes = keyTypes.map((t) => stripLiteralValue(t)); - valueTypes = valueTypes.map((t) => stripLiteralValue(t)); - - keyType = keyTypes.length > 0 ? combineTypes(keyTypes) : AnyType.create(); - - // If the value type differs and we're not using "strict inference mode", - // we need to back off because we can't properly represent the mappings - // between different keys and associated value types. If all the values - // are the same type, we'll assume that all values in this dictionary should - // be the same. - if (valueTypes.length > 0) { - if (getFileInfo(node).diagnosticRuleSet.strictDictionaryInference) { - valueType = combineTypes(valueTypes); - } else { - valueType = areTypesSame(valueTypes) ? valueTypes[0] : UnknownType.create(); - } - } else { - valueType = AnyType.create(); - } - - // If we weren't provided an expected type, strip away any - // literals from the key and value. - if (!expectedType) { - keyType = stripLiteralValue(keyType); - valueType = stripLiteralValue(valueType); - } - - const type = getBuiltInObject(node, 'Dict', [keyType, valueType]); - - return { type, node, expectedTypeDiagAddendum: !diagAddendum.isEmpty() ? diagAddendum : undefined }; } function getTypeFromList(node: ListNode, expectedType: Type | undefined): TypeResult { - // Define a local helper function that determines whether a - // type is a list and returns the list element type if it is. - const getListTypeArg = (potentialList: Type) => { - const expectedType = doForSubtypes(potentialList, (subtype) => { - subtype = transformPossibleRecursiveTypeAlias(subtype); - if (!isObject(subtype)) { - return undefined; + // If the expected type is a union, recursively call for each of the subtypes + // to find one that matches. + let effectiveExpectedType = expectedType; + + if (expectedType && expectedType.category === TypeCategory.Union) { + let matchingSubtype: Type | undefined; + + for (const subtype of expectedType.subtypes) { + const subtypeResult = useSpeculativeMode(node, () => { + return getTypeFromListExpected(node, subtype); + }); + + if (subtypeResult) { + matchingSubtype = subtype; + break; } + } - const classAlias = subtype.classType.details.aliasClass || subtype.classType; - if (!ClassType.isBuiltIn(classAlias, 'list') || !subtype.classType.typeArguments) { - return undefined; - } + effectiveExpectedType = matchingSubtype; + } - return subtype.classType.typeArguments[0]; - }); + if (effectiveExpectedType) { + const result = getTypeFromListExpected(node, effectiveExpectedType); + if (result) { + return result; + } + } - return isNever(expectedType) ? undefined : expectedType; - }; + return getTypeFromListInferred(node, /* forceStrict */ !!expectedType)!; + } - const expectedEntryType = expectedType ? getListTypeArg(expectedType) : undefined; + // Attempts to determine the type of a list statement based on an expected type. + // Returns undefined if that type cannot be honored. + function getTypeFromListExpected(node: ListNode, expectedType: Type): TypeResult | undefined { + expectedType = transformPossibleRecursiveTypeAlias(expectedType); + if (!isObject(expectedType)) { + return undefined; + } - let entryTypes: Type[] = []; + const builtInList = getBuiltInObject(node, 'List'); + if (!isObject(builtInList)) { + return undefined; + } + + const listTypeVarMap = new TypeVarMap(); + if (!populateTypeVarMapBasedOnExpectedType(builtInList.classType, expectedType, listTypeVarMap)) { + return undefined; + } + + const specializedList = specializeType(builtInList.classType, listTypeVarMap) as ClassType; + if (!specializedList.typeArguments || specializedList.typeArguments.length !== 1) { + return undefined; + } + + const expectedEntryType = specializeType(specializedList.typeArguments[0], /* typeVarMap */ undefined); + + const entryTypes: Type[] = []; node.entries.forEach((entry, index) => { if (index < maxEntriesToUseForInference || expectedType !== undefined) { if (entry.nodeType === ParseNodeType.ListComprehension) { @@ -8086,82 +8182,94 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } }); - // If there is an expected type, see if we can match it. - if (expectedType && entryTypes.length > 0) { - const narrowedExpectedType = doForSubtypes(expectedType, (subtype) => { - const expectedListElementType = getListTypeArg(subtype); - if (expectedListElementType) { - const typeVarMap = new TypeVarMap(); - - for (const entryType of entryTypes) { - let assignedNonLiteral = false; - - // If the entry type is a literal value, try to assign a non-literal - // type first to avoid over-narrowing. This may not work if the expected - // element type is a literal or a TypeVar bound to a literal. - const nonLiteralEntryType = stripLiteralValue(entryType); - if (entryType !== nonLiteralEntryType) { - if ( - canAssignType( - expectedListElementType, - nonLiteralEntryType, - new DiagnosticAddendum(), - typeVarMap - ) - ) { - assignedNonLiteral = true; - } - } - - if (!assignedNonLiteral) { - if ( - !canAssignType(expectedListElementType, entryType, new DiagnosticAddendum(), typeVarMap) - ) { - return undefined; - } - } - } - - return specializeType(subtype, typeVarMap); - } - - return undefined; - }); - - if (!isNever(narrowedExpectedType)) { - return { type: narrowedExpectedType, node }; - } + const isExpectedTypeList = + isObject(expectedType) && ClassType.isBuiltIn(ClassType.getAliasClass(expectedType.classType), 'list'); + const specializedEntryType = inferTypeArgFromExpectedType( + expectedEntryType, + entryTypes, + /* isNarrowable */ !isExpectedTypeList + ); + if (!specializedEntryType) { + return undefined; } + const type = getBuiltInObject(node, 'List', [specializedEntryType]); + return { type, node }; + } + + // Attempts to infer the type of a list statement with no "expected type". If + // forceStrict is true, it always includes all of the list subtypes. + function getTypeFromListInferred(node: ListNode, forceStrict: boolean): TypeResult { + let entryTypes: Type[] = []; + node.entries.forEach((entry, index) => { + if (index < maxEntriesToUseForInference) { + if (entry.nodeType === ParseNodeType.ListComprehension) { + entryTypes.push(getElementTypeFromListComprehension(entry)); + } else { + entryTypes.push(getTypeOfExpression(entry).type); + } + } + }); + entryTypes = entryTypes.map((t) => stripLiteralValue(t)); let inferredEntryType: Type = AnyType.create(); if (entryTypes.length > 0) { // If there was an expected type or we're using strict list inference, // combine the types into a union. - if (expectedType || getFileInfo(node).diagnosticRuleSet.strictListInference) { + if (getFileInfo(node).diagnosticRuleSet.strictListInference || forceStrict) { inferredEntryType = combineTypes(entryTypes, maxSubtypesForInferredType); } else { // Is the list homogeneous? If so, use stricter rules. Otherwise relax the rules. inferredEntryType = areTypesSame(entryTypes) ? entryTypes[0] : UnknownType.create(); } - } else if (expectedEntryType) { - inferredEntryType = expectedEntryType; - } - - // If we weren't provided an expected type, strip away any - // literals from the list. The user is probably not expecting - // ['a'] to be interpreted as type List[Literal['a']] but - // instead List[str]. - if (!expectedType) { - inferredEntryType = stripLiteralValue(inferredEntryType); } const type = getBuiltInObject(node, 'List', [inferredEntryType]); - return { type, node }; } + function inferTypeArgFromExpectedType( + expectedType: Type, + entryTypes: Type[], + isNarrowable: boolean + ): Type | undefined { + const diagDummy = new DiagnosticAddendum(); + + // Synthesize a temporary bound type var. We will attempt to assign all list + // entries to this type var, possibly narrowing the type in the process. + const targetTypeVar = TypeVarType.createInstance( + '__typeArg', + /* isParamSpec */ false, + /* isSynthesized */ true + ); + targetTypeVar.details.boundType = expectedType; + + let typeVarMap = new TypeVarMap(); + typeVarMap.setTypeVar(targetTypeVar, expectedType, isNarrowable); + + // First, try to assign entries with their literal values stripped. + // The only time we don't want to strip them is if the expected + // type explicitly includes literals. + if ( + entryTypes.some( + (entryType) => !canAssignType(targetTypeVar, stripLiteralValue(entryType), diagDummy, typeVarMap) + ) + ) { + // Allocate a fresh typeVarMap before we try again with literals not stripped. + typeVarMap = new TypeVarMap(); + typeVarMap.setTypeVar(targetTypeVar, expectedType, isNarrowable); + if (entryTypes.some((entryType) => !canAssignType(targetTypeVar!, entryType, diagDummy, typeVarMap))) { + return undefined; + } + } + + // We need to call specializeType twice here. The first time specializes the + // temporary "__typeArg" type variable. The second time specializes any type + // variables that it referred to. + return specializeType(specializeType(targetTypeVar, typeVarMap), typeVarMap); + } + function getTypeFromTernary(node: TernaryNode, flags: EvaluatorFlags, expectedType: Type | undefined): TypeResult { getTypeOfExpression(node.testExpression); @@ -8857,7 +8965,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // If this is a recursive type alias that hasn't yet been fully resolved // (i.e. there is no boundType associated with it), don't apply the transform. - if (isTypeVar(type) && type.details.recursiveTypeAliasName && !type.details.boundType) { + if (isTypeAliasPlaceholder(type)) { return type; } @@ -9143,6 +9251,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Set the resulting type to the boundType of the original type alias // to support recursive type aliases. typeAliasTypeVar!.details.boundType = rightHandType; + + // Record the type parameters within the recursive type alias so it + // can be specialized. + typeAliasTypeVar!.details.recursiveTypeParameters = rightHandType.typeAliasInfo?.typeParameters; } } } @@ -12639,9 +12751,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } const classType = containerType.classType; - const builtInName = classType.details.aliasClass - ? classType.details.aliasClass.details.name - : classType.details.name; + const builtInName = ClassType.getAliasClass(classType).details.name; if (!['list', 'set', 'frozenset', 'deque'].some((name) => name === builtInName)) { return referenceType; @@ -13104,11 +13214,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } // Disables recording of errors and warnings. - function suppressDiagnostics(callback: () => void) { + function suppressDiagnostics(callback: () => T) { const wasSuppressed = isDiagnosticSuppressed; isDiagnosticSuppressed = true; try { - callback(); + return callback(); } finally { isDiagnosticSuppressed = wasSuppressed; } @@ -13117,11 +13227,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Disables recording of errors and warnings and disables // any caching of types, under the assumption that we're // performing speculative evaluations. - function useSpeculativeMode(speculativeNode: ParseNode, callback: () => void) { + function useSpeculativeMode(speculativeNode: ParseNode, callback: () => T) { speculativeTypeTracker.enterSpeculativeContext(speculativeNode); try { - callback(); + return callback(); } finally { speculativeTypeTracker.leaveSpeculativeContext(); } @@ -15686,40 +15796,101 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return canAssign; } - // When a value is assigned to a variable with a declared type, - // we may be able to narrow the type based on the assignment. - function narrowTypeBasedOnAssignment(declaredType: Type, assignedType: Type): Type { - const diagAddendum = new DiagnosticAddendum(); + // If the declaredType contains type arguments that are "Any" and + // the corresponding type argument in the assignedType is not "Any", + // replace that type argument in the assigned type. This function assumes + // that the caller has already verified that the assignedType is assignable + // tot he declaredType. + function replaceTypeArgsWithAny(declaredType: ClassType, assignedType: ClassType): ClassType | undefined { + const assignedTypeAlias = ClassType.getAliasClass(assignedType); - if (declaredType.category === TypeCategory.Union) { - return doForSubtypes(declaredType, (subtype) => { - if (assignedType.category === TypeCategory.Union) { - if (!assignedType.subtypes.some((t) => canAssignType(subtype, t, diagAddendum))) { - return undefined; - } else { - return subtype; + if ( + assignedTypeAlias.details.typeParameters.length > 0 && + assignedType.typeArguments && + assignedType.typeArguments.length <= assignedTypeAlias.details.typeParameters.length + ) { + const typeVarMap = new TypeVarMap(); + populateTypeVarMapBasedOnExpectedType( + ClassType.cloneForSpecialization( + assignedTypeAlias, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ), + ObjectType.create(declaredType), + typeVarMap + ); + + let replacedTypeArg = false; + const newTypeArgs = assignedType.typeArguments.map((typeArg, index) => { + const typeParam = assignedTypeAlias.details.typeParameters[index]; + const expectedTypeArgType = typeVarMap.getTypeVar(typeParam); + + if (expectedTypeArgType) { + if (expectedTypeArgType.category === TypeCategory.Any || isAnyOrUnknown(typeArg)) { + replacedTypeArg = true; + return expectedTypeArgType; } } - if (!canAssignType(subtype, assignedType, diagAddendum)) { - return undefined; - } - - // We assume that assignedType is a narrower type than subtype, - // so return it rather than subtype. - if (!isAnyOrUnknown(assignedType)) { - return assignedType; - } - - return subtype; + return typeArg; }); + + if (replacedTypeArg) { + return ClassType.cloneForSpecialization(assignedType, newTypeArgs, /* isTypeArgumentExplicit */ true); + } } - if (!canAssignType(declaredType, assignedType, diagAddendum)) { - return NeverType.create(); + return undefined; + } + + // When a value is assigned to a variable with a declared type, + // we may be able to narrow the type based on the assignment. + function narrowTypeBasedOnAssignment(declaredType: Type, assignedType: Type): Type { + const diag = new DiagnosticAddendum(); + + const narrowedType = doForSubtypes(assignedType, (assignedSubtype) => { + const narrowedSubtype = doForSubtypes(declaredType, (declaredSubtype) => { + // We can't narrow "Any". + if (isAnyOrUnknown(declaredType)) { + return declaredType; + } + + if (canAssignType(declaredSubtype, assignedSubtype, diag)) { + // If the source is generic and has unspecified type arguments, + // see if we can determine then based on the declared type. + if (isClass(declaredSubtype) && isClass(assignedSubtype)) { + const result = replaceTypeArgsWithAny(declaredSubtype, assignedSubtype); + if (result) { + assignedSubtype = result; + } + } else if (isObject(declaredSubtype) && isObject(assignedSubtype)) { + const result = replaceTypeArgsWithAny(declaredSubtype.classType, assignedSubtype.classType); + if (result) { + assignedSubtype = ObjectType.create(result); + } + } + + return assignedSubtype; + } + + return undefined; + }); + + // If we couldn't assign the assigned subtype any of the declared + // subtypes, the types are incompatible. Return the unnarrowed form. + if (isNever(narrowedSubtype)) { + return assignedSubtype; + } + + return narrowedSubtype; + }); + + // If the result of narrowing is Any, stick with the declared (unnarrowed) type. + if (isAnyOrUnknown(assignedType)) { + return declaredType; } - return transformTypeObjectToClass(declaredType); + return narrowedType; } function canOverrideMethod(baseMethod: Type, overrideMethod: FunctionType, diag: DiagnosticAddendum): boolean { diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index aed8298b2..c492a5249 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -370,10 +370,16 @@ export function transformPossibleRecursiveTypeAlias(type: Type | undefined): Typ export function transformPossibleRecursiveTypeAlias(type: Type | undefined): Type | undefined { if (type) { if (isTypeVar(type) && type.details.recursiveTypeAliasName && type.details.boundType) { - if (TypeBase.isInstance(type)) { - return convertToInstance(type.details.boundType); + const unspecializedType = TypeBase.isInstance(type) + ? convertToInstance(type.details.boundType) + : type.details.boundType; + + if (!type.typeAliasInfo?.typeArguments || !type.details.recursiveTypeParameters) { + return unspecializedType; } - return type.details.boundType; + + const typeVarMap = buildTypeVarMap(type.details.recursiveTypeParameters, type.typeAliasInfo.typeArguments); + return specializeType(unspecializedType, typeVarMap); } } @@ -643,6 +649,26 @@ export function specializeType( } if (isTypeVar(type)) { + // Handle recursive type aliases specially. In particular, + // we need to specialize type arguments for generic recursive + // type aliases. + if (type.details.recursiveTypeAliasName) { + if (!type.typeAliasInfo?.typeArguments) { + return type; + } + + const typeArgs = type.typeAliasInfo.typeArguments.map((typeArg) => + specializeType(typeArg, typeVarMap, /* makeConcrete */ false, recursionLevel + 1) + ); + + return TypeBase.cloneForTypeAlias( + type, + type.typeAliasInfo.aliasName, + type.typeAliasInfo.typeParameters, + typeArgs + ); + } + if (typeVarMap) { const replacementType = typeVarMap.getTypeVar(type); if (replacementType) { @@ -1712,8 +1738,18 @@ export function requiresSpecialization(type: Type, recursionCount = 0): boolean } case TypeCategory.TypeVar: { - // If this is a recursive type alias, don't treat it like other TypeVars. - return type.details.recursiveTypeAliasName === undefined; + // Most TypeVar types need to be specialized. + if (!type.details.recursiveTypeAliasName) { + return true; + } + + // If this is a recursive type alias, it may need to be specialized + // if it has generic type arguments. + if (type.typeAliasInfo?.typeArguments) { + return type.typeAliasInfo.typeArguments.some((typeArg) => + requiresSpecialization(typeArg, recursionCount + 1) + ); + } } } diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 6607e9c71..8b6df120b 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -1272,6 +1272,9 @@ export interface TypeVarDetails { // Used for recursive type aliases. recursiveTypeAliasName?: string; + + // Type parameters for a recursive type alias. + recursiveTypeParameters?: TypeVarType[]; } export interface TypeVarType extends TypeBase { diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index 1f1a39575..f20dcb377 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -353,7 +353,13 @@ test('TypeNarrowing17', () => { test('TypeNarrowing18', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowing18.py']); - validateResults(analysisResults, 0, 0, 10); + validateResults(analysisResults, 0); +}); + +test('TypeNarrowing19', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowing19.py']); + + validateResults(analysisResults, 0); }); test('CircularBaseClass', () => { @@ -1125,7 +1131,7 @@ test('TypeAlias6', () => { test('TypeAlias7', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeAlias7.py']); - validateResults(analysisResults, 2); + validateResults(analysisResults, 3); }); test('TypeAlias8', () => { @@ -2376,6 +2382,12 @@ test('Constructor1', () => { validateResults(analysisResults, 0); }); +test('Constructor2', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['constructor2.py']); + + validateResults(analysisResults, 0); +}); + test('Constructor3', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['constructor3.py']); diff --git a/packages/pyright-internal/src/tests/samples/expressions6.py b/packages/pyright-internal/src/tests/samples/expressions6.py index 30d63ea87..55df759d4 100644 --- a/packages/pyright-internal/src/tests/samples/expressions6.py +++ b/packages/pyright-internal/src/tests/samples/expressions6.py @@ -11,4 +11,4 @@ def func_or(a: Optional[Dict[str, Any]]): def func_and(): a: Optional[Dict[str, Any]] = True and dict() - t1: Literal["Dict[str, Any]"] = reveal_type(a) + t1: Literal["dict[str, Any]"] = reveal_type(a) diff --git a/packages/pyright-internal/src/tests/samples/typeAlias7.py b/packages/pyright-internal/src/tests/samples/typeAlias7.py index 58124d8c6..82fe7653f 100644 --- a/packages/pyright-internal/src/tests/samples/typeAlias7.py +++ b/packages/pyright-internal/src/tests/samples/typeAlias7.py @@ -3,16 +3,17 @@ from typing import List, TypeVar, Union -_T2 = TypeVar("_T2", str, int) +_T1 = TypeVar("_T1", str, int) +_T2 = TypeVar("_T2") -GenericTypeAlias1 = List[Union["GenericTypeAlias1", _T2]] +GenericTypeAlias1 = List[Union["GenericTypeAlias1[_T1]", _T1]] SpecializedTypeAlias1 = GenericTypeAlias1[str] a1: SpecializedTypeAlias1 = ["hi", ["hi", "hi"]] # This should generate an error because int doesn't match the -# constraint of the TypeVar _T2. +# constraint of the TypeVar _T1. SpecializedClass2 = GenericTypeAlias1[float] b1: GenericTypeAlias1[str] = ["hi", "bye", [""], [["hi"]]] @@ -20,3 +21,12 @@ b1: GenericTypeAlias1[str] = ["hi", "bye", [""], [["hi"]]] # This should generate an error. b2: GenericTypeAlias1[str] = ["hi", [2.4]] + +GenericTypeAlias2 = List[Union["GenericTypeAlias2[_T1, _T2]", _T1, _T2]] + +c2: GenericTypeAlias2[str, int] = [[3, ["hi"]], "hi"] + +c3: GenericTypeAlias2[str, float] = [[3, ["hi", 3.4, [3.4]]], "hi"] + +# This should generate an error because a float is a type mismatch. +c4: GenericTypeAlias2[str, int] = [[3, ["hi", 3, [3.4]]], "hi"] diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowing19.py b/packages/pyright-internal/src/tests/samples/typeNarrowing19.py new file mode 100644 index 000000000..b31bb4d24 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowing19.py @@ -0,0 +1,19 @@ +# This sample tests type narrowing for assignments +# where the source contains Unknown or Any type +# arguments. + +from typing import Any, Dict, Literal + + +def func1(struct: Dict[Any, Any]): + a1: Dict[str, Any] = struct + t1: Literal["Dict[str, Any]"] = reveal_type(a1) + + +def func2(struct: Any): + a1: Dict[Any, str] = struct + t1: Literal["Dict[Any, str]"] = reveal_type(a1) + + if isinstance(struct, Dict): + a2: Dict[str, Any] = struct + t2: Literal["Dict[str, Any]"] = reveal_type(a2)