From 90015a385271767ed2e1a37521dcad55073cdb03 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 28 Jul 2020 14:45:56 -0700 Subject: [PATCH] Improved bidirectional type inference in the case where the type and the expected type are generic but the expected type is a base class that has been specialized. For example, if the expected type is `Mapping[str, int]` and the type is a `dict`. --- server/src/analyzer/typeEvaluator.ts | 77 +++++++++++++++++++++- server/src/analyzer/types.ts | 1 + server/src/tests/checker.test.ts | 6 ++ server/src/tests/samples/genericTypes29.py | 12 ++++ 4 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 server/src/tests/samples/genericTypes29.py diff --git a/server/src/analyzer/typeEvaluator.ts b/server/src/analyzer/typeEvaluator.ts index 85dc4b1e7..bea544833 100644 --- a/server/src/analyzer/typeEvaluator.ts +++ b/server/src/analyzer/typeEvaluator.ts @@ -4572,8 +4572,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, printTypeFlags: const typeVarMap = new TypeVarMap(); if (expectedType) { - // Prepopulate the typeVarMap based on the expected type. - canAssignType(ObjectType.create(type), expectedType, new DiagnosticAddendum(), typeVarMap); + populateTypeVarMapBasedOnExpectedType(ObjectType.create(type), expectedType, typeVarMap); } const callResult = validateCallArguments( @@ -4744,6 +4743,80 @@ export function createTypeEvaluator(importLookup: ImportLookup, printTypeFlags: return type; } + // In cases where the expected type is a specialized base class of the + // source type, we need to determine which type arguments in the derived + // class will make it compatible with the specialized base class. This method + // performs this reverse mapping of type arguments and populates the type var + // map for the target type. + function populateTypeVarMapBasedOnExpectedType(type: ObjectType, expectedType: Type, typeVarMap: TypeVarMap) { + // It's common for the expected type to be Optional. Remove the None + // to see if the resulting type is an object. + const expectedTypeWithoutNone = removeNoneFromUnion(expectedType); + + // If the resulting type isn't an object, we can't proceed. + if (expectedTypeWithoutNone.category !== TypeCategory.Object) { + return; + } + + // If the target type isn't generic, there's nothing for us to do. + if (!requiresSpecialization(type)) { + return; + } + + // If the expected type is generic (not specialized), we can't proceed. + const expectedTypeArgs = expectedTypeWithoutNone.classType.typeArguments; + if (expectedTypeArgs === undefined) { + return; + } + + // If the expected type is the same as the target type (commonly the case), + // we can use a faster method. + if (ClassType.isSameGenericClass(expectedTypeWithoutNone.classType, type.classType)) { + canAssignType(type, expectedTypeWithoutNone, new DiagnosticAddendum(), typeVarMap); + return; + } + + // Create a generic (not specialized) version of the expected type. + const genericExpectedType = ClassType.cloneForSpecialization( + expectedTypeWithoutNone.classType, + undefined, + /* isTypeArgumentExplicit */ false + ); + + // For each type param in the target type, create a placeholder type variable. + const typeArgs = type.classType.details.typeParameters.map((_, index) => { + const typeVar = TypeVarType.createInstance(`__${index}`, /* isParamSpec */ false, /* isSynthesized */ true); + typeVar.synthesizedIndex = index; + return typeVar; + }); + + const specializedType = ClassType.cloneForSpecialization( + type.classType, + typeArgs, + /* isTypeArgumentExplicit */ true + ); + const syntheticTypeVarMap = new TypeVarMap(); + if (canAssignType(genericExpectedType, specializedType, new DiagnosticAddendum(), syntheticTypeVarMap)) { + genericExpectedType.details.typeParameters.forEach((typeVar, index) => { + const synthTypeVar = syntheticTypeVarMap.getTypeVar(typeVar.name); + + // Is this one of the synthesized type vars we allocated above? If so, + // the type arg that corresponds to this type var maps back to the target type. + if ( + synthTypeVar && + synthTypeVar.category === TypeCategory.TypeVar && + synthTypeVar.isSynthesized && + synthTypeVar.synthesizedIndex !== undefined + ) { + const targetTypeVar = specializedType.details.typeParameters[synthTypeVar.synthesizedIndex]; + if (index < expectedTypeArgs.length) { + typeVarMap.setTypeVar(targetTypeVar.name, expectedTypeArgs[index], /* isNarrowable */ false); + } + } + }); + } + } + // Validates that the arguments can be assigned to the call's parameter // list, specializes the call based on arg types, and returns the // specialized type of the return value. If it detects an error along diff --git a/server/src/analyzer/types.ts b/server/src/analyzer/types.ts index 2edcfcb33..781344617 100644 --- a/server/src/analyzer/types.ts +++ b/server/src/analyzer/types.ts @@ -1100,6 +1100,7 @@ export interface TypeVarType extends TypeBase { // Internally created (e.g. for pseudo-generic classes) isSynthesized: boolean; + synthesizedIndex?: number; } export namespace TypeVarType { diff --git a/server/src/tests/checker.test.ts b/server/src/tests/checker.test.ts index bbac970a5..3d81e5b7b 100644 --- a/server/src/tests/checker.test.ts +++ b/server/src/tests/checker.test.ts @@ -1411,6 +1411,12 @@ test('GenericTypes28', () => { validateResults(analysisResults, 1); }); +test('GenericTypes29', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['genericTypes29.py']); + + validateResults(analysisResults, 1); +}); + test('Protocol1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol1.py']); diff --git a/server/src/tests/samples/genericTypes29.py b/server/src/tests/samples/genericTypes29.py new file mode 100644 index 000000000..e95986714 --- /dev/null +++ b/server/src/tests/samples/genericTypes29.py @@ -0,0 +1,12 @@ +# This sample tests bidirectional inference when the +# type derives from the expected type and both are +# generic. + +from typing import Mapping, Optional, Union + +v0: Optional[Mapping[str, Union[int, str]]] = dict([('test1', 1), ('test2', 2)]) + +v1: Optional[Mapping[str, float]] = dict([('test1', 1), ('test2', 2)]) + +# This should generate an error because of a type mismatch. +v2: Mapping[str, str] = dict([('test1', 1), ('test2', 2)])