Finished adding support for generic TypedDict classes.

This commit is contained in:
Eric Traut 2022-05-07 00:42:25 -07:00
parent 5932ef4efd
commit b82a4f3266
3 changed files with 87 additions and 53 deletions

View File

@ -19210,26 +19210,38 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
): boolean {
// Handle typed dicts. They also use a form of structural typing for type
// checking, as defined in PEP 589.
if (ClassType.isTypedDictClass(destType) && ClassType.isTypedDictClass(srcType)) {
if (!canAssignTypedDict(evaluatorInterface, destType, srcType, diag, recursionCount)) {
if (
ClassType.isTypedDictClass(destType) &&
ClassType.isTypedDictClass(srcType) &&
!ClassType.isSameGenericClass(destType, srcType)
) {
if (
!canAssignTypedDict(evaluatorInterface, destType, srcType, diag, typeVarContext, flags, recursionCount)
) {
return false;
}
if (ClassType.isFinal(destType) !== ClassType.isFinal(srcType)) {
if (diag) {
diag.addMessage(
Localizer.DiagnosticAddendum.typedDictFinalMismatch().format({
sourceType: printType(convertToInstance(srcType)),
destType: printType(convertToInstance(destType)),
})
);
}
diag?.addMessage(
Localizer.DiagnosticAddendum.typedDictFinalMismatch().format({
sourceType: printType(convertToInstance(srcType)),
destType: printType(convertToInstance(destType)),
})
);
return false;
}
// If invariance is being enforced, the two TypedDicts must be assignable to each other.
if ((flags & CanAssignFlags.EnforceInvariance) !== 0 && !ClassType.isSameGenericClass(destType, srcType)) {
return canAssignTypedDict(evaluatorInterface, srcType, destType, /* diag */ undefined, recursionCount);
if ((flags & CanAssignFlags.EnforceInvariance) !== 0) {
return canAssignTypedDict(
evaluatorInterface,
srcType,
destType,
/* diag */ undefined,
/* typeVarContext */ undefined,
flags,
recursionCount
);
}
return true;

View File

@ -58,10 +58,13 @@ import {
import {
applySolvedTypeVars,
buildTypeVarContextFromSpecializedClass,
CanAssignFlags,
computeMroLinearization,
getTypeVarScopeId,
isLiteralType,
mapSubtypes,
} from './typeUtils';
import { TypeVarContext } from './typeVarContext';
// Creates a new custom TypedDict factory class.
export function createTypedDictType(
@ -647,6 +650,8 @@ export function canAssignTypedDict(
destType: ClassType,
srcType: ClassType,
diag: DiagnosticAddendum | undefined,
typeVarContext: TypeVarContext | undefined,
flags: CanAssignFlags,
recursionCount = 0
) {
let typesAreConsistent = true;
@ -656,51 +661,44 @@ export function canAssignTypedDict(
destEntries.forEach((destEntry, name) => {
const srcEntry = srcEntries.get(name);
if (!srcEntry) {
if (diag) {
diag.addMessage(
Localizer.DiagnosticAddendum.typedDictFieldMissing().format({
name,
type: evaluator.printType(srcType),
})
);
}
diag?.createAddendum().addMessage(
Localizer.DiagnosticAddendum.typedDictFieldMissing().format({
name,
type: evaluator.printType(srcType),
})
);
typesAreConsistent = false;
} else {
if (destEntry.isRequired && !srcEntry.isRequired) {
if (diag) {
diag.addMessage(
Localizer.DiagnosticAddendum.typedDictFieldRequired().format({
name,
type: evaluator.printType(destType),
})
);
}
diag?.createAddendum().addMessage(
Localizer.DiagnosticAddendum.typedDictFieldRequired().format({
name,
type: evaluator.printType(destType),
})
);
typesAreConsistent = false;
} else if (!destEntry.isRequired && srcEntry.isRequired) {
if (diag) {
diag.addMessage(
Localizer.DiagnosticAddendum.typedDictFieldNotRequired().format({
name,
type: evaluator.printType(destType),
})
);
}
diag?.createAddendum().addMessage(
Localizer.DiagnosticAddendum.typedDictFieldNotRequired().format({
name,
type: evaluator.printType(destType),
})
);
typesAreConsistent = false;
}
const subDiag = diag?.createAddendum();
if (
!evaluator.canAssignType(
destEntry.valueType,
srcEntry.valueType,
/* diag */ undefined,
/* typeVarContext */ undefined,
/* flags */ undefined,
subDiag?.createAddendum(),
typeVarContext,
flags,
recursionCount
)
) {
if (diag) {
diag.addMessage(Localizer.DiagnosticAddendum.memberTypeMismatch().format({ name }));
}
subDiag?.addMessage(Localizer.DiagnosticAddendum.memberTypeMismatch().format({ name }));
typesAreConsistent = false;
}
}
@ -729,7 +727,23 @@ export function assignToTypedDict(
let isMatch = true;
const narrowedEntries = new Map<string, TypedDictEntry>();
const symbolMap = getTypedDictMembersForClass(evaluator, classType);
let typeVarContext: TypeVarContext | undefined;
let genericClassType = classType;
if (classType.details.typeParameters.length > 0) {
typeVarContext = new TypeVarContext(getTypeVarScopeId(classType));
// Create a generic (nonspecialized version) of the class.
if (classType.typeArguments) {
genericClassType = ClassType.cloneForSpecialization(
classType,
/* typeArguments */ undefined,
/* isTypeArgumentExplicit */ false
);
}
}
const symbolMap = getTypedDictMembersForClass(evaluator, genericClassType);
keyTypes.forEach((keyType, index) => {
if (!isClassInstance(keyType) || !ClassType.isBuiltIn(keyType, 'str') || !isLiteralType(keyType)) {
@ -752,7 +766,14 @@ export function assignToTypedDict(
} else {
// Can we assign the value to the declared type?
const subDiag = diagAddendum?.createAddendum();
if (!evaluator.canAssignType(symbolEntry.valueType, valueTypes[index], subDiag?.createAddendum())) {
if (
!evaluator.canAssignType(
symbolEntry.valueType,
valueTypes[index],
subDiag?.createAddendum(),
typeVarContext
)
) {
if (subDiag) {
subDiag.addMessage(
Localizer.DiagnosticAddendum.typedDictFieldTypeMismatch().format({
@ -800,9 +821,13 @@ export function assignToTypedDict(
return undefined;
}
const specializedClassType = typeVarContext
? (applySolvedTypeVars(genericClassType, typeVarContext) as ClassType)
: classType;
return narrowedEntries.size === 0
? classType
: ClassType.cloneForNarrowedTypedDictEntries(classType, narrowedEntries);
? specializedClassType
: ClassType.cloneForNarrowedTypedDictEntries(specializedClassType, narrowedEntries);
}
export function getTypeOfIndexedTypedDict(

View File

@ -36,11 +36,8 @@ class TD4(TD3, Generic[_T1]):
v4: TD4[str] = {"a": 3, "b": ""}
# The following does not yet work like it should, so
# it is commented out for now.
def func1(x: TD1[_T1, _T2]) -> dict[_T1, _T2]:
return x["a"]
# def func1(x: TD1[_T1, _T2]) -> dict[_T1, _T2]:
# return x["a"]
# v1_3 = func1({"a": {"x": 3}, "b": "y"})
# reveal_type(v1_3, expected_text="dict[str, int]")
v1_3 = func1({"a": {"x": 3}, "b": "y"})
reveal_type(v1_3, expected_text="dict[str, int]")