diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 77767766f..fc9add6b7 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -17338,10 +17338,16 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Handle typed dicts. They also use a form of structural typing for type // checking, as defined in PEP 589. if (ClassType.isTypedDictClass(destType) && ClassType.isTypedDictClass(srcType)) { - if ((flags & CanAssignFlags.EnforceInvariance) !== 0 && !ClassType.isSameGenericClass(destType, srcType)) { + if (!canAssignTypedDict(evaluatorInterface, destType, srcType, diag, recursionCount)) { return false; } - return canAssignTypedDict(evaluatorInterface, destType, srcType, diag, recursionCount); + + // If invariance is being enforced, the two TypedDicts must be assignable to each other. + if ((flags & CanAssignFlags.EnforceInvariance) !== 0 && !ClassType.isSameGenericClass(destType, srcType)) { + return canAssignTypedDict(evaluatorInterface, srcType, destType, /* diag */ undefined, recursionCount); + } + + return true; } // Handle special-case type promotions. diff --git a/packages/pyright-internal/src/tests/samples/typedDict16.py b/packages/pyright-internal/src/tests/samples/typedDict16.py new file mode 100644 index 000000000..dac520961 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typedDict16.py @@ -0,0 +1,80 @@ +# This sample tests that type compatibility between TypedDicts. + +from typing import List, TypedDict + + +class TD0(TypedDict): + key: str + + +class TD1(TD0): + value: str + + +class TD2(TypedDict): + key: str + value: str + + +v1: TD2 = TD1(key="", value="") +v2: TD1 = TD2(key="", value="") + +v3 = [v2] +v4: List[TD2] = v3 +v5 = [v1] +v6: List[TD1] = v5 + + +class TD10(TypedDict, total=False): + key: str + + +class TD11(TD10): + value: str + + +class TD12(TypedDict): + key: str + value: str + + +# This should generate an error. +v10: TD12 = TD11(key="", value="") + +# This should generate an error. +v11: TD11 = TD12(key="", value="") + + +v12 = [v10] +# This should generate an error. +v13: List[TD10] = v12 + +v14 = [v11] +# This should generate an error. +v15: List[TD12] = v14 + + +class TD20(TypedDict): + key: str + value: str + + +class TD21(TypedDict): + key: str + value: str + extra: str + + +# This should generate an error. +v20: TD21 = TD20(key="", value="") + +v21: TD20 = TD21(key="", value="", extra="") + + +v22 = [v20] +# This should generate an error. +v23: List[TD20] = v22 + +v24: List[TD20] = [v21] +# This should generate an error. +v25: List[TD21] = v24 diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index 68f90e318..31745b3ca 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -986,3 +986,9 @@ test('TypedDict15', () => { TestUtils.validateResults(analysisResults, 2); }); + +test('TypedDict16', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDict16.py']); + + TestUtils.validateResults(analysisResults, 7); +});