From 501681295b13b8b088c192fac6b62011a791c690 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sat, 29 Jul 2023 22:45:10 -0600 Subject: [PATCH] Modified type inference logic so methods that raise an exception whose type derives from `NotImplementedError` is not inferred to return `NoReturn`. Previously, only `NotImplementedError` was exempted, not subclasses. This addresses https://github.com/microsoft/pyright/issues/5608. (#5609) Co-authored-by: Eric Traut --- docs/type-inference.md | 2 +- docs/typed-libraries.md | 2 +- .../src/analyzer/codeFlowEngine.ts | 16 ++++----- .../src/analyzer/typeEvaluator.ts | 9 ++--- .../src/analyzer/typeUtils.ts | 4 +++ .../src/tests/samples/inferredTypes1.py | 2 +- .../src/tests/samples/inferredTypes3.py | 33 +++++++++++++++++++ .../src/tests/typeEvaluator4.test.ts | 5 +++ 8 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/inferredTypes3.py diff --git a/docs/type-inference.md b/docs/type-inference.md index 59f50585d..a374c69c5 100644 --- a/docs/type-inference.md +++ b/docs/type-inference.md @@ -142,7 +142,7 @@ def func1(val: int): #### NoReturn return type -If there is no code path that returns from a function (e.g. all code paths raise an exception), Pyright infers a return type of `NoReturn`. As an exception to this rule, if the function is decorated with `@abstractmethod`, the return type is not inferred as `NoReturn` even if there is no return. This accommodates a common practice where an abstract method is implemented with a `raise NotImplementedError()` statement. +If there is no code path that returns from a function (e.g. all code paths raise an exception), Pyright infers a return type of `NoReturn`. As an exception to this rule, if the function is decorated with `@abstractmethod`, the return type is not inferred as `NoReturn` even if there is no return. This accommodates a common practice where an abstract method is implemented with a `raise` statement that raises an exception of type `NotImplementedError`. ```python class Foo: diff --git a/docs/typed-libraries.md b/docs/typed-libraries.md index 6f86b510c..5cd85849a 100644 --- a/docs/typed-libraries.md +++ b/docs/typed-libraries.md @@ -311,7 +311,7 @@ StrOrInt: TypeAlias = str | int ``` #### Abstract Classes and Methods -Classes that must be subclassed should derive from `ABC`, and methods or properties that must be overridden should be decorated with the `@abstractmethod` decorator. This allows type checkers to validate that the required methods have been overridden and provide developers with useful error messages when they are not. It is customary to implement an abstract method by raising a `NotImplementedError` exception. +Classes that must be subclassed should derive from `ABC`, and methods or properties that must be overridden should be decorated with the `@abstractmethod` decorator. This allows type checkers to validate that the required methods have been overridden and provide developers with useful error messages when they are not. It is customary to implement an abstract method by raising a `NotImplementedError` exception or subclass thereof. ```python from abc import ABC, abstractmethod diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 4b0664e35..79e635bd0 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -65,6 +65,7 @@ import { } from './types'; import { ClassMemberLookupFlags, + derivesFromStdlibClass, doForEachSubtype, isIncompleteUnknown, isTypeAliasPlaceholder, @@ -1620,16 +1621,13 @@ export function getCodeFlowEngine( } if (simpleStatement.nodeType === ParseNodeType.Raise && simpleStatement.typeExpression) { - // Check for "raise NotImplementedError" or "raise NotImplementedError()" - const isNotImplementedName = (node: ParseNode) => { - return node?.nodeType === ParseNodeType.Name && node.value === 'NotImplementedError'; - }; + // Check for a raising about 'NotImplementedError' or a subtype thereof. + const exceptionType = evaluator.getType(simpleStatement.typeExpression); - if (isNotImplementedName(simpleStatement.typeExpression)) { - foundRaiseNotImplemented = true; - } else if ( - simpleStatement.typeExpression.nodeType === ParseNodeType.Call && - isNotImplementedName(simpleStatement.typeExpression.leftExpression) + if ( + exceptionType && + isClass(exceptionType) && + derivesFromStdlibClass(exceptionType, 'NotImplementedError') ) { foundRaiseNotImplemented = true; } diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index eeedabe49..19fcfabbc 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -200,6 +200,7 @@ import { convertToInstantiable, convertTypeToParamSpecValue, derivesFromClassRecursive, + derivesFromStdlibClass, doForEachSubtype, ensureFunctionSignaturesAreUnique, explodeGenericClass, @@ -17540,9 +17541,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return inferredReturnType ? { type: inferredReturnType, isIncomplete } : undefined; } - // Determines whether the function consists only of a "raise" statement - // and the exception type raised is a NotImplementedError. This is commonly - // used for abstract methods that + // Determines whether the method consists only of a "raise" statement + // and the exception type raised is a NotImplementedError or a subclass + // thereof. This is commonly used for abstract methods. function methodAlwaysRaisesNotImplemented(functionDecl?: FunctionDeclaration): boolean { if ( !functionDecl || @@ -17564,7 +17565,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions : isClassInstance(raiseType) ? raiseType : undefined; - if (!classType || !ClassType.isBuiltIn(classType, 'NotImplementedError')) { + if (!classType || !derivesFromStdlibClass(classType, 'NotImplementedError')) { return false; } } diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 3657a2efc..787be5c1f 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -1856,6 +1856,10 @@ export function specializeForBaseClass(srcType: ClassType, baseClass: ClassType) return specializedType as ClassType; } +export function derivesFromStdlibClass(classType: ClassType, className: string) { + return classType.details.mro.some((mroClass) => isClass(mroClass) && ClassType.isBuiltIn(mroClass, className)); +} + // If ignoreUnknown is true, an unknown base class is ignored when // checking for derivation. If ignoreUnknown is false, a return value // of true is assumed. diff --git a/packages/pyright-internal/src/tests/samples/inferredTypes1.py b/packages/pyright-internal/src/tests/samples/inferredTypes1.py index c5b4ac22e..7a5193aa0 100644 --- a/packages/pyright-internal/src/tests/samples/inferredTypes1.py +++ b/packages/pyright-internal/src/tests/samples/inferredTypes1.py @@ -12,7 +12,7 @@ def make_api_request(auth: str) -> str: return "meow" -def testfunc() -> None: +def func1() -> None: resp = open("test") auth = resp.read() diff --git a/packages/pyright-internal/src/tests/samples/inferredTypes3.py b/packages/pyright-internal/src/tests/samples/inferredTypes3.py new file mode 100644 index 000000000..3cd9f4905 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/inferredTypes3.py @@ -0,0 +1,33 @@ +# This sample tests return type annotations for functions that +# do not return. + +from abc import ABC, abstractmethod + + +class OtherError(NotImplementedError): + ... + + +class A(ABC): + def func1(self): + raise Exception("test") + + def func2(self): + raise NotImplementedError() + + def func3(self): + raise OtherError + + @abstractmethod + def func4(self): + raise Exception() + + +def func1(a: A): + reveal_type(a.func1(), expected_text="NoReturn") + + reveal_type(a.func2(), expected_text="Unknown") + + reveal_type(a.func3(), expected_text="Unknown") + + reveal_type(a.func4(), expected_text="Unknown") diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 0f0791e78..2cb5faa24 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -374,6 +374,11 @@ test('InferredTypes2', () => { TestUtils.validateResults(analysisResults, 0); }); +test('InferredTypes3', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['inferredTypes3.py']); + TestUtils.validateResults(analysisResults, 0); +}); + test('CallSite2', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['callSite2.py']); TestUtils.validateResults(analysisResults, 0);