diff --git a/server/src/analyzer/typeEvaluator.ts b/server/src/analyzer/typeEvaluator.ts index 8542c0535..8a1d3205a 100644 --- a/server/src/analyzer/typeEvaluator.ts +++ b/server/src/analyzer/typeEvaluator.ts @@ -9033,6 +9033,32 @@ export function createTypeEvaluator(importLookup: ImportLookup): TypeEvaluator { return typesAreConsistent; } + // Handle property classes. They are special because each property + // class has a different source ID, so they wouldn't otherwise match. + // We need to see if the return types of the properties match. + if (ClassType.isPropertyClass(destType) && ClassType.isPropertyClass(srcType)) { + let typesAreConsistent = true; + + const fgetDest = destType.details.fields.get('fget'); + const fgetSrc = srcType.details.fields.get('fget'); + if (fgetDest && fgetSrc) { + const fgetDestType = getDeclaredTypeOfSymbol(fgetDest); + const fgetSrcType = getDeclaredTypeOfSymbol(fgetSrc); + if (fgetDestType && fgetSrcType && + fgetDestType.category === TypeCategory.Function && + fgetSrcType.category === TypeCategory.Function) { + + const fgetDestReturnType = getFunctionEffectiveReturnType(fgetDestType); + const fgetSrcReturnType = getFunctionEffectiveReturnType(fgetSrcType); + if (!canAssignType(fgetDestReturnType, fgetSrcReturnType, diag)) { + typesAreConsistent = false; + } + } + } + + return typesAreConsistent; + } + // Special-case conversion for the "numeric tower". if (ClassType.isBuiltIn(destType, 'float')) { if (ClassType.isBuiltIn(srcType, 'int')) { diff --git a/server/src/tests/checker.test.ts b/server/src/tests/checker.test.ts index e56f5183c..ed42cf519 100644 --- a/server/src/tests/checker.test.ts +++ b/server/src/tests/checker.test.ts @@ -882,6 +882,12 @@ test('Protocol2', () => { validateResults(analysisResults, 0); }); +test('Protocol3', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol3.py']); + + validateResults(analysisResults, 1); +}); + test('TypedDict1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDict1.py']); diff --git a/server/src/tests/samples/protocol3.py b/server/src/tests/samples/protocol3.py new file mode 100644 index 000000000..01512a590 --- /dev/null +++ b/server/src/tests/samples/protocol3.py @@ -0,0 +1,41 @@ +# This sample tests the assignment of protocols that +# include property declarations. + +from typing import Protocol + +class Foo1(Protocol): + @property + def batch_shape(self) -> int: + return 0 + + +class MockFoo1: + def __init__(self, batch_shape: int): + self._batch_shape = batch_shape + + @property + def batch_shape(self) -> int: + return self._batch_shape + +# This should not generate an error. +d: Foo1 = MockFoo1(batch_shape=1) + + +class Foo2(Protocol): + @property + def batch_shape(self) -> int: + return 0 + + +class MockFoo2: + def __init__(self, batch_shape: int): + self._batch_shape = batch_shape + + @property + def batch_shape(self) -> float: + return self._batch_shape + +# This should generate an error because the +# type of the batch_shape property is not compatible. +e: Foo2 = MockFoo2(batch_shape=1) +