From 62ac83694ef8bc51fcc9e8616746dabd6d192f84 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Fri, 14 Jan 2022 11:55:23 -0800 Subject: [PATCH] Implemented a new diagnostic check "reportMissingSuperCall" that checks for `__init__`, `__init_subclass__`, `__enter__` and `__exit__` methods that fail to call through to their parent class(es) methods of the same name. This is a common source of bugs. The check is disabled by default in basic mode but enabled by default in strict mode. --- docs/configuration.md | 3 + .../pyright-internal/src/analyzer/checker.ts | 119 +++++++++++++++++- .../src/analyzer/parseTreeUtils.ts | 11 ++ .../src/common/configOptions.ts | 14 +++ .../src/common/diagnosticRules.ts | 1 + .../src/localization/localize.ts | 2 + .../src/localization/package.nls.en-us.json | 1 + .../src/tests/samples/missingSuper1.py | 61 +++++++++ .../src/tests/typeEvaluator2.test.ts | 11 ++ packages/vscode-pyright/package.json | 11 ++ .../schemas/pyrightconfig.schema.json | 6 + 11 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 packages/pyright-internal/src/tests/samples/missingSuper1.py diff --git a/docs/configuration.md b/docs/configuration.md index 04ae95fb8..0e1e76c8c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -114,6 +114,8 @@ The following settings control pyright’s diagnostic output (warnings or errors **reportOverlappingOverload** [boolean or string, optional]: Generate or suppress diagnostics for function overloads that overlap in signature and obscure each other or have incompatible return types. The default value for this setting is 'none'. +**reportMissingSuperCall** [boolean or string, optional]: Generate or suppress diagnostics for `__init__`, `__init_subclass__`, `__enter__` and `__exit__` methods in a subclass that fail to call through to the same-named method on a base class. The default value for this setting is 'none'. + **reportUninitializedInstanceVariable** [boolean or string, optional]: Generate or suppress diagnostics for instance variables within a class that are not initialized or declared within the class body or the `__init__` method. The default value for this setting is 'none'. **reportInvalidStringEscapeSequence** [boolean or string, optional]: Generate or suppress diagnostics for invalid escape sequences used within string literals. The Python specification indicates that such sequences will generate a syntax error in future versions. The default value for this setting is 'warning'. @@ -308,6 +310,7 @@ The following table lists the default severity levels for each diagnostic rule w | reportIncompatibleVariableOverride | "none" | "none" | "error" | | reportInconsistentConstructor | "none" | "none" | "error" | | reportOverlappingOverload | "none" | "none" | "error" | +| reportMissingSuperCall | "none" | "none" | "error" | | reportUninitializedInstanceVariable | "none" | "none" | "none" | | reportInvalidStringEscapeSequence | "none" | "warning" | "error" | | reportUnknownParameterType | "none" | "none" | "error" | diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 551326483..6d72a69f8 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -352,7 +352,7 @@ export class Checker extends ParseTreeWalker { override visitFunction(node: FunctionNode): boolean { const functionTypeResult = this._evaluator.getTypeOfFunction(node); - const containingClassNode = ParseTreeUtils.getEnclosingClass(node, true); + const containingClassNode = ParseTreeUtils.getEnclosingClass(node, /* stopAtFunction */ true); if (functionTypeResult) { // Track whether we have seen a *args: P.args parameter. Named @@ -4271,6 +4271,19 @@ export class Checker extends ParseTreeWalker { const classTypeInfo = this._evaluator.getTypeOfClass(classNode); const classType = classTypeInfo?.classType; + if (node.name && classType) { + const superCheckMethods = ['__init__', '__init_subclass__', '__enter__', '__exit__']; + if (superCheckMethods.some((name) => name === node.name.value)) { + if ( + !FunctionType.isAbstractMethod(functionType) && + !FunctionType.isOverloaded(functionType) && + !this._fileInfo.isStubFile + ) { + this._validateSuperCallForMethod(node, functionType, classType); + } + } + } + if (node.name && node.name.value === '__new__') { // __new__ overrides should have a "cls" parameter. if ( @@ -4410,6 +4423,110 @@ export class Checker extends ParseTreeWalker { } } + // Determines whether the method properly calls through to the same method in all + // parent classes that expose a same-named method. + private _validateSuperCallForMethod(node: FunctionNode, methodType: FunctionType, classType: ClassType) { + // This is an expensive test, so if it's not enabled, don't do any work. + if (this._fileInfo.diagnosticRuleSet.reportMissingSuperCall === 'none') { + return; + } + + // Determine which parent classes expose a same-named method. + let baseClassesToCall: ClassType[] = []; + classType.details.baseClasses.forEach((baseClass) => { + if (isClass(baseClass)) { + const methodMember = lookUpClassMember( + baseClass, + methodType.details.name, + ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass + ); + + if (methodMember && isClass(methodMember.classType)) { + const methodClass = methodMember.classType; + if (!baseClassesToCall.find((cls) => ClassType.isSameGenericClass(cls, methodClass))) { + baseClassesToCall.push(methodClass); + } + } + } + }); + + // Now scan the implementation of the method to determine whether + // super(). has been called for all of the required base classes. + const callNodeWalker = new ParseTreeUtils.CallNodeWalker((node) => { + if (node.leftExpression.nodeType === ParseNodeType.MemberAccess) { + // Is it accessing the method by the same name? + if (node.leftExpression.memberName.value === methodType.details.name) { + const memberBaseExpr = node.leftExpression.leftExpression; + + // Is it a "super" call? + if ( + memberBaseExpr.nodeType === ParseNodeType.Call && + memberBaseExpr.leftExpression.nodeType === ParseNodeType.Name && + memberBaseExpr.leftExpression.value === 'super' + ) { + let targetClassType: Type | undefined; + + // Is this a zero-argument call to 'super', or is an explicit + // target class provided? + if (memberBaseExpr.arguments.length === 0) { + targetClassType = classType; + } else { + targetClassType = this._evaluator.getTypeOfExpression( + memberBaseExpr.arguments[0].valueExpression + ).type; + targetClassType = this._evaluator.makeTopLevelTypeVarsConcrete(targetClassType); + } + + if (isAnyOrUnknown(targetClassType)) { + // If the target class is Any or unknown, all bets are off. + baseClassesToCall = []; + } else if (isInstantiableClass(targetClassType)) { + const lookupResults = lookUpClassMember( + targetClassType, + methodType.details.name, + ClassMemberLookupFlags.SkipOriginalClass | + ClassMemberLookupFlags.SkipInstanceVariables | + ClassMemberLookupFlags.SkipObjectBaseClass + ); + + if (lookupResults && isInstantiableClass(lookupResults.classType)) { + const baseType = lookupResults.classType; + + // Note that we've called this base class. + baseClassesToCall = baseClassesToCall.filter( + (cls) => !ClassType.isSameGenericClass(cls, baseType) + ); + } + } + } else { + // Is it an X. direct call? + const baseType = this._evaluator.getType(memberBaseExpr); + if (baseType && isInstantiableClass(baseType)) { + // Note that we've called this base class. + baseClassesToCall = baseClassesToCall.filter( + (cls) => !ClassType.isSameGenericClass(cls, baseType) + ); + } + } + } + } + }); + callNodeWalker.walk(node.suite); + + // If there are base classes that haven't yet been called, report it as an error. + baseClassesToCall.forEach((baseClass) => { + this._evaluator.addDiagnostic( + this._fileInfo.diagnosticRuleSet.reportMissingSuperCall, + DiagnosticRule.reportMissingSuperCall, + Localizer.Diagnostic.missingSuperCall().format({ + methodName: methodType.details.name, + classType: this._evaluator.printType(convertToInstance(baseClass)), + }), + node.name + ); + }); + } + // Validates that the annotated type of a "self" or "cls" parameter is // compatible with the type of the class that contains it. private _validateClsSelfParameterType(functionType: FunctionType, classType: ClassType, isCls: boolean) { diff --git a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts index 170541ed7..7e04f86b4 100644 --- a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts +++ b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts @@ -1353,6 +1353,17 @@ export class NameNodeWalker extends ParseTreeWalker { } } +export class CallNodeWalker extends ParseTreeWalker { + constructor(private _callback: (node: CallNode) => void) { + super(); + } + + override visitCall(node: CallNode) { + this._callback(node); + return true; + } +} + export function getEnclosingParameter(node: ParseNode): ParameterNode | undefined { let curNode: ParseNode | undefined = node; diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts index f8531fd95..723c4a4e9 100644 --- a/packages/pyright-internal/src/common/configOptions.ts +++ b/packages/pyright-internal/src/common/configOptions.ts @@ -195,6 +195,9 @@ export interface DiagnosticRuleSet { // incompatible return types. reportOverlappingOverload: DiagnosticLevel; + // Report failure to call super().__init__() in __init__ method. + reportMissingSuperCall: DiagnosticLevel; + // Report instance variables that are not initialized within // the constructor. reportUninitializedInstanceVariable: DiagnosticLevel; @@ -332,6 +335,7 @@ export function getDiagLevelDiagnosticRules() { DiagnosticRule.reportIncompatibleVariableOverride, DiagnosticRule.reportInconsistentConstructor, DiagnosticRule.reportOverlappingOverload, + DiagnosticRule.reportMissingSuperCall, DiagnosticRule.reportUninitializedInstanceVariable, DiagnosticRule.reportInvalidStringEscapeSequence, DiagnosticRule.reportUnknownParameterType, @@ -409,6 +413,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet { reportIncompatibleVariableOverride: 'none', reportInconsistentConstructor: 'none', reportOverlappingOverload: 'none', + reportMissingSuperCall: 'none', reportUninitializedInstanceVariable: 'none', reportInvalidStringEscapeSequence: 'none', reportUnknownParameterType: 'none', @@ -482,6 +487,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet { reportIncompatibleVariableOverride: 'none', reportInconsistentConstructor: 'none', reportOverlappingOverload: 'none', + reportMissingSuperCall: 'none', reportUninitializedInstanceVariable: 'none', reportInvalidStringEscapeSequence: 'warning', reportUnknownParameterType: 'none', @@ -555,6 +561,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet { reportIncompatibleVariableOverride: 'error', reportInconsistentConstructor: 'error', reportOverlappingOverload: 'error', + reportMissingSuperCall: 'error', reportUninitializedInstanceVariable: 'none', reportInvalidStringEscapeSequence: 'error', reportUnknownParameterType: 'error', @@ -1143,6 +1150,13 @@ export class ConfigOptions { defaultSettings.reportOverlappingOverload ), + // Read the "reportMissingSuperCall" entry. + reportMissingSuperCall: this._convertDiagnosticLevel( + configObj.reportMissingSuperCall, + DiagnosticRule.reportMissingSuperCall, + defaultSettings.reportMissingSuperCall + ), + // Read the "reportUninitializedInstanceVariable" entry. reportUninitializedInstanceVariable: this._convertDiagnosticLevel( configObj.reportUninitializedInstanceVariable, diff --git a/packages/pyright-internal/src/common/diagnosticRules.ts b/packages/pyright-internal/src/common/diagnosticRules.ts index 86f215653..8e3c38800 100644 --- a/packages/pyright-internal/src/common/diagnosticRules.ts +++ b/packages/pyright-internal/src/common/diagnosticRules.ts @@ -48,6 +48,7 @@ export enum DiagnosticRule { reportIncompatibleVariableOverride = 'reportIncompatibleVariableOverride', reportInconsistentConstructor = 'reportInconsistentConstructor', reportOverlappingOverload = 'reportOverlappingOverload', + reportMissingSuperCall = 'reportMissingSuperCall', reportUninitializedInstanceVariable = 'reportUninitializedInstanceVariable', reportInvalidStringEscapeSequence = 'reportInvalidStringEscapeSequence', reportUnknownParameterType = 'reportUnknownParameterType', diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 4d29c27d0..3add1cd78 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -507,6 +507,8 @@ export namespace Localizer { ); export const methodReturnsNonObject = () => new ParameterizedString<{ name: string }>(getRawString('Diagnostic.methodReturnsNonObject')); + export const missingSuperCall = () => + new ParameterizedString<{ methodName: string, classType: string }>(getRawString('Diagnostic.missingSuperCall')); export const moduleAsType = () => getRawString('Diagnostic.moduleAsType'); export const moduleNotCallable = () => getRawString('Diagnostic.moduleNotCallable'); export const moduleUnknownMember = () => diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 62175bf63..d12c73968 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -227,6 +227,7 @@ "methodOrdering": "Cannot create consistent method ordering", "methodOverridden": "\"{name}\" overrides method of same name in class \"{className}\" with incompatible type \"{type}\"", "methodReturnsNonObject": "\"{name}\" method does not return an object", + "missingSuperCall": "Method \"{methodName}\" does not call the method of the same name in parent class \"{classType}\"", "moduleAsType": "Module cannot be used as a type", "moduleNotCallable": "Module is not callable", "moduleUnknownMember": "\"{name}\" is not a known member of module", diff --git a/packages/pyright-internal/src/tests/samples/missingSuper1.py b/packages/pyright-internal/src/tests/samples/missingSuper1.py new file mode 100644 index 000000000..1c1b7c4ca --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/missingSuper1.py @@ -0,0 +1,61 @@ +# This sample tests the reportMissingSuperCall diagnostic check. + + +class ParentA: + def __init__(self): + pass + + def __init_subclass__(cls) -> None: + pass + + +class ParentB: + def __init__(self): + pass + + +class ParentBPrime(ParentB): + pass + + +class ParentC: + pass + + +class ChildA(ParentA, ParentB): + # This should generate two errors. + def __init__(self): + pass + + # This should generate one error. + def __init_subclass__(cls) -> None: + pass + + +class ChildB(ParentA, ParentB): + # This should generate one error. + def __init__(self): + super().__init__() + + +class ChildC1(ParentA, ParentB): + def __init__(self): + super().__init__() + ParentB.__init__(self) + + +class ChildC2(ParentA, ParentB): + def __init__(self): + ParentA.__init__(self) + ParentB.__init__(self) + + +class ChildCPrime(ParentA, ParentBPrime, ParentC): + def __init__(self): + super().__init__() + super(ParentBPrime).__init__() + + +class ChildD(ParentC): + def __init__(self): + pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index 08d1fa18a..d12c95beb 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -174,6 +174,17 @@ test('Super7', () => { TestUtils.validateResults(analysisResults, 3); }); +test('MissingSuper1', () => { + const configOptions = new ConfigOptions('.'); + + const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['missingSuper1.py'], configOptions); + TestUtils.validateResults(analysisResults1, 0); + + configOptions.diagnosticRuleSet.reportMissingSuperCall = 'error'; + const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['missingSuper1.py'], configOptions); + TestUtils.validateResults(analysisResults2, 4); +}); + test('NewType1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['newType1.py']); diff --git a/packages/vscode-pyright/package.json b/packages/vscode-pyright/package.json index 728b02b1f..0742ecec3 100644 --- a/packages/vscode-pyright/package.json +++ b/packages/vscode-pyright/package.json @@ -454,6 +454,17 @@ "error" ] }, + "reportMissingSuperCall": { + "type": "string", + "description": "Diagnostics for missing call to parent class for inherited `__init__` methods.", + "default": "none", + "enum": [ + "none", + "information", + "warning", + "error" + ] + }, "reportUninitializedInstanceVariable": { "type": "string", "description": "Diagnostics for instance variables that are not declared or initialized within class body or `__init__` method.", diff --git a/packages/vscode-pyright/schemas/pyrightconfig.schema.json b/packages/vscode-pyright/schemas/pyrightconfig.schema.json index 0cca1815f..650ce54a3 100644 --- a/packages/vscode-pyright/schemas/pyrightconfig.schema.json +++ b/packages/vscode-pyright/schemas/pyrightconfig.schema.json @@ -317,6 +317,12 @@ "title": "Controls reporting of function overloads that overlap in signature and obscure each other or do not agree on return type", "default": "none" }, + "reportMissingSuperCall": { + "$id": "#/properties/reportMissingSuperCall", + "$ref": "#/definitions/diagnostic", + "title": "Controls reporting of missing call to parent class for inherited `__init__` methods", + "default": "none" + }, "reportUninitializedInstanceVariable": { "$id": "#/properties/reportUninitializedInstanceVariable", "$ref": "#/definitions/diagnostic",